├── .cargo └── config.toml ├── .gitignore ├── .gitmodules ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── bacon.toml ├── crates ├── a │ ├── Cargo.toml │ └── src │ │ ├── gpt.rs │ │ ├── lib.rs │ │ ├── main.rs │ │ ├── record.rs │ │ └── util.rs ├── anthropic │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── client.rs │ │ ├── complete.rs │ │ ├── error.rs │ │ └── lib.rs ├── b │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── anthropic.rs │ │ ├── chats.rs │ │ ├── commands.rs │ │ ├── completions.rs │ │ ├── edits.rs │ │ ├── lib.rs │ │ ├── main.rs │ │ ├── tokenizer.rs │ │ └── utils.rs ├── c │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── commands.rs │ │ ├── commands │ │ ├── anthropic.rs │ │ ├── nlpcloud.rs │ │ ├── ollama.rs │ │ ├── openai.rs │ │ └── vertex.rs │ │ ├── lib.rs │ │ ├── main.rs │ │ ├── session.rs │ │ └── utils.rs ├── d │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── commands.rs │ │ ├── commands │ │ ├── chat.rs │ │ ├── embeddings.rs │ │ ├── sessions.rs │ │ └── vector.rs │ │ ├── constants.rs │ │ ├── main.rs │ │ ├── models.rs │ │ ├── printer.rs │ │ ├── sessions.rs │ │ ├── shutdown.rs │ │ ├── similarity.rs │ │ └── vector.rs ├── e │ ├── Cargo.toml │ └── src │ │ ├── anthropic.rs │ │ ├── args.rs │ │ ├── config.rs │ │ ├── error.rs │ │ ├── google.rs │ │ ├── main.rs │ │ ├── mistral.rs │ │ ├── mistral_fim.rs │ │ ├── openai.rs │ │ ├── prelude.rs │ │ └── printer.rs ├── fs │ ├── Cargo.toml │ ├── README.md │ └── src │ │ └── lib.rs ├── openai │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── chats.rs │ │ ├── client.rs │ │ ├── completions.rs │ │ ├── edits.rs │ │ ├── error.rs │ │ ├── lib.rs │ │ ├── models.rs │ │ └── utils.rs └── spinner │ ├── Cargo.toml │ ├── README.md │ └── src │ └── lib.rs ├── lib ├── README.md ├── es_stream │ ├── Cargo.toml │ ├── examples │ │ ├── stream_anthropic.rs │ │ ├── stream_copilot.rs │ │ ├── stream_google.rs │ │ ├── stream_mistral.rs │ │ ├── stream_mistral_fim.rs │ │ └── stream_openai.rs │ └── src │ │ ├── anthropic.rs │ │ ├── error.rs │ │ ├── google.rs │ │ ├── lib.rs │ │ ├── mistral.rs │ │ ├── mistral_fim.rs │ │ ├── openai.rs │ │ └── requests.rs └── gpt_tokenizer │ ├── Cargo.toml │ ├── README.md │ ├── examples │ ├── custom_tokenizer.rs │ ├── default_tokenizer.rs │ └── regex.rs │ └── src │ ├── encoder.json │ ├── lib.rs │ └── vocab.bpe └── xtask ├── Cargo.toml └── src ├── cli.rs ├── main.rs ├── scripts.rs └── utils.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [alias] 2 | xtask = "run --package xtask --bin xtask --" 3 | 4 | [env] 5 | CARGO_WORKSPACE_DIR = { value = "", relative = true } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | lib/bat 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lib/bat"] 2 | path = lib/bat 3 | url = git@github.com:sharkdp/bat.git 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["xtask/", "lib/*", "crates/*"] 3 | resolver = "2" 4 | 5 | [workspace.package] 6 | edition = "2021" 7 | license = "MIT" 8 | authors = ["Guzmán Monné"] 9 | 10 | [profile.dev] 11 | # Disabling debug info speeds up builds. 12 | debug = 0 13 | 14 | [profile.release] 15 | incremental = true 16 | # Set this to 1 or 2 to get more useful backtraces in debugger. 17 | debug = 0 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Guzman Monne 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bacon.toml: -------------------------------------------------------------------------------- 1 | # This is a configuration file for the bacon tool 2 | # 3 | # Bacon repository: https://github.com/Canop/bacon 4 | # Complete help on configuration: https://dystroy.org/bacon/config/ 5 | # You can also check bacon's own bacon.toml file 6 | # as an example: https://github.com/Canop/bacon/blob/main/bacon.toml 7 | 8 | default_job = "check" 9 | 10 | [jobs.check] 11 | command = ["cargo", "check", "--color", "always"] 12 | need_stdout = false 13 | 14 | [jobs.check-all] 15 | command = ["cargo", "check", "--all-targets", "--color", "always"] 16 | need_stdout = false 17 | 18 | [jobs.clippy] 19 | command = [ 20 | "cargo", "clippy", "--fix", 21 | "--all-targets", 22 | "--color", "always", 23 | "--", 24 | "-W", "clippy::pedantic", 25 | "-W", "clippy::nursery", 26 | "-W", "clippy::unwrap_used", 27 | "-W", "clippy::expect_used", 28 | ] 29 | need_stdout = false 30 | 31 | [jobs.test] 32 | command = [ 33 | "cargo", "test", "--color", "always", 34 | "--", "--color", "always", # see https://github.com/Canop/bacon/issues/124 35 | ] 36 | need_stdout = true 37 | 38 | [jobs.doc] 39 | command = ["cargo", "doc", "--color", "always", "--no-deps"] 40 | need_stdout = false 41 | 42 | # If the doc compiles, then it opens in your browser and bacon switches 43 | # to the previous job 44 | [jobs.doc-open] 45 | command = ["cargo", "doc", "--color", "always", "--no-deps", "--open"] 46 | need_stdout = false 47 | on_success = "back" # so that we don't open the browser at each change 48 | 49 | # You can run your application and have the result displayed in bacon, 50 | # *if* it makes sense for this crate. You can run an example the same 51 | # way. Don't forget the `--color always` part or the errors won't be 52 | # properly parsed. 53 | [jobs.run] 54 | command = [ 55 | "cargo", "run", 56 | "--color", "always", 57 | # put launch parameters for your program behind a `--` separator 58 | ] 59 | need_stdout = true 60 | allow_warnings = true 61 | 62 | # You may define here keybindings that would be specific to 63 | # a project, for example a shortcut to launch a specific job. 64 | # Shortcuts to internal functions (scrolling, toggling, etc.) 65 | # should go in your personal global prefs.toml file instead. 66 | [keybindings] 67 | j = "scroll-lines(1)" 68 | k = "scroll-lines(-1)" 69 | ctrl-d = "scroll-pages(1)" 70 | ctrl-u = "scroll-pages(-1)" 71 | g = "scroll-to-top" 72 | shift-g = "scroll-to-bottom" 73 | -------------------------------------------------------------------------------- /crates/a/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "gpt" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "A CLI tool to interact with ChatGPT as a Software Engineer" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [[bin]] 13 | name = "a" 14 | path = "src/main.rs" 15 | 16 | [lib] 17 | name = "a" 18 | path = "src/lib.rs" 19 | 20 | [dependencies] 21 | gpt_tokenizer = { version = "0.1.0", path = "../../lib/gpt_tokenizer" } 22 | bat = "0.24" 23 | copypasta-ext = "0.4.4" 24 | reqwest = { version = "0.11.14", features = ["blocking"] } 25 | serde = { version = "1.0.152", features = ["derive"] } 26 | serde_json = "1.0.93" 27 | env_logger = "0.10.0" 28 | log = "0.4.17" 29 | chrono = "0.4.23" 30 | 31 | [features] 32 | clipboard = [] 33 | default = ["clipboard"] 34 | -------------------------------------------------------------------------------- /crates/a/src/lib.rs: -------------------------------------------------------------------------------- 1 | use log::error; 2 | use std::error::Error; 3 | use std::io; 4 | 5 | pub mod gpt; 6 | pub mod record; 7 | pub mod util; 8 | 9 | /// Max tokens that will be used for the prompt. Thise leaves 10 | /// 1096 tokens for ChatGPT response. 11 | const MAX_TOKENS: u32 = 3000; 12 | const LAST_REQUEST_FILE: &str = "last_request.json"; 13 | const CONFIG_DIRECTORY_PATH: &str = "/tmp/a"; 14 | 15 | /// Gathers all arguments provided to the binary. If no arguments are provided then stdin 16 | /// is used. The first argument will always be considered the programming language. 17 | /// 18 | /// # Errors 19 | /// 20 | /// This function will return an error if . 21 | pub fn gather_args(args: &mut Vec) -> Result<(String, String), Box> { 22 | let lang; 23 | let mut prompt = String::new(); 24 | 25 | if args.is_empty() { 26 | return Err(Box::new(io::Error::new( 27 | io::ErrorKind::InvalidInput, 28 | "No arguments provided", 29 | ))); 30 | } 31 | 32 | args.remove(0); 33 | if args.is_empty() { 34 | if let Err(e) = io::stdin().read_line(&mut prompt) { 35 | error!("Could not read from stdin: {}", e); 36 | return Err(Box::new(e)); 37 | } 38 | 39 | let words: Vec = prompt.split_whitespace().map(|s| s.to_string()).collect(); 40 | if words.is_empty() { 41 | error!("Less than one word found"); 42 | return Err(Box::new(io::Error::new( 43 | io::ErrorKind::InvalidInput, 44 | "Less than one word found", 45 | ))); 46 | } 47 | 48 | if words[0] != "a" { 49 | lang = words[0].to_string(); 50 | } else if words.len() >= 2 { 51 | lang = words[1].to_string(); 52 | } else { 53 | error!("No language specified"); 54 | return Err(Box::new(io::Error::new( 55 | io::ErrorKind::InvalidInput, 56 | "No language specified", 57 | ))); 58 | } 59 | } else { 60 | lang = args[0].clone(); 61 | prompt = args.join(" "); 62 | } 63 | 64 | Ok((prompt, lang)) 65 | } 66 | -------------------------------------------------------------------------------- /crates/a/src/main.rs: -------------------------------------------------------------------------------- 1 | use log::{debug, error, info}; 2 | 3 | const WHISPER_TRIGGER: &str = "whisper"; 4 | 5 | fn main() { 6 | env_logger::init(); 7 | 8 | let mut args: Vec<_> = std::env::args().collect(); 9 | 10 | let tuple = match a::gather_args(&mut args) { 11 | Ok(args) => args, 12 | Err(e) => { 13 | error!("error parsing arguments: {}", e); 14 | std::process::exit(1); 15 | } 16 | }; 17 | 18 | debug!("tuple: {:?}", tuple); 19 | let api_key = match std::env::var("OPENAI_API_KEY") { 20 | Ok(key) => key, 21 | Err(_) => { 22 | error!("Please set the GPT_API_KEY not set environment variable"); 23 | std::process::exit(1); 24 | } 25 | }; 26 | 27 | let mut client = a::gpt::GPTClient::new(api_key.to_string()); 28 | 29 | let (prompt, lang) = if tuple.1 == crate::WHISPER_TRIGGER { 30 | let text = match a::record::whisper(api_key) { 31 | Ok(text) => text, 32 | Err(e) => { 33 | error!("error recording whisper: {}", e); 34 | std::process::exit(1); 35 | } 36 | }; 37 | ( 38 | text.clone(), 39 | text.split_whitespace().next().unwrap_or("text").to_string(), 40 | ) 41 | } else { 42 | tuple 43 | }; 44 | 45 | let mut response = match client.prompt(prompt) { 46 | Ok(response) => response, 47 | Err(e) => { 48 | error!("prompt error: {}", e); 49 | std::process::exit(2); 50 | } 51 | }; 52 | debug!("response: {:#?}", response); 53 | 54 | response.push('\n'); 55 | if let Some(r) = response.strip_suffix("\n\n") { 56 | response = String::from(r); 57 | } 58 | 59 | #[cfg(feature = "clipboard")] 60 | { 61 | a::util::copy_to_clipboard(&response); 62 | info!("copy to clipboard"); 63 | } 64 | 65 | info!("pretty print to stdout"); 66 | a::util::pretty_print(&a::util::remove_code_lines(&response), &lang); 67 | } 68 | -------------------------------------------------------------------------------- /crates/a/src/record.rs: -------------------------------------------------------------------------------- 1 | use log::info; 2 | use serde::{Deserialize, Serialize}; 3 | use std::io::{self, Read}; 4 | use std::process::{Command, Stdio}; 5 | 6 | #[derive(Deserialize, Serialize, Debug)] 7 | struct TranscriptionResponse { 8 | text: String, 9 | } 10 | 11 | fn wait_for_keypress() -> io::Result<()> { 12 | let mut buffer = [0; 1]; 13 | let stdin = io::stdin(); 14 | let mut handle = stdin.lock(); 15 | handle.read_exact(&mut buffer)?; 16 | Ok(()) 17 | } 18 | 19 | pub fn whisper(api_key: String) -> std::io::Result { 20 | info!("Spawning the rec command"); 21 | let tmp = String::from_utf8( 22 | Command::new("mktemp") 23 | .output() 24 | .expect("Failed to create temp file") 25 | .stdout, 26 | ) 27 | .expect("Failed to parse mktemp") 28 | .trim() 29 | .to_string() 30 | + ".wav"; 31 | let mut rec = Command::new("rec") 32 | .arg("-c") 33 | .arg("1") 34 | .arg("-r") 35 | .arg("48000") 36 | .arg(&tmp) 37 | .stdout(Stdio::null()) 38 | .spawn() 39 | .expect("Failed to spawn rec command"); 40 | 41 | println!("\n----------------------------------------------"); 42 | println!("Recording! Press any key to stop the recording"); 43 | println!("[Ctrl+C] cancels the recording"); 44 | println!("----------------------------------------------\n"); 45 | wait_for_keypress().expect("Failed to wait for keypress"); 46 | 47 | info!("Stopping the rec command"); 48 | rec.kill().expect("Failed to kill rec"); 49 | 50 | info!("Running ffmpeg"); 51 | Command::new("ffmpeg") 52 | .arg("-i") 53 | .arg(&tmp) 54 | .arg("-acodec") 55 | .arg("libmp3lame") 56 | .arg("-y") 57 | .arg(crate::CONFIG_DIRECTORY_PATH.to_owned() + "/whisper.mp3") 58 | .output() 59 | .expect("Failed to execute ffmpeg command"); 60 | 61 | info!("Removing the tmp file"); 62 | Command::new("rm") 63 | .arg("-rf") 64 | .arg(&tmp) 65 | .output() 66 | .expect("Failed to execute rm command"); 67 | 68 | info!("Getting audio transcription"); 69 | 70 | let response = String::from_utf8( 71 | Command::new("curl") 72 | .arg("-sX") 73 | .arg("POST") 74 | .arg("https://api.openai.com/v1/audio/transcriptions") 75 | .arg("-H") 76 | .arg("Content-Type: multipart/form-data") 77 | .arg("-H") 78 | .arg(("Authorization: Bearer ").to_owned() + &api_key) 79 | .arg("--form") 80 | .arg(("file=@").to_owned() + crate::CONFIG_DIRECTORY_PATH + "/whisper.mp3") 81 | .arg("--form") 82 | .arg("model=whisper-1") 83 | .output() 84 | .expect("Failed to execute request to the Whisper API") 85 | .stdout, 86 | ) 87 | .expect("Failed to parse reponse") 88 | .trim() 89 | .to_string(); 90 | 91 | info!("Removing the whisper.mp3 file"); 92 | Command::new("rm") 93 | .arg("-rf") 94 | .arg(crate::CONFIG_DIRECTORY_PATH.to_owned() + "/whisper.mp3") 95 | .output() 96 | .expect("Failed to execute rm command"); 97 | 98 | let body: TranscriptionResponse = serde_json::from_str(&response)?; 99 | 100 | println!("\n----------------------------------------------\n\n"); 101 | 102 | Ok(body.text) 103 | } 104 | -------------------------------------------------------------------------------- /crates/a/src/util.rs: -------------------------------------------------------------------------------- 1 | use bat::PrettyPrinter; 2 | use bat::Syntax; 3 | use chrono::prelude::*; 4 | use copypasta_ext::prelude::*; 5 | use copypasta_ext::x11_fork::ClipboardContext; 6 | use std::fs; 7 | use std::fs::File; 8 | use std::io::prelude::*; 9 | use std::io::Read; 10 | 11 | const THEME: &str = "ansi"; 12 | 13 | fn lang_exists(lang: &str, langs: &Vec) -> bool { 14 | for l in langs { 15 | if l.name.to_lowercase() == lang.to_lowercase() { 16 | return true; 17 | } 18 | for e in &l.file_extensions { 19 | if e == &lang.to_lowercase() { 20 | return true; 21 | } 22 | } 23 | } 24 | false 25 | } 26 | 27 | pub fn pretty_print(str: &str, mut lang: &str) { 28 | let mut pp = PrettyPrinter::new(); 29 | 30 | let langs: Vec<_> = pp.syntaxes().collect(); 31 | if !lang_exists(lang, &langs) { 32 | lang = "txt" 33 | } 34 | 35 | pp.input_from_bytes(str.as_bytes()) 36 | .language(lang) 37 | .use_italics(true) 38 | .theme(THEME) 39 | .print() 40 | .unwrap(); 41 | } 42 | 43 | pub fn copy_to_clipboard(str: &str) { 44 | let mut ctx = ClipboardContext::new().unwrap(); 45 | ctx.set_contents(str.to_string()).unwrap(); 46 | } 47 | 48 | pub fn write_to_file(file_path: &str, content: &str) -> std::io::Result<()> { 49 | // Open the file in write mode 50 | let mut file = File::create(file_path)?; 51 | 52 | // Write the content to the file 53 | file.write_all(content.as_bytes())?; 54 | 55 | Ok(()) 56 | } 57 | 58 | pub fn read_file(file_path: &str) -> String { 59 | match File::open(file_path) { 60 | Ok(mut file) => { 61 | let mut contents = String::new(); 62 | match file.read_to_string(&mut contents) { 63 | Ok(_) => contents, 64 | Err(_) => String::new(), 65 | } 66 | } 67 | Err(_) => String::new(), 68 | } 69 | } 70 | 71 | pub fn get_current_date() -> String { 72 | let local: DateTime = Local::now(); 73 | local.format("%Y-%m-%d").to_string() 74 | } 75 | 76 | pub fn remove_code_lines(text: &str) -> String { 77 | let mut result = String::new(); 78 | for line in text.lines() { 79 | if !line.trim_start().starts_with("```") { 80 | result.push_str(line); 81 | result.push('\n'); 82 | } 83 | } 84 | result 85 | } 86 | 87 | pub fn create_dir_if_not_exist(dir_path: &str) -> std::io::Result<()> { 88 | if fs::metadata(dir_path).is_err() { 89 | fs::create_dir_all(dir_path)?; 90 | } 91 | Ok(()) 92 | } 93 | -------------------------------------------------------------------------------- /crates/anthropic/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "anthropic" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "Anthropic API Wrapper" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [lib] 13 | name = "anthropic" 14 | path = "src/lib.rs" 15 | 16 | [dependencies] 17 | fs = { path = "../fs", version = "0.0.0" } 18 | gpt_tokenizer = { version = "0.1.0", path = "../../lib/gpt_tokenizer" } 19 | custom_error = "1.9.2" # Define custom errors without boilerplate using the custom_error! macro. 20 | env_logger = "0.10.0" 21 | log = "0.4.17" 22 | reqwest = { version = "0.11.16", features = ["json"] } 23 | serde = { version = "1.0.152", features = ["derive"] } 24 | serde_either = "0.2.1" # Simple set to enums to deserialize and serialize data that can either be string, struct or vec 25 | serde_json = "1.0.93" 26 | serde_yaml = "0.9.25" 27 | anyhow = "1.0.71" # Flexible concrete Error type built on std::error::Error 28 | tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.… 29 | tracing = "0.1.37" 30 | color-eyre = "0.6.2" 31 | tokio-stream = "0.1.14" 32 | reqwest-eventsource = "0.4.0" 33 | -------------------------------------------------------------------------------- /crates/anthropic/README.md: -------------------------------------------------------------------------------- 1 | # Anthropic 2 | 3 | Contains the necessary definitions and traits that are used by other traits to implement 4 | their desired behavior. 5 | 6 | -------------------------------------------------------------------------------- /crates/anthropic/src/client.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use color_eyre::eyre::{self, Context, Result}; 4 | use reqwest::header::{HeaderMap, HeaderValue}; 5 | use reqwest::{Client as ReqwestClient, Response as ReqwestResponse}; 6 | use reqwest_eventsource::EventSource; 7 | 8 | #[derive(Clone, Debug, Default)] 9 | pub struct Client { 10 | reqwest: ReqwestClient, 11 | base_url: String, 12 | headers: HeaderMap, 13 | } 14 | 15 | const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; 16 | 17 | fn create_headers(api_key: String) -> Result { 18 | let mut headers = HeaderMap::new(); 19 | 20 | let authorization = 21 | HeaderValue::from_str(api_key.as_str()).context("can't create authorization header")?; 22 | let content_type = 23 | HeaderValue::from_str("application/json").context("can't create content-type header")?; 24 | let version = 25 | HeaderValue::from_str("2023-06-01").context("can't create anthropic-version header")?; 26 | 27 | headers.insert("anthropic-version", version); 28 | headers.insert("X-API-Key", authorization); 29 | headers.insert("Content-Type", content_type); 30 | 31 | Ok(headers) 32 | } 33 | 34 | impl Client { 35 | /// Creates a new client with the given API key. 36 | pub fn new(api_key: String) -> Result { 37 | let reqwest = ReqwestClient::builder() 38 | .timeout(Duration::from_secs(300)) 39 | .build() 40 | .context("can't create reqwest client")?; 41 | 42 | tracing::event!(tracing::Level::INFO, "Creating API client headers..."); 43 | let headers = create_headers(api_key).context("can't create headers")?; 44 | 45 | Ok(Self { 46 | reqwest, 47 | base_url: ANTHROPIC_API_URL.to_string(), 48 | headers, 49 | }) 50 | } 51 | 52 | /// Changes the client base url. 53 | pub fn set_base_url(&mut self, base_url: String) -> &mut Self { 54 | self.base_url = base_url; 55 | self 56 | } 57 | 58 | /// Change the Anthropic API key. 59 | pub fn set_api_key(&mut self, api_key: String) -> Result<&mut Self> { 60 | self.headers = create_headers(api_key).context("can't create headers")?; 61 | Ok(self) 62 | } 63 | 64 | /// Makes a GET request to the Anthropic API. 65 | pub async fn get(&self, endpoint: &str) -> Result { 66 | let mut url = self.base_url.clone(); 67 | url.push_str(endpoint); 68 | 69 | tracing::event!(tracing::Level::INFO, "GET {}", url); 70 | 71 | self.reqwest 72 | .get(url) 73 | .headers(self.headers.clone()) 74 | .send() 75 | .await 76 | .context("can't send reqwest request") 77 | } 78 | 79 | /// Makes a POST request to the Anthropic API. 80 | pub async fn post(&self, endpoint: &str, body: String) -> Result { 81 | let mut url = self.base_url.clone(); 82 | url.push_str(endpoint); 83 | 84 | tracing::event!(tracing::Level::INFO, "POST {}", url); 85 | 86 | self.reqwest 87 | .post(url) 88 | .headers(self.headers.clone()) 89 | .body(body) 90 | .send() 91 | .await 92 | .context("can't send reqwest request") 93 | } 94 | 95 | /// Makes a POST request to the OpenAi API that returns a SSE stream. 96 | pub async fn post_stream(&self, endpoint: &str, body: String) -> Result { 97 | let mut url = self.base_url.clone(); 98 | url.push_str(endpoint); 99 | 100 | tracing::event!(tracing::Level::INFO, "POST {}", url); 101 | 102 | let builder = self 103 | .reqwest 104 | .post(url) 105 | .headers(self.headers.clone()) 106 | .body(body); 107 | 108 | match EventSource::new(builder) { 109 | Ok(x) => Ok(x), 110 | Err(e) => Err(eyre::eyre!("can't create event source: {}", e)), 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /crates/anthropic/src/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug)] 2 | pub enum Anthropic { 3 | Unknown, 4 | } 5 | 6 | impl std::error::Error for Anthropic {} 7 | 8 | impl std::fmt::Display for Anthropic { 9 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 10 | match self { 11 | Self::Unknown => write!(f, "Unknown error"), 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /crates/anthropic/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod client; 2 | pub mod complete; 3 | pub mod error; 4 | -------------------------------------------------------------------------------- /crates/b/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "b" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "A CLI tool to interact with ChatGPT as a Software Engineer" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [[bin]] 13 | name = "b" 14 | path = "src/main.rs" 15 | 16 | [lib] 17 | name = "b" 18 | path = "src/lib.rs" 19 | 20 | [dependencies] 21 | gpt_tokenizer = { version = "0.1.0", path = "../../lib/gpt_tokenizer" } 22 | clap = { version = "4.1.8", features = ["derive", "env"] } 23 | env_logger = "0.10.0" 24 | log = "0.4.17" 25 | openai = { path = "../openai", version = "0.0.0" } 26 | anthropic = { path = "../anthropic", version = "0.0.0" } 27 | serde = { version = "1.0.152", features = ["derive"] } 28 | serde_either = "0.2.1" # Simple set to enums to deserialize and serialize data that can either be string, struct or vec 29 | serde_json = "1.0.93" 30 | serde_yaml = "0.9.19" # YAML data format for Serde 31 | async-trait = "0.1.68" # Type erasure for async trait methods 32 | tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.… 33 | indicatif = "0.17.3" # A progress bar and cli reporting library for Rust 34 | anyhow = "1.0.71" # Flexible concrete Error type built on std::error::Error 35 | tokio-stream = "0.1.14" 36 | 37 | -------------------------------------------------------------------------------- /crates/b/README.md: -------------------------------------------------------------------------------- 1 | # b 2 | 3 | ## Snippets 4 | 5 | ```rust 6 | #[derive(Debug, Parser)] 7 | #[command(name = "v2")] 8 | #[command(about = "Interact with OpenAI's ChatGPT through the terminal")] 9 | struct Cli { 10 | #[command(subcommand)] 11 | command: Option, 12 | 13 | /// Name of the chat session. 14 | #[arg(short, long)] 15 | session: Option, 16 | 17 | /// URL endpoint of the OpenAI ChatGPT API. 18 | #[arg(short, long, default_value_t=String::from("https://api.openai.com/v1/chat/completions"))] 19 | url: String, 20 | 21 | /// ChatGPT model to use. 22 | #[arg(short, long, default_value_t=String::from("gpt-3.5-turbo"))] 23 | model: String, 24 | 25 | /// Temperature value of ChatGPT response. 26 | #[arg(long, default_value_t=0.0, value_parser = in_range)] 27 | temperature: f32, 28 | 29 | /// Top-p value of ChatGPT response. 30 | #[arg(long, default_value_t=0.8, value_parser = in_range)] 31 | top_p: f32, 32 | 33 | /// Presence penalty value of ChatGPT response. 34 | #[arg(long, default_value_t=0.0, value_parser = in_range)] 35 | presence_penalty: f32, 36 | 37 | /// Frequencey penalty value of ChatGPT response. 38 | #[arg(long, default_value_t=0.0, value_parser = in_range)] 39 | frequency_penalty: f32, 40 | 41 | /// Prompt that should be send to ChatGPT. 42 | prompt: Vec, 43 | } 44 | 45 | fn in_range(s: &str) -> Result { 46 | let num: f32 = s.parse().map_err(|_| "Not a number".to_string())?; 47 | if &num < &0.0 { 48 | Err(String::from("Temperature must be positive")) 49 | } else if &num > &1.0 { 50 | Err(String::from("Temperature must be less than 1")) 51 | } else { 52 | Ok(num) 53 | } 54 | } 55 | 56 | #[derive(Debug, Subcommand)] 57 | enum Commands { 58 | /// Whisper to OpenAI 59 | Whisper, 60 | /// Create new resources 61 | New { 62 | #[command(subcommand)] 63 | command: NewCommand, 64 | }, 65 | } 66 | 67 | #[derive(Debug, Subcommand, Clone)] 68 | enum NewCommand { 69 | /// Create a new chat session 70 | Chat, 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /crates/b/src/anthropic.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | 5 | use anthropic::complete::{Api, Model, Response}; 6 | 7 | use crate::utils::read_from_stdin; 8 | use crate::{AnthropicCommands, Cli, CommandError, CommandHandle, CommandResult}; 9 | 10 | pub struct CompleteCreateCommand { 11 | pub api: Api, 12 | } 13 | 14 | impl CompleteCreateCommand { 15 | pub fn new(cli: &Cli, command: &AnthropicCommands) -> Result> { 16 | match command { 17 | AnthropicCommands::Create { 18 | prompt, 19 | system, 20 | model, 21 | max_tokens_to_sample, 22 | stop_sequences, 23 | temperature, 24 | top_k, 25 | top_p, 26 | session, 27 | max_supported_tokens, 28 | } => { 29 | let api_key = cli 30 | .anthropic_api_key 31 | .as_ref() 32 | .expect("api key not set") 33 | .to_string(); 34 | 35 | let mut api = if let Some(s) = session { 36 | log::debug!("loading session {}", s); 37 | Api::new_with_session(api_key, s.to_owned())? 38 | } else { 39 | log::debug!("creating new session"); 40 | Api::new(api_key)? 41 | }; 42 | 43 | api.prompt.push_str("\n\nHuman: "); 44 | 45 | if prompt == "-" { 46 | let stdin = read_from_stdin()?; 47 | api.prompt.push_str(&stdin); 48 | } else { 49 | api.prompt.push_str(prompt); 50 | } 51 | 52 | if let Some(m) = model { 53 | let model: Model = Model::from(*m); 54 | 55 | if api.model as u32 != model as u32 { 56 | api.model = model; 57 | }; 58 | } 59 | 60 | log::debug!("model: {:?}", api.model); 61 | 62 | if system.is_some() { 63 | api.system = system.clone(); 64 | } 65 | 66 | max_tokens_to_sample.map(|s| api.max_tokens_to_sample = Some(s)); 67 | max_supported_tokens.map(|s| api.max_supported_tokens = Some(s)); 68 | temperature.map(|s| api.set_temperature(s)); 69 | top_k.map(|s| api.set_top_k(s)); 70 | top_p.map(|s| api.set_top_p(s)); 71 | 72 | api.stream = Some(cli.stream); 73 | 74 | if let Some(s) = stop_sequences { 75 | api.stop_sequences = Some(s.to_vec()); 76 | } 77 | 78 | Ok(Self { api }) 79 | } 80 | } 81 | } 82 | } 83 | 84 | impl CommandResult for Response { 85 | type ResultError = CommandError; 86 | 87 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 88 | writeln!(w, "{}", self.completion)?; 89 | Ok(()) 90 | } 91 | } 92 | 93 | #[derive(Debug)] 94 | pub struct CompleteCreateCommandError; 95 | 96 | impl std::error::Error for CompleteCreateCommandError {} 97 | 98 | impl std::fmt::Display for CompleteCreateCommandError { 99 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 100 | write!(f, "Bad :(") 101 | } 102 | } 103 | 104 | #[async_trait] 105 | impl CommandHandle for CompleteCreateCommand { 106 | type CallError = CommandError; 107 | 108 | async fn call(&self) -> Result { 109 | todo!() 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /crates/b/src/chats.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::error::Error; 3 | use tokio_stream::StreamExt; 4 | 5 | use async_trait::async_trait; 6 | use serde_either::SingleOrVec; 7 | use serde_json::from_str; 8 | 9 | use openai::chats::{Chat, ChatMessage, ChatsApi}; 10 | use openai::error::OpenAi as OpenAiError; 11 | 12 | use crate::utils::{read_from_stdin, Spinner}; 13 | use crate::{ChatsCommands, Cli, CommandError, CommandHandle, CommandResult}; 14 | 15 | pub struct ChatsCreateCommand { 16 | pub api: ChatsApi, 17 | } 18 | 19 | impl ChatsCreateCommand { 20 | pub fn new(cli: &Cli, command: &ChatsCommands) -> Result> { 21 | match command { 22 | ChatsCommands::Create { 23 | model, 24 | session, 25 | prompt, 26 | system, 27 | max_tokens, 28 | temperature, 29 | top_p, 30 | n, 31 | stop, 32 | presence_penalty, 33 | frequency_penalty, 34 | user, 35 | logit_bias, 36 | min_available_tokens, 37 | max_supported_tokens, 38 | functions, 39 | function_call, 40 | } => { 41 | let api_key = cli 42 | .openai_api_key 43 | .as_ref() 44 | .expect("No API Key provided") 45 | .to_string(); 46 | 47 | log::debug!("Session: {:?}", session); 48 | let mut api = if let Some(s) = session { 49 | ChatsApi::new_with_session(api_key, s.to_owned())? 50 | } else { 51 | ChatsApi::new(api_key)? 52 | }; 53 | 54 | let message = match prompt { 55 | Some(s) if s == "-" => ChatMessage { 56 | content: Some(read_from_stdin()?), 57 | role: "user".to_owned(), 58 | ..Default::default() 59 | }, 60 | Some(s) => ChatMessage { 61 | content: Some(s.to_owned()), 62 | role: "user".to_owned(), 63 | ..Default::default() 64 | }, 65 | None => ChatMessage { 66 | content: Some("".to_owned()), 67 | role: "user".to_owned(), 68 | ..Default::default() 69 | }, 70 | }; 71 | 72 | api.messages.push(message); 73 | 74 | if let Some(s) = system { 75 | if api.messages.first().unwrap().role == "system" { 76 | api.messages.remove(0); 77 | } 78 | api.messages.insert( 79 | 0, 80 | ChatMessage { 81 | content: Some(s.to_owned()), 82 | role: "system".to_owned(), 83 | ..Default::default() 84 | }, 85 | ); 86 | } 87 | 88 | if let Some(m) = model { 89 | api.model = m.to_owned(); 90 | } 91 | 92 | if let Some(f) = functions { 93 | api.functions = match from_str(f) { 94 | Ok(f) => f, 95 | Err(e) => { 96 | log::error!("Error parsing functions: {}", e); 97 | return Err(Box::new(e)); 98 | } 99 | }; 100 | } 101 | 102 | if let Some(f) = function_call { 103 | api.function_call = match from_str(f) { 104 | Ok(f) => f, 105 | Err(e) => { 106 | log::error!("Error parsing function_call: {}", e); 107 | return Err(Box::new(e)); 108 | } 109 | }; 110 | } 111 | 112 | log::debug!("Using model: {:?}", api.model); 113 | 114 | max_tokens.map(|s| api.max_tokens = Some(s)); 115 | n.map(|s| api.n = Some(s)); 116 | temperature.map(|s| api.set_temperature(s)); 117 | top_p.map(|s| api.set_top_p(s)); 118 | presence_penalty.map(|s| api.set_presence_penalty(s)); 119 | frequency_penalty.map(|s| api.set_frequency_penalty(s)); 120 | min_available_tokens.map(|s| api.min_available_tokens = Some(s)); 121 | max_supported_tokens.map(|s| api.max_supported_tokens = Some(s)); 122 | 123 | api.stream = Some(cli.stream); 124 | 125 | if &api.user != user { 126 | api.user = user.to_owned(); 127 | } 128 | 129 | stop.as_ref() 130 | .map(|s| api.set_stop(SingleOrVec::Vec(s.to_vec()))); 131 | 132 | if let Some(logit_bias) = logit_bias { 133 | let mut map = api.logit_bias.unwrap_or(HashMap::new()); 134 | for (key, value) in logit_bias { 135 | map.insert(key.to_owned(), *value); 136 | } 137 | api.logit_bias = Some(map); 138 | } 139 | 140 | Ok(Self { api }) 141 | } 142 | } 143 | } 144 | } 145 | 146 | impl CommandResult for Chat { 147 | type ResultError = CommandError; 148 | 149 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 150 | match self.choices.first() { 151 | Some(choice) => { 152 | if let Some(content) = &choice.message.content { 153 | writeln!(w, "{}", content)?; 154 | } else if let Some(fc) = &choice.message.function_call { 155 | writeln!(w, "{}", fc.arguments)?; 156 | } 157 | Ok(()) 158 | } 159 | None => Err(CommandError::from(OpenAiError::NoChoices)), 160 | } 161 | } 162 | } 163 | 164 | #[async_trait] 165 | impl CommandHandle for ChatsCreateCommand { 166 | type CallError = OpenAiError; 167 | 168 | async fn call(&self) -> Result { 169 | let mut spinner = Spinner::new(false); 170 | 171 | log::debug!("Stream is: {:?}", self.api.stream); 172 | 173 | if Some(true) == self.api.stream { 174 | log::debug!("Creating stream"); 175 | 176 | let chunks = match self.api.create_stream().await { 177 | Ok(chunks) => chunks, 178 | Err(e) => { 179 | log::error!("Error creating stream: {}", e); 180 | return Err(OpenAiError::StreamError); 181 | } 182 | }; 183 | 184 | tokio::pin!(chunks); 185 | 186 | while let Some(chunk) = chunks.next().await { 187 | if chunk.is_err() { 188 | log::error!("Error reading stream"); 189 | spinner.err("Error reading stream"); 190 | return Err(OpenAiError::StreamError); 191 | } 192 | 193 | // spinner.ok(); 194 | 195 | let chunk = chunk.unwrap(); 196 | 197 | if let Some(choice) = chunk.choices.get(0) { 198 | if let Some(delta) = &choice.delta { 199 | if let Some(content) = &delta.content { 200 | // print!("{}", content); 201 | spinner.print(content); 202 | } 203 | } 204 | } 205 | } 206 | 207 | spinner.ok(); 208 | Ok(openai::chats::Chat::default()) 209 | } else { 210 | log::debug!("Creating chat"); 211 | self.api.create().await 212 | } 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /crates/b/src/commands.rs: -------------------------------------------------------------------------------- 1 | use anthropic::complete::Response; 2 | use openai::chats::Chat; 3 | use openai::completions::Completions; 4 | use openai::edits::Edit; 5 | use serde::Serialize; 6 | 7 | use crate::anthropic::CompleteCreateCommand; 8 | use crate::chats::ChatsCreateCommand; 9 | use crate::completions::CompletionsCreateCommand; 10 | use crate::edits::EditsCreateCommand; 11 | use crate::tokenizer::{ 12 | TokenizerDecodeCommand, TokenizerDecodeResult, TokenizerEncodeCommand, TokenizerEncodeResult, 13 | }; 14 | use crate::{CommandError, CommandHandle, CommandResult}; 15 | 16 | pub enum CommandCallers { 17 | TokenizerDecode(TokenizerDecodeCommand), 18 | TokenizerEncode(TokenizerEncodeCommand), 19 | ChatsCreate(ChatsCreateCommand), 20 | EditsCreate(EditsCreateCommand), 21 | CompletionsCreate(CompletionsCreateCommand), 22 | AnthropicCompleteCreate(CompleteCreateCommand), 23 | } 24 | 25 | #[derive(Serialize)] 26 | pub enum CommandResults { 27 | TokenizerDecode(TokenizerDecodeResult), 28 | TokenizerEncode(TokenizerEncodeResult), 29 | ChatsCreate(Chat), 30 | EditsCreate(Edit), 31 | CompletionsCreate(Completions), 32 | AnthropicCompleteCreate(Response), 33 | } 34 | 35 | impl CommandCallers { 36 | pub async fn call(self) -> Result { 37 | match self { 38 | CommandCallers::CompletionsCreate(command) => match command.call().await { 39 | Ok(result) => Ok(CommandResults::CompletionsCreate(result)), 40 | Err(err) => Err(CommandError::from(err)), 41 | }, 42 | CommandCallers::ChatsCreate(command) => match command.call().await { 43 | Ok(result) => Ok(CommandResults::ChatsCreate(result)), 44 | Err(err) => Err(CommandError::from(err)), 45 | }, 46 | CommandCallers::EditsCreate(command) => match command.call().await { 47 | Ok(result) => Ok(CommandResults::EditsCreate(result)), 48 | Err(err) => Err(CommandError::from(err)), 49 | }, 50 | CommandCallers::TokenizerDecode(command) => match command.call().await { 51 | Ok(result) => Ok(CommandResults::TokenizerDecode(result)), 52 | Err(err) => Err(err), 53 | }, 54 | CommandCallers::TokenizerEncode(command) => match command.call().await { 55 | Ok(result) => Ok(CommandResults::TokenizerEncode(result)), 56 | Err(err) => Err(err), 57 | }, 58 | CommandCallers::AnthropicCompleteCreate(command) => match command.call().await { 59 | Ok(result) => Ok(CommandResults::AnthropicCompleteCreate(result)), 60 | Err(err) => Err(err), 61 | }, 62 | } 63 | } 64 | } 65 | 66 | impl CommandResult for CommandResults { 67 | type ResultError = CommandError; 68 | 69 | fn print_json(&self, w: W) -> Result<(), Self::ResultError> { 70 | match self { 71 | CommandResults::TokenizerEncode(result) => result.print_json(w), 72 | CommandResults::TokenizerDecode(result) => result.print_json(w), 73 | CommandResults::ChatsCreate(result) => result.print_json(w), 74 | CommandResults::EditsCreate(result) => result.print_json(w), 75 | CommandResults::CompletionsCreate(result) => result.print_json(w), 76 | CommandResults::AnthropicCompleteCreate(result) => result.print_json(w), 77 | } 78 | } 79 | 80 | fn print_yaml(&self, w: W) -> Result<(), Self::ResultError> { 81 | match self { 82 | CommandResults::TokenizerEncode(result) => result.print_yaml(w), 83 | CommandResults::TokenizerDecode(result) => result.print_yaml(w), 84 | CommandResults::ChatsCreate(result) => result.print_yaml(w), 85 | CommandResults::EditsCreate(result) => result.print_yaml(w), 86 | CommandResults::CompletionsCreate(result) => result.print_yaml(w), 87 | CommandResults::AnthropicCompleteCreate(result) => result.print_yaml(w), 88 | } 89 | } 90 | 91 | fn print_raw(&self, w: W) -> Result<(), Self::ResultError> { 92 | match self { 93 | CommandResults::TokenizerEncode(result) => result.print_raw(w), 94 | CommandResults::TokenizerDecode(result) => result.print_raw(w), 95 | CommandResults::ChatsCreate(result) => result.print_raw(w), 96 | CommandResults::EditsCreate(result) => result.print_raw(w), 97 | CommandResults::CompletionsCreate(result) => result.print_raw(w), 98 | CommandResults::AnthropicCompleteCreate(result) => result.print_raw(w), 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /crates/b/src/completions.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt::Write; 3 | use std::io::Read; 4 | use std::string::String; 5 | 6 | use async_trait::async_trait; 7 | use serde_either::SingleOrVec; 8 | 9 | use openai::completions::{Completions, CompletionsApi}; 10 | use openai::error::OpenAi as OpenAiError; 11 | 12 | use crate::{Cli, CommandError, CommandHandle, CommandResult, CompletionsCommands}; 13 | 14 | pub struct CompletionsCreateCommand { 15 | pub api: CompletionsApi, 16 | } 17 | 18 | impl CompletionsCreateCommand { 19 | pub fn new(cli: &Cli, command: &CompletionsCommands) -> Result> { 20 | match command { 21 | CompletionsCommands::Create { 22 | model, 23 | prompt, 24 | suffix, 25 | max_tokens, 26 | temperature, 27 | top_p, 28 | n, 29 | logprobs, 30 | echo, 31 | stop, 32 | presence_penalty, 33 | frequency_penalty, 34 | best_of, 35 | user, 36 | } => { 37 | let api_key = cli 38 | .openai_api_key 39 | .as_ref() 40 | .expect("No API key provided") 41 | .to_string(); 42 | let mut api = CompletionsApi::new(api_key)?; 43 | 44 | let mut stdin = Vec::new(); 45 | // Read from stdin if it's not a tty and don't forget to unlock `stdin` 46 | { 47 | let mut stdin_lock = std::io::stdin().lock(); 48 | stdin_lock.read_to_end(&mut stdin)?; 49 | } 50 | 51 | if !stdin.is_empty() { 52 | if prompt.is_empty() { 53 | api.prompt = Some(SingleOrVec::Single( 54 | String::from_utf8_lossy(&stdin).to_string(), 55 | )); 56 | } else { 57 | let mut first = String::new(); 58 | write!( 59 | first, 60 | "{}\n{}", 61 | String::from_utf8_lossy(&stdin), 62 | prompt.first().unwrap().clone(), 63 | )?; 64 | let mut clone = prompt.clone().iter().skip(1).cloned().collect::>(); 65 | clone.insert(0, first); 66 | api.prompt = Some(SingleOrVec::Vec(clone)); 67 | } 68 | } else { 69 | api.prompt = Some(SingleOrVec::Vec(prompt.clone())); 70 | } 71 | 72 | api.model = model.to_string(); 73 | api.max_tokens = *max_tokens; 74 | api.n = *n; 75 | 76 | if let Some(user) = user.as_ref() { 77 | api.user = Some(user.to_string()); 78 | } 79 | 80 | echo.map(|s| api.set_echo(s)); 81 | suffix.as_ref().map(|s| api.set_suffix(s.to_string())); 82 | logprobs.map(|s| api.set_logprobs(s)); 83 | stop.as_ref() 84 | .map(|s| api.set_stop(SingleOrVec::Vec(s.to_vec()))); 85 | presence_penalty.map(|s| api.set_presence_penalty(s)); 86 | frequency_penalty.map(|s| api.set_frequency_penalty(s)); 87 | best_of.map(|s| api.set_best_of(s)); 88 | temperature.map(|s| api.set_temperature(s)); 89 | top_p.map(|s| api.set_top_p(s)); 90 | 91 | _ = api.set_stream(cli.stream); 92 | 93 | Ok(Self { api }) 94 | } 95 | } 96 | } 97 | } 98 | 99 | impl CommandResult for Completions { 100 | type ResultError = CommandError; 101 | 102 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 103 | match self.choices.first() { 104 | Some(choice) => { 105 | write!(w, "{}", choice.text)?; 106 | Ok(()) 107 | } 108 | None => Err(CommandError::from(OpenAiError::NoChoices)), 109 | } 110 | } 111 | } 112 | 113 | #[async_trait] 114 | impl CommandHandle for CompletionsCreateCommand { 115 | type CallError = OpenAiError; 116 | 117 | async fn call(&self) -> Result { 118 | self.api.create().await 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /crates/b/src/edits.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | use openai::edits::{Edit, EditsApi}; 5 | use openai::error::OpenAi as OpenAiError; 6 | 7 | use crate::{Cli, CommandError, CommandHandle, CommandResult, EditsCommands}; 8 | 9 | pub struct EditsCreateCommand { 10 | pub api: EditsApi, 11 | } 12 | 13 | impl EditsCreateCommand { 14 | pub fn new(cli: &Cli, command: &EditsCommands) -> Result> { 15 | match command { 16 | EditsCommands::Create { 17 | model, 18 | input, 19 | instruction, 20 | n, 21 | temperature, 22 | top_p, 23 | } => { 24 | let api_key = cli 25 | .openai_api_key 26 | .as_ref() 27 | .expect("No API Key provided") 28 | .to_string(); 29 | let mut api = EditsApi::new(api_key)?; 30 | api.model = model.to_owned(); 31 | api.instruction = instruction.to_owned(); 32 | api.n = *n; 33 | 34 | if let Some(input) = input.as_ref() { 35 | api.input = input.to_owned(); 36 | } 37 | 38 | temperature.map(|s| api.set_temperature(s)); 39 | top_p.map(|s| api.set_top_p(s)); 40 | 41 | Ok(Self { api }) 42 | } 43 | } 44 | } 45 | } 46 | 47 | impl CommandResult for Edit { 48 | type ResultError = CommandError; 49 | 50 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 51 | match self.choices.first() { 52 | Some(choice) => { 53 | write!(w, "{}", choice.text)?; 54 | Ok(()) 55 | } 56 | None => Err(CommandError::from(OpenAiError::NoChoices)), 57 | } 58 | } 59 | } 60 | 61 | #[async_trait] 62 | impl CommandHandle for EditsCreateCommand { 63 | type CallError = OpenAiError; 64 | 65 | async fn call(&self) -> Result { 66 | self.api.create().await 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /crates/b/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | 3 | use b::anthropic::CompleteCreateCommand; 4 | use b::chats::ChatsCreateCommand; 5 | use b::commands::CommandCallers; 6 | use b::completions::CompletionsCreateCommand; 7 | use b::edits::EditsCreateCommand; 8 | use b::tokenizer::{TokenizerDecodeCommand, TokenizerEncodeCommand}; 9 | use b::utils::Spinner; 10 | use b::{Cli, CommandError, CommandResult, Commands, Output, TokenizerCommands}; 11 | 12 | #[tokio::main] 13 | async fn main() -> Result<(), CommandError> { 14 | env_logger::init(); 15 | 16 | let cli = Cli::parse(); 17 | 18 | let command = match cli.command { 19 | Some(Commands::Anthropic { ref command }) => CommandCallers::AnthropicCompleteCreate( 20 | CompleteCreateCommand::new(&cli, command).expect("Failed to parse command"), 21 | ), 22 | Some(Commands::Chats { ref command }) => CommandCallers::ChatsCreate( 23 | ChatsCreateCommand::new(&cli, command).expect("Failed to parse command"), 24 | ), 25 | Some(Commands::Edits { ref command }) => CommandCallers::EditsCreate( 26 | EditsCreateCommand::new(&cli, command).expect("Failed to parse command"), 27 | ), 28 | Some(Commands::Completions { ref command }) => CommandCallers::CompletionsCreate( 29 | CompletionsCreateCommand::new(&cli, command).expect("Failed to parse command"), 30 | ), 31 | Some(Commands::Tokenizer { ref command }) => match command { 32 | TokenizerCommands::Encode { ref prompt } => CommandCallers::TokenizerEncode( 33 | TokenizerEncodeCommand::new(&cli, prompt.to_string()), 34 | ), 35 | TokenizerCommands::Decode { ref encoded } => { 36 | CommandCallers::TokenizerDecode(TokenizerDecodeCommand::new(&cli, encoded.to_vec())) 37 | } 38 | }, 39 | None => { 40 | std::process::exit(1); 41 | } 42 | }; 43 | 44 | let mut spinner = Spinner::new(cli.silent || cli.stream); 45 | 46 | let result = match command.call().await { 47 | Ok(result) => { 48 | spinner.ok(); 49 | result 50 | } 51 | Err(e) => { 52 | spinner.err(&e.to_string()); 53 | std::process::exit(1); 54 | } 55 | }; 56 | 57 | match cli.output { 58 | Output::Json => { 59 | result.print_json(std::io::stdout())?; 60 | } 61 | Output::Yaml => { 62 | result.print_yaml(std::io::stdout())?; 63 | } 64 | Output::Raw => { 65 | if !cli.stream { 66 | result.print_raw(std::io::stdout())?; 67 | } 68 | } 69 | } 70 | 71 | Ok(()) 72 | } 73 | -------------------------------------------------------------------------------- /crates/b/src/tokenizer.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use gpt_tokenizer::Default as DefaultTokenizer; 3 | use serde::Serialize; 4 | 5 | use crate::{Cli, CommandError, CommandHandle, CommandResult}; 6 | 7 | pub struct TokenizerEncodeCommand { 8 | tokenizer: DefaultTokenizer, 9 | prompt: String, 10 | } 11 | 12 | #[derive(Serialize)] 13 | pub struct TokenizerEncodeResult { 14 | pub value: Vec, 15 | } 16 | 17 | impl TokenizerEncodeCommand { 18 | pub fn new(_: &Cli, prompt: String) -> Self { 19 | Self { 20 | prompt, 21 | tokenizer: DefaultTokenizer::new(), 22 | } 23 | } 24 | } 25 | 26 | impl CommandResult for TokenizerEncodeResult { 27 | type ResultError = CommandError; 28 | 29 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 30 | for value in &self.value { 31 | write!(w, "{} ", value)?; 32 | } 33 | Ok(()) 34 | } 35 | } 36 | 37 | #[async_trait] 38 | impl CommandHandle for TokenizerEncodeCommand { 39 | type CallError = CommandError; 40 | 41 | async fn call(&self) -> Result { 42 | let value = self.tokenizer.encode(&self.prompt.to_string()); 43 | Ok(TokenizerEncodeResult { value }) 44 | } 45 | } 46 | 47 | pub struct TokenizerDecodeCommand { 48 | tokenizer: DefaultTokenizer, 49 | encoded: Vec, 50 | } 51 | 52 | #[derive(Serialize)] 53 | pub struct TokenizerDecodeResult { 54 | pub value: String, 55 | } 56 | 57 | impl TokenizerDecodeCommand { 58 | pub fn new(_: &Cli, encoded: Vec) -> Self { 59 | Self { 60 | encoded, 61 | tokenizer: DefaultTokenizer::new(), 62 | } 63 | } 64 | } 65 | 66 | impl CommandResult for TokenizerDecodeResult { 67 | type ResultError = CommandError; 68 | 69 | fn print_raw(&self, mut w: W) -> Result<(), Self::ResultError> { 70 | write!(w, "{}", self.value)?; 71 | Ok(()) 72 | } 73 | } 74 | 75 | #[async_trait] 76 | impl CommandHandle for TokenizerDecodeCommand { 77 | type CallError = CommandError; 78 | 79 | async fn call(&self) -> Result { 80 | let value = self.tokenizer.decode(&self.encoded); 81 | Ok(TokenizerDecodeResult { value }) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /crates/b/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::io::Read; 2 | use std::time::Duration; 3 | 4 | use indicatif::{ProgressBar, ProgressStyle}; 5 | 6 | /// Spinner state 7 | enum SpinnerState { 8 | /// Spinner is running 9 | Running, 10 | /// Spinner is stopped 11 | Stopped, 12 | /// Spinner is silent 13 | Silent, 14 | /// Spinner is errored 15 | Errored, 16 | } 17 | 18 | pub struct Spinner { 19 | progress_bar: ProgressBar, 20 | state: SpinnerState, 21 | } 22 | 23 | impl Spinner { 24 | /// Creates a new Spinner 25 | pub fn new(silent: bool) -> Self { 26 | let progress_bar = if silent { 27 | ProgressBar::hidden() 28 | } else { 29 | let progress_bar = ProgressBar::new_spinner(); 30 | progress_bar.enable_steady_tick(Duration::from_millis(100)); 31 | progress_bar.set_style( 32 | ProgressStyle::with_template("{spinner:.magenta} {msg}") 33 | .unwrap() 34 | .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]), 35 | ); 36 | progress_bar 37 | }; 38 | Self { 39 | state: if silent { 40 | SpinnerState::Silent 41 | } else { 42 | SpinnerState::Running 43 | }, 44 | progress_bar, 45 | } 46 | } 47 | 48 | pub fn print(&mut self, msg: &str) { 49 | if let SpinnerState::Running = self.state { 50 | self.progress_bar.suspend(|| { 51 | print!("{}", msg); 52 | }); 53 | } 54 | } 55 | 56 | /// Stops the spinner successfully 57 | pub fn ok(&mut self) { 58 | if let SpinnerState::Running = self.state { 59 | self.state = SpinnerState::Stopped; 60 | self.progress_bar.finish_and_clear() 61 | } 62 | } 63 | 64 | /// Stops the spinner with an error 65 | pub fn err(&mut self, msg: &str) { 66 | if let SpinnerState::Running = self.state { 67 | self.progress_bar.abandon_with_message(msg.to_string()); 68 | self.state = SpinnerState::Errored; 69 | } 70 | } 71 | } 72 | 73 | /// Reads stdin and retusn a string with its content. 74 | pub fn read_from_stdin() -> Result { 75 | let mut stdin = Vec::new(); 76 | log::debug!("Reading from stdin..."); 77 | let mut lock = std::io::stdin().lock(); 78 | lock.read_to_end(&mut stdin)?; 79 | Ok(String::from_utf8_lossy(&stdin).to_string()) 80 | } 81 | -------------------------------------------------------------------------------- /crates/c/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "c" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "A CLI tool to interact with ChatGPT as a Software Engineer" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [[bin]] 13 | name = "c" 14 | path = "src/main.rs" 15 | 16 | [lib] 17 | name = "c" 18 | path = "src/lib.rs" 19 | 20 | [dependencies] 21 | gpt_tokenizer = { version = "0.1.0", path = "../../lib/gpt_tokenizer" } 22 | openai = { path = "../openai", version = "0.0.0" } 23 | anthropic = { path = "../anthropic", version = "0.0.0" } 24 | spinner = { path = "../spinner", version = "0.0.0" } 25 | clap = { version = "4.1.8", features = ["derive", "env"] } 26 | serde = { version = "1.0.152", features = ["derive"] } 27 | serde_json = "1.0.93" 28 | serde_yaml = "0.9.19" 29 | color-eyre = "0.6.2" 30 | tracing = { version = "0.1.37", features = ["max_level_debug", "release_max_level_warn"] } 31 | tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } 32 | tokio = "1.29.1" 33 | tokio-stream = "0.1.14" 34 | indicatif = "0.17.5" 35 | ulid = "1.0.0" 36 | reqwest = { version = "0.11.16", features = ["json"] } 37 | reqwest-eventsource = "0.4.0" 38 | -------------------------------------------------------------------------------- /crates/c/README.md: -------------------------------------------------------------------------------- 1 | # c - Interact with Chat AI APIs 2 | 3 | Interact with AI chatbots through the command-line interface (CLI). 4 | 5 | ## Commands 6 | 7 | ### `anthropic` 8 | 9 | Uses Anthropic's Claude API. 10 | 11 | ``` 12 | c anthropic [OPTIONS] --anthropic-api-key [PROMPT] 13 | ``` 14 | 15 | #### Arguments 16 | 17 | | Argument | Description | 18 | |-|-| 19 | | `PROMPT` | The prompt text to send to Claude. | 20 | 21 | #### Options 22 | 23 | | Option | Description | 24 | |-|-| 25 | | `--session` | Chat session name to store context. | 26 | | `--model` | Claude model to use (claude-v1, claude2, etc). | 27 | | `--max-tokens` | Max tokens to generate. | 28 | | `--temperature` | Randomness of response. | 29 | | `--top-k` | Only sample from top k tokens. | 30 | | `--top-p` | Nucleus sampling top-p. | 31 | | `--stop-sequences` | Strings to stop generation. | 32 | | `--anthropic-api-key` | Anthropic API key. | 33 | | `--silent` | Silent mode. | 34 | | `--stream` | Stream response incrementally. | 35 | | `--pin` | Pin message to history. | 36 | | `--format` | Output format (raw, json, yaml). | 37 | | `-h, --help` | Print help. | 38 | 39 | 40 | ### `openai` 41 | 42 | Uses OpenAI's GPT API. 43 | 44 | ``` 45 | c openai [OPTIONS] --openai-api-key [PROMPT] 46 | ``` 47 | 48 | #### Arguments 49 | 50 | | Argument | Description | 51 | |-|-| 52 | | `PROMPT` | The prompt text to send GPT. | 53 | 54 | #### Options 55 | 56 | | Option | Description | 57 | |-|-| 58 | | `--session` | Chat session name to store context. | 59 | | `--model` | GPT model to use (gpt3, gpt4, etc). | 60 | | `--max-tokens` | Max tokens to generate. | 61 | | `--temperature` | Randomness of response. | 62 | | `--top-p` | Nucleus sampling top-p. | 63 | | `--stop` | Sequences to stop generation. | 64 | | `--openai-api-key` | OpenAI API key. | 65 | | `--silent` | Silent mode. | 66 | | `--stream` | Stream response incrementally. | 67 | | `--pin` | Pin message to history. | 68 | | `--format` | Output format (raw, json, yaml). | 69 | | `-h, --help` | Print help. | 70 | 71 | ### `vertex` 72 | 73 | Uses Google Vertex AI Code API. 74 | 75 | ``` 76 | c vertex [OPTIONS] --google-api-key [PROMPT] 77 | ``` 78 | 79 | #### Arguments 80 | 81 | | Argument | Description | 82 | |-|-| 83 | | `PROMPT` | The prompt text to send Vertex AI. | 84 | 85 | #### Options 86 | 87 | | Option | Description | 88 | |-|-| 89 | | `--session` | Chat session name to store context. | 90 | | `--model` | Vertex model to use. | 91 | | `--max-tokens` | Max tokens to generate. | 92 | | `--temperature` | Randomness of response. | 93 | | `--top-k` | Only sample from top k tokens. | 94 | | `--top-p` | Nucleus sampling top-p. | 95 | | `--stop-sequences` | Strings to stop generation. | 96 | | `--google-api-key` | Google Vertex API key. | 97 | | `--silent` | Silent mode. | 98 | | `--stream` | Stream response incrementally. | 99 | | `--pin` | Pin message to history. | 100 | | `--format` | Output format (raw, json, yaml). | 101 | | `-h, --help` | Print help. | 102 | 103 | -------------------------------------------------------------------------------- /crates/c/src/commands.rs: -------------------------------------------------------------------------------- 1 | pub mod anthropic; 2 | pub mod nlpcloud; 3 | pub mod ollama; 4 | pub mod openai; 5 | pub mod vertex; 6 | -------------------------------------------------------------------------------- /crates/c/src/lib.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, Subcommand, ValueEnum}; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | pub mod commands; 6 | pub mod session; 7 | pub mod utils; 8 | 9 | #[derive(Debug, Parser)] 10 | #[command(name = "v2")] 11 | #[command(about = "Interact with OpenAI's ChatGPT through the terminal")] 12 | pub struct Cli { 13 | #[command(subcommand)] 14 | pub command: Option, 15 | } 16 | 17 | #[derive(Debug, Subcommand)] 18 | pub enum Commands { 19 | /// Anthropic Chat AI API 20 | #[clap(alias = "a")] 21 | Anthropic(commands::anthropic::CommandOptions), 22 | /// OpenAi Chat AI API 23 | #[clap(name = "openai", alias = "o")] 24 | OpenAi(commands::openai::CommandOptions), 25 | /// Google Vertex AI Chat Code API 26 | #[clap(name = "vertex", alias = "v")] 27 | Vertex(commands::vertex::CommandOptions), 28 | /// NLPCloud AI Chat Bot API 29 | #[clap(name = "nlpcloud", alias = "n")] 30 | NLPCloud(commands::nlpcloud::CommandOptions), 31 | /// Ollama AI Chat Bot API 32 | #[clap(name = "ollama", alias = "l")] 33 | Ollama(commands::ollama::CommandOptions), 34 | } 35 | 36 | #[derive(Default, ValueEnum, Debug, Clone, Serialize, Deserialize)] 37 | #[clap(rename_all = "kebab-case")] 38 | pub enum Output { 39 | #[default] 40 | /// Plain text 41 | Raw, 42 | /// JSON 43 | Json, 44 | /// YAML 45 | Yaml, 46 | } 47 | -------------------------------------------------------------------------------- /crates/c/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | 3 | #[tokio::main] 4 | async fn main() -> color_eyre::eyre::Result<()> { 5 | color_eyre::install()?; 6 | tracing_subscriber::fmt::fmt().init(); 7 | 8 | run().await?; 9 | 10 | Ok(()) 11 | } 12 | 13 | /// Run the program 14 | async fn run() -> color_eyre::eyre::Result<()> { 15 | match c::Cli::parse().command { 16 | Some(c::Commands::Anthropic(options)) => { 17 | let span = tracing::span!(tracing::Level::INFO, "Anthropic"); 18 | let _enter = span.enter(); 19 | c::commands::anthropic::run(options).await?; 20 | } 21 | Some(c::Commands::OpenAi(options)) => c::commands::openai::run(options).await?, 22 | Some(c::Commands::Vertex(options)) => c::commands::vertex::run(options).await?, 23 | Some(c::Commands::NLPCloud(options)) => c::commands::nlpcloud::run(options).await?, 24 | Some(c::Commands::Ollama(options)) => c::commands::ollama::run(options).await?, 25 | None => { 26 | color_eyre::eyre::bail!( 27 | "No subcommand provided. Use --help to see available subcommands." 28 | ) 29 | } 30 | } 31 | 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /crates/c/src/session.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::fs; 3 | use std::path; 4 | 5 | use color_eyre::eyre::Result; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | /// Chat LLM Vendor 9 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 10 | pub enum Vendor { 11 | #[default] 12 | OpenAI, 13 | Anthropic, 14 | Google, 15 | NLPCloud, 16 | } 17 | 18 | /// Chat LLM Role 19 | #[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] 20 | #[serde(rename_all = "kebab-case")] 21 | pub enum Role { 22 | #[default] 23 | Human, 24 | User, 25 | Assistant, 26 | System, 27 | } 28 | 29 | /// Represents a chat message 30 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 31 | pub struct Message { 32 | pub content: String, 33 | pub role: Role, 34 | pub pin: bool, 35 | } 36 | 37 | impl Message { 38 | /// Creates a new message 39 | pub fn new(content: String, role: Role, pin: bool) -> Self { 40 | Self { content, role, pin } 41 | } 42 | } 43 | 44 | /// Important data that are provided on each invocation 45 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 46 | pub struct Meta { 47 | path: String, 48 | pub format: crate::Output, 49 | pub key: String, 50 | pub pin: bool, 51 | pub reverse: bool, 52 | pub history_size: Option, 53 | pub save: bool, 54 | pub silent: bool, 55 | pub stream: bool, 56 | } 57 | 58 | /// Represents a chat session 59 | #[derive(Debug, Clone, Default, Serialize, Deserialize)] 60 | pub struct Session { 61 | id: String, 62 | vendor: Vendor, 63 | pub history: Vec, 64 | pub options: T, 65 | pub max_supported_tokens: u32, 66 | pub max_history: Option, 67 | #[serde(skip)] 68 | pub meta: Meta, 69 | } 70 | 71 | impl Deserialize<'a>> Session { 72 | /// Creates a new anonymous session 73 | pub fn anonymous(vendor: Vendor, options: T, max_supported_tokens: u32) -> Session { 74 | let id = ulid::Ulid::new().to_string(); 75 | let home = env::var("C_ROOT").unwrap_or(env::var("HOME").unwrap()); 76 | let path = format!("{home}/.c/sessions/anonymous/{id}.yaml"); 77 | Self { 78 | vendor, 79 | max_supported_tokens, 80 | options, 81 | meta: Meta { 82 | path, 83 | ..Default::default() 84 | }, 85 | ..Default::default() 86 | } 87 | } 88 | 89 | /// Creates a new session 90 | pub fn new(id: String, vendor: Vendor, options: T, max_supported_tokens: u32) -> Session { 91 | let home = env::var("C_ROOT").unwrap_or(env::var("HOME").unwrap()); 92 | let path = format!("{home}/.c/sessions/{id}.yaml"); 93 | Self { 94 | id, 95 | vendor, 96 | max_supported_tokens, 97 | options, 98 | meta: Meta { 99 | path, 100 | ..Default::default() 101 | }, 102 | ..Default::default() 103 | } 104 | } 105 | 106 | /// Checks if a session exists 107 | pub fn exists(id: &str) -> bool { 108 | let home = env::var("C_ROOT").unwrap_or(env::var("HOME").unwrap()); 109 | let path = format!("{home}/.c/sessions/{id}.yaml"); 110 | fs::metadata(path).is_ok() 111 | } 112 | 113 | /// Tries to load a session from a file 114 | pub fn load(id: &str) -> Result> { 115 | let home = env::var("C_ROOT").unwrap_or(env::var("HOME")?); 116 | let path = format!("{home}/.c/sessions/{id}.yaml"); 117 | 118 | let meta = Meta { 119 | path: path.clone(), 120 | ..Default::default() 121 | }; 122 | 123 | let session = if fs::metadata(&path).is_ok() { 124 | let content = fs::read_to_string(&path)?; 125 | let mut session: Session = serde_yaml::from_str(&content)?; 126 | session.meta = meta; 127 | session 128 | } else { 129 | Err(color_eyre::eyre::eyre!("Session not found"))? 130 | }; 131 | 132 | Ok(session) 133 | } 134 | 135 | /// Saves the session to the filesystem 136 | pub fn save(&self) -> Result<()> { 137 | tracing::event!( 138 | tracing::Level::INFO, 139 | "saving session to {:?}", 140 | self.meta.path 141 | ); 142 | let parent = path::Path::new(&self.meta.path) 143 | .parent() 144 | .unwrap() 145 | .to_str() 146 | .unwrap(); 147 | 148 | if !directory_exists(parent) { 149 | fs::create_dir_all(parent)?; 150 | } 151 | 152 | fs::write(&self.meta.path, serde_yaml::to_string(&self)?)?; 153 | Ok(()) 154 | } 155 | } 156 | 157 | /// Chacks if a directory exists. 158 | pub fn directory_exists(dir_name: &str) -> bool { 159 | let p = path::Path::new(dir_name); 160 | p.exists() && p.is_dir() 161 | } 162 | -------------------------------------------------------------------------------- /crates/c/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::io::Read; 2 | use std::ops::RangeInclusive; 3 | 4 | /// Reads stdin and retusn a string with its content. 5 | pub fn read_from_stdin() -> Result { 6 | let mut stdin = Vec::new(); 7 | tracing::event!(tracing::Level::INFO, "Reading from stdin..."); 8 | let mut lock = std::io::stdin().lock(); 9 | lock.read_to_end(&mut stdin)?; 10 | Ok(String::from_utf8_lossy(&stdin).to_string()) 11 | } 12 | 13 | /// The range of values for the `temperature` option which goes from 0 to 1. 14 | const TEMPERATURE_RANGE: RangeInclusive = 0.0..=2.0; 15 | /// The range of values for the `top_p` option which goes from 0 to 1. 16 | const TOP_P_RANGE: RangeInclusive = 0.0..=1.0; 17 | /// The range of values for the `top_k` option which goes from 0 to Infinity. 18 | const TOP_K_RANGE: RangeInclusive = 0.0..=f32::INFINITY; 19 | /// The range of values for the `repetition_penalty` option which goes from 0.0 to 8.0. 20 | const REPETITION_PENALTY_RANGE: RangeInclusive = 0.0..=8.0; 21 | /// The range of values for the `repetition_penalty_range` option which goes from 0 to 2048. 22 | const REPETITION_PENALTY_RANGE_RANGE: RangeInclusive = 0..=2048; 23 | 24 | /// Parses the temperature value. 25 | pub fn parse_temperature(s: &str) -> std::result::Result { 26 | let value = s.parse::().map_err(|_| { 27 | format!( 28 | "`{s}` must be a number between {} and {}", 29 | TEMPERATURE_RANGE.start(), 30 | TEMPERATURE_RANGE.end() 31 | ) 32 | })?; 33 | if !TEMPERATURE_RANGE.contains(&value) { 34 | return Err(format!( 35 | "`{s}` must be a number between {} and {}", 36 | TEMPERATURE_RANGE.start(), 37 | TEMPERATURE_RANGE.end() 38 | )); 39 | } 40 | Ok(value) 41 | } 42 | 43 | /// Parses the top_p value. 44 | pub fn parse_top_p(s: &str) -> std::result::Result { 45 | let value = s.parse::().map_err(|_| { 46 | format!( 47 | "`{s}` must be a number between {} and {}", 48 | TOP_P_RANGE.start(), 49 | TOP_P_RANGE.end() 50 | ) 51 | })?; 52 | if !TOP_P_RANGE.contains(&value) { 53 | return Err(format!( 54 | "`{s}` must be a number between {} and {}", 55 | TOP_P_RANGE.start(), 56 | TOP_P_RANGE.end() 57 | )); 58 | } 59 | Ok(value) 60 | } 61 | 62 | /// Parse a single key-value pair 63 | pub fn parse_key_val( 64 | s: &str, 65 | ) -> Result<(T, U), Box> 66 | where 67 | T: std::str::FromStr, 68 | T::Err: std::error::Error + Send + Sync + 'static, 69 | U: std::str::FromStr, 70 | U::Err: std::error::Error + Send + Sync + 'static, 71 | { 72 | let pos = s 73 | .find('=') 74 | .ok_or_else(|| format!("Invalid key-value pair: {}", s))?; 75 | Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) 76 | } 77 | 78 | /// Parses the top_k value. 79 | pub fn parse_top_k(s: &str) -> std::result::Result { 80 | let value = s.parse::().map_err(|_| { 81 | format!( 82 | "`{s}` must be a number between {} and {}", 83 | TOP_K_RANGE.start(), 84 | TOP_K_RANGE.end() 85 | ) 86 | })?; 87 | if !TOP_K_RANGE.contains(&value) { 88 | return Err(format!( 89 | "`{s}` must be a number between {} and {}", 90 | TOP_K_RANGE.start(), 91 | TOP_K_RANGE.end() 92 | )); 93 | } 94 | Ok(value) 95 | } 96 | 97 | /// Parses the repetition_penalty value. 98 | pub fn parse_repetition_penalty(s: &str) -> std::result::Result { 99 | let value = s.parse::().map_err(|_| { 100 | format!( 101 | "`{s}` must be a number between {} and {}", 102 | REPETITION_PENALTY_RANGE.start(), 103 | REPETITION_PENALTY_RANGE.end() 104 | ) 105 | })?; 106 | if !REPETITION_PENALTY_RANGE.contains(&value) { 107 | return Err(format!( 108 | "`{s}` must be a number between {} and {}", 109 | REPETITION_PENALTY_RANGE.start(), 110 | REPETITION_PENALTY_RANGE.end() 111 | )); 112 | } 113 | Ok(value) 114 | } 115 | 116 | /// Parses the repetition_penalty_range value. 117 | pub fn parse_repetition_penalty_range(s: &str) -> std::result::Result { 118 | let value = s.parse::().map_err(|_| { 119 | format!( 120 | "`{s}` must be a number between {} and {}", 121 | REPETITION_PENALTY_RANGE_RANGE.start(), 122 | REPETITION_PENALTY_RANGE_RANGE.end() 123 | ) 124 | })?; 125 | if !REPETITION_PENALTY_RANGE_RANGE.contains(&value) { 126 | return Err(format!( 127 | "`{s}` must be a number between {} and {}", 128 | REPETITION_PENALTY_RANGE_RANGE.start(), 129 | REPETITION_PENALTY_RANGE_RANGE.end() 130 | )); 131 | } 132 | Ok(value) 133 | } 134 | 135 | /// Takes in a list of messages and returns two new lists, one with messages with `pin == true` or 136 | /// `role == crate::session::Role::System` and the other with messages without `pin = true` or `role == crate::session::Role::System`. 137 | pub fn split_messages( 138 | messages: &[crate::session::Message], 139 | ) -> (Vec, Vec) { 140 | let pinned: Vec = messages 141 | .iter() 142 | .filter(|m| m.pin || m.role == crate::session::Role::System) 143 | .cloned() 144 | .collect(); 145 | 146 | let unpinned: Vec = messages 147 | .iter() 148 | .filter(|m| !m.pin && m.role != crate::session::Role::System) 149 | .cloned() 150 | .collect(); 151 | 152 | (pinned, unpinned) 153 | } 154 | 155 | /// Filter the Session History message so that only the last `n` messages without `pin = true` and 156 | /// `role == crate::session::Role::System` are returned. 157 | pub fn filter_history( 158 | messages: &[crate::session::Message], 159 | n: usize, 160 | ) -> Vec { 161 | let (mut pinned, mut unpinned) = split_messages(messages); 162 | let len = unpinned.len(); 163 | 164 | if len > n { 165 | unpinned = unpinned.drain(len - n..len).collect(); 166 | } 167 | 168 | pinned.append(&mut unpinned); 169 | 170 | pinned 171 | } 172 | -------------------------------------------------------------------------------- /crates/d/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "d" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "A CLI tool to interact with ChatGPT as a Software Engineer" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [[bin]] 13 | name = "d" 14 | path = "src/main.rs" 15 | 16 | [dependencies] 17 | clap = { version = "4.4.14", features = ["derive", "env"] } 18 | serde = { version = "1.0.195", features = ["derive"] } 19 | serde_json = "1.0.111" 20 | serde_yaml = "0.9.30" 21 | color-eyre = "0.6.2" 22 | tokio = "1.35.1" 23 | env_logger = "0.11.3" 24 | log = "0.4.20" 25 | openai = "1.0.0-alpha.13" 26 | bat = { version = "0.24.0", path = "../../lib/bat" } 27 | crossterm = "0.28.0" 28 | atty = "0.2" 29 | spinners = "4.1.1" 30 | rayon = "1.8.0" 31 | lazy_static = "1.4.0" 32 | thiserror = "1.0.56" 33 | bincode = "1.3.3" 34 | rusqlite = { version = "0.32.1", features = ["bundled"] } 35 | uuid = { version = "1.6.1", features = ["v4"] } 36 | -------------------------------------------------------------------------------- /crates/d/README.md: -------------------------------------------------------------------------------- 1 | # d 2 | -------------------------------------------------------------------------------- /crates/d/src/commands.rs: -------------------------------------------------------------------------------- 1 | pub mod chat; 2 | pub mod embeddings; 3 | pub mod sessions; 4 | pub mod vector; 5 | -------------------------------------------------------------------------------- /crates/d/src/commands/chat.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | 3 | use clap::Parser; 4 | use color_eyre::eyre::{bail, Result}; 5 | use crossterm::{cursor, execute}; 6 | use openai::chat::{ChatCompletion, ChatCompletionDelta, ChatCompletionMessageRole}; 7 | use serde::{Deserialize, Serialize}; 8 | use tokio::sync::mpsc::Receiver; 9 | 10 | use crate::models::Model; 11 | use crate::sessions::Session; 12 | 13 | #[derive(Default, Clone, Parser, Debug, Serialize, Deserialize)] 14 | pub struct Options { 15 | /// The content of the message to be sent to the chatbot. You can also populate this value 16 | /// from stdin. If you pass a value here and pipe data from stdin, both will be sent to the 17 | /// API, stdin taking precedence. 18 | prompt: Option, 19 | /// ID of the model to use. 20 | #[clap(short, long, value_enum)] 21 | model: Option, 22 | /// Chat session name. Will be used to store previous session interactions. 23 | #[arg(long)] 24 | session: Option, 25 | /// DB collection where the new messages should be stored. 26 | #[arg(long)] 27 | collection: Option, 28 | /// The system message helps set the behavior of the assistant. 29 | #[arg(long)] 30 | system: Option, 31 | /// The temperature value to use for the session. 32 | #[arg(long)] 33 | temperature: Option, 34 | /// The top_p value to use for the session 35 | #[arg(long)] 36 | top_p: Option, 37 | /// The max_tokens value to use for the session 38 | #[arg(long)] 39 | max_tokens: Option, 40 | /// Don't perform the request and instead print the session to stdout. 41 | #[arg(long)] 42 | dry_run: bool, 43 | } 44 | 45 | /// Reads from `stdin` 46 | fn read_stdin() -> Result { 47 | let mut input = String::new(); 48 | 49 | // Read the entire contents of stdin into the string 50 | match std::io::stdin().read_to_string(&mut input) { 51 | Ok(_) => Ok(input), 52 | Err(e) => { 53 | bail!("[read_stdin] Error: {e}"); 54 | } 55 | } 56 | } 57 | 58 | /// Runs the `openai` command 59 | pub async fn run(mut options: Options) -> Result<()> { 60 | let mut session = match options.session { 61 | Some(session) => Session::load(session)?, 62 | None => Session::new(), 63 | }; 64 | 65 | if let Some(collection) = options.collection.take() { 66 | session.collection = Some(collection); 67 | } 68 | 69 | if let Some(model) = options.model.take() { 70 | session.model = model; 71 | } 72 | 73 | if let Some(temperature) = options.temperature.take() { 74 | session.set_temperature(temperature)?; 75 | } 76 | 77 | if let Some(top_p) = options.top_p.take() { 78 | session.set_top_p(top_p)?; 79 | } 80 | 81 | if let Some(max_tokens) = options.max_tokens.take() { 82 | session.set_max_tokens(max_tokens)?; 83 | } 84 | 85 | if let Some(system) = options.system.take() { 86 | session.system(system); 87 | } 88 | 89 | if !atty::is(atty::Stream::Stdin) { 90 | session.push(read_stdin()?, ChatCompletionMessageRole::User) 91 | } 92 | 93 | match options.prompt { 94 | Some(prompt) => { 95 | if !prompt.is_empty() { 96 | session.push(prompt, ChatCompletionMessageRole::User) 97 | } 98 | } 99 | None => { 100 | if atty::is(atty::Stream::Stdin) { 101 | print!("User: "); 102 | 103 | std::io::stdout().flush()?; 104 | 105 | let mut user_message_content = String::new(); 106 | 107 | std::io::stdin().read_line(&mut user_message_content)?; 108 | 109 | session.push( 110 | user_message_content.to_string(), 111 | ChatCompletionMessageRole::User, 112 | ) 113 | } 114 | } 115 | } 116 | 117 | if options.dry_run { 118 | println!("{}", serde_yaml::to_string(&session)?); 119 | return Ok(()); 120 | } 121 | 122 | let messages = session.completion_messages(); 123 | let chat_stream = ChatCompletionDelta::builder( 124 | options.model.unwrap_or(Model::GPT41106Preview).as_str(), 125 | messages, 126 | ) 127 | .temperature(session.get_temperature()) 128 | .top_p(session.get_top_p()) 129 | .max_tokens(session.get_max_tokens()) 130 | .create_stream() 131 | .await?; 132 | 133 | let chat_completion: ChatCompletion = listen_for_tokens(chat_stream).await?; 134 | let returned_message = chat_completion 135 | .choices 136 | .first() 137 | .expect("A response choice was expected") 138 | .message 139 | .clone(); 140 | 141 | session.push( 142 | returned_message.content.unwrap(), 143 | ChatCompletionMessageRole::Assistant, 144 | ); 145 | 146 | session.save().await?; 147 | 148 | Ok(()) 149 | } 150 | 151 | /// Handle streaming output 152 | async fn listen_for_tokens( 153 | mut chat_stream: Receiver, 154 | ) -> Result { 155 | let mut merged: Option = None; 156 | let mut previous_output = String::new(); 157 | let mut accumulated_content_bytes = Vec::new(); 158 | let mut sp: Option = None; 159 | 160 | if atty::is(atty::Stream::Stdout) { 161 | sp = Some(spinners::Spinner::new( 162 | spinners::Spinners::OrangeBluePulse, 163 | "Loading...".into(), 164 | )); 165 | } 166 | 167 | while let Some(delta) = chat_stream.recv().await { 168 | let choice = &delta.choices[0]; 169 | 170 | if let Some(content) = &choice.delta.content { 171 | if atty::is(atty::Stream::Stdout) && sp.is_some() { 172 | sp.take().unwrap().stop(); 173 | std::io::stdout().flush()?; 174 | } 175 | 176 | accumulated_content_bytes.extend_from_slice(content.as_bytes()); 177 | 178 | let output = crate::printer::CustomPrinter::new() 179 | .input_from_bytes(&accumulated_content_bytes) 180 | .print()?; 181 | 182 | let unprinted_lines = output 183 | .lines() 184 | .skip(if previous_output.lines().count() == 0 { 185 | 0 186 | } else { 187 | previous_output.lines().count() - 1 188 | }) 189 | .collect::>() 190 | .join("\n"); 191 | 192 | execute!(std::io::stdout(), cursor::MoveToColumn(0))?; 193 | print!("{unprinted_lines}"); 194 | std::io::stdout().flush()?; 195 | 196 | // Update the previous output 197 | previous_output = output; 198 | } 199 | 200 | if choice.finish_reason.is_some() { 201 | // The message being streamed has been fully received. 202 | println!(); 203 | } 204 | 205 | // Merge completion into accrued. 206 | match merged.as_mut() { 207 | Some(c) => { 208 | c.merge(delta).unwrap(); 209 | } 210 | None => merged = Some(delta), 211 | }; 212 | } 213 | 214 | std::io::stdout().flush()?; 215 | 216 | Ok(merged.unwrap().into()) 217 | } 218 | -------------------------------------------------------------------------------- /crates/d/src/commands/embeddings.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use color_eyre::eyre::Result; 3 | use openai::embeddings::Embedding; 4 | use std::io::Write; 5 | 6 | use crate::constants::MODEL; 7 | 8 | #[derive(Default, Clone, Parser, Debug)] 9 | pub struct Options { 10 | /// Input text to get embeddings for. 11 | input: String, 12 | } 13 | 14 | pub async fn run(options: Options) -> Result<()> { 15 | let mut sp: Option = None; 16 | 17 | if atty::is(atty::Stream::Stdout) { 18 | sp = Some(spinners::Spinner::new( 19 | spinners::Spinners::OrangeBluePulse, 20 | "Loading...".into(), 21 | )); 22 | } 23 | 24 | let embedding = Embedding::create(MODEL, &options.input, &String::default()).await?; 25 | 26 | if atty::is(atty::Stream::Stdout) && sp.is_some() { 27 | sp.take().unwrap().stop(); 28 | // Flush `stdout` 29 | std::io::stdout().flush()?; 30 | } 31 | 32 | // Concatenate all the values of embedding into a string separated by a comma 33 | let embedding = embedding 34 | .vec 35 | .iter() 36 | .map(|x| x.to_string()) 37 | .collect::>() 38 | .join(","); 39 | 40 | println!("{embedding}"); 41 | 42 | Ok(()) 43 | } 44 | -------------------------------------------------------------------------------- /crates/d/src/commands/sessions.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, Subcommand}; 2 | use color_eyre::eyre::Result; 3 | 4 | #[derive(Debug, Subcommand)] 5 | pub enum Commands { 6 | /// List the available sessions 7 | #[clap(name = "list")] 8 | List, 9 | /// Reads a session or a session message 10 | #[clap(name = "read")] 11 | Read(ReadOptions), 12 | } 13 | 14 | #[derive(Default, Clone, Parser, Debug)] 15 | pub struct ReadOptions { 16 | /// Session name 17 | session: String, 18 | /// Session message id 19 | #[clap(short, long)] 20 | id: Option, 21 | } 22 | 23 | #[derive(Debug, Parser)] 24 | #[command(name = "sessions")] 25 | #[command(about = "Manage sessions")] 26 | pub struct Cli { 27 | #[command(subcommand)] 28 | pub command: Option, 29 | } 30 | 31 | /// Runs the `sessions` command 32 | pub async fn run(cli: Cli) -> Result<()> { 33 | match cli.command { 34 | Some(Commands::List) => list().await?, 35 | Some(Commands::Read(options)) => read(options).await?, 36 | None => { 37 | color_eyre::eyre::bail!("No subcommand provided. Use `d sessions help` to see the list of available subcommands.") 38 | } 39 | } 40 | 41 | Ok(()) 42 | } 43 | 44 | /// Runs the `list` command 45 | pub async fn list() -> Result<()> { 46 | let home = std::env::var("D_ROOT").unwrap_or(std::env::var("HOME")?); 47 | let path = format!("{home}/.d/sessions"); 48 | 49 | // List all the `.yaml` files inside of `path` without the extension. 50 | let sessions = std::fs::read_dir(path)? 51 | .filter_map(|entry| entry.ok()) 52 | .filter_map(|entry| entry.file_name().into_string().ok()) 53 | .filter(|entry| entry.ends_with(".yaml")) 54 | .map(|entry| entry.trim_end_matches(".yaml").to_string()) 55 | .collect::>(); 56 | 57 | println!("{}", serde_json::to_string_pretty(&sessions)?); 58 | 59 | Ok(()) 60 | } 61 | 62 | /// Runs the `read` command 63 | pub async fn read(options: ReadOptions) -> Result<()> { 64 | match options.id { 65 | Some(id) => print_message(&options.session, &id), 66 | None => print_session(&options.session), 67 | } 68 | } 69 | 70 | fn print_session(session: &str) -> Result<()> { 71 | println!( 72 | "{}", 73 | serde_json::to_string_pretty(&crate::sessions::Session::load(String::from(session))?)? 74 | ); 75 | 76 | Ok(()) 77 | } 78 | 79 | fn print_message(name: &str, id: &str) -> Result<()> { 80 | let session = &crate::sessions::Session::load(String::from(name))?; 81 | 82 | // Find the message inside `session.messages` who's id is `id` 83 | let messages = session.messages(); 84 | let message = messages 85 | .iter() 86 | .find(|message| message.id == id) 87 | .ok_or_else(|| { 88 | color_eyre::eyre::eyre!("Message with id {} not found in session {}", id, name) 89 | })?; 90 | 91 | println!("{}", serde_json::to_string_pretty(&message)?); 92 | 93 | Ok(()) 94 | } 95 | -------------------------------------------------------------------------------- /crates/d/src/commands/vector.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::ops::RangeInclusive; 3 | 4 | use clap::{Parser, Subcommand}; 5 | use color_eyre::eyre::eyre; 6 | use color_eyre::eyre::{bail, Result}; 7 | use uuid::Uuid; 8 | 9 | use crate::constants::DIMENSION; 10 | use crate::constants::DISTANCE; 11 | 12 | #[derive(Debug, Subcommand)] 13 | pub enum Commands { 14 | /// Creates a new vector db collection 15 | #[clap(name = "create")] 16 | Create(CreateOptions), 17 | /// Deletes a vector db collection 18 | #[clap(name = "delete")] 19 | Delete(GetOptions), 20 | /// Gets the list of available collections 21 | #[clap(name = "list")] 22 | List, 23 | /// Gets a vector db collection 24 | #[clap(name = "get")] 25 | Get(GetOptions), 26 | /// Queries a vector db collection 27 | #[clap(name = "query")] 28 | Query(QueryOptions), 29 | /// Inserts an embedding into a collection 30 | #[clap(name = "insert")] 31 | Insert(InsertOptions), 32 | } 33 | 34 | #[derive(Default, Clone, Parser, Debug)] 35 | pub struct CreateOptions { 36 | /// Name of the collection 37 | name: String, 38 | /// Collection dimension 39 | #[arg(long, default_value = "1536")] 40 | dimension: Option, 41 | /// Collection distance function to use 42 | #[clap(long, value_enum, default_value = "cosine")] 43 | distance: Option, 44 | } 45 | 46 | #[derive(Default, Clone, Parser, Debug)] 47 | pub struct GetOptions { 48 | /// Name of the collection 49 | name: String, 50 | } 51 | 52 | /// The range of values for the `score` option which goes from 0 to 1. 53 | const SCORE_RANGE: RangeInclusive = 0.0..=1.0; 54 | 55 | /// Validates an input that's between 0 and 1 56 | fn parse_score(s: &str) -> std::result::Result { 57 | let value = s 58 | .parse::() 59 | .map_err(|_| format!("`{s}` must be a number between {} and {}", 0.0, 1.0,))?; 60 | if !SCORE_RANGE.contains(&value) { 61 | return Err(format!( 62 | "`{s}` must be a number between {} and {}", 63 | 0.0, 1.0, 64 | )); 65 | } 66 | 67 | Ok(value) 68 | } 69 | 70 | #[derive(Default, Clone, Parser, Debug)] 71 | pub struct QueryOptions { 72 | /// Name of the collection 73 | name: String, 74 | /// The query string 75 | #[arg(value_delimiter = ',', allow_hyphen_values = true)] 76 | query: Vec, 77 | /// The number of results to return 78 | #[arg(long, default_value = "1")] 79 | k: Option, 80 | /// Filter values that score less than this value. 81 | #[clap(long, value_parser = parse_score)] 82 | score: Option, 83 | } 84 | 85 | /// Parse a single key-value pair 86 | fn parse_key_val( 87 | s: &str, 88 | ) -> Result<(T, U), Box> 89 | where 90 | T: std::str::FromStr, 91 | T::Err: std::error::Error + Send + Sync + 'static, 92 | U: std::str::FromStr, 93 | U::Err: std::error::Error + Send + Sync + 'static, 94 | { 95 | let pos = s 96 | .find('=') 97 | .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; 98 | Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) 99 | } 100 | 101 | #[derive(Default, Clone, Parser, Debug)] 102 | pub struct InsertOptions { 103 | /// Name of the collection 104 | name: String, 105 | /// Embedding Vector 106 | #[arg(value_delimiter = ',', allow_hyphen_values = true)] 107 | vector: Vec, 108 | /// Embedding Id 109 | #[arg(long)] 110 | id: Option, 111 | /// Embedding Metadata 112 | #[arg(long, value_parser = parse_key_val::, value_delimiter = ',')] 113 | metadata: Option>, 114 | } 115 | 116 | #[derive(Debug, Parser)] 117 | #[command(name = "vector")] 118 | #[command(about = "Manage vector collections")] 119 | pub struct Cli { 120 | #[command(subcommand)] 121 | pub command: Option, 122 | } 123 | 124 | /// Runs the `vector` subcommand 125 | pub async fn run(cli: Cli) -> Result<()> { 126 | match cli.command { 127 | Some(Commands::Create(options)) => create(options).await?, 128 | Some(Commands::Delete(options)) => delete(options).await?, 129 | Some(Commands::Get(options)) => get(options).await?, 130 | Some(Commands::Query(options)) => query(options).await?, 131 | Some(Commands::Insert(options)) => insert(options).await?, 132 | Some(Commands::List) => list().await?, 133 | None => { 134 | bail!("No subcommand provided. Use `d help` to see the list of available subcommands.") 135 | } 136 | } 137 | 138 | Ok(()) 139 | } 140 | 141 | pub async fn insert(options: InsertOptions) -> Result<()> { 142 | let mut db = crate::vector::from_store()?; 143 | let mut metadata: HashMap = HashMap::new(); 144 | 145 | if let Some(tuples) = options.metadata { 146 | for (key, value) in tuples { 147 | metadata.insert(key, value); 148 | } 149 | } 150 | 151 | let embedding = crate::vector::Embedding { 152 | id: options.id.unwrap_or_else(|| Uuid::new_v4().to_string()), 153 | vector: options.vector, 154 | metadata: Some(metadata), 155 | }; 156 | 157 | db.insert_into_collection(&options.name, embedding)?; 158 | 159 | println!("Inserted into {} collection", options.name); 160 | 161 | Ok(()) 162 | } 163 | 164 | pub async fn query(options: QueryOptions) -> Result<()> { 165 | let db = crate::vector::from_store()?; 166 | 167 | match db.get_collection(&options.name) { 168 | Some(collection) => { 169 | if collection.dimension != options.query.len() { 170 | return Err(eyre!( 171 | "Collection {} does not have the same dimension of the query", 172 | options.name 173 | )); 174 | } 175 | 176 | let instant = std::time::Instant::now(); 177 | let results = collection.get_similarity(&options.query, options.k.unwrap_or(1)); 178 | 179 | // Filter those results whose value is less than `options.score` it its `Some` 180 | let results = match options.score { 181 | Some(score) => results.into_iter().filter(|x| x.score >= score).collect(), 182 | None => results, 183 | }; 184 | 185 | log::info!("Query to {} took {:?}", &options.name, instant.elapsed()); 186 | 187 | println!("{}", serde_json::to_string_pretty(&results)?); 188 | 189 | Ok(()) 190 | } 191 | None => Err(eyre!("Collection {} does not exist", options.name)), 192 | } 193 | } 194 | 195 | pub async fn list() -> Result<()> { 196 | let db = crate::vector::from_store()?; 197 | 198 | println!("{:#?}", db.list_collections()); 199 | 200 | Ok(()) 201 | } 202 | 203 | pub async fn get(options: GetOptions) -> Result<()> { 204 | let db = crate::vector::from_store()?; 205 | 206 | let collection = db.get_collection(&options.name).unwrap(); 207 | 208 | println!("{}", serde_json::to_string_pretty(&collection)?); 209 | 210 | Ok(()) 211 | } 212 | 213 | pub async fn delete(options: GetOptions) -> Result<()> { 214 | let mut db = crate::vector::from_store()?; 215 | 216 | db.delete_collection(&options.name)?; 217 | 218 | println!("Deleted {} collection", options.name); 219 | 220 | Ok(()) 221 | } 222 | 223 | pub async fn create(options: CreateOptions) -> Result<()> { 224 | let mut db = crate::vector::from_store()?; 225 | 226 | let collection = db.create_collection( 227 | options.name, 228 | options.dimension.unwrap_or(DIMENSION), 229 | options.distance.unwrap_or(DISTANCE), 230 | )?; 231 | 232 | println!("{:#?}", collection); 233 | 234 | Ok(()) 235 | } 236 | -------------------------------------------------------------------------------- /crates/d/src/constants.rs: -------------------------------------------------------------------------------- 1 | pub const DIMENSION: usize = 1536; 2 | pub const DISTANCE: crate::similarity::Distance = crate::similarity::Distance::Cosine; 3 | pub const MODEL: &str = "text-embedding-ada-002"; 4 | -------------------------------------------------------------------------------- /crates/d/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, Subcommand}; 2 | use color_eyre::eyre::{bail, eyre}; 3 | 4 | mod commands; 5 | mod constants; 6 | mod models; 7 | mod printer; 8 | mod sessions; 9 | mod shutdown; 10 | mod similarity; 11 | mod vector; 12 | 13 | #[derive(Debug, Subcommand)] 14 | pub enum Commands { 15 | /// OpenAI Chat AI API 16 | #[clap(name = "chat", alias = "c")] 17 | Chat(commands::chat::Options), 18 | /// OpenAI Embedding commands 19 | #[clap(name = "embeddings", alias = "e")] 20 | Embeddings(commands::embeddings::Options), 21 | /// Vector commands 22 | #[clap(name = "vector", alias = "v")] 23 | Vector(commands::vector::Cli), 24 | /// Read commands 25 | #[clap(name = "sessions", alias = "s")] 26 | Sessions(commands::sessions::Cli), 27 | } 28 | 29 | #[derive(Debug, Parser)] 30 | #[command(name = "d")] 31 | #[command(about = "Interact with LLMs through the terminal")] 32 | pub struct Cli { 33 | #[command(subcommand)] 34 | pub command: Option, 35 | } 36 | 37 | #[tokio::main] 38 | async fn main() -> color_eyre::eyre::Result<()> { 39 | color_eyre::install()?; 40 | env_logger::init(); 41 | 42 | // Load the OpenAI API Key from the OPENAI_API_KEY environment variable. 43 | openai::set_key(std::env::var("OPENAI_API_KEY")?); 44 | 45 | // Create the shutdown handler 46 | let shutdown = shutdown::Shutdown::new()?; 47 | 48 | // Run app in separate async task 49 | tokio::spawn(async { 50 | if let Err(e) = run().await { 51 | bail!("Application error: {}", e) 52 | } 53 | 54 | Ok(()) 55 | }); 56 | 57 | shutdown.handle().await; 58 | 59 | Ok(()) 60 | } 61 | 62 | async fn run() -> color_eyre::eyre::Result<()> { 63 | let result = match Cli::parse().command { 64 | Some(Commands::Chat(options)) => commands::chat::run(options).await, 65 | Some(Commands::Embeddings(options)) => commands::embeddings::run(options).await, 66 | Some(Commands::Vector(cli)) => commands::vector::run(cli).await, 67 | Some(Commands::Sessions(cli)) => commands::sessions::run(cli).await, 68 | None => Err(eyre!( 69 | "No subcommand provided. Use --help to see available subcommands." 70 | )), 71 | }; 72 | 73 | match result { 74 | Ok(_) => std::process::exit(0), 75 | Err(e) => { 76 | eprintln!("{e}"); 77 | std::process::exit(1); 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /crates/d/src/models.rs: -------------------------------------------------------------------------------- 1 | use clap::ValueEnum; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | #[derive(ValueEnum, Debug, Default, Clone, Copy, Serialize, Deserialize)] 5 | #[clap(rename_all = "kebab-case")] 6 | #[serde(rename_all = "kebab-case")] 7 | pub enum Model { 8 | #[serde(rename = "gpt-4-1106-preview")] 9 | GPT41106Preview, 10 | #[serde(rename = "gpt-4o-mini")] 11 | GPT4OMini, 12 | #[default] 13 | #[serde(rename = "gpt-4o")] 14 | GPT4O, 15 | #[serde(rename = "gpt-4")] 16 | GPT4, 17 | #[serde(rename = "gpt-4-32k")] 18 | GPT432K, 19 | #[serde(rename = "gpt-3.5-turbo")] 20 | GPT35Turbo, 21 | #[serde(rename = "gpt-3.5-turbo-16k")] 22 | GPT35Turbo16K, 23 | #[serde(rename = "gpt-3.5-turbo-1106")] 24 | GPT35Turbo1106, 25 | } 26 | 27 | impl Model { 28 | pub fn as_str(&self) -> &'static str { 29 | match self { 30 | Model::GPT41106Preview => "gpt-4-1106-preview", 31 | Model::GPT4OMini => "gpt-4o-mini", 32 | Model::GPT4O => "gpt-4o", 33 | Model::GPT4 => "gpt-4", 34 | Model::GPT432K => "gpt-4-32k", 35 | Model::GPT35Turbo => "gpt-3.5-turbo", 36 | Model::GPT35Turbo16K => "gpt-3.5-turbo-16k", 37 | Model::GPT35Turbo1106 => "gpt-3.5-turbo-1106", 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /crates/d/src/printer.rs: -------------------------------------------------------------------------------- 1 | use color_eyre::eyre::Result; 2 | use crossterm::terminal; 3 | 4 | // Markdown language constant string 5 | const THEME: &str = "tokyonight-storm"; 6 | const LANGUAGE: &str = "markdown"; 7 | 8 | pub struct CustomPrinter<'a> { 9 | inputs: Vec>, 10 | config: bat::config::Config<'a>, 11 | assets: bat::assets::HighlightingAssets, 12 | term_width: Option, 13 | } 14 | 15 | impl<'a> CustomPrinter<'a> { 16 | pub fn new() -> Self { 17 | let config = bat::config::Config { 18 | colored_output: true, 19 | true_color: true, 20 | language: Some(LANGUAGE), 21 | theme: THEME.to_string(), 22 | use_italic_text: true, 23 | wrapping_mode: bat::WrappingMode::Character, 24 | ..Default::default() 25 | }; 26 | 27 | CustomPrinter { 28 | inputs: vec![], 29 | config, 30 | assets: bat::assets::HighlightingAssets::from_binary(), 31 | term_width: None, 32 | } 33 | } 34 | 35 | /// Add a byte string as an input 36 | pub fn input_from_bytes(&mut self, content: &'a [u8]) -> &mut Self { 37 | self.input_from_reader(content) 38 | } 39 | 40 | /// Add a custom reader as an input 41 | pub fn input_from_reader(&mut self, reader: R) -> &mut Self { 42 | self.inputs 43 | .push(bat::input::Input::from_reader(Box::new(reader))); 44 | self 45 | } 46 | 47 | /// Custom print function that takes advantage of the fact that `bat` controllers can take a 48 | /// String as the output of the highlighted text. 49 | pub fn print(&mut self) -> Result { 50 | self.config.term_width = self 51 | .term_width 52 | .unwrap_or_else(|| terminal::size().unwrap().0 as usize); 53 | let inputs = std::mem::take(&mut self.inputs); 54 | 55 | let mut output = String::new(); 56 | 57 | let controller = bat::controller::Controller::new(&self.config, &self.assets); 58 | controller.run(inputs, Some(&mut output))?; 59 | 60 | Ok(output) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /crates/d/src/shutdown.rs: -------------------------------------------------------------------------------- 1 | use color_eyre::eyre::Result; 2 | use crossterm::cursor::Show; 3 | use crossterm::execute; 4 | use std::io::stdout; 5 | use std::{ 6 | error::Error, 7 | fmt, 8 | fmt::Display, 9 | future::Future, 10 | sync::atomic::{AtomicBool, Ordering}, 11 | }; 12 | use tokio::{signal, sync::broadcast}; 13 | 14 | #[derive(Debug, PartialEq, Eq)] 15 | pub struct AlreadyCreatedError; 16 | 17 | impl Error for AlreadyCreatedError {} 18 | 19 | impl Display for AlreadyCreatedError { 20 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 21 | f.write_str("shutdown handler already created") 22 | } 23 | } 24 | 25 | static CREATED: AtomicBool = AtomicBool::new(false); 26 | 27 | #[derive(Debug, Clone)] 28 | pub struct Shutdown { 29 | pub sender: broadcast::Sender<()>, 30 | } 31 | 32 | impl Shutdown { 33 | pub fn new() -> Result { 34 | if (CREATED).swap(true, Ordering::SeqCst) { 35 | log::error!("shutdown handler called twice"); 36 | return Err(AlreadyCreatedError); 37 | } 38 | 39 | let (tx, _) = broadcast::channel(1); 40 | let handle = register_handlers(); 41 | 42 | let tx_for_handle = tx.clone(); 43 | tokio::spawn(async move { 44 | log::debug!("Registered shutdown handlers"); 45 | handle.await; 46 | tx_for_handle.send(()).ok(); 47 | }); 48 | 49 | Ok(Self { sender: tx }) 50 | } 51 | 52 | pub fn handle(&self) -> impl Future + '_ { 53 | let mut rx = self.sender.subscribe(); 54 | 55 | async move { 56 | let rx = rx.recv(); 57 | 58 | rx.await.unwrap(); 59 | } 60 | } 61 | } 62 | 63 | fn register_handlers() -> impl Future { 64 | let ctrl_c = async { 65 | signal::ctrl_c() 66 | .await 67 | .expect("failed to install Ctrl+C handler"); 68 | }; 69 | 70 | #[cfg(unix)] 71 | let terminate = async { 72 | signal::unix::signal(signal::unix::SignalKind::terminate()) 73 | .expect("failed to install signal handler") 74 | .recv() 75 | .await; 76 | }; 77 | 78 | async move { 79 | tokio::select! { 80 | _ = ctrl_c => { 81 | log::info!("Received Ctrl+C signal"); 82 | if let Err(e) = execute!(stdout(), Show) { 83 | eprintln!("Failed to restore cursor: {e}"); 84 | } 85 | }, 86 | _ = terminate => { 87 | log::info!("Received terminate signal"); 88 | if let Err(e) = execute!(stdout(), Show) { 89 | eprintln!("Failed to restore cursor: {e}"); 90 | } 91 | }, 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /crates/d/src/similarity.rs: -------------------------------------------------------------------------------- 1 | use clap::ValueEnum; 2 | use serde::{Deserialize, Serialize}; 3 | use std::cmp::Ordering; 4 | 5 | #[derive(ValueEnum, Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] 6 | pub enum Distance { 7 | #[serde(rename = "euclidean")] 8 | Euclidean, 9 | #[default] 10 | #[serde(rename = "cosine")] 11 | Cosine, 12 | #[serde(rename = "dot")] 13 | DotProduct, 14 | } 15 | 16 | pub fn get_cache_attr(metric: Distance, vec: &[f32]) -> f32 { 17 | match metric { 18 | // Dot product doesn't allow any caching 19 | Distance::DotProduct | Distance::Euclidean => 0.0, 20 | // Precompute the magnitude of the vector 21 | Distance::Cosine => vec.iter().map(|&x| x.powi(2)).sum::().sqrt(), 22 | } 23 | } 24 | 25 | pub fn get_distance_fn(metric: Distance) -> impl Fn(&[f32], &[f32], f32) -> f32 { 26 | match metric { 27 | Distance::Euclidean => euclidian_distance, 28 | // We use dot product for cosine because we've normalized the vectors on insertion 29 | Distance::Cosine | Distance::DotProduct => dot_product, 30 | } 31 | } 32 | 33 | fn euclidian_distance(a: &[f32], b: &[f32], a_sum_squares: f32) -> f32 { 34 | let mut cross_terms = 0.0; 35 | let mut b_sum_squares = 0.0; 36 | 37 | for (i, j) in a.iter().zip(b) { 38 | cross_terms += i * j; 39 | b_sum_squares += j.powi(2); 40 | } 41 | 42 | 2.0f32 43 | .mul_add(-cross_terms, a_sum_squares + b_sum_squares) 44 | .max(0.0) 45 | .sqrt() 46 | } 47 | 48 | fn dot_product(a: &[f32], b: &[f32], _: f32) -> f32 { 49 | a.iter().zip(b).fold(0.0, |acc, (x, y)| acc + x * y) 50 | } 51 | 52 | pub fn normalize(vec: &[f32]) -> Vec { 53 | let magnitude = (vec.iter().fold(0.0, |acc, &val| val.mul_add(val, acc))).sqrt(); 54 | 55 | if magnitude > std::f32::EPSILON { 56 | vec.iter().map(|&val| val / magnitude).collect() 57 | } else { 58 | vec.to_vec() 59 | } 60 | } 61 | 62 | pub struct ScoreIndex { 63 | pub score: f32, 64 | pub index: usize, 65 | } 66 | 67 | impl PartialEq for ScoreIndex { 68 | fn eq(&self, other: &Self) -> bool { 69 | self.score.eq(&other.score) 70 | } 71 | } 72 | 73 | impl Eq for ScoreIndex {} 74 | 75 | impl Ord for ScoreIndex { 76 | fn cmp(&self, other: &Self) -> Ordering { 77 | self.partial_cmp(other).unwrap_or(Ordering::Equal) 78 | } 79 | } 80 | 81 | impl PartialOrd for ScoreIndex { 82 | fn partial_cmp(&self, other: &Self) -> Option { 83 | // The comparison is intentionally reversed here to make the heap a min-heap 84 | other.score.partial_cmp(&self.score) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /crates/d/src/vector.rs: -------------------------------------------------------------------------------- 1 | use color_eyre::eyre::{eyre, ContextCompat, Result}; 2 | use lazy_static::lazy_static; 3 | use rayon::prelude::*; 4 | use std::{ 5 | collections::{BinaryHeap, HashMap}, 6 | fs, 7 | path::PathBuf, 8 | }; 9 | 10 | use crate::similarity::{get_cache_attr, get_distance_fn, normalize, Distance, ScoreIndex}; 11 | 12 | lazy_static! { 13 | pub static ref STORE_PATH: PathBuf = PathBuf::from(std::env::var("D_DB_PATH").unwrap_or( 14 | format!("{}/.d.db", std::env::var("HOME").unwrap_or(".".to_string())) 15 | )); 16 | } 17 | 18 | #[allow(dead_code)] 19 | #[derive(Debug, thiserror::Error)] 20 | pub enum Error { 21 | #[error("Collection already exists")] 22 | UniqueViolation, 23 | 24 | #[error("Collection doesn't exist")] 25 | NotFound, 26 | 27 | #[error("The dimension of the vector doesn't match the dimension of the collection")] 28 | DimensionMismatch, 29 | } 30 | 31 | #[derive(Debug, serde::Serialize, serde::Deserialize)] 32 | pub struct Db { 33 | pub collections: HashMap, 34 | } 35 | 36 | #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] 37 | pub struct SimilarityResult { 38 | pub score: f32, 39 | pub embedding: Embedding, 40 | } 41 | 42 | #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] 43 | pub struct Collection { 44 | /// Dimension of the vectors in the collection 45 | pub dimension: usize, 46 | /// Distance metric used for querying 47 | pub distance: Distance, 48 | /// Embeddings in the collection 49 | #[serde(default)] 50 | pub embeddings: Vec, 51 | } 52 | 53 | impl Collection { 54 | pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { 55 | let memo_attr = get_cache_attr(self.distance, query); 56 | let distance_fn = get_distance_fn(self.distance); 57 | 58 | let scores = self 59 | .embeddings 60 | .par_iter() 61 | .enumerate() 62 | .map(|(index, embedding)| { 63 | let score = distance_fn(&embedding.vector, query, memo_attr); 64 | ScoreIndex { score, index } 65 | }) 66 | .collect::>(); 67 | 68 | let mut heap = BinaryHeap::new(); 69 | for score_index in scores { 70 | if heap.len() < k || score_index < *heap.peek().unwrap() { 71 | heap.push(score_index); 72 | 73 | if heap.len() > k { 74 | heap.pop(); 75 | } 76 | } 77 | } 78 | 79 | heap.into_sorted_vec() 80 | .into_iter() 81 | .map(|ScoreIndex { score, index }| SimilarityResult { 82 | score, 83 | embedding: self.embeddings[index].clone(), 84 | }) 85 | .collect() 86 | } 87 | } 88 | 89 | #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] 90 | pub struct Embedding { 91 | pub id: String, 92 | pub vector: Vec, 93 | pub metadata: Option>, 94 | } 95 | 96 | impl Db { 97 | pub fn new() -> Self { 98 | Self { 99 | collections: HashMap::new(), 100 | } 101 | } 102 | 103 | pub fn create_collection( 104 | &mut self, 105 | name: String, 106 | dimension: usize, 107 | distance: Distance, 108 | ) -> Result { 109 | if self.collections.contains_key(&name) { 110 | return Err(eyre!("Collection {} already exists", &name)); 111 | } 112 | 113 | let collection = Collection { 114 | dimension, 115 | distance, 116 | embeddings: Vec::new(), 117 | }; 118 | 119 | log::debug!("Creating collection {}: {:?}", &name, collection); 120 | self.collections.insert(name.clone(), collection.clone()); 121 | 122 | log::debug!("Created collection {}", &name); 123 | Ok(collection) 124 | } 125 | 126 | pub fn delete_collection(&mut self, name: &str) -> Result<(), Error> { 127 | if !self.collections.contains_key(name) { 128 | return Err(Error::NotFound); 129 | } 130 | 131 | self.collections.remove(name); 132 | 133 | Ok(()) 134 | } 135 | 136 | pub fn insert_into_collection( 137 | &mut self, 138 | collection_name: &str, 139 | mut embedding: Embedding, 140 | ) -> Result<(), Error> { 141 | let collection = self 142 | .collections 143 | .get_mut(collection_name) 144 | .ok_or(Error::NotFound)?; 145 | 146 | if collection.embeddings.iter().any(|e| e.id == embedding.id) { 147 | return Err(Error::UniqueViolation); 148 | } 149 | 150 | if embedding.vector.len() != collection.dimension { 151 | return Err(Error::DimensionMismatch); 152 | } 153 | 154 | // Normalize the vector if the distance metric is cosine, so we can use dot product later 155 | if collection.distance == Distance::Cosine { 156 | embedding.vector = normalize(&embedding.vector); 157 | } 158 | 159 | collection.embeddings.push(embedding); 160 | 161 | Ok(()) 162 | } 163 | 164 | pub fn list_collections(&self) -> Vec { 165 | // Get the keys of a HasMap 166 | self.collections.keys().cloned().collect() 167 | } 168 | 169 | pub fn get_collection(&self, name: &str) -> Option<&Collection> { 170 | self.collections.get(name) 171 | } 172 | 173 | fn load_from_store() -> color_eyre::eyre::Result { 174 | if !STORE_PATH.exists() { 175 | log::debug!("Creating database store"); 176 | fs::create_dir_all(STORE_PATH.parent().context("Invalid store path")?)?; 177 | 178 | return Ok(Self::new()); 179 | } 180 | 181 | log::debug!("Loading database from store"); 182 | let db = fs::read(STORE_PATH.as_path())?; 183 | Ok(bincode::deserialize(&db[..])?) 184 | } 185 | 186 | fn save_to_store(&self) -> color_eyre::eyre::Result<()> { 187 | let db = bincode::serialize(self)?; 188 | 189 | fs::write(STORE_PATH.as_path(), db)?; 190 | 191 | Ok(()) 192 | } 193 | } 194 | 195 | impl Drop for Db { 196 | fn drop(&mut self) { 197 | log::debug!("Saving database to store"); 198 | self.save_to_store().ok(); 199 | } 200 | } 201 | 202 | pub fn from_store() -> color_eyre::eyre::Result { 203 | Db::load_from_store() 204 | } 205 | -------------------------------------------------------------------------------- /crates/e/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "e" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "Interact with multiple LLMs from the terminal" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [[bin]] 13 | name = "e" 14 | path = "src/main.rs" 15 | 16 | [dependencies] 17 | es_stream = { version = "0.1.0", path = "../../lib/es_stream" } 18 | clap = { version = "4.5.16", features = ["derive", "string", "env"] } 19 | serde = { version = "1.0.209", features = ["derive"] } 20 | serde_json = "1.0.127" 21 | serde_yaml = "0.9.34" 22 | env_logger = "0.11.5" 23 | log = "0.4.22" 24 | bat = { version = "0.24.0", path = "../../lib/bat", features = ["os_str_bytes"] } 25 | thiserror = "1.0.56" 26 | tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } 27 | clap-stdin = "0.5.1" 28 | futures = "0.3.30" 29 | spinners = "4.1.1" 30 | atty = "0.2.14" 31 | crossterm = "0.28.0" 32 | config-file = "0.2.3" 33 | tera = "1.20.0" 34 | -------------------------------------------------------------------------------- /crates/e/src/anthropic.rs: -------------------------------------------------------------------------------- 1 | use es_stream::anthropic; 2 | 3 | use crate::prelude::*; 4 | 5 | const DEFAULT_URL: &str = "https://api.anthropic.com/v1"; 6 | const DEFAULT_MODEL: &str = "claude-3-5-sonnet-20240620"; 7 | const DEFAULT_ENV: &str = "ANTHROPIC_API_KEY"; 8 | 9 | pub async fn run(prompt: String, args: Args) -> Result<()> { 10 | let key = match args.globals.api_key { 11 | Some(key) => key, 12 | None => { 13 | let environment_variable = match args.globals.api_env { 14 | Some(env) => env, 15 | None => DEFAULT_ENV.to_string(), 16 | }; 17 | std::env::var(environment_variable)? 18 | } 19 | }; 20 | log::info!("key: {}", key); 21 | 22 | let url = match args.globals.api_base_url { 23 | Some(url) => url, 24 | None => DEFAULT_URL.to_string(), 25 | }; 26 | log::info!("url: {}", url); 27 | 28 | let auth = anthropic::Auth::new(key, args.globals.api_version); 29 | 30 | log::info!("auth: {:#?}", auth); 31 | 32 | let client = anthropic::Client::new(auth, url); 33 | 34 | log::info!("client: {:#?}", client); 35 | 36 | let messages = vec![anthropic::Message { 37 | role: anthropic::Role::User, 38 | content: prompt, 39 | }]; 40 | 41 | let mut body = anthropic::MessageBody::new( 42 | args.globals 43 | .model 44 | .unwrap_or(DEFAULT_MODEL.to_string()) 45 | .as_ref(), 46 | messages, 47 | args.globals.max_tokens.unwrap_or(4096), 48 | ); 49 | 50 | body.system = args.globals.system; 51 | body.temperature = args.globals.temperature; 52 | body.top_p = args.globals.top_p; 53 | body.top_k = args.globals.top_k; 54 | 55 | log::info!("body: {:#?}", body); 56 | 57 | let stream = client.delta(&body)?; 58 | 59 | handle_stream( 60 | stream, 61 | args.globals.quiet.unwrap_or(false), 62 | args.globals.language, 63 | ) 64 | .await 65 | } 66 | -------------------------------------------------------------------------------- /crates/e/src/args.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, ValueEnum}; 2 | use clap_stdin::MaybeStdin; 3 | use serde::{Deserialize, Serialize}; 4 | use serde_json::Value; 5 | use std::str::FromStr; 6 | 7 | use crate::prelude::*; 8 | 9 | #[derive(Debug, clap::Args)] 10 | pub struct Globals { 11 | /// Hidden prompt to support prompting from stdin and as an argument 12 | #[clap(default_value = "-", hide = true)] 13 | pub stdin: MaybeStdin, 14 | 15 | /// The user message prompt 16 | #[clap(default_value = "", hide = true)] 17 | pub prompt: MaybeStdin, 18 | 19 | /// The API provider to use. 20 | #[clap(short, long, value_enum)] 21 | pub api: Option, 22 | 23 | /// The LLM Model to use 24 | #[clap(short, long)] 25 | pub model: Option, 26 | 27 | /// The maximum amount of tokens to return. 28 | #[clap(long)] 29 | pub max_tokens: Option, 30 | 31 | /// The minimum amount of tokens to return. 32 | #[clap(long)] 33 | pub min_tokens: Option, 34 | 35 | /// The environment variable to use to get the access token for the api. 36 | #[clap(long)] 37 | pub api_env: Option, 38 | 39 | /// The api version to use. 40 | #[clap(long)] 41 | pub api_version: Option, 42 | 43 | /// The api key to use (will override the value of the environment variable.) 44 | #[clap(long)] 45 | pub api_key: Option, 46 | 47 | /// The api base url. 48 | #[clap(long)] 49 | pub api_base_url: Option, 50 | 51 | /// Don't run the spinner 52 | #[clap(long)] 53 | pub quiet: Option, 54 | 55 | /// Add a system message to the request. 56 | #[clap(long)] 57 | pub system: Option, 58 | 59 | /// Temperature value. 60 | #[clap(long)] 61 | pub temperature: Option, 62 | 63 | /// Top-P value. 64 | #[clap(long)] 65 | pub top_p: Option, 66 | 67 | /// Top-K value. 68 | #[clap(long)] 69 | pub top_k: Option, 70 | 71 | /// Config file 72 | #[clap(long, default_value = "~/.config/e.toml")] 73 | pub config_file: String, 74 | 75 | /// Preset configuration 76 | #[clap(short, long)] 77 | pub preset: Option, 78 | 79 | /// Additional variables in JSON format 80 | #[clap(long, default_value="{}", value_parser = parse_json)] 81 | pub vars: Option, 82 | 83 | /// Suffix prompt 84 | #[clap(long)] 85 | pub suffix: Option, 86 | 87 | /// Language to use for syntax highlight 88 | #[clap(long, default_value = "markdown")] 89 | pub language: String, 90 | 91 | /// Prompt template to use 92 | #[clap(short, long)] 93 | pub template: Option, 94 | 95 | /// Prints the rendered template instead of calling the LLM. 96 | #[clap(long, default_value = "false")] 97 | pub print_template: bool, 98 | } 99 | 100 | /// Custom parser function for JSON values 101 | fn parse_json(s: &str) -> std::result::Result { 102 | serde_json::from_str(s) 103 | } 104 | 105 | #[derive(ValueEnum, Debug, Default, Clone, Copy, Serialize, Deserialize)] 106 | #[serde(rename_all = "lowercase")] 107 | pub enum Api { 108 | OpenAi, 109 | #[default] 110 | Anthropic, 111 | Google, 112 | Mistral, 113 | MistralFim, 114 | } 115 | 116 | // From string to API enum 117 | impl FromStr for Api { 118 | type Err = Error; 119 | 120 | fn from_str(s: &str) -> std::result::Result { 121 | match s { 122 | "OpenAi" => Ok(Api::OpenAi), 123 | "openai" => Ok(Api::OpenAi), 124 | "Anthropic" => Ok(Api::Anthropic), 125 | "anthropic" => Ok(Api::Anthropic), 126 | "google" => Ok(Api::Google), 127 | "Google" => Ok(Api::Google), 128 | "gemini" => Ok(Api::Google), 129 | "Gemini" => Ok(Api::Google), 130 | "mistral" => Ok(Api::Mistral), 131 | "Mistral" => Ok(Api::Mistral), 132 | "mistral-fim" => Ok(Api::MistralFim), 133 | "mistral_fim" => Ok(Api::MistralFim), 134 | "Mistral-FIM" => Ok(Api::MistralFim), 135 | "Mistral-Fim" => Ok(Api::MistralFim), 136 | "MistralFim" => Ok(Api::MistralFim), 137 | "Mistral_FIM" => Ok(Api::MistralFim), 138 | "Mistral_Fim" => Ok(Api::MistralFim), 139 | "MistralFIM" => Ok(Api::MistralFim), 140 | _ => Err(Error::InvalidAPI), 141 | } 142 | } 143 | } 144 | 145 | #[derive(Debug, Parser)] 146 | #[command(name = "e", version = "0.1.0")] 147 | #[command(about = "Interact with LLMs through the terminal")] 148 | #[command( 149 | long_about = "This Rust-based CLI enables users to interact with various Large Language Models 150 | (LLMs) directly from the terminal. Through this tool, you can send prompts to different 151 | APIs, such as OpenAI, Anthropic, Google, Mistral, and Mistral FIM, and receive and handle 152 | responses from these models. 153 | 154 | The tool offers extensive configuration options, allowing you 155 | to specify parameters like model type, maximum and minimum tokens, temperature, top-p 156 | sampling, system messages, and more. You can set these options via command line arguments 157 | or environment variables. Additionally, it supports preset configurations and prompt 158 | templates, enabling more advanced and customizable usage scenarios. 159 | 160 | The CLI can format and 161 | highlight the model's responses using syntax highlighting, making it easier to read the 162 | output in the terminal. It also includes functionality to handle streaming responses 163 | efficiently, ensuring a smooth user experience when interacting with the LLMs." 164 | )] 165 | pub struct Args { 166 | #[clap(flatten)] 167 | pub globals: Globals, 168 | } 169 | -------------------------------------------------------------------------------- /crates/e/src/config.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use serde_json::Value; 3 | 4 | #[derive(Debug, Default, Deserialize)] 5 | pub struct Preset { 6 | pub name: String, 7 | 8 | // Api 9 | pub api: crate::args::Api, 10 | pub env: Option, 11 | pub key: Option, 12 | pub base_url: Option, 13 | 14 | // Model 15 | pub model: Option, 16 | 17 | // Model Configuration 18 | pub system: Option, 19 | pub max_tokens: Option, 20 | pub version: Option, 21 | pub temperature: Option, 22 | pub top_p: Option, 23 | pub top_k: Option, 24 | } 25 | 26 | #[derive(Debug, Default, Deserialize)] 27 | pub enum Role { 28 | Assistant, 29 | Model, 30 | #[default] 31 | User, 32 | Human, 33 | System, 34 | } 35 | 36 | #[derive(Debug, Default, Deserialize)] 37 | pub struct Template { 38 | pub name: String, 39 | pub description: Option, 40 | pub template: String, 41 | pub default_vars: Option, 42 | pub system: Option, 43 | } 44 | 45 | #[derive(Debug, Default, Deserialize)] 46 | pub struct Config { 47 | // Api 48 | pub api: Option, 49 | pub base_url: Option, 50 | pub env: Option, 51 | pub key: Option, 52 | 53 | // Presets 54 | pub presets: Option>, 55 | 56 | // Templates 57 | pub templates: Option>, 58 | 59 | // Global 60 | pub quiet: Option, 61 | 62 | // Model 63 | pub model: Option, 64 | 65 | // Model Configuration 66 | pub system: Option, 67 | pub max_tokens: Option, 68 | pub version: Option, 69 | pub temperature: Option, 70 | pub top_p: Option, 71 | pub top_k: Option, 72 | } 73 | -------------------------------------------------------------------------------- /crates/e/src/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error)] 2 | pub enum Error { 3 | #[error("io error")] 4 | Io(#[from] std::io::Error), 5 | #[error("unable to stream the api")] 6 | EsStream(#[from] es_stream::error::Error), 7 | #[error("unable to get value from environment variable")] 8 | EnvVar(#[from] std::env::VarError), 9 | #[error("invalid api")] 10 | InvalidAPI, 11 | #[error("unable to print with bat")] 12 | Bat(#[from] bat::error::Error), 13 | #[error("unable to coherce to u32")] 14 | TryFrom(#[from] std::num::TryFromIntError), 15 | #[error("api not specified")] 16 | ApiNotSpecified, 17 | #[error("config file error")] 18 | ConfigFile(#[from] config_file::ConfigFileError), 19 | #[error("infallible error")] 20 | Infallible(#[from] std::convert::Infallible), 21 | #[error("template not found")] 22 | TemplateNotFound, 23 | #[error("tera error")] 24 | Tera(#[from] tera::Error), 25 | } 26 | 27 | pub(crate) fn format_error( 28 | e: &impl std::error::Error, 29 | f: &mut std::fmt::Formatter, 30 | ) -> std::fmt::Result { 31 | write!(f, "{e}")?; 32 | 33 | let mut source = e.source(); 34 | 35 | if e.source().is_some() { 36 | writeln!(f, "\ncaused by:")?; 37 | let mut i: usize = 0; 38 | while let Some(inner) = source { 39 | writeln!(f, "{i: >5}: {inner}")?; 40 | source = inner.source(); 41 | i += 1; 42 | } 43 | } 44 | 45 | Ok(()) 46 | } 47 | 48 | impl std::fmt::Debug for Error { 49 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 50 | format_error(self, f) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /crates/e/src/google.rs: -------------------------------------------------------------------------------- 1 | use es_stream::google; 2 | 3 | use crate::prelude::*; 4 | 5 | const DEFAULT_URL: &str = "https://generativelanguage.googleapis.com/v1beta"; 6 | const DEFAULT_MODEL: &str = "gemini-1.5-pro"; 7 | const DEFAULT_ENV: &str = "GOOGLE_API_KEY"; 8 | 9 | pub async fn run(prompt: String, args: Args) -> Result<()> { 10 | let key = match args.globals.api_key { 11 | Some(key) => key, 12 | None => { 13 | let environment_variable = match args.globals.api_env { 14 | Some(env) => env, 15 | None => DEFAULT_ENV.to_string(), 16 | }; 17 | std::env::var(environment_variable)? 18 | } 19 | }; 20 | log::info!("key: {}", key); 21 | 22 | let url = match args.globals.api_base_url { 23 | Some(url) => url, 24 | None => DEFAULT_URL.to_string(), 25 | }; 26 | log::info!("url: {}", url); 27 | 28 | let auth = google::Auth::new(key); 29 | log::info!("auth: {:#?}", auth); 30 | 31 | let client = google::Client::new(auth, url); 32 | log::info!("client: {:#?}", client); 33 | 34 | let contents = vec![google::Content { 35 | parts: vec![google::Part { text: prompt }], 36 | role: google::Role::User, 37 | }]; 38 | 39 | let mut body = google::MessageBody::new( 40 | args.globals 41 | .model 42 | .unwrap_or(DEFAULT_MODEL.to_string()) 43 | .as_ref(), 44 | contents, 45 | ); 46 | 47 | if let Some(system) = args.globals.system { 48 | let system_message = google::Content { 49 | parts: vec![google::Part { text: system }], 50 | role: google::Role::User, 51 | }; 52 | 53 | body.contents.insert(0, system_message); 54 | } 55 | 56 | body.generation_config = Some(google::GenerationConfig { 57 | max_output_tokens: Some(args.globals.max_tokens.unwrap_or(4096)), 58 | temperature: args.globals.temperature, 59 | top_p: args.globals.top_p, 60 | top_k: args.globals.top_k, 61 | ..Default::default() 62 | }); 63 | 64 | log::info!("body: {:#?}", body); 65 | 66 | let stream = client.delta(&body)?; 67 | 68 | handle_stream( 69 | stream, 70 | args.globals.quiet.unwrap_or(false), 71 | args.globals.language, 72 | ) 73 | .await 74 | } 75 | -------------------------------------------------------------------------------- /crates/e/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use config_file::FromConfigFile; 3 | 4 | mod anthropic; 5 | mod args; 6 | mod config; 7 | mod error; 8 | mod google; 9 | mod mistral; 10 | mod mistral_fim; 11 | mod openai; 12 | mod prelude; 13 | mod printer; 14 | 15 | use crate::prelude::*; 16 | 17 | const SYSTEM_TEMPLATE: &str = "system"; 18 | const PROMPT_TEMPLATE: &str = "prompt"; 19 | 20 | #[tokio::main] 21 | async fn main() -> Result<()> { 22 | env_logger::init(); 23 | 24 | let mut args = Args::parse(); 25 | 26 | let mut api: Option = if let Some(api) = args.globals.api.clone() { 27 | Some(api.parse()?) 28 | } else { 29 | None 30 | }; 31 | 32 | let mut prompt = args.globals.prompt.to_string(); 33 | let mut stdin = args.globals.stdin.to_string(); 34 | 35 | // Turn them around if there's nothing coming from `stdin`. 36 | if prompt.is_empty() && !stdin.is_empty() { 37 | stdin = args.globals.prompt.to_string(); 38 | prompt = args.globals.stdin.to_string(); 39 | } 40 | 41 | log::info!("info: {:#?}", args.globals); 42 | 43 | let home = std::env::var("HOME")?; 44 | let path = args.globals.config_file.clone().replace('~', &home); 45 | 46 | log::info!("path: {:#?}", path); 47 | 48 | // Check if `path` exists 49 | let config = if !std::path::Path::new(&path).exists() { 50 | Config::default() 51 | } else { 52 | Config::from_config_file(path)? 53 | }; 54 | 55 | log::info!("config: {:#?}", config); 56 | 57 | if let Some(preset) = args.globals.preset.clone() { 58 | let p = config 59 | .presets 60 | .unwrap_or_default() 61 | .into_iter() 62 | .find(|p| p.name == preset); 63 | 64 | if let Some(p) = p { 65 | api = Some(p.api); 66 | 67 | if args.globals.top_p.is_none() { 68 | args.globals.top_p = p.top_p; 69 | } 70 | if args.globals.top_k.is_none() { 71 | args.globals.top_k = p.top_k; 72 | } 73 | if args.globals.temperature.is_none() { 74 | args.globals.temperature = p.temperature; 75 | } 76 | if args.globals.system.is_none() { 77 | args.globals.system = p.system; 78 | } 79 | if args.globals.max_tokens.is_none() { 80 | args.globals.max_tokens = p.max_tokens; 81 | } 82 | if args.globals.api_version.is_none() { 83 | args.globals.api_version = p.version; 84 | } 85 | if args.globals.api_env.is_none() { 86 | args.globals.api_env = p.env; 87 | } 88 | if args.globals.api_key.is_none() { 89 | args.globals.api_key = p.key; 90 | } 91 | if args.globals.api_base_url.is_none() { 92 | args.globals.api_base_url = p.base_url; 93 | } 94 | if args.globals.model.is_none() { 95 | args.globals.model = p.model; 96 | } 97 | } 98 | }; 99 | 100 | if args.globals.top_p.is_none() { 101 | args.globals.top_p = config.top_p; 102 | } 103 | if args.globals.top_k.is_none() { 104 | args.globals.top_k = config.top_k; 105 | } 106 | if args.globals.temperature.is_none() { 107 | args.globals.temperature = config.temperature; 108 | } 109 | if args.globals.system.is_none() { 110 | args.globals.system = config.system; 111 | } 112 | if args.globals.max_tokens.is_none() { 113 | args.globals.max_tokens = config.max_tokens; 114 | } 115 | if args.globals.api_version.is_none() { 116 | args.globals.api_version = config.version; 117 | } 118 | if args.globals.api_env.is_none() { 119 | args.globals.api_env = config.env; 120 | } 121 | if args.globals.api_key.is_none() { 122 | args.globals.api_key = config.key; 123 | } 124 | if args.globals.api_base_url.is_none() { 125 | args.globals.api_base_url = config.base_url; 126 | } 127 | if args.globals.model.is_none() { 128 | args.globals.model = config.model; 129 | } 130 | if args.globals.quiet.is_none() { 131 | args.globals.quiet = config.quiet; 132 | } 133 | if api.is_none() { 134 | api = config.api; 135 | } 136 | 137 | log::info!("globals: {:#?}", args.globals); 138 | 139 | let prompt: String = if let Some(ref template) = args.globals.template { 140 | let t = config 141 | .templates 142 | .unwrap_or_default() 143 | .into_iter() 144 | .find(|t| t.name == *template); 145 | 146 | if t.is_none() { 147 | return Err(Error::TemplateNotFound); 148 | } 149 | 150 | let t = t.unwrap(); 151 | 152 | log::info!("template: {:#?}", t); 153 | 154 | let system = args.globals.system.clone().unwrap_or_default().to_string(); 155 | let suffix = args.globals.suffix.clone().unwrap_or_default().to_string(); 156 | let language = args.globals.language.clone(); 157 | 158 | let mut default_vars = t.default_vars.unwrap_or_default(); 159 | let vars = args.globals.vars.take().unwrap_or_default(); 160 | merge(&mut default_vars, vars); 161 | 162 | let mut value = serde_json::json!({ 163 | "prompt": prompt, 164 | "system": system, 165 | "stdin": stdin, 166 | "suffix": suffix, 167 | "language": language, 168 | }); 169 | 170 | merge(&mut value, default_vars); 171 | 172 | let context = tera::Context::from_value(value)?; 173 | 174 | log::info!("context: {:#?}", context); 175 | 176 | let mut tera = tera::Tera::default(); 177 | 178 | if let Some(system) = t.system { 179 | tera.add_raw_template(SYSTEM_TEMPLATE, &system)?; 180 | args.globals.system = Some(tera.render(SYSTEM_TEMPLATE, &context)?); 181 | } 182 | 183 | tera.add_raw_template(PROMPT_TEMPLATE, t.template.as_ref())?; 184 | 185 | tera.render(PROMPT_TEMPLATE, &context)? 186 | } else if !stdin.is_empty() { 187 | format!("{}\n{}", stdin, prompt) 188 | } else { 189 | prompt 190 | }; 191 | 192 | if args.globals.print_template { 193 | if args.globals.system.is_some() { 194 | println!( 195 | "System Prompt:\n{}\n---\n", 196 | args.globals.system.clone().unwrap() 197 | ); 198 | } 199 | println!("Prompt:\n{}", prompt); 200 | return Ok(()); 201 | } 202 | 203 | match api { 204 | Some(Api::OpenAi) => openai::run(prompt, args).await, 205 | Some(Api::Anthropic) => anthropic::run(prompt, args).await, 206 | Some(Api::Google) => google::run(prompt, args).await, 207 | Some(Api::Mistral) => mistral::run(prompt, args).await, 208 | Some(Api::MistralFim) => mistral_fim::run(prompt, args).await, 209 | None => Err(Error::ApiNotSpecified), 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /crates/e/src/mistral.rs: -------------------------------------------------------------------------------- 1 | use es_stream::mistral; 2 | 3 | use crate::prelude::*; 4 | 5 | const DEFAULT_URL: &str = "https://api.mistral.ai/v1"; 6 | const DEFAULT_MODEL: &str = "mistral-small-latest"; 7 | const DEFAULT_ENV: &str = "MISTRAL_API_KEY"; 8 | 9 | pub async fn run(prompt: String, args: Args) -> Result<()> { 10 | let key = match args.globals.api_key { 11 | Some(key) => key, 12 | None => { 13 | let environment_variable = match args.globals.api_env { 14 | Some(env) => env, 15 | None => DEFAULT_ENV.to_string(), 16 | }; 17 | std::env::var(environment_variable)? 18 | } 19 | }; 20 | log::info!("key: {}", key); 21 | 22 | let url = match args.globals.api_base_url { 23 | Some(url) => url, 24 | None => DEFAULT_URL.to_string(), 25 | }; 26 | 27 | log::info!("url: {}", url); 28 | 29 | let auth = mistral::Auth::new(key); 30 | 31 | log::info!("auth: {:#?}", auth); 32 | 33 | let client = mistral::Client::new(auth, url); 34 | 35 | log::info!("client: {:#?}", client); 36 | 37 | let messages = vec![mistral::Message { 38 | role: mistral::Role::User, 39 | content: prompt, 40 | }]; 41 | 42 | let mut body = mistral::MessageBody::new( 43 | args.globals 44 | .model 45 | .unwrap_or(DEFAULT_MODEL.to_string()) 46 | .as_ref(), 47 | messages, 48 | ); 49 | 50 | if let Some(system) = args.globals.system { 51 | let system_message = mistral::Message { 52 | role: mistral::Role::System, 53 | content: system, 54 | }; 55 | 56 | body.messages.insert(0, system_message); 57 | } 58 | 59 | body.temperature = args.globals.temperature; 60 | body.top_p = args.globals.top_p; 61 | if let Some(max_tokens) = args.globals.max_tokens { 62 | body.max_tokens = Some(max_tokens); 63 | }; 64 | if let Some(min_tokens) = args.globals.min_tokens { 65 | body.min_tokens = Some(min_tokens); 66 | }; 67 | 68 | log::info!("body: {:#?}", body); 69 | 70 | let stream = client.delta(&body)?; 71 | 72 | handle_stream( 73 | stream, 74 | args.globals.quiet.unwrap_or(false), 75 | args.globals.language, 76 | ) 77 | .await 78 | } 79 | -------------------------------------------------------------------------------- /crates/e/src/mistral_fim.rs: -------------------------------------------------------------------------------- 1 | use es_stream::mistral_fim; 2 | 3 | use crate::prelude::*; 4 | 5 | const DEFAULT_URL: &str = "https://api.mistral.ai/v1"; 6 | const DEFAULT_MODEL: &str = "codestral-2405"; 7 | const DEFAULT_ENV: &str = "MISTRAL_API_KEY"; 8 | 9 | pub async fn run(prompt: String, args: Args) -> Result<()> { 10 | let key = match args.globals.api_key { 11 | Some(key) => key, 12 | None => { 13 | let environment_variable = match args.globals.api_env { 14 | Some(env) => env, 15 | None => DEFAULT_ENV.to_string(), 16 | }; 17 | std::env::var(environment_variable)? 18 | } 19 | }; 20 | log::info!("key: {}", key); 21 | 22 | let url = match args.globals.api_base_url { 23 | Some(url) => url, 24 | None => DEFAULT_URL.to_string(), 25 | }; 26 | 27 | log::info!("url: {}", url); 28 | 29 | let auth = mistral_fim::Auth::new(key); 30 | 31 | log::info!("auth: {:#?}", auth); 32 | 33 | let client = mistral_fim::Client::new(auth, url); 34 | 35 | log::info!("client: {:#?}", client); 36 | 37 | let mut body = mistral_fim::MessageBody::new( 38 | args.globals 39 | .model 40 | .unwrap_or(DEFAULT_MODEL.to_string()) 41 | .as_ref(), 42 | prompt, 43 | args.globals.suffix, 44 | ); 45 | 46 | body.temperature = args.globals.temperature; 47 | body.top_p = args.globals.top_p; 48 | if let Some(max_tokens) = args.globals.max_tokens { 49 | body.max_tokens = Some(max_tokens); 50 | }; 51 | if let Some(min_tokens) = args.globals.min_tokens { 52 | body.min_tokens = Some(min_tokens); 53 | }; 54 | 55 | log::info!("body: {:#?}", body); 56 | 57 | let stream = client.delta(&body)?; 58 | 59 | handle_stream( 60 | stream, 61 | args.globals.quiet.unwrap_or(false), 62 | args.globals.language, 63 | ) 64 | .await 65 | } 66 | -------------------------------------------------------------------------------- /crates/e/src/openai.rs: -------------------------------------------------------------------------------- 1 | use es_stream::openai; 2 | 3 | use crate::prelude::*; 4 | 5 | const DEFAULT_URL: &str = "https://api.openai.com/v1"; 6 | const DEFAULT_MODEL: &str = "gpt-4o"; 7 | const DEFAULT_ENV: &str = "OPENAI_API_KEY"; 8 | 9 | pub async fn run(prompt: String, args: Args) -> Result<()> { 10 | let key = match args.globals.api_key { 11 | Some(key) => key, 12 | None => { 13 | let environment_variable = match args.globals.api_env { 14 | Some(env) => env, 15 | None => DEFAULT_ENV.to_string(), 16 | }; 17 | std::env::var(environment_variable)? 18 | } 19 | }; 20 | log::info!("key: {}", key); 21 | 22 | let url = match args.globals.api_base_url { 23 | Some(url) => url, 24 | None => DEFAULT_URL.to_string(), 25 | }; 26 | 27 | log::info!("url: {}", url); 28 | 29 | let auth = openai::Auth::new(key); 30 | 31 | log::info!("auth: {:#?}", auth); 32 | 33 | let client = openai::Client::new(auth, url); 34 | 35 | log::info!("client: {:#?}", client); 36 | 37 | let messages = vec![openai::Message { 38 | role: openai::Role::User, 39 | content: prompt, 40 | }]; 41 | 42 | let mut body = openai::MessageBody::new( 43 | args.globals 44 | .model 45 | .unwrap_or(DEFAULT_MODEL.to_string()) 46 | .as_ref(), 47 | messages, 48 | ); 49 | 50 | if let Some(system) = args.globals.system { 51 | let system_message = openai::Message { 52 | role: openai::Role::System, 53 | content: system, 54 | }; 55 | 56 | body.messages.insert(0, system_message); 57 | } 58 | 59 | body.temperature = args.globals.temperature; 60 | body.top_p = args.globals.top_p; 61 | if let Some(max_tokens) = args.globals.max_tokens { 62 | body.max_tokens = Some(max_tokens); 63 | }; 64 | 65 | log::info!("body: {:#?}", body); 66 | 67 | let stream = client.delta(&body)?; 68 | 69 | handle_stream( 70 | stream, 71 | args.globals.quiet.unwrap_or(false), 72 | args.globals.language, 73 | ) 74 | .await 75 | } 76 | -------------------------------------------------------------------------------- /crates/e/src/prelude.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::{Stream, TryStreamExt}; 2 | use serde_json::Value; 3 | use std::io::Write; 4 | 5 | pub use crate::args::{Api, Args}; 6 | pub use crate::config::Config; 7 | pub use crate::error::Error; 8 | 9 | pub type Result = std::result::Result; 10 | 11 | pub async fn handle_stream( 12 | mut stream: impl Stream> 13 | + std::marker::Unpin, 14 | quiet: bool, 15 | language: String, 16 | ) -> Result<()> { 17 | let mut previous_output = String::new(); 18 | let mut accumulated_content_bytes: Vec = Vec::new(); 19 | 20 | let is_terminal = atty::is(atty::Stream::Stdout); 21 | 22 | let mut sp = if !quiet && is_terminal { 23 | Some(spinners::Spinner::new( 24 | spinners::Spinners::OrangeBluePulse, 25 | "Loading...".into(), 26 | )) 27 | } else { 28 | None 29 | }; 30 | 31 | while let Ok(Some(text)) = stream.try_next().await { 32 | if is_terminal && sp.is_some() { 33 | // TODO: Find a better way to clean the spinner from the terminal. 34 | sp.take().unwrap().stop(); 35 | std::io::stdout().flush()?; 36 | crossterm::execute!(std::io::stdout(), crossterm::cursor::MoveToColumn(0))?; 37 | print!(" "); 38 | crossterm::execute!(std::io::stdout(), crossterm::cursor::MoveToColumn(0))?; 39 | } 40 | 41 | if !is_terminal { 42 | // If not a terminal, print each instance of `text` directly to `stdout` 43 | print!("{}", text); 44 | std::io::stdout().flush()?; 45 | continue; 46 | } 47 | 48 | accumulated_content_bytes.extend_from_slice(text.as_bytes()); 49 | 50 | let output = crate::printer::CustomPrinter::new(&language)? 51 | .input_from_bytes(&accumulated_content_bytes) 52 | .print()?; 53 | 54 | let unprinted_lines = output 55 | .lines() 56 | .skip(if previous_output.lines().count() == 0 { 57 | 0 58 | } else { 59 | previous_output.lines().count() - 1 60 | }) 61 | .collect::>() 62 | .join("\n"); 63 | 64 | crossterm::execute!(std::io::stdout(), crossterm::cursor::MoveToColumn(0))?; 65 | print!("{unprinted_lines}"); 66 | std::io::stdout().flush()?; 67 | 68 | // Update the previous output 69 | previous_output = output; 70 | } 71 | 72 | Ok(()) 73 | } 74 | 75 | // Merges two JSON objects defined as `serde_json::Value`. 76 | pub fn merge(a: &mut Value, b: Value) { 77 | if let Value::Object(a) = a { 78 | if let Value::Object(b) = b { 79 | for (k, v) in b { 80 | if v.is_null() { 81 | a.remove(&k); 82 | } else { 83 | merge(a.entry(k).or_insert(Value::Null), v); 84 | } 85 | } 86 | 87 | return; 88 | } 89 | } 90 | 91 | *a = b; 92 | } 93 | -------------------------------------------------------------------------------- /crates/e/src/printer.rs: -------------------------------------------------------------------------------- 1 | use crossterm::terminal; 2 | 3 | use crate::prelude::*; 4 | 5 | // Markdown language constant string 6 | const DEFAULT_THEME: &str = "tokyonight-storm"; 7 | 8 | pub struct CustomPrinter<'a> { 9 | inputs: Vec>, 10 | config: bat::config::Config<'a>, 11 | assets: bat::assets::HighlightingAssets, 12 | term_width: Option, 13 | } 14 | 15 | impl<'a> CustomPrinter<'a> { 16 | pub fn new(language: &'a str) -> Result { 17 | let theme = std::env::var("BAT_THEME").unwrap_or_else(|_| DEFAULT_THEME.to_string()); 18 | 19 | let config = bat::config::Config { 20 | colored_output: true, 21 | true_color: true, 22 | language: Some(language), 23 | theme, 24 | use_italic_text: true, 25 | wrapping_mode: bat::WrappingMode::Character, 26 | ..Default::default() 27 | }; 28 | 29 | Ok(CustomPrinter { 30 | inputs: vec![], 31 | config, 32 | assets: bat::assets::HighlightingAssets::from_binary(), 33 | term_width: None, 34 | }) 35 | } 36 | 37 | /// Add a byte string as an input 38 | pub fn input_from_bytes(&mut self, content: &'a [u8]) -> &mut Self { 39 | self.input_from_reader(content) 40 | } 41 | 42 | /// Add a custom reader as an input 43 | pub fn input_from_reader(&mut self, reader: R) -> &mut Self { 44 | self.inputs 45 | .push(bat::input::Input::from_reader(Box::new(reader))); 46 | self 47 | } 48 | 49 | /// Custom print function that takes advantage of the fact that `bat` controllers can take a 50 | /// String as the output of the highlighted text. 51 | pub fn print(&mut self) -> Result { 52 | self.config.term_width = self 53 | .term_width 54 | .unwrap_or_else(|| terminal::size().unwrap().0 as usize); 55 | let inputs = std::mem::take(&mut self.inputs); 56 | 57 | let mut output = String::new(); 58 | 59 | let controller = bat::controller::Controller::new(&self.config, &self.assets); 60 | controller.run(inputs, Some(&mut output))?; 61 | 62 | Ok(output) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /crates/fs/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "fs" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "FileSystem Utilities" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [lib] 13 | name = "fs" 14 | path = "src/lib.rs" 15 | 16 | [dependencies] 17 | -------------------------------------------------------------------------------- /crates/fs/README.md: -------------------------------------------------------------------------------- 1 | # FS Utilities 2 | -------------------------------------------------------------------------------- /crates/fs/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::fs; 3 | use std::path::Path; 4 | 5 | /// Checks if a file exists. 6 | pub fn file_exists(filename: &str) -> bool { 7 | fs::metadata(filename).is_ok() 8 | } 9 | 10 | /// Chacks if a directory exists. 11 | pub fn directory_exists(dir_name: &str) -> bool { 12 | let path = Path::new(dir_name); 13 | path.exists() && path.is_dir() 14 | } 15 | 16 | /// Get HOME directory. 17 | pub fn get_home_directory(directory: &str) -> String { 18 | match env::var("HOME") { 19 | Ok(val) => val + directory, 20 | Err(_) => String::from(directory), 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /crates/openai/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "openai" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "OpenAI API Wrapper" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [features] 13 | proc-macro = [] 14 | 15 | [lib] 16 | name = "openai" 17 | path = "src/lib.rs" 18 | 19 | [dependencies] 20 | gpt_tokenizer = { version = "0.1.0", path = "../../lib/gpt_tokenizer" } 21 | custom_error = "1.9.2" # Define custom errors without boilerplate using the custom_error! macro. 22 | env_logger = "0.10.0" 23 | log = "0.4.17" 24 | reqwest = { version = "0.11.16", features = ["json"] } 25 | serde = { version = "1.0.152", features = ["derive"] } 26 | serde_either = "0.2.1" # Simple set to enums to deserialize and serialize data that can either be string, struct or vec 27 | serde_json = "1.0.93" 28 | serde_yaml = "0.9.19" # YAML data format for Serde 29 | tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.… 30 | reqwest-eventsource = "0.4.0" 31 | futures = "0.3.28" 32 | tokio-stream = "0.1.14" 33 | -------------------------------------------------------------------------------- /crates/openai/README.md: -------------------------------------------------------------------------------- 1 | # OpenAi 2 | 3 | Contains the necessary definitions and traits that are used by other traits to implement 4 | their desired behavior. 5 | -------------------------------------------------------------------------------- /crates/openai/src/client.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use crate::error; 4 | use log; 5 | use reqwest::header::{HeaderMap, HeaderValue}; 6 | use reqwest::{Client as ReqwestClient, Response as ReqwestResponse}; 7 | use reqwest_eventsource::EventSource; 8 | 9 | #[derive(Clone, Debug, Default)] 10 | pub struct Client { 11 | reqwest: ReqwestClient, 12 | base_url: String, 13 | headers: HeaderMap, 14 | } 15 | 16 | const OPEN_API_URL: &str = "https://api.openai.com"; 17 | 18 | fn create_headers(api_key: String) -> Result { 19 | let mut auth = String::from("Bearer "); 20 | auth.push_str(&api_key); 21 | 22 | let mut headers = HeaderMap::new(); 23 | let authorization = match HeaderValue::from_str(auth.as_str()) { 24 | Ok(x) => x, 25 | Err(e) => { 26 | return Err(error::OpenAi::RequestError { 27 | body: e.to_string(), 28 | }) 29 | } 30 | }; 31 | let content_type = match HeaderValue::from_str("application/json") { 32 | Ok(x) => x, 33 | Err(e) => { 34 | return Err(error::OpenAi::RequestError { 35 | body: e.to_string(), 36 | }) 37 | } 38 | }; 39 | 40 | headers.insert("Authorization", authorization); 41 | headers.insert("Content-Type", content_type); 42 | 43 | Ok(headers) 44 | } 45 | 46 | impl Client { 47 | /// Creates a new client. 48 | pub fn new(api_key: String) -> Result { 49 | let reqwest = match ReqwestClient::builder() 50 | .timeout(Duration::from_secs(300)) 51 | .build() 52 | { 53 | Ok(x) => x, 54 | Err(e) => { 55 | return Err(error::OpenAi::RequestError { 56 | body: e.to_string(), 57 | }); 58 | } 59 | }; 60 | 61 | log::debug!("Created reqwest client"); 62 | 63 | let headers = match create_headers(api_key) { 64 | Ok(x) => x, 65 | Err(e) => { 66 | return Err(error::OpenAi::RequestError { 67 | body: e.to_string(), 68 | }) 69 | } 70 | }; 71 | 72 | log::debug!("Created headers"); 73 | 74 | Ok(Client { 75 | reqwest, 76 | headers, 77 | base_url: OPEN_API_URL.to_string(), 78 | }) 79 | } 80 | 81 | /// Changes the client's base_url 82 | pub fn set_base_url(&mut self, base_url: String) -> &mut Self { 83 | self.base_url = base_url; 84 | self 85 | } 86 | 87 | /// Change the OpenAi API key 88 | pub fn set_api_key(&mut self, api_key: String) -> Result<&mut Self, error::OpenAi> { 89 | let headers = match create_headers(api_key) { 90 | Ok(x) => x, 91 | Err(e) => { 92 | return Err(error::OpenAi::RequestError { 93 | body: e.to_string(), 94 | }) 95 | } 96 | }; 97 | 98 | self.headers = headers; 99 | Ok(self) 100 | } 101 | 102 | /// Makes a GET request to the OpenAi API. 103 | pub async fn get(&self, endpoint: &str) -> Result { 104 | let mut url = self.base_url.clone(); 105 | url.push_str(endpoint); 106 | 107 | log::debug!("GET: {}", url); 108 | 109 | let request = self.reqwest.get(url).headers(self.headers.clone()); 110 | 111 | match request.send().await { 112 | Ok(x) => Ok(x), 113 | Err(e) => { 114 | log::error!("Error: {}", e); 115 | Err(error::OpenAi::RequestError { 116 | body: e.to_string(), 117 | }) 118 | } 119 | } 120 | } 121 | 122 | /// Makes a POST request to the OpenAi API that returns a SSE stream. 123 | pub async fn post_stream( 124 | &self, 125 | endpoint: &str, 126 | body: String, 127 | ) -> Result { 128 | let mut url = self.base_url.clone(); 129 | url.push_str(endpoint); 130 | 131 | log::debug!("POST: {}", url); 132 | 133 | let builder = self 134 | .reqwest 135 | .post(url) 136 | .headers(self.headers.clone()) 137 | .body(body); 138 | 139 | match EventSource::new(builder) { 140 | Ok(x) => Ok(x), 141 | Err(e) => Err(error::OpenAi::RequestError { 142 | body: e.to_string(), 143 | }), 144 | } 145 | } 146 | 147 | /// Makes a POST request to the OpenAi API. 148 | pub async fn post( 149 | &self, 150 | endpoint: &str, 151 | body: String, 152 | ) -> Result { 153 | let mut url = self.base_url.clone(); 154 | url.push_str(endpoint); 155 | 156 | log::debug!("POST: {}1", url); 157 | 158 | match self 159 | .reqwest 160 | .post(url) 161 | .headers(self.headers.clone()) 162 | .body(body) 163 | .send() 164 | .await 165 | { 166 | Ok(x) => Ok(x), 167 | Err(e) => Err(error::OpenAi::RequestError { 168 | body: e.to_string(), 169 | }), 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /crates/openai/src/edits.rs: -------------------------------------------------------------------------------- 1 | use log; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | use crate::client::Client; 5 | use crate::error; 6 | 7 | #[derive(Debug, Serialize, Deserialize, Default)] 8 | pub struct EditsApi { 9 | #[serde(skip)] 10 | client: Client, 11 | // Edits Properties 12 | pub model: String, 13 | pub input: String, 14 | pub instruction: String, 15 | #[serde(skip_serializing_if = "Option::is_none")] 16 | pub n: Option, 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | temperature: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | top_p: Option, 21 | } 22 | 23 | #[derive(Debug, Serialize, Deserialize, Default)] 24 | pub struct Edit { 25 | pub object: String, 26 | pub created: u64, 27 | pub choices: Vec, 28 | pub usage: Usage, 29 | } 30 | 31 | #[derive(Debug, Serialize, Deserialize, Default)] 32 | pub struct Choice { 33 | pub text: String, 34 | pub index: u32, 35 | } 36 | 37 | #[derive(Debug, Serialize, Deserialize, Default)] 38 | pub struct Usage { 39 | pub prompt_tokens: u32, 40 | pub completion_tokens: u32, 41 | pub total_tokens: u32, 42 | } 43 | 44 | const DEFAULT_MODEL: &str = "code-davinci-edit-001"; 45 | 46 | impl EditsApi { 47 | pub fn new(api_key: String) -> Result { 48 | let client = match Client::new(api_key) { 49 | Ok(client) => client, 50 | Err(err) => { 51 | return Err(error::OpenAi::ClientError { 52 | body: err.to_string(), 53 | }); 54 | } 55 | }; 56 | 57 | log::debug!("Created OpenAi HTTP Client"); 58 | 59 | Ok(Self { 60 | client, 61 | model: DEFAULT_MODEL.to_string(), 62 | ..Default::default() 63 | }) 64 | } 65 | 66 | /// Gets the value of the temperature. 67 | pub fn get_temperature(self) -> Option { 68 | self.temperature 69 | } 70 | 71 | /// Sets the value of the temperature. 72 | pub fn set_temperature(&mut self, temperature: f32) -> Result<&mut Self, error::OpenAi> { 73 | if !(0.0..=2.0).contains(&temperature) { 74 | return Err(error::OpenAi::InvalidTemperature { temperature }); 75 | } 76 | self.temperature = Some(temperature); 77 | 78 | log::debug!("Set temperature to {}", temperature); 79 | 80 | Ok(self) 81 | } 82 | 83 | /// Gets the value of the top_p. 84 | pub fn get_top_p(self) -> Option { 85 | self.top_p 86 | } 87 | 88 | /// Sets the value of the top_p. 89 | pub fn set_top_p(&mut self, top_p: f32) -> Result<&mut Self, error::OpenAi> { 90 | if !(0.0..=2.0).contains(&top_p) { 91 | return Err(error::OpenAi::InvalidTopP { top_p }); 92 | } 93 | self.top_p = Some(top_p); 94 | 95 | log::debug!("Set top_p to {}", top_p); 96 | 97 | Ok(self) 98 | } 99 | 100 | /// Creates an edit from the provided parameters. 101 | pub async fn create(&self) -> Result { 102 | let request = match serde_json::to_string(&self) { 103 | Ok(request) => request, 104 | Err(err) => { 105 | return Err(error::OpenAi::SerializationError { 106 | body: err.to_string(), 107 | }); 108 | } 109 | }; 110 | 111 | let body = match self.client.post("/edits", request).await { 112 | Ok(response) => match response.text().await { 113 | Ok(text) => text, 114 | Err(e) => { 115 | return Err(error::OpenAi::RequestError { 116 | body: e.to_string(), 117 | }) 118 | } 119 | }, 120 | Err(e) => { 121 | return Err(error::OpenAi::RequestError { 122 | body: e.to_string(), 123 | }) 124 | } 125 | }; 126 | 127 | let body: Edit = match serde_json::from_str(&body) { 128 | Ok(body) => body, 129 | Err(e) => { 130 | return Err(error::OpenAi::RequestError { 131 | body: e.to_string(), 132 | }) 133 | } 134 | }; 135 | 136 | Ok(body) 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /crates/openai/src/error.rs: -------------------------------------------------------------------------------- 1 | use custom_error::custom_error; 2 | 3 | custom_error! {pub OpenAi 4 | ClientError{body: String} = "client error:\n{body}", 5 | DeserializationError{body: String} = "deserialization error: {body}", 6 | FileError{body: String} = "file error: {body}", 7 | InvalidBestOf = "'best_of' cannot be used with 'stream'", 8 | InvalidEcho = "'echo' cannot be used with 'suffix'", 9 | InvalidFrequencyPenalty{frequency_penalty: f32} = "frequency_penalty ({frequency_penalty}) must be between -2.0 and 2.0", 10 | InvalidLogProbs{logprobs: f32} = "logprob value ({logprobs}) must be between 0 and 5", 11 | InvalidPresencePenalty{presence_penalty: f32} = "presence_penalty value ({presence_penalty}) must be between -2.0 and 2.0", 12 | InvalidStop{stop: String} = "stop value ({stop}) must be either 'left' or 'right'", 13 | InvalidStream = "'stream' cannot be used with 'best_of'", 14 | InvalidSuffix = "'suffix' cannot be used with 'echo'", 15 | InvalidTemperature{temperature: f32} = "temperature value ({temperature}) must be between 0.0 and 2.0", 16 | InvalidTopP{top_p: f32} = "top_p value ({top_p}) must be between 0 and 1", 17 | ModelNotFound{model_name: String} = "model not found: {model_name}", 18 | NoChoices = "no chat choices", 19 | NoSession = "no session", 20 | RequestError{body: String} = "request error: {body}", 21 | SerializationError{body: String} = "serialization error: {body}", 22 | StreamError = "stream error", 23 | TrimError = "could not find a message to trim", 24 | UknownError = "unknown error", 25 | } 26 | -------------------------------------------------------------------------------- /crates/openai/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod chats; 2 | pub mod client; 3 | pub mod completions; 4 | pub mod edits; 5 | pub mod error; 6 | pub mod models; 7 | pub mod utils; 8 | -------------------------------------------------------------------------------- /crates/openai/src/models.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::client::Client; 4 | use crate::error; 5 | use log; 6 | 7 | /// ModelsAPI struct. 8 | #[derive(Debug, Default)] 9 | pub struct ModelsApi { 10 | client: Client, 11 | } 12 | 13 | /// OpenAi Completions Model. 14 | #[derive(Serialize, Deserialize, Debug)] 15 | pub struct Model { 16 | pub id: String, 17 | pub object: String, 18 | pub owned_by: String, 19 | pub created: i64, 20 | pub permission: Vec, 21 | pub root: String, 22 | pub parent: Option, 23 | } 24 | 25 | /// OpenAi Model permissions. 26 | #[derive(Serialize, Deserialize, Debug)] 27 | pub struct ModelPermission { 28 | pub id: String, 29 | pub object: String, 30 | pub created: i64, 31 | pub allow_create_engine: bool, 32 | pub allow_sampling: bool, 33 | pub allow_logprobs: bool, 34 | pub allow_search_indices: bool, 35 | pub allow_view: bool, 36 | pub allow_fine_tuning: bool, 37 | pub organization: String, 38 | pub group: Option, 39 | pub is_blocking: bool, 40 | } 41 | 42 | /// OpenAi Models Request Body 43 | #[derive(Serialize, Deserialize, Debug)] 44 | pub struct ModelsRequestBody { 45 | pub data: Vec, 46 | pub object: String, 47 | } 48 | 49 | impl ModelsApi { 50 | pub fn new(client: Client) -> Self { 51 | let client = Self { client }; 52 | 53 | log::debug!("Created ModelsApi"); 54 | 55 | client 56 | } 57 | 58 | pub async fn list(&self) -> Result, error::OpenAi> { 59 | let body = match self.client.get("/models").await { 60 | Ok(response) => match response.text().await { 61 | Ok(text) => text, 62 | Err(e) => { 63 | return Err(error::OpenAi::RequestError { 64 | body: e.to_string(), 65 | }) 66 | } 67 | }, 68 | Err(e) => { 69 | return Err(error::OpenAi::RequestError { 70 | body: e.to_string(), 71 | }) 72 | } 73 | }; 74 | 75 | let body: ModelsRequestBody = match serde_json::from_str(&body) { 76 | Ok(body) => body, 77 | Err(e) => { 78 | return Err(error::OpenAi::RequestError { 79 | body: e.to_string(), 80 | }) 81 | } 82 | }; 83 | 84 | Ok(body.data) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /crates/openai/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::fs; 3 | use std::path::Path; 4 | 5 | /// Checks if a file exists. 6 | pub fn file_exists(filename: &str) -> bool { 7 | fs::metadata(filename).is_ok() 8 | } 9 | 10 | /// Chacks if a directory exists. 11 | pub fn directory_exists(dir_name: &str) -> bool { 12 | let path = Path::new(dir_name); 13 | path.exists() && path.is_dir() 14 | } 15 | 16 | /// Get HOME directory. 17 | pub fn get_home_directory() -> String { 18 | match env::var("HOME") { 19 | Ok(val) => val + "/.b/sessions", 20 | Err(_) => String::from("/tmp/.b/sessions"), 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /crates/spinner/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "spinner" 3 | version = "0.0.0" 4 | homepage = "https://github.com/cloudbridgeuy/a" 5 | description = "Spinner component to use in cli applications" 6 | autobins = false 7 | 8 | authors.workspace = true 9 | edition.workspace = true 10 | license.workspace = true 11 | 12 | [lib] 13 | name = "spinner" 14 | path = "src/lib.rs" 15 | 16 | [dependencies] 17 | indicatif = "0.17.5" 18 | tracing = { version = "0.1.37", features = ["max_level_debug", "release_max_level_warn"] } 19 | tokio = { version = "1.27.0", features = ["full"] } 20 | tokio-stream = "0.1.14" 21 | rand = "0.8.5" 22 | -------------------------------------------------------------------------------- /crates/spinner/README.md: -------------------------------------------------------------------------------- 1 | # Spinner 2 | -------------------------------------------------------------------------------- /crates/spinner/src/lib.rs: -------------------------------------------------------------------------------- 1 | use indicatif::{ProgressBar, ProgressStyle}; 2 | use rand::seq::SliceRandom; 3 | use std::{sync::Arc, time::Duration}; 4 | use tokio::sync::Mutex; 5 | use tokio_stream::StreamExt; 6 | 7 | /// List of phrases used for the phrases spinner. 8 | static PHRASES: [&str; 113] = [ 9 | "Lubricating the hamsters...", 10 | "Winding up the rubber band...", 11 | "Giving the bits a stern talking to...", 12 | "Bribing the loading bar...", 13 | "Feeding carrots to the loading donkey...", 14 | "Pressing the 'Any' key...", 15 | "Herding the CPU sheep...", 16 | "Pruning redundant binary trees...", 17 | "Debugging the DeLorean's flux capacitor...", 18 | "Realigning the dilithium crystals...", 19 | "Untangling the interwebs...", 20 | "Petting the tribbles...", 21 | "Locating the missing Oxford comma...", 22 | "Rounding up missing semicolons...", 23 | "Tightening loose CSS floats...", 24 | "Plugging memory leaks...", 25 | "Updating Adobe Reader...", 26 | "Disinfecting droids...", 27 | "Defragmenting the hard drive...", 28 | "Polishing pixels...", 29 | "Resetting the odometer...", 30 | "Priming the warp core...", 31 | "Tuning the holodeck...", 32 | "Spinning up the flux inverter...", 33 | "Starting engines...", 34 | "Lubricating wheels...", 35 | "Corraling the penguins...", 36 | "Baking pixels...", 37 | "Putting the 'fun' in fundamental algorithms...", 38 | "Herding cats...", 39 | "Inserting witty loading message...", 40 | "Teaching monkeys to type...", 41 | "Untangling the wires...", 42 | "Locating sense of purpose...", 43 | "Spinning up the hamster wheels...", 44 | "Fluffing the pillows...", 45 | "Straightening the rug...", 46 | "Feeding the fish...", 47 | "Photon alignment...", 48 | "Reconfiguring quibits...", 49 | "Translating whims into action items...", 50 | "Activating sloths...", 51 | "Priming the pump...", 52 | "Twiddling the bits...", 53 | "Polishing the monocle...", 54 | "Herding the cats...", 55 | "Locating my marbles...", 56 | "Unsticking stuck pixels...", 57 | "Untangling the yarn...", 58 | "Alphabetizing the library...", 59 | "Dusting the cobwebs...", 60 | "Poking the angry badger...", 61 | "Re-hydrating the fish...", 62 | "Counting the grains of sand...", 63 | "Tuning the orchestra...", 64 | "Beating the high score...", 65 | "Photocopying the paperwork...", 66 | "Proofreading the dictionary...", 67 | "Translating to Pig Latin...", 68 | "De-wrinkling the fabric of space-time...", 69 | "Rounding up the unicorns...", 70 | "Sharpening the pencils...", 71 | "Milking the concrete cow...", 72 | "Herding Schrödinger's cats...", 73 | "Digitizing the analog...", 74 | "Retouching the masterpiece...", 75 | "Debugging the dreamweaver...", 76 | "Restocking the water cooler...", 77 | "Reloading the motivation...", 78 | "Calibrating the mood rings...", 79 | "Fluffing the pillows...", 80 | "Coiling the garden hose...", 81 | "Putting on my wizard hat...", 82 | "Baking your cookies...", 83 | "Locating my eye patch...", 84 | "Polishin' my monocle...", 85 | "Fetching my quill and parchment...", 86 | "Saddling the centaurs...", 87 | "Rounding up the hedgehogs...", 88 | "Tuning my banjo...", 89 | "Anthropomorphizing the mushrooms...", 90 | "Driving the snails...", 91 | "Turning the cranks and tightening the springs...", 92 | "Poking the bees...", 93 | "Rustling the jimmies...", 94 | "Counting the spiderwebs...", 95 | "Organizing my sock drawer...", 96 | "Auditioning the crickets...", 97 | "Testing the waters...", 98 | "Routing the pigeons...", 99 | "Pulling the levers...", 100 | "Winding the clock...", 101 | "Consulting the ancient scrolls...", 102 | "Polishing the pixels...", 103 | "Reticulating splines...", 104 | "Herding cats...", 105 | "Aligning the flux capacitor...", 106 | "Spinning up the hamster wheel...", 107 | "Baking the cookies...", 108 | "Calibrating the confetti cannons...", 109 | "Ironing out the electrons...", 110 | "Rebooting the sugar rush...", 111 | "Inserting witty message here...", 112 | "Activating imagination modules...", 113 | "Distracting you with this message...", 114 | "Gathering 1's and 0's...", 115 | "Locating sense of purpose...", 116 | "Brewing another pot of coffee...", 117 | "Start the reactor... (imagine heavy machinery noise)", 118 | "Raise the mizzenmast...", 119 | "Turn the crank...", 120 | "Wind up the gramophone...", 121 | "Ready the trebuchet...", 122 | ]; 123 | 124 | /// Struct to hold the spinner implementation. 125 | pub struct Spinner { 126 | progress_bar: ProgressBar, 127 | state: State, 128 | } 129 | 130 | /// The State represents the status of the spinner. 131 | #[derive(Debug, PartialEq)] 132 | enum State { 133 | /// The spinner is running 134 | Run, 135 | /// The spinner is stopped 136 | Stop, 137 | /// The spinner errored out 138 | Error, 139 | } 140 | 141 | impl Spinner { 142 | /// Creates a new Spinner 143 | pub fn new() -> Self { 144 | let progress_bar = ProgressBar::new_spinner(); 145 | progress_bar.enable_steady_tick(Duration::from_millis(100)); 146 | progress_bar.set_style( 147 | ProgressStyle::with_template("{spinner:.magenta} {msg}") 148 | .unwrap() 149 | .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]), 150 | ); 151 | 152 | Self { 153 | progress_bar, 154 | state: State::Run, 155 | } 156 | } 157 | 158 | /// Cretes a new Spinner that will be constantly chainging its loading message. 159 | pub fn new_with_checky_messages(millis: u64) -> std::sync::Arc> { 160 | let spinner_arc = Arc::new(Mutex::new(Spinner::new())); 161 | let spinner_clone = Arc::clone(&spinner_arc); 162 | 163 | let mut rng = rand::thread_rng(); 164 | let mut phrases = PHRASES.to_vec(); 165 | phrases.shuffle(&mut rng); 166 | 167 | tokio::spawn(async move { 168 | let mut stream = tokio_stream::iter(phrases.iter()); 169 | while let Some(phrase) = stream.next().await { 170 | let mut spinner = spinner_arc.lock().await; 171 | if spinner.state != State::Run { 172 | break; 173 | } 174 | spinner.message(phrase); 175 | tokio::time::sleep(Duration::from_millis(millis)).await; 176 | } 177 | }); 178 | 179 | spinner_clone 180 | } 181 | 182 | /// Prints a message along the spinner. 183 | pub fn message(&mut self, msg: &str) { 184 | if self.state == State::Run { 185 | self.progress_bar.set_message(msg.to_string()); 186 | } 187 | } 188 | 189 | /// Prints a message along the spinner. 190 | pub fn print(&mut self, msg: &str) { 191 | if self.state == State::Run { 192 | self.progress_bar.suspend(|| { 193 | print!("{}", msg); 194 | }); 195 | } 196 | } 197 | 198 | // Stops the execution of the spinner. 199 | pub fn stop(&mut self) { 200 | if self.state == State::Run { 201 | self.state = State::Stop; 202 | self.progress_bar.finish_and_clear(); 203 | tracing::event!(tracing::Level::INFO, "Done"); 204 | } 205 | } 206 | 207 | // Stops the execution of the spinner with an error. 208 | pub fn error(&mut self, msg: &str) { 209 | if self.state == State::Run { 210 | self.state = State::Error; 211 | self.progress_bar.abandon_with_message(msg.to_string()); 212 | tracing::event!(tracing::Level::ERROR, "{}", msg); 213 | } 214 | } 215 | } 216 | 217 | impl Default for Spinner { 218 | fn default() -> Self { 219 | Self::new() 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /lib/README.md: -------------------------------------------------------------------------------- 1 | # /Lib 2 | 3 | > Crates in this directory are published to crates.io and **must obey** semver. 4 | -------------------------------------------------------------------------------- /lib/es_stream/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "es_stream" 3 | version = "0.1.0" 4 | description = "A very simple Rust library to simplify streaming api interaction with LLMs, free from complex async operations and redundant dependencies." 5 | license = "MIT" 6 | repository = "https://github.com/cloudbridgeuy/gpt/tree/main/lib/stream" 7 | edition = "2021" 8 | keywords = ["stream", "api", "llm", "api-bindings"] 9 | 10 | [lub] 11 | name = "es_stream" 12 | path = "src/lib.rs" 13 | 14 | [dependencies] 15 | ureq = { version = "2.10.1", features = ["json"] } 16 | serde = { version = "1.0.209", features = ["derive"] } 17 | serde_json = "1.0.127" 18 | log = "0.4.22" 19 | env_logger = "0.11.5" 20 | futures = "0.3.30" 21 | eventsource-client = "0.13.0" 22 | thiserror = "1.0.63" 23 | 24 | [dev-dependencies] 25 | anyhow = "1.0.86" 26 | tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } 27 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_anthropic.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::anthropic::{Auth, Client, Message, MessageBody, Role}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("ANTHROPIC_API_KEY")?; 11 | 12 | let auth = Auth::new(key, None); 13 | let client = Client::new(auth, "https://api.anthropic.com/v1"); 14 | 15 | let messages = vec![Message { 16 | role: Role::User, 17 | content: "What is the capital of the United States?".to_string(), 18 | }]; 19 | 20 | let body = MessageBody::new("claude-3-opus-20240229", messages, 300); 21 | 22 | // let mut stream = client.message_stream(&body)?; 23 | let mut stream = client.delta(&body)?; 24 | 25 | while let Ok(Some(text)) = stream.try_next().await { 26 | print!("{text}"); 27 | std::io::stdout().flush()?; 28 | } 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_copilot.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::openai::{Auth, Client, Message, MessageBody, Role}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("COPILOT_API_KEY")?; 11 | 12 | let auth = Auth::new(key); 13 | let client = Client::new(auth, "https://api.githubcopilot.com"); 14 | 15 | let messages = vec![Message { 16 | role: Role::User, 17 | content: "What is the capital of the United States?".to_string(), 18 | }]; 19 | 20 | let body = MessageBody::new("gpt-4o", messages); 21 | 22 | // let mut stream = client.message_stream(&body)?; 23 | let mut stream = client.delta(&body)?; 24 | 25 | while let Ok(Some(text)) = stream.try_next().await { 26 | print!("{text}"); 27 | std::io::stdout().flush()?; 28 | } 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_google.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::google::{Auth, Client, Content, MessageBody, Part, Role}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("GOOGLE_API_KEY")?; 11 | 12 | let auth = Auth::new(key); 13 | let client = Client::new(auth, "https://generativelanguage.googleapis.com/v1beta"); 14 | 15 | let messages = vec![Content { 16 | parts: vec![Part { 17 | text: "What is the capital of the United States?".to_string(), 18 | }], 19 | role: Role::User, 20 | }]; 21 | 22 | let body = MessageBody::new("gemini-1.5-flash", messages); 23 | 24 | // let mut stream = client.message_stream(&body)?; 25 | let mut stream = client.delta(&body)?; 26 | 27 | while let Ok(Some(text)) = stream.try_next().await { 28 | print!("{text}"); 29 | std::io::stdout().flush()?; 30 | } 31 | 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_mistral.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::mistral::{Auth, Client, Message, MessageBody, Role}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("MISTRAL_API_KEY")?; 11 | 12 | let auth = Auth::new(key); 13 | let client = Client::new(auth, "https://api.mistral.ai/v1"); 14 | 15 | let messages = vec![Message { 16 | role: Role::User, 17 | content: "What is the capital of the United States?".to_string(), 18 | }]; 19 | 20 | let body = MessageBody::new("mistral-small-latest", messages); 21 | 22 | // let mut stream = client.message_stream(&body)?; 23 | let mut stream = client.delta(&body)?; 24 | 25 | while let Ok(Some(text)) = stream.try_next().await { 26 | print!("{text}"); 27 | std::io::stdout().flush()?; 28 | } 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_mistral_fim.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::mistral_fim::{Auth, Client, MessageBody}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("MISTRAL_API_KEY")?; 11 | 12 | let auth = Auth::new(key); 13 | let client = Client::new(auth, "https://api.mistral.ai/v1"); 14 | 15 | let prompt = "def coin_problem_solved_with_dp".to_string(); 16 | let suffix = Some("return result".to_string()); 17 | 18 | let body = MessageBody::new("codestral-2405", prompt, suffix); 19 | 20 | // let mut stream = client.message_stream(&body)?; 21 | let mut stream = client.delta(&body)?; 22 | 23 | while let Ok(Some(text)) = stream.try_next().await { 24 | print!("{text}"); 25 | std::io::stdout().flush()?; 26 | } 27 | 28 | Ok(()) 29 | } 30 | -------------------------------------------------------------------------------- /lib/es_stream/examples/stream_openai.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use es_stream::openai::{Auth, Client, Message, MessageBody, Role}; 3 | use futures::stream::TryStreamExt; 4 | use std::io::Write; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | env_logger::init(); 9 | 10 | let key = std::env::var("OPENAI_API_KEY")?; 11 | 12 | let auth = Auth::new(key); 13 | let client = Client::new(auth, "https://api.openai.com/v1"); 14 | 15 | let messages = vec![Message { 16 | role: Role::User, 17 | content: "What is the capital of the United States?".to_string(), 18 | }]; 19 | 20 | let body = MessageBody::new("gpt-4o", messages); 21 | 22 | // let mut stream = client.message_stream(&body)?; 23 | let mut stream = client.delta(&body)?; 24 | 25 | while let Ok(Some(text)) = stream.try_next().await { 26 | print!("{text}"); 27 | std::io::stdout().flush()?; 28 | } 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /lib/es_stream/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// Error type returned from this library's functions 4 | #[derive(Debug, Error)] 5 | pub enum Error { 6 | /// An error when creating the SSE stream. 7 | #[error("Eventsource Client error: {0}")] 8 | EventsourceClient(#[from] eventsource_client::Error), 9 | /// An Error returned by the API 10 | #[error("AuthError Error: {0}")] 11 | AuthError(String), 12 | /// An Error returned by the API 13 | #[error("API Error: {0}")] 14 | ApiError(String), 15 | /// An Error not related to the API 16 | #[error("Request Error: {0}")] 17 | RequestError(String), 18 | /// De/serialization error 19 | #[error("de/serialize error: {0}")] 20 | Serde(#[from] serde_json::error::Error), 21 | /// An Error occurred when performing an IO operation. 22 | #[error("io error: {0}")] 23 | IO(#[from] std::io::Error), 24 | } 25 | -------------------------------------------------------------------------------- /lib/es_stream/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod anthropic; 2 | pub mod error; 3 | pub mod google; 4 | pub mod mistral; 5 | pub mod mistral_fim; 6 | pub mod openai; 7 | pub mod requests; 8 | -------------------------------------------------------------------------------- /lib/es_stream/src/mistral.rs: -------------------------------------------------------------------------------- 1 | use eventsource_client as es; 2 | use futures::stream::{Stream, StreamExt}; 3 | use serde::{Deserialize, Serialize}; 4 | use std::time::Duration; 5 | 6 | use crate::error::Error; 7 | use crate::requests::{Json, Requests}; 8 | 9 | // Chat Completion API 10 | const CHAT_API: &str = "/chat/completions"; 11 | 12 | #[derive(Debug, Serialize, Deserialize, Clone)] 13 | pub struct Message { 14 | pub role: Role, 15 | pub content: String, 16 | } 17 | 18 | #[derive(Debug, Serialize, Deserialize, Clone)] 19 | #[serde(rename_all = "lowercase")] 20 | pub enum Role { 21 | System, 22 | Assistant, 23 | User, 24 | } 25 | 26 | #[derive(Debug, Serialize, Deserialize, Default)] 27 | pub struct MessageBody { 28 | /// ID of the model to use. You can use the [List Available Models API](https://docs.mistral.ai/api/#tag/models/operation/list_models_v1_models_get) to see all of your available models, or see our [Model overview](https://docs.mistral.ai/models) for model descriptions. 29 | pub model: String, 30 | /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. 31 | #[serde(skip_serializing_if = "Option::is_none")] 32 | pub temperature: Option, 33 | /// Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. 34 | #[serde(skip_serializing_if = "Option::is_none")] 35 | pub top_p: Option, 36 | /// The maximum number of tokens to generate in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length. 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | pub max_tokens: Option, 39 | /// The minimum number of tokens to generate in the completion. 40 | #[serde(skip_serializing_if = "Option::is_none")] 41 | pub min_tokens: Option, 42 | /// Stop generation if this token is detected. Or if one of these tokens is detected when providing an array 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | pub stop: Option>, 45 | /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. 46 | pub messages: Vec, 47 | /// Whether to stream back partial progress. If set, tokens will be sent as data-only server-side events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. 48 | #[serde(skip_serializing_if = "Option::is_none")] 49 | pub stream: Option, 50 | /// The seed to use for random sampling. If set, different calls will generate deterministic results. 51 | #[serde(skip_serializing_if = "Option::is_none")] 52 | pub random_seed: Option, 53 | } 54 | 55 | impl MessageBody { 56 | /// Creates a new `MessageBody` 57 | #[must_use] 58 | pub fn new(model: &str, messages: Vec) -> Self { 59 | Self { 60 | model: model.into(), 61 | messages, 62 | stream: Some(true), 63 | ..Default::default() 64 | } 65 | } 66 | } 67 | 68 | #[derive(Debug, Serialize, Deserialize)] 69 | pub struct ChatCompletionChunk { 70 | pub id: String, 71 | pub object: String, 72 | pub created: u64, 73 | pub model: String, 74 | pub choices: Vec, 75 | pub usage: Option, 76 | } 77 | 78 | #[derive(Debug, Serialize, Deserialize)] 79 | pub struct Choice { 80 | pub index: u32, 81 | pub delta: Delta, 82 | pub finish_reason: Option, 83 | pub logprobs: Option, 84 | } 85 | 86 | #[derive(Debug, Serialize, Deserialize)] 87 | pub struct Delta { 88 | pub role: Option, 89 | pub content: String, 90 | } 91 | 92 | #[derive(Debug, Serialize, Deserialize)] 93 | pub struct Usage { 94 | pub prompt_tokens: u32, 95 | pub total_tokens: u32, 96 | pub completion_tokens: u32, 97 | } 98 | 99 | #[derive(Debug, Serialize, Deserialize, Clone)] 100 | pub struct Auth { 101 | pub api_key: String, 102 | } 103 | 104 | impl Auth { 105 | #[must_use] 106 | pub fn new(api_key: String) -> Self { 107 | Self { api_key } 108 | } 109 | 110 | pub fn from_env() -> Result { 111 | let api_key = match std::env::var("MISTRAL_API_KEY") { 112 | Ok(key) => key, 113 | Err(_) => return Err(Error::AuthError("MISTRAL_API_KEY not found".to_string())), 114 | }; 115 | Ok(Self { api_key }) 116 | } 117 | } 118 | 119 | #[derive(Debug, Clone)] 120 | pub struct Client { 121 | pub auth: Auth, 122 | pub api_url: String, 123 | } 124 | 125 | impl Client { 126 | pub fn new(auth: Auth, api_url: impl Into) -> Self { 127 | Self { 128 | auth, 129 | api_url: api_url.into(), 130 | } 131 | } 132 | } 133 | 134 | impl Client { 135 | pub fn delta<'a>( 136 | &'a self, 137 | message_body: &'a MessageBody, 138 | ) -> Result> + 'a, Error> { 139 | log::debug!("message_body: {:#?}", message_body); 140 | 141 | let request_body = match serde_json::to_value(message_body) { 142 | Ok(body) => body, 143 | Err(e) => return Err(Error::Serde(e)), 144 | }; 145 | log::debug!("request_body: {:#?}", request_body); 146 | 147 | let original_stream = match self.post_stream(CHAT_API.to_string(), request_body) { 148 | Ok(stream) => stream, 149 | Err(e) => return Err(Error::EventsourceClient(e)), 150 | }; 151 | 152 | let mapped_stream = original_stream.map(|item| { 153 | if item.is_err() { 154 | return Err(Error::EventsourceClient(item.err().unwrap())); 155 | } 156 | item.map(|event| match event { 157 | es::SSE::Connected(_) => String::default(), 158 | es::SSE::Event(ev) => match serde_json::from_str::(&ev.data) { 159 | Ok(chunk) => { 160 | if chunk.choices.is_empty() { 161 | String::default() 162 | } else { 163 | chunk.choices.first().unwrap().delta.content.clone() 164 | } 165 | } 166 | Err(_) => String::default(), 167 | }, 168 | es::SSE::Comment(comment) => { 169 | log::debug!("Comment: {:#?}", comment); 170 | String::default() 171 | } 172 | }) 173 | .map_err(Error::from) 174 | }); 175 | 176 | Ok(mapped_stream) 177 | } 178 | } 179 | 180 | impl Requests for Client { 181 | fn post_stream( 182 | &self, 183 | sub_url: String, 184 | body: Json, 185 | ) -> Result>, es::Error> { 186 | let authorization: &str = &format!("Bearer {}", self.auth.api_key); 187 | 188 | let client = es::ClientBuilder::for_url(&(self.api_url.clone() + &sub_url))? 189 | .header("content-type", "application/json")? 190 | .header("authorization", authorization)? 191 | .method("POST".into()) 192 | .body(body.to_string()) 193 | .reconnect( 194 | es::ReconnectOptions::reconnect(true) 195 | .retry_initial(false) 196 | .delay(Duration::from_secs(1)) 197 | .backoff_factor(2) 198 | .delay_max(Duration::from_secs(60)) 199 | .build(), 200 | ) 201 | .build(); 202 | 203 | Ok(crate::requests::tail(&client)) 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /lib/es_stream/src/mistral_fim.rs: -------------------------------------------------------------------------------- 1 | use eventsource_client as es; 2 | use futures::stream::{Stream, StreamExt}; 3 | use serde::{Deserialize, Serialize}; 4 | use std::time::Duration; 5 | 6 | use crate::error::Error; 7 | use crate::requests::{Json, Requests}; 8 | 9 | // Fill in the Middle Completion API 10 | const FIM_API: &str = "/fim/completions"; 11 | 12 | #[derive(Debug, Serialize, Deserialize, Default)] 13 | pub struct MessageBody { 14 | /// ID of the model to use. You can use the [List Available Models API](https://docs.mistral.ai/api/#tag/models/operation/list_models_v1_models_get) to see all of your available models, or see our [Model overview](https://docs.mistral.ai/models) for model descriptions. 15 | pub model: String, 16 | /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both. 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub temperature: Option, 19 | /// Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub top_p: Option, 22 | /// The maximum number of tokens to generate in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length. 23 | #[serde(skip_serializing_if = "Option::is_none")] 24 | pub max_tokens: Option, 25 | /// The minimum number of tokens to generate in the completion. 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | pub min_tokens: Option, 28 | /// Stop generation if this token is detected. Or if one of these tokens is detected when providing an array 29 | #[serde(skip_serializing_if = "Option::is_none")] 30 | pub stop: Option>, 31 | /// Whether to stream back partial progress. If set, tokens will be sent as data-only server-side events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. 32 | #[serde(skip_serializing_if = "Option::is_none")] 33 | pub stream: Option, 34 | /// The seed to use for random sampling. If set, different calls will generate deterministic results. 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | pub random_seed: Option, 37 | /// The text/code to complete. 38 | pub prompt: String, 39 | /// Optional text/code that adds more context for the model. When given a prompt and a suffix the model will fill what is between them. When suffix is not provided, the model will simply execute completion starting with prompt. 40 | #[serde(skip_serializing_if = "Option::is_none")] 41 | pub suffix: Option, 42 | } 43 | 44 | impl MessageBody { 45 | /// Creates a new `MessageBody` 46 | #[must_use] 47 | pub fn new(model: &str, prompt: String, suffix: Option) -> Self { 48 | Self { 49 | model: model.into(), 50 | prompt, 51 | suffix, 52 | stream: Some(true), 53 | ..Default::default() 54 | } 55 | } 56 | } 57 | 58 | #[derive(Serialize, Deserialize, Debug)] 59 | pub struct FimCompletionsChunk { 60 | pub id: String, 61 | pub object: String, 62 | pub created: u64, 63 | pub model: String, 64 | pub choices: Vec, 65 | } 66 | 67 | #[derive(Serialize, Deserialize, Debug)] 68 | pub struct Choice { 69 | pub index: u32, 70 | pub delta: Delta, 71 | pub finish_reason: Option, 72 | pub logprobs: Option, 73 | } 74 | 75 | #[derive(Serialize, Deserialize, Debug)] 76 | pub struct Delta { 77 | pub content: String, 78 | } 79 | 80 | pub use crate::mistral::Auth; 81 | 82 | #[derive(Debug, Clone)] 83 | pub struct Client { 84 | pub auth: Auth, 85 | pub api_url: String, 86 | } 87 | 88 | impl Client { 89 | pub fn new(auth: Auth, api_url: impl Into) -> Self { 90 | Self { 91 | auth, 92 | api_url: api_url.into(), 93 | } 94 | } 95 | } 96 | 97 | impl Client { 98 | pub fn delta<'a>( 99 | &'a self, 100 | message_body: &'a MessageBody, 101 | ) -> Result> + 'a, Error> { 102 | log::debug!("message_body: {:#?}", message_body); 103 | 104 | let request_body = match serde_json::to_value(message_body) { 105 | Ok(body) => body, 106 | Err(e) => return Err(Error::Serde(e)), 107 | }; 108 | log::debug!("request_body: {:#?}", request_body); 109 | 110 | let original_stream = match self.post_stream(FIM_API.to_string(), request_body) { 111 | Ok(stream) => stream, 112 | Err(e) => return Err(Error::EventsourceClient(e)), 113 | }; 114 | 115 | let mapped_stream = original_stream.map(|item| { 116 | if item.is_err() { 117 | return Err(Error::EventsourceClient(item.err().unwrap())); 118 | } 119 | item.map(|event| match event { 120 | es::SSE::Connected(_) => String::default(), 121 | es::SSE::Event(ev) => match serde_json::from_str::(&ev.data) { 122 | Ok(chunk) => { 123 | if chunk.choices.is_empty() { 124 | String::default() 125 | } else { 126 | chunk.choices.first().unwrap().delta.content.clone() 127 | } 128 | } 129 | Err(_) => String::default(), 130 | }, 131 | es::SSE::Comment(comment) => { 132 | log::debug!("Comment: {:#?}", comment); 133 | String::default() 134 | } 135 | }) 136 | .map_err(Error::from) 137 | }); 138 | 139 | Ok(mapped_stream) 140 | } 141 | } 142 | 143 | impl Requests for Client { 144 | fn post_stream( 145 | &self, 146 | sub_url: String, 147 | body: Json, 148 | ) -> Result>, es::Error> { 149 | let authorization: &str = &format!("Bearer {}", self.auth.api_key); 150 | 151 | let client = es::ClientBuilder::for_url(&(self.api_url.clone() + &sub_url))? 152 | .header("content-type", "application/json")? 153 | .header("authorization", authorization)? 154 | .method("POST".into()) 155 | .body(body.to_string()) 156 | .reconnect( 157 | es::ReconnectOptions::reconnect(true) 158 | .retry_initial(false) 159 | .delay(Duration::from_secs(1)) 160 | .backoff_factor(2) 161 | .delay_max(Duration::from_secs(60)) 162 | .build(), 163 | ) 164 | .build(); 165 | 166 | Ok(crate::requests::tail(&client)) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /lib/es_stream/src/requests.rs: -------------------------------------------------------------------------------- 1 | use eventsource_client as es; 2 | use futures::stream::Stream; 3 | 4 | pub type Json = serde_json::Value; 5 | 6 | pub trait Requests { 7 | /// # Errors 8 | /// 9 | /// Will return `Err` if: 10 | /// 11 | /// - The headers can't be loaded to the request. 12 | /// - The body can't be loaded to the request. 13 | /// - The POST request to start the stream fails. 14 | /// - The stream connection fails to reconnect. 15 | /// - A stream can't be created. 16 | fn post_stream( 17 | &self, 18 | sub_url: String, 19 | body: Json, 20 | ) -> Result>, es::Error>; 21 | } 22 | 23 | pub(crate) fn tail(client: &impl es::Client) -> impl Stream> { 24 | client.stream() 25 | } 26 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "gpt_tokenizer" 3 | version = "0.1.0" 4 | description = "Rust BPE Encoder Decoder (Tokenizer) for GPT-2 / GPT-3" 5 | license = "MIT" 6 | readme = "README.md" 7 | repository = "https://github.com/cloudbridgeuy/a/tree/main/lib/tokenizer" 8 | edition = "2021" 9 | keywords = ["bpe", "tokenizer", "openai", "gpt3", "chatgpt"] 10 | 11 | [dependencies] 12 | serde_json = "1.0.93" 13 | regex = "1.7.1" 14 | 15 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/README.md: -------------------------------------------------------------------------------- 1 | # GPT-Tokenizer 2 | 3 | An implementation of the GPT-3 tokenizer created by converting the [`GPT-3-Encoder`](https://www.npmjs.com/package/gpt-3-encoder) 4 | JavaScript package to Rust (with the help of ChatGPT-4). You can use it to estimate the number of 5 | tokens that your prompt would approximately consume. You can also create your own custom `encoding` and 6 | `decoding` functions by providing your own `encoder.json` and `vocab.bpe` files. 7 | 8 | > As a rule of thumb, OpenAI suggest that 100 tokens equal 75 words. 9 | 10 | See how it works against the tokenizer published by OpenAI: 11 | 12 | [https://platform.openai.com/tokenizer](https://platform.openai.com/tokenizer) 13 | 14 | ```rust 15 | use tokenizer::DefaultTokenizer; 16 | 17 | fn main() { 18 | let tokenizer = DefaultTokenizer::new(); 19 | 20 | let text = r#"I'Many words map to one token, but some don't: indivisible. 21 | 22 | Unicode characters like emojis may be split into many tokens containing the underlying bytes: 🤚🏾 23 | 24 | Sequences of characters commonly found next to each other may be grouped together: 1234567890"#; 25 | 26 | let encoded = &tokenizer.encode(text); 27 | let decoded = &tokenizer.decode(encoded); 28 | 29 | println!("Original text: {}", text); 30 | println!("Encoded text: {:#?}", encoded); 31 | println!("Decoded text: {}", decoded 32 | 33 | println!("Text size: {}", text.len()); 34 | println!("Words: {}", text.split(" ").count()); 35 | println!("Rule of Thumb: {}", text.split(" ").count() * 4 / 3); 36 | println!("Tokens: {}", encoded.len()); 37 | } 38 | ``` 39 | 40 | See the [./examples](./examples) directory to see more examples of how to use it. 41 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/examples/custom_tokenizer.rs: -------------------------------------------------------------------------------- 1 | //! You need to load the `encoder.json` and `vocab.bpe` in order to use this crate. 2 | //! 3 | //! A default `encoder.json` and `vocab.bpe` comes included in the library through 4 | //! the `ENCODER_JSON` and `VOCAB_BPE` constants respectively. You may opt-out of 5 | //! this variables by bringing your own files. 6 | //! 7 | //! The following example shows how you need to process this files in order to create 8 | //! your `encode` and `decode` functions. 9 | 10 | use std::collections::HashMap; 11 | use std::iter::FromIterator; 12 | 13 | use gpt_tokenizer::{bpe_ranks, bytes_to_unicode, decode, encode, ENCODER_JSON, VOCAB_BPE}; 14 | 15 | fn main() { 16 | let encoder: HashMap = serde_json::from_str(ENCODER_JSON).unwrap(); 17 | let decoder: HashMap = 18 | HashMap::from_iter(encoder.clone().into_iter().map(|(k, v)| (v, k))); 19 | 20 | let lines: Vec = VOCAB_BPE.lines().map(|line| line.to_owned()).collect(); 21 | let bpe_ranks = bpe_ranks(&lines); 22 | 23 | let byte_encoder = bytes_to_unicode(); 24 | let byte_decoder: HashMap = 25 | HashMap::from_iter(byte_encoder.clone().into_iter().map(|(k, v)| (v, k))); 26 | 27 | let text = r#"I'Many words map to one token, but some don't: indivisible. 28 | 29 | Unicode characters like emojis may be split into many tokens containing the underlying bytes: 🤚🏾 30 | 31 | Sequences of characters commonly found next to each other may be grouped together: 1234567890"#; 32 | let encoded = encode(text, &bpe_ranks, &encoder); 33 | let decoded = decode(&encoded, &decoder, &byte_decoder); 34 | 35 | println!("Byte encoder: {:?}", byte_encoder); 36 | // println!("BPE Rank: {:?}", bpe_ranks); 37 | 38 | println!("Original text: {}", text); 39 | println!("Encoded text: {:#?}", encoded); 40 | println!("Decoded text: {}", decoded); 41 | 42 | println!("Text size: {}", text.len()); 43 | println!("Words: {}", text.split(' ').count()); 44 | println!("Rule of Thumb: {}", text.split(' ').count() * 4 / 3); 45 | println!("Tokens: {}", encoded.len()); 46 | } 47 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/examples/default_tokenizer.rs: -------------------------------------------------------------------------------- 1 | //! The library comes with a DefaultTokenizer, which is a struct that loads the internal 2 | //! `encoder.json` and `vocab.bpe`. It simplifies the creation of the `encode` and `decode` 3 | //! functions. This is specially useful when you just want to estimate the number of tokens 4 | //! your prompt will consume. 5 | //! 6 | //! > As a rule of thumb, OpenAI suggest that 100 tokens equal 75 words. 7 | use gpt_tokenizer::Default; 8 | 9 | fn main() { 10 | let tokenizer = Default::new(); 11 | 12 | let text = r#"I'Many words map to one token, but some don't: indivisible. 13 | 14 | Unicode characters like emojis may be split into many tokens containing the underlying bytes: 🤚🏾 15 | 16 | Sequences of characters commonly found next to each other may be grouped together: 1234567890"#; 17 | 18 | let encoded = &tokenizer.encode(text); 19 | let decoded = &tokenizer.decode(encoded); 20 | 21 | println!("Original text: {}", text); 22 | println!("Encoded text: {:#?}", encoded); 23 | println!("Decoded text: {}", decoded); 24 | 25 | println!("Text size: {}", text.len()); 26 | println!("Words: {}", text.split(' ').count()); 27 | println!("Rule of Thumb: {}", text.split(' ').count() * 4 / 3); 28 | println!("Tokens: {}", encoded.len()); 29 | } 30 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/examples/regex.rs: -------------------------------------------------------------------------------- 1 | use regex::Regex; 2 | 3 | fn main() { 4 | // let re1 = Regex::new(r#"s[|']t|[|']re|[|']ve|[|']m|[|']ll|[|']d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"#).unwrap(); 5 | // let re2 = Regex::new(r#"s[|']t|[|']re|[|']ve|[|']m|[|']ll|[|']d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+$|\s+"#).unwrap(); 6 | // let re1 = Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+").unwrap(); 7 | let re1 = Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\pL+| ?\pN+| ?[^\s\pL\pN]+|\s+$|\s+").unwrap(); 8 | 9 | let text = "I'm a string with some contractions like I'm, you're, and we'll, as well as some numbers like 123 and some punctuation like !?"; 10 | 11 | println!("Parsed text with re1"); 12 | for cap in re1.captures_iter(text) { 13 | println!("{:?}", cap); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /lib/gpt_tokenizer/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This file includes code which was modified from https://github.com/openai/gpt-2 2 | //! and https://github.com/latitudegames/GPT-3-Encoder/blob/master/Encoder.js 3 | //! It was converted from JavaScript with the help of ChatGPT 4.0 4 | 5 | use std::collections::{HashMap, HashSet}; 6 | 7 | use regex::Regex; 8 | 9 | pub const ENCODER_JSON: &str = include_str!("encoder.json"); 10 | pub const VOCAB_BPE: &str = include_str!("vocab.bpe"); 11 | 12 | /// Default tokenizer that uses embedded encoder and vocab values to create the `encode` and 13 | /// `decode` functions. 14 | #[derive(Default)] 15 | pub struct Default { 16 | encoder: HashMap, 17 | decoder: HashMap, 18 | bpe_ranks: HashMap, usize>, 19 | byte_decoder: HashMap, 20 | } 21 | 22 | impl Default { 23 | /// Creates a new DefaultTokenizer. 24 | pub fn new() -> Self { 25 | let byte_encoder = bytes_to_unicode(); 26 | let lines: Vec = VOCAB_BPE.lines().map(|line| line.to_owned()).collect(); 27 | let encoder: HashMap = serde_json::from_str(ENCODER_JSON).unwrap(); 28 | 29 | Self { 30 | encoder: encoder.clone(), 31 | decoder: HashMap::from_iter(encoder.into_iter().map(|(k, v)| (v, k))), 32 | bpe_ranks: bpe_ranks(&lines), 33 | byte_decoder: HashMap::from_iter(byte_encoder.into_iter().map(|(k, v)| (v, k))), 34 | } 35 | } 36 | 37 | pub fn encode(&self, text: &str) -> Vec { 38 | encode(text, &self.bpe_ranks, &self.encoder) 39 | } 40 | 41 | pub fn decode(&self, encoded: &[u32]) -> String { 42 | decode(encoded, &self.decoder, &self.byte_decoder) 43 | } 44 | } 45 | 46 | /// Constructs the `bpe_ranks` hashmap from a `vocab.bpe` file provides as a list of lines. 47 | pub fn bpe_ranks(lines: &[String]) -> HashMap, usize> { 48 | let bpe_merges: Vec> = lines 49 | .iter() 50 | .map(|x| x.split_whitespace().map(|s| s.to_owned()).collect()) 51 | .collect(); 52 | 53 | dict_zip(&bpe_merges, &(0..bpe_merges.len()).collect::>()) 54 | } 55 | 56 | /// Constructs a bytes to unicode HashMap. 57 | pub fn bytes_to_unicode() -> HashMap { 58 | let mut bs = range(ord('!'), ord('~') + 1) 59 | .iter() 60 | .chain(range(ord('¡'), ord('¬') + 1).iter()) 61 | .chain(range(ord('®'), ord('ÿ') + 1).iter()) 62 | .cloned() 63 | .collect::>(); 64 | 65 | let mut cs = bs.clone(); 66 | let mut n = 0; 67 | for b in 0..(2_u32.pow(8)) { 68 | if !bs.contains(&b) { 69 | bs.push(b); 70 | cs.push(2_u32.pow(8) + n); 71 | n += 1; 72 | } 73 | } 74 | 75 | let cs: Vec = cs.into_iter().map(chr).collect(); 76 | 77 | dict_zip(&bs, &cs) 78 | } 79 | 80 | /// Encodes a string using a custom bpe_ranks and encoder HashMaps. 81 | pub fn encode( 82 | text: &str, 83 | bpe_ranks: &HashMap, usize>, 84 | encoder: &HashMap, 85 | ) -> Vec { 86 | // I had to update this regex to makr it work in Rust, given that Rust doesn't support 87 | // look-around assertions. 88 | // 89 | // - `'s|'t|'re|'ve|'m|'ll|'d: This part of the regex matches common contractions in English, such as 's, 't, 're, 've, 'm, 'll, and 'd. 90 | // - `?\p{L}+: This part matches Unicode letters (L) with an optional space (?) before them. The plus sign (+) indicates one or more occurrences of the preceding element. 91 | // - `?\p{N}+: This part matches Unicode numbers (N) with an optional space (?) before them. The plus sign (+) indicates one or more occurrences of the preceding element. 92 | // - `?[^\s\p{L}\p{N}]+: This part matches any character that is not a whitespace (\s), letter (\p{L}), or number (\p{N}) with an optional space (?) before them. The plus sign (+) indicates one or more occurrences of the preceding element. 93 | // - `\s+: This part matches one or more whitespace characters (\s+). 94 | let pat = Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+$|\s+") 95 | .unwrap(); 96 | let mut bpe_tokens = Vec::new(); 97 | 98 | for token in pat.find_iter(text) { 99 | let token = token.as_str(); 100 | let token = encode_str(token); 101 | let token = token 102 | .into_iter() 103 | .map(|x| chr(x.parse::().unwrap()).to_string()) 104 | .collect::>() 105 | .join(""); 106 | 107 | let new_tokens: Vec = bpe(&token, bpe_ranks) 108 | .split_whitespace() 109 | .map(|x| encoder[x]) 110 | .collect(); 111 | bpe_tokens.extend(new_tokens); 112 | } 113 | 114 | bpe_tokens 115 | } 116 | 117 | /// Decodes an encoded string using a custom decoder and byte decoder created from the encoder that 118 | /// encoded the original string. 119 | pub fn decode( 120 | tokens: &[u32], 121 | decoder: &HashMap, 122 | byte_decoder: &HashMap, 123 | ) -> String { 124 | let text: String = tokens.iter().map(|x| decoder[x].as_str()).collect(); 125 | let text: String = text 126 | .chars() 127 | .map(|x| chr(*byte_decoder.get(&x).unwrap())) 128 | .collect(); 129 | 130 | text 131 | } 132 | 133 | fn range(x: u32, y: u32) -> Vec { 134 | (x..y).collect() 135 | } 136 | 137 | fn ord(ch: char) -> u32 { 138 | ch as u32 139 | } 140 | 141 | fn chr(code: u32) -> char { 142 | char::from_u32(code).unwrap() 143 | } 144 | 145 | fn encode_str(s: &str) -> Vec { 146 | s.as_bytes().iter().map(|b| b.to_string()).collect() 147 | } 148 | 149 | fn dict_zip(x: &[T], y: &[U]) -> HashMap { 150 | let mut map = HashMap::new(); 151 | for (i, key) in x.iter().enumerate() { 152 | map.insert(key.clone(), y[i].clone()); 153 | } 154 | map 155 | } 156 | 157 | fn get_pairs(word: &[String]) -> HashSet> { 158 | let mut pairs = HashSet::new(); 159 | let mut prev_char = &word[0]; 160 | for ch in word.iter().skip(1) { 161 | pairs.insert(vec![prev_char.clone(), ch.clone()]); 162 | prev_char = ch; 163 | } 164 | pairs 165 | } 166 | 167 | fn bpe(token: &str, bpe_ranks: &HashMap, usize>) -> String { 168 | let byte_encoder = bytes_to_unicode(); 169 | 170 | let mut word = token 171 | .chars() 172 | .map(|c| byte_encoder[&(c as u32)].to_string()) 173 | .collect::>(); 174 | let mut pairs = get_pairs(&word); 175 | 176 | while !pairs.is_empty() { 177 | let min_pair_rank = pairs 178 | .iter() 179 | .map(|pair| bpe_ranks.get(pair).copied().unwrap_or(usize::MAX)) 180 | .min() 181 | .unwrap(); 182 | let bigram = pairs 183 | .iter() 184 | .find(|pair| bpe_ranks.get(*pair).copied().unwrap_or(usize::MAX) == min_pair_rank) 185 | .cloned() 186 | .unwrap(); 187 | 188 | if !bpe_ranks.contains_key(&bigram) { 189 | break; 190 | } 191 | 192 | let first = &bigram[0]; 193 | let second = &bigram[1]; 194 | let mut new_word = Vec::new(); 195 | let mut i = 0; 196 | 197 | while i < word.len() { 198 | let j = word[i..].iter().position(|x| x == first); 199 | if let Some(j) = j { 200 | new_word.extend_from_slice(&word[i..i + j]); 201 | i += j; 202 | 203 | if i < word.len() - 1 && &word[i + 1] == second { 204 | new_word.push(format!("{}{}", first, second)); 205 | i += 2; 206 | } else { 207 | new_word.push(word[i].clone()); 208 | i += 1; 209 | } 210 | } else { 211 | new_word.extend_from_slice(&word[i..]); 212 | break; 213 | } 214 | } 215 | 216 | word = new_word; 217 | pairs = get_pairs(&word); 218 | } 219 | 220 | word.join(" ") 221 | } 222 | -------------------------------------------------------------------------------- /xtask/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xtask" 3 | version = "0.1.0" 4 | publish = false 5 | license = "MIT" 6 | edition = "2021" 7 | 8 | [dependencies] 9 | clap = { version = "4.1.8", features = ["derive"] } # A simple to use, efficient, and full-featured Command Line Argument Parser 10 | duct = "0.13.6" # A library for running child processes 11 | chrono = "0.4.24" # Date and time library for Rust 12 | bunt = "0.2.8" # Simple macros to write colored and formatted text to a terminal. Based on `termcolor`, thus als… 13 | -------------------------------------------------------------------------------- /xtask/src/cli.rs: -------------------------------------------------------------------------------- 1 | use clap::{Args, Parser, Subcommand}; 2 | 3 | #[derive(Debug, Parser)] 4 | #[command(name = "xtasks")] 5 | #[command(about = "Run project tasks using rust instead of scripts")] 6 | pub struct App { 7 | #[command(subcommand)] 8 | pub command: Option, 9 | } 10 | 11 | #[derive(Debug, Subcommand)] 12 | pub enum Commands { 13 | /// Runs one of the project binaries 14 | Run(RunArgs), 15 | /// Builds one of the project binaries 16 | Build(BuildArgs), 17 | /// Builds a binary and installs it at the given path 18 | Install(InstallArgs), 19 | /// Publishes a package to crates.io 20 | Publish(PublishArgs), 21 | /// Creates a new GitHub release 22 | Github(GithubArgs), 23 | } 24 | 25 | #[derive(Args, Debug)] 26 | pub struct RunArgs { 27 | /// Name of the binary to run. 28 | #[arg(short, long)] 29 | pub name: String, 30 | 31 | /// Arguments to pass to the binary. 32 | pub args: Option>, 33 | } 34 | 35 | #[derive(Args, Debug)] 36 | pub struct BuildArgs { 37 | /// Name of the binary to run. 38 | #[arg(short, long)] 39 | pub name: String, 40 | 41 | /// Release flag 42 | #[arg(short, long)] 43 | pub release: bool, 44 | } 45 | 46 | #[derive(Args, Debug)] 47 | pub struct PublishArgs { 48 | /// Name of the library to publish. 49 | #[arg(short, long)] 50 | pub name: String, 51 | 52 | /// Dry run flag. 53 | #[arg(short, long)] 54 | pub dry_run: bool, 55 | } 56 | 57 | #[derive(Args, Debug)] 58 | pub struct InstallArgs { 59 | /// Name of the binary to run. 60 | #[arg(short, long)] 61 | pub name: String, 62 | 63 | /// Path to install the binary to. 64 | #[arg(short, long)] 65 | pub path: String, 66 | } 67 | 68 | #[derive(Args, Debug)] 69 | pub struct GithubArgs { 70 | /// Name of the binary to run. 71 | #[arg(short, long)] 72 | pub name: String, 73 | } 74 | -------------------------------------------------------------------------------- /xtask/src/main.rs: -------------------------------------------------------------------------------- 1 | //! See 2 | //! 3 | //! This binary defines various auxiliary build commands, which are not 4 | //! expressible with just `cargo`. 5 | //! 6 | //! The binary is integrated into the `cargo` command line by using an 7 | //! alias in `.cargo/config`. 8 | 9 | mod cli; 10 | mod scripts; 11 | mod utils; 12 | 13 | use clap::Parser; 14 | 15 | fn main() -> Result<(), Box> { 16 | let cli = cli::App::parse(); 17 | 18 | match &cli.command { 19 | Some(command) => match command { 20 | cli::Commands::Run(args) => scripts::run(args), 21 | cli::Commands::Build(args) => scripts::build(args), 22 | cli::Commands::Publish(args) => scripts::publish(args), 23 | cli::Commands::Github(args) => scripts::github(args), 24 | cli::Commands::Install(args) => scripts::install(args), 25 | }, 26 | None => { 27 | println!("No command specified."); 28 | std::process::exit(1); 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /xtask/src/scripts.rs: -------------------------------------------------------------------------------- 1 | use crate::cli; 2 | use crate::utils; 3 | use bunt::println; 4 | use duct::cmd; 5 | use std::error::Error; 6 | 7 | pub fn run(args: &cli::RunArgs) -> Result<(), Box> { 8 | let mut arguments = vec!["run", "--bin", &args.name]; 9 | 10 | match &args.args { 11 | Some(args) => arguments.extend(args.iter().map(|s| s.as_str())), 12 | None => {} 13 | } 14 | 15 | cmd("cargo", arguments).read()?; 16 | 17 | Ok(()) 18 | } 19 | 20 | pub fn build(args: &cli::BuildArgs) -> Result<(), Box> { 21 | let mut arguments = vec!["build", "--bin", &args.name]; 22 | 23 | if args.release { 24 | arguments.push("--release"); 25 | } 26 | 27 | cmd("cargo", arguments).read()?; 28 | 29 | Ok(()) 30 | } 31 | 32 | fn release(name: &str) -> Result<(), Box> { 33 | let buid_args = cli::BuildArgs { 34 | name: name.to_string(), 35 | release: true, 36 | }; 37 | 38 | build(&buid_args)?; 39 | 40 | Ok(()) 41 | } 42 | 43 | pub fn install(args: &cli::InstallArgs) -> Result<(), Box> { 44 | release(&args.name)?; 45 | 46 | let target_path = "target/release/".to_string() + &args.name; 47 | 48 | cmd!("cp", &target_path, &args.path).run()?; 49 | cmd!("chmod", "+x", &args.path).run()?; 50 | 51 | Ok(()) 52 | } 53 | 54 | pub fn publish(args: &cli::PublishArgs) -> Result<(), Box> { 55 | let mut arguments = vec!["publish", "--package", &args.name]; 56 | 57 | if args.dry_run { 58 | arguments.push("--dry-run"); 59 | } 60 | 61 | cmd("cargo", arguments).read()?; 62 | 63 | Ok(()) 64 | } 65 | 66 | pub fn github(args: &cli::GithubArgs) -> Result<(), Box> { 67 | release(&args.name)?; 68 | 69 | let version = utils::create_tag(); 70 | let target_path = "target/release/".to_string() + &args.name; 71 | let notes = "Release notes for ".to_string() + &version; 72 | 73 | println!("{$magenta}Creating {[yellow]} tag{/$}", &version); 74 | cmd!("git", "tag", "-a", &version, "-m", &version).run()?; 75 | println!("{$magenta}Pusing {[yellow]} tag{/$}", &version); 76 | cmd!("git", "push", "origin", &version).run()?; 77 | println!("{$magenta}Creating {[yellow]} release{/$}", &version); 78 | cmd!("gh", "release", "create", &version, "--title", &version, "--notes", ¬es).run()?; 79 | println!( 80 | "{$magenta}Uploading {[yellow]} release binary{/$}", 81 | &version 82 | ); 83 | cmd!( 84 | "gh", 85 | "release", 86 | "upload", 87 | &version, 88 | &target_path, 89 | "--clobber" 90 | ) 91 | .run()?; 92 | 93 | Ok(()) 94 | } 95 | -------------------------------------------------------------------------------- /xtask/src/utils.rs: -------------------------------------------------------------------------------- 1 | use chrono::{DateTime, Utc}; 2 | 3 | pub fn create_tag() -> String { 4 | let now: DateTime = Utc::now(); 5 | now.format("%Y-%m-%dT%H%M").to_string() 6 | } 7 | --------------------------------------------------------------------------------