├── examples ├── metallica.pdf ├── metallica.png ├── utils │ ├── mod.rs │ └── vertex.rs ├── deprecated_use_openai_completions.rs ├── use_completions_vertex.rs ├── deprecated_use_openai_assistant.rs ├── use_xai_tools.rs ├── use_mistral_tools.rs ├── use_openai_assistant.rs ├── use_openai_assistant_json.rs ├── use_openai_assistant_azure.rs ├── use_google_tools.rs ├── use_anthropic_tools.rs ├── use_openai_responses.rs └── use_completions.rs ├── .gitignore ├── src ├── deprecated │ └── mod.rs ├── assistants │ ├── mod.rs │ └── openai │ │ ├── mod.rs │ │ └── openai_vector_store.rs ├── apis │ ├── mod.rs │ ├── google.rs │ ├── anthropic.rs │ └── openai.rs ├── files │ ├── mod.rs │ ├── llm_files.rs │ ├── openai.rs │ └── anthropic.rs ├── lib.rs ├── llm_models │ ├── mod.rs │ ├── deepseek.rs │ ├── llm_model.rs │ ├── xai.rs │ ├── aws.rs │ ├── perplexity.rs │ ├── mistral.rs │ └── anthropic.rs ├── enums.rs ├── constants.rs └── completions.rs ├── LICENSE-APACHE ├── .github └── workflows │ └── ci.yaml ├── LICENSE-MIT ├── Cargo.toml └── README.md /examples/metallica.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neferdata/allms/HEAD/examples/metallica.pdf -------------------------------------------------------------------------------- /examples/metallica.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neferdata/allms/HEAD/examples/metallica.png -------------------------------------------------------------------------------- /examples/utils/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod vertex; 2 | 3 | // Re-export commonly used functions for convenience 4 | pub use vertex::get_vertex_token; 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | /gh-pages 3 | __pycache__ 4 | .DS_Store 5 | *.so 6 | *.out 7 | *.pyc 8 | *.pid 9 | *.sock 10 | *~ 11 | target/ 12 | .idea 13 | examples/secrets 14 | examples/data 15 | .env -------------------------------------------------------------------------------- /src/deprecated/mod.rs: -------------------------------------------------------------------------------- 1 | mod openai_assistant_deprecated; 2 | mod openai_completions_deprecated; 3 | 4 | pub use openai_assistant_deprecated::{OpenAIAssistant, OpenAIAssistantVersion, OpenAIFile}; 5 | pub use openai_completions_deprecated::{OpenAI, OpenAIModels}; 6 | -------------------------------------------------------------------------------- /src/assistants/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod openai; 2 | 3 | pub use crate::files::LLMFiles; 4 | pub use openai::{ 5 | OpenAIAssistant, OpenAIAssistantResource, OpenAIAssistantVersion, OpenAIFile, 6 | OpenAIVectorStore, OpenAIVectorStoreFileCounts, OpenAIVectorStoreStatus, 7 | }; 8 | -------------------------------------------------------------------------------- /src/apis/mod.rs: -------------------------------------------------------------------------------- 1 | mod anthropic; 2 | mod google; 3 | mod openai; 4 | 5 | pub use anthropic::AnthropicApiEndpoints; 6 | pub use google::GoogleApiEndpoints; 7 | pub use openai::{ 8 | OpenAIAssistantResource, OpenAIAssistantVersion, OpenAICompletionsAPI, OpenAiApiEndpoints, 9 | }; 10 | -------------------------------------------------------------------------------- /src/files/mod.rs: -------------------------------------------------------------------------------- 1 | mod anthropic; 2 | mod llm_files; 3 | mod openai; 4 | 5 | /// Main trait for LLM files 6 | pub use llm_files::LLMFiles; 7 | 8 | /// OpenAI file implementation 9 | pub use openai::OpenAIFile; 10 | 11 | /// Anthropic file implementation 12 | pub use anthropic::AnthropicFile; 13 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod apis; 2 | pub mod assistants; 3 | mod completions; 4 | mod constants; 5 | mod domain; 6 | mod enums; 7 | pub mod files; 8 | pub mod llm_models; 9 | pub use llm_models as llm; 10 | mod utils; 11 | 12 | #[allow(deprecated)] 13 | mod deprecated; 14 | 15 | pub use crate::completions::{Completions, ThinkingLevel}; 16 | #[allow(deprecated)] 17 | pub use crate::deprecated::{ 18 | OpenAI, OpenAIAssistant, OpenAIAssistantVersion, OpenAIFile, OpenAIModels, 19 | }; 20 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Copyright [2023] [Neferdata] 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/actions-rs/meta/blob/master/recipes/quickstart.md 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'examples/**' 7 | - 'src/**' 8 | - 'Cargo.toml' 9 | 10 | name: check 11 | 12 | env: 13 | ACTIX_PORT: 8080 14 | 15 | jobs: 16 | test: 17 | name: cargo test 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v3 21 | - uses: dtolnay/rust-toolchain@1.89 22 | with: 23 | components: clippy, rustfmt 24 | - run: cargo clippy -- --deny warnings 25 | - run: cargo fmt --check 26 | - run: cargo test 27 | - run: cargo publish --dry-run 28 | -------------------------------------------------------------------------------- /examples/utils/vertex.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use yup_oauth2::{read_service_account_key, ServiceAccountAuthenticator}; 3 | 4 | // Reusable function for Vertex API authentication 5 | pub async fn get_vertex_token() -> Result { 6 | let service_account_key = read_service_account_key("secrets/gcp_sa_key.json") 7 | .await 8 | .unwrap(); 9 | 10 | let auth = ServiceAccountAuthenticator::builder(service_account_key) 11 | .build() 12 | .await 13 | .unwrap(); 14 | 15 | let google_token = auth 16 | .token(&["https://www.googleapis.com/auth/cloud-platform"]) 17 | .await 18 | .unwrap(); 19 | 20 | Ok(google_token.token().unwrap().to_string()) 21 | } 22 | -------------------------------------------------------------------------------- /src/llm_models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod anthropic; 2 | pub mod aws; 3 | pub mod deepseek; 4 | pub mod google; 5 | pub mod llm_model; 6 | pub mod mistral; 7 | pub mod openai; 8 | pub mod perplexity; 9 | pub mod tools; 10 | pub mod xai; 11 | 12 | pub use anthropic::AnthropicModels; 13 | pub use aws::AwsBedrockModels; 14 | pub use deepseek::DeepSeekModels; 15 | pub use google::GoogleModels; 16 | pub use llm_model::LLMModel; 17 | pub use llm_model::LLMModel as LLM; 18 | pub use mistral::MistralModels; 19 | pub use openai::OpenAIModels; 20 | pub use perplexity::PerplexityModels; 21 | pub use tools::LLMTools; 22 | pub use xai::XAIModels; 23 | 24 | // Re-export structs for backwards compatibility 25 | pub use crate::apis::GoogleApiEndpoints; 26 | pub use crate::apis::{ 27 | OpenAIAssistantResource, OpenAIAssistantVersion, OpenAICompletionsAPI, OpenAiApiEndpoints, 28 | }; 29 | -------------------------------------------------------------------------------- /src/apis/google.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | // Enum of supported Completions APIs 4 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)] 5 | pub enum GoogleApiEndpoints { 6 | #[default] 7 | GoogleStudio, 8 | GoogleVertex, 9 | } 10 | 11 | impl GoogleApiEndpoints { 12 | /// Parses a string into `GoogleApiEndpoints`. 13 | /// 14 | /// Supported formats (case-insensitive): 15 | /// - `"google-studio"` -> `GoogleApiEndpoints::GoogleStudio` 16 | /// - `"google-vertex"` -> `GoogleApiEndpoints::GoogleVertex` 17 | /// 18 | /// Returns default for others. 19 | #[allow(clippy::should_implement_trait)] 20 | pub fn from_str(s: &str) -> Self { 21 | let s_lower = s.to_lowercase(); 22 | match s_lower.as_str() { 23 | "google-studio" => GoogleApiEndpoints::GoogleStudio, 24 | "google-vertex" => GoogleApiEndpoints::GoogleVertex, 25 | _ => GoogleApiEndpoints::default(), 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Neferdata 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /examples/deprecated_use_openai_completions.rs: -------------------------------------------------------------------------------- 1 | use allms::OpenAI; 2 | use allms::OpenAIModels; 3 | use schemars::JsonSchema; 4 | use serde::Deserialize; 5 | use serde::Serialize; 6 | 7 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 8 | struct TranslationResponse { 9 | pub spanish: String, 10 | pub french: String, 11 | pub german: String, 12 | pub polish: String, 13 | } 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | env_logger::init(); 18 | let api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 19 | let model = OpenAIModels::Gpt4; // Choose the model 20 | 21 | let open_ai = OpenAI::new(&api_key, model, None, None); 22 | 23 | // Example context and instructions 24 | let instructions = 25 | "Translate the following English sentence to all the languages in the response type: Rust is best for working with LLMs"; 26 | 27 | match open_ai 28 | .get_answer::(instructions) 29 | .await 30 | { 31 | Ok(response) => println!("Response: {:?}", response), 32 | Err(e) => eprintln!("Error: {:?}", e), 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "allms" 3 | version = "0.33.0" 4 | edition = "2021" 5 | authors = [ 6 | "Kamil Litman ", 7 | "Nate Dubetz ", 8 | "Chance Cyphers " 9 | ] 10 | keywords = ["openai", "anthropic", "mistral", "gemini", "assistant"] 11 | description = "One Library to rule them aLLMs" 12 | license = "MIT" 13 | repository = "https://github.com/neferdata/allms.git" 14 | readme = "README.md" 15 | categories = ["api-bindings", "development-tools", "parsing", "science", "text-processing"] 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | 18 | [dependencies] 19 | anyhow = "1.0.60" 20 | aws-config = "1.5.4" 21 | aws-sdk-bedrockruntime = "1.40.0" 22 | env_logger = "0.9.0" 23 | jsonschema = "=0.15.2" 24 | log = "0.4.0" 25 | r2d2 = "0.8.10" 26 | r2d2_postgres = "0.18.1" 27 | regex = "1.11.1" 28 | serde = "1.0.140" 29 | serde_json = "1.0.82" 30 | tiktoken-rs = "0.4.4" 31 | schemars = "0.8.12" 32 | reqwest = { version = "0.11.11", features = ["json", "multipart", "stream"]} 33 | lazy_static = "1.4.0" 34 | base64 = "0.13.0" 35 | tokio = { version = "1.19.2", features = ["full"] } 36 | async-trait = "0.1.66" 37 | yup-oauth2 = "8.3.2" 38 | futures = "0.3" -------------------------------------------------------------------------------- /src/assistants/openai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod openai_assistant; 2 | pub mod openai_vector_store; 3 | 4 | pub use openai_assistant::OpenAIAssistant; 5 | pub use openai_vector_store::{ 6 | OpenAIVectorStore, OpenAIVectorStoreFileCounts, OpenAIVectorStoreStatus, 7 | }; 8 | 9 | // Re-export structs for backwards compatibility 10 | pub use crate::apis::{OpenAIAssistantResource, OpenAIAssistantVersion}; 11 | pub use crate::files::{LLMFiles, OpenAIFile}; 12 | 13 | // Add inherent methods for OpenAIFile for backwards compatibility 14 | impl OpenAIFile { 15 | pub fn new(id: Option, api_key: &str) -> Self { 16 | ::new(id, api_key) 17 | } 18 | 19 | pub fn debug(self) -> Self { 20 | ::debug(self) 21 | } 22 | 23 | pub async fn upload(self, file_name: &str, file_bytes: Vec) -> anyhow::Result { 24 | ::upload(self, file_name, file_bytes).await 25 | } 26 | 27 | pub async fn delete(&self) -> anyhow::Result<()> { 28 | ::delete(self).await 29 | } 30 | 31 | pub fn get_id(&self) -> Option<&String> { 32 | ::get_id(self) 33 | } 34 | 35 | pub fn is_debug(&self) -> bool { 36 | ::is_debug(self) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/apis/anthropic.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::constants::{ANTHROPIC_FILES_VERSION, ANTHROPIC_MESSAGES_VERSION}; 4 | 5 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 6 | pub enum AnthropicApiEndpoints { 7 | Messages { version: String }, 8 | Files { version: String }, 9 | } 10 | 11 | impl AnthropicApiEndpoints { 12 | pub fn messages_default() -> Self { 13 | AnthropicApiEndpoints::Messages { 14 | version: ANTHROPIC_MESSAGES_VERSION.to_string(), 15 | } 16 | } 17 | 18 | pub fn messages(version: String) -> Self { 19 | AnthropicApiEndpoints::Messages { version } 20 | } 21 | 22 | pub fn files_default() -> Self { 23 | AnthropicApiEndpoints::Files { 24 | version: ANTHROPIC_FILES_VERSION.to_string(), 25 | } 26 | } 27 | 28 | pub fn files(version: String) -> Self { 29 | AnthropicApiEndpoints::Files { version } 30 | } 31 | 32 | pub fn version(&self) -> String { 33 | match self { 34 | AnthropicApiEndpoints::Messages { version } => version.to_string(), 35 | AnthropicApiEndpoints::Files { version } => version.to_string(), 36 | } 37 | } 38 | 39 | pub fn version_static(&self) -> &'static str { 40 | match self { 41 | AnthropicApiEndpoints::Messages { .. } => ANTHROPIC_MESSAGES_VERSION.as_str(), 42 | AnthropicApiEndpoints::Files { .. } => ANTHROPIC_FILES_VERSION.as_str(), 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/enums.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Deserialize, Serialize, Debug, Clone)] 4 | pub enum OpenAIToolTypes { 5 | #[serde(rename(deserialize = "code_interpreter", serialize = "code_interpreter"))] 6 | CodeInterpreter, 7 | #[serde(rename(deserialize = "retrieval", serialize = "retrieval"))] 8 | Retrieval, 9 | #[serde(rename(deserialize = "file_search", serialize = "file_search"))] 10 | FileSearch, 11 | } 12 | 13 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 14 | pub enum OpenAIAssistantRole { 15 | #[serde(rename(deserialize = "user", serialize = "user"))] 16 | User, 17 | #[serde(rename(deserialize = "assistant", serialize = "assistant"))] 18 | Assistant, 19 | } 20 | 21 | #[derive(Deserialize, Serialize, Debug, Clone)] 22 | pub enum OpenAIRunStatus { 23 | #[serde(rename(deserialize = "queued", serialize = "queued"))] 24 | Queued, 25 | #[serde(rename(deserialize = "in_progress", serialize = "in_progress"))] 26 | InProgress, 27 | #[serde(rename(deserialize = "requires_action", serialize = "requires_action"))] 28 | RequiresAction, 29 | #[serde(rename(deserialize = "cancelling", serialize = "cancelling"))] 30 | Cancelling, 31 | #[serde(rename(deserialize = "cancelled", serialize = "cancelled"))] 32 | Cancelled, 33 | #[serde(rename(deserialize = "failed", serialize = "failed"))] 34 | Failed, 35 | #[serde(rename(deserialize = "completed", serialize = "completed"))] 36 | Completed, 37 | #[serde(rename(deserialize = "expired", serialize = "expired"))] 38 | Expired, 39 | } 40 | -------------------------------------------------------------------------------- /src/files/llm_files.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | 4 | /// Trait for LLM file operations across different providers 5 | /// 6 | /// This trait provides a common interface for file operations such as 7 | /// creation, debugging, uploading, and deletion of files for use with 8 | /// LLM assistants and completions. 9 | #[async_trait(?Send)] 10 | pub trait LLMFiles: Send + Sync + Sized { 11 | /// Create a new file instance 12 | /// 13 | /// # Arguments 14 | /// * `id` - Optional file ID (for existing files) 15 | /// * `api_key` - API key for the LLM provider 16 | fn new(id: Option, api_key: &str) -> Self; 17 | 18 | /// Enable debug mode for the file instance 19 | /// 20 | /// Returns the modified instance for method chaining 21 | fn debug(self) -> Self; 22 | 23 | /// Upload a file to the LLM provider 24 | /// 25 | /// # Arguments 26 | /// * `file_name` - Name of the file to upload 27 | /// * `file_bytes` - File content as bytes 28 | /// 29 | /// # Returns 30 | /// * `Result` - The file instance with updated ID on success 31 | async fn upload(self, file_name: &str, file_bytes: Vec) -> Result; 32 | 33 | /// Delete a file from the LLM provider 34 | /// 35 | /// # Returns 36 | /// * `Result<()>` - Success or error 37 | async fn delete(&self) -> Result<()>; 38 | 39 | /// Get the file ID if available 40 | /// 41 | /// # Returns 42 | /// * `Option<&String>` - The file ID if it exists 43 | fn get_id(&self) -> Option<&String>; 44 | 45 | /// Check if debug mode is enabled 46 | /// 47 | /// # Returns 48 | /// * `bool` - True if debug mode is enabled 49 | fn is_debug(&self) -> bool; 50 | } 51 | -------------------------------------------------------------------------------- /examples/use_completions_vertex.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | 6 | use allms::{llm::GoogleModels, Completions}; 7 | 8 | mod utils; 9 | use utils::get_vertex_token; 10 | 11 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 12 | struct TranslationResponse { 13 | pub spanish: String, 14 | pub french: String, 15 | pub german: String, 16 | pub polish: String, 17 | } 18 | 19 | #[tokio::main] 20 | async fn main() -> Result<()> { 21 | env_logger::init(); 22 | 23 | // Get Vertex API authentication token 24 | let google_token_str = get_vertex_token().await?; 25 | 26 | // Example context and instructions 27 | let instructions = 28 | "Translate the following English sentence to all the languages in the response type: Rust is best for working with LLMs"; 29 | 30 | // Get answer using Google GeminiPro via Vertex AI 31 | let model = GoogleModels::Gemini2_5FlashLite; 32 | 33 | // **Pre-requisite**: GeminiPro request through Vertex AI require `GOOGLE_PROJECT_ID` environment variable defined 34 | let gemini_completion = 35 | Completions::new(model, &google_token_str, None, None).version("google-vertex"); 36 | 37 | match gemini_completion 38 | .get_answer::(instructions) 39 | .await 40 | { 41 | Ok(response) => println!("Vertex Gemini response: {:#?}", response), 42 | Err(e) => eprintln!("Error: {:?}", e), 43 | } 44 | 45 | // Get answer using a fine-tuned model 46 | 47 | // Using a fine-tuned model requires addressing the endpoint directly 48 | // Replace env variable with the endpoint ID of the fine-tuned model 49 | let fine_tuned_endpoint_id: String = 50 | std::env::var("GOOGLE_VERTEX_ENDPOINT_ID").expect("GOOGLE_VERTEX_ENDPOINT_ID not set"); 51 | let model = GoogleModels::endpoint(&fine_tuned_endpoint_id); 52 | 53 | let gemini_completion = 54 | Completions::new(model, &google_token_str, None, None).version("google-vertex"); 55 | 56 | match gemini_completion 57 | .get_answer::(instructions) 58 | .await 59 | { 60 | Ok(response) => println!("Vertex Gemini response: {:#?}", response), 61 | Err(e) => eprintln!("Error: {:?}", e), 62 | } 63 | 64 | Ok(()) 65 | } 66 | -------------------------------------------------------------------------------- /examples/deprecated_use_openai_assistant.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::OsStr; 2 | use std::path::Path; 3 | 4 | use allms::OpenAIAssistant; 5 | use allms::OpenAIAssistantVersion; 6 | use allms::OpenAIFile; 7 | use allms::OpenAIModels; 8 | 9 | use anyhow::{anyhow, Result}; 10 | use schemars::JsonSchema; 11 | use serde::Deserialize; 12 | use serde::Serialize; 13 | 14 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 15 | pub struct ConcertInfo { 16 | dates: Vec, 17 | band: String, 18 | genre: String, 19 | venue: String, 20 | city: String, 21 | country: String, 22 | ticket_price: String, 23 | } 24 | 25 | #[tokio::main] 26 | async fn main() -> Result<()> { 27 | env_logger::init(); 28 | let api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 29 | // Read concert file 30 | let path = Path::new("metallica.pdf"); 31 | let bytes = std::fs::read(path)?; 32 | let file_name = path 33 | .file_name() 34 | .and_then(OsStr::to_str) 35 | .map(|s| s.to_string()) 36 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 37 | 38 | let openai_file = OpenAIFile::new(&file_name, bytes, &api_key, true).await?; 39 | 40 | let bands_genres = vec![ 41 | ("Metallica", "Metal"), 42 | ("The Beatles", "Rock"), 43 | ("Daft Punk", "Electronic"), 44 | ("Miles Davis", "Jazz"), 45 | ("Johnny Cash", "Country"), 46 | ]; 47 | 48 | // Extract concert information using Assistant API 49 | let concert_info = OpenAIAssistant::new(OpenAIModels::Gpt4o, &api_key, true) 50 | .await? 51 | // Constructor defaults to V1 52 | .version(OpenAIAssistantVersion::V2) 53 | .set_context( 54 | "bands_genres", 55 | &bands_genres 56 | ) 57 | .await? 58 | .get_answer::( 59 | "Extract the information requested in the response type from the attached concert information. 60 | The response should include the genre of the music the 'band' represents. 61 | The mapping of bands to genres was provided in 'bands_genres' list in a previous message.", 62 | &[openai_file.id.clone()], 63 | ) 64 | .await?; 65 | 66 | println!("Concert Info: {:?}", concert_info); 67 | 68 | //Remove the file from OpenAI 69 | openai_file.delete_file().await?; 70 | 71 | Ok(()) 72 | } 73 | -------------------------------------------------------------------------------- /examples/use_xai_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | 6 | use allms::{ 7 | llm::{ 8 | tools::{LLMTools, XAISearchSource, XAIWebSearchConfig}, 9 | XAIModels, 10 | }, 11 | Completions, 12 | }; 13 | 14 | // Example 1: Web search 15 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 16 | struct AINewsArticles { 17 | pub articles: Vec, 18 | } 19 | 20 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 21 | struct AINewsArticle { 22 | pub title: String, 23 | pub url: String, 24 | pub description: String, 25 | } 26 | 27 | #[tokio::main] 28 | async fn main() -> Result<()> { 29 | env_logger::init(); 30 | 31 | let xai_api_key: String = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set"); 32 | 33 | // Example 1: Web search example with multiple sources 34 | let web_search_config = XAIWebSearchConfig::new() 35 | .add_source( 36 | XAISearchSource::web() 37 | .with_allowed_sites(vec!["techcrunch.com".to_string(), "wired.com".to_string()]) 38 | .with_country("US".to_string()) 39 | .with_safe_search(true), 40 | ) 41 | .add_source( 42 | XAISearchSource::news() 43 | .with_excluded_sites(vec!["tabloid.com".to_string()]) 44 | .with_country("US".to_string()) 45 | .with_safe_search(true), 46 | ) 47 | .add_source( 48 | XAISearchSource::x() 49 | .with_included_handles(vec![ 50 | "openai".to_string(), 51 | "anthropic".to_string(), 52 | "googleai".to_string(), 53 | ]) 54 | .with_favorite_count(100) 55 | .with_view_count(1000), 56 | ) 57 | .max_search_results(10) 58 | .return_citations(true); 59 | 60 | let web_search_tool = LLMTools::XAIWebSearch(web_search_config); 61 | let xai_responses = 62 | Completions::new(XAIModels::Grok3Mini, &xai_api_key, None, None).add_tool(web_search_tool); 63 | 64 | match xai_responses 65 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 66 | For each news item, provide the title, url, and a short description.") 67 | .await 68 | { 69 | Ok(response) => println!("AI news articles:\n{:#?}", response), 70 | Err(e) => eprintln!("Error: {:?}", e), 71 | } 72 | 73 | Ok(()) 74 | } 75 | -------------------------------------------------------------------------------- /examples/use_mistral_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | 6 | use allms::{ 7 | llm::{ 8 | tools::{LLMTools, MistralCodeInterpreterConfig, MistralWebSearchConfig}, 9 | MistralModels, 10 | }, 11 | Completions, 12 | }; 13 | 14 | // Example 1: Web search 15 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 16 | struct AINewsArticles { 17 | pub articles: Vec, 18 | } 19 | 20 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 21 | struct AINewsArticle { 22 | pub title: String, 23 | pub url: String, 24 | pub description: String, 25 | } 26 | 27 | // Example 2: Code interpreter example 28 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 29 | pub struct CodeInterpreterResponse { 30 | pub problem: String, 31 | pub code: String, 32 | pub solution: String, 33 | } 34 | 35 | #[tokio::main] 36 | async fn main() -> Result<()> { 37 | env_logger::init(); 38 | 39 | let mistral_api_key: String = 40 | std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); 41 | 42 | // Example 1: Web search example 43 | let web_search_tool = LLMTools::MistralWebSearch(MistralWebSearchConfig::new()); 44 | let mistral_responses = Completions::new( 45 | MistralModels::MistralMedium3_1, 46 | &mistral_api_key, 47 | None, 48 | None, 49 | ) 50 | .add_tool(web_search_tool); 51 | 52 | match mistral_responses 53 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 54 | For each news item, provide the title, url, and a short description.") 55 | .await 56 | { 57 | Ok(response) => println!("AI news articles:\n{:#?}", response), 58 | Err(e) => eprintln!("Error: {:?}", e), 59 | } 60 | 61 | // Example 2: Code interpreter example 62 | let code_interpreter_tool = 63 | LLMTools::MistralCodeInterpreter(MistralCodeInterpreterConfig::new()); 64 | let mistral_responses = Completions::new( 65 | MistralModels::MistralMedium3_1, 66 | &mistral_api_key, 67 | None, 68 | None, 69 | ) 70 | .add_tool(code_interpreter_tool); 71 | 72 | match mistral_responses 73 | .get_answer::( 74 | "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", 75 | ) 76 | .await 77 | { 78 | Ok(response) => println!("Code interpreter response:\n{:#?}", response), 79 | Err(e) => eprintln!("Error: {:?}", e), 80 | } 81 | 82 | Ok(()) 83 | } 84 | -------------------------------------------------------------------------------- /examples/use_openai_assistant.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::OsStr; 2 | use std::path::Path; 3 | 4 | use allms::assistants::{OpenAIAssistant, OpenAIAssistantVersion, OpenAIFile, OpenAIVectorStore}; 5 | use allms::llm::OpenAIModels; 6 | 7 | use anyhow::{anyhow, Result}; 8 | use schemars::JsonSchema; 9 | use serde::Deserialize; 10 | use serde::Serialize; 11 | 12 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 13 | pub struct ConcertInfo { 14 | dates: Vec, 15 | band: String, 16 | genre: String, 17 | venue: String, 18 | city: String, 19 | country: String, 20 | ticket_price: String, 21 | } 22 | 23 | #[tokio::main] 24 | async fn main() -> Result<()> { 25 | env_logger::init(); 26 | let api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 27 | // Read concert file 28 | let path = Path::new("metallica.pdf"); 29 | let bytes = std::fs::read(path)?; 30 | let file_name = path 31 | .file_name() 32 | .and_then(OsStr::to_str) 33 | .map(|s| s.to_string()) 34 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 35 | 36 | let openai_file = OpenAIFile::new(None, &api_key) 37 | .debug() 38 | .version(OpenAIAssistantVersion::V2) 39 | .upload(&file_name, bytes) 40 | .await?; 41 | 42 | let bands_genres = vec![ 43 | ("Metallica", "Metal"), 44 | ("The Beatles", "Rock"), 45 | ("Daft Punk", "Electronic"), 46 | ("Miles Davis", "Jazz"), 47 | ("Johnny Cash", "Country"), 48 | ]; 49 | 50 | // Create a Vector Store and assign the file to it 51 | let openai_vector_store = OpenAIVectorStore::new(None, "Concerts", &api_key) 52 | .debug() 53 | .version(OpenAIAssistantVersion::V2) 54 | .upload(&[openai_file.id.clone().unwrap_or_default()]) 55 | .await?; 56 | 57 | let status = openai_vector_store.status().await?; 58 | println!( 59 | ">>> Vector Store: {:?}; Status: {:?}", 60 | &openai_vector_store.id, &status 61 | ); 62 | 63 | let file_count = openai_vector_store.file_count().await?; 64 | println!( 65 | ">>> Vector Store: {:?}; File count: {:?}", 66 | &openai_vector_store.id, &file_count 67 | ); 68 | 69 | // Extract concert information using Assistant API 70 | let concert_info = OpenAIAssistant::new(OpenAIModels::Gpt4oMini, &api_key) 71 | .debug() 72 | // Constructor defaults to V1 73 | .version(OpenAIAssistantVersion::V2) 74 | .poll_interval(5) // Set response polling to every 5 sec 75 | .vector_store(openai_vector_store.clone()) 76 | .await? 77 | .set_context( 78 | "bands_genres", 79 | &bands_genres 80 | ) 81 | .await? 82 | .get_answer::( 83 | "Extract the information requested in the response type from the attached concert information. 84 | The response should include the genre of the music the 'band' represents. 85 | The mapping of bands to genres was provided in 'bands_genres' list in a previous message.", 86 | &[], 87 | ) 88 | .await?; 89 | 90 | println!(">>> Concert Info: {:#?}", concert_info); 91 | 92 | //Remove the file from OpenAI 93 | openai_file.delete().await?; 94 | 95 | // Delete the Vector Store 96 | openai_vector_store.delete().await?; 97 | 98 | Ok(()) 99 | } 100 | -------------------------------------------------------------------------------- /examples/use_openai_assistant_json.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::OsStr; 2 | use std::path::Path; 3 | 4 | use allms::assistants::{OpenAIAssistant, OpenAIAssistantVersion, OpenAIFile, OpenAIVectorStore}; 5 | use allms::llm::OpenAIModels; 6 | 7 | use anyhow::{anyhow, Result}; 8 | 9 | const CONCERT_INFO_SCHEMA: &str = r#" 10 | { 11 | "type": "object", 12 | "properties": { 13 | "dates": { 14 | "type": "array", 15 | "items": { 16 | "type": "string" 17 | } 18 | }, 19 | "band": { 20 | "type": "string" 21 | }, 22 | "genre": { 23 | "type": "string" 24 | }, 25 | "venue": { 26 | "type": "string" 27 | }, 28 | "city": { 29 | "type": "string" 30 | }, 31 | "country": { 32 | "type": "string" 33 | }, 34 | "ticket_price": { 35 | "type": "string" 36 | } 37 | } 38 | } 39 | "#; 40 | 41 | #[tokio::main] 42 | async fn main() -> Result<()> { 43 | env_logger::init(); 44 | let api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 45 | // Read concert file 46 | let path = Path::new("metallica.pdf"); 47 | let bytes = std::fs::read(path)?; 48 | let file_name = path 49 | .file_name() 50 | .and_then(OsStr::to_str) 51 | .map(|s| s.to_string()) 52 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 53 | 54 | let openai_file = OpenAIFile::new(None, &api_key) 55 | .debug() 56 | .upload(&file_name, bytes) 57 | .await?; 58 | 59 | let bands_genres = vec![ 60 | ("Metallica", "Metal"), 61 | ("The Beatles", "Rock"), 62 | ("Daft Punk", "Electronic"), 63 | ("Miles Davis", "Jazz"), 64 | ("Johnny Cash", "Country"), 65 | ]; 66 | 67 | // Create a Vector Store and assign the file to it 68 | let openai_vector_store = OpenAIVectorStore::new(None, "Concerts", &api_key) 69 | .debug() 70 | .upload(&[openai_file.id.clone().unwrap_or_default()]) 71 | .await?; 72 | 73 | let status = openai_vector_store.status().await?; 74 | println!( 75 | "Vector Store: {:?}; Status: {:?}", 76 | &openai_vector_store.id, &status 77 | ); 78 | 79 | let file_count = openai_vector_store.file_count().await?; 80 | println!( 81 | "Vector Store: {:?}; File count: {:?}", 82 | &openai_vector_store.id, &file_count 83 | ); 84 | 85 | // Extract concert information using Assistant API 86 | let concert_info = OpenAIAssistant::new(OpenAIModels::Gpt4_1Mini, &api_key) 87 | .debug() 88 | // Constructor defaults to V1 89 | .version(OpenAIAssistantVersion::V2) 90 | .vector_store(openai_vector_store.clone()) 91 | .await? 92 | .set_context( 93 | "bands_genres", 94 | &bands_genres 95 | ) 96 | .await? 97 | .get_json_answer( 98 | "Extract the information requested in the response type from the attached concert information. 99 | The response should include the genre of the music the 'band' represents. 100 | The mapping of bands to genres was provided in 'bands_genres' list in a previous message.", 101 | CONCERT_INFO_SCHEMA, 102 | &[], 103 | ) 104 | .await?; 105 | 106 | println!("Concert Info: {:#?}", concert_info); 107 | 108 | //Remove the file from OpenAI 109 | openai_file.delete().await?; 110 | 111 | // Delete the Vector Store 112 | openai_vector_store.delete().await?; 113 | 114 | Ok(()) 115 | } 116 | -------------------------------------------------------------------------------- /examples/use_openai_assistant_azure.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::OsStr; 2 | use std::path::Path; 3 | 4 | use allms::assistants::{OpenAIAssistant, OpenAIAssistantVersion, OpenAIFile, OpenAIVectorStore}; 5 | use allms::llm::OpenAIModels; 6 | 7 | use anyhow::{anyhow, Result}; 8 | use schemars::JsonSchema; 9 | use serde::Deserialize; 10 | use serde::Serialize; 11 | 12 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 13 | pub struct ConcertInfo { 14 | dates: Vec, 15 | band: String, 16 | genre: String, 17 | venue: String, 18 | city: String, 19 | country: String, 20 | ticket_price: String, 21 | } 22 | 23 | #[tokio::main] 24 | async fn main() -> Result<()> { 25 | env_logger::init(); 26 | // Ensure `OPENAI_API_URL` is set to your Azure OpenAI resource endpoint 27 | let api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 28 | // Read concert file 29 | let path = Path::new("metallica.pdf"); 30 | let bytes = std::fs::read(path)?; 31 | let file_name = path 32 | .file_name() 33 | .and_then(OsStr::to_str) 34 | .map(|s| s.to_string()) 35 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 36 | 37 | // Set API version to Azure 38 | let openai_file = OpenAIFile::new(None, &api_key) 39 | .debug() 40 | .version(OpenAIAssistantVersion::AzureVersion { 41 | version: "2024-06-01".to_string(), 42 | }) 43 | .upload(&file_name, bytes) 44 | .await?; 45 | 46 | let bands_genres = vec![ 47 | ("Metallica", "Metal"), 48 | ("The Beatles", "Rock"), 49 | ("Daft Punk", "Electronic"), 50 | ("Miles Davis", "Jazz"), 51 | ("Johnny Cash", "Country"), 52 | ]; 53 | 54 | // Create a Vector Store and assign the file to it 55 | let openai_vector_store = OpenAIVectorStore::new(None, "Concerts", &api_key) 56 | .debug() 57 | .version(OpenAIAssistantVersion::AzureVersion { 58 | version: "2024-06-01".to_string(), 59 | }) 60 | .upload(&[openai_file.id.clone().unwrap_or_default()]) 61 | .await?; 62 | 63 | let status = openai_vector_store.status().await?; 64 | println!( 65 | ">>> Vector Store: {:?}; Status: {:?}", 66 | &openai_vector_store.id, &status 67 | ); 68 | 69 | let file_count = openai_vector_store.file_count().await?; 70 | println!( 71 | ">>> Vector Store: {:?}; File count: {:?}", 72 | &openai_vector_store.id, &file_count 73 | ); 74 | 75 | // Ensure model deployment name in Azure OpenAI Studio matches that of the used model `as_str` representation 76 | let concert_info = OpenAIAssistant::new(OpenAIModels::Gpt4oMini, &api_key) 77 | .debug() 78 | // Constructor defaults to V1 79 | .version(OpenAIAssistantVersion::AzureVersion { version: "2024-06-01".to_string(), }) 80 | .vector_store(openai_vector_store.clone()) 81 | .await? 82 | .set_context( 83 | "bands_genres", 84 | &bands_genres 85 | ) 86 | .await? 87 | .get_answer::( 88 | "Extract the information requested in the response type from the attached concert information. 89 | The response should include the genre of the music the 'band' represents. 90 | The mapping of bands to genres was provided in 'bands_genres' list in a previous message.", 91 | &[], 92 | ) 93 | .await?; 94 | 95 | println!(">>> Concert Info: {:#?}", concert_info); 96 | 97 | //Remove the file from OpenAI 98 | openai_file.delete().await?; 99 | 100 | // Delete the Vector Store 101 | openai_vector_store.delete().await?; 102 | 103 | Ok(()) 104 | } 105 | -------------------------------------------------------------------------------- /examples/use_google_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | 6 | use allms::{ 7 | llm::{ 8 | tools::{GeminiCodeInterpreterConfig, GeminiWebSearchConfig, LLMTools}, 9 | GoogleModels, 10 | }, 11 | Completions, ThinkingLevel, 12 | }; 13 | 14 | mod utils; 15 | use utils::get_vertex_token; 16 | 17 | // Example 1: Web search 18 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 19 | struct AINewsArticles { 20 | pub articles: Vec, 21 | } 22 | 23 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 24 | struct AINewsArticle { 25 | pub title: String, 26 | pub url: String, 27 | pub description: String, 28 | } 29 | 30 | // Example 2: Code interpreter example 31 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 32 | pub struct CodeInterpreterResponse { 33 | pub problem: String, 34 | pub code: String, 35 | pub solution: String, 36 | } 37 | 38 | #[tokio::main] 39 | async fn main() -> Result<()> { 40 | env_logger::init(); 41 | 42 | let google_api_key: String = 43 | std::env::var("GOOGLE_AI_STUDIO_API_KEY").expect("GOOGLE_AI_STUDIO_API_KEY not set"); 44 | let vertex_token = get_vertex_token().await?; 45 | 46 | // Example 1A: Web search example (with Studio API) 47 | let web_search_config = 48 | GeminiWebSearchConfig::new().add_source("https://www.artificialintelligence-news.com/"); 49 | 50 | let web_search_tool = LLMTools::GeminiWebSearch(web_search_config); 51 | let google_responses = Completions::new(GoogleModels::Gemini3Pro, &google_api_key, None, None) 52 | .add_tool(web_search_tool.clone()) 53 | .thinking_level(ThinkingLevel::Low); 54 | 55 | match google_responses 56 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 57 | For each news item, provide the title, url, and a short description.") 58 | .await 59 | { 60 | Ok(response) => println!("[AI Studio] AI news articles:\n{:#?}", response), 61 | Err(e) => eprintln!("[AI Studio] AI news articles error: {:?}", e), 62 | } 63 | 64 | // Example 1B: Web search example (with Vertex API) 65 | let google_responses_vertex = 66 | Completions::new(GoogleModels::Gemini2_5Flash, &vertex_token, None, None) 67 | .add_tool(web_search_tool) 68 | .version("google-vertex"); 69 | 70 | match google_responses_vertex 71 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 72 | For each news item, provide the title, url, and a short description.") 73 | .await 74 | { 75 | Ok(response) => println!("[Vertex] AI news articles:\n{:#?}", response), 76 | Err(e) => eprintln!("[Vertex] AI news articles error: {:?}", e), 77 | } 78 | 79 | // Example 2A: Code interpreter example (with Studio API) 80 | let code_interpreter_tool = LLMTools::GeminiCodeInterpreter(GeminiCodeInterpreterConfig::new()); 81 | let google_responses = Completions::new(GoogleModels::Gemini3Pro, &google_api_key, None, None) 82 | .add_tool(code_interpreter_tool.clone()); 83 | 84 | match google_responses 85 | .get_answer::( 86 | "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", 87 | ) 88 | .await 89 | { 90 | Ok(response) => println!("[AI Studio] Code interpreter response:\n{:#?}", response), 91 | Err(e) => eprintln!("[AI Studio] Code interpreter error: {:?}", e), 92 | } 93 | 94 | // Example 2B: Code interpreter example (with Vertex API) 95 | let google_responses_vertex = 96 | Completions::new(GoogleModels::Gemini2_5Pro, &vertex_token, None, None) 97 | .add_tool(code_interpreter_tool) 98 | .version("google-vertex"); 99 | 100 | match google_responses_vertex 101 | .get_answer::( 102 | "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", 103 | ) 104 | .await 105 | { 106 | Ok(response) => println!("[Vertex] Code interpreter response:\n{:#?}", response), 107 | Err(e) => eprintln!("[Vertex] Code interpreter error: {:?}", e), 108 | } 109 | 110 | Ok(()) 111 | } 112 | -------------------------------------------------------------------------------- /examples/use_anthropic_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | use std::ffi::OsStr; 6 | use std::path::Path; 7 | 8 | use allms::{ 9 | files::{AnthropicFile, LLMFiles}, 10 | llm::{ 11 | tools::{ 12 | AnthropicCodeExecutionConfig, AnthropicFileSearchConfig, AnthropicWebSearchConfig, 13 | LLMTools, 14 | }, 15 | AnthropicModels, 16 | }, 17 | Completions, 18 | }; 19 | 20 | // Example 1: Web search 21 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 22 | struct AINewsArticles { 23 | pub articles: Vec, 24 | } 25 | 26 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 27 | struct AINewsArticle { 28 | pub title: String, 29 | pub url: String, 30 | pub description: String, 31 | } 32 | 33 | // Example 2: Code interpreter example 34 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 35 | pub struct CodeInterpreterResponse { 36 | pub problem: String, 37 | pub code: String, 38 | pub solution: String, 39 | } 40 | 41 | // Example 3: File search 42 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 43 | pub struct ConcertInfo { 44 | dates: Vec, 45 | band: String, 46 | genre: String, 47 | venue: String, 48 | city: String, 49 | country: String, 50 | ticket_price: String, 51 | } 52 | 53 | const BANDS_GENRES: &[(&str, &str)] = &[ 54 | ("Metallica", "Metal"), 55 | ("The Beatles", "Rock"), 56 | ("Daft Punk", "Electronic"), 57 | ("Miles Davis", "Jazz"), 58 | ("Johnny Cash", "Country"), 59 | ]; 60 | 61 | #[tokio::main] 62 | async fn main() -> Result<()> { 63 | env_logger::init(); 64 | 65 | let anthropic_api_key: String = 66 | std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); 67 | 68 | // Example 1: Web search example 69 | let web_search_tool = LLMTools::AnthropicWebSearch(AnthropicWebSearchConfig::new()); 70 | let anthropic_responses = Completions::new( 71 | AnthropicModels::Claude4_5Opus, 72 | &anthropic_api_key, 73 | None, 74 | None, 75 | ) 76 | .add_tool(web_search_tool); 77 | 78 | match anthropic_responses 79 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 80 | For each news item, provide the title, url, and a short description.") 81 | .await 82 | { 83 | Ok(response) => println!("AI news articles:\n{:#?}", response), 84 | Err(e) => eprintln!("Error: {:?}", e), 85 | } 86 | 87 | // Example 2: Code interpreter example 88 | 89 | let code_interpreter_tool = 90 | LLMTools::AnthropicCodeExecution(AnthropicCodeExecutionConfig::new()); 91 | let anthropic_responses = Completions::new( 92 | AnthropicModels::Claude4_5Opus, 93 | &anthropic_api_key, 94 | None, 95 | None, 96 | ) 97 | .add_tool(code_interpreter_tool); 98 | 99 | match anthropic_responses 100 | .get_answer::( 101 | "Calculate the mean and standard deviation of [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", 102 | ) 103 | .await 104 | { 105 | Ok(response) => println!("Code interpreter response:\n{:#?}", response), 106 | Err(e) => eprintln!("Error: {:?}", e), 107 | } 108 | 109 | // Example 3: File search example 110 | 111 | // Read the concert file and upload it to Anthropic 112 | let path = Path::new("metallica.pdf"); 113 | let bytes = std::fs::read(path)?; 114 | let file_name = path 115 | .file_name() 116 | .and_then(OsStr::to_str) 117 | .map(|s| s.to_string()) 118 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 119 | 120 | let anthropic_file = AnthropicFile::new(None, &anthropic_api_key) 121 | .upload(&file_name, bytes) 122 | .await?; 123 | 124 | // Extract concert information using Anthropic API with file search tool 125 | let file_search_tool = LLMTools::AnthropicFileSearch(AnthropicFileSearchConfig::new( 126 | anthropic_file.id.clone().unwrap_or_default(), 127 | )); 128 | 129 | let anthropic_responses = Completions::new( 130 | AnthropicModels::Claude4_5Opus, 131 | &anthropic_api_key, 132 | None, 133 | None, 134 | ) 135 | .set_context("bands_genres", &BANDS_GENRES)? 136 | .add_tool(file_search_tool); 137 | 138 | match anthropic_responses 139 | .get_answer::("Extract the information requested in the response type from the attached concert information. 140 | The response should include the genre of the music the 'band' represents. 141 | The mapping of bands to genres was provided in 'bands_genres' list.") 142 | .await 143 | { 144 | Ok(response) => println!("Concert Info:\n{:#?}", response), 145 | Err(e) => eprintln!("Error: {:?}", e), 146 | } 147 | 148 | // Cleanup 149 | anthropic_file.delete().await?; 150 | 151 | Ok(()) 152 | } 153 | -------------------------------------------------------------------------------- /src/constants.rs: -------------------------------------------------------------------------------- 1 | use lazy_static::lazy_static; 2 | 3 | lazy_static! { 4 | pub(crate) static ref OPENAI_API_URL: String = 5 | std::env::var("OPENAI_API_URL").unwrap_or("https://api.openai.com".to_string()); 6 | } 7 | 8 | lazy_static! { 9 | pub(crate) static ref ANTHROPIC_API_URL: String = std::env::var("ANTHROPIC_API_URL") 10 | .unwrap_or("https://api.anthropic.com/v1/complete".to_string()); 11 | pub(crate) static ref ANTHROPIC_MESSAGES_API_URL: String = 12 | std::env::var("ANTHROPIC_MESSAGES_API_URL") 13 | .unwrap_or("https://api.anthropic.com/v1/messages".to_string()); 14 | pub(crate) static ref ANTHROPIC_MESSAGES_VERSION: String = 15 | std::env::var("ANTHROPIC_MESSAGES_VERSION").unwrap_or("2023-06-01".to_string()); 16 | pub(crate) static ref ANTHROPIC_FILES_VERSION: String = 17 | std::env::var("ANTHROPIC_FILES_VERSION").unwrap_or("files-api-2025-04-14".to_string()); 18 | pub(crate) static ref ANTHROPIC_FILES_API_URL: String = 19 | std::env::var("ANTHROPIC_FILES_API_URL") 20 | .unwrap_or("https://api.anthropic.com/v1/files".to_string()); 21 | } 22 | 23 | lazy_static! { 24 | pub(crate) static ref MISTRAL_API_URL: String = std::env::var("MISTRAL_API_URL") 25 | .unwrap_or("https://api.mistral.ai/v1/chat/completions".to_string()); 26 | pub(crate) static ref MISTRAL_CONVERSATIONS_API_URL: String = 27 | std::env::var("MISTRAL_CONVERSATIONS_API_URL") 28 | .unwrap_or("https://api.mistral.ai/v1/conversations".to_string()); 29 | } 30 | 31 | lazy_static! { 32 | pub(crate) static ref GOOGLE_VERTEX_API_URL: String = { 33 | let region = std::env::var("GOOGLE_REGION").unwrap_or("us-central1".to_string()); 34 | let project_id = std::env::var("GOOGLE_PROJECT_ID").expect("PROJECT_ID not set"); 35 | 36 | format!("https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models", 37 | region, project_id, region) 38 | }; 39 | pub(crate) static ref GOOGLE_VERTEX_ENDPOINT_API_URL: String = { 40 | let region = std::env::var("GOOGLE_REGION").unwrap_or("us-central1".to_string()); 41 | let project_id = std::env::var("GOOGLE_PROJECT_ID").expect("PROJECT_ID not set"); 42 | 43 | format!("https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/endpoints") 44 | }; 45 | } 46 | 47 | lazy_static! { 48 | pub(crate) static ref GOOGLE_GEMINI_API_URL: String = std::env::var("GOOGLE_GEMINI_API_URL") 49 | .unwrap_or("https://generativelanguage.googleapis.com/v1/models".to_string()); 50 | pub(crate) static ref GOOGLE_GEMINI_BETA_API_URL: String = 51 | std::env::var("GOOGLE_GEMINI_BETA_API_URL") 52 | .unwrap_or("https://generativelanguage.googleapis.com/v1beta/models".to_string()); 53 | } 54 | 55 | lazy_static! { 56 | pub(crate) static ref PERPLEXITY_API_URL: String = std::env::var("PERPLEXITY_API_URL") 57 | .unwrap_or("https://api.perplexity.ai/chat/completions".to_string()); 58 | } 59 | 60 | lazy_static! { 61 | /// Docs: https://docs.aws.amazon.com/general/latest/gr/bedrock.html 62 | pub(crate) static ref AWS_REGION: String = std::env::var("AWS_REGION").unwrap_or("us-east-1".to_string()); 63 | pub(crate) static ref AWS_BEDROCK_API_URL: String = { 64 | format!("https://bedrock.{}.amazonaws.com", &*AWS_REGION) 65 | }; 66 | } 67 | 68 | lazy_static! { 69 | pub(crate) static ref AWS_ACCESS_KEY_ID: String = 70 | std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not set"); 71 | pub(crate) static ref AWS_SECRET_ACCESS_KEY: String = 72 | std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not set"); 73 | } 74 | 75 | lazy_static! { 76 | pub(crate) static ref DEEPSEEK_API_URL: String = std::env::var("DEEPSEEK_API_URL") 77 | .unwrap_or("https://api.deepseek.com/chat/completions".to_string()); 78 | } 79 | 80 | lazy_static! { 81 | pub(crate) static ref XAI_API_URL: String = 82 | std::env::var("XAI_API_URL").unwrap_or("https://api.x.ai/v1/chat/completions".to_string()); 83 | } 84 | 85 | //Generic OpenAI instructions 86 | pub(crate) const OPENAI_BASE_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks: 87 | Step 1: Review and understand the 'instructions'. 88 | Step 2: Prepare a response by processing the provided data as per the 'instructions'. 89 | Step 3: Convert the response to a Json object. The Json object must match the schema provided as the `output json schema`. 90 | Step 4: Validate that the Json object matches the 'output json schema' and correct if needed. If you are not able to generate a valid Json, respond with "Error calculating the answer." 91 | Step 5: Respond ONLY with properly formatted Json object. No other words or text, only valid Json in the answer. 92 | "#; 93 | 94 | pub(crate) const OPENAI_FUNCTION_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks: 95 | Step 1: Review and understand the 'instructions'. 96 | Step 2: Prepare a response by processing the provided data as per the 'instructions'. 97 | Step 3: Convert the response to a Json object. The Json object must match the schema provided in the function definition. 98 | Step 4: Validate that the Json object matches the function properties and correct if needed. If you are not able to generate a valid Json, respond with "Error calculating the answer." 99 | Step 5: Respond ONLY with properly formatted Json object. No other words or text, only valid Json in the answer. 100 | "#; 101 | 102 | pub(crate) const OPENAI_ASSISTANT_INSTRUCTIONS: &str = r#"You are a computer function. You are expected to perform the following tasks: 103 | 1: Review and understand the content of user messages passed to you in the thread. 104 | 2: Review and consider any files the user provided attached to the messages. 105 | 3: Prepare response using your language model based on the user messages and provided files. 106 | 4: Respond ONLY with properly formatted data portion of a Json. No other words or text, only valid Json in your answers. 107 | "#; 108 | 109 | pub(crate) const OPENAI_ASSISTANT_POLL_FREQ: usize = 10; 110 | 111 | pub(crate) const DEFAULT_AZURE_VERSION: &str = "2024-06-01"; 112 | -------------------------------------------------------------------------------- /src/llm_models/deepseek.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use async_trait::async_trait; 3 | use log::info; 4 | use reqwest::{header, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | use serde_json::{json, Value}; 7 | 8 | use crate::completions::ThinkingLevel; 9 | use crate::constants::DEEPSEEK_API_URL; 10 | use crate::domain::{DeepSeekAPICompletionsResponse, RateLimit}; 11 | use crate::llm_models::{LLMModel, LLMTools}; 12 | use crate::utils::map_to_range_f32; 13 | 14 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 15 | //DeepSeek docs: https://api-docs.deepseek.com/quick_start/pricing 16 | pub enum DeepSeekModels { 17 | DeepSeekChat, 18 | DeepSeekReasoner, 19 | } 20 | 21 | #[async_trait(?Send)] 22 | impl LLMModel for DeepSeekModels { 23 | fn as_str(&self) -> &str { 24 | match self { 25 | DeepSeekModels::DeepSeekChat => "deepseek-chat", 26 | DeepSeekModels::DeepSeekReasoner => "deepseek-reasoner", 27 | } 28 | } 29 | 30 | fn try_from_str(name: &str) -> Option { 31 | match name.to_lowercase().as_str() { 32 | "deepseek-chat" => Some(DeepSeekModels::DeepSeekChat), 33 | "deepseek-reasoner" => Some(DeepSeekModels::DeepSeekReasoner), 34 | _ => None, 35 | } 36 | } 37 | 38 | fn default_max_tokens(&self) -> usize { 39 | match self { 40 | DeepSeekModels::DeepSeekChat => 8_192, 41 | DeepSeekModels::DeepSeekReasoner => 8_192, 42 | } 43 | } 44 | 45 | fn get_endpoint(&self) -> String { 46 | DEEPSEEK_API_URL.to_string() 47 | } 48 | 49 | /// This method prepares the body of the API call for different models 50 | fn get_body( 51 | &self, 52 | instructions: &str, 53 | json_schema: &Value, 54 | function_call: bool, 55 | max_tokens: &usize, 56 | temperature: &f32, 57 | _tools: Option<&[LLMTools]>, 58 | _thinking_level: Option<&ThinkingLevel>, 59 | ) -> serde_json::Value { 60 | //Prepare the 'messages' part of the body 61 | let base_instructions = self.get_base_instructions(Some(function_call)); 62 | let system_message = json!({ 63 | "role": "system", 64 | "content": base_instructions, 65 | }); 66 | let user_message = json!({ 67 | "role": "user", 68 | "content": format!( 69 | " 70 | {instructions} 71 | 72 | 73 | {json_schema} 74 | " 75 | ), 76 | }); 77 | json!({ 78 | "model": self.as_str(), 79 | "max_tokens": max_tokens, 80 | "temperature": temperature, 81 | "messages": vec![ 82 | system_message, 83 | user_message, 84 | ], 85 | }) 86 | } 87 | /// 88 | /// This function leverages DeepSeek API to perform any query as per the provided body. 89 | /// 90 | /// It returns a String the Response object that needs to be parsed based on the self.model. 91 | /// 92 | async fn call_api( 93 | &self, 94 | api_key: &str, 95 | _version: Option, 96 | body: &serde_json::Value, 97 | debug: bool, 98 | _tools: Option<&[LLMTools]>, 99 | ) -> Result { 100 | //Get the API url 101 | let model_url = self.get_endpoint(); 102 | 103 | //Make the API call 104 | let client = Client::new(); 105 | 106 | //Send request 107 | let response = client 108 | .post(model_url) 109 | .header(header::CONTENT_TYPE, "application/json") 110 | .bearer_auth(api_key) 111 | .json(&body) 112 | .send() 113 | .await?; 114 | 115 | let response_status = response.status(); 116 | let response_text = response.text().await?; 117 | 118 | if debug { 119 | info!( 120 | "[debug] DeepSeek API response: [{}] {:#?}", 121 | &response_status, &response_text 122 | ); 123 | } 124 | 125 | Ok(response_text) 126 | } 127 | 128 | /// 129 | /// This method attempts to convert the provided API response text into the expected struct and extracts the data from the response 130 | /// 131 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 132 | //Convert API response to struct representing expected response format 133 | let completions_response: DeepSeekAPICompletionsResponse = 134 | serde_json::from_str(response_text)?; 135 | 136 | //Parse the response and return the assistant content 137 | completions_response 138 | .choices 139 | .iter() 140 | .filter_map(|choice| choice.message.as_ref()) 141 | .find(|&message| message.role == Some("assistant".to_string())) 142 | .and_then(|message| { 143 | message 144 | .content 145 | .as_ref() 146 | .map(|content| self.sanitize_json_response(content)) 147 | }) 148 | .ok_or_else(|| anyhow!("Assistant role content not found")) 149 | } 150 | 151 | // This function allows to check the rate limits for different models 152 | fn get_rate_limit(&self) -> RateLimit { 153 | // DeepSeek documentation: https://api-docs.deepseek.com/quick_start/rate_limit 154 | // "DeepSeek API does NOT constrain user's rate limit. We will try out best to serve every request." 155 | RateLimit { 156 | tpm: 100_000_000, // i.e. very large number 157 | rpm: 100_000_000, 158 | } 159 | } 160 | 161 | // Accepts a [0-100] percentage range and returns the target temperature based on model ranges 162 | fn get_normalized_temperature(&self, relative_temp: u32) -> f32 { 163 | // Temperature range documentation: https://api-docs.deepseek.com/quick_start/parameter_settings 164 | let min = 0.0f32; 165 | let max = 1.5f32; 166 | map_to_range_f32(min, max, relative_temp) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/llm_models/llm_model.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use serde_json::Value; 4 | 5 | use crate::completions::ThinkingLevel; 6 | use crate::constants::OPENAI_BASE_INSTRUCTIONS; 7 | use crate::domain::RateLimit; 8 | use crate::llm_models::LLMTools; 9 | use crate::utils::{map_to_range, remove_json_wrapper}; 10 | 11 | ///This trait defines functions that need to be implemented for an enum that represents an LLM Model from any of the API providers 12 | #[async_trait(?Send)] 13 | pub trait LLMModel { 14 | ///Converts each item in the model enum into its string representation 15 | fn as_str(&self) -> &str; 16 | ///Returns an instance of the enum based on the provided string representation of name 17 | fn try_from_str(name: &str) -> Option 18 | where 19 | Self: Sized; 20 | ///Returns max supported number of tokens for each of the variants of the enum 21 | fn default_max_tokens(&self) -> usize; 22 | ///Returns the url of the endpoint that should be called for each variant of the LLM Model enum 23 | fn get_endpoint(&self) -> String { 24 | self.get_version_endpoint(None) 25 | } 26 | ///Returns the url of the endpoint that should be called for each variant of the LLM Model enum 27 | ///It allows to specify which version of the endpoint to use 28 | fn get_version_endpoint(&self, _version: Option) -> String { 29 | self.get_endpoint() 30 | } 31 | ///Provides a list of base instructions that should be added to each prompt when using each of the models 32 | fn get_base_instructions(&self, _function_call: Option) -> String { 33 | OPENAI_BASE_INSTRUCTIONS.to_string() 34 | } 35 | ///Returns recommendation if function calling should be used for the specified model 36 | fn function_call_default(&self) -> bool { 37 | false 38 | } 39 | ///Constructs the body that should be attached to the API call for each of the LLM Models 40 | #[allow(clippy::too_many_arguments)] 41 | fn get_body( 42 | &self, 43 | instructions: &str, 44 | json_schema: &Value, 45 | function_call: bool, 46 | max_tokens: &usize, 47 | temperature: &f32, 48 | tools: Option<&[LLMTools]>, 49 | thinking_level: Option<&ThinkingLevel>, 50 | ) -> serde_json::Value { 51 | self.get_version_body( 52 | instructions, 53 | json_schema, 54 | function_call, 55 | max_tokens, 56 | temperature, 57 | None, 58 | tools, 59 | thinking_level, 60 | ) 61 | } 62 | /// An API-version-specific implementation of the body constructor 63 | #[allow(clippy::too_many_arguments)] 64 | fn get_version_body( 65 | &self, 66 | instructions: &str, 67 | json_schema: &Value, 68 | function_call: bool, 69 | max_tokens: &usize, 70 | temperature: &f32, 71 | _version: Option, 72 | tools: Option<&[LLMTools]>, 73 | thinking_level: Option<&ThinkingLevel>, 74 | ) -> serde_json::Value { 75 | self.get_body( 76 | instructions, 77 | json_schema, 78 | function_call, 79 | max_tokens, 80 | temperature, 81 | tools, 82 | thinking_level, 83 | ) 84 | } 85 | ///Makes the call to the correct API for the selected model 86 | async fn call_api( 87 | &self, 88 | api_key: &str, 89 | version: Option, 90 | body: &serde_json::Value, 91 | debug: bool, 92 | tools: Option<&[LLMTools]>, 93 | ) -> Result; 94 | ///Based on the model type extracts the data portion of the API response 95 | fn get_data(&self, response_text: &str, function_call: bool) -> Result { 96 | self.get_version_data(response_text, function_call, None) 97 | } 98 | /// An API-version-specific implementation of the data extractor 99 | fn get_version_data( 100 | &self, 101 | response_text: &str, 102 | function_call: bool, 103 | _version: Option, 104 | ) -> Result { 105 | self.get_data(response_text, function_call) 106 | } 107 | /// This function sanitizes the text response from LLMs to clean up common formatting issues. 108 | /// The default implementation of the function removes the common ```json{}``` wrapper returned by most models 109 | fn sanitize_json_response(&self, json_response: &str) -> String { 110 | remove_json_wrapper(json_response) 111 | } 112 | ///Returns the rate limit accepted by the API depending on the used model 113 | ///If not explicitly defined it will assume 1B tokens or 100k transactions a minute 114 | fn get_rate_limit(&self) -> RateLimit { 115 | RateLimit { 116 | tpm: 100_000_000, 117 | rpm: 100_000, 118 | } 119 | } 120 | ///Based on the RateLimit for the model calculates how many requests can be send to the API 121 | fn get_max_requests(&self) -> usize { 122 | let rate_limit = self.get_rate_limit(); 123 | 124 | //Check max requests based on rpm 125 | let max_requests_from_rpm = rate_limit.rpm; 126 | 127 | //Double check max number of requests based on tpm 128 | //Assume we will use ~50% of allowed tokens per request (for prompt + response) 129 | let max_tokens_per_minute = rate_limit.tpm; 130 | let tpm_per_request = (self.default_max_tokens() as f64 * 0.5).ceil() as usize; 131 | //Then check how many requests we can process 132 | let max_requests_from_tpm = max_tokens_per_minute / tpm_per_request; 133 | 134 | //To be safe we go with smaller of the numbers 135 | std::cmp::min(max_requests_from_rpm, max_requests_from_tpm) 136 | } 137 | ///Returns the default temperature to be used by the model 138 | fn get_default_temperature(&self) -> f32 { 139 | 0f32 140 | } 141 | ///Returns the normalized temperature for the model 142 | //Input should be a 0-100 number representing the percentage of max temp for the model 143 | fn get_normalized_temperature(&self, relative_temp: u32) -> f32 { 144 | //Assuming 0-1 range for most models. Different ranges require model-specific implementations. 145 | let min = 0u32; 146 | let max = 1u32; 147 | map_to_range(min, max, relative_temp) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /examples/use_openai_responses.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use schemars::JsonSchema; 3 | use serde::Deserialize; 4 | use serde::Serialize; 5 | use std::ffi::OsStr; 6 | use std::path::Path; 7 | 8 | use allms::{ 9 | assistants::OpenAIVectorStore, 10 | files::OpenAIFile, 11 | llm::{ 12 | tools::{ 13 | LLMTools, OpenAICodeInterpreterConfig, OpenAIFileSearchConfig, OpenAIReasoningConfig, 14 | OpenAIWebSearchConfig, 15 | }, 16 | OpenAIModels, 17 | }, 18 | Completions, 19 | }; 20 | 21 | // Example 1: Basic translation example using reasoning model 22 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 23 | struct TranslationResponse { 24 | pub spanish: String, 25 | pub french: String, 26 | pub german: String, 27 | pub polish: String, 28 | } 29 | 30 | // Example 2: Web search 31 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 32 | struct AINewsArticles { 33 | pub articles: Vec, 34 | } 35 | 36 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 37 | struct AINewsArticle { 38 | pub title: String, 39 | pub url: String, 40 | pub description: String, 41 | } 42 | 43 | // Example 3: File search 44 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 45 | pub struct ConcertInfo { 46 | dates: Vec, 47 | band: String, 48 | genre: String, 49 | venue: String, 50 | city: String, 51 | country: String, 52 | ticket_price: String, 53 | } 54 | 55 | const BANDS_GENRES: &[(&str, &str)] = &[ 56 | ("Metallica", "Metal"), 57 | ("The Beatles", "Rock"), 58 | ("Daft Punk", "Electronic"), 59 | ("Miles Davis", "Jazz"), 60 | ("Johnny Cash", "Country"), 61 | ]; 62 | 63 | // Example 4: Code interpreter example 64 | #[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] 65 | pub struct CodeInterpreterResponse { 66 | pub problem: String, 67 | pub code: String, 68 | pub output: String, 69 | } 70 | 71 | #[tokio::main] 72 | async fn main() -> Result<()> { 73 | env_logger::init(); 74 | 75 | // Example 1: Basic translation example using reasoning model 76 | let instructions = 77 | "Translate the following English sentence to all the languages in the response type: Rust is best for working with LLMs"; 78 | 79 | let openai_api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 80 | 81 | let reasoning_tool = LLMTools::OpenAIReasoning(OpenAIReasoningConfig::default()); 82 | 83 | let openai_responses = Completions::new(OpenAIModels::Gpt5_2Pro, &openai_api_key, None, None) 84 | .add_tool(reasoning_tool) 85 | .version("openai_responses"); 86 | 87 | match openai_responses 88 | .get_answer::(instructions) 89 | .await 90 | { 91 | Ok(response) => println!("Translations:\n{:#?}", response), 92 | Err(e) => eprintln!("Error: {:?}", e), 93 | } 94 | 95 | // Example 2: Web search example 96 | let web_search_tool = LLMTools::OpenAIWebSearch(OpenAIWebSearchConfig::new()); 97 | let openai_responses = Completions::new(OpenAIModels::Gpt5_2, &openai_api_key, None, None) 98 | .version("openai_responses") 99 | .add_tool(web_search_tool); 100 | 101 | match openai_responses 102 | .get_answer::("Find up to 5 most recent news items about Artificial Intelligence, Generative AI, and Large Language Models. 103 | For each news item, provide the title, url, and a short description.") 104 | .await 105 | { 106 | Ok(response) => println!("AI news articles:\n{:#?}", response), 107 | Err(e) => eprintln!("Error: {:?}", e), 108 | } 109 | 110 | // Example 3: File search example 111 | 112 | // Read the concert file and upload it to OpenAI 113 | let path = Path::new("metallica.pdf"); 114 | let bytes = std::fs::read(path)?; 115 | let file_name = path 116 | .file_name() 117 | .and_then(OsStr::to_str) 118 | .map(|s| s.to_string()) 119 | .ok_or_else(|| anyhow!("Failed to extract file name"))?; 120 | let openai_file = OpenAIFile::new(None, &openai_api_key) 121 | .upload(&file_name, bytes) 122 | .await?; 123 | let openai_vector_store = OpenAIVectorStore::new(None, "Concerts", &openai_api_key) 124 | .upload(&[openai_file.id.clone().unwrap_or_default()]) 125 | .await?; 126 | 127 | // Extract concert information using Responses API with file search tool 128 | let file_search_tool = 129 | LLMTools::OpenAIFileSearch(OpenAIFileSearchConfig::new(vec![openai_vector_store 130 | .id 131 | .clone() 132 | .unwrap_or_default()])); 133 | 134 | let openai_responses = Completions::new(OpenAIModels::Gpt5_2, &openai_api_key, None, None) 135 | .version("openai_responses") 136 | .set_context("bands_genres", &BANDS_GENRES)? 137 | .add_tool(file_search_tool); 138 | 139 | match openai_responses 140 | .get_answer::("Extract the information requested in the response type from the attached concert information. 141 | The response should include the genre of the music the 'band' represents. 142 | The mapping of bands to genres was provided in 'bands_genres' list.") 143 | .await 144 | { 145 | Ok(response) => println!("Concert Info:\n{:#?}", response), 146 | Err(e) => eprintln!("Error: {:?}", e), 147 | } 148 | 149 | // Cleanup 150 | openai_file.delete().await?; 151 | openai_vector_store.delete().await?; 152 | 153 | // Example 4: Code interpreter example 154 | 155 | let code_interpreter_tool = LLMTools::OpenAICodeInterpreter(OpenAICodeInterpreterConfig::new()); 156 | let openai_responses = Completions::new(OpenAIModels::Gpt5_1, &openai_api_key, None, None) 157 | .version("openai_responses") 158 | .set_context("Code Interpreter", &"You are a personal math tutor. When asked a math question, write and run code to answer the question.".to_string())? 159 | .add_tool(code_interpreter_tool); 160 | 161 | match openai_responses 162 | .get_answer::( 163 | "I need to solve the equation 3x + 11 = 14. Can you help me?", 164 | ) 165 | .await 166 | { 167 | Ok(response) => println!("Code interpreter response:\n{:#?}", response), 168 | Err(e) => eprintln!("Error: {:?}", e), 169 | } 170 | 171 | Ok(()) 172 | } 173 | -------------------------------------------------------------------------------- /src/files/openai.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Context, Result}; 2 | use async_trait::async_trait; 3 | use log::{error, info}; 4 | use reqwest::{header, multipart, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | use crate::assistants::{OpenAIAssistantResource, OpenAIAssistantVersion}; 8 | use crate::domain::AllmsError; 9 | use crate::files::LLMFiles; 10 | use crate::utils::get_mime_type; 11 | 12 | #[derive(Deserialize, Serialize, Debug, Clone)] 13 | pub struct OpenAIFile { 14 | pub id: Option, 15 | debug: bool, 16 | api_key: String, 17 | version: OpenAIAssistantVersion, 18 | } 19 | 20 | #[derive(Deserialize, Serialize, Debug, Clone)] 21 | pub struct OpenAIFileResp { 22 | id: String, 23 | } 24 | 25 | #[derive(Deserialize, Serialize, Debug, Clone)] 26 | pub struct OpenAIDFileDeleteResp { 27 | id: String, 28 | object: String, 29 | deleted: bool, 30 | } 31 | 32 | #[async_trait(?Send)] 33 | impl LLMFiles for OpenAIFile { 34 | /// Constructor 35 | fn new(id: Option, open_ai_key: &str) -> Self { 36 | OpenAIFile { 37 | id, 38 | debug: false, 39 | api_key: open_ai_key.to_string(), 40 | version: OpenAIAssistantVersion::V1, // Default to V1 41 | } 42 | } 43 | 44 | /// 45 | /// This method can be used to turn on debug mode for the OpenAIFile struct 46 | /// 47 | fn debug(mut self) -> Self { 48 | self.debug = true; 49 | self 50 | } 51 | 52 | /// 53 | /// This function uploads a file to OpenAI and assigns it for use with Assistant API 54 | /// 55 | async fn upload(mut self, file_name: &str, file_bytes: Vec) -> Result { 56 | let files_url = self.version.get_endpoint(&OpenAIAssistantResource::Files); 57 | 58 | // This API sends a form so content type is automatically set by multipart method 59 | let mut version_headers = self.version.get_headers(&self.api_key); 60 | version_headers.remove(header::CONTENT_TYPE); 61 | 62 | // Determine MIME type based on file extension 63 | let mime_type = get_mime_type(file_name).ok_or_else(|| anyhow!("Unsupported file type"))?; 64 | 65 | let form = multipart::Form::new().text("purpose", "assistants").part( 66 | "file", 67 | multipart::Part::bytes(file_bytes) 68 | .file_name(file_name.to_string()) 69 | .mime_str(mime_type) 70 | .context("Failed to set MIME type")?, 71 | ); 72 | 73 | //Make the API call 74 | let client = Client::new(); 75 | 76 | let response = client 77 | .post(files_url) 78 | .headers(version_headers) 79 | .multipart(form) 80 | .send() 81 | .await?; 82 | 83 | let response_status = response.status(); 84 | let response_text = response.text().await?; 85 | 86 | if self.debug { 87 | info!( 88 | "[debug] OpenAI Files status API response: [{}] {:#?}", 89 | &response_status, &response_text 90 | ); 91 | } 92 | 93 | //Deserialize the string response into the Message object to confirm if there were any errors 94 | let response_deser: OpenAIFileResp = 95 | serde_json::from_str(&response_text).map_err(|error| { 96 | let error = AllmsError { 97 | crate_name: "allms".to_string(), 98 | module: "assistants::openai_file".to_string(), 99 | error_message: format!("Files API response serialization error: {}", error), 100 | error_detail: response_text, 101 | }; 102 | error!("{:?}", error); 103 | anyhow!("{:?}", error) 104 | })?; 105 | 106 | self.id = Some(response_deser.id); 107 | 108 | Ok(self) 109 | } 110 | 111 | /// 112 | /// This function deletes a file from OpenAI 113 | /// 114 | async fn delete(&self) -> Result<()> { 115 | let file_id = if let Some(id) = &self.id { 116 | id 117 | } else { 118 | return Err(anyhow!( 119 | "[OpenAI][File API] Unable to delete file without an ID." 120 | )); 121 | }; 122 | 123 | let files_resource = OpenAIAssistantResource::File { 124 | file_id: file_id.to_string(), 125 | }; 126 | let files_url = self.version.get_endpoint(&files_resource); 127 | let version_headers = self.version.get_headers(&self.api_key); 128 | 129 | //Make the API call 130 | let client = Client::new(); 131 | 132 | let response = client 133 | .delete(files_url) 134 | .headers(version_headers) 135 | .send() 136 | .await?; 137 | 138 | let response_status = response.status(); 139 | let response_text = response.text().await?; 140 | 141 | if self.debug { 142 | info!( 143 | "[debug] OpenAI Files status API response: [{}] {:#?}", 144 | &response_status, &response_text 145 | ); 146 | } 147 | 148 | //Check if the file was successfully deleted 149 | serde_json::from_str::(&response_text) 150 | .map_err(|error| { 151 | let error = AllmsError { 152 | crate_name: "allms".to_string(), 153 | module: "assistants::openai_file".to_string(), 154 | error_message: format!( 155 | "Files Delete API response serialization error: {}", 156 | error 157 | ), 158 | error_detail: response_text, 159 | }; 160 | error!("{:?}", error); 161 | anyhow!("{:?}", error) 162 | }) 163 | .and_then(|response| match response.deleted { 164 | true => Ok(()), 165 | false => Err(anyhow!("[OpenAIAssistant] Failed to delete the file.")), 166 | }) 167 | } 168 | 169 | /// 170 | /// This function returns the ID of the file if it exists 171 | /// 172 | fn get_id(&self) -> Option<&String> { 173 | self.id.as_ref() 174 | } 175 | 176 | /// 177 | /// This function returns the debug mode of the file 178 | /// 179 | fn is_debug(&self) -> bool { 180 | self.debug 181 | } 182 | } 183 | 184 | impl OpenAIFile { 185 | /// 186 | /// This method can be used to set the version of Assistants API Beta 187 | /// Current default is V1 188 | /// 189 | pub fn version(mut self, version: OpenAIAssistantVersion) -> Self { 190 | // Files endpoint currently requires v1 so if v2 is selected we overwrite 191 | let version = match version { 192 | OpenAIAssistantVersion::V2 => OpenAIAssistantVersion::V1, 193 | _ => version, 194 | }; 195 | self.version = version; 196 | self 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /examples/use_completions.rs: -------------------------------------------------------------------------------- 1 | use schemars::JsonSchema; 2 | use serde::Deserialize; 3 | use serde::Serialize; 4 | 5 | use allms::{ 6 | llm::{ 7 | AnthropicModels, AwsBedrockModels, DeepSeekModels, GoogleModels, LLMModel, MistralModels, 8 | OpenAIModels, PerplexityModels, XAIModels, 9 | }, 10 | Completions, 11 | }; 12 | 13 | #[derive(Deserialize, Serialize, JsonSchema, Debug, Clone)] 14 | struct TranslationResponse { 15 | pub spanish: String, 16 | pub french: String, 17 | pub german: String, 18 | pub polish: String, 19 | } 20 | 21 | #[tokio::main] 22 | async fn main() { 23 | env_logger::init(); 24 | 25 | // Example context and instructions 26 | let instructions = 27 | "Translate the following English sentence to all the languages in the response type: Rust is best for working with LLMs"; 28 | 29 | // Get answer using AWS Bedrock Converse 30 | // AWS Bedrock SDK requires `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be defined and matching your AWS account 31 | let model = AwsBedrockModels::try_from_str("amazon.nova-lite-v1:0") 32 | .unwrap_or(AwsBedrockModels::NovaLite); // Choose the model 33 | println!("AWS Bedrock model: {:#?}", model.as_str()); 34 | 35 | let aws_completion = Completions::new(model, "", None, None); 36 | 37 | match aws_completion 38 | .get_answer::(instructions) 39 | .await 40 | { 41 | Ok(response) => println!("AWS Bedrock response: {:#?}", response), 42 | Err(e) => eprintln!("Error: {:?}", e), 43 | } 44 | 45 | // Get answer using OpenAI Completions API 46 | let openai_api_key: String = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 47 | let model = OpenAIModels::try_from_str("gpt-5.2").unwrap_or(OpenAIModels::Gpt5_2); // Choose the model 48 | println!("OpenAI model: {:#?}", model.as_str()); 49 | 50 | let openai_completion = Completions::new(model, &openai_api_key, None, None); 51 | 52 | match openai_completion 53 | .get_answer::(instructions) 54 | .await 55 | { 56 | Ok(response) => println!("OpenAI Completions API response: {:#?}", response), 57 | Err(e) => eprintln!("Error: {:?}", e), 58 | } 59 | 60 | // Get answer using OpenAI (on Azure) 61 | // Ensure `OPENAI_API_URL` is set to your Azure OpenAI resource endpoint 62 | let azure_openai_completion = 63 | Completions::new(OpenAIModels::Gpt5_2, &openai_api_key, None, None) 64 | .version("azure:2024-08-01-preview"); 65 | match azure_openai_completion 66 | .get_answer::(instructions) 67 | .await 68 | { 69 | Ok(response) => println!("Azure OpenAI response: {:#?}", response), 70 | Err(e) => eprintln!("Error: {:?}", e), 71 | } 72 | 73 | // Get answer using Anthropic 74 | let anthropic_api_key: String = 75 | std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); 76 | let model = AnthropicModels::try_from_str("claude-haiku-4-5") 77 | .unwrap_or(AnthropicModels::Claude4_5Haiku); // Choose the model 78 | println!("Anthropic model: {:#?}", model.as_str()); 79 | 80 | let anthropic_completion = Completions::new(model, &anthropic_api_key, None, None); 81 | 82 | match anthropic_completion 83 | .get_answer::(instructions) 84 | .await 85 | { 86 | Ok(response) => println!("Anthropic response: {:#?}", response), 87 | Err(e) => eprintln!("Error: {:?}", e), 88 | } 89 | 90 | // Get answer using Mistral 91 | let mistral_api_key: String = 92 | std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); 93 | let model = MistralModels::try_from_str("mistral-medium-latest") 94 | .unwrap_or(MistralModels::MistralMedium3_1); // Choose the model 95 | println!("Mistral model: {:#?}", model.as_str()); 96 | 97 | let mistral_completion = Completions::new(model, &mistral_api_key, None, None); 98 | 99 | match mistral_completion 100 | .get_answer::(instructions) 101 | .await 102 | { 103 | Ok(response) => println!("Mistral response: {:#?}", response), 104 | Err(e) => eprintln!("Error: {:?}", e), 105 | } 106 | 107 | // Get answer using Google Studio 108 | let model = GoogleModels::try_from_str("gemini-2.5-flash-lite") 109 | .unwrap_or(GoogleModels::Gemini2_5FlashLite); // Choose the model 110 | println!("Google Gemini model: {:#?}", model.as_str()); 111 | 112 | let google_token_str: String = 113 | std::env::var("GOOGLE_AI_STUDIO_API_KEY").expect("GOOGLE_AI_STUDIO_API_KEY not set"); 114 | 115 | let gemini_completion = 116 | Completions::new(model, &google_token_str, None, None).version("google-studio"); 117 | 118 | match gemini_completion 119 | .get_answer::(instructions) 120 | .await 121 | { 122 | Ok(response) => println!("Gemini response: {:#?}", response), 123 | Err(e) => eprintln!("Error: {:?}", e), 124 | } 125 | 126 | // Get answer using Perplexity 127 | let model = PerplexityModels::try_from_str("sonar-pro").unwrap_or(PerplexityModels::Sonar); // Choose the model 128 | println!("Perplexity model: {:#?}", model.as_str()); 129 | 130 | let perplexity_token_str: String = 131 | std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"); 132 | 133 | let perplexity_completion = Completions::new(model, &perplexity_token_str, None, None); 134 | 135 | match perplexity_completion 136 | .get_answer::(instructions) 137 | .await 138 | { 139 | Ok(response) => println!("Perplexity response: {:#?}", response), 140 | Err(e) => eprintln!("Error: {:?}", e), 141 | } 142 | 143 | // Get answer using DeepSeek 144 | let model = 145 | DeepSeekModels::try_from_str("deepseek-chat").unwrap_or(DeepSeekModels::DeepSeekChat); // Choose the model 146 | println!("DeepSeek model: {:#?}", model.as_str()); 147 | 148 | let deepseek_token_str: String = 149 | std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set"); 150 | 151 | let deepseek_completion = Completions::new(model, &deepseek_token_str, None, None); 152 | 153 | match deepseek_completion 154 | .get_answer::(instructions) 155 | .await 156 | { 157 | Ok(response) => println!("DeepSeek response: {:#?}", response), 158 | Err(e) => eprintln!("Error: {:?}", e), 159 | } 160 | 161 | // Get answer using xAI Grok 162 | let xai_api_key: String = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set"); 163 | let model = XAIModels::try_from_str("grok-3-mini").unwrap_or(XAIModels::Grok3Mini); // Choose the model 164 | println!("xAI Grok model: {:#?}", model.as_str()); 165 | 166 | let xai_completion = Completions::new(model, &xai_api_key, None, None); 167 | 168 | match xai_completion 169 | .get_answer::(instructions) 170 | .await 171 | { 172 | Ok(response) => println!("xAI Grok response: {:#?}", response), 173 | Err(e) => eprintln!("Error: {:?}", e), 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/llm_models/xai.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use log::info; 4 | use reqwest::{header, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | use serde_json::Value; 7 | 8 | use crate::completions::ThinkingLevel; 9 | use crate::constants::XAI_API_URL; 10 | use crate::domain::{ 11 | XAIAssistantMessageRole, XAIChatMessage, XAIChatRequest, XAIChatResponse, XAIRole, 12 | }; 13 | use crate::llm_models::{LLMModel, LLMTools}; 14 | 15 | // API Docs: https://docs.x.ai/docs/models 16 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 17 | pub enum XAIModels { 18 | Grok4, 19 | Grok3, 20 | Grok3Mini, 21 | Grok3Fast, 22 | Grok3MiniFast, 23 | } 24 | 25 | #[async_trait(?Send)] 26 | impl LLMModel for XAIModels { 27 | fn as_str(&self) -> &str { 28 | match self { 29 | XAIModels::Grok4 => "grok-4", 30 | XAIModels::Grok3 => "grok-3", 31 | XAIModels::Grok3Mini => "grok-3-mini", 32 | XAIModels::Grok3Fast => "grok-3-fast", 33 | XAIModels::Grok3MiniFast => "grok-3-mini-fast", 34 | } 35 | } 36 | 37 | // Docs: https://docs.x.ai/docs/models 38 | fn try_from_str(name: &str) -> Option { 39 | match name.to_lowercase().as_str() { 40 | "grok-4" => Some(XAIModels::Grok4), 41 | "grok-4-latest" => Some(XAIModels::Grok4), 42 | "grok-4-0709" => Some(XAIModels::Grok4), 43 | "grok-3" => Some(XAIModels::Grok3), 44 | "grok-3-latest" => Some(XAIModels::Grok3), 45 | "grok-3-beta" => Some(XAIModels::Grok3), 46 | "grok-3-mini" => Some(XAIModels::Grok3Mini), 47 | "grok-3-mini-latest" => Some(XAIModels::Grok3Mini), 48 | "grok-3-mini-beta" => Some(XAIModels::Grok3Mini), 49 | "grok-3-fast" => Some(XAIModels::Grok3Fast), 50 | "grok-3-fast-latest" => Some(XAIModels::Grok3Fast), 51 | "grok-3-fast-beta" => Some(XAIModels::Grok3Fast), 52 | "grok-3-mini-fast" => Some(XAIModels::Grok3MiniFast), 53 | "grok-3-mini-fast-latest" => Some(XAIModels::Grok3MiniFast), 54 | "grok-3-mini-fast-beta" => Some(XAIModels::Grok3MiniFast), 55 | _ => None, 56 | } 57 | } 58 | 59 | fn default_max_tokens(&self) -> usize { 60 | // Docs: https://docs.x.ai/docs/models 61 | match self { 62 | XAIModels::Grok4 => 256_000, 63 | XAIModels::Grok3 => 131_072, 64 | XAIModels::Grok3Mini => 131_072, 65 | XAIModels::Grok3Fast => 131_072, 66 | XAIModels::Grok3MiniFast => 131_072, 67 | } 68 | } 69 | 70 | fn get_endpoint(&self) -> String { 71 | XAI_API_URL.to_string() 72 | } 73 | 74 | //This method prepares the body of the API call for different models 75 | fn get_body( 76 | &self, 77 | instructions: &str, 78 | json_schema: &Value, 79 | function_call: bool, 80 | max_tokens: &usize, 81 | temperature: &f32, 82 | tools: Option<&[LLMTools]>, 83 | _thinking_level: Option<&ThinkingLevel>, 84 | ) -> serde_json::Value { 85 | // Get system instructions 86 | let base_instructions = self.get_base_instructions(Some(function_call)); 87 | 88 | // Set the structured output schema 89 | 90 | // TODO: Using structured output with JSON Schema is not working. Attaching schema to instructions until fixed. 91 | // let response_format = Some(XAIResponseFormat{ 92 | // r#type: XAIResponseFormatType::JsonSchema, 93 | // json_schema: Some(json_schema.clone()), 94 | // }); 95 | 96 | let instructions = format!( 97 | "{} 98 | {:?}", 99 | instructions, json_schema, 100 | ); 101 | 102 | let search_parameters = tools.and_then(|tools| { 103 | tools.iter().find_map(|tool| match tool { 104 | LLMTools::XAIWebSearch(config) => Some(config.clone()), 105 | _ => None, 106 | }) 107 | }); 108 | 109 | // TODOs: 110 | // TextFile tool - currently only supports text files exposed as URL with instructions and not content 111 | 112 | let chat_request = XAIChatRequest { 113 | model: self.as_str().to_string(), 114 | max_completion_tokens: Some(*max_tokens), 115 | temperature: Some(*temperature), 116 | messages: vec![ 117 | XAIChatMessage::new(XAIRole::System, base_instructions), 118 | XAIChatMessage::new(XAIRole::User, instructions.to_string()), 119 | ], 120 | response_format: None, 121 | tools: None, 122 | search_parameters, 123 | }; 124 | 125 | serde_json::to_value(chat_request).unwrap_or_default() 126 | } 127 | 128 | /* 129 | * This function leverages xAI API to perform any query as per the provided body. 130 | * 131 | * It returns a String the Response object that needs to be parsed based on the self.model. 132 | */ 133 | async fn call_api( 134 | &self, 135 | api_key: &str, 136 | _version: Option, 137 | body: &serde_json::Value, 138 | debug: bool, 139 | _tools: Option<&[LLMTools]>, 140 | ) -> Result { 141 | //Get the API url 142 | let model_url = self.get_endpoint(); 143 | 144 | //Make the API call 145 | let client = Client::new(); 146 | 147 | //Send request 148 | let response = client 149 | .post(model_url) 150 | .header(header::CONTENT_TYPE, "application/json") 151 | .bearer_auth(api_key) 152 | .json(&body) 153 | .send() 154 | .await?; 155 | 156 | let response_status = response.status(); 157 | let response_text = response.text().await?; 158 | 159 | if debug { 160 | info!( 161 | "[debug] xAI API response: [{}] {:#?}", 162 | &response_status, &response_text 163 | ); 164 | } 165 | 166 | Ok(response_text) 167 | } 168 | 169 | //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response 170 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 171 | //Convert API response to struct representing expected response format 172 | let messages_response: XAIChatResponse = serde_json::from_str(response_text)?; 173 | 174 | let assistant_response = messages_response 175 | .choices 176 | .iter() 177 | .map(|item| &item.message) 178 | .filter(|message| message.role == XAIAssistantMessageRole::Assistant) 179 | .filter_map(|message| { 180 | // Use content or reasoning_content if present 181 | message 182 | .content 183 | .as_ref() 184 | .or(message.reasoning_content.as_ref()) 185 | }) 186 | .fold(String::new(), |mut acc, content| { 187 | acc.push_str(&self.sanitize_json_response(content)); 188 | acc 189 | }); 190 | 191 | //Return completions text 192 | Ok(assistant_response) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # allms: One Library to rule them aLLMs 2 | [![crates.io](https://img.shields.io/crates/v/allms.svg)](https://crates.io/crates/allms) 3 | [![docs.rs](https://docs.rs/allms/badge.svg)](https://docs.rs/allms) 4 | 5 | This Rust library is specialized in providing type-safe interactions with APIs of the following LLM providers: Anthropic, AWS Bedrock, Azure, DeepSeek, Google Gemini, Mistral, OpenAI, Perplexity, xAI. (More providers to be added in the future.) It's designed to simplify the process of experimenting with different models. It de-risks the process of migrating between providers reducing vendor lock-in issues. It also standardizes serialization of sending requests to LLM APIs and interpreting the responses, ensuring that the JSON data is handled in a type-safe manner. With allms you can focus on creating effective prompts and providing LLM with the right context, instead of worrying about differences in API implementations. 6 | 7 | ## Features 8 | 9 | - Support for various foundational LLM providers including Anthropic, AWS Bedrock, Azure, DeepSeek, Google Gemini, OpenAI, Mistral, and Perplexity. 10 | - Easy-to-use functions for chat/text completions and assistants. Use the same struct and methods regardless of which model you choose. 11 | - Automated response deserialization to custom types. 12 | - Standardized approach to providing context with support of function calling, tools, and file uploads. 13 | - Enhanced developer productivity with automated token calculations, rate limits and debug mode. 14 | - Extensibility enabling easy adoption of other models with standardized trait. 15 | - Asynchronous support using Tokio. 16 | 17 | ### Foundational Models 18 | Anthropic: 19 | - APIs: Messages, Text Completions 20 | - Models: Claude Opus 4.5, Claude Sonnet 4.5, Claude Haiku 4.5, Claude Opus 4.1, Claude Sonnet 4, Claude Opus 4, Claude 3.7 Sonnet, Claude 3.5 Sonnet, Claude 3.5 Haiku, Claude 3 Opus, Claude 3 Sonnet, Claude 3 Haiku, Claude 2.0, Claude Instant 1.2 21 | - Tools: file search, web search, code interpreter, computer use 22 | 23 | AWS Bedrock: 24 | - APIs: Converse 25 | - Models: Nova Micro, Nova Lite, Nova Pro (additional models to be added) 26 | 27 | Azure OpenAI: 28 | - APIs: Completions, Responses, Assistants, Files, Vector Stores, Tools 29 | - API version can be set using `AzureVersion` variant 30 | - Models: as per model deployments in Azure OpenAI Studio 31 | - If using custom model deployment names please use the `Custom` variant of `OpenAIModels` 32 | 33 | DeepSeek: 34 | - APIs: Chat Completion 35 | - Models: DeepSeek-V3, DeepSeek-R1 36 | 37 | Google Gemini: 38 | - APIs: Chat Completions (including streaming) 39 | - Via Vertex AI or AI Studio 40 | - Models: Gemini 3 Pro (Preview), Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 2.5 Flash-Lite, Gemini 2.0 Flash, Gemini 2.0 Flash-Lite, Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 1.5 Flash-8B 41 | - Experimental models: Gemini 2.0 Pro, Gemini 2.0 Flash-Thinking 42 | - Fine-tuned models: in Vertex AI only, using endpoint constructor 43 | - Tools: Google Search, code execution 44 | 45 | Mistral: 46 | - APIs: Chat Completions 47 | - Models: 48 | - Multimodal: Mistral Large 2.1, Mistral Medium 3.1, Mistral Medium 3, Mistral Small 3.2, Mistral Small 3.1, Mistral Small 3, Mistral Small 2 49 | - Reasoning: Magistral Medium 1.2, Magistral Medium, Magistral Small 1.2 50 | - Other: Codestral 2508, Codestral 2, Ministral 3B, Ministral 8B 51 | - Legacy models: Mistral Large, Mistral Nemo, Mistral 7B, Mixtral 8x7B, Mixtral 8x22B, Mistral Medium, Mistral Small, Mistral Tiny 52 | - Tools: file web search, code interpreter 53 | 54 | OpenAI: 55 | - APIs: Chat Completions, Responses, Function Calling, Assistants (v1 & v2), Files, Vector Stores 56 | - Models: 57 | - Chat Completions & Responses only: GPT-5.2, GPT-5.2 Pro, Gpt-5.1, o1, o1 Preview, o1 Mini, o1 Pro, o3, o3 Mini, o4 Mini 58 | - Chat Completions, Responses & Assistants: GPT-5, GPT-5-mini, GPT-5-nano, GPT-4.5-Preview, GPT-4o, GPT-4, GPT-4 32k, GPT-4 Turbo, GPT-3.5 Turbo, GPT-3.5 Turbo 16k, fine-tuned models (via `Custom` variant) 59 | - Tools: file search, web search, code interpreter, computer use 60 | 61 | Perplexity: 62 | - APIs: Chat Completions 63 | - Models: Sonar, Sonar Pro, Sonar Reasoning 64 | - The following legacy models will be supported until February 22, 2025: Llama 3.1 Sonar Small, Llama 3.1 Sonar Large, Llama 3.1 Sonar Huge 65 | 66 | xAI: 67 | - APIs: Chat Completions 68 | - Models: Grok 4, Grok 3, Grok 3 Mini, Grok 3 Fast, Grok 3 Mini Fast 69 | - Tools: web search 70 | 71 | ### Prerequisites 72 | - Anthropic: API key (passed in model constructor) 73 | - AWS Bedrock: environment variables `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_REGION` set as per AWS settings. 74 | - Azure OpenAI: environment variable `OPENAI_API_URL` set to your Azure OpenAI resource endpoint. Endpoint key passed in constructor 75 | - DeepSeek: API key (passed in model constructor) 76 | - Google AI Studio: API key (passed in model constructor) 77 | - Google Vertex AI: GCP service account key (used to obtain access token) + GCP project ID (set as environment variable) 78 | - Mistral: API key (passed in model constructor) 79 | - OpenAI: API key (passed in model constructor) 80 | - Perplexity: API key (passed in model constructor) 81 | - xAI: API key (passed in model constructor) 82 | 83 | ### Examples 84 | Explore the `examples` directory to see more use cases and how to use different LLM providers and endpoint types. 85 | 86 | Using `Completions` API with different foundational models: 87 | ``` 88 | let anthropic_answer = Completions::new(AnthropicModels::Claude4Sonnet, &API_KEY, None, None) 89 | .get_answer::(instructions) 90 | .await? 91 | 92 | let aws_bedrock_answer = Completions::new(AwsBedrockModels::NovaLite, "", None, None) 93 | .get_answer::(instructions) 94 | .await? 95 | 96 | let deepseek_answer = Completions::new(DeepSeekModels::DeepSeekReasoner, &API_KEY, None, None) 97 | .get_answer::(instructions) 98 | .await? 99 | 100 | let google_answer = Completions::new(GoogleModels::Gemini2_5Flash, &API_KEY, None, None) 101 | .get_answer::(instructions) 102 | .await? 103 | 104 | let mistral_answer = Completions::new(MistralModels::MistralMedium3, &API_KEY, None, None) 105 | .get_answer::(instructions) 106 | .await? 107 | 108 | let openai_answer = Completions::new(OpenAIModels::Gpt4_1Mini, &API_KEY, None, None) 109 | .get_answer::(instructions) 110 | .await? 111 | 112 | let openai_responses_answer = Completions::new(OpenAIModels::Gpt4_1Mini, &API_KEY, None, None) 113 | .version("openai_responses") 114 | .get_answer::(instructions) 115 | .await? 116 | 117 | let perplexity_answer = Completions::new(PerplexityModels::SonarPro, &API_KEY, None, None) 118 | .get_answer::(instructions) 119 | .await? 120 | 121 | let xai_answer = Completions::new(XAIModels::Grok3Mini, &API_KEY, None, None) 122 | .get_answer::(instructions) 123 | .await? 124 | ``` 125 | 126 | Example: 127 | ``` 128 | RUST_LOG=info RUST_BACKTRACE=1 cargo run --example use_completions 129 | ``` 130 | 131 | Using `Assistant` API to analyze your files with `File` and `VectorStore` capabilities: 132 | ``` 133 | // Create a File 134 | let openai_file = OpenAIFile::new(None, &API_KEY) 135 | .upload(&file_name, bytes) 136 | .await?; 137 | 138 | // Create a Vector Store 139 | let openai_vector_store = OpenAIVectorStore::new(None, "Name", &API_KEY) 140 | .upload(&[openai_file.id.clone().unwrap_or_default()]) 141 | .await?; 142 | 143 | // Extract data using Assistant 144 | let openai_answer = OpenAIAssistant::new(OpenAIModels::Gpt4o, &API_KEY) 145 | .version(OpenAIAssistantVersion::V2) 146 | .vector_store(openai_vector_store.clone()) 147 | .await? 148 | .get_answer::(instructions, &[]) 149 | .await?; 150 | ``` 151 | 152 | Example: 153 | ``` 154 | RUST_LOG=info RUST_BACKTRACE=1 cargo run --example use_openai_assistant 155 | ``` 156 | 157 | ## License 158 | This project is licensed under dual MIT/Apache-2.0 license. See the [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE) files for details. 159 | -------------------------------------------------------------------------------- /src/llm_models/aws.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use async_trait::async_trait; 3 | use aws_config::BehaviorVersion; 4 | use aws_sdk_bedrockruntime::{ 5 | types::{ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock}, 6 | Client, 7 | }; 8 | use log::info; 9 | use serde::{Deserialize, Serialize}; 10 | use serde_json::Value; 11 | 12 | use crate::completions::ThinkingLevel; 13 | use crate::constants::{AWS_BEDROCK_API_URL, AWS_REGION}; 14 | use crate::domain::RateLimit; 15 | use crate::llm_models::{LLMModel, LLMTools}; 16 | 17 | #[derive(Serialize, Deserialize)] 18 | struct AwsBedrockRequestBody { 19 | instructions: String, 20 | json_schema: Value, 21 | max_tokens: i32, 22 | temperature: f32, 23 | } 24 | 25 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 26 | // AWS Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html 27 | pub enum AwsBedrockModels { 28 | NovaPro, 29 | NovaLite, 30 | NovaMicro, 31 | } 32 | 33 | #[async_trait(?Send)] 34 | impl LLMModel for AwsBedrockModels { 35 | fn as_str(&self) -> &str { 36 | match self { 37 | AwsBedrockModels::NovaPro => "amazon.nova-pro-v1:0", 38 | AwsBedrockModels::NovaLite => "amazon.nova-lite-v1:0", 39 | AwsBedrockModels::NovaMicro => "amazon.nova-micro-v1:0", 40 | } 41 | } 42 | 43 | fn try_from_str(name: &str) -> Option { 44 | match name.to_lowercase().as_str() { 45 | "amazon.nova-pro-v1:0" => Some(AwsBedrockModels::NovaPro), 46 | "amazon.nova-lite-v1:0" => Some(AwsBedrockModels::NovaLite), 47 | "amazon.nova-micro-v1:0" => Some(AwsBedrockModels::NovaMicro), 48 | _ => None, 49 | } 50 | } 51 | 52 | fn default_max_tokens(&self) -> usize { 53 | match self { 54 | AwsBedrockModels::NovaPro => 5_120, 55 | AwsBedrockModels::NovaLite => 5_120, 56 | AwsBedrockModels::NovaMicro => 5_120, 57 | } 58 | } 59 | 60 | fn get_endpoint(&self) -> String { 61 | format!("{}/model/{}/converse", &*AWS_BEDROCK_API_URL, self.as_str()) 62 | } 63 | 64 | /// AWS Bedrock implementation leverages AWS Bedrock SKD, therefore data is only passed by this method to `call_api` method where the actual logic is implemented 65 | fn get_body( 66 | &self, 67 | instructions: &str, 68 | json_schema: &Value, 69 | _function_call: bool, 70 | max_tokens: &usize, 71 | temperature: &f32, 72 | _tools: Option<&[LLMTools]>, 73 | _thinking_level: Option<&ThinkingLevel>, 74 | ) -> serde_json::Value { 75 | let body = AwsBedrockRequestBody { 76 | instructions: instructions.to_string(), 77 | json_schema: json_schema.clone(), 78 | max_tokens: *max_tokens as i32, 79 | temperature: *temperature, 80 | }; 81 | 82 | // Return the body serialized as a JSON value 83 | serde_json::to_value(body).unwrap() 84 | } 85 | 86 | /// This function leverages AWS Bedrock SDK to perform any query as per the provided body. 87 | async fn call_api( 88 | &self, 89 | // AWS Bedrock SDK utilizes `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables for request authentication 90 | // Docs: https://docs.aws.amazon.com/sdk-for-rust/latest/dg/credproviders.html 91 | _api_key: &str, 92 | _version: Option, 93 | body: &serde_json::Value, 94 | debug: bool, 95 | _tools: Option<&[LLMTools]>, 96 | ) -> Result { 97 | let sdk_config = aws_config::defaults(BehaviorVersion::latest()) 98 | .region(&**AWS_REGION) 99 | .load() 100 | .await; 101 | let client = Client::new(&sdk_config); 102 | 103 | // Get request info from body 104 | let request_body_opt: Option = 105 | serde_json::from_value(body.clone()).ok(); 106 | let (instructions_opt, json_schema_opt, max_tokens_opt, temperature_opt) = request_body_opt 107 | .map_or_else( 108 | || (None, None, None, None), 109 | |request_body| { 110 | ( 111 | Some(request_body.instructions), 112 | Some(request_body.json_schema), 113 | Some(request_body.max_tokens), 114 | Some(request_body.temperature), 115 | ) 116 | }, 117 | ); 118 | 119 | // Get base instructions 120 | let base_instructions = self.get_base_instructions(None); 121 | 122 | let converse_builder = client 123 | .converse() 124 | .model_id(self.as_str()) 125 | .system(SystemContentBlock::Text(base_instructions)); 126 | 127 | // Add user instructions including the expected output schema if specifed 128 | let instructions = instructions_opt.unwrap_or_default(); 129 | let user_instructions = json_schema_opt 130 | .map(|schema| { 131 | format!( 132 | " 133 | {instructions} 134 | 135 | 136 | {schema} 137 | " 138 | ) 139 | }) 140 | .unwrap_or(instructions); 141 | let converse_builder = converse_builder.messages( 142 | Message::builder() 143 | .role(ConversationRole::User) 144 | .content(ContentBlock::Text(user_instructions)) 145 | .build() 146 | .map_err(|_| anyhow!("failed to build message"))?, 147 | ); 148 | 149 | // If specified add inference config 150 | let converse_builder = if max_tokens_opt.is_some() || temperature_opt.is_some() { 151 | let inference_config = InferenceConfiguration::builder() 152 | .set_max_tokens(max_tokens_opt) 153 | .set_temperature(temperature_opt) 154 | .build(); 155 | converse_builder.set_inference_config(Some(inference_config)) 156 | } else { 157 | converse_builder 158 | }; 159 | 160 | // Send request 161 | let converse_response = converse_builder.send().await?; 162 | 163 | if debug { 164 | info!( 165 | "[debug] AWS Bedrock API response: {:#?}", 166 | &converse_response 167 | ); 168 | } 169 | 170 | //Parse the response and return the assistant content 171 | let text = converse_response 172 | .output() 173 | .ok_or(anyhow!("no output"))? 174 | .as_message() 175 | .map_err(|_| anyhow!("output not a message"))? 176 | .content() 177 | .first() 178 | .ok_or(anyhow!("no content in message"))? 179 | .as_text() 180 | .map_err(|_| anyhow!("content is not text"))? 181 | .to_string(); 182 | Ok(self.sanitize_json_response(&text)) 183 | } 184 | 185 | /// AWS Bedrock implementation leverages AWS Bedrock SDK, therefore data extraction is implemented directly in `call_api` method and this method only passes the data on 186 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 187 | Ok(response_text.to_string()) 188 | } 189 | 190 | //This function allows to check the rate limits for different models 191 | fn get_rate_limit(&self) -> RateLimit { 192 | // Docs: https://docs.aws.amazon.com/general/latest/gr/bedrock.html 193 | match self { 194 | AwsBedrockModels::NovaPro => RateLimit { 195 | tpm: 400_000, 196 | rpm: 100, 197 | }, 198 | AwsBedrockModels::NovaLite | AwsBedrockModels::NovaMicro => RateLimit { 199 | tpm: 2_000_000, 200 | rpm: 1_000, 201 | }, 202 | } 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /src/files/anthropic.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Context, Result}; 2 | use async_trait::async_trait; 3 | use log::{error, info}; 4 | use reqwest::{multipart, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | use crate::{ 8 | apis::AnthropicApiEndpoints, constants::ANTHROPIC_FILES_API_URL, domain::AllmsError, 9 | files::LLMFiles, utils::get_mime_type, 10 | }; 11 | 12 | #[derive(Deserialize, Serialize, Debug, Clone)] 13 | pub struct AnthropicFile { 14 | pub id: Option, 15 | debug: bool, 16 | api_key: String, 17 | } 18 | 19 | #[async_trait(?Send)] 20 | impl LLMFiles for AnthropicFile { 21 | /// Create a new file instance 22 | /// 23 | /// # Arguments 24 | /// * `id` - Optional file ID (for existing files) 25 | /// * `api_key` - API key for the LLM provider 26 | fn new(id: Option, api_key: &str) -> Self { 27 | Self { 28 | id, 29 | debug: false, 30 | api_key: api_key.to_string(), 31 | } 32 | } 33 | 34 | /// Enable debug mode for the file instance 35 | /// 36 | /// Returns the modified instance for method chaining 37 | fn debug(mut self) -> Self { 38 | self.debug = true; 39 | self 40 | } 41 | 42 | /// Upload a file to the LLM provider 43 | /// 44 | /// # Arguments 45 | /// * `file_name` - Name of the file to upload 46 | /// * `file_bytes` - File content as bytes 47 | /// 48 | /// # Returns 49 | /// * `Result` - The file instance with updated ID on success 50 | async fn upload(mut self, file_name: &str, file_bytes: Vec) -> Result { 51 | let files_url = ANTHROPIC_FILES_API_URL.to_string(); 52 | 53 | let mime_type = get_mime_type(file_name) 54 | .and_then(|mime_type| { 55 | if mime_type != "application/pdf" { 56 | None 57 | } else { 58 | Some(mime_type) 59 | } 60 | }) 61 | .ok_or_else(|| anyhow!("Only PDF files (application/pdf) are supported"))?; 62 | 63 | let form = multipart::Form::new().part( 64 | "file", 65 | multipart::Part::bytes(file_bytes) 66 | .file_name(file_name.to_string()) 67 | .mime_str(mime_type) 68 | .context("Failed to set MIME type")?, 69 | ); 70 | 71 | //Make the API call 72 | let client = Client::new(); 73 | 74 | // Build request with appropriate headers 75 | let response = client 76 | .post(files_url) 77 | // Anthropic-specific way of passing API key 78 | .header("x-api-key", self.api_key.clone()) 79 | // Required as per documentation 80 | .header( 81 | "anthropic-version", 82 | AnthropicApiEndpoints::messages_default().version(), 83 | ) 84 | .header( 85 | "anthropic-beta", 86 | AnthropicApiEndpoints::files_default().version(), 87 | ) 88 | // Specify the mime type of the file 89 | .multipart(form) 90 | .send() 91 | .await?; 92 | 93 | let response_status = response.status(); 94 | let response_text = response.text().await?; 95 | 96 | if self.debug { 97 | info!( 98 | "[debug] Anthropic Files status API response: [{}] {:#?}", 99 | &response_status, &response_text 100 | ); 101 | } 102 | 103 | // Deserialize the string response into the AnthropicFileResp object to confirm if there were any errors 104 | let response_deser: AnthropicFileResp = 105 | serde_json::from_str(&response_text).map_err(|error| { 106 | let error = AllmsError { 107 | crate_name: "allms".to_string(), 108 | module: "files::anthropic".to_string(), 109 | error_message: format!( 110 | "Anthropic Files API response serialization error: {}", 111 | error 112 | ), 113 | error_detail: response_text, 114 | }; 115 | error!("{:?}", error); 116 | anyhow!("{:?}", error) 117 | })?; 118 | 119 | self.id = Some(response_deser.id); 120 | 121 | Ok(self) 122 | } 123 | 124 | /// Delete a file from the LLM provider 125 | /// 126 | /// # Returns 127 | /// * `Result<()>` - Success or error 128 | async fn delete(&self) -> Result<()> { 129 | let files_url = if let Some(id) = self.id.as_ref() { 130 | format!("{}/{}", &*ANTHROPIC_FILES_API_URL, id) 131 | } else { 132 | return Err(anyhow!("File ID is required to delete a file")); 133 | }; 134 | 135 | //Make the API call 136 | let client = Client::new(); 137 | 138 | // Build request with appropriate headers 139 | let response = client 140 | .delete(files_url) 141 | // Anthropic-specific way of passing API key 142 | .header("x-api-key", self.api_key.clone()) 143 | // Required as per documentation 144 | .header( 145 | "anthropic-version", 146 | AnthropicApiEndpoints::messages_default().version(), 147 | ) 148 | .header( 149 | "anthropic-beta", 150 | AnthropicApiEndpoints::files_default().version(), 151 | ) 152 | .send() 153 | .await?; 154 | 155 | let response_status = response.status(); 156 | let response_text = response.text().await?; 157 | 158 | if self.debug { 159 | info!( 160 | "[debug] Anthropic Files status API response: [{}] {:#?}", 161 | &response_status, &response_text 162 | ); 163 | } 164 | 165 | //Check if the file was successfully deleted 166 | serde_json::from_str::(&response_text) 167 | .map_err(|error| { 168 | let error = AllmsError { 169 | crate_name: "allms".to_string(), 170 | module: "files::anthropic".to_string(), 171 | error_message: format!( 172 | "Anthropic Files Delete API response serialization error: {}", 173 | error 174 | ), 175 | error_detail: response_text, 176 | }; 177 | error!("{:?}", error); 178 | anyhow!("{:?}", error) 179 | }) 180 | .map(|response| match response.result_type { 181 | AnthropicDeleteResultType::FileDeleted => (), 182 | }) 183 | } 184 | 185 | /// Get the file ID if available 186 | /// 187 | /// # Returns 188 | /// * `Option<&String>` - The file ID if it exists 189 | fn get_id(&self) -> Option<&String> { 190 | self.id.as_ref() 191 | } 192 | 193 | /// Check if debug mode is enabled 194 | /// 195 | /// # Returns 196 | /// * `bool` - True if debug mode is enabled 197 | fn is_debug(&self) -> bool { 198 | self.debug 199 | } 200 | } 201 | 202 | #[derive(Deserialize, Serialize, Debug, Clone)] 203 | struct AnthropicFileResp { 204 | id: String, 205 | created_at: String, 206 | filename: String, 207 | mime_type: String, 208 | size_bytes: usize, 209 | #[serde(rename = "type")] 210 | file_type: AnthropicFileType, 211 | downloadable: bool, 212 | } 213 | 214 | #[derive(Deserialize, Serialize, Debug, Clone, Default)] 215 | #[serde(rename_all = "snake_case")] 216 | enum AnthropicFileType { 217 | #[default] 218 | File, 219 | } 220 | 221 | #[derive(Deserialize, Serialize, Debug, Clone)] 222 | struct AnthropicFileDeleteResp { 223 | id: String, 224 | #[serde(rename = "type")] 225 | result_type: AnthropicDeleteResultType, 226 | } 227 | 228 | #[derive(Deserialize, Serialize, Debug, Clone, Default)] 229 | #[serde(rename_all = "snake_case")] 230 | enum AnthropicDeleteResultType { 231 | #[default] 232 | FileDeleted, 233 | } 234 | -------------------------------------------------------------------------------- /src/llm_models/perplexity.rs: -------------------------------------------------------------------------------- 1 | #![allow(deprecated)] 2 | 3 | use anyhow::{anyhow, Result}; 4 | use async_trait::async_trait; 5 | use log::info; 6 | use reqwest::{header, Client}; 7 | use serde::{Deserialize, Serialize}; 8 | use serde_json::{json, Value}; 9 | 10 | use crate::completions::ThinkingLevel; 11 | use crate::constants::PERPLEXITY_API_URL; 12 | use crate::domain::{PerplexityAPICompletionsResponse, RateLimit}; 13 | use crate::llm_models::{LLMModel, LLMTools}; 14 | use crate::utils::{map_to_range_f32, remove_json_wrapper, remove_think_reasoner_wrapper}; 15 | 16 | // Perplexity API Docs: https://docs.perplexity.ai/api-reference/chat-completions 17 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 18 | pub enum PerplexityModels { 19 | SonarPro, 20 | Sonar, 21 | SonarReasoning, 22 | // Legacy models 23 | #[deprecated( 24 | since = "0.12.0", 25 | note = "`Llama3_1SonarSmall` is deprecated starting February 22, 2025, please use `Sonar` or `SonarPro` instead." 26 | )] 27 | Llama3_1SonarSmall, 28 | #[deprecated( 29 | since = "0.12.0", 30 | note = "`Llama3_1SonarLarge` is deprecated starting February 22, 2025, please use `Sonar` or `SonarPro` instead." 31 | )] 32 | Llama3_1SonarLarge, 33 | #[deprecated( 34 | since = "0.12.0", 35 | note = "`Llama3_1SonarHuge` is deprecated starting February 22, 2025, please use `Sonar` or `SonarPro` instead." 36 | )] 37 | Llama3_1SonarHuge, 38 | } 39 | 40 | #[async_trait(?Send)] 41 | impl LLMModel for PerplexityModels { 42 | fn as_str(&self) -> &str { 43 | match self { 44 | PerplexityModels::SonarPro => "sonar-pro", 45 | PerplexityModels::Sonar => "sonar", 46 | PerplexityModels::SonarReasoning => "sonar-reasoning", 47 | // Legacy models 48 | #[allow(deprecated)] 49 | PerplexityModels::Llama3_1SonarSmall => "llama-3.1-sonar-small-128k-online", 50 | #[allow(deprecated)] 51 | PerplexityModels::Llama3_1SonarLarge => "llama-3.1-sonar-large-128k-online", 52 | #[allow(deprecated)] 53 | PerplexityModels::Llama3_1SonarHuge => "llama-3.1-sonar-huge-128k-online", 54 | } 55 | } 56 | 57 | fn try_from_str(name: &str) -> Option { 58 | match name.to_lowercase().as_str() { 59 | "sonar-pro" => Some(PerplexityModels::SonarPro), 60 | "sonar" => Some(PerplexityModels::Sonar), 61 | "sonar-reasoning" => Some(PerplexityModels::SonarReasoning), 62 | // Legacy models 63 | #[allow(deprecated)] 64 | "llama-3.1-sonar-small-128k-online" => Some(PerplexityModels::Llama3_1SonarSmall), 65 | #[allow(deprecated)] 66 | "llama-3.1-sonar-large-128k-online" => Some(PerplexityModels::Llama3_1SonarLarge), 67 | #[allow(deprecated)] 68 | "llama-3.1-sonar-huge-128k-online" => Some(PerplexityModels::Llama3_1SonarHuge), 69 | _ => None, 70 | } 71 | } 72 | 73 | // https://docs.perplexity.ai/guides/model-cards 74 | fn default_max_tokens(&self) -> usize { 75 | match self { 76 | // Docs: https://docs.perplexity.ai/guides/model-cards 77 | // FYI: sonar-pro has a max output token limit of 8k 78 | PerplexityModels::SonarPro => 200_000, 79 | PerplexityModels::Sonar => 127_072, 80 | PerplexityModels::SonarReasoning => 127_072, 81 | // Legacy models 82 | #[allow(deprecated)] 83 | PerplexityModels::Llama3_1SonarSmall 84 | | PerplexityModels::Llama3_1SonarLarge 85 | | PerplexityModels::Llama3_1SonarHuge => 127_072, 86 | } 87 | } 88 | 89 | fn get_endpoint(&self) -> String { 90 | PERPLEXITY_API_URL.to_string() 91 | } 92 | 93 | //This method prepares the body of the API call for different models 94 | fn get_body( 95 | &self, 96 | instructions: &str, 97 | json_schema: &Value, 98 | function_call: bool, 99 | // The total number of tokens requested in max_tokens plus the number of prompt tokens sent in messages must not exceed the context window token limit of model requested. 100 | // If left unspecified, then the model will generate tokens until either it reaches its stop token or the end of its context window. 101 | _max_tokens: &usize, 102 | temperature: &f32, 103 | _tools: Option<&[LLMTools]>, 104 | _thinking_level: Option<&ThinkingLevel>, 105 | ) -> serde_json::Value { 106 | //Prepare the 'messages' part of the body 107 | let base_instructions = self.get_base_instructions(Some(function_call)); 108 | let system_message = json!({ 109 | "role": "system", 110 | "content": base_instructions, 111 | }); 112 | let user_message = json!({ 113 | "role": "user", 114 | "content": format!( 115 | " 116 | {instructions} 117 | 118 | 119 | {json_schema} 120 | " 121 | ), 122 | }); 123 | json!({ 124 | "model": self.as_str(), 125 | "temperature": temperature, 126 | "messages": vec![ 127 | system_message, 128 | user_message, 129 | ], 130 | }) 131 | } 132 | /// 133 | /// This function leverages Perplexity API to perform any query as per the provided body. 134 | /// 135 | /// It returns a String the Response object that needs to be parsed based on the self.model. 136 | /// 137 | async fn call_api( 138 | &self, 139 | api_key: &str, 140 | _version: Option, 141 | body: &serde_json::Value, 142 | debug: bool, 143 | _tools: Option<&[LLMTools]>, 144 | ) -> Result { 145 | //Get the API url 146 | let model_url = self.get_endpoint(); 147 | 148 | //Make the API call 149 | let client = Client::new(); 150 | 151 | //Send request 152 | let response = client 153 | .post(model_url) 154 | .header(header::CONTENT_TYPE, "application/json") 155 | .bearer_auth(api_key) 156 | .json(&body) 157 | .send() 158 | .await?; 159 | 160 | let response_status = response.status(); 161 | let response_text = response.text().await?; 162 | 163 | if debug { 164 | info!( 165 | "[debug] Perplexity API response: [{}] {:#?}", 166 | &response_status, &response_text 167 | ); 168 | } 169 | 170 | Ok(response_text) 171 | } 172 | 173 | //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response 174 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 175 | //Convert API response to struct representing expected response format 176 | let completions_response: PerplexityAPICompletionsResponse = 177 | serde_json::from_str(response_text)?; 178 | 179 | //Parse the response and return the assistant content 180 | completions_response 181 | .choices 182 | .iter() 183 | .filter_map(|choice| choice.message.as_ref()) 184 | .find(|&message| message.role == Some("assistant".to_string())) 185 | .and_then(|message| { 186 | message 187 | .content 188 | .as_ref() 189 | .map(|content| self.sanitize_json_response(content)) 190 | }) 191 | .ok_or_else(|| anyhow!("Assistant role content not found")) 192 | } 193 | 194 | /// This function sanitizes the text response from Perplexity models to clean up common formatting issues. 195 | /// Currently the function checks: 196 | /// * ```json{}``` wrapper around response 197 | /// * wrapper (for SonarReasoning model only) 198 | fn sanitize_json_response(&self, json_response: &str) -> String { 199 | let no_json_text = remove_json_wrapper(json_response); 200 | if self == &PerplexityModels::SonarReasoning { 201 | remove_think_reasoner_wrapper(&no_json_text) 202 | } else { 203 | no_json_text 204 | } 205 | } 206 | 207 | // This function allows to check the rate limits for different models 208 | fn get_rate_limit(&self) -> RateLimit { 209 | //Perplexity documentation: https://docs.perplexity.ai/guides/rate-limits 210 | RateLimit { 211 | tpm: 50 * 127_072, // 50 requests per minute wit max 127,072 context length 212 | rpm: 50, // 50 request per minute 213 | } 214 | } 215 | 216 | // Accepts a [0-100] percentage range and returns the target temperature based on model ranges 217 | fn get_normalized_temperature(&self, relative_temp: u32) -> f32 { 218 | // Temperature range documentation: https://docs.perplexity.ai/api-reference/chat-completions 219 | // "The amount of randomness in the response, valued between 0 *inclusive* and 2 *exclusive*." 220 | let min = 0.0f32; 221 | let max = 1.99999f32; 222 | map_to_range_f32(min, max, relative_temp) 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /src/completions.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use log::{error, info, warn}; 3 | use schemars::JsonSchema; 4 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; 5 | 6 | use crate::domain::{AllmsError, OpenAIDataResponse}; 7 | use crate::llm_models::{LLMModel, LLMTools}; 8 | use crate::utils::{get_tokenizer, get_type_schema}; 9 | 10 | /// Completions APIs take a list of messages as input and return a model-generated message as output. 11 | /// Although the Completions format is designed to make multi-turn conversations easy, 12 | /// it’s just as useful for single-turn tasks without any conversation. 13 | pub struct Completions { 14 | model: T, 15 | //For prompt & response 16 | max_tokens: usize, 17 | temperature: f32, 18 | input_json: Option, 19 | debug: bool, 20 | function_call: bool, 21 | api_key: String, 22 | version: Option, 23 | tools: Option>, 24 | thinking_level: Option, 25 | } 26 | 27 | impl Completions { 28 | /// Constructor for the Completions API 29 | pub fn new( 30 | model: T, 31 | api_key: &str, 32 | max_tokens: Option, 33 | temperature: Option, 34 | ) -> Self { 35 | let temperature = temperature 36 | .map(|temp| model.get_normalized_temperature(temp)) 37 | .unwrap_or(model.get_default_temperature()); 38 | Completions { 39 | //If no max tokens limit is provided we default to max allowed for the model 40 | max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()), 41 | function_call: model.function_call_default(), 42 | model, 43 | temperature, 44 | input_json: None, 45 | debug: false, 46 | api_key: api_key.to_string(), 47 | version: None, 48 | tools: None, 49 | thinking_level: None, 50 | } 51 | } 52 | 53 | /// 54 | /// This function turns on debug mode which will info! the prompt to log when executing it. 55 | /// 56 | pub fn debug(mut self) -> Self { 57 | self.debug = true; 58 | self 59 | } 60 | 61 | /// 62 | /// This function turns on/off function calling mode when interacting with OpenAI API. 63 | /// 64 | pub fn function_calling(mut self, function_call: bool) -> Self { 65 | self.function_call = function_call; 66 | self 67 | } 68 | 69 | /// 70 | /// This method can be used to define the model temperature used by the Assistant 71 | /// This method accepts % target of the acceptable range for the model 72 | /// 73 | pub fn temperature(mut self, temp_target: u32) -> Self { 74 | self.temperature = self.model.get_normalized_temperature(temp_target); 75 | self 76 | } 77 | 78 | /// 79 | /// This method can be used to define the model temperature used by the Assistant 80 | /// Using this method the temperature can be set directly without any validation of the range accepted by the model 81 | /// For a range-safe implementation please consider using `OpenAIAssistant::temperature` method 82 | /// 83 | pub fn temperature_unchecked(mut self, temp: f32) -> Self { 84 | self.temperature = temp; 85 | self 86 | } 87 | 88 | /// 89 | /// This method can be used to set the version of Completions API to be used 90 | /// This is currently used for OpenAI models which can be run on OpenAI API or Azure API 91 | /// 92 | pub fn version(mut self, version: &str) -> Self { 93 | // TODO: We should use the model trait to check which versions are allowed 94 | self.version = Some(version.to_string()); 95 | self 96 | } 97 | 98 | /// 99 | /// This method can be used to inform the model to use a tool. 100 | /// Different models support different tool implementations. 101 | /// 102 | pub fn add_tool(mut self, tool: LLMTools) -> Self { 103 | self.tools = Some(match self.tools { 104 | Some(mut tools) => { 105 | tools.push(tool); 106 | tools 107 | } 108 | None => vec![tool], 109 | }); 110 | self 111 | } 112 | 113 | /// 114 | /// This method can be used to set the thinking level for the model 115 | /// This is currently used for Gemini 3 models 116 | /// 117 | pub fn thinking_level(mut self, thinking_level: ThinkingLevel) -> Self { 118 | self.thinking_level = Some(thinking_level); 119 | self 120 | } 121 | 122 | /// 123 | /// This method can be used to provide values that will be used as context for the prompt. 124 | /// Using this function you can provide multiple input values by calling it multiple times. New values will be appended with the category name 125 | /// It accepts any instance that implements the Serialize trait. 126 | /// 127 | pub fn set_context(mut self, input_name: &str, input_data: &U) -> Result { 128 | let input_json = if let Ok(json) = serde_json::to_string(&input_data) { 129 | json 130 | } else { 131 | return Err(anyhow!("Unable serialize provided input data.")); 132 | }; 133 | let line_break = match self.input_json { 134 | Some(_) => "\n\n".to_string(), 135 | None => "".to_string(), 136 | }; 137 | let new_json = format!( 138 | "{}{}<{}>{}", 139 | self.input_json.unwrap_or_default(), 140 | line_break, 141 | input_name, 142 | input_json, 143 | input_name, 144 | ); 145 | self.input_json = Some(new_json); 146 | Ok(self) 147 | } 148 | 149 | /// 150 | /// This method is used to check how many tokens would most likely remain for the response 151 | /// This is accomplished by estimating number of tokens needed for system/base instructions, user prompt, and function components including schema definition. 152 | /// 153 | pub fn check_prompt_tokens( 154 | &self, 155 | instructions: &str, 156 | ) -> Result { 157 | //Output schema is extracted from the type parameter 158 | let schema = get_type_schema::()?; 159 | 160 | let context_text = self 161 | .input_json 162 | .as_ref() 163 | .map(|context| format!("\n\n{}", &context)) 164 | .unwrap_or_default(); 165 | 166 | let prompt = format!( 167 | "Instructions: 168 | {instructions}{context_text} 169 | 170 | Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.", 171 | ); 172 | 173 | let full_prompt = format!( 174 | "{}{}{}", 175 | //Base (system) instructions 176 | self.model.get_base_instructions(Some(self.function_call)), 177 | //Instructions & context data 178 | prompt, 179 | //Output schema 180 | schema 181 | ); 182 | 183 | //Check how many tokens are required for prompt 184 | let bpe = get_tokenizer(&self.model)?; 185 | let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len(); 186 | 187 | //Assuming another 5% overhead for json formatting 188 | Ok((prompt_tokens as f64 * 1.05) as usize) 189 | } 190 | 191 | /// 192 | /// This method is used to submit a prompt to OpenAI and process the response. 193 | /// When calling the function you need to specify the type parameter as the response will match the schema of that type. 194 | /// The prompt in this function is written in a way to instruct OpenAI to behave like a computer function that calculates an output based on provided input and its language model. 195 | /// 196 | pub async fn get_answer( 197 | self, 198 | instructions: &str, 199 | ) -> Result { 200 | //Output schema is extracted from the type parameter 201 | let schema = get_type_schema::()?; 202 | let json_schema = serde_json::from_str(&schema)?; 203 | 204 | let context_text = self 205 | .input_json 206 | .as_ref() 207 | .map(|context| format!("\n\n{}", &context)) 208 | .unwrap_or_default(); 209 | 210 | let prompt = format!("{instructions}{context_text}"); 211 | 212 | //Validate how many tokens remain for the response (and how many are used for prompt) 213 | let prompt_tokens = self 214 | .check_prompt_tokens::(instructions) 215 | .unwrap_or_default(); 216 | 217 | if prompt_tokens >= self.max_tokens { 218 | return Err(anyhow!( 219 | "The provided prompt requires more tokens than allocated." 220 | )); 221 | } 222 | let response_tokens = self.max_tokens - prompt_tokens; 223 | 224 | //Throw a warning if after processing the prompt there might be not enough tokens for response 225 | //This assumes response will be similar size as input. Because this is not always correct this is a warning and not an error 226 | if prompt_tokens * 2 >= self.max_tokens { 227 | warn!( 228 | "{} tokens remaining for response: {} allocated, {} used for prompt", 229 | response_tokens, self.max_tokens, prompt_tokens, 230 | ); 231 | }; 232 | 233 | //Build the API body depending on the used model 234 | let model_body = self.model.get_version_body( 235 | &prompt, 236 | &json_schema, 237 | self.function_call, 238 | &response_tokens, 239 | &self.temperature, 240 | self.version.clone(), 241 | self.tools.as_deref(), 242 | self.thinking_level.as_ref(), 243 | ); 244 | 245 | //Display debug info if requested 246 | if self.debug { 247 | info!("[debug] Model body: {:#?}", model_body); 248 | info!( 249 | "[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.", 250 | prompt_tokens, response_tokens, 251 | ); 252 | } 253 | 254 | let response_text = self 255 | .model 256 | .call_api( 257 | &self.api_key, 258 | self.version.clone(), 259 | &model_body, 260 | self.debug, 261 | self.tools.as_deref(), 262 | ) 263 | .await?; 264 | 265 | //Extract data from the returned response text based on the used model 266 | let response_string = self 267 | .model 268 | .get_version_data(&response_text, self.function_call, self.version) 269 | .map_err(|error| { 270 | let error = AllmsError { 271 | crate_name: "allms".to_string(), 272 | module: format!("assistants::completions::{}", self.model.as_str()), 273 | error_message: format!( 274 | "Completions API response serialization error: {}", 275 | error 276 | ), 277 | error_detail: response_text.to_string(), 278 | }; 279 | error!("{:?}", error); 280 | anyhow!("{:?}", error) 281 | })?; 282 | 283 | if self.debug { 284 | info!("[debug] Completions response data: {}", response_string); 285 | } 286 | //Deserialize the string response into the expected output type 287 | let response_deser: anyhow::Result = 288 | serde_json::from_str(&response_string).map_err(|error| { 289 | let error = AllmsError { 290 | crate_name: "allms".to_string(), 291 | module: format!("assistants::completions::{}", self.model.as_str()), 292 | error_message: format!( 293 | "Completions API response serialization error: {}", 294 | error 295 | ), 296 | error_detail: response_string, 297 | }; 298 | error!("{:?}", error); 299 | anyhow!("{:?}", error) 300 | }); 301 | // Sometimes openai responds with a json object that has a data property. If that's the case, we need to extract the data property and deserialize that. 302 | // TODO: This is OpenAI specific and should be implemented within the model. 303 | if let Err(_e) = response_deser { 304 | let response_deser: OpenAIDataResponse = serde_json::from_str(&response_text) 305 | .map_err(|error| { 306 | let error = AllmsError { 307 | crate_name: "allms".to_string(), 308 | module: format!("assistants::completions::{}", self.model.as_str()), 309 | error_message: format!( 310 | "Completions API response serialization error: {}", 311 | error 312 | ), 313 | error_detail: response_text, 314 | }; 315 | error!("{:?}", error); 316 | anyhow!("{:?}", error) 317 | })?; 318 | Ok(response_deser.data) 319 | } else { 320 | response_deser 321 | } 322 | } 323 | } 324 | 325 | #[derive(Debug, Serialize, Deserialize, Default, PartialEq, Eq)] 326 | #[serde(rename_all = "snake_case")] 327 | pub enum ThinkingLevel { 328 | Low, 329 | #[default] 330 | High, 331 | } 332 | 333 | impl ThinkingLevel { 334 | pub fn as_str(&self) -> &str { 335 | match self { 336 | ThinkingLevel::Low => "low", 337 | ThinkingLevel::High => "high", 338 | } 339 | } 340 | 341 | pub fn try_from_str(s: &str) -> Option { 342 | match s { 343 | "low" => Some(ThinkingLevel::Low), 344 | "high" => Some(ThinkingLevel::High), 345 | _ => None, 346 | } 347 | } 348 | } 349 | -------------------------------------------------------------------------------- /src/assistants/openai/openai_vector_store.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use log::{error, info}; 3 | use reqwest::Client; 4 | use serde::{Deserialize, Serialize}; 5 | use serde_json::json; 6 | 7 | use crate::assistants::{OpenAIAssistantResource, OpenAIAssistantVersion}; 8 | use crate::domain::AllmsError; 9 | 10 | #[derive(Deserialize, Serialize, Debug, Clone)] 11 | pub struct OpenAIVectorStore { 12 | pub id: Option, 13 | pub name: String, 14 | api_key: String, 15 | status: OpenAIVectorStoreStatus, 16 | debug: bool, 17 | version: OpenAIAssistantVersion, 18 | } 19 | 20 | impl OpenAIVectorStore { 21 | /// Constructor 22 | pub fn new(id: Option, name: &str, api_key: &str) -> Self { 23 | OpenAIVectorStore { 24 | id, 25 | name: name.to_string(), 26 | api_key: api_key.to_string(), 27 | status: OpenAIVectorStoreStatus::InProgress, 28 | debug: false, 29 | version: OpenAIAssistantVersion::V2, 30 | } 31 | } 32 | 33 | /// 34 | /// This method can be used to set turn on/off the debug mode 35 | /// 36 | pub fn debug(mut self) -> Self { 37 | self.debug = !self.debug; 38 | self 39 | } 40 | 41 | /// 42 | /// This method can be used to set the version of Assistants API Beta 43 | /// Current default is V2 44 | /// 45 | pub fn version(mut self, version: OpenAIAssistantVersion) -> Self { 46 | // VectorStores endpoint is only available for v2 so if v1 is selected we overwrite 47 | let version = match version { 48 | OpenAIAssistantVersion::V1 => OpenAIAssistantVersion::V2, 49 | _ => version, 50 | }; 51 | self.version = version; 52 | self 53 | } 54 | 55 | /* 56 | * This function creates a new Vector Store and updates the ID of the struct 57 | */ 58 | async fn create(&mut self, file_ids: Option>) -> Result<()> { 59 | let vector_store_url = self 60 | .version 61 | .get_endpoint(&OpenAIAssistantResource::VectorStores); 62 | 63 | //Make the API call 64 | let client = Client::new(); 65 | 66 | //Get the version-specific header 67 | let version_headers = self.version.get_headers(&self.api_key); 68 | 69 | let mut body = json!({ 70 | "name": self.name.clone(), 71 | }); 72 | if let Some(ids) = file_ids { 73 | body["file_ids"] = json!(ids.to_vec()); 74 | } 75 | 76 | let response = client 77 | .post(vector_store_url) 78 | .headers(version_headers) 79 | .json(&body) 80 | .send() 81 | .await?; 82 | 83 | let response_status = response.status(); 84 | let response_text = response.text().await?; 85 | 86 | if self.debug { 87 | info!( 88 | "[allms][OpenAI][VectorStore][debug] VectorStore Create API response: [{}] {:#?}", 89 | &response_status, &response_text 90 | ); 91 | } 92 | 93 | //Deserialize the string response into the Assistant object 94 | let response_deser: OpenAIVectorStoreResp = 95 | serde_json::from_str(&response_text).map_err(|error| { 96 | let error = AllmsError { 97 | crate_name: "allms".to_string(), 98 | module: "assistants::openai_vector_store".to_string(), 99 | error_message: format!( 100 | "VectorStore Create API response serialization error: {}", 101 | error 102 | ), 103 | error_detail: response_text, 104 | }; 105 | error!("{:?}", error); 106 | anyhow!("{:?}", error) 107 | })?; 108 | 109 | //Add correct ID & status to self 110 | self.id = Some(response_deser.id); 111 | self.status = response_deser.status; 112 | 113 | Ok(()) 114 | } 115 | 116 | /// 117 | /// This method uploads files to a Vector Store. If no ID was provided the method first creates the Vector Store 118 | /// 119 | pub async fn upload(&mut self, file_ids: &[String]) -> Result { 120 | // If the Vector Store was not yet created we do that first 121 | if self.id.is_none() { 122 | self.create(Some(file_ids.to_vec())).await?; 123 | } else { 124 | // If working with existing Vector Store we simply upload files 125 | self.assign_to_store(file_ids).await?; 126 | } 127 | Ok(self.clone()) 128 | } 129 | 130 | /* 131 | * This function assigns OpenAI Files to an existing Vector Store 132 | */ 133 | async fn assign_to_store(&self, file_ids: &[String]) -> Result<()> { 134 | // The function requires an ID of an existing vector store 135 | let vs_id = if let Some(id) = &self.id { 136 | id 137 | } else { 138 | return Err(anyhow!( 139 | "[allms][OpenAI][VectorStore][debug] Unable to assign files. No ID provided." 140 | )); 141 | }; 142 | 143 | // Construct the API url 144 | let vector_store_resource = OpenAIAssistantResource::VectorStoreFileBatches { 145 | vector_store_id: vs_id.to_string(), 146 | }; 147 | let url = self.version.get_endpoint(&vector_store_resource); 148 | 149 | //Get the version-specific header 150 | let version_headers = self.version.get_headers(&self.api_key); 151 | 152 | //Make the API call 153 | let client = Client::new(); 154 | 155 | let body = json!({ 156 | "file_ids": file_ids.to_vec(), 157 | }); 158 | 159 | let response = client 160 | .post(&url) 161 | .headers(version_headers) 162 | .json(&body) 163 | .send() 164 | .await?; 165 | 166 | let response_status = response.status(); 167 | let response_text = response.text().await?; 168 | 169 | if self.debug { 170 | info!( 171 | "[allms][OpenAI][VectorStore][debug] VectorStore Batch Upload API response: [{}] {:#?}", 172 | &response_status, &response_text 173 | ); 174 | } 175 | 176 | //Deserialize & validate the string response 177 | serde_json::from_str::(&response_text) 178 | .map_err(|error| { 179 | let error = AllmsError { 180 | crate_name: "allms".to_string(), 181 | module: "assistants::openai_vector_store".to_string(), 182 | error_message: format!( 183 | "VectorStore Batch Upload API response serialization error: {}", 184 | error 185 | ), 186 | error_detail: response_text, 187 | }; 188 | error!("{:?}", error); 189 | anyhow!("{:?}", error) 190 | }) 191 | .map(|_| Ok(()))? 192 | } 193 | 194 | /// 195 | /// This method checks the status of a Vector Store 196 | /// 197 | pub async fn status(&self) -> Result { 198 | // Requires an ID of an existing vector store 199 | let vs_id = if let Some(id) = &self.id { 200 | id 201 | } else { 202 | return Err(anyhow!( 203 | "[allms][OpenAI][VectorStore][debug] Unable to check status. No ID provided." 204 | )); 205 | }; 206 | 207 | // Construct the API url 208 | let vector_store_resource = OpenAIAssistantResource::VectorStore { 209 | vector_store_id: vs_id.to_string(), 210 | }; 211 | let url = self.version.get_endpoint(&vector_store_resource); 212 | 213 | //Get the version-specific header 214 | let version_headers = self.version.get_headers(&self.api_key); 215 | 216 | //Make the API call 217 | let client = Client::new(); 218 | 219 | let response = client.get(&url).headers(version_headers).send().await?; 220 | 221 | let response_status = response.status(); 222 | let response_text = response.text().await?; 223 | 224 | if self.debug { 225 | info!( 226 | "[allms][OpenAI][VectorStore][debug] VectorStore Status API response: [{}] {:#?}", 227 | &response_status, &response_text 228 | ); 229 | } 230 | 231 | //Deserialize & validate the string response 232 | let response_deser: OpenAIVectorStoreResp = 233 | serde_json::from_str(&response_text).map_err(|error| { 234 | let error = AllmsError { 235 | crate_name: "allms".to_string(), 236 | module: "assistants::openai_vector_store".to_string(), 237 | error_message: format!( 238 | "VectorStore Status API response serialization error: {}", 239 | error 240 | ), 241 | error_detail: response_text, 242 | }; 243 | error!("{:?}", error); 244 | anyhow!("{:?}", error) 245 | })?; 246 | Ok(response_deser.status) 247 | } 248 | 249 | /// 250 | /// This method checks the counts of files added to a Vector Store and their statuses 251 | /// 252 | pub async fn file_count(&self) -> Result { 253 | // Requires an ID of an existing vector store 254 | let vs_id = if let Some(id) = &self.id { 255 | id 256 | } else { 257 | return Err(anyhow!( 258 | "[allms][OpenAI][VectorStore][debug] Unable to check status. No ID provided." 259 | )); 260 | }; 261 | 262 | // Construct the API url 263 | let vector_store_resource = OpenAIAssistantResource::VectorStore { 264 | vector_store_id: vs_id.to_string(), 265 | }; 266 | let url = self.version.get_endpoint(&vector_store_resource); 267 | 268 | //Get the version-specific header 269 | let version_headers = self.version.get_headers(&self.api_key); 270 | 271 | //Make the API call 272 | let client = Client::new(); 273 | 274 | let response = client.get(&url).headers(version_headers).send().await?; 275 | 276 | let response_status = response.status(); 277 | let response_text = response.text().await?; 278 | 279 | if self.debug { 280 | info!( 281 | "[allms][OpenAI][VectorStore][debug] VectorStore Status API response: [{}] {:#?}", 282 | &response_status, &response_text 283 | ); 284 | } 285 | 286 | //Deserialize & validate the string response 287 | let response_deser: OpenAIVectorStoreResp = 288 | serde_json::from_str(&response_text).map_err(|error| { 289 | let error = AllmsError { 290 | crate_name: "allms".to_string(), 291 | module: "assistants::openai_vector_store".to_string(), 292 | error_message: format!( 293 | "VectorStore Status API response serialization error: {}", 294 | error 295 | ), 296 | error_detail: response_text, 297 | }; 298 | error!("{:?}", error); 299 | anyhow!("{:?}", error) 300 | })?; 301 | Ok(response_deser.file_counts) 302 | } 303 | 304 | /// 305 | /// This method can be used to delete a Vector Store 306 | /// 307 | pub async fn delete(&self) -> Result<()> { 308 | // Requires an ID of an existing vector store 309 | let vs_id = if let Some(id) = &self.id { 310 | id 311 | } else { 312 | return Err(anyhow!( 313 | "[allms][OpenAI][VectorStore][debug] Unable to delete. No ID provided." 314 | )); 315 | }; 316 | 317 | // Construct the API url 318 | let vector_store_resource = OpenAIAssistantResource::VectorStore { 319 | vector_store_id: vs_id.to_string(), 320 | }; 321 | let url = self.version.get_endpoint(&vector_store_resource); 322 | 323 | //Get the version-specific header 324 | let version_headers = self.version.get_headers(&self.api_key); 325 | 326 | //Make the API call 327 | let client = Client::new(); 328 | 329 | let response = client.delete(&url).headers(version_headers).send().await?; 330 | 331 | let response_status = response.status(); 332 | let response_text = response.text().await?; 333 | 334 | if self.debug { 335 | info!( 336 | "[allms][OpenAI][VectorStore][debug] VectorStore Delete API response: [{}] {:#?}", 337 | &response_status, &response_text 338 | ); 339 | } 340 | 341 | //Deserialize & validate the string response 342 | serde_json::from_str::(&response_text) 343 | .map_err(|error| { 344 | let error = AllmsError { 345 | crate_name: "allms".to_string(), 346 | module: "assistants::openai_vector_store".to_string(), 347 | error_message: format!( 348 | "VectorStore Delete API response serialization error: {}", 349 | error 350 | ), 351 | error_detail: response_text, 352 | }; 353 | error!("{:?}", error); 354 | anyhow!("{:?}", error) 355 | }) 356 | .and_then(|response| match response.deleted { 357 | true => Ok(()), 358 | false => Err(anyhow!( 359 | "[OpenAIAssistant] VectorStore Delete API failed to delete the store." 360 | )), 361 | }) 362 | } 363 | } 364 | 365 | /****************************************************************************************** 366 | * 367 | * API Response serialization / deserialization structs 368 | * 369 | ******************************************************************************************/ 370 | #[derive(Deserialize, Serialize, Debug, Clone)] 371 | struct OpenAIVectorStoreResp { 372 | id: String, 373 | name: String, 374 | status: OpenAIVectorStoreStatus, 375 | created_at: i64, 376 | expires_at: Option, 377 | last_active_at: Option, 378 | file_counts: OpenAIVectorStoreFileCounts, 379 | } 380 | 381 | #[derive(Deserialize, Serialize, Debug, Clone)] 382 | pub struct OpenAIVectorStoreFileCounts { 383 | pub in_progress: i32, 384 | pub completed: i32, 385 | pub failed: i32, 386 | pub cancelled: i32, 387 | pub total: i32, 388 | } 389 | 390 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 391 | pub enum OpenAIVectorStoreStatus { 392 | #[serde(rename(deserialize = "expired", serialize = "expired"))] 393 | Expired, 394 | #[serde(rename(deserialize = "in_progress", serialize = "in_progress"))] 395 | InProgress, 396 | #[serde(rename(deserialize = "completed", serialize = "completed"))] 397 | Completed, 398 | } 399 | 400 | #[derive(Deserialize, Serialize, Debug, Clone)] 401 | struct OpenAIVectorStoreFileBatchResp { 402 | id: String, 403 | vector_store_id: String, 404 | status: OpenAIVectorStoreFileBatchStatus, 405 | created_at: i64, 406 | file_counts: OpenAIVectorStoreFileCounts, 407 | } 408 | 409 | #[derive(Deserialize, Serialize, Debug, Clone)] 410 | pub enum OpenAIVectorStoreFileBatchStatus { 411 | #[serde(rename(deserialize = "in_progress", serialize = "in_progress"))] 412 | InProgress, 413 | #[serde(rename(deserialize = "completed", serialize = "completed"))] 414 | Completed, 415 | #[serde(rename(deserialize = "cancelled", serialize = "cancelled"))] 416 | Cancelled, 417 | #[serde(rename(deserialize = "failed", serialize = "failed"))] 418 | Failed, 419 | } 420 | 421 | #[derive(Deserialize, Serialize, Debug, Clone)] 422 | struct OpenAIVectorStoreDeleteResp { 423 | id: String, 424 | deleted: bool, 425 | } 426 | -------------------------------------------------------------------------------- /src/llm_models/mistral.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use async_trait::async_trait; 3 | use log::info; 4 | use reqwest::{header, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | use serde_json::{json, Value}; 7 | 8 | use crate::completions::ThinkingLevel; 9 | use crate::constants::{MISTRAL_API_URL, MISTRAL_CONVERSATIONS_API_URL}; 10 | use crate::domain::{ 11 | MistralAPICompletionsResponse, MistralAPIConversationsChunk, 12 | MistralAPIConversationsMessageOutputContent, MistralAPIConversationsOutput, 13 | MistralAPIConversationsResponse, RateLimit, 14 | }; 15 | use crate::llm_models::{ 16 | tools::{MistralCodeInterpreterConfig, MistralWebSearchConfig}, 17 | LLMModel, LLMTools, 18 | }; 19 | use crate::utils::has_values; 20 | 21 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 22 | //Mistral docs: https://docs.mistral.ai/getting-started/models/models_overview/ 23 | pub enum MistralModels { 24 | // Frontier multimodal models 25 | MistralLarge2_1, 26 | MistralMedium3_1, 27 | MistralMedium3, 28 | MistralSmall3_2, 29 | MistralSmall3_1, 30 | MistralSmall3, 31 | MistralSmall2, 32 | // Frontier reasoning models 33 | MagistralMedium1_2, 34 | MagistralMedium, 35 | MagistralSmall1_2, 36 | // Other frontier models 37 | Codestral2508, 38 | Codestral2, 39 | Ministral3B, 40 | Ministral8B, 41 | // Legacy models 42 | MistralLarge, 43 | MistralNemo, 44 | Mistral7B, 45 | Mixtral8x7B, 46 | Mixtral8x22B, 47 | MistralTiny, 48 | MistralSmall, 49 | MistralMedium, 50 | } 51 | 52 | #[async_trait(?Send)] 53 | impl LLMModel for MistralModels { 54 | fn as_str(&self) -> &str { 55 | match self { 56 | // Frontier multimodal models 57 | MistralModels::MistralLarge2_1 => "mistral-large-latest", 58 | MistralModels::MistralMedium3_1 => "mistral-medium-latest", 59 | MistralModels::MistralMedium3 => "mistral-medium-2505", 60 | MistralModels::MistralSmall3_2 => "mistral-small-latest", 61 | MistralModels::MistralSmall3_1 => "mistral-small-2503", 62 | MistralModels::MistralSmall3 => "mistral-small-2501", 63 | MistralModels::MistralSmall2 => "mistral-small-2407", 64 | // Frontier reasoning models 65 | MistralModels::MagistralMedium1_2 => "magistral-medium-latest", 66 | MistralModels::MagistralMedium => "magistral-medium-2506", 67 | MistralModels::MagistralSmall1_2 => "magistral-small-latest", 68 | // Other frontier models 69 | MistralModels::Codestral2508 => "codestral-2508", 70 | MistralModels::Codestral2 => "codestral-2501", 71 | MistralModels::Ministral3B => "ministral-3b-2410", 72 | MistralModels::Ministral8B => "ministral-8b-2410", 73 | // Legacy 74 | MistralModels::MistralLarge => "mistral-large-latest", 75 | MistralModels::MistralNemo => "open-mistral-nemo", 76 | MistralModels::Mistral7B => "open-mistral-7b", 77 | MistralModels::Mixtral8x7B => "open-mixtral-8x7b", 78 | MistralModels::Mixtral8x22B => "open-mixtral-8x22b", 79 | MistralModels::MistralTiny => "mistral-tiny", 80 | MistralModels::MistralSmall => "mistral-small", 81 | MistralModels::MistralMedium => "mistral-medium", 82 | } 83 | } 84 | 85 | fn try_from_str(name: &str) -> Option { 86 | match name.to_lowercase().as_str() { 87 | // Frontier multimodal models 88 | "mistral-large-latest" => Some(MistralModels::MistralLarge2_1), 89 | "mistral-large-2411" => Some(MistralModels::MistralLarge2_1), 90 | "mistral-medium-latest" => Some(MistralModels::MistralMedium3_1), 91 | "mistral-medium-2508" => Some(MistralModels::MistralMedium3_1), 92 | "mistral-medium-2505" => Some(MistralModels::MistralMedium3), 93 | "mistral-small-latest" => Some(MistralModels::MistralSmall3_2), 94 | "mistral-small-2506" => Some(MistralModels::MistralSmall3_2), 95 | "mistral-small-2503" => Some(MistralModels::MistralSmall3_1), 96 | "mistral-small-2501" => Some(MistralModels::MistralSmall3), 97 | "mistral-small-2407" => Some(MistralModels::MistralSmall2), 98 | // Frontier reasoning models 99 | "magistral-medium-latest" => Some(MistralModels::MagistralMedium1_2), 100 | "magistral-medium-2506" => Some(MistralModels::MagistralMedium), 101 | "magistral-small-latest" => Some(MistralModels::MagistralSmall1_2), 102 | // Other frontier models 103 | "codestral-latest" => Some(MistralModels::Codestral2508), 104 | "codestral-2508" => Some(MistralModels::Codestral2508), 105 | "codestral-2501" => Some(MistralModels::Codestral2), 106 | "ministral-3b-2410" => Some(MistralModels::Ministral3B), 107 | "ministral-3b-latest" => Some(MistralModels::Ministral3B), 108 | "ministral-8b-2410" => Some(MistralModels::Ministral8B), 109 | "ministral-8b-latest" => Some(MistralModels::Ministral8B), 110 | // Legacy 111 | "open-mistral-nemo" => Some(MistralModels::MistralNemo), 112 | "open-mistral-7b" => Some(MistralModels::Mistral7B), 113 | "open-mixtral-8x7b" => Some(MistralModels::Mixtral8x7B), 114 | "open-mixtral-8x22b" => Some(MistralModels::Mixtral8x22B), 115 | "mistral-tiny" => Some(MistralModels::MistralTiny), 116 | "mistral-small" => Some(MistralModels::MistralSmall), 117 | "mistral-medium" => Some(MistralModels::MistralMedium), 118 | _ => None, 119 | } 120 | } 121 | 122 | fn default_max_tokens(&self) -> usize { 123 | match self { 124 | // Frontier multimodal models 125 | MistralModels::MistralLarge2_1 => 128_000, 126 | MistralModels::MistralMedium3_1 => 128_000, 127 | MistralModels::MistralMedium3 => 128_000, 128 | MistralModels::MistralSmall3_2 => 128_000, 129 | MistralModels::MistralSmall3_1 => 128_000, 130 | MistralModels::MistralSmall3 => 32_000, 131 | MistralModels::MistralSmall2 => 32_000, 132 | // Frontier reasoning models 133 | MistralModels::MagistralMedium1_2 => 128_000, 134 | MistralModels::MagistralMedium => 128_000, 135 | MistralModels::MagistralSmall1_2 => 128_000, 136 | // Other frontier models 137 | MistralModels::Codestral2508 => 256_000, 138 | MistralModels::Codestral2 => 256_000, 139 | MistralModels::Ministral3B => 128_000, 140 | MistralModels::Ministral8B => 128_000, 141 | // Legacy 142 | MistralModels::MistralLarge => 128_000, 143 | MistralModels::MistralNemo => 128_000, 144 | MistralModels::Mistral7B => 32_000, 145 | MistralModels::Mixtral8x7B => 32_000, 146 | MistralModels::Mixtral8x22B => 64_000, 147 | MistralModels::MistralTiny => 32_000, 148 | MistralModels::MistralSmall => 32_000, 149 | MistralModels::MistralMedium => 32_000, 150 | } 151 | } 152 | 153 | fn get_endpoint(&self) -> String { 154 | MISTRAL_API_URL.to_string() 155 | } 156 | 157 | // This method prepares the body of the API call for different models 158 | fn get_body( 159 | &self, 160 | instructions: &str, 161 | json_schema: &Value, 162 | function_call: bool, 163 | max_tokens: &usize, 164 | temperature: &f32, 165 | tools: Option<&[LLMTools]>, 166 | _thinking_level: Option<&ThinkingLevel>, 167 | ) -> serde_json::Value { 168 | if has_values(tools) { 169 | self.get_conversations_body( 170 | instructions, 171 | json_schema, 172 | function_call, 173 | max_tokens, 174 | temperature, 175 | tools, 176 | ) 177 | } else { 178 | self.get_chat_completions_body( 179 | instructions, 180 | json_schema, 181 | function_call, 182 | max_tokens, 183 | temperature, 184 | tools, 185 | ) 186 | } 187 | } 188 | /* 189 | * This function leverages Mistral API to perform any query as per the provided body. 190 | * 191 | * It returns a String the Response object that needs to be parsed based on the self.model. 192 | */ 193 | async fn call_api( 194 | &self, 195 | api_key: &str, 196 | _version: Option, 197 | body: &serde_json::Value, 198 | debug: bool, 199 | tools: Option<&[LLMTools]>, 200 | ) -> Result { 201 | //Get the API url 202 | let model_url = if has_values(tools) { 203 | // If tools are provided we need to use the conversations API 204 | MISTRAL_CONVERSATIONS_API_URL.to_string() 205 | } else { 206 | self.get_endpoint() 207 | }; 208 | 209 | //Make the API call 210 | let client = Client::new(); 211 | 212 | //Send request 213 | let response = client 214 | .post(model_url) 215 | .header(header::CONTENT_TYPE, "application/json") 216 | .bearer_auth(api_key) 217 | .json(&body) 218 | .send() 219 | .await?; 220 | 221 | let response_status = response.status(); 222 | let response_text = response.text().await?; 223 | 224 | if debug { 225 | info!( 226 | "[debug] Mistral API response: [{}] {:#?}", 227 | &response_status, &response_text 228 | ); 229 | } 230 | 231 | Ok(response_text) 232 | } 233 | 234 | //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response 235 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 236 | if let Ok(data) = self.get_chat_completions_data(response_text, _function_call) { 237 | Ok(data) 238 | } else { 239 | self.get_conversations_data(response_text, _function_call) 240 | } 241 | } 242 | 243 | //This function allows to check the rate limits for different models 244 | fn get_rate_limit(&self) -> RateLimit { 245 | //Mistral documentation: https://docs.mistral.ai/platform/pricing#rate-limits 246 | RateLimit { 247 | tpm: 2_000_000, 248 | rpm: 360, // 6 requests per second 249 | } 250 | } 251 | } 252 | 253 | impl MistralModels { 254 | fn get_supported_tools(&self) -> Vec { 255 | vec![ 256 | LLMTools::MistralWebSearch(MistralWebSearchConfig::new()), 257 | LLMTools::MistralCodeInterpreter(MistralCodeInterpreterConfig::new()), 258 | ] 259 | } 260 | 261 | fn get_chat_completions_body( 262 | &self, 263 | instructions: &str, 264 | json_schema: &Value, 265 | function_call: bool, 266 | max_tokens: &usize, 267 | temperature: &f32, 268 | _tools: Option<&[LLMTools]>, 269 | ) -> serde_json::Value { 270 | //Prepare the 'messages' part of the body 271 | let base_instructions = self.get_base_instructions(Some(function_call)); 272 | let system_message = json!({ 273 | "role": "system", 274 | "content": base_instructions, 275 | }); 276 | let user_message = json!({ 277 | "role": "user", 278 | "content": format!( 279 | " 280 | {instructions} 281 | 282 | 283 | {json_schema} 284 | " 285 | ), 286 | }); 287 | json!({ 288 | "model": self.as_str(), 289 | "max_tokens": max_tokens, 290 | "temperature": temperature, 291 | "messages": vec![ 292 | system_message, 293 | user_message, 294 | ], 295 | }) 296 | } 297 | 298 | fn get_conversations_body( 299 | &self, 300 | instructions: &str, 301 | json_schema: &Value, 302 | function_call: bool, 303 | max_tokens: &usize, 304 | temperature: &f32, 305 | tools: Option<&[LLMTools]>, 306 | ) -> serde_json::Value { 307 | // Prepare the inputs part of the body 308 | let base_instructions = self.get_base_instructions(Some(function_call)); 309 | let inputs = format!( 310 | "{base_instructions} 311 | 312 | {instructions} 313 | 314 | 315 | {json_schema} 316 | " 317 | ); 318 | // Prepare the completion args part of the body 319 | let completion_args = json!({ 320 | "max_tokens": max_tokens, 321 | "temperature": temperature, 322 | }); 323 | // Prepare the tools part of the body 324 | let tools = if let Some(tools) = tools { 325 | // Filter out the tools that are not supported by the model 326 | let supported_tools = self.get_supported_tools(); 327 | let tools = tools 328 | .iter() 329 | .filter(|tool| supported_tools.contains(tool)) 330 | .filter_map(|tool| tool.get_config_json()) 331 | .collect::>(); 332 | 333 | // If no tools are supported return None 334 | if tools.is_empty() { 335 | None 336 | } else { 337 | Some(tools) 338 | } 339 | } else { 340 | None 341 | }; 342 | 343 | // Prepare and return the body 344 | json!({ 345 | "model": self.as_str(), 346 | "inputs": inputs, 347 | "completion_args": completion_args, 348 | "tools": tools, 349 | }) 350 | } 351 | 352 | fn get_chat_completions_data( 353 | &self, 354 | response_text: &str, 355 | _function_call: bool, 356 | ) -> Result { 357 | //Convert API response to struct representing expected response format 358 | let completions_response: MistralAPICompletionsResponse = 359 | serde_json::from_str(response_text)?; 360 | 361 | //Parse the response and return the assistant content 362 | completions_response 363 | .choices 364 | .iter() 365 | .filter_map(|choice| choice.message.as_ref()) 366 | .find(|&message| message.role == Some("assistant".to_string())) 367 | .and_then(|message| { 368 | message 369 | .content 370 | .as_ref() 371 | .map(|content| self.sanitize_json_response(content)) 372 | }) 373 | .ok_or_else(|| anyhow!("Assistant role content not found")) 374 | } 375 | 376 | fn get_conversations_data(&self, response_text: &str, _function_call: bool) -> Result { 377 | //Convert API response to struct representing expected response format 378 | let conversations_response: MistralAPIConversationsResponse = 379 | serde_json::from_str(response_text)?; 380 | 381 | // Parse the response and return the assistant content 382 | let content_text = conversations_response 383 | .outputs 384 | .iter() 385 | .find_map(|output| { 386 | if let MistralAPIConversationsOutput::MistralAPIConversationsMessageOutput(message_output) = output { 387 | message_output.content.as_ref().map(|content| { 388 | match content { 389 | MistralAPIConversationsMessageOutputContent::MistralAPIConversationsMessageOutputContentString(s) => s.clone(), 390 | MistralAPIConversationsMessageOutputContent::MistralAPIConversationsMessageOutputContentChunks(chunks) => { 391 | chunks 392 | .iter() 393 | .map(|chunk| { 394 | match chunk { 395 | MistralAPIConversationsChunk::MistralAPIConversationsChunkText(text_chunk) => { 396 | text_chunk.text.clone() 397 | } 398 | } 399 | }) 400 | .collect::>() 401 | .join("") 402 | } 403 | } 404 | }) 405 | } else { 406 | None 407 | } 408 | }) 409 | .ok_or_else(|| anyhow!("Message output content not found"))?; 410 | 411 | let sanitized = self.sanitize_json_response(&content_text); 412 | Ok(sanitized) 413 | } 414 | } 415 | -------------------------------------------------------------------------------- /src/apis/openai.rs: -------------------------------------------------------------------------------- 1 | #![allow(deprecated)] 2 | 3 | use anyhow::{anyhow, Result}; 4 | use reqwest::header::{self, HeaderMap, HeaderValue}; 5 | use serde::{Deserialize, Serialize}; 6 | use serde_json::{json, Value}; 7 | use std::str::FromStr; 8 | 9 | use crate::constants::{DEFAULT_AZURE_VERSION, OPENAI_API_URL}; 10 | 11 | /// 12 | /// Enum of supported Completions and Responses APIs (non-Assistant APIs) 13 | /// 14 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)] 15 | pub enum OpenAiApiEndpoints { 16 | #[deprecated(note = "Use OpenAICompletions instead")] 17 | OpenAI, 18 | #[default] 19 | OpenAICompletions, 20 | OpenAIResponses, 21 | #[deprecated(note = "Use AzureCompletions instead")] 22 | Azure { 23 | version: String, 24 | }, 25 | AzureCompletions { 26 | version: String, 27 | }, 28 | AzureResponses { 29 | version: String, 30 | }, 31 | } 32 | 33 | /// Type alias for backward compatibility 34 | pub type OpenAICompletionsAPI = OpenAiApiEndpoints; 35 | 36 | impl OpenAiApiEndpoints { 37 | /// Default version of Azure set to `2025-01-01-preview` as of 5/9/2025 38 | pub fn default_azure_version() -> String { 39 | "2025-01-01-preview".to_string() 40 | } 41 | 42 | /// Parses a string into `OpenAiApiEndpoints`. 43 | /// 44 | /// Supported formats (case-insensitive): 45 | /// - `"OpenAI"` or `"openai_completions"` -> `OpenAiApiEndpoints::OpenAICompletions` 46 | /// - `"openai_responses"` -> `OpenAiApiEndpoints::OpenAIResponses` 47 | /// - `"azure:"` or `"azure_completions:"` -> `OpenAiApiEndpoints::AzureCompletions { version }` 48 | /// - `"azure_responses:"` -> `OpenAiApiEndpoints::AzureResponses { version }` 49 | /// 50 | /// Returns default for others. 51 | #[allow(clippy::should_implement_trait)] 52 | pub fn from_str(s: &str) -> Self { 53 | let s_lower = s.to_lowercase(); 54 | match s_lower.as_str() { 55 | "openai" | "openai_completions" => OpenAiApiEndpoints::OpenAICompletions, 56 | "openai_responses" => OpenAiApiEndpoints::OpenAIResponses, 57 | _ if s_lower.starts_with("azure") || s_lower.starts_with("azure_completions") => { 58 | let version = s_lower 59 | .strip_prefix("azure:") 60 | .or_else(|| s_lower.strip_prefix("azure_completions:")) 61 | .map(|v| v.trim().to_string()) 62 | .unwrap_or_else(OpenAICompletionsAPI::default_azure_version); 63 | 64 | OpenAICompletionsAPI::AzureCompletions { version } 65 | } 66 | _ if s_lower.starts_with("azure_responses") => { 67 | let version = s_lower 68 | .strip_prefix("azure_responses:") 69 | .map(|v| v.trim().to_string()) 70 | .unwrap_or_else(OpenAICompletionsAPI::default_azure_version); 71 | 72 | OpenAICompletionsAPI::AzureResponses { version } 73 | } 74 | _ => OpenAiApiEndpoints::default(), 75 | } 76 | } 77 | } 78 | 79 | /// 80 | /// OpenAI Assistant Version 81 | /// 82 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 83 | pub enum OpenAIAssistantVersion { 84 | V1, 85 | V2, 86 | Azure, 87 | AzureVersion { version: String }, 88 | } 89 | 90 | impl OpenAIAssistantVersion { 91 | pub(crate) fn get_endpoint(&self, resource: &OpenAIAssistantResource) -> String { 92 | //OpenAI documentation: https://platform.openai.com/docs/models/model-endpoint-compatibility 93 | let trimmed_api_url = (*OPENAI_API_URL).trim_end_matches('/'); 94 | let base_url = match self { 95 | OpenAIAssistantVersion::V1 | OpenAIAssistantVersion::V2 => { 96 | format!("{trimmed_api_url}/v1") 97 | } 98 | OpenAIAssistantVersion::Azure | OpenAIAssistantVersion::AzureVersion { .. } => { 99 | format!("{trimmed_api_url}/openai") 100 | } 101 | }; 102 | 103 | let path = match resource { 104 | OpenAIAssistantResource::Assistants => format!("{base_url}/assistants"), 105 | OpenAIAssistantResource::Assistant { assistant_id } => { 106 | format!("{base_url}/assistants/{assistant_id}") 107 | } 108 | OpenAIAssistantResource::Threads => format!("{base_url}/threads"), 109 | OpenAIAssistantResource::Messages { thread_id } => { 110 | format!("{base_url}/threads/{thread_id}/messages") 111 | } 112 | OpenAIAssistantResource::Runs { thread_id } => { 113 | format!("{base_url}/threads/{thread_id}/runs") 114 | } 115 | OpenAIAssistantResource::Run { thread_id, run_id } => { 116 | format!("{base_url}/threads/{thread_id}/runs/{run_id}") 117 | } 118 | OpenAIAssistantResource::Files => format!("{base_url}/files"), 119 | OpenAIAssistantResource::File { file_id } => format!("{base_url}/files/{file_id}"), 120 | OpenAIAssistantResource::VectorStores => format!("{base_url}/vector_stores"), 121 | OpenAIAssistantResource::VectorStore { vector_store_id } => { 122 | format!("{base_url}/vector_stores/{vector_store_id}") 123 | } 124 | OpenAIAssistantResource::VectorStoreFileBatches { vector_store_id } => { 125 | format!("{base_url}/vector_stores/{vector_store_id}/file_batches") 126 | } 127 | }; 128 | 129 | // Add Azure version suffix if needed 130 | match self { 131 | OpenAIAssistantVersion::Azure => { 132 | format!("{path}?api-version={}", DEFAULT_AZURE_VERSION) 133 | } 134 | OpenAIAssistantVersion::AzureVersion { version } => { 135 | format!("{path}?api-version={version}") 136 | } 137 | _ => path, 138 | } 139 | } 140 | 141 | pub(crate) fn get_headers(&self, api_key: &str) -> HeaderMap { 142 | let mut headers = HeaderMap::new(); 143 | headers.insert( 144 | header::CONTENT_TYPE, 145 | HeaderValue::from_static("application/json"), 146 | ); 147 | 148 | match self { 149 | OpenAIAssistantVersion::V1 => { 150 | // Try to create the header value from the bearer token 151 | if let Ok(bearer_header) = HeaderValue::from_str(&format!("Bearer {api_key}")) { 152 | headers.insert("Authorization", bearer_header); 153 | } else { 154 | headers.insert( 155 | "Error", 156 | HeaderValue::from_static("Invalid Authorization Header"), 157 | ); 158 | }; 159 | headers.insert("OpenAI-Beta", HeaderValue::from_static("assistants=v1")); 160 | } 161 | OpenAIAssistantVersion::V2 => { 162 | // Try to create the header value from the bearer token 163 | if let Ok(bearer_header) = HeaderValue::from_str(&format!("Bearer {api_key}")) { 164 | headers.insert("Authorization", bearer_header); 165 | } else { 166 | headers.insert( 167 | "Error", 168 | HeaderValue::from_static("Invalid Authorization Header"), 169 | ); 170 | }; 171 | headers.insert("OpenAI-Beta", HeaderValue::from_static("assistants=v2")); 172 | } 173 | OpenAIAssistantVersion::Azure | OpenAIAssistantVersion::AzureVersion { .. } => { 174 | // Try to create the header value from the bearer token 175 | if let Ok(api_key_header) = HeaderValue::from_str(api_key) { 176 | headers.insert("api-key", api_key_header); 177 | } else { 178 | headers.insert( 179 | "Error", 180 | HeaderValue::from_static("Invalid Authorization Header"), 181 | ); 182 | }; 183 | } 184 | }; 185 | headers 186 | } 187 | 188 | pub(crate) fn get_tools_payload(&self) -> Value { 189 | match self { 190 | OpenAIAssistantVersion::V1 => json!([{ 191 | "type": "retrieval" 192 | }]), 193 | OpenAIAssistantVersion::V2 194 | | OpenAIAssistantVersion::Azure 195 | | OpenAIAssistantVersion::AzureVersion { .. } => json!([{ 196 | "type": "file_search" 197 | }]), 198 | } 199 | } 200 | 201 | pub(crate) fn add_message_attachments( 202 | &self, 203 | message_payload: &Value, 204 | file_ids: &[String], 205 | ) -> Value { 206 | let mut message_payload = message_payload.clone(); 207 | match self { 208 | OpenAIAssistantVersion::V1 => { 209 | message_payload["file_ids"] = json!(file_ids); 210 | } 211 | OpenAIAssistantVersion::V2 212 | | OpenAIAssistantVersion::Azure 213 | | OpenAIAssistantVersion::AzureVersion { .. } => { 214 | let file_search_json = json!({ 215 | "type": "file_search" 216 | }); 217 | let attachments_vec: Vec = file_ids 218 | .iter() 219 | .map(|file_id| { 220 | json!({ 221 | "file_id": file_id.to_string(), 222 | "tools": [file_search_json.clone()] 223 | }) 224 | }) 225 | .collect(); 226 | message_payload["attachments"] = json!(attachments_vec); 227 | } 228 | } 229 | message_payload 230 | } 231 | } 232 | 233 | impl FromStr for OpenAIAssistantVersion { 234 | type Err = anyhow::Error; 235 | 236 | /// Parses a string into `OpenAIAssistantVersion`. 237 | /// 238 | /// Supported formats (case-insensitive): 239 | /// - `"v1"` -> `OpenAIAssistantVersion::V1` 240 | /// - `"v2"` -> `OpenAIAssistantVersion::V2` 241 | /// - `"azure"` -> `OpenAIAssistantVersion::Azure` 242 | /// - `"azure:"` -> `OpenAIAssistantVersion::AzureVersion { version }` 243 | /// 244 | /// Returns an error for unrecognized formats. 245 | fn from_str(s: &str) -> Result { 246 | let s_lower = s.to_lowercase(); 247 | match s_lower.as_str() { 248 | "v1" => Ok(OpenAIAssistantVersion::V1), 249 | "v2" => Ok(OpenAIAssistantVersion::V2), 250 | _ if s_lower.starts_with("azure") => { 251 | // Check if the string contains a version after "azure:" 252 | if let Some(version) = s_lower.strip_prefix("azure:") { 253 | Ok(OpenAIAssistantVersion::AzureVersion { 254 | version: version.trim().to_string(), 255 | }) 256 | } else { 257 | // Backward compatibility: if it's just "azure", use a default version 258 | Ok(OpenAIAssistantVersion::Azure) 259 | } 260 | } 261 | _ => Err(anyhow!("Invalid version: {}", s)), 262 | } 263 | } 264 | } 265 | 266 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 267 | pub enum OpenAIAssistantResource { 268 | Assistants, 269 | Assistant { assistant_id: String }, 270 | Threads, 271 | Messages { thread_id: String }, 272 | Runs { thread_id: String }, 273 | Run { thread_id: String, run_id: String }, 274 | Files, 275 | File { file_id: String }, 276 | VectorStores, 277 | VectorStore { vector_store_id: String }, 278 | VectorStoreFileBatches { vector_store_id: String }, 279 | } 280 | 281 | #[cfg(test)] 282 | mod tests { 283 | use super::*; 284 | 285 | const OPENAI_API_URL: &str = "https://api.openai.com"; 286 | const DEFAULT_AZURE_VERSION: &str = "2024-06-01"; 287 | 288 | #[test] 289 | fn test_v1_assistants_endpoint() { 290 | let version = OpenAIAssistantVersion::V1; 291 | let resource = OpenAIAssistantResource::Assistants; 292 | let expected_url = format!("{}/v1/assistants", OPENAI_API_URL); 293 | assert_eq!(version.get_endpoint(&resource), expected_url); 294 | } 295 | 296 | #[test] 297 | fn test_azure_assistant_endpoint() { 298 | let version = OpenAIAssistantVersion::AzureVersion { 299 | version: "2024-05-01-preview".to_string(), 300 | }; 301 | let resource = OpenAIAssistantResource::Assistant { 302 | assistant_id: "123".to_string(), 303 | }; 304 | let expected_url = format!( 305 | "{}/openai/assistants/123?api-version=2024-05-01-preview", 306 | OPENAI_API_URL 307 | ); 308 | assert_eq!(version.get_endpoint(&resource), expected_url); 309 | } 310 | 311 | #[test] 312 | fn test_default_azure_assistant_endpoint() { 313 | let version = OpenAIAssistantVersion::from_str("azure").unwrap(); 314 | let resource = OpenAIAssistantResource::Assistant { 315 | assistant_id: "123".to_string(), 316 | }; 317 | let expected_url = format!( 318 | "{}/openai/assistants/123?api-version={}", 319 | OPENAI_API_URL, DEFAULT_AZURE_VERSION 320 | ); 321 | assert_eq!(version.get_endpoint(&resource), expected_url); 322 | } 323 | 324 | #[test] 325 | fn test_v2_threads_endpoint() { 326 | let version = OpenAIAssistantVersion::V2; 327 | let resource = OpenAIAssistantResource::Threads; 328 | let expected_url = format!("{}/v1/threads", OPENAI_API_URL); 329 | assert_eq!(version.get_endpoint(&resource), expected_url); 330 | } 331 | 332 | #[test] 333 | fn test_azure_file_batches_endpoint() { 334 | let version = OpenAIAssistantVersion::AzureVersion { 335 | version: "2024-05-01-preview".to_string(), 336 | }; 337 | let resource = OpenAIAssistantResource::VectorStoreFileBatches { 338 | vector_store_id: "abc".to_string(), 339 | }; 340 | let expected_url = format!( 341 | "{}/openai/vector_stores/abc/file_batches?api-version=2024-05-01-preview", 342 | OPENAI_API_URL 343 | ); 344 | assert_eq!(version.get_endpoint(&resource), expected_url); 345 | } 346 | 347 | #[test] 348 | fn test_v1_run_endpoint() { 349 | let version = OpenAIAssistantVersion::V1; 350 | let resource = OpenAIAssistantResource::Run { 351 | thread_id: "xyz".to_string(), 352 | run_id: "456".to_string(), 353 | }; 354 | let expected_url = format!("{}/v1/threads/xyz/runs/456", OPENAI_API_URL); 355 | assert_eq!(version.get_endpoint(&resource), expected_url); 356 | } 357 | 358 | #[test] 359 | fn test_v1_tools_payload() { 360 | let version = OpenAIAssistantVersion::V1; 361 | let expected_payload: Value = json!([{ 362 | "type": "retrieval" 363 | }]); 364 | assert_eq!(version.get_tools_payload(), expected_payload); 365 | } 366 | 367 | #[test] 368 | fn test_v2_tools_payload() { 369 | let version = OpenAIAssistantVersion::V2; 370 | let expected_payload: Value = json!([{ 371 | "type": "file_search" 372 | }]); 373 | assert_eq!(version.get_tools_payload(), expected_payload); 374 | } 375 | 376 | #[test] 377 | fn test_azure_tools_payload() { 378 | let version = OpenAIAssistantVersion::AzureVersion { 379 | version: "2024-05-01-preview".to_string(), 380 | }; 381 | let expected_payload: Value = json!([{ 382 | "type": "file_search" 383 | }]); 384 | assert_eq!(version.get_tools_payload(), expected_payload); 385 | } 386 | 387 | // Deserializing from string 388 | #[test] 389 | fn test_v1_version() { 390 | let result = OpenAIAssistantVersion::from_str("v1"); 391 | assert_eq!(result.unwrap(), OpenAIAssistantVersion::V1); 392 | } 393 | 394 | #[test] 395 | fn test_v2_version() { 396 | let result = OpenAIAssistantVersion::from_str("v2"); 397 | assert_eq!(result.unwrap(), OpenAIAssistantVersion::V2); 398 | } 399 | 400 | #[test] 401 | fn test_azure_with_version() { 402 | let result = OpenAIAssistantVersion::from_str("azure:2024-09-01"); 403 | assert_eq!( 404 | result.unwrap(), 405 | OpenAIAssistantVersion::AzureVersion { 406 | version: "2024-09-01".to_string(), 407 | } 408 | ); 409 | } 410 | 411 | #[test] 412 | fn test_azure_with_spaces_in_version() { 413 | let result = OpenAIAssistantVersion::from_str("azure: 2024-09-01 "); 414 | assert_eq!( 415 | result.unwrap(), 416 | OpenAIAssistantVersion::AzureVersion { 417 | version: "2024-09-01".to_string(), // Spaces trimmed 418 | } 419 | ); 420 | } 421 | 422 | #[test] 423 | fn test_azure_default_version() { 424 | let result = OpenAIAssistantVersion::from_str("azure"); 425 | assert_eq!(result.unwrap(), OpenAIAssistantVersion::Azure); 426 | } 427 | 428 | #[test] 429 | fn test_invalid_version() { 430 | let result = OpenAIAssistantVersion::from_str("invalid_version"); 431 | assert!(result.is_err()); 432 | assert_eq!( 433 | format!("{}", result.unwrap_err()), 434 | "Invalid version: invalid_version" 435 | ); 436 | } 437 | } 438 | -------------------------------------------------------------------------------- /src/llm_models/anthropic.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use log::info; 4 | use reqwest::{header, Client}; 5 | use serde::{Deserialize, Serialize}; 6 | use serde_json::{json, Value}; 7 | 8 | use crate::apis::AnthropicApiEndpoints; 9 | use crate::completions::ThinkingLevel; 10 | use crate::constants::{ANTHROPIC_API_URL, ANTHROPIC_MESSAGES_API_URL}; 11 | use crate::domain::{AnthropicAPICompletionsResponse, AnthropicAPIMessagesResponse}; 12 | use crate::llm_models::{ 13 | tools::{ 14 | AnthropicCodeExecutionConfig, AnthropicComputerUseConfig, AnthropicFileSearchConfig, 15 | AnthropicWebSearchConfig, 16 | }, 17 | LLMModel, LLMTools, 18 | }; 19 | 20 | // API Docs: https://docs.anthropic.com/en/docs/about-claude/models/all-models 21 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 22 | pub enum AnthropicModels { 23 | Claude4_5Opus, 24 | Claude4_5Sonnet, 25 | Claude4_5Haiku, 26 | Claude4_1Opus, 27 | Claude4Sonnet, 28 | Claude4Opus, 29 | Claude3_7Sonnet, 30 | Claude3_5Sonnet, 31 | Claude3_5Haiku, 32 | Claude3Opus, 33 | Claude3Sonnet, 34 | Claude3Haiku, 35 | // Legacy 36 | Claude2, 37 | ClaudeInstant1_2, 38 | } 39 | 40 | #[async_trait(?Send)] 41 | impl LLMModel for AnthropicModels { 42 | fn as_str(&self) -> &str { 43 | match self { 44 | AnthropicModels::Claude4_5Opus => "claude-opus-4-5", 45 | AnthropicModels::Claude4_5Sonnet => "claude-sonnet-4-5", 46 | AnthropicModels::Claude4_5Haiku => "claude-haiku-4-5", 47 | AnthropicModels::Claude4_1Opus => "claude-opus-4-1-20250805", 48 | AnthropicModels::Claude4Sonnet => "claude-sonnet-4-20250514", 49 | AnthropicModels::Claude4Opus => "claude-opus-4-20250514", 50 | AnthropicModels::Claude3_7Sonnet => "claude-3-7-sonnet-latest", 51 | AnthropicModels::Claude3_5Sonnet => "claude-3-5-sonnet-latest", 52 | AnthropicModels::Claude3_5Haiku => "claude-3-5-haiku-latest", 53 | AnthropicModels::Claude3Opus => "claude-3-opus-latest", 54 | AnthropicModels::Claude3Sonnet => "claude-3-sonnet-20240229", 55 | AnthropicModels::Claude3Haiku => "claude-3-haiku-20240307", 56 | // Legacy 57 | AnthropicModels::Claude2 => "claude-2.1", 58 | AnthropicModels::ClaudeInstant1_2 => "claude-instant-1.2", 59 | } 60 | } 61 | 62 | // Docs: https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases 63 | fn try_from_str(name: &str) -> Option { 64 | match name.to_lowercase().as_str() { 65 | "claude-opus-4-5" => Some(AnthropicModels::Claude4_5Opus), 66 | "claude-opus-4-5-20251101" => Some(AnthropicModels::Claude4_5Opus), 67 | "claude-sonnet-4-5" => Some(AnthropicModels::Claude4_5Sonnet), 68 | "claude-sonnet-4-5-20250929" => Some(AnthropicModels::Claude4_5Sonnet), 69 | "claude-haiku-4-5" => Some(AnthropicModels::Claude4_5Haiku), 70 | "claude-haiku-4-5-20251001" => Some(AnthropicModels::Claude4_5Haiku), 71 | "claude-opus-4-1-20250805" => Some(AnthropicModels::Claude4_1Opus), 72 | "claude-opus-4-1" => Some(AnthropicModels::Claude4_1Opus), 73 | "claude-sonnet-4-20250514" => Some(AnthropicModels::Claude4Sonnet), 74 | "claude-sonnet-4-0" => Some(AnthropicModels::Claude4Sonnet), 75 | "claude-opus-4-20250514" => Some(AnthropicModels::Claude4Opus), 76 | "claude-opus-4-0" => Some(AnthropicModels::Claude4Opus), 77 | "claude-3-7-sonnet-latest" => Some(AnthropicModels::Claude3_7Sonnet), 78 | "claude-3-5-sonnet-20240620" => Some(AnthropicModels::Claude3_5Sonnet), 79 | "claude-3-5-sonnet-latest" => Some(AnthropicModels::Claude3_5Sonnet), 80 | "claude-3-5-haiku-latest" => Some(AnthropicModels::Claude3_5Haiku), 81 | "claude-3-opus-20240229" => Some(AnthropicModels::Claude3Opus), 82 | "claude-3-opus-latest" => Some(AnthropicModels::Claude3Opus), 83 | "claude-3-sonnet-20240229" => Some(AnthropicModels::Claude3Sonnet), 84 | "claude-3-haiku-20240307" => Some(AnthropicModels::Claude3Haiku), 85 | // Legacy 86 | "claude-2.1" => Some(AnthropicModels::Claude2), 87 | "claude-instant-1.2" => Some(AnthropicModels::ClaudeInstant1_2), 88 | _ => None, 89 | } 90 | } 91 | 92 | fn default_max_tokens(&self) -> usize { 93 | // This is the max tokens allowed for response and not context as per documentation: https://docs.anthropic.com/en/docs/about-claude/models/overview#model-comparison-table 94 | match self { 95 | AnthropicModels::Claude4_5Opus => 64_000, 96 | AnthropicModels::Claude4_5Sonnet => 64_000, 97 | AnthropicModels::Claude4_5Haiku => 64_000, 98 | AnthropicModels::Claude4_1Opus => 32_000, 99 | AnthropicModels::Claude4Sonnet => 64_000, 100 | AnthropicModels::Claude4Opus => 32_000, 101 | AnthropicModels::Claude3_7Sonnet => 64_000, 102 | AnthropicModels::Claude3_5Sonnet => 8_192, 103 | AnthropicModels::Claude3_5Haiku => 8_192, 104 | AnthropicModels::Claude3Opus => 4_096, 105 | AnthropicModels::Claude3Sonnet => 4_096, 106 | AnthropicModels::Claude3Haiku => 4_096, 107 | // Legacy 108 | AnthropicModels::Claude2 => 4_096, 109 | AnthropicModels::ClaudeInstant1_2 => 4_096, 110 | } 111 | } 112 | 113 | fn get_endpoint(&self) -> String { 114 | match self { 115 | AnthropicModels::Claude4_5Opus 116 | | AnthropicModels::Claude4_5Sonnet 117 | | AnthropicModels::Claude4_5Haiku 118 | | AnthropicModels::Claude4_1Opus 119 | | AnthropicModels::Claude4Sonnet 120 | | AnthropicModels::Claude4Opus 121 | | AnthropicModels::Claude3_7Sonnet 122 | | AnthropicModels::Claude3_5Sonnet 123 | | AnthropicModels::Claude3_5Haiku 124 | | AnthropicModels::Claude3Opus 125 | | AnthropicModels::Claude3Sonnet 126 | | AnthropicModels::Claude3Haiku => ANTHROPIC_MESSAGES_API_URL.to_string(), 127 | // Legacy 128 | AnthropicModels::Claude2 | AnthropicModels::ClaudeInstant1_2 => { 129 | ANTHROPIC_API_URL.to_string() 130 | } 131 | } 132 | } 133 | 134 | //This method prepares the body of the API call for different models 135 | fn get_body( 136 | &self, 137 | instructions: &str, 138 | json_schema: &Value, 139 | function_call: bool, 140 | max_tokens: &usize, 141 | temperature: &f32, 142 | tools: Option<&[LLMTools]>, 143 | _thinking_level: Option<&ThinkingLevel>, 144 | ) -> serde_json::Value { 145 | let schema_string = serde_json::to_string(json_schema).unwrap_or_default(); 146 | let base_instructions = self.get_base_instructions(Some(function_call)); 147 | 148 | let completions_body = json!({ 149 | "model": self.as_str(), 150 | "max_tokens_to_sample": max_tokens, 151 | "temperature": temperature, 152 | "prompt": format!( 153 | "\n\nHuman: 154 | {base_instructions}\n\n 155 | Output Json schema:\n 156 | {schema_string}\n\n 157 | {instructions} 158 | \n\nAssistant:", 159 | ), 160 | }); 161 | 162 | let base_message = json!({ 163 | "role": "user", 164 | "content": format!( 165 | "{base_instructions}" 166 | ) 167 | }); 168 | 169 | let user_instructions = format!( 170 | " 171 | {instructions} 172 | 173 | 174 | {schema_string} 175 | " 176 | ); 177 | 178 | // The file search tool, if attached, is added to the body of the message 179 | // We check if the tool is added and if so use it to get the message content to be sent to the model 180 | let messages = if let Some(file_search_tool_config) = tools.and_then(|tools_inner| { 181 | tools_inner 182 | .iter() 183 | // Check if the tool is supported by the model 184 | .filter(|tool| { 185 | self.get_supported_tools().iter().any(|supported| { 186 | std::mem::discriminant(*tool) == std::mem::discriminant(supported) 187 | }) 188 | }) 189 | // Find the file search tool 190 | .find(|tool| matches!(tool, LLMTools::AnthropicFileSearch(_))) 191 | // Extract the file search tool config 192 | .and_then(|tool| { 193 | tool.get_config_json().and_then(|config_json| { 194 | serde_json::from_value::(config_json).ok() 195 | }) 196 | }) 197 | }) { 198 | json!([ 199 | base_message, 200 | { 201 | "role": "user", 202 | "content": [ 203 | // Use the file search tool config to get the content to be sent to the model 204 | file_search_tool_config.content(), 205 | { 206 | "type": "text", 207 | "text": user_instructions 208 | } 209 | ] 210 | } 211 | ]) 212 | } else { 213 | json!([base_message, { 214 | "role": "user", 215 | "content": user_instructions 216 | }]) 217 | }; 218 | 219 | let mut message_body = json!({ 220 | "model": self.as_str(), 221 | "max_tokens": max_tokens, 222 | "temperature": temperature, 223 | "messages": messages, 224 | }); 225 | 226 | // Add tools if provided 227 | if let Some(tools_inner) = tools { 228 | let processed_tools: Vec = tools_inner 229 | .iter() 230 | // File search is handled separately 231 | .filter(|tool| !matches!(tool, LLMTools::AnthropicFileSearch(_))) 232 | .filter(|tool| { 233 | self.get_supported_tools().iter().any(|supported| { 234 | std::mem::discriminant(*tool) == std::mem::discriminant(supported) 235 | }) 236 | }) 237 | .filter_map(LLMTools::get_config_json) 238 | .collect::>(); 239 | 240 | // Only add tools if the processed vector is not empty 241 | if !processed_tools.is_empty() { 242 | message_body["tools"] = json!(processed_tools); 243 | } 244 | } 245 | 246 | match self { 247 | AnthropicModels::Claude4_5Opus 248 | | AnthropicModels::Claude4_5Sonnet 249 | | AnthropicModels::Claude4_5Haiku 250 | | AnthropicModels::Claude4_1Opus 251 | | AnthropicModels::Claude4Sonnet 252 | | AnthropicModels::Claude4Opus 253 | | AnthropicModels::Claude3_7Sonnet 254 | | AnthropicModels::Claude3_5Sonnet 255 | | AnthropicModels::Claude3_5Haiku 256 | | AnthropicModels::Claude3Opus 257 | | AnthropicModels::Claude3Sonnet 258 | | AnthropicModels::Claude3Haiku => message_body, 259 | // Legacy 260 | AnthropicModels::Claude2 | AnthropicModels::ClaudeInstant1_2 => completions_body, 261 | } 262 | } 263 | /* 264 | * This function leverages Anthropic API to perform any query as per the provided body. 265 | * 266 | * It returns a String the Response object that needs to be parsed based on the self.model. 267 | */ 268 | async fn call_api( 269 | &self, 270 | api_key: &str, 271 | _version: Option, 272 | body: &serde_json::Value, 273 | debug: bool, 274 | tools: Option<&[LLMTools]>, 275 | ) -> Result { 276 | //Get the API url 277 | let model_url = self.get_endpoint(); 278 | 279 | //Make the API call 280 | let client = Client::new(); 281 | 282 | // Build request with base headers 283 | let mut request = client 284 | .post(model_url) 285 | .header(header::CONTENT_TYPE, "application/json") 286 | //Anthropic-specific way of passing API key 287 | .header("x-api-key", api_key) 288 | //Required as per documentation 289 | .header( 290 | "anthropic-version", 291 | AnthropicApiEndpoints::messages_default().version(), 292 | ); 293 | 294 | // Add tool-specific headers 295 | if let Some(tools_list) = tools { 296 | for tool in tools_list { 297 | if let Some((header_name, header_value)) = self.get_tool_header(tool) { 298 | request = request.header(header_name, header_value); 299 | } 300 | } 301 | } 302 | 303 | //Send request 304 | let response = request.json(&body).send().await?; 305 | 306 | let response_status = response.status(); 307 | let response_text = response.text().await?; 308 | 309 | if debug { 310 | info!( 311 | "[debug] Anthropic API response: [{}] {:#?}", 312 | &response_status, &response_text 313 | ); 314 | } 315 | 316 | Ok(response_text) 317 | } 318 | 319 | //This method attempts to convert the provided API response text into the expected struct and extracts the data from the response 320 | fn get_data(&self, response_text: &str, _function_call: bool) -> Result { 321 | //Convert API response to struct representing expected response format 322 | match self { 323 | AnthropicModels::Claude4_5Opus 324 | | AnthropicModels::Claude4_5Sonnet 325 | | AnthropicModels::Claude4_5Haiku 326 | | AnthropicModels::Claude4_1Opus 327 | | AnthropicModels::Claude4Sonnet 328 | | AnthropicModels::Claude4Opus 329 | | AnthropicModels::Claude3_7Sonnet 330 | | AnthropicModels::Claude3_5Sonnet 331 | | AnthropicModels::Claude3_5Haiku 332 | | AnthropicModels::Claude3Opus 333 | | AnthropicModels::Claude3Sonnet 334 | | AnthropicModels::Claude3Haiku => { 335 | let messages_response: AnthropicAPIMessagesResponse = 336 | serde_json::from_str(response_text)?; 337 | 338 | let assistant_response = messages_response 339 | .content 340 | .iter() 341 | .filter(|item| item.content_type == "text") 342 | .filter_map(|item| item.text.clone()) 343 | // Sanitize the response to remove the json schema wrapper 344 | .map(|text| self.sanitize_json_response(&text)) 345 | .next_back() 346 | .ok_or(anyhow::anyhow!("No assistant response found"))?; 347 | 348 | //Return completions text 349 | Ok(assistant_response) 350 | } 351 | // Legacy 352 | AnthropicModels::Claude2 | AnthropicModels::ClaudeInstant1_2 => { 353 | let completions_response: AnthropicAPICompletionsResponse = 354 | serde_json::from_str(response_text)?; 355 | 356 | //Return completions text 357 | Ok(completions_response.completion) 358 | } 359 | } 360 | } 361 | } 362 | 363 | impl AnthropicModels { 364 | pub fn get_supported_tools(&self) -> Vec { 365 | match self { 366 | AnthropicModels::Claude4_5Opus 367 | | AnthropicModels::Claude4_5Sonnet 368 | | AnthropicModels::Claude4_1Opus 369 | | AnthropicModels::Claude4Sonnet 370 | | AnthropicModels::Claude4Opus 371 | | AnthropicModels::Claude3_7Sonnet 372 | | AnthropicModels::Claude3_5Haiku => { 373 | vec![ 374 | LLMTools::AnthropicCodeExecution(AnthropicCodeExecutionConfig::new()), 375 | LLMTools::AnthropicComputerUse(AnthropicComputerUseConfig::new(1920, 1080)), 376 | LLMTools::AnthropicFileSearch(AnthropicFileSearchConfig::new("".to_string())), 377 | LLMTools::AnthropicWebSearch(AnthropicWebSearchConfig::new()), 378 | ] 379 | } 380 | // As of 2025.10.16 Claude 4.5 Haiku does not seem to support file search 381 | AnthropicModels::Claude4_5Haiku => { 382 | vec![ 383 | LLMTools::AnthropicCodeExecution(AnthropicCodeExecutionConfig::new()), 384 | LLMTools::AnthropicComputerUse(AnthropicComputerUseConfig::new(1920, 1080)), 385 | LLMTools::AnthropicWebSearch(AnthropicWebSearchConfig::new()), 386 | ] 387 | } 388 | AnthropicModels::Claude3_5Sonnet => { 389 | vec![ 390 | LLMTools::AnthropicComputerUse(AnthropicComputerUseConfig::new(1920, 1080)), 391 | LLMTools::AnthropicFileSearch(AnthropicFileSearchConfig::new("".to_string())), 392 | ] 393 | } 394 | _ => vec![], 395 | } 396 | } 397 | 398 | /// Returns a tuple of (header_name, header_value) for a specific tool, or None if no header is needed 399 | pub fn get_tool_header(&self, tool: &LLMTools) -> Option<(&'static str, &'static str)> { 400 | match (self, tool) { 401 | ( 402 | AnthropicModels::Claude4_5Opus 403 | | AnthropicModels::Claude4_5Sonnet 404 | | AnthropicModels::Claude4_5Haiku 405 | | AnthropicModels::Claude4_1Opus 406 | | AnthropicModels::Claude4Sonnet 407 | | AnthropicModels::Claude4Opus 408 | | AnthropicModels::Claude3_7Sonnet 409 | | AnthropicModels::Claude3_5Haiku, 410 | LLMTools::AnthropicCodeExecution(_), 411 | ) => Some(("anthropic-beta", "code-execution-2025-08-25")), 412 | ( 413 | // As of 2025.10.16 it is unclear if computer us is supported for 4.5 models 414 | // https://docs.claude.com/en/docs/agents-and-tools/tool-use/computer-use-tool 415 | AnthropicModels::Claude4Sonnet 416 | | AnthropicModels::Claude4Opus 417 | | AnthropicModels::Claude3_7Sonnet, 418 | LLMTools::AnthropicComputerUse(_), 419 | ) => Some(("anthropic-beta", "computer-use-2025-01-24")), 420 | // Claude Opus 4.5 requires a different header for computer use 421 | // https://platform.claude.com/docs/en/agents-and-tools/tool-use/computer-use-tool 422 | (AnthropicModels::Claude4_5Opus, LLMTools::AnthropicComputerUse(_)) => { 423 | Some(("anthropic-beta", "computer-use-2025-11-24")) 424 | } 425 | (AnthropicModels::Claude3_5Sonnet, LLMTools::AnthropicComputerUse(_)) => { 426 | Some(("anthropic-beta", "computer-use-2024-10-22")) 427 | } 428 | ( 429 | AnthropicModels::Claude4_5Opus 430 | | AnthropicModels::Claude4_5Sonnet 431 | | AnthropicModels::Claude4_1Opus 432 | | AnthropicModels::Claude4Sonnet 433 | | AnthropicModels::Claude4Opus 434 | | AnthropicModels::Claude3_7Sonnet 435 | | AnthropicModels::Claude3_5Sonnet 436 | | AnthropicModels::Claude3_5Haiku, 437 | LLMTools::AnthropicFileSearch(_), 438 | ) => Some(( 439 | "anthropic-beta", 440 | AnthropicApiEndpoints::files_default().version_static(), 441 | )), 442 | _ => { 443 | // Return None for tools that don't require a header 444 | None 445 | } 446 | } 447 | } 448 | } 449 | --------------------------------------------------------------------------------