├── llm_prompts ├── src │ ├── chat │ │ ├── mod.rs │ │ └── broker.rs │ ├── bin │ │ ├── main.rs │ │ ├── mistral_inline_edit.rs │ │ └── mistral_reranking.rs │ ├── reranking │ │ ├── mod.rs │ │ ├── openai.rs │ │ ├── mistral.rs │ │ ├── types.rs │ │ └── broker.rs │ ├── in_line_edit │ │ ├── storage │ │ │ └── mod.rs │ │ ├── mod.rs │ │ ├── doc_helpers.rs │ │ ├── broker.rs │ │ ├── types.rs │ │ ├── mistral.rs │ │ └── openai.rs │ ├── lib.rs │ └── answer_model │ │ └── mod.rs └── Cargo.toml ├── llm_client ├── src │ ├── tokenizer │ │ ├── mod.rs │ │ └── tokenizer.rs │ ├── lib.rs │ ├── format │ │ ├── mod.rs │ │ ├── tokenizer_config │ │ │ ├── mixtral.json │ │ │ ├── mistral.json │ │ │ └── deepseekcoder.json │ │ ├── types.rs │ │ ├── deepseekcoder.rs │ │ ├── mistral.rs │ │ └── mixtral.rs │ ├── clients │ │ ├── mod.rs │ │ ├── ollama.rs │ │ ├── codestory.rs │ │ ├── lmstudio.rs │ │ ├── togetherai.rs │ │ ├── openai.rs │ │ └── types.rs │ ├── config.rs │ ├── bin │ │ ├── codestory_provider.rs │ │ ├── openai_llm.rs │ │ └── mixtral_test.rs │ ├── sqlite.rs │ ├── provider.rs │ └── broker.rs ├── ai.codestory.llm_client.db ├── migrations │ ├── 20240118143148_llm_data_chat_row.sql │ └── 20240117185650_llm_data.sql ├── .sqlx │ ├── query-d37e4c0d93b0f19a43c5a44ac3632276bfe4bc92c8e22b52dda337d2be641b1c.json │ └── query-08c2149206ecf5f799ce8526b2fdc75a6e9dbd96776b176fab0b8ef5328a94a2.json └── Cargo.toml ├── Cargo.toml ├── .gitignore ├── README.md └── LICENSE /llm_prompts/src/chat/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod broker; 2 | -------------------------------------------------------------------------------- /llm_client/src/tokenizer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod tokenizer; 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "llm_client", 4 | "llm_prompts", 5 | ] -------------------------------------------------------------------------------- /llm_prompts/src/bin/main.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("Hello, world!"); 3 | } 4 | -------------------------------------------------------------------------------- /llm_prompts/src/reranking/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod broker; 2 | mod mistral; 3 | mod openai; 4 | pub mod types; 5 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/storage/mod.rs: -------------------------------------------------------------------------------- 1 | //! Crate for storage layer for storing to a sqlite DB as the backend 2 | -------------------------------------------------------------------------------- /llm_prompts/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod answer_model; 2 | pub mod chat; 3 | pub mod in_line_edit; 4 | pub mod reranking; 5 | -------------------------------------------------------------------------------- /llm_client/ai.codestory.llm_client.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codestoryai/prompts/HEAD/llm_client/ai.codestory.llm_client.db -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod broker; 2 | mod doc_helpers; 3 | pub mod mistral; 4 | pub mod openai; 5 | pub mod types; 6 | -------------------------------------------------------------------------------- /llm_client/migrations/20240118143148_llm_data_chat_row.sql: -------------------------------------------------------------------------------- 1 | -- Add migration script here 2 | ALTER TABLE llm_data ADD COLUMN chat_messages TEXT; -------------------------------------------------------------------------------- /llm_client/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod broker; 2 | pub mod clients; 3 | pub mod config; 4 | pub mod format; 5 | pub mod provider; 6 | mod sqlite; 7 | pub mod tokenizer; 8 | -------------------------------------------------------------------------------- /llm_client/src/format/mod.rs: -------------------------------------------------------------------------------- 1 | //! Crate for formatting prompts for different llms 2 | 3 | pub mod deepseekcoder; 4 | pub mod mistral; 5 | pub mod mixtral; 6 | pub mod types; 7 | -------------------------------------------------------------------------------- /llm_client/src/clients/mod.rs: -------------------------------------------------------------------------------- 1 | //! Exposes all the clients which we are interested in, and a standardized format 2 | //! so we can be happy while the provider client takes care of the details 3 | 4 | pub mod codestory; 5 | pub mod lmstudio; 6 | pub mod ollama; 7 | pub mod openai; 8 | pub mod togetherai; 9 | pub mod types; 10 | -------------------------------------------------------------------------------- /llm_client/src/config.rs: -------------------------------------------------------------------------------- 1 | //! The configuration which will be passed to the llm broker 2 | 3 | use std::path::PathBuf; 4 | 5 | pub struct LLMBrokerConfiguration { 6 | pub data_dir: PathBuf, 7 | } 8 | 9 | impl LLMBrokerConfiguration { 10 | pub fn new(data_dir: PathBuf) -> Self { 11 | Self { data_dir } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /llm_client/migrations/20240117185650_llm_data.sql: -------------------------------------------------------------------------------- 1 | -- Add migration script here 2 | CREATE TABLE llm_data ( 3 | id INTEGER PRIMARY KEY AUTOINCREMENT, 4 | created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, 5 | prompt TEXT, 6 | response TEXT, 7 | llm_type TEXT, 8 | temperature FLOAT, 9 | max_tokens INTEGER, 10 | event_type TEXT 11 | ); -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | -------------------------------------------------------------------------------- /llm_client/.sqlx/query-d37e4c0d93b0f19a43c5a44ac3632276bfe4bc92c8e22b52dda337d2be641b1c.json: -------------------------------------------------------------------------------- 1 | { 2 | "db_name": "SQLite", 3 | "query": "\n INSERT INTO llm_data (prompt, response, llm_type, temperature, max_tokens, event_type)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", 4 | "describe": { 5 | "columns": [], 6 | "parameters": { 7 | "Right": 6 8 | }, 9 | "nullable": [] 10 | }, 11 | "hash": "d37e4c0d93b0f19a43c5a44ac3632276bfe4bc92c8e22b52dda337d2be641b1c" 12 | } 13 | -------------------------------------------------------------------------------- /llm_client/.sqlx/query-08c2149206ecf5f799ce8526b2fdc75a6e9dbd96776b176fab0b8ef5328a94a2.json: -------------------------------------------------------------------------------- 1 | { 2 | "db_name": "SQLite", 3 | "query": "\n INSERT INTO llm_data (chat_messages, response, llm_type, temperature, max_tokens, event_type)\n VALUES ($1, $2, $3, $4, $5, $6)\n ", 4 | "describe": { 5 | "columns": [], 6 | "parameters": { 7 | "Right": 6 8 | }, 9 | "nullable": [] 10 | }, 11 | "hash": "08c2149206ecf5f799ce8526b2fdc75a6e9dbd96776b176fab0b8ef5328a94a2" 12 | } 13 | -------------------------------------------------------------------------------- /llm_prompts/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm_prompts" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | llm_client = { path = "../llm_client" } 10 | thiserror = "1.0.56" 11 | futures = "0.3.28" 12 | serde_json = "1.0.107" 13 | serde = { version = "1.0.188", features = ["derive"] } 14 | sqlx = { version = "0.7.2", features = ["sqlite", "migrate", "runtime-tokio-rustls", "chrono", "uuid"]} 15 | tokio = { version = "1.32.0", features = ["full"] } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://dcbadge.vercel.app/api/server/JEjtaVwyeX)](https://discord.gg/JEjtaVwyeX) 2 | 3 | 4 | 5 | ``` 6 | 7 | ██████╗ ██████╗ ██████╗ ███████╗███████╗████████╗ ██████╗ ██████╗ ██╗ ██╗ 8 | ██╔════╝██╔═══██╗██╔══██╗██╔════╝██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗╚██╗ ██╔╝ 9 | ██║ ██║ ██║██║ ██║█████╗ ███████╗ ██║ ██║ ██║██████╔╝ ╚████╔╝ 10 | ██║ ██║ ██║██║ ██║██╔══╝ ╚════██║ ██║ ██║ ██║██╔══██╗ ╚██╔╝ 11 | ╚██████╗╚██████╔╝██████╔╝███████╗███████║ ██║ ╚██████╔╝██║ ██║ ██║ 12 | ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ 13 | 14 | ``` 15 | -------------------------------------------------------------------------------- /llm_client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm_client" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | async-trait = "0.1.77" 10 | anyhow = "1.0.75" 11 | reqwest = "0.11.23" 12 | serde = "1.0.195" 13 | serde_json = "1.0.111" 14 | eventsource-stream = "0.2.3" 15 | futures = "0.3.28" 16 | tokio = { version = "1.32.0", features = ["full"] } 17 | thiserror = "1.0.49" 18 | tokenizers = { version = "0.13.3", default-features = false, features = ["progressbar", "cli", "onig", "esaxx_fast"] } 19 | tiktoken-rs = "0.5.4" 20 | async-openai = "0.14.3" 21 | sqlx = { version = "0.7.2", features = ["sqlite", "migrate", "runtime-tokio-rustls", "chrono", "uuid"]} 22 | -------------------------------------------------------------------------------- /llm_prompts/src/chat/broker.rs: -------------------------------------------------------------------------------- 1 | use llm_client::clients::types::LLMType; 2 | 3 | use crate::answer_model::{AnswerModel, LLMAnswerModelBroker}; 4 | 5 | #[derive(thiserror::Error, Debug)] 6 | pub enum ChatModelBrokerErrors { 7 | #[error("The model {0} is not supported yet")] 8 | ModelNotSupported(LLMType), 9 | } 10 | 11 | pub struct LLMChatModelBroker { 12 | answer_model_broker: LLMAnswerModelBroker, 13 | } 14 | 15 | impl LLMChatModelBroker { 16 | pub fn init() -> Self { 17 | let answer_model_broker = LLMAnswerModelBroker::new(); 18 | Self { 19 | answer_model_broker, 20 | } 21 | } 22 | 23 | pub fn get_answer_model( 24 | &self, 25 | llm_type: &LLMType, 26 | ) -> Result<&AnswerModel, ChatModelBrokerErrors> { 27 | self.answer_model_broker 28 | .get_answer_model(llm_type) 29 | .ok_or(ChatModelBrokerErrors::ModelNotSupported(llm_type.clone())) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /llm_client/src/bin/codestory_provider.rs: -------------------------------------------------------------------------------- 1 | //! Call the endpoints of codestory endpoint 2 | 3 | use llm_client::{ 4 | clients::{ 5 | codestory::CodeStoryClient, 6 | types::{LLMClient, LLMClientCompletionRequest, LLMClientMessage, LLMType}, 7 | }, 8 | provider::LLMProviderAPIKeys, 9 | }; 10 | 11 | #[tokio::main] 12 | async fn main() { 13 | let codestory_client = 14 | CodeStoryClient::new("https://codestory-provider-dot-anton-390822.ue.r.appspot.com"); 15 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 16 | let request = LLMClientCompletionRequest::new( 17 | LLMType::GPT3_5_16k, 18 | vec![ 19 | LLMClientMessage::system("you are a python expert".to_owned()), 20 | LLMClientMessage::user( 21 | "write me a big python function which does a lot of things".to_owned(), 22 | ), 23 | ], 24 | 1.0, 25 | None, 26 | ); 27 | let response = codestory_client 28 | .stream_completion(LLMProviderAPIKeys::CodeStory, request, sender) 29 | .await; 30 | println!("{:?}", response); 31 | } 32 | -------------------------------------------------------------------------------- /llm_client/src/bin/openai_llm.rs: -------------------------------------------------------------------------------- 1 | use async_openai::{ 2 | config::AzureConfig, 3 | types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs}, 4 | Client, 5 | }; 6 | use futures::StreamExt; 7 | use llm_client::clients::{ 8 | openai::OpenAIClient, 9 | types::{LLMClient, LLMClientCompletionRequest, LLMClientMessage}, 10 | }; 11 | use llm_client::provider::AzureConfig as ProviderAzureConfig; 12 | 13 | #[tokio::main] 14 | async fn main() { 15 | let openai_client = OpenAIClient::new(); 16 | let api_key = 17 | llm_client::provider::LLMProviderAPIKeys::OpenAIAzureConfig(ProviderAzureConfig { 18 | deployment_id: "some_deployment_id".to_string(), 19 | api_base: "some_base".to_owned(), 20 | api_key: "some_key".to_owned(), 21 | api_version: "some_version".to_owned(), 22 | }); 23 | let request = LLMClientCompletionRequest::new( 24 | llm_client::clients::types::LLMType::GPT3_5_16k, 25 | vec![LLMClientMessage::system("message".to_owned())], 26 | 1.0, 27 | None, 28 | ); 29 | let (sender, receiver) = tokio::sync::mpsc::unbounded_channel(); 30 | let response = openai_client 31 | .stream_completion(api_key, request, sender) 32 | .await; 33 | dbg!(&response); 34 | } 35 | -------------------------------------------------------------------------------- /llm_client/src/sqlite.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use sqlx::SqlitePool; 4 | 5 | use crate::{clients::types::LLMClientError, config::LLMBrokerConfiguration}; 6 | 7 | pub async fn init(config: LLMBrokerConfiguration) -> Result { 8 | let data_dir = config.data_dir.to_string_lossy().to_owned(); 9 | 10 | match connect(&data_dir).await { 11 | Ok(pool) => Ok(pool), 12 | Err(_) => { 13 | reset(&data_dir)?; 14 | connect(&data_dir).await 15 | } 16 | } 17 | } 18 | 19 | async fn connect(data_dir: &str) -> Result { 20 | let url = format!("sqlite://{data_dir}/llm_data.data?mode=rwc"); 21 | let pool = SqlitePool::connect(&url) 22 | .await 23 | .map_err(|_| LLMClientError::TokioMpscSendError)?; 24 | 25 | if let Err(e) = sqlx::migrate!().run(&pool).await { 26 | // We manually close the pool here to ensure file handles are properly cleaned up on 27 | // Windows. 28 | pool.close().await; 29 | Err(e).map_err(|_e| LLMClientError::SqliteSetupError)? 30 | } else { 31 | Ok(pool) 32 | } 33 | } 34 | 35 | fn reset(data_dir: &str) -> Result<(), LLMClientError> { 36 | let db_path = Path::new(data_dir).join("llm_data.data"); 37 | let bk_path = db_path.with_extension("llm_data.bk"); 38 | std::fs::rename(db_path, bk_path).map_err(|_| LLMClientError::SqliteSetupError) 39 | } 40 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/doc_helpers.rs: -------------------------------------------------------------------------------- 1 | use super::types::InLineDocRequest; 2 | 3 | pub fn documentation_type(identifier_node: &InLineDocRequest) -> String { 4 | let language = identifier_node.language(); 5 | let is_identifier = identifier_node.is_identifier_node(); 6 | let comment_type = match language { 7 | "typescript" | "typescriptreact" => match is_identifier { 8 | true => "a TSDoc comment".to_owned(), 9 | false => "TSDoc comment".to_owned(), 10 | }, 11 | "javascript" | "javascriptreact" => match is_identifier { 12 | true => "a JSDoc comment".to_owned(), 13 | false => "JSDoc comment".to_owned(), 14 | }, 15 | "python" => "docstring".to_owned(), 16 | "rust" => "Rustdoc comment".to_owned(), 17 | _ => "documentation comment".to_owned(), 18 | }; 19 | comment_type 20 | } 21 | 22 | pub fn selection_type(identifier_node: &InLineDocRequest) -> String { 23 | let identifier_node_str = identifier_node.identifier_node_str(); 24 | match identifier_node_str { 25 | Some(identifier_node) => identifier_node.to_owned(), 26 | None => "the selection".to_owned(), 27 | } 28 | } 29 | 30 | pub fn document_symbol_metadata(identifier_node: &InLineDocRequest) -> String { 31 | let is_identifier = identifier_node.is_identifier_node(); 32 | let language = identifier_node.language(); 33 | let comment_type = documentation_type(identifier_node); 34 | let identifier_node_str = identifier_node.identifier_node_str(); 35 | match identifier_node_str { 36 | Some(identifier_node) => { 37 | format!("Please add {comment_type} for {identifier_node}") 38 | } 39 | None => format!("Please add {comment_type} for the selection"), 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /llm_client/src/format/tokenizer_config/mixtral.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "added_tokens_decoder": { 5 | "0": { 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false, 11 | "special": true 12 | }, 13 | "1": { 14 | "content": "", 15 | "lstrip": false, 16 | "normalized": false, 17 | "rstrip": false, 18 | "single_word": false, 19 | "special": true 20 | }, 21 | "2": { 22 | "content": "", 23 | "lstrip": false, 24 | "normalized": false, 25 | "rstrip": false, 26 | "single_word": false, 27 | "special": true 28 | } 29 | }, 30 | "additional_special_tokens": [], 31 | "bos_token": "", 32 | "clean_up_tokenization_spaces": false, 33 | "eos_token": "", 34 | "legacy": true, 35 | "model_max_length": 1000000000000000019884624838656, 36 | "pad_token": null, 37 | "sp_model_kwargs": {}, 38 | "spaces_between_special_tokens": false, 39 | "tokenizer_class": "LlamaTokenizer", 40 | "unk_token": "", 41 | "use_default_system_prompt": false, 42 | "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" 43 | } -------------------------------------------------------------------------------- /llm_client/src/format/tokenizer_config/mistral.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "added_tokens_decoder": { 5 | "0": { 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false, 11 | "special": true 12 | }, 13 | "1": { 14 | "content": "", 15 | "lstrip": false, 16 | "normalized": false, 17 | "rstrip": false, 18 | "single_word": false, 19 | "special": true 20 | }, 21 | "2": { 22 | "content": "", 23 | "lstrip": false, 24 | "normalized": false, 25 | "rstrip": false, 26 | "single_word": false, 27 | "special": true 28 | } 29 | }, 30 | "additional_special_tokens": [], 31 | "bos_token": "", 32 | "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", 33 | "clean_up_tokenization_spaces": false, 34 | "eos_token": "", 35 | "legacy": true, 36 | "model_max_length": 1000000000000000019884624838656, 37 | "pad_token": null, 38 | "sp_model_kwargs": {}, 39 | "spaces_between_special_tokens": false, 40 | "tokenizer_class": "LlamaTokenizer", 41 | "unk_token": "", 42 | "use_default_system_prompt": false 43 | } -------------------------------------------------------------------------------- /llm_client/src/format/tokenizer_config/deepseekcoder.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "<|begin▁of▁sentence|>", 7 | "lstrip": false, 8 | "normalized": true, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "<|EOT|>", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "legacy": true, 22 | "model_max_length": 16384, 23 | "pad_token": { 24 | "__type": "AddedToken", 25 | "content": "<|end▁of▁sentence|>", 26 | "lstrip": false, 27 | "normalized": true, 28 | "rstrip": false, 29 | "single_word": false 30 | }, 31 | "sp_model_kwargs": {}, 32 | "unk_token": null, 33 | "tokenizer_class": "LlamaTokenizerFast", 34 | "chat_template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}" 35 | } -------------------------------------------------------------------------------- /llm_client/src/format/types.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | use thiserror::Error; 4 | 5 | use crate::clients::types::LLMClientMessage; 6 | 7 | pub trait LLMFormatting { 8 | fn to_prompt(&self, messages: Vec) -> String; 9 | } 10 | 11 | pub struct DummyLLMFormatting {} 12 | 13 | impl DummyLLMFormatting { 14 | pub fn new() -> Self { 15 | Self {} 16 | } 17 | } 18 | 19 | impl LLMFormatting for DummyLLMFormatting { 20 | fn to_prompt(&self, messages: Vec) -> String { 21 | messages 22 | .into_iter() 23 | .map(|message| message.content().to_owned()) 24 | .collect::>() 25 | .join("\n") 26 | } 27 | } 28 | 29 | #[derive(Serialize, Deserialize, Debug)] 30 | pub struct TokenizerConfig { 31 | add_bos_token: bool, 32 | add_eos_token: bool, 33 | added_tokens_decoder: HashMap, 34 | additional_special_tokens: Vec, 35 | bos_token: String, 36 | chat_template: String, 37 | clean_up_tokenization_spaces: bool, 38 | eos_token: String, 39 | legacy: bool, 40 | model_max_length: u128, 41 | pad_token: Option, 42 | sp_model_kwargs: HashMap, 43 | spaces_between_special_tokens: bool, 44 | tokenizer_class: String, 45 | unk_token: String, 46 | use_default_system_prompt: bool, 47 | } 48 | 49 | impl TokenizerConfig { 50 | pub fn add_bos_token(&self) -> bool { 51 | self.add_bos_token 52 | } 53 | 54 | pub fn add_eos_token(&self) -> bool { 55 | self.add_eos_token 56 | } 57 | 58 | pub fn bos_token(&self) -> &str { 59 | &self.bos_token 60 | } 61 | 62 | pub fn eos_token(&self) -> &str { 63 | &self.eos_token 64 | } 65 | 66 | pub fn chat_template(&self) -> &str { 67 | &self.chat_template 68 | } 69 | } 70 | 71 | #[derive(Serialize, Deserialize, Debug)] 72 | pub struct AddedTokenDecoder { 73 | content: String, 74 | lstrip: bool, 75 | normalized: bool, 76 | rstrip: bool, 77 | single_word: bool, 78 | special: bool, 79 | } 80 | 81 | #[derive(Error, Debug)] 82 | pub enum TokenizerError { 83 | #[error("Failed to get response from LLM")] 84 | FailedToGetResponse, 85 | 86 | #[error("Reqwest error: {0}")] 87 | ReqwestError(#[from] reqwest::Error), 88 | 89 | #[error("serde failed: {0}")] 90 | SerdeError(#[from] serde_json::Error), 91 | } 92 | -------------------------------------------------------------------------------- /llm_client/src/bin/mixtral_test.rs: -------------------------------------------------------------------------------- 1 | use llm_client::clients::togetherai::TogetherAIClient; 2 | use llm_client::{ 3 | clients::types::{LLMClient, LLMClientCompletionRequest, LLMClientMessage}, 4 | provider::TogetherAIProvider, 5 | }; 6 | 7 | #[tokio::main] 8 | async fn main() { 9 | let togetherai = TogetherAIClient::new(); 10 | let api_key = llm_client::provider::LLMProviderAPIKeys::TogetherAI(TogetherAIProvider { 11 | api_key: "some_key".to_owned(), 12 | }); 13 | let message = r#"[INST] You are an expert software engineer. You have been given some code context below: 14 | 15 | Code Context above the selection: 16 | ```rust 17 | // FILEPATH: $/Users/skcd/scratch/dataset/commit_play/src/bin/tokenizers.rs 18 | // BEGIN: abpxx6d04wxr 19 | //! Tokenizers we are going to use 20 | 21 | use std::str::FromStr; 22 | 23 | use commit_play::llm::tokenizers::mistral::mistral_tokenizer; 24 | use tokenizers::Tokenizer; 25 | 26 | #[tokio::main] 27 | // END: abpxx6d04wxr 28 | ``` 29 | 30 | Your task is to rewrite the code below following the instruction: add tracing::info! calls after the call to encode and decode 31 | Code you have to edit: 32 | ```rust 33 | // FILEPATH: $/Users/skcd/scratch/dataset/commit_play/src/bin/tokenizers.rs 34 | // BEGIN: ed8c6549bwf9 35 | async fn main() { 36 | let tokenizer = Tokenizer::from_str(&mistral_tokenizer()).expect("tokenizer error"); 37 | 38 | let tokens = tokenizer.encode( 39 | "[INST] write a function which adds 2 numbers in python [/INST]", 40 | true, 41 | ); 42 | 43 | let result = tokenizer.decode( 44 | &[ 45 | 733, 16289, 28793, 3324, 264, 908, 690, 13633, 28705, 28750, 5551, 297, 21966, 733, 46 | 28748, 16289, 28793, 47 | ], 48 | false, 49 | ); 50 | 51 | dbg!(tokens.expect("to work").get_ids()); 52 | 53 | dbg!(result.expect("something")); 54 | } 55 | 56 | // END: ed8c6549bwf9 57 | ``` 58 | 59 | Rewrite the code without any explanation [/INST] 60 | ```rust 61 | // FILEPATH: /Users/skcd/scratch/dataset/commit_play/src/bin/tokenizers.rs 62 | // BEGIN: ed8c6549bwf9"#; 63 | let request = LLMClientCompletionRequest::new( 64 | llm_client::clients::types::LLMType::MistralInstruct, 65 | vec![LLMClientMessage::user(message.to_owned())], 66 | 1.0, 67 | None, 68 | ); 69 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 70 | let response = togetherai.stream_completion(api_key, request, sender).await; 71 | dbg!(&response); 72 | } 73 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/broker.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use llm_client::clients::types::LLMType; 4 | 5 | use super::{ 6 | mistral::MistralLineEditPrompt, 7 | openai::OpenAILineEditPrompt, 8 | types::{ 9 | InLineDocRequest, InLineEditPrompt, InLineEditPromptError, InLineEditRequest, 10 | InLineFixRequest, InLinePromptResponse, 11 | }, 12 | }; 13 | 14 | pub struct InLineEditPromptBroker { 15 | prompt_generators: HashMap>, 16 | } 17 | 18 | impl InLineEditPromptBroker { 19 | pub fn new() -> Self { 20 | let broker = Self { 21 | prompt_generators: HashMap::new(), 22 | }; 23 | broker 24 | .insert_prompt_generator(LLMType::GPT3_5_16k, Box::new(OpenAILineEditPrompt::new())) 25 | .insert_prompt_generator(LLMType::Gpt4, Box::new(OpenAILineEditPrompt::new())) 26 | .insert_prompt_generator(LLMType::Gpt4_32k, Box::new(OpenAILineEditPrompt::new())) 27 | .insert_prompt_generator( 28 | LLMType::MistralInstruct, 29 | Box::new(MistralLineEditPrompt::new()), 30 | ) 31 | .insert_prompt_generator(LLMType::Mixtral, Box::new(MistralLineEditPrompt::new())) 32 | } 33 | 34 | pub fn insert_prompt_generator( 35 | mut self, 36 | llm_type: LLMType, 37 | prompt_generator: Box, 38 | ) -> Self { 39 | self.prompt_generators.insert(llm_type, prompt_generator); 40 | self 41 | } 42 | 43 | fn get_prompt_generator( 44 | &self, 45 | llm_type: &LLMType, 46 | ) -> Result<&Box, InLineEditPromptError> { 47 | self.prompt_generators 48 | .get(llm_type) 49 | .ok_or(InLineEditPromptError::ModelNotSupported) 50 | } 51 | 52 | pub fn get_prompt( 53 | &self, 54 | llm_type: &LLMType, 55 | request: InLineEditRequest, 56 | ) -> Result { 57 | let prompt_generator = self.get_prompt_generator(llm_type)?; 58 | Ok(prompt_generator.inline_edit(request)) 59 | } 60 | 61 | pub fn get_fix_prompt( 62 | &self, 63 | llm_type: &LLMType, 64 | request: InLineFixRequest, 65 | ) -> Result { 66 | let prompt_generator = self.get_prompt_generator(llm_type)?; 67 | Ok(prompt_generator.inline_fix(request)) 68 | } 69 | 70 | pub fn get_doc_prompt( 71 | &self, 72 | llm_type: &LLMType, 73 | request: InLineDocRequest, 74 | ) -> Result { 75 | let prompt_generator = self.get_prompt_generator(llm_type)?; 76 | Ok(prompt_generator.inline_doc(request)) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /llm_prompts/src/answer_model/mod.rs: -------------------------------------------------------------------------------- 1 | //! We define all the properties for the model configuration related to answering 2 | //! a user question in the chat format here 3 | 4 | use std::collections::HashMap; 5 | 6 | use llm_client::clients::types::LLMType; 7 | 8 | #[derive(Debug)] 9 | pub struct AnswerModel { 10 | pub llm_type: LLMType, 11 | /// The number of tokens reserved for the answer 12 | pub answer_tokens: i64, 13 | 14 | /// The number of tokens reserved for the prompt 15 | pub prompt_tokens_limit: i64, 16 | 17 | /// The number of tokens reserved for history 18 | pub history_tokens_limit: i64, 19 | 20 | /// The total number of tokens reserved for the model 21 | pub total_tokens: i64, 22 | } 23 | 24 | // GPT-3.5-16k Turbo has 16,385 tokens 25 | pub const GPT_3_5_TURBO_16K: AnswerModel = AnswerModel { 26 | llm_type: LLMType::GPT3_5_16k, 27 | answer_tokens: 1024 * 2, 28 | prompt_tokens_limit: 2500 * 2, 29 | history_tokens_limit: 2048 * 2, 30 | total_tokens: 16385, 31 | }; 32 | 33 | // GPT-4 has 8,192 tokens 34 | pub const GPT_4: AnswerModel = AnswerModel { 35 | llm_type: LLMType::Gpt4, 36 | answer_tokens: 1024, 37 | // The prompt tokens limit for gpt4 are a bit higher so we can get more context 38 | // when required 39 | prompt_tokens_limit: 4500, 40 | history_tokens_limit: 2048, 41 | total_tokens: 8192, 42 | }; 43 | 44 | // GPT4-32k has 32,769 tokens 45 | pub const GPT_4_32K: AnswerModel = AnswerModel { 46 | llm_type: LLMType::Gpt4_32k, 47 | answer_tokens: 1024 * 4, 48 | prompt_tokens_limit: 2500 * 4, 49 | history_tokens_limit: 2048 * 4, 50 | total_tokens: 32769, 51 | }; 52 | 53 | // GPT4-Turbo has 128k tokens as input, but let's keep it capped at 32k tokens 54 | // as LLMs exhibit LIM issues which has been frequently documented 55 | pub const GPT_4_TURBO_128K: AnswerModel = AnswerModel { 56 | llm_type: LLMType::Gpt4Turbo, 57 | answer_tokens: 1024 * 4, 58 | prompt_tokens_limit: 2500 * 4, 59 | history_tokens_limit: 2048 * 4, 60 | total_tokens: 32769, 61 | }; 62 | 63 | // MistralInstruct has 8k tokens in total 64 | pub const MISTRAL_INSTRUCT: AnswerModel = AnswerModel { 65 | llm_type: LLMType::MistralInstruct, 66 | answer_tokens: 1024, 67 | prompt_tokens_limit: 4500, 68 | history_tokens_limit: 2048, 69 | total_tokens: 8000, 70 | }; 71 | 72 | // Mixtral has 32k tokens in total 73 | pub const MIXTRAL: AnswerModel = AnswerModel { 74 | llm_type: LLMType::Mixtral, 75 | answer_tokens: 1024, 76 | prompt_tokens_limit: 2500 * 4, 77 | history_tokens_limit: 1024 * 4, 78 | total_tokens: 32000, 79 | }; 80 | 81 | pub struct LLMAnswerModelBroker { 82 | pub models: HashMap, 83 | } 84 | 85 | impl LLMAnswerModelBroker { 86 | pub fn new() -> Self { 87 | let broker = Self { 88 | models: Default::default(), 89 | }; 90 | broker 91 | .add_answer_model(GPT_3_5_TURBO_16K) 92 | .add_answer_model(GPT_4) 93 | .add_answer_model(GPT_4_32K) 94 | .add_answer_model(GPT_4_TURBO_128K) 95 | .add_answer_model(MISTRAL_INSTRUCT) 96 | .add_answer_model(MIXTRAL) 97 | } 98 | 99 | pub fn add_answer_model(mut self, model: AnswerModel) -> Self { 100 | self.models.insert(model.llm_type.clone(), model); 101 | self 102 | } 103 | 104 | pub fn get_answer_model(&self, llm_type: &LLMType) -> Option<&AnswerModel> { 105 | self.models.get(llm_type) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /llm_client/src/format/deepseekcoder.rs: -------------------------------------------------------------------------------- 1 | use crate::clients::types::LLMClientMessage; 2 | 3 | use super::types::LLMFormatting; 4 | 5 | pub struct DeepSeekCoderFormatting {} 6 | 7 | impl DeepSeekCoderFormatting { 8 | pub fn new() -> Self { 9 | Self {} 10 | } 11 | } 12 | 13 | impl LLMFormatting for DeepSeekCoderFormatting { 14 | fn to_prompt(&self, messages: Vec) -> String { 15 | // we want to convert the message to deepseekcoder format 16 | // present here: https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct/blob/main/tokenizer_config.json#L34 17 | // {% if not add_generation_prompt is defined %} 18 | // {% set add_generation_prompt = false %} 19 | // {% endif %} 20 | // {%- set ns = namespace(found=false) -%} 21 | // {%- for message in messages -%} 22 | // {%- if message['role'] == 'system' -%} 23 | // {%- set ns.found = true -%} 24 | // {%- endif -%} 25 | // {%- endfor -%} 26 | // {{bos_token}}{%- if not ns.found -%} 27 | // {{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n'}} 28 | // {%- endif %} 29 | // {%- for message in messages %} 30 | // {%- if message['role'] == 'system' %} 31 | // {{ message['content'] }} 32 | // {%- else %} 33 | // {%- if message['role'] == 'user' %} 34 | // {{'### Instruction:\n' + message['content'] + '\n'}} 35 | // {%- else %} 36 | // {{'### Response:\n' + message['content'] + '\n<|EOT|>\n'}} 37 | // {%- endif %} 38 | // {%- endif %} 39 | // {%- endfor %} 40 | // {% if add_generation_prompt %} 41 | // {{'### Response:'}} 42 | // {% endif %} 43 | let formatted_message = messages.into_iter().skip_while(|message| message.role().is_assistant()) 44 | .map(|message| { 45 | let content = message.content(); 46 | if message.role().is_system() { 47 | "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n".to_owned() 48 | } else if message.role().is_user() { 49 | format!("### Instruction:\n{}\n", content) 50 | } else { 51 | format!("### Response:\n{}\n<|EOT|>\n", content) 52 | } 53 | }).collect::>().join(""); 54 | formatted_message 55 | } 56 | } 57 | 58 | #[cfg(test)] 59 | mod tests { 60 | use crate::clients::types::LLMClientMessage; 61 | 62 | use super::DeepSeekCoderFormatting; 63 | use super::LLMFormatting; 64 | 65 | #[test] 66 | fn test_formatting_works() { 67 | let messages = vec![ 68 | LLMClientMessage::system("system_message_not_show_up".to_owned()), 69 | LLMClientMessage::user("user_message1".to_owned()), 70 | LLMClientMessage::assistant("assistant_message1".to_owned()), 71 | LLMClientMessage::user("user_message2".to_owned()), 72 | ]; 73 | let deepseek_formatting = DeepSeekCoderFormatting::new(); 74 | assert_eq!(deepseek_formatting.to_prompt(messages), "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nuser_message1\n### Response:\nassistant_message1\n<|EOT|>\n### Instruction:\nuser_message2\n"); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /llm_client/src/format/mistral.rs: -------------------------------------------------------------------------------- 1 | use crate::clients::types::LLMClientMessage; 2 | 3 | use super::types::{LLMFormatting, TokenizerConfig, TokenizerError}; 4 | 5 | pub struct MistralInstructFormatting { 6 | tokenizer_config: TokenizerConfig, 7 | } 8 | 9 | impl MistralInstructFormatting { 10 | pub fn new() -> Result { 11 | let config = include_str!("tokenizer_config/mistral.json"); 12 | let tokenizer_config = serde_json::from_str::(config)?; 13 | Ok(Self { tokenizer_config }) 14 | } 15 | } 16 | 17 | impl LLMFormatting for MistralInstructFormatting { 18 | fn to_prompt(&self, messages: Vec) -> String { 19 | // we want to convert the message to mistral format 20 | // persent here: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json#L31-L34 21 | // {{ bos_token }} 22 | // {% for message in messages %} 23 | // {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} 24 | // {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} 25 | // {% endif %} 26 | // {% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }} 27 | // {% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }} 28 | // {% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" 29 | // First the messages have to be alternating, if that's not enforced then we run into problems 30 | // but since thats the case, we can do something better, which is to to just send consecutive messages 31 | // from human and assistant together 32 | let formatted_message = messages 33 | .into_iter() 34 | .skip_while(|message| message.role().is_assistant()) 35 | .map(|message| { 36 | let content = message.content(); 37 | let eos_token = self.tokenizer_config.eos_token(); 38 | if message.role().is_system() || message.role().is_user() { 39 | format!("[INST] {content} [/INST]") 40 | } else if message.role().is_function() { 41 | // This will be formatted as a function call as well 42 | match message.get_function_call() { 43 | Some(function_call) => { 44 | let function_call = serde_json::to_string(function_call) 45 | .expect("serde deserialize to not fail"); 46 | format!("[INST] {function_call} [/INST]") 47 | } 48 | None => { 49 | // not entirely correct, we will make it better with more testing 50 | format!("[INST] {content} [/INST]") 51 | } 52 | } 53 | } else { 54 | // we are in an assistant message now, so we can have a function 55 | // call which we have to format 56 | match message.get_function_call() { 57 | Some(function_call) => { 58 | let function_call = serde_json::to_string(function_call) 59 | .expect("serde deserialize to not fail"); 60 | format!("{content}{function_call}{eos_token} ") 61 | } 62 | None => { 63 | format!("{content}{eos_token} ") 64 | } 65 | } 66 | } 67 | }) 68 | .collect::>() 69 | .join(""); 70 | format!("{formatted_message}") 71 | } 72 | } 73 | 74 | #[cfg(test)] 75 | mod tests { 76 | 77 | use crate::clients::types::LLMClientMessage; 78 | 79 | use super::LLMFormatting; 80 | use super::MistralInstructFormatting; 81 | 82 | #[test] 83 | fn test_formatting_works() { 84 | let messages = vec![ 85 | LLMClientMessage::user("user_msg1".to_owned()), 86 | LLMClientMessage::assistant("assistant_msg1".to_owned()), 87 | ]; 88 | let mistral_formatting = MistralInstructFormatting::new().unwrap(); 89 | assert_eq!( 90 | mistral_formatting.to_prompt(messages), 91 | "[INST] user_msg1 [/INST]assistant_msg1 ", 92 | ); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /llm_client/src/format/mixtral.rs: -------------------------------------------------------------------------------- 1 | use crate::clients::types::LLMClientMessage; 2 | 3 | use super::types::{LLMFormatting, TokenizerConfig, TokenizerError}; 4 | 5 | pub struct MixtralInstructFormatting { 6 | tokenizer_config: TokenizerConfig, 7 | } 8 | 9 | impl MixtralInstructFormatting { 10 | pub fn new() -> Result { 11 | let config = include_str!("tokenizer_config/mistral.json"); 12 | let tokenizer_config = serde_json::from_str::(config)?; 13 | Ok(Self { tokenizer_config }) 14 | } 15 | } 16 | 17 | impl LLMFormatting for MixtralInstructFormatting { 18 | fn to_prompt(&self, messages: Vec) -> String { 19 | // we want to convert the message to mistral format 20 | // persent here: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json 21 | // {{ bos_token }} 22 | // {% for message in messages %} 23 | // {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} 24 | // {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} 25 | // {% endif %} 26 | // {% if message['role'] == 'user' %} 27 | // {{ '[INST] ' + message['content'] + ' [/INST]' }} 28 | // {% elif message['role'] == 'assistant' %} 29 | // {{ message['content'] + eos_token}} 30 | // {% else %} 31 | // {{ raise_exception('Only user and assistant roles are supported!') }} 32 | // {% endif %}{% endfor %} 33 | // First the messages have to be alternating, if that's not enforced then we run into problems 34 | // but since thats the case, we can do something better, which is to to just send consecutive messages 35 | // from human and assistant together 36 | // for handling function calls we will define that in 2 ways: 37 | // if the message is about a function return: then we proxy that as a user message 38 | // if the message is about a function call, then we keep it as an assistant message and push the json 39 | let formatted_message = messages 40 | .into_iter() 41 | .skip_while(|message| message.role().is_assistant()) 42 | .map(|message| { 43 | let content = message.content(); 44 | let eos_token = self.tokenizer_config.eos_token(); 45 | if message.role().is_system() || message.role().is_user() { 46 | format!("[INST] {content} [/INST]") 47 | } else if message.role().is_function() { 48 | // This will be formatted as a function call as well 49 | match message.get_function_call() { 50 | Some(function_call) => { 51 | let function_call = serde_json::to_string(function_call) 52 | .expect("serde deserialize to not fail"); 53 | format!("[INST] {function_call} [/INST]") 54 | } 55 | None => { 56 | // not entirely correct, we will make it better with more testing 57 | format!("[INST] {content} [/INST]") 58 | } 59 | } 60 | } else { 61 | // we are in an assistant message now, so we can have a function 62 | // call which we have to format 63 | match message.get_function_call() { 64 | Some(function_call) => { 65 | let function_call = serde_json::to_string(function_call) 66 | .expect("serde deserialize to not fail"); 67 | format!("{content}{function_call}{eos_token}") 68 | } 69 | None => { 70 | format!("{content}{eos_token}") 71 | } 72 | } 73 | } 74 | }) 75 | .collect::>() 76 | .join(""); 77 | format!("{formatted_message}") 78 | } 79 | } 80 | 81 | #[cfg(test)] 82 | mod tests { 83 | 84 | use crate::clients::types::LLMClientMessage; 85 | 86 | use super::LLMFormatting; 87 | use super::MixtralInstructFormatting; 88 | 89 | #[test] 90 | fn test_formatting_works() { 91 | let messages = vec![ 92 | LLMClientMessage::user("user_msg1".to_owned()), 93 | LLMClientMessage::assistant("assistant_msg1".to_owned()), 94 | ]; 95 | let mistral_formatting = MixtralInstructFormatting::new().unwrap(); 96 | assert_eq!( 97 | mistral_formatting.to_prompt(messages), 98 | "[INST] user_msg1 [/INST]assistant_msg1", 99 | ); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /llm_client/src/clients/ollama.rs: -------------------------------------------------------------------------------- 1 | //! Ollama client here so we can send requests to it 2 | 3 | use async_trait::async_trait; 4 | use tokio::sync::mpsc::UnboundedSender; 5 | 6 | use crate::provider::LLMProviderAPIKeys; 7 | 8 | use super::types::LLMClient; 9 | use super::types::LLMClientCompletionRequest; 10 | use super::types::LLMClientCompletionResponse; 11 | use super::types::LLMClientCompletionStringRequest; 12 | use super::types::LLMClientError; 13 | use super::types::LLMType; 14 | 15 | pub struct OllamaClient { 16 | pub client: reqwest::Client, 17 | pub base_url: String, 18 | } 19 | 20 | #[derive(serde::Deserialize, Debug, Clone)] 21 | struct OllamaResponse { 22 | model: String, 23 | response: String, 24 | done: bool, 25 | } 26 | 27 | impl LLMType { 28 | pub fn to_ollama_model(&self) -> Result { 29 | match self { 30 | LLMType::MistralInstruct => Ok("mistral".to_owned()), 31 | LLMType::Mixtral => Ok("mixtral".to_owned()), 32 | _ => Err(LLMClientError::UnSupportedModel), 33 | } 34 | } 35 | } 36 | 37 | #[derive(serde::Serialize)] 38 | struct OllamaClientRequest { 39 | prompt: String, 40 | model: String, 41 | temperature: f32, 42 | stream: bool, 43 | raw: bool, 44 | #[serde(skip_serializing_if = "Option::is_none")] 45 | frequency_penalty: Option, 46 | } 47 | 48 | impl OllamaClientRequest { 49 | pub fn from_request(request: LLMClientCompletionRequest) -> Result { 50 | Ok(Self { 51 | prompt: request 52 | .messages() 53 | .into_iter() 54 | .map(|message| message.content().to_owned()) 55 | .collect::>() 56 | .join("\n"), 57 | model: request.model().to_ollama_model()?, 58 | temperature: request.temperature(), 59 | stream: true, 60 | raw: true, 61 | frequency_penalty: request.frequency_penalty(), 62 | }) 63 | } 64 | 65 | pub fn from_string_request( 66 | request: LLMClientCompletionStringRequest, 67 | ) -> Result { 68 | Ok(Self { 69 | prompt: request.prompt().to_owned(), 70 | model: request.model().to_ollama_model()?, 71 | temperature: request.temperature(), 72 | stream: true, 73 | raw: true, 74 | frequency_penalty: None, 75 | }) 76 | } 77 | } 78 | 79 | impl OllamaClient { 80 | pub fn new() -> Self { 81 | // ollama always runs on the following url: 82 | // http://localhost:11434/ 83 | Self { 84 | client: reqwest::Client::new(), 85 | base_url: "http://localhost:11434".to_owned(), 86 | } 87 | } 88 | 89 | pub fn generation_endpoint(&self) -> String { 90 | format!("{}/api/generate", self.base_url) 91 | } 92 | } 93 | 94 | #[async_trait] 95 | impl LLMClient for OllamaClient { 96 | fn client(&self) -> &crate::provider::LLMProvider { 97 | &crate::provider::LLMProvider::Ollama 98 | } 99 | 100 | async fn stream_completion( 101 | &self, 102 | _api_key: LLMProviderAPIKeys, 103 | request: LLMClientCompletionRequest, 104 | sender: tokio::sync::mpsc::UnboundedSender, 105 | ) -> Result { 106 | let ollama_request = OllamaClientRequest::from_request(request)?; 107 | let mut response = self 108 | .client 109 | .post(self.generation_endpoint()) 110 | .json(&ollama_request) 111 | .send() 112 | .await?; 113 | 114 | let mut buffered_string = "".to_owned(); 115 | while let Some(chunk) = response.chunk().await? { 116 | let value = serde_json::from_slice::(chunk.to_vec().as_slice())?; 117 | buffered_string.push_str(&value.response); 118 | sender.send(LLMClientCompletionResponse::new( 119 | buffered_string.to_owned(), 120 | Some(value.response), 121 | value.model, 122 | ))?; 123 | } 124 | Ok(buffered_string) 125 | } 126 | 127 | async fn completion( 128 | &self, 129 | api_key: LLMProviderAPIKeys, 130 | request: LLMClientCompletionRequest, 131 | ) -> Result { 132 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 133 | let result = self.stream_completion(api_key, request, sender).await?; 134 | Ok(result) 135 | } 136 | 137 | async fn stream_prompt_completion( 138 | &self, 139 | api_key: LLMProviderAPIKeys, 140 | request: LLMClientCompletionStringRequest, 141 | sender: UnboundedSender, 142 | ) -> Result { 143 | let ollama_request = OllamaClientRequest::from_string_request(request)?; 144 | let mut response = self 145 | .client 146 | .post(self.generation_endpoint()) 147 | .json(&ollama_request) 148 | .send() 149 | .await?; 150 | 151 | let mut buffered_string = "".to_owned(); 152 | while let Some(chunk) = response.chunk().await? { 153 | let value = serde_json::from_slice::(chunk.to_vec().as_slice())?; 154 | buffered_string.push_str(&value.response); 155 | sender.send(LLMClientCompletionResponse::new( 156 | buffered_string.to_owned(), 157 | Some(value.response), 158 | value.model, 159 | ))?; 160 | } 161 | Ok(buffered_string) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /llm_prompts/src/bin/mistral_inline_edit.rs: -------------------------------------------------------------------------------- 1 | //! We want to test the reranking with mistral 2 | 3 | use std::path::PathBuf; 4 | 5 | use llm_client::{ 6 | broker::LLMBroker, 7 | clients::types::{LLMClientCompletionStringRequest, LLMType}, 8 | config::LLMBrokerConfiguration, 9 | provider::{LLMProviderAPIKeys, TogetherAIProvider}, 10 | }; 11 | 12 | #[tokio::main] 13 | async fn main() { 14 | let prompt = r#"[INST] You are an expert software engineer. You have to perform edits in the selected code snippet following the user instruction in tags. 15 | You have been given code context below: 16 | 17 | Code Context above the selection: 18 | ```rust 19 | // FILEPATH: $/Users/skcd/scratch/dataset/commit_play/src/language/types.rs 20 | // BEGIN: abpxx6d04wxr 21 | use std::{collections::HashMap, fmt::Debug, path::Display, sync::Arc}; 22 | 23 | use derivative::Derivative; 24 | 25 | use crate::git::commit::GitCommit; 26 | 27 | use super::config::TreeSitterLanguageParsing; 28 | 29 | // These are always 0 indexed 30 | #[derive( 31 | Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq, std::hash::Hash, 32 | )] 33 | #[serde(rename_all = "camelCase")] 34 | pub struct Position { 35 | line: usize, 36 | character: usize, 37 | byte_offset: usize, 38 | } 39 | // END: abpxx6d04wxr 40 | ``` 41 | Code Context below the selection: 42 | ```rust 43 | // FILEPATH: $/Users/skcd/scratch/dataset/commit_play/src/language/types.rs 44 | // BEGIN: be15d9bcejpp 45 | impl Position { 46 | fn to_tree_sitter(&self) -> tree_sitter::Point { 47 | tree_sitter::Point::new(self.line, self.character) 48 | } 49 | 50 | pub fn from_tree_sitter_point(point: &tree_sitter::Point, byte_offset: usize) -> Self { 51 | Self { 52 | line: point.row, 53 | character: point.column, 54 | byte_offset, 55 | } 56 | } 57 | 58 | pub fn to_byte_offset(&self) -> usize { 59 | self.byte_offset 60 | } 61 | 62 | pub fn new(line: usize, character: usize, byte_offset: usize) -> Self { 63 | Self { 64 | line, 65 | character, 66 | byte_offset, 67 | } 68 | } 69 | 70 | pub fn line(&self) -> usize { 71 | self.line 72 | } 73 | 74 | pub fn column(&self) -> usize { 75 | self.character 76 | } 77 | 78 | pub fn set_byte_offset(&mut self, byte_offset: usize) { 79 | self.byte_offset = byte_offset; 80 | } 81 | 82 | pub fn from_byte(byte: usize, line_end_indices: &[u32]) -> Self { 83 | let line = line_end_indices 84 | .iter() 85 | .position(|&line_end_byte| (line_end_byte as usize) > byte) 86 | .unwrap_or(0); 87 | 88 | let column = line 89 | .checked_sub(1) 90 | .and_then(|idx| line_end_indices.get(idx)) 91 | .map(|&prev_line_end| byte.saturating_sub(prev_line_end as usize)) 92 | .unwrap_or(byte); 93 | 94 | Self::new(line, column, byte) 95 | } 96 | } 97 | 98 | #[derive( 99 | Debug, Clone, Copy, serde::Deserialize, serde::Serialize, PartialEq, Eq, std::hash::Hash, 100 | )] 101 | #[serde(rename_all = "camelCase")] 102 | pub struct Range { 103 | start_position: Position, 104 | end_position: Position, 105 | } 106 | 107 | impl Default for Range { 108 | fn default() -> Self { 109 | Self { 110 | start_position: Position::new(0, 0, 0), 111 | end_position: Position::new(0, 0, 0), 112 | } 113 | } 114 | } 115 | 116 | impl Range { 117 | pub fn new(start_position: Position, end_position: Position) -> Self { 118 | Self { 119 | start_position, 120 | end_position, 121 | } 122 | } 123 | // END: be15d9bcejpp 124 | ``` 125 | 126 | Code you have to edit: 127 | ```rust 128 | // FILEPATH: $/Users/skcd/scratch/dataset/commit_play/src/language/types.rs 129 | // BEGIN: ed8c6549bwf9 130 | impl Into for Position { 131 | fn into(self) -> tree_sitter::Point { 132 | self.to_tree_sitter() 133 | } 134 | } 135 | // END: ed8c6549bwf9 136 | ``` 137 | 138 | Rewrite the code enclosed in // BEGIN: ed8c6549bwf9 and // END: ed8c6549bwf9 following the user query without any explanation [/INST] 139 | The user has instructed me to perform the following edits on the selection: 140 | can you add comments all over the function body? 141 | 142 | The edited code is: 143 | ```rust 144 | // FILEPATH: /Users/skcd/scratch/dataset/commit_play/src/language/types.rs 145 | // BEGIN: ed8c6549bwf9 146 | "#; 147 | let llm_broker = LLMBroker::new(LLMBrokerConfiguration::new(PathBuf::from( 148 | "/Users/skcd/Library/Application Support/ai.codestory.sidecar", 149 | ))) 150 | .await 151 | .expect("broker to startup"); 152 | 153 | let api_key = 154 | LLMProviderAPIKeys::TogetherAI(TogetherAIProvider::new("some_key_here".to_owned())); 155 | let request = LLMClientCompletionStringRequest::new( 156 | LLMType::MistralInstruct, 157 | prompt.to_owned(), 158 | 1.0, 159 | None, 160 | ); 161 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 162 | let metadata = vec![("event_type".to_owned(), "listwise_reranking".to_owned())] 163 | .into_iter() 164 | .collect(); 165 | let result = llm_broker 166 | .stream_string_completion(api_key.clone(), request, metadata, sender) 167 | .await; 168 | println!("Mistral:"); 169 | println!("{:?}", result); 170 | let mixtral_request = 171 | LLMClientCompletionStringRequest::new(LLMType::Mixtral, prompt.to_owned(), 0.7, None); 172 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 173 | let metadata = vec![("event_type".to_owned(), "listwise_reranking".to_owned())] 174 | .into_iter() 175 | .collect(); 176 | let result = llm_broker 177 | .stream_string_completion(api_key, mixtral_request, metadata, sender) 178 | .await; 179 | println!("Mixtral:"); 180 | println!("{:?}", result); 181 | } 182 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/types.rs: -------------------------------------------------------------------------------- 1 | //! The various interfaces for prompt declaration we have for the in line agent 2 | //! chat. We take care to send the data here properly (after filtering/reranking etc) 3 | //! and let the LLM decide what we want to do with it 4 | 5 | use llm_client::clients::types::LLMClientMessage; 6 | 7 | pub enum InLineDocNode { 8 | /// This might just be a selection of code 9 | Selection, 10 | /// We might have a single symbol in the selection 11 | Node(String), 12 | } 13 | 14 | pub struct InLineDocRequest { 15 | in_range: String, 16 | is_identifier_node: InLineDocNode, 17 | language: String, 18 | file_path: String, 19 | } 20 | 21 | impl InLineDocRequest { 22 | pub fn new( 23 | in_range: String, 24 | is_identifier_node: InLineDocNode, 25 | language: String, 26 | file_path: String, 27 | ) -> Self { 28 | Self { 29 | in_range, 30 | is_identifier_node, 31 | language, 32 | file_path, 33 | } 34 | } 35 | 36 | pub fn file_path(&self) -> &str { 37 | &self.file_path 38 | } 39 | 40 | pub fn language(&self) -> &str { 41 | &self.language 42 | } 43 | 44 | pub fn identifier_node(&self) -> &InLineDocNode { 45 | &self.is_identifier_node 46 | } 47 | 48 | pub fn in_range(&self) -> &str { 49 | &self.in_range 50 | } 51 | 52 | pub fn is_identifier_node(&self) -> bool { 53 | match self.is_identifier_node { 54 | InLineDocNode::Node(_) => true, 55 | InLineDocNode::Selection => false, 56 | } 57 | } 58 | 59 | pub fn identifier_node_str(&self) -> Option<&str> { 60 | match self.is_identifier_node { 61 | InLineDocNode::Node(ref node) => Some(node), 62 | InLineDocNode::Selection => None, 63 | } 64 | } 65 | } 66 | 67 | pub struct InLineFixRequest { 68 | above: Option, 69 | below: Option, 70 | in_range: String, 71 | diagnostics_prompts: Vec, 72 | language: String, 73 | file_path: String, 74 | } 75 | 76 | impl InLineFixRequest { 77 | pub fn new( 78 | above: Option, 79 | below: Option, 80 | in_range: String, 81 | diagnostics_prompts: Vec, 82 | language: String, 83 | file_path: String, 84 | ) -> Self { 85 | Self { 86 | above, 87 | below, 88 | in_range, 89 | diagnostics_prompts, 90 | language, 91 | file_path, 92 | } 93 | } 94 | 95 | pub fn above(&self) -> Option<&String> { 96 | self.above.as_ref() 97 | } 98 | 99 | pub fn below(&self) -> Option<&String> { 100 | self.below.as_ref() 101 | } 102 | 103 | pub fn in_range(&self) -> &str { 104 | self.in_range.as_ref() 105 | } 106 | 107 | pub fn diagnostics_prompts(&self) -> &[String] { 108 | &self.diagnostics_prompts 109 | } 110 | 111 | pub fn language(&self) -> &str { 112 | &self.language 113 | } 114 | 115 | pub fn file_path(&self) -> &str { 116 | &self.file_path 117 | } 118 | } 119 | 120 | #[derive(Debug, serde::Serialize, serde::Deserialize)] 121 | pub struct InLineEditRequest { 122 | above: Option, 123 | below: Option, 124 | in_range: Option, 125 | user_query: String, 126 | file_path: String, 127 | /// The extra symbols or data which the user has passed as reference 128 | extra_data: Vec, 129 | language: String, 130 | } 131 | 132 | impl InLineEditRequest { 133 | pub fn above(&self) -> Option<&String> { 134 | self.above.as_ref() 135 | } 136 | 137 | pub fn below(&self) -> Option<&String> { 138 | self.below.as_ref() 139 | } 140 | 141 | pub fn in_range(&self) -> Option<&String> { 142 | self.in_range.as_ref() 143 | } 144 | 145 | pub fn user_query(&self) -> &str { 146 | &self.user_query 147 | } 148 | 149 | pub fn file_path(&self) -> &str { 150 | &self.file_path 151 | } 152 | 153 | pub fn extra_data(&self) -> &[String] { 154 | &self.extra_data 155 | } 156 | 157 | pub fn language(&self) -> &str { 158 | &self.language 159 | } 160 | } 161 | 162 | impl InLineEditRequest { 163 | pub fn new( 164 | above: Option, 165 | below: Option, 166 | in_range: Option, 167 | user_query: String, 168 | file_path: String, 169 | extra_data: Vec, 170 | language: String, 171 | ) -> Self { 172 | Self { 173 | above, 174 | below, 175 | in_range, 176 | user_query, 177 | file_path, 178 | extra_data, 179 | language, 180 | } 181 | } 182 | } 183 | 184 | /// We might end up calling the chat or the completion endpoint for a LLM, 185 | /// its important that we support both 186 | #[derive(Debug)] 187 | pub enum InLinePromptResponse { 188 | Completion(String), 189 | Chat(Vec), 190 | } 191 | 192 | impl InLinePromptResponse { 193 | pub fn completion(completion: String) -> Self { 194 | InLinePromptResponse::Completion(completion) 195 | } 196 | 197 | pub fn get_completion(self) -> Option { 198 | if let InLinePromptResponse::Completion(completion) = self { 199 | Some(completion) 200 | } else { 201 | None 202 | } 203 | } 204 | } 205 | 206 | /// Should we send context here as the above, below and in line context, or do we 207 | /// just send the data as it is? 208 | pub trait InLineEditPrompt { 209 | fn inline_edit(&self, request: InLineEditRequest) -> InLinePromptResponse; 210 | 211 | fn inline_fix(&self, request: InLineFixRequest) -> InLinePromptResponse; 212 | 213 | fn inline_doc(&self, request: InLineDocRequest) -> InLinePromptResponse; 214 | } 215 | 216 | /// The error type which we will return if we do not support that model yet 217 | #[derive(thiserror::Error, Debug)] 218 | pub enum InLineEditPromptError { 219 | #[error("Model not supported yet")] 220 | ModelNotSupported, 221 | } 222 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/mistral.rs: -------------------------------------------------------------------------------- 1 | use super::doc_helpers::documentation_type; 2 | use super::doc_helpers::selection_type; 3 | use super::types::InLineDocRequest; 4 | use super::types::InLineEditPrompt; 5 | use super::types::InLineEditRequest; 6 | use super::types::InLineFixRequest; 7 | use super::types::InLinePromptResponse; 8 | 9 | pub struct MistralLineEditPrompt {} 10 | 11 | impl MistralLineEditPrompt { 12 | pub fn new() -> Self { 13 | Self {} 14 | } 15 | } 16 | 17 | impl MistralLineEditPrompt { 18 | fn extra_code_context(&self, extra_data: &[String]) -> String { 19 | if extra_data.is_empty() { 20 | String::new() 21 | } else { 22 | let extra_data_str = extra_data.join("\n"); 23 | let extra_data_prompt = format!( 24 | "The following code context has been provided to you: 25 | {extra_data_str}\n" 26 | ); 27 | extra_data_prompt 28 | } 29 | } 30 | 31 | /// We try to get the inline edit prompt for the code here, so we can ask 32 | /// the LLM to generate the prompt 33 | fn code_context(&self, above: Option<&String>, below: Option<&String>) -> String { 34 | // Do we have some code context above? 35 | let above = if let Some(above) = above { 36 | format!( 37 | r#"Code Context above the selection: 38 | {above} 39 | "# 40 | ) 41 | } else { 42 | String::new() 43 | }; 44 | 45 | // Do we have some code context below? 46 | let below = if let Some(below) = below { 47 | format!( 48 | r#"Code Context below the selection: 49 | {below} 50 | "# 51 | ) 52 | } else { 53 | String::new() 54 | }; 55 | 56 | // We send this context to the LLM 57 | let code_context = format!(r#"{above}{below}"#,); 58 | code_context 59 | } 60 | } 61 | 62 | impl InLineEditPrompt for MistralLineEditPrompt { 63 | fn inline_edit(&self, request: InLineEditRequest) -> InLinePromptResponse { 64 | let extra_data_context = self.extra_code_context(request.extra_data()); 65 | let code_context = self.code_context(request.above(), request.below()); 66 | let user_query = request.user_query(); 67 | let language = request.language(); 68 | let file_path = request.file_path(); 69 | let (selection_context, extra_instruction) = if let Some(in_range_code_context) = 70 | request.in_range() 71 | { 72 | ( 73 | format!( 74 | r#"Your task is to rewrite the code below following the instruction: {user_query} 75 | Code you have to edit: 76 | {in_range_code_context}"# 77 | ), 78 | "Rewrite the code [/INST]", 79 | ) 80 | } else { 81 | ( 82 | format!(r#"Follow the user instruction and generate code: {user_query}"#), 83 | "Generate the code [/INST]", 84 | ) 85 | }; 86 | // Now we want to create the prompt for the LLM 87 | let prompt = format!( 88 | r#"[INST] You are an expert software engineer. You have been given some code context below: 89 | {extra_data_context} 90 | {code_context} 91 | {selection_context} 92 | 93 | {extra_instruction} 94 | ```{language} 95 | // FILEPATH: {file_path} 96 | // BEGIN: ed8c6549bwf9 97 | "# 98 | ); 99 | InLinePromptResponse::completion(prompt) 100 | } 101 | 102 | fn inline_fix(&self, request: InLineFixRequest) -> InLinePromptResponse { 103 | let code_context = self.code_context(request.above(), request.below()); 104 | let language = request.language(); 105 | let errors = request.diagnostics_prompts().join("\n"); 106 | let in_range_code_context = request.in_range(); 107 | let file_path = request.file_path(); 108 | let selection_context = format!( 109 | r#"Your task is to fix the errors in the code using the errors provided 110 | {errors} 111 | 112 | Code you have to edit: 113 | {in_range_code_context}"# 114 | ); 115 | let prompt = format!( 116 | r#"[INST] You are an expert software engineer. You have to fix the errors present in the code, the context is given below: 117 | {code_context} 118 | 119 | {selection_context} 120 | 121 | You have to fix the code below, generate the code without any explanation [/INST] 122 | ```{language} 123 | // FILEPATH: {file_path} 124 | // BEGIN: ed8c6549bwf9 125 | "# 126 | ); 127 | InLinePromptResponse::completion(prompt) 128 | } 129 | 130 | fn inline_doc(&self, request: InLineDocRequest) -> InLinePromptResponse { 131 | let comment_type = documentation_type(&request); 132 | let selection_type = selection_type(&request); 133 | let in_range = request.in_range(); 134 | let language = request.language(); 135 | let file_path = request.file_path(); 136 | let prompt = format!( 137 | r#"[INST] You are an expert software engineer. You have to generate {comment_type} for {selection_type}, the {selection_type} is given below: 138 | {in_range} 139 | 140 | Add {comment_type} and generate the selected code, do not for the // END marker [/INST] 141 | ```{language} 142 | // FILEPATH: {file_path} 143 | // BEGIN: ed8c6549bwf9 144 | "# 145 | ); 146 | InLinePromptResponse::Completion(prompt) 147 | } 148 | } 149 | 150 | #[cfg(test)] 151 | mod tests { 152 | 153 | use super::InLineEditPrompt; 154 | use super::InLineEditRequest; 155 | use super::MistralLineEditPrompt; 156 | 157 | #[test] 158 | fn test_inline_edit_prompt() { 159 | let prompt = MistralLineEditPrompt {}; 160 | let request = InLineEditRequest::new( 161 | Some("above_context".to_owned()), 162 | Some("below_context".to_owned()), 163 | Some("in_range_context".to_owned()), 164 | "user_query".to_owned(), 165 | "testing/path/something.rs".to_owned(), 166 | vec!["first_symbol".to_owned()], 167 | "rust".to_owned(), 168 | ); 169 | let prompt = prompt.inline_edit(request); 170 | let expected_output = r#"[INST] You are an expert software engineer. You have been given some code context below: 171 | The following code context has been provided to you: 172 | first_symbol 173 | 174 | Code Context above the selection: 175 | above_context 176 | Code Context below the selection: 177 | below_context 178 | 179 | Your task is to rewrite the code below following the instruction: user_query 180 | Code you have to edit: 181 | in_range_context 182 | 183 | Rewrite the code without any explanation [/INST] 184 | ```rust 185 | // FILEPATH: testing/path/something.rs 186 | // BEGIN: 187 | "#; 188 | assert_eq!( 189 | prompt.get_completion().expect("to have completion type"), 190 | expected_output 191 | ); 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /llm_client/src/provider.rs: -------------------------------------------------------------------------------- 1 | //! Contains types for setting the provider for the LLM, we are going to support 2 | //! 3 things for now: 3 | //! - CodeStory 4 | //! - OpenAI 5 | //! - Ollama 6 | //! - Azure 7 | //! - together.ai 8 | 9 | use crate::clients::types::LLMType; 10 | 11 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Hash, PartialEq, Eq)] 12 | pub struct AzureOpenAIDeploymentId { 13 | pub deployment_id: String, 14 | } 15 | 16 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Hash, PartialEq, Eq)] 17 | pub struct CodeStoryLLMType { 18 | // shoe horning the llm type here so we can provide the correct api keys 19 | pub llm_type: Option, 20 | } 21 | 22 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Hash, PartialEq, Eq)] 23 | pub enum LLMProvider { 24 | OpenAI, 25 | TogetherAI, 26 | Ollama, 27 | LMStudio, 28 | CodeStory(CodeStoryLLMType), 29 | Azure(AzureOpenAIDeploymentId), 30 | } 31 | 32 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 33 | pub enum LLMProviderAPIKeys { 34 | OpenAI(OpenAIProvider), 35 | TogetherAI(TogetherAIProvider), 36 | Ollama(OllamaProvider), 37 | OpenAIAzureConfig(AzureConfig), 38 | LMStudio(LMStudioConfig), 39 | CodeStory, 40 | } 41 | 42 | impl LLMProviderAPIKeys { 43 | pub fn is_openai(&self) -> bool { 44 | matches!(self, LLMProviderAPIKeys::OpenAI(_)) 45 | } 46 | 47 | pub fn provider_type(&self) -> LLMProvider { 48 | match self { 49 | LLMProviderAPIKeys::OpenAI(_) => LLMProvider::OpenAI, 50 | LLMProviderAPIKeys::TogetherAI(_) => LLMProvider::TogetherAI, 51 | LLMProviderAPIKeys::Ollama(_) => LLMProvider::Ollama, 52 | LLMProviderAPIKeys::OpenAIAzureConfig(_) => { 53 | LLMProvider::Azure(AzureOpenAIDeploymentId { 54 | deployment_id: "".to_owned(), 55 | }) 56 | } 57 | LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio, 58 | LLMProviderAPIKeys::CodeStory => { 59 | LLMProvider::CodeStory(CodeStoryLLMType { llm_type: None }) 60 | } 61 | } 62 | } 63 | 64 | // Gets the relevant key from the llm provider 65 | pub fn key(&self, llm_provider: &LLMProvider) -> Option { 66 | match llm_provider { 67 | LLMProvider::OpenAI => { 68 | if let LLMProviderAPIKeys::OpenAI(key) = self { 69 | Some(LLMProviderAPIKeys::OpenAI(key.clone())) 70 | } else { 71 | None 72 | } 73 | } 74 | LLMProvider::TogetherAI => { 75 | if let LLMProviderAPIKeys::TogetherAI(key) = self { 76 | Some(LLMProviderAPIKeys::TogetherAI(key.clone())) 77 | } else { 78 | None 79 | } 80 | } 81 | LLMProvider::Ollama => { 82 | if let LLMProviderAPIKeys::Ollama(key) = self { 83 | Some(LLMProviderAPIKeys::Ollama(key.clone())) 84 | } else { 85 | None 86 | } 87 | } 88 | LLMProvider::LMStudio => { 89 | if let LLMProviderAPIKeys::LMStudio(key) = self { 90 | Some(LLMProviderAPIKeys::LMStudio(key.clone())) 91 | } else { 92 | None 93 | } 94 | } 95 | // Azure is weird, so we are have to copy the config which we get 96 | // from the provider keys and then set the deployment id of it 97 | // properly for the azure provider, if its set to "" that means 98 | // we do not have a deployment key and we should be returning quickly 99 | // here. 100 | // NOTE: We should change this to using the codestory configuration 101 | // and make calls appropriately, for now this is fine 102 | LLMProvider::Azure(deployment_id) => { 103 | if deployment_id.deployment_id == "" { 104 | return None; 105 | } 106 | if let LLMProviderAPIKeys::OpenAIAzureConfig(key) = self { 107 | let mut azure_config = key.clone(); 108 | azure_config.deployment_id = deployment_id.deployment_id.to_owned(); 109 | Some(LLMProviderAPIKeys::OpenAIAzureConfig(azure_config)) 110 | } else { 111 | None 112 | } 113 | } 114 | LLMProvider::CodeStory(_) => Some(LLMProviderAPIKeys::CodeStory), 115 | } 116 | } 117 | } 118 | 119 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 120 | pub struct OpenAIProvider { 121 | pub api_key: String, 122 | } 123 | 124 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 125 | pub struct TogetherAIProvider { 126 | pub api_key: String, 127 | } 128 | 129 | impl TogetherAIProvider { 130 | pub fn new(api_key: String) -> Self { 131 | Self { api_key } 132 | } 133 | } 134 | 135 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 136 | pub struct OllamaProvider {} 137 | 138 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 139 | pub struct AzureConfig { 140 | pub deployment_id: String, 141 | pub api_base: String, 142 | pub api_key: String, 143 | pub api_version: String, 144 | } 145 | 146 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 147 | pub struct LMStudioConfig { 148 | pub api_base: String, 149 | } 150 | 151 | impl LMStudioConfig { 152 | pub fn api_base(&self) -> &str { 153 | &self.api_base 154 | } 155 | } 156 | 157 | #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 158 | pub struct CodeStoryConfig { 159 | pub llm_type: LLMType, 160 | } 161 | 162 | #[cfg(test)] 163 | mod tests { 164 | use super::{AzureOpenAIDeploymentId, LLMProvider, LLMProviderAPIKeys}; 165 | 166 | #[test] 167 | fn test_reading_from_string_for_provider() { 168 | let provider = LLMProvider::Azure(AzureOpenAIDeploymentId { 169 | deployment_id: "testing".to_owned(), 170 | }); 171 | let string_provider = serde_json::to_string(&provider).expect("to work"); 172 | assert_eq!( 173 | string_provider, 174 | "{\"Azure\":{\"deployment_id\":\"testing\"}}" 175 | ); 176 | let provider = LLMProvider::Ollama; 177 | let string_provider = serde_json::to_string(&provider).expect("to work"); 178 | assert_eq!(string_provider, "\"Ollama\""); 179 | } 180 | 181 | #[test] 182 | fn test_reading_provider_keys() { 183 | let provider_keys = LLMProviderAPIKeys::OpenAI(super::OpenAIProvider { 184 | api_key: "testing".to_owned(), 185 | }); 186 | let string_provider_keys = serde_json::to_string(&provider_keys).expect("to work"); 187 | assert_eq!(string_provider_keys, "",); 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /llm_client/src/clients/codestory.rs: -------------------------------------------------------------------------------- 1 | use async_openai::types::CreateChatCompletionStreamResponse; 2 | use async_trait::async_trait; 3 | use eventsource_stream::Eventsource; 4 | use futures::StreamExt; 5 | use tokio::sync::mpsc::UnboundedSender; 6 | 7 | use crate::provider::{LLMProvider, LLMProviderAPIKeys}; 8 | 9 | use super::types::{ 10 | LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, 11 | LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, LLMType, 12 | }; 13 | 14 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 15 | struct LMStudioResponse { 16 | model: String, 17 | choices: Vec, 18 | } 19 | 20 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 21 | struct Choice { 22 | text: String, 23 | } 24 | 25 | pub struct CodeStoryClient { 26 | client: reqwest::Client, 27 | api_base: String, 28 | } 29 | 30 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 31 | struct CodeStoryMessage { 32 | role: String, 33 | content: String, 34 | } 35 | 36 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 37 | struct CodeStoryRequestOptions { 38 | temperature: f32, 39 | } 40 | 41 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 42 | struct CodeStoryRequest { 43 | messages: Vec, 44 | options: CodeStoryRequestOptions, 45 | } 46 | 47 | impl CodeStoryRequest { 48 | fn from_chat_request(request: LLMClientCompletionRequest) -> Self { 49 | Self { 50 | messages: request 51 | .messages() 52 | .into_iter() 53 | .map(|message| match message.role() { 54 | LLMClientRole::System => CodeStoryMessage { 55 | role: "system".to_owned(), 56 | content: message.content().to_owned(), 57 | }, 58 | LLMClientRole::User => CodeStoryMessage { 59 | role: "user".to_owned(), 60 | content: message.content().to_owned(), 61 | }, 62 | LLMClientRole::Function => CodeStoryMessage { 63 | role: "function".to_owned(), 64 | content: message.content().to_owned(), 65 | }, 66 | LLMClientRole::Assistant => CodeStoryMessage { 67 | role: "assistant".to_owned(), 68 | content: message.content().to_owned(), 69 | }, 70 | }) 71 | .collect(), 72 | options: CodeStoryRequestOptions { 73 | temperature: request.temperature(), 74 | }, 75 | } 76 | } 77 | } 78 | 79 | impl CodeStoryClient { 80 | pub fn new(api_base: &str) -> Self { 81 | Self { 82 | api_base: api_base.to_owned(), 83 | client: reqwest::Client::new(), 84 | } 85 | } 86 | 87 | pub fn gpt3_endpoint(&self, api_base: &str) -> String { 88 | format!("{api_base}/chat-3") 89 | } 90 | 91 | pub fn gpt4_endpoint(&self, api_base: &str) -> String { 92 | format!("{api_base}/chat-4") 93 | } 94 | 95 | pub fn model_name(&self, model: &LLMType) -> Result { 96 | match model { 97 | LLMType::GPT3_5_16k => Ok("gpt-3.5-turbo-16k-0613".to_owned()), 98 | LLMType::Gpt4 => Ok("gpt-4-0613".to_owned()), 99 | _ => Err(LLMClientError::UnSupportedModel), 100 | } 101 | } 102 | 103 | pub fn model_endpoint(&self, model: &LLMType) -> Result { 104 | match model { 105 | LLMType::GPT3_5_16k => Ok(self.gpt3_endpoint(&self.api_base)), 106 | LLMType::Gpt4 => Ok(self.gpt4_endpoint(&self.api_base)), 107 | _ => Err(LLMClientError::UnSupportedModel), 108 | } 109 | } 110 | } 111 | 112 | #[async_trait] 113 | impl LLMClient for CodeStoryClient { 114 | fn client(&self) -> &LLMProvider { 115 | &LLMProvider::LMStudio 116 | } 117 | 118 | async fn completion( 119 | &self, 120 | api_key: LLMProviderAPIKeys, 121 | request: LLMClientCompletionRequest, 122 | ) -> Result { 123 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 124 | self.stream_completion(api_key, request, sender).await 125 | } 126 | 127 | async fn stream_completion( 128 | &self, 129 | _api_key: LLMProviderAPIKeys, 130 | request: LLMClientCompletionRequest, 131 | sender: UnboundedSender, 132 | ) -> Result { 133 | let model = self.model_name(request.model())?; 134 | let endpoint = self.model_endpoint(request.model())?; 135 | 136 | let request = CodeStoryRequest::from_chat_request(request); 137 | let mut response_stream = self 138 | .client 139 | .post(endpoint) 140 | .json(&request) 141 | .send() 142 | .await? 143 | .bytes_stream() 144 | .eventsource(); 145 | 146 | let mut buffered_stream = "".to_owned(); 147 | while let Some(event) = response_stream.next().await { 148 | match event { 149 | Ok(event) => { 150 | if &event.data == "[DONE]" { 151 | continue; 152 | } 153 | // we just proxy back the openai response back here 154 | let response = 155 | serde_json::from_str::(&event.data); 156 | match response { 157 | Ok(response) => { 158 | let delta = response 159 | .choices 160 | .get(0) 161 | .map(|choice| choice.delta.content.to_owned()) 162 | .flatten() 163 | .unwrap_or("".to_owned()); 164 | buffered_stream.push_str(&delta); 165 | println!("{}", &buffered_stream); 166 | sender.send(LLMClientCompletionResponse::new( 167 | buffered_stream.to_owned(), 168 | Some(delta), 169 | model.to_owned(), 170 | ))?; 171 | } 172 | Err(e) => { 173 | dbg!(e); 174 | } 175 | } 176 | } 177 | Err(e) => { 178 | dbg!(e); 179 | } 180 | } 181 | } 182 | Ok(buffered_stream) 183 | } 184 | 185 | async fn stream_prompt_completion( 186 | &self, 187 | api_key: LLMProviderAPIKeys, 188 | request: LLMClientCompletionStringRequest, 189 | sender: UnboundedSender, 190 | ) -> Result { 191 | Err(LLMClientError::UnSupportedModel) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /llm_prompts/src/reranking/openai.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use llm_client::clients::types::{LLMClientCompletionRequest, LLMClientMessage}; 4 | 5 | use super::types::{ 6 | CodeSpan, CodeSpanDigest, ReRankCodeSpan, ReRankCodeSpanError, ReRankCodeSpanRequest, 7 | ReRankCodeSpanResponse, ReRankListWiseResponse, ReRankPointWisePrompt, ReRankStrategy, 8 | }; 9 | 10 | pub struct OpenAIReRank {} 11 | 12 | impl OpenAIReRank { 13 | pub fn new() -> Self { 14 | Self {} 15 | } 16 | } 17 | 18 | impl OpenAIReRank { 19 | pub fn pointwise_reranking(&self, request: ReRankCodeSpanRequest) -> ReRankCodeSpanResponse { 20 | let code_span_digests = CodeSpan::to_digests(request.code_spans().to_vec()); 21 | // Now we query the LLM for the pointwise reranking here 22 | let user_query = request.user_query().to_owned(); 23 | let prompts = code_span_digests 24 | .into_iter() 25 | .map(|code_span_digest| { 26 | let user_query = user_query.to_owned(); 27 | let hash = code_span_digest.hash(); 28 | let data = code_span_digest.data(); 29 | let prompt = format!(r#"You are an expert software developer responsible for helping detect whether the retrieved snippet of code is relevant to the query. For a given input, you need to output a single word: "Yes" or "No" indicating the retrieved snippet is relevant to the query. 30 | Query: Where is the client for OpenAI defined? 31 | Code Snippet: 32 | ```/Users/skcd/client/openai.rs 33 | pub struct OpenAIClient {{}} 34 | 35 | impl OpenAIClient {{ 36 | pub fn new() -> Self {{ 37 | Self {{}} 38 | }} 39 | ``` 40 | Relevant: Yes 41 | 42 | Query: Where do we handle the errors in the webview? 43 | Snippet: 44 | ```/Users/skcd/algorithm/dfs.rs 45 | pub fn dfs(graph: &Graph, start: NodeId) -> Vec {{ 46 | let mut visited = HashSet::new(); 47 | let mut stack = vec![start]; 48 | let mut result = vec![]; 49 | while let Some(node) = stack.pop() {{ 50 | if visited.contains(&node) {{ 51 | continue; 52 | }} 53 | visited.insert(node); 54 | result.push(node); 55 | for neighbor in graph.neighbors(node) {{ 56 | stack.push(neighbor); 57 | }} 58 | }} 59 | result 60 | }} 61 | ``` 62 | Relevant: No 63 | 64 | Query: {user_query} 65 | Snippet: 66 | ```{hash} 67 | {data} 68 | ``` 69 | Relevant:"#); 70 | let llm_prompt = LLMClientCompletionRequest::from_messages( 71 | vec![LLMClientMessage::system(prompt)], 72 | request.llm_type().clone(), 73 | ); 74 | ReRankPointWisePrompt::new_message_request(llm_prompt, code_span_digest) 75 | }) 76 | .collect(); 77 | 78 | ReRankCodeSpanResponse::PointWise(prompts) 79 | } 80 | 81 | pub fn listwise_reranking(&self, request: ReRankCodeSpanRequest) -> ReRankCodeSpanResponse { 82 | // First we get the code spans which are present here cause they are important 83 | let code_spans = request.code_spans().to_vec(); 84 | let user_query = request.user_query().to_owned(); 85 | // Now we need to generate the prompt for this 86 | let code_span_digests = CodeSpan::to_digests(code_spans); 87 | let code_snippets = code_span_digests 88 | .iter() 89 | .map(|code_span_digest| { 90 | let identifier = code_span_digest.hash(); 91 | let data = code_span_digest.data(); 92 | let span_identifier = code_span_digest.get_span_identifier(); 93 | format!("{identifier}\n```\n{span_identifier}\n{data}\n```\n") 94 | }) 95 | .collect::>() 96 | .join("\n"); 97 | // Now we create the prompt for this reranking 98 | let prompt = format!( 99 | r#"You are an expert at ranking the code snippets for the user query. You have the order the list of code snippets from the most relevant to the least relevant. As an example 100 | 101 | add.rs::0 102 | ``` 103 | // FILEPATH: add.rs:0-2 104 | fn add(a: i32, b: i32) -> i32 {{ 105 | a + b 106 | }} 107 | ``` 108 | 109 | subtract.rs::0 110 | ``` 111 | // FILEPATH: subtract.rs:0-2 112 | fn subtract(a: i32, b: i32) -> i32 {{ 113 | a - b 114 | }} 115 | ``` 116 | 117 | 118 | And if you thought the code snippet add.rs::0 is more relevant than subtract.rs::0 then you would rank it as: 119 | 120 | add.rs::0 121 | subtract.rs::0 122 | 123 | 124 | The user query might contain a selection of line ranges in the following format: 125 | [#file:foo.rs:4-10](values:file:foo.rs:4-10) this means the line range from 4 to 10 is selected by the user in the file foo.rs 126 | 127 | The user has asked the following query: {user_query} 128 | 129 | {code_snippets} 130 | 131 | 132 | As a reminder the user query is: 133 | 134 | {user_query} 135 | 136 | 137 | The final reranking ordered from the most relevant to the least relevant is: 138 | "# 139 | ); 140 | let llm_prompt = LLMClientCompletionRequest::from_messages( 141 | vec![LLMClientMessage::system(prompt)], 142 | request.llm_type().clone(), 143 | ); 144 | ReRankCodeSpanResponse::listwise_message(llm_prompt, code_span_digests) 145 | } 146 | } 147 | 148 | impl ReRankCodeSpan for OpenAIReRank { 149 | fn rerank_prompt( 150 | &self, 151 | request: ReRankCodeSpanRequest, 152 | ) -> Result { 153 | Ok(match request.strategy() { 154 | ReRankStrategy::ListWise => self.listwise_reranking(request), 155 | ReRankStrategy::PointWise => { 156 | // We need to generate the prompt for this 157 | self.pointwise_reranking(request) 158 | } 159 | }) 160 | } 161 | 162 | fn parse_listwise_output( 163 | &self, 164 | llm_output: String, 165 | rerank_request: ReRankListWiseResponse, 166 | ) -> Result, ReRankCodeSpanError> { 167 | // In case of OpenAI things are a bit easier, since the list is properly formatted 168 | // almost always and we can just grab the ids from the list and rank the 169 | // code snippets based that. 170 | let mut output = llm_output.split("\n"); 171 | let mut code_spans_mapping: HashMap = rerank_request 172 | .code_span_digests 173 | .into_iter() 174 | .map(|code_span_digest| (code_span_digest.hash().to_owned(), code_span_digest)) 175 | .collect(); 176 | let mut reranked_code_snippets: Vec = vec![]; 177 | while let Some(line) = output.next() { 178 | let line_output = line.trim(); 179 | if line_output.contains("") { 180 | break; 181 | } 182 | let possible_id = line.trim(); 183 | if let Some(code_span) = code_spans_mapping.remove(possible_id) { 184 | reranked_code_snippets.push(code_span); 185 | } 186 | } 187 | // Add back the remaining code snippets to the list 188 | code_spans_mapping.into_iter().for_each(|(_, code_span)| { 189 | reranked_code_snippets.push(code_span); 190 | }); 191 | Ok(reranked_code_snippets) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /llm_client/src/clients/lmstudio.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use eventsource_stream::Eventsource; 3 | use futures::StreamExt; 4 | use tokio::sync::mpsc::UnboundedSender; 5 | 6 | use crate::provider::{LLMProvider, LLMProviderAPIKeys}; 7 | 8 | use super::types::{ 9 | LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, 10 | LLMClientCompletionStringRequest, LLMClientError, LLMClientRole, 11 | }; 12 | 13 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 14 | struct LMStudioResponse { 15 | model: String, 16 | choices: Vec, 17 | } 18 | 19 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 20 | struct Choice { 21 | text: String, 22 | } 23 | 24 | pub struct LMStudioClient { 25 | client: reqwest::Client, 26 | } 27 | 28 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 29 | struct LLMStudioMessage { 30 | role: String, 31 | content: String, 32 | } 33 | 34 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 35 | struct LMStudioRequest { 36 | prompt: Option, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | messages: Option>, 39 | temperature: f32, 40 | stream: bool, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | frequency_penalty: Option, 43 | // set the max tokens to -1 so we get as much completion as possible 44 | max_tokens: i32, 45 | } 46 | 47 | impl LMStudioRequest { 48 | fn from_string_request(request: LLMClientCompletionStringRequest) -> Self { 49 | Self { 50 | prompt: Some(request.prompt().to_owned()), 51 | messages: None, 52 | temperature: request.temperature(), 53 | stream: true, 54 | frequency_penalty: request.frequency_penalty(), 55 | max_tokens: -1, 56 | } 57 | } 58 | 59 | fn from_chat_request(request: LLMClientCompletionRequest) -> Self { 60 | Self { 61 | prompt: None, 62 | messages: Some( 63 | request 64 | .messages() 65 | .into_iter() 66 | .map(|message| match message.role() { 67 | LLMClientRole::System => LLMStudioMessage { 68 | role: "system".to_owned(), 69 | content: message.content().to_owned(), 70 | }, 71 | LLMClientRole::User => LLMStudioMessage { 72 | role: "user".to_owned(), 73 | content: message.content().to_owned(), 74 | }, 75 | LLMClientRole::Function => LLMStudioMessage { 76 | role: "function".to_owned(), 77 | content: message.content().to_owned(), 78 | }, 79 | LLMClientRole::Assistant => LLMStudioMessage { 80 | role: "assistant".to_owned(), 81 | content: message.content().to_owned(), 82 | }, 83 | }) 84 | .collect(), 85 | ), 86 | temperature: request.temperature(), 87 | stream: true, 88 | frequency_penalty: request.frequency_penalty(), 89 | max_tokens: -1, 90 | } 91 | } 92 | } 93 | 94 | impl LMStudioClient { 95 | pub fn new() -> Self { 96 | Self { 97 | client: reqwest::Client::new(), 98 | } 99 | } 100 | 101 | pub fn completion_endpoint(&self, base_url: &str) -> String { 102 | format!("{}/v1/completions", base_url) 103 | } 104 | 105 | pub fn chat_endpoint(&self, base_url: &str) -> String { 106 | format!("{}/v1/chat/completions", base_url) 107 | } 108 | 109 | pub fn generate_base_url(&self, api_key: LLMProviderAPIKeys) -> Result { 110 | match api_key { 111 | LLMProviderAPIKeys::LMStudio(api_key) => Ok(api_key.api_base().to_owned()), 112 | _ => Err(LLMClientError::UnSupportedModel), 113 | } 114 | } 115 | } 116 | 117 | #[async_trait] 118 | impl LLMClient for LMStudioClient { 119 | fn client(&self) -> &LLMProvider { 120 | &LLMProvider::LMStudio 121 | } 122 | 123 | async fn completion( 124 | &self, 125 | api_key: LLMProviderAPIKeys, 126 | request: LLMClientCompletionRequest, 127 | ) -> Result { 128 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 129 | self.stream_completion(api_key, request, sender).await 130 | } 131 | 132 | async fn stream_completion( 133 | &self, 134 | api_key: LLMProviderAPIKeys, 135 | request: LLMClientCompletionRequest, 136 | sender: UnboundedSender, 137 | ) -> Result { 138 | let base_url = self.generate_base_url(api_key)?; 139 | let endpoint = self.chat_endpoint(&base_url); 140 | 141 | let request = LMStudioRequest::from_chat_request(request); 142 | let mut response_stream = self 143 | .client 144 | .post(endpoint) 145 | .json(&request) 146 | .send() 147 | .await? 148 | .bytes_stream() 149 | .eventsource(); 150 | 151 | let mut buffered_stream = "".to_owned(); 152 | while let Some(event) = response_stream.next().await { 153 | match event { 154 | Ok(event) => { 155 | if &event.data == "[DONE]" { 156 | continue; 157 | } 158 | let value = serde_json::from_str::(&event.data)?; 159 | buffered_stream = buffered_stream + &value.choices[0].text; 160 | sender.send(LLMClientCompletionResponse::new( 161 | buffered_stream.to_owned(), 162 | Some(value.choices[0].text.to_owned()), 163 | value.model, 164 | ))?; 165 | } 166 | Err(e) => { 167 | dbg!(e); 168 | } 169 | } 170 | } 171 | Ok(buffered_stream) 172 | } 173 | 174 | async fn stream_prompt_completion( 175 | &self, 176 | api_key: LLMProviderAPIKeys, 177 | request: LLMClientCompletionStringRequest, 178 | sender: UnboundedSender, 179 | ) -> Result { 180 | let base_url = self.generate_base_url(api_key)?; 181 | let endpoint = self.completion_endpoint(&base_url); 182 | 183 | let request = LMStudioRequest::from_string_request(request); 184 | let mut response_stream = self 185 | .client 186 | .post(endpoint) 187 | .json(&request) 188 | .send() 189 | .await? 190 | .bytes_stream() 191 | .eventsource(); 192 | 193 | let mut buffered_stream = "".to_owned(); 194 | while let Some(event) = response_stream.next().await { 195 | match event { 196 | Ok(event) => { 197 | if &event.data == "[DONE]" { 198 | continue; 199 | } 200 | let value = serde_json::from_str::(&event.data)?; 201 | buffered_stream = buffered_stream + &value.choices[0].text; 202 | sender.send(LLMClientCompletionResponse::new( 203 | buffered_stream.to_owned(), 204 | Some(value.choices[0].text.to_owned()), 205 | value.model, 206 | ))?; 207 | } 208 | Err(e) => { 209 | dbg!(e); 210 | } 211 | } 212 | } 213 | Ok(buffered_stream) 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /llm_client/src/clients/togetherai.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use eventsource_stream::Eventsource; 3 | use futures::StreamExt; 4 | use tokio::sync::mpsc::UnboundedSender; 5 | 6 | use crate::provider::LLMProviderAPIKeys; 7 | 8 | use super::types::LLMClient; 9 | use super::types::LLMClientCompletionRequest; 10 | use super::types::LLMClientCompletionResponse; 11 | use super::types::LLMClientCompletionStringRequest; 12 | use super::types::LLMClientError; 13 | use super::types::LLMType; 14 | 15 | pub struct TogetherAIClient { 16 | pub client: reqwest::Client, 17 | pub base_url: String, 18 | } 19 | 20 | #[derive(serde::Serialize, Debug, Clone)] 21 | struct TogetherAIRequest { 22 | prompt: String, 23 | model: String, 24 | temperature: f32, 25 | stream_tokens: bool, 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | frequency_penalty: Option, 28 | } 29 | 30 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 31 | struct TogetherAIResponse { 32 | choices: Vec, 33 | // id: String, 34 | // token: Token, 35 | } 36 | 37 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 38 | struct Choice { 39 | text: String, 40 | } 41 | 42 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 43 | struct Token { 44 | id: i32, 45 | text: String, 46 | logprob: i32, 47 | special: bool, 48 | } 49 | 50 | impl TogetherAIRequest { 51 | pub fn from_request(request: LLMClientCompletionRequest) -> Self { 52 | Self { 53 | prompt: { 54 | if request.messages().len() == 1 { 55 | request.messages()[0].content().to_owned() 56 | } else { 57 | request 58 | .messages() 59 | .into_iter() 60 | .map(|message| message.content().to_owned()) 61 | .collect::>() 62 | .join("\n") 63 | } 64 | }, 65 | // TODO(skcd): Proper error handling here 66 | model: TogetherAIClient::model_str(request.model()).expect("to be present"), 67 | temperature: request.temperature(), 68 | stream_tokens: true, 69 | frequency_penalty: request.frequency_penalty(), 70 | } 71 | } 72 | 73 | pub fn from_string_request(request: LLMClientCompletionStringRequest) -> Self { 74 | Self { 75 | prompt: request.prompt().to_owned(), 76 | model: TogetherAIClient::model_str(request.model()).expect("to be present"), 77 | temperature: request.temperature(), 78 | stream_tokens: true, 79 | frequency_penalty: request.frequency_penalty(), 80 | } 81 | } 82 | } 83 | 84 | impl TogetherAIClient { 85 | pub fn new() -> Self { 86 | let client = reqwest::Client::new(); 87 | Self { 88 | client, 89 | base_url: "https://api.together.xyz".to_owned(), 90 | } 91 | } 92 | 93 | pub fn inference_endpoint(&self) -> String { 94 | format!("{}/inference", self.base_url) 95 | } 96 | 97 | pub fn completion_endpoint(&self) -> String { 98 | format!("{}/completions", self.base_url) 99 | } 100 | 101 | pub fn model_str(model: &LLMType) -> Option { 102 | match model { 103 | LLMType::Mixtral => Some("mistralai/Mixtral-8x7B-Instruct-v0.1".to_owned()), 104 | LLMType::MistralInstruct => Some("mistralai/Mistral-7B-Instruct-v0.1".to_owned()), 105 | LLMType::Custom(model) => Some(model.to_owned()), 106 | _ => None, 107 | } 108 | } 109 | 110 | fn generate_together_ai_bearer_key( 111 | &self, 112 | api_key: LLMProviderAPIKeys, 113 | ) -> Result { 114 | match api_key { 115 | LLMProviderAPIKeys::TogetherAI(api_key) => Ok(api_key.api_key), 116 | _ => Err(LLMClientError::WrongAPIKeyType), 117 | } 118 | } 119 | } 120 | 121 | #[async_trait] 122 | impl LLMClient for TogetherAIClient { 123 | fn client(&self) -> &crate::provider::LLMProvider { 124 | &crate::provider::LLMProvider::TogetherAI 125 | } 126 | 127 | async fn completion( 128 | &self, 129 | api_key: LLMProviderAPIKeys, 130 | request: LLMClientCompletionRequest, 131 | ) -> Result { 132 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 133 | self.stream_completion(api_key, request, sender).await 134 | } 135 | 136 | async fn stream_prompt_completion( 137 | &self, 138 | api_key: LLMProviderAPIKeys, 139 | request: LLMClientCompletionStringRequest, 140 | sender: UnboundedSender, 141 | ) -> Result { 142 | let model = TogetherAIClient::model_str(request.model()); 143 | if model.is_none() { 144 | return Err(LLMClientError::FailedToGetResponse); 145 | } 146 | let model = model.expect("is_none check above to work"); 147 | let together_ai_request = TogetherAIRequest::from_string_request(request); 148 | let mut response_stream = self 149 | .client 150 | .post(self.inference_endpoint()) 151 | .bearer_auth(self.generate_together_ai_bearer_key(api_key)?.to_owned()) 152 | .json(&together_ai_request) 153 | .send() 154 | .await? 155 | .bytes_stream() 156 | .eventsource(); 157 | 158 | let mut buffered_string = "".to_owned(); 159 | while let Some(event) = response_stream.next().await { 160 | match event { 161 | Ok(event) => { 162 | if &event.data == "[DONE]" { 163 | continue; 164 | } 165 | let value = serde_json::from_str::(&event.data)?; 166 | buffered_string = buffered_string + &value.choices[0].text; 167 | sender.send(LLMClientCompletionResponse::new( 168 | buffered_string.to_owned(), 169 | Some(value.choices[0].text.to_owned()), 170 | model.to_owned(), 171 | ))?; 172 | } 173 | Err(e) => { 174 | dbg!(e); 175 | } 176 | } 177 | } 178 | 179 | Ok(buffered_string) 180 | } 181 | 182 | async fn stream_completion( 183 | &self, 184 | api_key: LLMProviderAPIKeys, 185 | request: LLMClientCompletionRequest, 186 | sender: UnboundedSender, 187 | ) -> Result { 188 | let model = TogetherAIClient::model_str(request.model()); 189 | if model.is_none() { 190 | return Err(LLMClientError::FailedToGetResponse); 191 | } 192 | let model = model.expect("is_none check above to work"); 193 | let together_ai_request = TogetherAIRequest::from_request(request); 194 | let mut response_stream = self 195 | .client 196 | .post(self.inference_endpoint()) 197 | .bearer_auth(self.generate_together_ai_bearer_key(api_key)?.to_owned()) 198 | .json(&together_ai_request) 199 | .send() 200 | .await? 201 | .bytes_stream() 202 | .eventsource(); 203 | 204 | let mut buffered_string = "".to_owned(); 205 | while let Some(event) = response_stream.next().await { 206 | match event { 207 | Ok(event) => { 208 | if &event.data == "[DONE]" { 209 | continue; 210 | } 211 | let value = serde_json::from_str::(&event.data)?; 212 | buffered_string = buffered_string + &value.choices[0].text; 213 | sender.send(LLMClientCompletionResponse::new( 214 | buffered_string.to_owned(), 215 | Some(value.choices[0].text.to_owned()), 216 | model.to_owned(), 217 | ))?; 218 | } 219 | Err(e) => { 220 | dbg!(e); 221 | } 222 | } 223 | } 224 | 225 | Ok(buffered_string) 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /llm_prompts/src/reranking/mistral.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use llm_client::clients::types::LLMClientCompletionStringRequest; 4 | 5 | use super::types::{ 6 | CodeSpan, CodeSpanDigest, ReRankCodeSpan, ReRankCodeSpanError, ReRankCodeSpanRequest, 7 | ReRankCodeSpanResponse, ReRankListWiseResponse, ReRankPointWisePrompt, ReRankStrategy, 8 | }; 9 | 10 | #[derive(Default)] 11 | pub struct MistralReRank {} 12 | 13 | impl MistralReRank { 14 | pub fn new() -> Self { 15 | Default::default() 16 | } 17 | } 18 | 19 | impl MistralReRank { 20 | pub fn pointwise_reranking(&self, request: ReRankCodeSpanRequest) -> ReRankCodeSpanResponse { 21 | let code_span_digests = CodeSpan::to_digests(request.code_spans().to_vec()); 22 | // Now we query the LLM for the pointwise reranking here 23 | let user_query = request.user_query().to_owned(); 24 | let prompts = code_span_digests 25 | .into_iter() 26 | .map(|code_span_digest| { 27 | let user_query = user_query.to_owned(); 28 | let hash = code_span_digest.hash(); 29 | let data = code_span_digest.data(); 30 | let prompt = format!(r#"[INST] You are an expert software developer responsible for helping detect whether the retrieved snippet of code is relevant to the query. For a given input, you need to output a single word: "Yes" or "No" indicating the retrieved snippet is relevant to the query. 31 | Query: Where is the client for OpenAI defined? 32 | Code Snippet: 33 | ```/Users/skcd/client/openai.rs 34 | pub struct OpenAIClient {{}} 35 | 36 | impl OpenAIClient {{ 37 | pub fn new() -> Self {{ 38 | Self {{}} 39 | }} 40 | ``` 41 | Relevant: Yes 42 | 43 | Query: Where do we handle the errors in the webview? 44 | Snippet: 45 | ```/Users/skcd/algorithm/dfs.rs 46 | pub fn dfs(graph: &Graph, start: NodeId) -> Vec {{ 47 | let mut visited = HashSet::new(); 48 | let mut stack = vec![start]; 49 | let mut result = vec![]; 50 | while let Some(node) = stack.pop() {{ 51 | if visited.contains(&node) {{ 52 | continue; 53 | }} 54 | visited.insert(node); 55 | result.push(node); 56 | for neighbor in graph.neighbors(node) {{ 57 | stack.push(neighbor); 58 | }} 59 | }} 60 | result 61 | }} 62 | ``` 63 | Relevant: No 64 | 65 | Query: {user_query} 66 | Snippet: 67 | ```{hash} 68 | {data} 69 | ``` [/INST] 70 | Relevant: "#); 71 | let prompt = LLMClientCompletionStringRequest::new( 72 | request.llm_type().clone(), 73 | prompt, 74 | 0.0, 75 | None, 76 | ); 77 | ReRankPointWisePrompt::new_string_completion(prompt, code_span_digest) 78 | }) 79 | .collect(); 80 | 81 | ReRankCodeSpanResponse::PointWise(prompts) 82 | } 83 | 84 | pub fn listwise_reranking(&self, request: ReRankCodeSpanRequest) -> ReRankCodeSpanResponse { 85 | // First we get the code spans which are present here cause they are important 86 | let code_spans = request.code_spans().to_vec(); 87 | let user_query = request.user_query().to_owned(); 88 | // Now we need to generate the prompt for this 89 | let code_span_digests = CodeSpan::to_digests(code_spans); 90 | let code_snippets = code_span_digests 91 | .iter() 92 | .map(|code_span_digest| { 93 | let identifier = code_span_digest.hash(); 94 | let data = code_span_digest.data(); 95 | let span_identifier = code_span_digest.get_span_identifier(); 96 | format!( 97 | "\n{identifier}\n\n\n```\n{span_identifier}\n{data}\n```\n\n" 98 | ) 99 | }) 100 | .collect::>() 101 | .join("\n"); 102 | // Now we create the prompt for this reranking 103 | let prompt = format!( 104 | r#"[INST] You are an expert at ordering the code snippets from the most relevant to the least relevant for the user query. You have the order the list of code snippets from the most relevant to the least relevant. As an example 105 | 106 | 107 | subtract.rs::0 108 | 109 | 110 | ``` 111 | fn subtract(a: i32, b: i32) -> i32 {{ 112 | a - b 113 | }} 114 | ``` 115 | 116 | 117 | 118 | add.rs::0 119 | 120 | 121 | ``` 122 | fn add(a: i32, b: i32) -> i32 {{ 123 | a + b 124 | }} 125 | ``` 126 | 127 | 128 | 129 | And if you thought the code snippet with id add.rs::0 is more relevant than subtract.rs::0 then you would rank it as: 130 | 131 | 132 | add.rs::0 133 | 134 | 135 | subtract.rs::0 136 | 137 | 138 | 139 | Now for the actual query. 140 | The user has asked the following query: 141 | 142 | {user_query} 143 | 144 | 145 | The code snippets along with their ids are given below: 146 | 147 | {code_snippets} 148 | 149 | 150 | As a reminder the user question is: 151 | 152 | {user_query} 153 | 154 | You have to order all the code snippets from the most relevant to the least relevant to the user query, all the code snippet ids should be present in your final reordered list. Only output the ids of the code snippets. 155 | [/INST] 156 | 157 | "# 158 | ); 159 | let prompt = 160 | LLMClientCompletionStringRequest::new(request.llm_type().clone(), prompt, 0.0, None); 161 | ReRankCodeSpanResponse::listwise_completion(prompt, code_span_digests) 162 | } 163 | 164 | fn parse_listwise_output( 165 | &self, 166 | output: &str, 167 | code_span_digests: Vec, 168 | ) -> Result, ReRankCodeSpanError> { 169 | // The output is generally in the format of 170 | // 171 | // {id} 172 | // 173 | // ... 174 | // 175 | // This is not guaranteed with mistral instruct (but always by mixtral) 176 | // so we split the string on \n and ignore the values which are or 177 | // and only parse until we get the tag 178 | let mut output = output.split("\n"); 179 | let mut code_span_digests_mapping: HashMap = code_span_digests 180 | .into_iter() 181 | .map(|code_span_digest| (code_span_digest.hash().to_owned(), code_span_digest)) 182 | .collect(); 183 | let mut code_spans_reordered_list = Vec::new(); 184 | while let Some(line) = output.next() { 185 | if line.contains("") { 186 | break; 187 | } 188 | if line.contains("") || line.contains("") { 189 | continue; 190 | } 191 | let possible_id = line.trim(); 192 | if let Some(code_span) = code_span_digests_mapping.remove(possible_id) { 193 | code_spans_reordered_list.push(code_span) 194 | } 195 | } 196 | 197 | // Add all the remaining code spans to the end of the list 198 | code_span_digests_mapping 199 | .into_iter() 200 | .for_each(|(_, code_span)| { 201 | code_spans_reordered_list.push(code_span); 202 | }); 203 | 204 | // Now that we have the possible ids in the list, we get the list of ranked 205 | // code span digests in the same manner 206 | Ok(code_spans_reordered_list) 207 | } 208 | } 209 | 210 | impl ReRankCodeSpan for MistralReRank { 211 | fn rerank_prompt( 212 | &self, 213 | request: ReRankCodeSpanRequest, 214 | ) -> Result { 215 | Ok(match request.strategy() { 216 | ReRankStrategy::ListWise => self.listwise_reranking(request), 217 | ReRankStrategy::PointWise => { 218 | // We need to generate the prompt for this 219 | self.pointwise_reranking(request) 220 | } 221 | }) 222 | } 223 | 224 | fn parse_listwise_output( 225 | &self, 226 | llm_output: String, 227 | rerank_request: ReRankListWiseResponse, 228 | ) -> Result, ReRankCodeSpanError> { 229 | self.parse_listwise_output(&llm_output, rerank_request.code_span_digests) 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /llm_client/src/broker.rs: -------------------------------------------------------------------------------- 1 | //! The llm client broker takes care of getting the right tokenizer formatter etc 2 | //! without us having to worry about the specifics, just pass in the message and the 3 | //! provider we take care of the rest 4 | 5 | use std::{collections::HashMap, sync::Arc}; 6 | 7 | use futures::future::Either; 8 | use sqlx::SqlitePool; 9 | 10 | use crate::{ 11 | clients::{ 12 | codestory::CodeStoryClient, 13 | lmstudio::LMStudioClient, 14 | ollama::OllamaClient, 15 | openai::OpenAIClient, 16 | togetherai::TogetherAIClient, 17 | types::{ 18 | LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, 19 | LLMClientCompletionStringRequest, LLMClientError, 20 | }, 21 | }, 22 | config::LLMBrokerConfiguration, 23 | provider::{CodeStoryLLMType, LLMProvider, LLMProviderAPIKeys}, 24 | sqlite, 25 | }; 26 | 27 | pub type SqlDb = Arc; 28 | 29 | pub struct LLMBroker { 30 | pub providers: HashMap>, 31 | db: SqlDb, 32 | } 33 | 34 | pub type LLMBrokerResponse = Result; 35 | 36 | impl LLMBroker { 37 | pub async fn new(config: LLMBrokerConfiguration) -> Result { 38 | let sqlite = Arc::new(sqlite::init(config).await?); 39 | let broker = Self { 40 | providers: HashMap::new(), 41 | db: sqlite, 42 | }; 43 | Ok(broker 44 | .add_provider(LLMProvider::OpenAI, Box::new(OpenAIClient::new())) 45 | .add_provider(LLMProvider::Ollama, Box::new(OllamaClient::new())) 46 | .add_provider(LLMProvider::TogetherAI, Box::new(TogetherAIClient::new())) 47 | .add_provider( 48 | LLMProvider::CodeStory(CodeStoryLLMType { llm_type: None }), 49 | Box::new(CodeStoryClient::new( 50 | "https://codestory-provider-dot-anton-390822.ue.r.appspot.com", 51 | )), 52 | )) 53 | } 54 | 55 | pub fn add_provider( 56 | mut self, 57 | provider: LLMProvider, 58 | client: Box, 59 | ) -> Self { 60 | self.providers.insert(provider, client); 61 | self 62 | } 63 | 64 | pub async fn stream_answer( 65 | &self, 66 | api_key: LLMProviderAPIKeys, 67 | provider: LLMProvider, 68 | request: Either, 69 | metadata: HashMap, 70 | sender: tokio::sync::mpsc::UnboundedSender, 71 | ) -> LLMBrokerResponse { 72 | match request { 73 | Either::Left(request) => { 74 | self.stream_completion(api_key, request, provider, metadata, sender) 75 | .await 76 | } 77 | Either::Right(request) => { 78 | self.stream_string_completion(api_key, request, metadata, sender) 79 | .await 80 | } 81 | } 82 | } 83 | 84 | pub async fn stream_completion( 85 | &self, 86 | api_key: LLMProviderAPIKeys, 87 | request: LLMClientCompletionRequest, 88 | provider: LLMProvider, 89 | metadata: HashMap, 90 | sender: tokio::sync::mpsc::UnboundedSender, 91 | ) -> LLMBrokerResponse { 92 | let api_key = api_key 93 | .key(&provider) 94 | .ok_or(LLMClientError::UnSupportedModel)?; 95 | let provider_type = match &api_key { 96 | LLMProviderAPIKeys::Ollama(_) => LLMProvider::Ollama, 97 | LLMProviderAPIKeys::OpenAI(_) => LLMProvider::OpenAI, 98 | LLMProviderAPIKeys::OpenAIAzureConfig(_) => LLMProvider::OpenAI, 99 | LLMProviderAPIKeys::TogetherAI(_) => LLMProvider::TogetherAI, 100 | LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio, 101 | LLMProviderAPIKeys::CodeStory => { 102 | LLMProvider::CodeStory(CodeStoryLLMType { llm_type: None }) 103 | } 104 | }; 105 | let provider = self.providers.get(&provider_type); 106 | if let Some(provider) = provider { 107 | let result = provider 108 | .stream_completion(api_key, request.clone(), sender) 109 | .await; 110 | if let Ok(result) = result.as_ref() { 111 | // we write the inputs to the DB so we can keep track of the inputs 112 | // and the result provided by the LLM 113 | let llm_type = request.model(); 114 | let temperature = request.temperature(); 115 | let str_metadata = serde_json::to_string(&metadata).unwrap_or_default(); 116 | let llm_type_str = serde_json::to_string(&llm_type)?; 117 | let messages = serde_json::to_string(&request.messages())?; 118 | let mut tx = self 119 | .db 120 | .begin() 121 | .await 122 | .map_err(|_e| LLMClientError::FailedToStoreInDB)?; 123 | let _ = sqlx::query! { 124 | r#" 125 | INSERT INTO llm_data (chat_messages, response, llm_type, temperature, max_tokens, event_type) 126 | VALUES ($1, $2, $3, $4, $5, $6) 127 | "#, 128 | messages, 129 | result, 130 | llm_type_str, 131 | temperature, 132 | -1, 133 | str_metadata, 134 | }.execute(&mut *tx).await?; 135 | let _ = tx 136 | .commit() 137 | .await 138 | .map_err(|_e| LLMClientError::FailedToStoreInDB)?; 139 | } 140 | result 141 | } else { 142 | Err(LLMClientError::UnSupportedModel) 143 | } 144 | } 145 | 146 | pub async fn stream_string_completion( 147 | &self, 148 | api_key: LLMProviderAPIKeys, 149 | request: LLMClientCompletionStringRequest, 150 | metadata: HashMap, 151 | sender: tokio::sync::mpsc::UnboundedSender, 152 | ) -> LLMBrokerResponse { 153 | let provider_type = match &api_key { 154 | LLMProviderAPIKeys::Ollama(_) => LLMProvider::Ollama, 155 | LLMProviderAPIKeys::OpenAI(_) => LLMProvider::OpenAI, 156 | LLMProviderAPIKeys::OpenAIAzureConfig(_) => LLMProvider::OpenAI, 157 | LLMProviderAPIKeys::TogetherAI(_) => LLMProvider::TogetherAI, 158 | LLMProviderAPIKeys::LMStudio(_) => LLMProvider::LMStudio, 159 | LLMProviderAPIKeys::CodeStory => { 160 | LLMProvider::CodeStory(CodeStoryLLMType { llm_type: None }) 161 | } 162 | }; 163 | let provider = self.providers.get(&provider_type); 164 | if let Some(provider) = provider { 165 | let result = provider 166 | .stream_prompt_completion(api_key, request.clone(), sender) 167 | .await; 168 | if let Ok(result) = result.as_ref() { 169 | // we write the inputs to the DB so we can keep track of the inputs 170 | // and the result provided by the LLM 171 | let llm_type = request.model(); 172 | let temperature = request.temperature(); 173 | let str_metadata = serde_json::to_string(&metadata).unwrap_or_default(); 174 | let llm_type_str = serde_json::to_string(&llm_type)?; 175 | let prompt = request.prompt(); 176 | let mut tx = self 177 | .db 178 | .begin() 179 | .await 180 | .map_err(|_e| LLMClientError::FailedToStoreInDB)?; 181 | let _ = sqlx::query! { 182 | r#" 183 | INSERT INTO llm_data (prompt, response, llm_type, temperature, max_tokens, event_type) 184 | VALUES ($1, $2, $3, $4, $5, $6) 185 | "#, 186 | prompt, 187 | result, 188 | llm_type_str, 189 | temperature, 190 | -1, 191 | str_metadata, 192 | }.execute(&mut *tx).await?; 193 | let _ = tx 194 | .commit() 195 | .await 196 | .map_err(|_e| LLMClientError::FailedToStoreInDB)?; 197 | } 198 | result 199 | } else { 200 | Err(LLMClientError::UnSupportedModel) 201 | } 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /llm_prompts/src/in_line_edit/openai.rs: -------------------------------------------------------------------------------- 1 | use llm_client::clients::types::LLMClientMessage; 2 | 3 | use crate::in_line_edit::doc_helpers::document_symbol_metadata; 4 | 5 | use super::types::InLineDocRequest; 6 | use super::types::InLineEditPrompt; 7 | use super::types::InLineEditRequest; 8 | use super::types::InLineFixRequest; 9 | use super::types::InLinePromptResponse; 10 | 11 | pub struct OpenAILineEditPrompt {} 12 | 13 | impl OpenAILineEditPrompt { 14 | pub fn new() -> Self { 15 | Self {} 16 | } 17 | } 18 | 19 | impl OpenAILineEditPrompt { 20 | fn system_message_inline_edit(&self, language: &str) -> String { 21 | format!( 22 | r#"You are an AI programming assistant. 23 | When asked for your name, you must respond with "Aide". 24 | Follow the user's requirements carefully & to the letter. 25 | - First think step-by-step - describe your plan for what to build in pseudocode, written out in great detail. 26 | - Then output the code in a single code block. 27 | - Minimize any other prose. 28 | - Each code block starts with ``` and // FILEPATH. 29 | - If you suggest to run a terminal command, use a code block that starts with ```bash. 30 | - You always answer with {language} code. 31 | - Modify the code or create new code. 32 | - Unless directed otherwise, the user is expecting for you to edit their selected code. 33 | - Make sure to ALWAYS INCLUDE the BEGIN and END markers in your generated code with // BEGIN and then // END which is present in the code selection given by the user 34 | You must decline to answer if the question is not related to a developer. 35 | If the question is related to a developer, you must respond with content related to a developer."# 36 | ) 37 | } 38 | 39 | fn system_message_fix(&self, language: &str) -> String { 40 | format!( 41 | r#"You are an AI programming assistant. 42 | When asked for your name, you must respond with "Aide". 43 | Follow the user's requirements carefully & to the letter. 44 | - First think step-by-step - describe your plan for what to build in pseudocode, written out in great detail. 45 | - Then output the code in a single code block. 46 | - Minimize any other prose. 47 | - Each code block starts with ``` and // FILEPATH. 48 | - If you suggest to run a terminal command, use a code block that starts with ```bash. 49 | - You always answer with {language} code. 50 | - Modify the code or create new code. 51 | - Unless directed otherwise, the user is expecting for you to edit their selected code. 52 | You must decline to answer if the question is not related to a developer. 53 | If the question is related to a developer, you must respond with content related to a developer."# 54 | ) 55 | } 56 | 57 | fn documentation_system_prompt(&self, language: &str, is_identifier_node: bool) -> String { 58 | if is_identifier_node { 59 | let system_prompt = format!( 60 | r#"You are an AI programming assistant. 61 | When asked for your name, you must respond with "Aide". 62 | Follow the user's requirements carefully & to the letter. 63 | - Each code block must ALWAYS STARTS and include ```{language} and // FILEPATH 64 | - You always answer with {language} code. 65 | - When the user asks you to document something, you must answer in the form of a {language} code block. 66 | - Your documentation should not include just the name of the function, think about what the function is really doing. 67 | - When generating the documentation, be sure to understand what the function is doing and include that as part of the documentation and then generate the documentation. 68 | - DO NOT modify the code which you will be generating"# 69 | ); 70 | system_prompt.to_owned() 71 | } else { 72 | let system_prompt = format!( 73 | r#"You are an AI programming assistant. 74 | When asked for your name, you must respond with "Aide". 75 | Follow the user's requirements carefully & to the letter. 76 | - Each code block must ALWAYS STARTS and include ```{language} and // FILEPATH 77 | - You always answer with {language} code. 78 | - When the user asks you to document something, you must answer in the form of a {language} code block. 79 | - Your documentation should not include just the code selection, think about what the selection is really doing. 80 | - When generating the documentation, be sure to understand what the selection is doing and include that as part of the documentation and then generate the documentation. 81 | - DO NOT modify the code which you will be generating"# 82 | ); 83 | system_prompt.to_owned() 84 | } 85 | } 86 | 87 | fn above_selection(&self, above_context: Option<&String>) -> Option { 88 | if let Some(above_context) = above_context { 89 | Some(format!( 90 | r#"I have the following code above: 91 | {above_context}"# 92 | )) 93 | } else { 94 | None 95 | } 96 | } 97 | 98 | fn below_selection(&self, below_context: Option<&String>) -> Option { 99 | if let Some(below_context) = below_context { 100 | Some(format!( 101 | r#"I have the following code below: 102 | {below_context}"# 103 | )) 104 | } else { 105 | None 106 | } 107 | } 108 | } 109 | 110 | impl InLineEditPrompt for OpenAILineEditPrompt { 111 | fn inline_edit(&self, request: InLineEditRequest) -> InLinePromptResponse { 112 | // Here we create the messages for the openai, since we have flexibility 113 | // and the llms are in general smart we can just send the chat messages 114 | // instead of the completion(which has been deprecated) 115 | let above = request.above(); 116 | let below = request.below(); 117 | let in_range = request.in_range(); 118 | let language = request.language(); 119 | 120 | let mut messages = vec![]; 121 | messages.push(LLMClientMessage::system( 122 | self.system_message_inline_edit(language), 123 | )); 124 | if let Some(above) = self.above_selection(above) { 125 | messages.push(LLMClientMessage::user(above)); 126 | } 127 | if let Some(below) = self.below_selection(below) { 128 | messages.push(LLMClientMessage::user(below)); 129 | } 130 | if let Some(in_range) = in_range { 131 | messages.push(LLMClientMessage::user(in_range.to_owned())); 132 | } 133 | messages.push(LLMClientMessage::user(request.user_query().to_owned())); 134 | // Add an additional message about keeping the // FILEPATH and the markers 135 | messages.push(LLMClientMessage::system(format!( 136 | r#"Make sure to ALWAYS INCLUDE the BEGIN and END markers in your generated code with // BEGIN and then // END which is present in the code selection given by me"# 137 | ))); 138 | InLinePromptResponse::Chat(messages) 139 | } 140 | 141 | fn inline_fix(&self, request: InLineFixRequest) -> InLinePromptResponse { 142 | let above = request.above(); 143 | let below = request.below(); 144 | let in_range = request.in_range(); 145 | let language = request.language(); 146 | 147 | let mut messages = vec![]; 148 | messages.push(LLMClientMessage::system(self.system_message_fix(language))); 149 | if let Some(above) = self.above_selection(above) { 150 | messages.push(LLMClientMessage::user(above)); 151 | } 152 | if let Some(below) = self.below_selection(below) { 153 | messages.push(LLMClientMessage::user(below)); 154 | } 155 | messages.push(LLMClientMessage::user(in_range.to_owned())); 156 | messages.extend( 157 | request 158 | .diagnostics_prompts() 159 | .into_iter() 160 | .map(|diagnostic_prompt| LLMClientMessage::user(diagnostic_prompt.to_owned())), 161 | ); 162 | messages.push( 163 | LLMClientMessage::user("Do not forget to include the // BEGIN and // END markers in your generated code. Only change the code inside of the selection, delimited by the markers: // BEGIN: ed8c6549bwf9 and // END: ed8c6549bwf9".to_owned()) 164 | ); 165 | InLinePromptResponse::Chat(messages) 166 | } 167 | 168 | fn inline_doc(&self, request: InLineDocRequest) -> InLinePromptResponse { 169 | let system_prompt = 170 | self.documentation_system_prompt(request.language(), request.is_identifier_node()); 171 | let mut messages = vec![]; 172 | messages.push(LLMClientMessage::system(system_prompt)); 173 | messages.push(LLMClientMessage::user(request.in_range().to_owned())); 174 | messages.push(LLMClientMessage::user(document_symbol_metadata(&request))); 175 | messages.push(LLMClientMessage::user("Do not forget to the include the // BEGIN and // END markers in your generated code. Only change the code provided to you in the selection".to_owned())); 176 | InLinePromptResponse::Chat(messages) 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /llm_client/src/clients/openai.rs: -------------------------------------------------------------------------------- 1 | //! Client which can help us talk to openai 2 | 3 | use async_openai::{ 4 | config::{AzureConfig, OpenAIConfig}, 5 | types::{ 6 | ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, 7 | CreateChatCompletionRequestArgs, FunctionCall, Role, 8 | }, 9 | Client, 10 | }; 11 | use async_trait::async_trait; 12 | use futures::StreamExt; 13 | 14 | use crate::provider::LLMProviderAPIKeys; 15 | 16 | use super::types::{ 17 | LLMClient, LLMClientCompletionRequest, LLMClientCompletionResponse, LLMClientError, 18 | LLMClientMessage, LLMClientRole, LLMType, 19 | }; 20 | 21 | enum OpenAIClientType { 22 | AzureClient(Client), 23 | OpenAIClient(Client), 24 | } 25 | 26 | pub struct OpenAIClient {} 27 | 28 | impl OpenAIClient { 29 | pub fn new() -> Self { 30 | Self {} 31 | } 32 | 33 | pub fn model(&self, model: &LLMType) -> Option { 34 | match model { 35 | LLMType::GPT3_5_16k => Some("gpt-3.5-turbo-16k-0613".to_owned()), 36 | LLMType::Gpt4 => Some("gpt-4-0613".to_owned()), 37 | LLMType::Gpt4Turbo => Some("gpt-4-1106-preview".to_owned()), 38 | LLMType::Gpt4_32k => Some("gpt-4-32k-0613".to_owned()), 39 | _ => None, 40 | } 41 | } 42 | 43 | pub fn messages( 44 | &self, 45 | messages: &[LLMClientMessage], 46 | ) -> Result, LLMClientError> { 47 | let formatted_messages = messages 48 | .into_iter() 49 | .map(|message| { 50 | let role = message.role(); 51 | match role { 52 | LLMClientRole::User => ChatCompletionRequestMessageArgs::default() 53 | .role(Role::User) 54 | .content(message.content().to_owned()) 55 | .build() 56 | .map_err(|e| LLMClientError::OpenAPIError(e)), 57 | LLMClientRole::System => ChatCompletionRequestMessageArgs::default() 58 | .role(Role::System) 59 | .content(message.content().to_owned()) 60 | .build() 61 | .map_err(|e| LLMClientError::OpenAPIError(e)), 62 | // the assistant is the one which ends up calling the function, so we need to 63 | // handle the case where the function is called by the assistant here 64 | LLMClientRole::Assistant => match message.get_function_call() { 65 | Some(function_call) => ChatCompletionRequestMessageArgs::default() 66 | .role(Role::Function) 67 | .function_call(FunctionCall { 68 | name: function_call.name().to_owned(), 69 | arguments: function_call.arguments().to_owned(), 70 | }) 71 | .build() 72 | .map_err(|e| LLMClientError::OpenAPIError(e)), 73 | None => ChatCompletionRequestMessageArgs::default() 74 | .role(Role::Assistant) 75 | .content(message.content().to_owned()) 76 | .build() 77 | .map_err(|e| LLMClientError::OpenAPIError(e)), 78 | }, 79 | LLMClientRole::Function => match message.get_function_call() { 80 | Some(function_call) => ChatCompletionRequestMessageArgs::default() 81 | .role(Role::Function) 82 | .content(message.content().to_owned()) 83 | .function_call(FunctionCall { 84 | name: function_call.name().to_owned(), 85 | arguments: function_call.arguments().to_owned(), 86 | }) 87 | .build() 88 | .map_err(|e| LLMClientError::OpenAPIError(e)), 89 | None => Err(LLMClientError::FunctionCallNotPresent), 90 | }, 91 | } 92 | }) 93 | .collect::>(); 94 | formatted_messages 95 | .into_iter() 96 | .collect::, LLMClientError>>() 97 | } 98 | 99 | fn generate_openai_client( 100 | &self, 101 | api_key: LLMProviderAPIKeys, 102 | ) -> Result { 103 | match api_key { 104 | LLMProviderAPIKeys::OpenAI(api_key) => { 105 | let config = OpenAIConfig::new().with_api_key(api_key.api_key); 106 | Ok(OpenAIClientType::OpenAIClient(Client::with_config(config))) 107 | } 108 | LLMProviderAPIKeys::OpenAIAzureConfig(azure_config) => { 109 | let config = AzureConfig::new() 110 | .with_api_base(azure_config.api_base) 111 | .with_api_key(azure_config.api_key) 112 | .with_deployment_id(azure_config.deployment_id) 113 | .with_api_version(azure_config.api_version); 114 | Ok(OpenAIClientType::AzureClient(Client::with_config(config))) 115 | } 116 | _ => Err(LLMClientError::WrongAPIKeyType), 117 | } 118 | } 119 | } 120 | 121 | #[async_trait] 122 | impl LLMClient for OpenAIClient { 123 | fn client(&self) -> &crate::provider::LLMProvider { 124 | &crate::provider::LLMProvider::OpenAI 125 | } 126 | 127 | async fn stream_completion( 128 | &self, 129 | api_key: LLMProviderAPIKeys, 130 | request: LLMClientCompletionRequest, 131 | sender: tokio::sync::mpsc::UnboundedSender, 132 | ) -> Result { 133 | let model = self.model(request.model()); 134 | if model.is_none() { 135 | return Err(LLMClientError::UnSupportedModel); 136 | } 137 | let model = model.unwrap(); 138 | let messages = self.messages(request.messages())?; 139 | let mut request_builder_args = CreateChatCompletionRequestArgs::default(); 140 | let mut request_builder = request_builder_args 141 | .model(model.to_owned()) 142 | .messages(messages) 143 | .temperature(request.temperature()) 144 | .stream(true); 145 | if let Some(frequency_penalty) = request.frequency_penalty() { 146 | request_builder = request_builder.frequency_penalty(frequency_penalty); 147 | } 148 | let request = request_builder.build()?; 149 | let mut buffer = String::new(); 150 | let client = self.generate_openai_client(api_key)?; 151 | 152 | // TODO(skcd): Bad code :| we are repeating too many things but this 153 | // just works and we need it right now 154 | match client { 155 | OpenAIClientType::AzureClient(client) => { 156 | let stream_maybe = client.chat().create_stream(request).await; 157 | if stream_maybe.is_err() { 158 | return Err(LLMClientError::OpenAPIError(stream_maybe.err().unwrap())); 159 | } else { 160 | dbg!("no error here"); 161 | } 162 | let mut stream = stream_maybe.unwrap(); 163 | while let Some(response) = stream.next().await { 164 | match response { 165 | Ok(response) => { 166 | let delta = response 167 | .choices 168 | .get(0) 169 | .map(|choice| choice.delta.content.to_owned()) 170 | .flatten() 171 | .unwrap_or("".to_owned()); 172 | let _value = response 173 | .choices 174 | .get(0) 175 | .map(|choice| choice.delta.content.as_ref()) 176 | .flatten(); 177 | buffer.push_str(&delta); 178 | let _ = sender.send(LLMClientCompletionResponse::new( 179 | buffer.to_owned(), 180 | Some(delta), 181 | model.to_owned(), 182 | )); 183 | } 184 | Err(err) => { 185 | dbg!(err); 186 | break; 187 | } 188 | } 189 | } 190 | } 191 | OpenAIClientType::OpenAIClient(client) => { 192 | let mut stream = client.chat().create_stream(request).await?; 193 | while let Some(response) = stream.next().await { 194 | match response { 195 | Ok(response) => { 196 | let response = response 197 | .choices 198 | .get(0) 199 | .ok_or(LLMClientError::FailedToGetResponse)?; 200 | let text = response.delta.content.to_owned(); 201 | if let Some(text) = text { 202 | buffer.push_str(&text); 203 | let _ = sender.send(LLMClientCompletionResponse::new( 204 | buffer.to_owned(), 205 | Some(text), 206 | model.to_owned(), 207 | )); 208 | } 209 | } 210 | Err(err) => { 211 | dbg!(err); 212 | break; 213 | } 214 | } 215 | } 216 | } 217 | } 218 | Ok(buffer) 219 | } 220 | 221 | async fn completion( 222 | &self, 223 | api_key: LLMProviderAPIKeys, 224 | request: LLMClientCompletionRequest, 225 | ) -> Result { 226 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 227 | let result = self.stream_completion(api_key, request, sender).await?; 228 | Ok(result) 229 | } 230 | 231 | async fn stream_prompt_completion( 232 | &self, 233 | _api_key: LLMProviderAPIKeys, 234 | _request: super::types::LLMClientCompletionStringRequest, 235 | _sender: tokio::sync::mpsc::UnboundedSender, 236 | ) -> Result { 237 | Err(LLMClientError::OpenAIDoesNotSupportCompletion) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /llm_client/src/tokenizer/tokenizer.rs: -------------------------------------------------------------------------------- 1 | //! We are going to run the various tokenizers here, we also make sure to run 2 | //! the tokenizer in a different thread here, because its important that we 3 | //! don't block the main thread from working 4 | 5 | use std::collections::HashMap; 6 | use std::str::FromStr; 7 | 8 | use thiserror::Error; 9 | use tiktoken_rs::ChatCompletionRequestMessage; 10 | use tokenizers::Tokenizer; 11 | 12 | use crate::{ 13 | clients::types::{LLMClientMessage, LLMClientRole, LLMType}, 14 | format::{ 15 | deepseekcoder::DeepSeekCoderFormatting, 16 | mistral::MistralInstructFormatting, 17 | mixtral::MixtralInstructFormatting, 18 | types::{LLMFormatting, TokenizerError}, 19 | }, 20 | }; 21 | 22 | pub struct LLMTokenizer { 23 | pub tokenizers: HashMap, 24 | pub formatters: HashMap>, 25 | } 26 | 27 | #[derive(Error, Debug)] 28 | pub enum LLMTokenizerError { 29 | #[error("Tokenizer not found for model {0}")] 30 | TokenizerNotFound(LLMType), 31 | 32 | #[error("Tokenizer error: {0}")] 33 | TokenizerError(String), 34 | 35 | #[error("error from tokenizer crate: {0}")] 36 | TokenizerCrateError(#[from] tokenizers::Error), 37 | 38 | #[error("anyhow error: {0}")] 39 | AnyhowError(#[from] anyhow::Error), 40 | 41 | #[error("tokenizer error: {0}")] 42 | TokenizerErrorInternal(#[from] TokenizerError), 43 | } 44 | 45 | pub enum LLMTokenizerInput { 46 | Prompt(String), 47 | Messages(Vec), 48 | } 49 | 50 | impl LLMTokenizer { 51 | pub fn new() -> Result { 52 | let tokenizer = Self { 53 | tokenizers: HashMap::new(), 54 | formatters: HashMap::new(), 55 | }; 56 | let updated_tokenizer = tokenizer 57 | .add_llm_type( 58 | LLMType::Mixtral, 59 | Box::new(MixtralInstructFormatting::new()?), 60 | ) 61 | .add_llm_type( 62 | LLMType::MistralInstruct, 63 | Box::new(MistralInstructFormatting::new()?), 64 | ) 65 | .add_llm_type( 66 | LLMType::DeepSeekCoder, 67 | Box::new(DeepSeekCoderFormatting::new()), 68 | ); 69 | Ok(updated_tokenizer) 70 | } 71 | 72 | fn add_llm_type( 73 | mut self, 74 | llm_type: LLMType, 75 | formatter: Box, 76 | ) -> Self { 77 | // This can be falliable, since soe llms might have formatting support 78 | // and if they don't thats fine 79 | let _ = self.load_tokenizer(&llm_type); 80 | self.formatters.insert(llm_type, formatter); 81 | self 82 | } 83 | 84 | fn to_openai_tokenizer(&self, model: &LLMType) -> Option { 85 | match model { 86 | LLMType::GPT3_5_16k => Some("gpt-3.5-turbo-16k-0613".to_owned()), 87 | LLMType::Gpt4 => Some("gpt-4-0613".to_owned()), 88 | LLMType::Gpt4Turbo => Some("gpt-4-1106-preview".to_owned()), 89 | LLMType::Gpt4_32k => Some("gpt-4-32k-0613".to_owned()), 90 | _ => None, 91 | } 92 | } 93 | 94 | pub fn count_tokens( 95 | &self, 96 | model: &LLMType, 97 | input: LLMTokenizerInput, 98 | ) -> Result { 99 | match input { 100 | LLMTokenizerInput::Prompt(prompt) => self.count_tokens_using_tokenizer(model, &prompt), 101 | LLMTokenizerInput::Messages(messages) => { 102 | // we can't send messages directly to the tokenizer, we have to 103 | // either make it a message or its an openai prompt in which case 104 | // its fine 105 | // so we are going to return an error if its not openai 106 | if model.is_openai() { 107 | // we can use the openai tokenizer 108 | let model = self.to_openai_tokenizer(model); 109 | match model { 110 | Some(model) => Ok(tiktoken_rs::num_tokens_from_messages( 111 | &model, 112 | messages 113 | .into_iter() 114 | .map(|message| { 115 | let role = message.role(); 116 | let content = message.content(); 117 | match role { 118 | LLMClientRole::User => ChatCompletionRequestMessage { 119 | role: "user".to_owned(), 120 | content: Some(content.to_owned()), 121 | name: None, 122 | function_call: None, 123 | }, 124 | LLMClientRole::Assistant => ChatCompletionRequestMessage { 125 | role: "assistant".to_owned(), 126 | content: Some(content.to_owned()), 127 | name: None, 128 | function_call: None, 129 | }, 130 | LLMClientRole::System => ChatCompletionRequestMessage { 131 | role: "system".to_owned(), 132 | content: Some(content.to_owned()), 133 | name: None, 134 | function_call: None, 135 | }, 136 | LLMClientRole::Function => ChatCompletionRequestMessage { 137 | role: "function".to_owned(), 138 | content: Some(content.to_owned()), 139 | name: None, 140 | function_call: None, 141 | }, 142 | } 143 | }) 144 | .collect::>() 145 | .as_slice(), 146 | )?), 147 | None => Err(LLMTokenizerError::TokenizerError( 148 | "Only openai models are supported for messages".to_owned(), 149 | )), 150 | } 151 | } else { 152 | let prompt = self 153 | .formatters 154 | .get(model) 155 | .map(|formatter| formatter.to_prompt(messages)); 156 | match prompt { 157 | Some(prompt) => { 158 | let num_tokens = self.tokenizers.get(model).map(|tokenizer| { 159 | tokenizer 160 | .encode(prompt, false) 161 | .map(|encoding| encoding.len()) 162 | }); 163 | match num_tokens { 164 | Some(Ok(num_tokens)) => Ok(num_tokens), 165 | _ => Err(LLMTokenizerError::TokenizerError( 166 | "Failed to encode prompt".to_owned(), 167 | )), 168 | } 169 | } 170 | None => Err(LLMTokenizerError::TokenizerError( 171 | "No formatter found for model".to_owned(), 172 | )), 173 | } 174 | } 175 | } 176 | } 177 | } 178 | 179 | pub fn count_tokens_using_tokenizer( 180 | &self, 181 | model: &LLMType, 182 | prompt: &str, 183 | ) -> Result { 184 | // we have the custom tokenizers already loaded, if this is not the openai loop 185 | if !model.is_openai() { 186 | let tokenizer = self.tokenizers.get(model); 187 | match tokenizer { 188 | Some(tokenizer) => { 189 | // Now over here we will try to figure out how to pass the 190 | // values around 191 | let result = tokenizer.encode(prompt, false); 192 | match result { 193 | Ok(encoding) => Ok(encoding.len()), 194 | Err(e) => Err(LLMTokenizerError::TokenizerError(format!( 195 | "Failed to encode prompt: {}", 196 | e 197 | ))), 198 | } 199 | } 200 | None => { 201 | return Err(LLMTokenizerError::TokenizerNotFound(model.clone())); 202 | } 203 | } 204 | } else { 205 | // If we are using openai model, then we have to use the bpe config 206 | // and count the number of tokens 207 | let model = self.to_openai_tokenizer(model); 208 | if let None = model { 209 | return Err(LLMTokenizerError::TokenizerError( 210 | "OpenAI model not found".to_owned(), 211 | )); 212 | } 213 | let model = model.expect("if let None to hold"); 214 | let bpe = tiktoken_rs::get_bpe_from_model(&model)?; 215 | Ok(bpe.encode_ordinary(prompt).len()) 216 | } 217 | } 218 | 219 | pub fn load_tokenizer(&mut self, model: &LLMType) -> Result<(), LLMTokenizerError> { 220 | let tokenizer = match model { 221 | LLMType::MistralInstruct => { 222 | let config = include_str!("configs/mistral.json"); 223 | Some(Tokenizer::from_str(config)?) 224 | } 225 | LLMType::Mixtral => { 226 | let config = include_str!("configs/mixtral.json"); 227 | Some(Tokenizer::from_str(config)?) 228 | } 229 | LLMType::DeepSeekCoder => { 230 | let config = include_str!("configs/deepseekcoder.json"); 231 | Some(Tokenizer::from_str(config)?) 232 | } 233 | _ => None, 234 | }; 235 | if let Some(tokenizer) = tokenizer { 236 | self.tokenizers.insert(model.clone(), tokenizer); 237 | } 238 | Ok(()) 239 | } 240 | } 241 | 242 | #[cfg(test)] 243 | mod tests { 244 | use std::str::FromStr; 245 | use tokenizers::Tokenizer; 246 | 247 | #[test] 248 | fn test_loading_deepseek_tokenizer_works() { 249 | let tokenizer_file = include_str!("configs/deepseekcoder.json"); 250 | let _ = Tokenizer::from_str(tokenizer_file).unwrap(); 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /llm_prompts/src/reranking/types.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use futures::future::Either; 4 | use llm_client::{ 5 | clients::types::{ 6 | LLMClientCompletionRequest, LLMClientCompletionStringRequest, LLMClientError, LLMType, 7 | }, 8 | tokenizer::tokenizer::LLMTokenizerError, 9 | }; 10 | 11 | #[derive(Clone, Debug, PartialEq)] 12 | pub struct CodeSpan { 13 | file_path: String, 14 | start_line: u64, 15 | end_line: u64, 16 | data: String, 17 | } 18 | 19 | impl CodeSpan { 20 | pub fn to_prompt(&self) -> String { 21 | format!( 22 | // TODO(skcd): Pass the language here for more accurate token counting 23 | "FILEPATH: {}-{}:{}\n```language\n{}```", 24 | self.file_path, self.start_line, self.end_line, self.data 25 | ) 26 | } 27 | 28 | pub fn merge_consecutive_spans(code_spans: Vec) -> Vec { 29 | const CHUNK_MERGE_DISTANCE: usize = 0; 30 | let mut file_to_code_snippets: HashMap> = Default::default(); 31 | 32 | code_spans.into_iter().for_each(|code_span| { 33 | let file_path = code_span.file_path.clone(); 34 | let code_spans = file_to_code_snippets 35 | .entry(file_path) 36 | .or_insert_with(Vec::new); 37 | code_spans.push(code_span); 38 | }); 39 | 40 | // We want to sort the code snippets in increasing order of the start line 41 | file_to_code_snippets 42 | .iter_mut() 43 | .for_each(|(_, code_snippets)| { 44 | code_snippets.sort_by(|a, b| a.start_line.cmp(&b.start_line)); 45 | }); 46 | 47 | // Now we will merge chunks which are in the range of CHUNK_MERGE_DISTANCE 48 | let results = file_to_code_snippets 49 | .into_iter() 50 | .map(|(file_path, mut code_snippets)| { 51 | let mut final_code_snippets = Vec::new(); 52 | let mut current_code_snippet = code_snippets.remove(0); 53 | for code_snippet in code_snippets { 54 | if code_snippet.end_line - current_code_snippet.start_line 55 | <= CHUNK_MERGE_DISTANCE as u64 56 | { 57 | // We can merge these two code snippets 58 | current_code_snippet.end_line = code_snippet.end_line; 59 | current_code_snippet.data = 60 | format!("{}{}", current_code_snippet.data, code_snippet.data); 61 | } else { 62 | // We cannot merge these two code snippets 63 | final_code_snippets.push(current_code_snippet); 64 | current_code_snippet = code_snippet; 65 | } 66 | } 67 | final_code_snippets.push(current_code_snippet); 68 | final_code_snippets 69 | .into_iter() 70 | .map(|code_snippet| CodeSpan { 71 | file_path: file_path.clone(), 72 | data: code_snippet.data, 73 | start_line: code_snippet.start_line, 74 | end_line: code_snippet.end_line, 75 | }) 76 | .collect::>() 77 | }) 78 | .flatten() 79 | .collect::>(); 80 | results 81 | } 82 | } 83 | 84 | /// This is the digest of the code span, we create a unique id for the code span 85 | /// always and use that for passing it to the prompt 86 | #[derive(Clone)] 87 | pub struct CodeSpanDigest { 88 | code_span: CodeSpan, 89 | hash: String, 90 | } 91 | 92 | impl CodeSpanDigest { 93 | pub fn new(code_span: CodeSpan, file_path: &str, index: usize) -> Self { 94 | // TODO(skcd): Add proper error handling here 95 | let base_name = std::path::Path::new(file_path) 96 | .file_name() 97 | .unwrap() 98 | .to_str() 99 | .unwrap(); 100 | Self { 101 | code_span, 102 | hash: format!("{}::{}", base_name, index), 103 | } 104 | } 105 | 106 | pub fn hash(&self) -> &str { 107 | &self.hash 108 | } 109 | 110 | pub fn data(&self) -> &str { 111 | self.code_span.data() 112 | } 113 | 114 | pub fn file_path(&self) -> &str { 115 | self.code_span.file_path() 116 | } 117 | 118 | pub fn get_code_span(self) -> CodeSpan { 119 | self.code_span 120 | } 121 | 122 | pub fn get_span_identifier(&self) -> String { 123 | format!( 124 | "// FILEPATH: {}:{}-{}", 125 | self.file_path(), 126 | self.code_span.start_line(), 127 | self.code_span.end_line() 128 | ) 129 | } 130 | } 131 | 132 | impl CodeSpan { 133 | pub fn new(file_path: String, start_line: u64, end_line: u64, data: String) -> Self { 134 | Self { 135 | file_path, 136 | start_line, 137 | end_line, 138 | data, 139 | } 140 | } 141 | 142 | pub fn file_path(&self) -> &str { 143 | &self.file_path 144 | } 145 | 146 | pub fn start_line(&self) -> u64 { 147 | self.start_line 148 | } 149 | 150 | pub fn end_line(&self) -> u64 { 151 | self.end_line 152 | } 153 | 154 | pub fn data(&self) -> &str { 155 | &self.data 156 | } 157 | 158 | pub fn to_digests(code_spans: Vec) -> Vec { 159 | // Naming the digests should happen using the filepath and creating a 160 | // numbered alias on top of it. 161 | let mut file_paths_counter: HashMap = Default::default(); 162 | code_spans 163 | .into_iter() 164 | .map(|code_span| { 165 | let file_path = code_span.file_path().to_owned(); 166 | let mut index = 0; 167 | if let Some(value) = file_paths_counter.get_mut(&file_path) { 168 | *value += 1; 169 | index = *value; 170 | } else { 171 | file_paths_counter.insert(file_path.to_string(), 0); 172 | } 173 | CodeSpanDigest::new(code_span, &file_path, index) 174 | }) 175 | .collect() 176 | } 177 | } 178 | 179 | /// We support both listwise and pairwise reranking strategies 180 | /// Going further we will add more strategies to this, right now 181 | /// these are the best ones 182 | /// list wise reading material here: https://arxiv.org/pdf/2312.02724.pdf 183 | /// point wise reading material here: https://cookbook.openai.com/examples/search_reranking_with_cross-encoders 184 | #[derive(Clone)] 185 | pub enum ReRankStrategy { 186 | ListWise, 187 | // This works best with logits enabled, if logits are not provied by the 188 | // underlying infra, then this is not that great tbh 189 | PointWise, 190 | } 191 | 192 | pub struct ReRankCodeSpanRequest { 193 | user_query: String, 194 | answer_snippets: usize, 195 | answer_limit_tokens: i64, 196 | code_spans: Vec, 197 | strategy: ReRankStrategy, 198 | llm_type: LLMType, 199 | } 200 | 201 | impl ReRankCodeSpanRequest { 202 | pub fn new( 203 | user_query: String, 204 | answer_snippets: usize, 205 | answer_limit_tokens: i64, 206 | code_spans: Vec, 207 | strategy: ReRankStrategy, 208 | llm_type: LLMType, 209 | ) -> Self { 210 | Self { 211 | user_query, 212 | answer_snippets, 213 | answer_limit_tokens, 214 | code_spans, 215 | strategy, 216 | llm_type, 217 | } 218 | } 219 | 220 | pub fn user_query(&self) -> &str { 221 | &self.user_query 222 | } 223 | 224 | pub fn limit(&self) -> usize { 225 | self.answer_snippets 226 | } 227 | 228 | pub fn token_limit(&self) -> i64 { 229 | self.answer_limit_tokens 230 | } 231 | 232 | pub fn strategy(&self) -> &ReRankStrategy { 233 | &self.strategy 234 | } 235 | 236 | pub fn code_spans(&self) -> &[CodeSpan] { 237 | self.code_spans.as_slice() 238 | } 239 | 240 | pub fn llm_type(&self) -> &LLMType { 241 | &self.llm_type 242 | } 243 | } 244 | 245 | pub struct ReRankListWiseResponse { 246 | pub prompt: Either, 247 | pub code_span_digests: Vec, 248 | } 249 | 250 | pub struct ReRankPointWisePrompt { 251 | pub prompt: Either, 252 | pub code_span_digest: CodeSpanDigest, 253 | } 254 | 255 | impl ReRankPointWisePrompt { 256 | pub fn new_message_request( 257 | prompt: LLMClientCompletionRequest, 258 | code_span_digest: CodeSpanDigest, 259 | ) -> Self { 260 | Self { 261 | prompt: Either::Left(prompt), 262 | code_span_digest, 263 | } 264 | } 265 | 266 | pub fn new_string_completion( 267 | prompt: LLMClientCompletionStringRequest, 268 | code_span_digest: CodeSpanDigest, 269 | ) -> Self { 270 | Self { 271 | prompt: Either::Right(prompt), 272 | code_span_digest, 273 | } 274 | } 275 | } 276 | 277 | pub enum ReRankCodeSpanResponse { 278 | ListWise(ReRankListWiseResponse), 279 | PointWise(Vec), 280 | } 281 | 282 | impl ReRankCodeSpanResponse { 283 | pub fn listwise_message( 284 | request: LLMClientCompletionRequest, 285 | code_span_digests: Vec, 286 | ) -> Self { 287 | Self::ListWise(ReRankListWiseResponse { 288 | prompt: Either::Left(request), 289 | code_span_digests, 290 | }) 291 | } 292 | 293 | pub fn listwise_completion( 294 | request: LLMClientCompletionStringRequest, 295 | code_span_digests: Vec, 296 | ) -> Self { 297 | Self::ListWise(ReRankListWiseResponse { 298 | prompt: Either::Right(request), 299 | code_span_digests, 300 | }) 301 | } 302 | 303 | pub fn pointwise(prompts: Vec) -> Self { 304 | Self::PointWise(prompts) 305 | } 306 | } 307 | 308 | #[derive(thiserror::Error, Debug)] 309 | pub enum ReRankCodeSpanError { 310 | #[error("Model not found")] 311 | ModelNotFound, 312 | 313 | #[error("tokenizer errors: {0}")] 314 | TokenizerError(#[from] LLMTokenizerError), 315 | 316 | #[error("Wrong rerank strategy returned")] 317 | WrongReRankStrategy, 318 | 319 | #[error("LLMClientError: {0}")] 320 | LLMClientError(#[from] LLMClientError), 321 | } 322 | 323 | /// The rerank code span will take in a list of code spans and generate a prompt 324 | /// for it, but I do think reranking by itself is pretty interesting, should we 325 | /// make it its own trait? 326 | pub trait ReRankCodeSpan { 327 | fn rerank_prompt( 328 | &self, 329 | request: ReRankCodeSpanRequest, 330 | ) -> Result; 331 | 332 | fn parse_listwise_output( 333 | &self, 334 | llm_output: String, 335 | rerank_request: ReRankListWiseResponse, 336 | ) -> Result, ReRankCodeSpanError>; 337 | } 338 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /llm_client/src/clients/types.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use serde::{ 3 | de::{self, Visitor}, 4 | Deserialize, Deserializer, Serialize, Serializer, 5 | }; 6 | use std::fmt; 7 | use thiserror::Error; 8 | use tokio::sync::mpsc::UnboundedSender; 9 | 10 | use crate::provider::{LLMProvider, LLMProviderAPIKeys}; 11 | 12 | #[derive(Debug, Clone, PartialEq, Hash, Eq)] 13 | pub enum LLMType { 14 | Mixtral, 15 | MistralInstruct, 16 | Gpt4, 17 | GPT3_5_16k, 18 | Gpt4_32k, 19 | Gpt4Turbo, 20 | DeepSeekCoder, 21 | Custom(String), 22 | } 23 | 24 | impl Serialize for LLMType { 25 | fn serialize(&self, serializer: S) -> Result 26 | where 27 | S: Serializer, 28 | { 29 | match self { 30 | LLMType::Custom(s) => serializer.serialize_str(s), 31 | _ => serializer.serialize_str(&format!("{:?}", self)), 32 | } 33 | } 34 | } 35 | 36 | impl<'de> Deserialize<'de> for LLMType { 37 | fn deserialize(deserializer: D) -> Result 38 | where 39 | D: Deserializer<'de>, 40 | { 41 | struct LLMTypeVisitor; 42 | 43 | impl<'de> Visitor<'de> for LLMTypeVisitor { 44 | type Value = LLMType; 45 | 46 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 47 | formatter.write_str("a string representing an LLMType") 48 | } 49 | 50 | fn visit_str(self, value: &str) -> Result 51 | where 52 | E: de::Error, 53 | { 54 | match value { 55 | "Mixtral" => Ok(LLMType::Mixtral), 56 | "MistralInstruct" => Ok(LLMType::MistralInstruct), 57 | "Gpt4" => Ok(LLMType::Gpt4), 58 | "GPT3_5_16k" => Ok(LLMType::GPT3_5_16k), 59 | "Gpt4_32k" => Ok(LLMType::Gpt4_32k), 60 | "Gpt4Turbo" => Ok(LLMType::Gpt4Turbo), 61 | "DeepSeekCoder" => Ok(LLMType::DeepSeekCoder), 62 | _ => Ok(LLMType::Custom(value.to_string())), 63 | } 64 | } 65 | } 66 | 67 | deserializer.deserialize_string(LLMTypeVisitor) 68 | } 69 | } 70 | 71 | impl LLMType { 72 | pub fn is_openai(&self) -> bool { 73 | matches!( 74 | self, 75 | LLMType::Gpt4 | LLMType::GPT3_5_16k | LLMType::Gpt4_32k | LLMType::Gpt4Turbo 76 | ) 77 | } 78 | 79 | pub fn is_custom(&self) -> bool { 80 | matches!(self, LLMType::Custom(_)) 81 | } 82 | } 83 | 84 | impl fmt::Display for LLMType { 85 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 86 | match self { 87 | LLMType::Mixtral => write!(f, "Mixtral"), 88 | LLMType::MistralInstruct => write!(f, "MistralInstruct"), 89 | LLMType::Gpt4 => write!(f, "Gpt4"), 90 | LLMType::GPT3_5_16k => write!(f, "GPT3_5_16k"), 91 | LLMType::Gpt4_32k => write!(f, "Gpt4_32k"), 92 | LLMType::Gpt4Turbo => write!(f, "Gpt4Turbo"), 93 | LLMType::DeepSeekCoder => write!(f, "DeepSeekCoder"), 94 | LLMType::Custom(s) => write!(f, "Custom({})", s), 95 | } 96 | } 97 | } 98 | 99 | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] 100 | pub enum LLMClientRole { 101 | System, 102 | User, 103 | Assistant, 104 | // function calling is weird, its only supported by openai right now 105 | // and not other LLMs, so we are going to make this work with the formatters 106 | // and still keep it as it is 107 | Function, 108 | } 109 | 110 | impl LLMClientRole { 111 | pub fn is_system(&self) -> bool { 112 | matches!(self, LLMClientRole::System) 113 | } 114 | 115 | pub fn is_user(&self) -> bool { 116 | matches!(self, LLMClientRole::User) 117 | } 118 | 119 | pub fn is_assistant(&self) -> bool { 120 | matches!(self, LLMClientRole::Assistant) 121 | } 122 | 123 | pub fn is_function(&self) -> bool { 124 | matches!(self, LLMClientRole::Function) 125 | } 126 | } 127 | 128 | #[derive(serde::Serialize, Debug, Clone)] 129 | pub struct LLMClientMessageFunctionCall { 130 | name: String, 131 | // arguments are generally given as a JSON string, so we keep it as a string 132 | // here, validate in the upper handlers for this 133 | arguments: String, 134 | } 135 | 136 | impl LLMClientMessageFunctionCall { 137 | pub fn name(&self) -> &str { 138 | &self.name 139 | } 140 | 141 | pub fn arguments(&self) -> &str { 142 | &self.arguments 143 | } 144 | } 145 | 146 | #[derive(serde::Serialize, Debug, Clone)] 147 | pub struct LLMClientMessageFunctionReturn { 148 | name: String, 149 | content: String, 150 | } 151 | 152 | impl LLMClientMessageFunctionReturn { 153 | pub fn name(&self) -> &str { 154 | &self.name 155 | } 156 | 157 | pub fn content(&self) -> &str { 158 | &self.content 159 | } 160 | } 161 | 162 | #[derive(serde::Serialize, Debug, Clone)] 163 | pub struct LLMClientMessage { 164 | role: LLMClientRole, 165 | message: String, 166 | function_call: Option, 167 | function_return: Option, 168 | } 169 | 170 | impl LLMClientMessage { 171 | pub fn new(role: LLMClientRole, message: String) -> Self { 172 | Self { 173 | role, 174 | message, 175 | function_call: None, 176 | function_return: None, 177 | } 178 | } 179 | 180 | pub fn function_call(name: String, arguments: String) -> Self { 181 | Self { 182 | role: LLMClientRole::Assistant, 183 | message: "".to_owned(), 184 | function_call: Some(LLMClientMessageFunctionCall { name, arguments }), 185 | function_return: None, 186 | } 187 | } 188 | 189 | pub fn function_return(name: String, content: String) -> Self { 190 | Self { 191 | role: LLMClientRole::Function, 192 | message: "".to_owned(), 193 | function_call: None, 194 | function_return: Some(LLMClientMessageFunctionReturn { name, content }), 195 | } 196 | } 197 | 198 | pub fn user(message: String) -> Self { 199 | Self::new(LLMClientRole::User, message) 200 | } 201 | 202 | pub fn assistant(message: String) -> Self { 203 | Self::new(LLMClientRole::Assistant, message) 204 | } 205 | 206 | pub fn system(message: String) -> Self { 207 | Self::new(LLMClientRole::System, message) 208 | } 209 | 210 | pub fn content(&self) -> &str { 211 | &self.message 212 | } 213 | 214 | pub fn function(message: String) -> Self { 215 | Self::new(LLMClientRole::Function, message) 216 | } 217 | 218 | pub fn role(&self) -> &LLMClientRole { 219 | &self.role 220 | } 221 | 222 | pub fn get_function_call(&self) -> Option<&LLMClientMessageFunctionCall> { 223 | self.function_call.as_ref() 224 | } 225 | 226 | pub fn get_function_return(&self) -> Option<&LLMClientMessageFunctionReturn> { 227 | self.function_return.as_ref() 228 | } 229 | } 230 | 231 | #[derive(Clone, Debug)] 232 | pub struct LLMClientCompletionRequest { 233 | model: LLMType, 234 | messages: Vec, 235 | temperature: f32, 236 | frequency_penalty: Option, 237 | } 238 | 239 | #[derive(Clone)] 240 | pub struct LLMClientCompletionStringRequest { 241 | model: LLMType, 242 | prompt: String, 243 | temperature: f32, 244 | frequency_penalty: Option, 245 | } 246 | 247 | impl LLMClientCompletionStringRequest { 248 | pub fn new( 249 | model: LLMType, 250 | prompt: String, 251 | temperature: f32, 252 | frequency_penalty: Option, 253 | ) -> Self { 254 | Self { 255 | model, 256 | prompt, 257 | temperature, 258 | frequency_penalty, 259 | } 260 | } 261 | 262 | pub fn model(&self) -> &LLMType { 263 | &self.model 264 | } 265 | 266 | pub fn temperature(&self) -> f32 { 267 | self.temperature 268 | } 269 | 270 | pub fn frequency_penalty(&self) -> Option { 271 | self.frequency_penalty 272 | } 273 | 274 | pub fn prompt(&self) -> &str { 275 | &self.prompt 276 | } 277 | } 278 | 279 | impl LLMClientCompletionRequest { 280 | pub fn new( 281 | model: LLMType, 282 | messages: Vec, 283 | temperature: f32, 284 | frequency_penalty: Option, 285 | ) -> Self { 286 | Self { 287 | model, 288 | messages, 289 | temperature, 290 | frequency_penalty, 291 | } 292 | } 293 | 294 | pub fn from_messages(messages: Vec, model: LLMType) -> Self { 295 | Self::new(model, messages, 0.0, None) 296 | } 297 | 298 | pub fn set_temperature(mut self, temperature: f32) -> Self { 299 | self.temperature = temperature; 300 | self 301 | } 302 | 303 | pub fn messages(&self) -> &[LLMClientMessage] { 304 | self.messages.as_slice() 305 | } 306 | 307 | pub fn temperature(&self) -> f32 { 308 | self.temperature 309 | } 310 | 311 | pub fn frequency_penalty(&self) -> Option { 312 | self.frequency_penalty 313 | } 314 | 315 | pub fn model(&self) -> &LLMType { 316 | &self.model 317 | } 318 | } 319 | 320 | pub struct LLMClientCompletionResponse { 321 | answer_up_until_now: String, 322 | delta: Option, 323 | model: String, 324 | } 325 | 326 | impl LLMClientCompletionResponse { 327 | pub fn new(answer_up_until_now: String, delta: Option, model: String) -> Self { 328 | Self { 329 | answer_up_until_now, 330 | delta, 331 | model, 332 | } 333 | } 334 | 335 | pub fn answer_up_until_now(&self) -> &str { 336 | &self.answer_up_until_now 337 | } 338 | 339 | pub fn delta(&self) -> Option<&str> { 340 | self.delta.as_deref() 341 | } 342 | 343 | pub fn model(&self) -> &str { 344 | &self.model 345 | } 346 | } 347 | 348 | #[derive(Error, Debug)] 349 | pub enum LLMClientError { 350 | #[error("Failed to get response from LLM")] 351 | FailedToGetResponse, 352 | 353 | #[error("Reqwest error: {0}")] 354 | ReqwestError(#[from] reqwest::Error), 355 | 356 | #[error("serde failed: {0}")] 357 | SerdeError(#[from] serde_json::Error), 358 | 359 | #[error("send error over channel: {0}")] 360 | SendError(#[from] tokio::sync::mpsc::error::SendError), 361 | 362 | #[error("unsupported model")] 363 | UnSupportedModel, 364 | 365 | #[error("OpenAI api error: {0}")] 366 | OpenAPIError(#[from] async_openai::error::OpenAIError), 367 | 368 | #[error("Wrong api key type")] 369 | WrongAPIKeyType, 370 | 371 | #[error("OpenAI does not support completion")] 372 | OpenAIDoesNotSupportCompletion, 373 | 374 | #[error("Sqlite setup error")] 375 | SqliteSetupError, 376 | 377 | #[error("tokio mspc error")] 378 | TokioMpscSendError, 379 | 380 | #[error("Failed to store in sqlite DB")] 381 | FailedToStoreInDB, 382 | 383 | #[error("Sqlx erorr: {0}")] 384 | SqlxError(#[from] sqlx::Error), 385 | 386 | #[error("Function calling role but not function call present")] 387 | FunctionCallNotPresent, 388 | } 389 | 390 | #[async_trait] 391 | pub trait LLMClient { 392 | fn client(&self) -> &LLMProvider; 393 | 394 | async fn stream_completion( 395 | &self, 396 | api_key: LLMProviderAPIKeys, 397 | request: LLMClientCompletionRequest, 398 | sender: UnboundedSender, 399 | ) -> Result; 400 | 401 | async fn completion( 402 | &self, 403 | api_key: LLMProviderAPIKeys, 404 | request: LLMClientCompletionRequest, 405 | ) -> Result; 406 | 407 | async fn stream_prompt_completion( 408 | &self, 409 | api_key: LLMProviderAPIKeys, 410 | request: LLMClientCompletionStringRequest, 411 | sender: UnboundedSender, 412 | ) -> Result; 413 | } 414 | 415 | #[cfg(test)] 416 | mod tests { 417 | use super::LLMType; 418 | 419 | #[test] 420 | fn test_llm_type_from_string() { 421 | let llm_type = LLMType::Custom("skcd_testing".to_owned()); 422 | let str_llm_type = serde_json::to_string(&llm_type).expect("to work"); 423 | assert_eq!(str_llm_type, ""); 424 | } 425 | } 426 | -------------------------------------------------------------------------------- /llm_prompts/src/reranking/broker.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cmp::{max, min}, 3 | collections::HashMap, 4 | sync::Arc, 5 | }; 6 | 7 | use futures::stream; 8 | use futures::StreamExt; 9 | use llm_client::{ 10 | broker::LLMBroker, 11 | clients::types::LLMType, 12 | provider::{LLMProvider, LLMProviderAPIKeys}, 13 | tokenizer::tokenizer::LLMTokenizer, 14 | }; 15 | 16 | use super::{ 17 | mistral::MistralReRank, 18 | openai::OpenAIReRank, 19 | types::{ 20 | CodeSpan, CodeSpanDigest, ReRankCodeSpan, ReRankCodeSpanError, ReRankCodeSpanRequest, 21 | ReRankCodeSpanResponse, ReRankListWiseResponse, ReRankStrategy, 22 | }, 23 | }; 24 | 25 | const SLIDING_WINDOW: i64 = 10; 26 | const TOP_K: i64 = 5; 27 | 28 | pub struct ReRankBroker { 29 | rerankers: HashMap>, 30 | } 31 | 32 | impl ReRankBroker { 33 | pub fn new() -> Self { 34 | let mut rerankers: HashMap> = HashMap::new(); 35 | rerankers.insert(LLMType::GPT3_5_16k, Box::new(OpenAIReRank::new())); 36 | rerankers.insert(LLMType::Gpt4, Box::new(OpenAIReRank::new())); 37 | rerankers.insert(LLMType::Gpt4_32k, Box::new(OpenAIReRank::new())); 38 | rerankers.insert(LLMType::MistralInstruct, Box::new(MistralReRank::new())); 39 | rerankers.insert(LLMType::Mixtral, Box::new(MistralReRank::new())); 40 | Self { rerankers } 41 | } 42 | 43 | pub fn rerank_prompt( 44 | &self, 45 | request: ReRankCodeSpanRequest, 46 | ) -> Result { 47 | let reranker = self.rerankers.get(&request.llm_type()).unwrap(); 48 | reranker.rerank_prompt(request) 49 | } 50 | 51 | fn measure_tokens( 52 | &self, 53 | llm_type: &LLMType, 54 | code_digests: &[CodeSpanDigest], 55 | tokenizer: Arc, 56 | ) -> Result { 57 | let total_tokens: usize = code_digests 58 | .into_iter() 59 | .map(|code_digest| { 60 | let file_path = code_digest.file_path(); 61 | let data = code_digest.data(); 62 | let prompt = format!( 63 | r#"FILEPATH: {file_path} 64 | ``` 65 | {data} 66 | ```"# 67 | ); 68 | tokenizer.count_tokens_using_tokenizer(llm_type, &prompt) 69 | }) 70 | .collect::>() 71 | .into_iter() 72 | .collect::, _>>()? 73 | .into_iter() 74 | .sum(); 75 | Ok(total_tokens) 76 | } 77 | 78 | fn order_code_digests_listwise( 79 | &self, 80 | llm_type: &LLMType, 81 | response: String, 82 | rerank_list_request: ReRankListWiseResponse, 83 | ) -> Result, ReRankCodeSpanError> { 84 | if let Some(reranker) = self.rerankers.get(llm_type) { 85 | let mut reranked_code_spans = 86 | reranker.parse_listwise_output(response, rerank_list_request)?; 87 | reranked_code_spans.reverse(); 88 | // revers it here since we want the most relevant one to be at the right and not the left 89 | Ok(reranked_code_spans) 90 | } else { 91 | Err(ReRankCodeSpanError::ModelNotFound) 92 | } 93 | } 94 | 95 | pub async fn listwise_reranking( 96 | &self, 97 | api_keys: LLMProviderAPIKeys, 98 | request: ReRankCodeSpanRequest, 99 | provider: LLMProvider, 100 | client_broker: Arc, 101 | tokenizer: Arc, 102 | ) -> Result, ReRankCodeSpanError> { 103 | // We are given a list of code spans, we are going to do the following: 104 | // - implement a sliding window algorithm which goes over the snippets 105 | // and keeps ranking them until we have the list of top k snippets 106 | let code_spans = request.code_spans().to_vec(); 107 | let mut digests = CodeSpan::to_digests(code_spans); 108 | // First we check if we need to do a sliding window here by measuring 109 | // against the token limit we have 110 | if request.token_limit() 111 | >= self.measure_tokens(request.llm_type(), &digests, tokenizer)? as i64 112 | { 113 | return Ok(digests 114 | .into_iter() 115 | .map(|digest| digest.get_code_span()) 116 | .collect()); 117 | } 118 | let mut end_index: i64 = (min( 119 | SLIDING_WINDOW, 120 | digests.len().try_into().expect("conversion to not fail"), 121 | ) - 1) 122 | .try_into() 123 | .expect("conversion to work"); 124 | while end_index < digests.len() as i64 { 125 | // Now that we are in the window, we have to take the elements from 126 | // (end_index - SLIDING_WINDOW)::(end_index) 127 | // and rank them, once we have these ranked 128 | // we move our window forward by TOP_K and repeat the process 129 | let llm_type = request.llm_type().clone(); 130 | let index_start: usize = max(end_index - SLIDING_WINDOW, 0).try_into().unwrap(); 131 | let end_index_usize = end_index.try_into().expect("to work"); 132 | let code_spans = digests[index_start..=end_index_usize] 133 | .iter() 134 | .map(|digest| digest.clone().get_code_span()) 135 | .collect::>(); 136 | let request = ReRankCodeSpanRequest::new( 137 | request.user_query().to_owned(), 138 | request.limit(), 139 | request.token_limit(), 140 | code_spans, 141 | request.strategy().clone(), 142 | llm_type.clone(), 143 | ); 144 | let prompt = self.rerank_prompt(request)?; 145 | if let ReRankCodeSpanResponse::ListWise(listwise_request) = prompt { 146 | let prompt = listwise_request.prompt.to_owned(); 147 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 148 | let response = client_broker 149 | .stream_answer( 150 | api_keys.clone(), 151 | provider.clone(), 152 | prompt, 153 | vec![("event_type".to_owned(), "listwise_reranking".to_owned())] 154 | .into_iter() 155 | .collect(), 156 | sender, 157 | ) 158 | .await?; 159 | 160 | // We have the updated list 161 | let updated_list = 162 | self.order_code_digests_listwise(&llm_type, response, listwise_request)?; 163 | // Now we will in place replace the code spans from the digests from our start position 164 | // with the elements in this list 165 | for (index, code_span_digest) in updated_list.into_iter().enumerate() { 166 | let index_i64: i64 = index.try_into().expect("to work"); 167 | let new_index: usize = (max(end_index - SLIDING_WINDOW, 0) + index_i64) 168 | .try_into() 169 | .expect("to work"); 170 | digests[new_index] = code_span_digest; 171 | } 172 | 173 | // Now move the window forward 174 | end_index += TOP_K; 175 | } else { 176 | return Err(ReRankCodeSpanError::WrongReRankStrategy); 177 | } 178 | // let response = client_broker.stream_completion(api_key, request, metadata, sender) 179 | } 180 | 181 | // At the end of this iteration we have our updated list of answers 182 | 183 | // First reverse the list so its ordered from the most relevant to the least 184 | digests.reverse(); 185 | // Only take the request.limit() number of answers 186 | digests.truncate(request.limit()); 187 | // convert back to the code span 188 | Ok(digests 189 | .into_iter() 190 | .map(|digest| digest.get_code_span()) 191 | .collect()) 192 | } 193 | 194 | pub async fn pointwise_reranking( 195 | &self, 196 | api_keys: LLMProviderAPIKeys, 197 | provider: LLMProvider, 198 | request: ReRankCodeSpanRequest, 199 | client_broker: Arc, 200 | tokenizer: Arc, 201 | ) -> Result, ReRankCodeSpanError> { 202 | // This approach uses the logits generated for yes and no to get the final 203 | // answer, since we are not use if we can logits yet on various platforms 204 | // we assume 1.0 for yes if thats the case or 0.0 for no otherwise 205 | let code_spans = request.code_spans().to_vec(); 206 | let digests = CodeSpan::to_digests(code_spans); 207 | let answer_snippets = request.limit(); 208 | 209 | // We first measure if we are within the token limit 210 | if request.token_limit() 211 | >= self.measure_tokens(request.llm_type(), &digests, tokenizer)? as i64 212 | { 213 | return Ok(digests 214 | .into_iter() 215 | .map(|digest| digest.get_code_span()) 216 | .collect()); 217 | } 218 | 219 | let request = ReRankCodeSpanRequest::new( 220 | request.user_query().to_owned(), 221 | request.limit(), 222 | request.token_limit(), 223 | digests 224 | .into_iter() 225 | .map(|digest| digest.get_code_span()) 226 | .collect(), 227 | request.strategy().clone(), 228 | request.llm_type().clone(), 229 | ); 230 | 231 | let prompt = self.rerank_prompt(request)?; 232 | 233 | if let ReRankCodeSpanResponse::PointWise(pointwise_prompts) = prompt { 234 | let response_with_code_digests = stream::iter(pointwise_prompts.into_iter()) 235 | .map(|pointwise_prompt| async { 236 | let prompt = pointwise_prompt.prompt; 237 | let code_digest = pointwise_prompt.code_span_digest; 238 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 239 | client_broker 240 | .stream_answer( 241 | api_keys.clone(), 242 | provider.clone(), 243 | prompt, 244 | vec![("event_type".to_owned(), "pointwise_reranking".to_owned())] 245 | .into_iter() 246 | .collect(), 247 | sender, 248 | ) 249 | .await 250 | .map(|response| (response, code_digest)) 251 | }) 252 | .buffer_unordered(25) 253 | .filter_map(|response| { 254 | if let Ok((response, code_digest)) = response { 255 | if response.trim().to_lowercase() == "yes" { 256 | futures::future::ready(Some(code_digest)) 257 | } else { 258 | futures::future::ready(None) 259 | } 260 | } else { 261 | futures::future::ready(None) 262 | } 263 | }) 264 | .collect::>() 265 | .await; 266 | // Now we only keep the code spans from the start until the length 267 | // of the limit we have 268 | let mut response_with_code_digests = response_with_code_digests 269 | .into_iter() 270 | .map(|code_digest| code_digest.get_code_span()) 271 | .collect::>(); 272 | // Only keep until the answer snippets which are limited in this case 273 | response_with_code_digests.truncate(answer_snippets); 274 | return Ok(response_with_code_digests); 275 | } else { 276 | return Err(ReRankCodeSpanError::WrongReRankStrategy); 277 | } 278 | } 279 | 280 | pub async fn rerank( 281 | &self, 282 | api_keys: LLMProviderAPIKeys, 283 | provider: LLMProvider, 284 | request: ReRankCodeSpanRequest, 285 | // we need the broker here to get the right client 286 | client_broker: Arc, 287 | // we need the tokenizer here to count the tokens properly 288 | tokenizer_broker: Arc, 289 | ) -> Result, ReRankCodeSpanError> { 290 | let strategy = request.strategy(); 291 | match strategy { 292 | ReRankStrategy::ListWise => { 293 | self.listwise_reranking( 294 | api_keys, 295 | request, 296 | provider, 297 | client_broker, 298 | tokenizer_broker, 299 | ) 300 | .await 301 | } 302 | ReRankStrategy::PointWise => { 303 | // We need to generate the prompt for this 304 | self.pointwise_reranking( 305 | api_keys, 306 | provider, 307 | request, 308 | client_broker, 309 | tokenizer_broker, 310 | ) 311 | .await 312 | } 313 | } 314 | } 315 | } 316 | -------------------------------------------------------------------------------- /llm_prompts/src/bin/mistral_reranking.rs: -------------------------------------------------------------------------------- 1 | //! We want to test the reranking with mistral 2 | 3 | use std::path::PathBuf; 4 | 5 | use llm_client::{ 6 | broker::LLMBroker, 7 | clients::types::{LLMClientCompletionStringRequest, LLMType}, 8 | config::LLMBrokerConfiguration, 9 | provider::{LLMProviderAPIKeys, TogetherAIProvider}, 10 | }; 11 | 12 | #[tokio::main] 13 | async fn main() { 14 | let prompt = r#"[INST] You are an expert at ordering the code snippets from the most relevant to the least relevant for the user query. You have the order the list of code snippets from the most relevant to the least relevant. As an example 15 | 16 | 17 | subtract.rs::0 18 | 19 | 20 | ``` 21 | fn subtract(a: i32, b: i32) -> i32 { 22 | a - b 23 | } 24 | ``` 25 | 26 | 27 | 28 | add.rs::0 29 | 30 | 31 | ``` 32 | fn add(a: i32, b: i32) -> i32 { 33 | a + b 34 | } 35 | ``` 36 | 37 | 38 | 39 | And if you thought the code snippet with id add.rs::0 is more relevant than subtract.rs::0 then you would rank it as: 40 | 41 | 42 | add.rs::0 43 | 44 | 45 | subtract.rs::0 46 | 47 | 48 | 49 | Now for the actual query. 50 | The user has asked the following query: 51 | 52 | User: [#file:client.rs:1-635](values:file:client.rs:1-635) where do we initialize the language server? 53 | 54 | 55 | The code snippets along with their ids are given below: 56 | 57 | 58 | client.rs::4 59 | 60 | 61 | ``` 62 | pub struct LanguageServerRef(Arc>>); 63 | 64 | //FIXME: this is hacky, and prevents good error propogation, 65 | fn number_from_id(id: Option<&Value>) -> usize { 66 | let id = id.expect("response missing id field"); 67 | let id = match id { 68 | &Value::Number(ref n) => n.as_u64().expect("failed to take id as u64"), 69 | &Value::String(ref s) => { 70 | u64::from_str_radix(s, 10).expect("failed to convert string id to u64") 71 | } 72 | other => panic!("unexpected value for id field: {:?}", other), 73 | }; 74 | 75 | id as usize 76 | } 77 | 78 | fn fetch_ts_files_recursively(dir: &Path, files: &mut Vec) -> Result<()> { 79 | ``` 80 | 81 | 82 | 83 | client.rs::0 84 | 85 | 86 | ``` 87 | use anyhow::Result; 88 | use lsp_types::{ 89 | ClientCapabilities, CodeActionClientCapabilities, CodeLensClientCapabilities, 90 | DynamicRegistrationClientCapabilities, ExecuteCommandClientCapabilities, GotoCapability, 91 | GotoDefinitionParams, GotoDefinitionResponse, InitializeParams, ReferenceParams, 92 | RenameClientCapabilities, SignatureHelpClientCapabilities, TextDocumentClientCapabilities, 93 | WorkDoneProgressParams, WorkspaceClientCapabilities, WorkspaceFolder, 94 | }; 95 | use std::str::FromStr; 96 | use tokio::fs; 97 | use tokio::task::JoinHandle; 98 | use tracing::info; 99 | 100 | use std::collections::HashMap; 101 | use std::ffi::OsStr; 102 | use std::path::{Path, PathBuf}; 103 | use std::process::Stdio; 104 | use std::sync::{Arc, Mutex}; 105 | use tokio::io::{AsyncWriteExt, BufReader}; 106 | use tokio::process::{Child, ChildStdin, Command}; 107 | use tokio::sync::oneshot; 108 | use url::Url; 109 | 110 | use serde_json::value::Value; 111 | use serde_json::{self, json}; 112 | 113 | use jsonrpc_lite::{Error, Id, JsonRpc}; 114 | 115 | use super::parsing; 116 | 117 | // this to get around some type system pain related to callbacks. See: 118 | // https://doc.rust-lang.org/beta/book/trait-objects.html, 119 | // http://stackoverflow.com/questions/41081240/idiomatic-callbacks-in-rust 120 | trait Callable: Send { 121 | fn call(self: Box, result: Result); 122 | } 123 | 124 | impl)> Callable for F { 125 | fn call(self: Box, result: Result) { 126 | (*self)(result) 127 | } 128 | } 129 | 130 | type Callback = Box; 131 | 132 | /// Represents (and mediates communcation with) a Language Server. 133 | /// 134 | ``` 135 | 136 | 137 | 138 | client.rs::1 139 | 140 | 141 | ``` 142 | /// LanguageServer should only ever be instantiated or accessed through an instance of 143 | /// LanguageServerRef, which mediates access to a single shared LanguageServer through a Mutex. 144 | struct LanguageServer { 145 | peer: W, 146 | pending: HashMap, 147 | next_id: usize, 148 | } 149 | 150 | /// Generates a Language Server Protocol compliant message. 151 | fn prepare_lsp_json(msg: &Value) -> Result { 152 | let request = serde_json::to_string(&msg)?; 153 | Ok(format!( 154 | "Content-Length: {}\r\n\r\n{}", 155 | request.len(), 156 | request 157 | )) 158 | } 159 | 160 | impl LanguageServer { 161 | ``` 162 | 163 | 164 | 165 | client.rs::2 166 | 167 | 168 | ``` 169 | async fn write(&mut self, msg: &str) { 170 | self.peer 171 | .write_all(msg.as_bytes()) 172 | .await 173 | .expect("error writing to stdin"); 174 | self.peer.flush().await.expect("error flushing child stdin"); 175 | } 176 | 177 | async fn send_request(&mut self, method: &str, params: &Value, completion: Callback) { 178 | let request = json!({ 179 | "jsonrpc": "2.0", 180 | "id": self.next_id, 181 | "method": method, 182 | "params": params 183 | }); 184 | 185 | self.pending.insert(self.next_id, completion); 186 | self.next_id += 1; 187 | self.send_rpc(&request).await; 188 | } 189 | 190 | async fn send_notification(&mut self, method: &str, params: &Value) { 191 | let notification = json!({ 192 | "jsonrpc": "2.0", 193 | "method": method, 194 | "params": params 195 | }); 196 | self.send_rpc(¬ification).await; 197 | } 198 | 199 | fn handle_response(&mut self, id: usize, result: Value) { 200 | let callback = self 201 | .pending 202 | .remove(&id) 203 | .expect(&format!("id {} missing from request table", id)); 204 | callback.call(Ok(result)); 205 | } 206 | 207 | fn handle_error(&mut self, id: usize, error: Error) { 208 | let callback = self 209 | .pending 210 | .remove(&id) 211 | .expect(&format!("id {} missing from request table", id)); 212 | callback.call(Err(error.data.unwrap_or(serde_json::Value::Null))); 213 | } 214 | 215 | async fn send_rpc(&mut self, rpc: &Value) { 216 | let rpc = match prepare_lsp_json(&rpc) { 217 | ``` 218 | 219 | 220 | 221 | client.rs::3 222 | 223 | 224 | ``` 225 | Ok(r) => r, 226 | Err(err) => panic!("error encoding rpc {:?}", err), 227 | }; 228 | self.write(&rpc).await; 229 | } 230 | } 231 | 232 | /// Access control and convenience wrapper around a shared LanguageServer instance. 233 | ``` 234 | 235 | 236 | 237 | client.rs::5 238 | 239 | 240 | ``` 241 | // If the path starts with `file://` then we need to remove it. Make sure the 3rd slash is not removed. 242 | // Example: `file:///Users/nareshr/github/codestory/ide/src` -> `/Users/nareshr/github/codestory/ide/src` 243 | // Don't use `strip_prefix` because it removes the 3rd slash. 244 | // let dir = dir.to_str().unwrap().replace("file://", ""); 245 | 246 | match std::fs::read_dir(dir) { 247 | Ok(entries) => { 248 | println!("Successfully read directory: {:?}", dir); 249 | for entry in entries { 250 | match entry { 251 | Ok(entry) => { 252 | let path = entry.path(); 253 | if path.is_dir() { 254 | match fetch_ts_files_recursively(&path, files) { 255 | Ok(_) => (), 256 | Err(e) => eprintln!("Error reading directory: {:?}", e), 257 | } 258 | } else if path.extension() == Some(OsStr::new("ts")) { 259 | files.push(path.clone()); 260 | println!("Added file: {:?}", path); 261 | } 262 | } 263 | Err(e) => eprintln!("Error reading entry: {:?}", e), 264 | } 265 | } 266 | } 267 | Err(e) => eprintln!("Error reading directory: {:?}", e), 268 | } 269 | Ok(()) 270 | } 271 | 272 | impl LanguageServerRef { 273 | fn new(peer: W) -> Self { 274 | LanguageServerRef(Arc::new(Mutex::new(LanguageServer { 275 | peer: peer, 276 | ``` 277 | 278 | 279 | 280 | client.rs::6 281 | 282 | 283 | ``` 284 | pending: HashMap::new(), 285 | next_id: 1, 286 | }))) 287 | } 288 | 289 | fn handle_msg(&self, val: &str) { 290 | let parsed_value = JsonRpc::parse(val); 291 | if let Err(err) = parsed_value { 292 | println!("error parsing json: {:?}", err); 293 | return; 294 | } 295 | let parsed_value = parsed_value.expect("to be present"); 296 | let id = parsed_value.get_id(); 297 | let response = parsed_value.get_result(); 298 | let error = parsed_value.get_error(); 299 | match (id, response, error) { 300 | (Some(Id::Num(id)), Some(response), None) => { 301 | let mut inner = self.0.lock().unwrap(); 302 | inner.handle_response(id.try_into().unwrap(), response.clone()); 303 | } 304 | (Some(Id::Num(id)), None, Some(error)) => { 305 | let mut inner = self.0.lock().unwrap(); 306 | inner.handle_error(id.try_into().unwrap(), error.clone()); 307 | } 308 | (Some(Id::Num(id)), Some(response), Some(error)) => { 309 | panic!("We got both response and error.. what even??"); 310 | } 311 | _ => {} 312 | } 313 | } 314 | 315 | /// Sends a JSON-RPC request message with the provided method and parameters. 316 | /// `completion` should be a callback which will be executed with the server's response. 317 | pub async fn send_request(&self, method: &str, params: &Value, completion: CB) 318 | where 319 | CB: 'static + Send + FnOnce(Result), 320 | { 321 | let mut inner = self.0.lock().unwrap(); 322 | ``` 323 | 324 | 325 | 326 | client.rs::7 327 | 328 | 329 | ``` 330 | inner 331 | .send_request(method, params, Box::new(completion)) 332 | .await; 333 | } 334 | 335 | /// Sends a JSON-RPC notification message with the provided method and parameters. 336 | pub async fn send_notification(&self, method: &str, params: &Value) { 337 | let mut inner = self.0.lock().unwrap(); 338 | inner.send_notification(method, params).await; 339 | } 340 | 341 | pub async fn initialize(&self, working_directory: &PathBuf) { 342 | info!( 343 | event_name = "initialize_lsp", 344 | event_type = "start", 345 | working_directory = ?working_directory 346 | ); 347 | 348 | ``` 349 | 350 | 351 | 352 | client.rs::8 353 | 354 | 355 | ``` 356 | let working_directory_path = 357 | format!("file://{}", working_directory.to_str().expect("to work")); 358 | let start = std::time::Instant::now(); 359 | 360 | let init_params = InitializeParams { 361 | process_id: None, // Super important to set it to NONE https://github.com/typescript-language-server/typescript-language-server/issues/262 362 | ``` 363 | 364 | 365 | 366 | client.rs::9 367 | 368 | 369 | ``` 370 | root_uri: Some(Url::parse(&working_directory_path).unwrap()), 371 | root_path: None, 372 | initialization_options: Some(serde_json::json!({ 373 | "hostInfo": "vscode", 374 | "maxTsServerMemory": 4096 * 2, 375 | "tsserver": { 376 | "logDirectory": "/tmp/tsserver", 377 | "logVerbosity": "verbose", 378 | "maxTsServerMemory": 4096 * 2, 379 | // sending the same path as the vscode extension 380 | // "path": "/Users/skcd/.aide/extensions/ms-vscode.vscode-typescript-next-5.3.20231102/node_modules/typescript/lib/tsserver.js" 381 | }, 382 | "preferences": { 383 | "providePrefixAndSuffixTextForRename": true, 384 | "allowRenameOfImportPath": true, 385 | "includePackageJsonAutoImports": "auto", 386 | "excludeLibrarySymbolsInNavTo": true 387 | } 388 | })), 389 | capabilities: ClientCapabilities { 390 | text_document: Some(TextDocumentClientCapabilities { 391 | declaration: Some(GotoCapability { 392 | dynamic_registration: Some(true), 393 | ``` 394 | 395 | 396 | 397 | As a reminder the user question is: [#file:client.rs:1-635](values:file:client.rs:1-635) where do we initialize the language server? 398 | You have to order all the code snippets from the most relevant to the least relevant to the user query, all the code snippet ids should be present in your final reordered list. Only output the ids of the code snippets. 399 | [/INST] 400 | 401 | "#; 402 | let llm_broker = LLMBroker::new(LLMBrokerConfiguration::new(PathBuf::from( 403 | "/Users/skcd/Library/Application Support/ai.codestory.sidecar", 404 | ))) 405 | .await 406 | .expect("broker to startup"); 407 | 408 | let api_key = LLMProviderAPIKeys::TogetherAI(TogetherAIProvider::new("some_key".to_owned())); 409 | let request = LLMClientCompletionStringRequest::new( 410 | LLMType::MistralInstruct, 411 | prompt.to_owned(), 412 | 0.9, 413 | None, 414 | ); 415 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 416 | let metadata = vec![("event_type".to_owned(), "listwise_reranking".to_owned())] 417 | .into_iter() 418 | .collect(); 419 | let result = llm_broker 420 | .stream_string_completion(api_key.clone(), request, metadata, sender) 421 | .await; 422 | println!("Mistral:"); 423 | println!("{:?}", result); 424 | let mixtral_request = 425 | LLMClientCompletionStringRequest::new(LLMType::Mixtral, prompt.to_owned(), 0.7, None); 426 | let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); 427 | let metadata = vec![("event_type".to_owned(), "listwise_reranking".to_owned())] 428 | .into_iter() 429 | .collect(); 430 | let result = llm_broker 431 | .stream_string_completion(api_key, mixtral_request, metadata, sender) 432 | .await; 433 | println!("Mixtral:"); 434 | println!("{:?}", result); 435 | } 436 | --------------------------------------------------------------------------------