├── .gitignore ├── src ├── load_balancer │ ├── mod.rs │ ├── utils.rs │ ├── tasks.rs │ ├── tracker.rs │ ├── strategies.rs │ ├── builder.rs │ └── manager.rs ├── providers │ ├── mod.rs │ ├── types.rs │ ├── instances.rs │ ├── openai.rs │ ├── mistral.rs │ ├── anthropic.rs │ ├── ollama.rs │ ├── google.rs │ └── model_discovery.rs ├── constants.rs ├── lib.rs └── errors.rs ├── Cargo.toml ├── LICENSE ├── CHANGELOG.md ├── examples └── task_routing.rs └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | CLAUDE.md 4 | debug_folder -------------------------------------------------------------------------------- /src/load_balancer/mod.rs: -------------------------------------------------------------------------------- 1 | /// Load balancer module for distributing requests across multiple LLM providers 2 | /// 3 | /// This module contains components for: 4 | /// - Managing provider instances with metrics tracking 5 | /// - Selecting appropriate providers based on tasks and load 6 | /// - Implementing different load balancing strategies 7 | /// - Handling retries and fallbacks when providers fail 8 | /// - Tracking token usage across providers 9 | 10 | pub mod tracker; 11 | pub mod manager; 12 | pub mod strategies; 13 | pub mod tasks; 14 | pub mod builder; 15 | pub mod utils; 16 | 17 | pub use manager::{LlmManager, GenerationRequest, LlmManagerResponse}; 18 | pub use tasks::TaskDefinition; -------------------------------------------------------------------------------- /src/providers/mod.rs: -------------------------------------------------------------------------------- 1 | /// Module for various LLM provider implementations 2 | /// 3 | /// This module contains implementations for different LLM providers: 4 | /// - Anthropic (Claude models) 5 | /// - OpenAI (GPT models) 6 | /// - Mistral AI 7 | /// - Google (Gemini models) 8 | /// - Ollama 9 | /// 10 | /// Each provider implements a common interface for generating text 11 | /// completions through their respective APIs. 12 | 13 | pub mod anthropic; 14 | pub mod openai; 15 | pub mod types; 16 | pub mod instances; 17 | pub mod google; 18 | pub mod mistral; 19 | pub mod ollama; 20 | pub mod model_discovery; 21 | 22 | pub use model_discovery::ModelDiscovery; 23 | pub use types::{ProviderType, LlmRequest, LlmResponse, Message, TokenUsage, ModelInfo}; 24 | pub use instances::{LlmInstance, create_instance}; 25 | pub use anthropic::AnthropicInstance; 26 | pub use openai::OpenAIInstance; -------------------------------------------------------------------------------- /src/constants.rs: -------------------------------------------------------------------------------- 1 | /// Common constants used throughout the crate 2 | 3 | // General 4 | pub const DEFAULT_MAX_TOKENS: u32 = 1024; 5 | pub const DEFAULT_MAX_TRIES: usize = 5; 6 | 7 | // OpenAI 8 | pub const OPENAI_API_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions"; 9 | 10 | // Anthropic 11 | pub const ANTHROPIC_API_ENDPOINT: &str = "https://api.anthropic.com/v1/messages"; 12 | pub const ANTHROPIC_API_VERSION: &str = "2023-06-01"; 13 | 14 | // Mistral 15 | pub const MISTRAL_API_ENDPOINT: &str = "https://api.mistral.ai/v1/chat/completions"; 16 | 17 | // Google 18 | pub const GOOGLE_API_ENDPOINT_PREFIX: &str = "https://generativelanguage.googleapis.com"; 19 | 20 | // Ollama 21 | pub const OLLAMA_API_ENDPOINT: &str = "http://localhost:11434/api/chat"; 22 | 23 | // Rate limiting 24 | pub const DEFAULT_RATE_LIMIT_WAIT_SECS: u64 = 2; 25 | pub const MAX_RATE_LIMIT_WAIT_SECS: u64 = 60; -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "flyllm" 3 | version = "0.3.1" 4 | edition = "2021" 5 | description = "A Rust library for unifying LLM backends as an abstraction layer with load balancing." 6 | authors = ["Pablo Rodríguez "] 7 | license = "MIT" 8 | repository = "https://github.com/rodmarkun/flyllm" 9 | readme = "README.md" 10 | keywords = ["llm", "ai", "openai", "anthropic", "load-balancing"] 11 | 12 | [dependencies] 13 | async-trait = "0.1.88" 14 | futures = "0.3.31" 15 | json = "0.12.4" 16 | reqwest = { version = "0.12.15", features = ["json"] } 17 | serde = { version = "1.0", features = ["derive"] } 18 | serde_json = "1.0.140" 19 | tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } 20 | log = "0.4" 21 | env_logger = "0.10" 22 | rand = "0.9.1" 23 | url = "2.5.4" 24 | 25 | [dev-dependencies] 26 | tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Pablo Rodríguez Martín 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/load_balancer/utils.rs: -------------------------------------------------------------------------------- 1 | use std::fs::{File, create_dir_all}; 2 | use std::io::Write; 3 | use std::path::PathBuf; 4 | use crate::errors::LlmError; 5 | 6 | pub fn get_debug_path(debug_folder: &PathBuf, timestamp: u64, instance_id: usize, instance_provider: &str, instance_model: &str) -> PathBuf { 7 | let timestamp_folder = debug_folder.join(timestamp.to_string()); 8 | let instance_folder = timestamp_folder.join(format!("{}_{}_{}", instance_id, instance_provider, instance_model)); 9 | instance_folder.join("debug.json") 10 | } 11 | 12 | pub fn write_to_debug_file(file_path: &PathBuf, contents: &str) -> Result<(), LlmError> { 13 | // Create parent directories if they don't exist 14 | if let Some(parent) = file_path.parent() { 15 | create_dir_all(parent) 16 | .map_err(|e| LlmError::ConfigError(format!("Failed to create debug directories: {}", e)))?; 17 | } 18 | 19 | let mut file = File::create(file_path) 20 | .map_err(|e| LlmError::ConfigError(format!("Failed to create debug file: {}", e)))?; 21 | 22 | file.write_all(contents.as_bytes()) 23 | .map_err(|e| LlmError::ConfigError(format!("Failed to write to debug file: {}", e)))?; 24 | 25 | Ok(()) 26 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to FlyLLM will be documented in this file. 4 | 5 | ## [0.3.1] - 2025-08-25 6 | ### Added 7 | - Upon request, the conversion from `&str` to ProviderType has been implemented 8 | 9 | ## [0.3.0] - 2025-08-06 10 | ### Added 11 | - Refactored the internals of FlyLLM, making it way simpler to modify and understand 12 | - Added optional debugging to LlmManager, allowing the user to store all requests and their metadata to JSON files automatically 13 | 14 | ## [0.2.3] - 2025-06-06 15 | ### Added 16 | - Rate limiting with wait for whenever all providers are overloaded 17 | 18 | ## [0.2.2] - 2025-05-19 19 | ### Added 20 | - Made the library entirely asynchronous, making the library more suitable for use in async contexts 21 | 22 | ## [0.2.1] - 2025-05-12 23 | ### Added 24 | - Capability of listing all available models from all providers 25 | 26 | ## [0.2.0] - 2025-04-30 27 | ### Added 28 | - Ollama provider support 29 | - Builder pattern for easier configuration 30 | - Aggregation of more basic routing strategies 31 | - Added optional custom endpoint configuration for any provider 32 | 33 | ## [0.1.0] - 2025-04-27 34 | ### Added 35 | - Initial release 36 | - Multiple Provider Support (OpenAI, Anthropic, Google, Mistral) 37 | - Task-Based Routing 38 | - Load Balancing 39 | - Failure Handling 40 | - Parallel Processing 41 | - Custom Parameters 42 | - Usage Tracking -------------------------------------------------------------------------------- /src/load_balancer/tasks.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use serde::{Serialize, Deserialize}; 3 | use serde_json::{json, Value}; 4 | 5 | /// Definition of a task that can be routed to specific providers 6 | /// 7 | /// Tasks represent specialized capabilities or configurations that 8 | /// certain providers might be better suited for. Each task can have 9 | /// associated parameters that affect how the request is processed. 10 | #[derive(Debug, Clone, Serialize, Deserialize)] 11 | pub struct TaskDefinition { 12 | pub name: String, 13 | pub parameters: HashMap, 14 | } 15 | 16 | impl TaskDefinition { 17 | /// Creates a new TaskDefinition with the given name. 18 | pub fn new(name: impl Into) -> Self { 19 | TaskDefinition { 20 | name: name.into(), 21 | parameters: HashMap::new(), 22 | } 23 | } 24 | 25 | /// Adds or updates a parameter for this task definition. 26 | /// Accepts any value that can be converted into a serde_json::Value. 27 | pub fn with_param(mut self, key: impl Into, value: impl Into) -> Self { 28 | self.parameters.insert(key.into(), value.into()); 29 | self 30 | } 31 | 32 | /// Sets the `max_tokens` parameter to a given value. 33 | pub fn with_max_tokens(self, tokens: u32) -> Self { 34 | self.with_param("max_tokens", json!(tokens)) 35 | } 36 | 37 | /// Sets the `temperature` parameter to a given value. 38 | pub fn with_temperature(self, temp: f32) -> Self { 39 | self.with_param("temperature", json!(temp)) 40 | } 41 | } -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! FlyLLM is a Rust library that provides a load-balanced, multi-provider client for Large Language Models. 2 | //! 3 | //! It enables developers to seamlessly work with multiple LLM providers (OpenAI, Anthropic, Google, Mistral...) 4 | //! through a unified API with request routing, load balancing, and failure handling. 5 | //! 6 | //! # Features 7 | //! 8 | //! - **Multi-provider support**: Integrate with OpenAI, Anthropic, Google, and Mistral 9 | //! - **Load balancing**: Distribute requests across multiple providers 10 | //! - **Automatic retries**: Handle provider failures with configurable retry policies 11 | //! - **Task routing**: Route specific tasks to the most suitable providers 12 | //! - **Metrics tracking**: Monitor response times, error rates, and token usage 13 | //! 14 | //! # Example 15 | //! 16 | //! ``` 17 | //! use flyllm::{LlmManager, ProviderType, GenerationRequest, TaskDefinition}; 18 | //! 19 | //! async fn example() { 20 | //! // Create a manager 21 | //! let mut manager = LlmManager::new(); 22 | //! 23 | //! // Add providers 24 | //! manager.add_provider( 25 | //! ProviderType::OpenAI, 26 | //! "api-key".to_string(), 27 | //! "gpt-4-turbo".to_string(), 28 | //! vec![], 29 | //! true 30 | //! ); 31 | //! 32 | //! // Generate a response 33 | //! let request = GenerationRequest { 34 | //! prompt: "Explain Rust in one paragraph".to_string(), 35 | //! task: None, 36 | //! params: None, 37 | //! }; 38 | //! 39 | //! let responses = manager.generate_sequentially(vec![request]).await; 40 | //! println!("{}", responses[0].content); 41 | //! } 42 | //! ``` 43 | 44 | pub mod providers; 45 | pub mod errors; 46 | pub mod constants; 47 | pub mod load_balancer; 48 | 49 | pub use providers::{ 50 | ProviderType, 51 | LlmRequest, 52 | LlmResponse, 53 | LlmInstance, 54 | create_instance, 55 | AnthropicInstance, 56 | OpenAIInstance, 57 | ModelInfo, 58 | ModelDiscovery 59 | }; 60 | 61 | pub use errors::{LlmError, LlmResult}; 62 | 63 | pub use load_balancer::{LlmManager, GenerationRequest, LlmManagerResponse, TaskDefinition}; 64 | 65 | /// Initialize the logging system 66 | /// 67 | /// This should be called at the start of your application in case 68 | /// you want to activate the library's debug and info logging. 69 | pub fn use_logging() { 70 | env_logger::init(); 71 | } -------------------------------------------------------------------------------- /src/providers/types.rs: -------------------------------------------------------------------------------- 1 | use serde::{Serialize, Deserialize}; 2 | 3 | /// Enum representing the different LLM providers supported 4 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] 5 | pub enum ProviderType { 6 | Anthropic, 7 | OpenAI, 8 | Mistral, 9 | Google, 10 | Ollama 11 | } 12 | 13 | /// Unified request structure used across all providers 14 | #[derive(Debug, Serialize, Deserialize, Clone)] 15 | pub struct LlmRequest { 16 | pub messages: Vec, 17 | pub model: Option, 18 | pub max_tokens: Option, 19 | pub temperature: Option, 20 | } 21 | 22 | /// Standard message format used across providers 23 | #[derive(Debug, Serialize, Deserialize, Clone)] 24 | pub struct Message { 25 | pub role: String, 26 | pub content: String, 27 | } 28 | 29 | /// Unified response structure returned by all providers 30 | #[derive(Debug, Serialize, Deserialize, Clone)] 31 | pub struct LlmResponse { 32 | pub content: String, 33 | pub model: String, 34 | pub usage: Option, 35 | } 36 | 37 | /// Token usage information returned by providers 38 | #[derive(Debug, Serialize, Deserialize, Clone)] 39 | pub struct TokenUsage { 40 | pub prompt_tokens: u32, 41 | pub completion_tokens: u32, 42 | pub total_tokens: u32, 43 | } 44 | 45 | impl Default for TokenUsage { 46 | fn default() -> Self { 47 | Self { 48 | prompt_tokens: 0, 49 | completion_tokens: 0, 50 | total_tokens: 0 51 | } 52 | } 53 | } 54 | 55 | /// Information about an LLM model 56 | #[derive(Debug, Serialize, Deserialize, Clone)] 57 | pub struct ModelInfo { 58 | pub name: String, 59 | pub provider: ProviderType, 60 | } 61 | 62 | /// Display implementation for ProviderType 63 | impl std::fmt::Display for ProviderType { 64 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 65 | match self { 66 | ProviderType::Anthropic => write!(f, "Anthropic"), 67 | ProviderType::OpenAI => write!(f, "OpenAI"), 68 | ProviderType::Mistral => write!(f, "Mistral"), 69 | ProviderType::Google => write!(f, "Google"), 70 | ProviderType::Ollama => write!(f, "Ollama"), 71 | } 72 | } 73 | } 74 | 75 | impl From<&str> for ProviderType { 76 | fn from(value: &str) -> Self { 77 | match value { 78 | "Anthropic" => ProviderType::Anthropic, 79 | "OpenAI" => ProviderType::OpenAI, 80 | "Mistral" => ProviderType::Mistral, 81 | "Google" => ProviderType::Google, 82 | "Ollama" => ProviderType::Ollama, 83 | _ => panic!("Unknown provider: {}", value), 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | use serde_json; 4 | 5 | /// Custom error types for LLM operations 6 | #[derive(Debug)] 7 | pub enum LlmError { 8 | /// Error from the HTTP client 9 | RequestError(reqwest::Error), 10 | /// Error from the API provider 11 | ApiError(String), 12 | /// Rate limiting error 13 | RateLimit(String), 14 | /// Parsing error 15 | ParseError(String), 16 | /// Provider is disabled 17 | ProviderDisabled(String), 18 | /// Configuration error 19 | ConfigError(String), 20 | } 21 | 22 | impl fmt::Display for LlmError { 23 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 24 | match self { 25 | LlmError::RequestError(err) => write!(f, "Request error: {}", err), 26 | LlmError::ApiError(msg) => write!(f, "API error: {}", msg), 27 | LlmError::RateLimit(msg) => write!(f, "Rate limit error: {}", msg), 28 | LlmError::ParseError(msg) => write!(f, "Parse error: {}", msg), 29 | LlmError::ProviderDisabled(provider) => write!(f, "Provider disabled: {}", provider), 30 | LlmError::ConfigError(msg) => write!(f, "Configuration error: {}", msg), 31 | } 32 | } 33 | } 34 | 35 | impl Error for LlmError { 36 | fn source(&self) -> Option<&(dyn Error + 'static)> { 37 | match self { 38 | LlmError::RequestError(err) => Some(err), 39 | _ => None, 40 | } 41 | } 42 | } 43 | 44 | /// Convert reqwest errors to LlmError 45 | impl From for LlmError { 46 | fn from(err: reqwest::Error) -> Self { 47 | LlmError::RequestError(err) 48 | } 49 | } 50 | 51 | /// Convert serde_json errors to LlmError 52 | impl From for LlmError { 53 | fn from(err: serde_json::Error) -> Self { 54 | LlmError::ParseError(err.to_string()) 55 | } 56 | } 57 | 58 | /// Result type alias for LLM operations 59 | pub type LlmResult = Result; 60 | 61 | impl LlmError { 62 | /// Returns RateLimit error for 429 status or rate limit keywords 63 | pub fn from_api_response(status: reqwest::StatusCode, error_message: String) -> Self { 64 | if status == reqwest::StatusCode::TOO_MANY_REQUESTS { 65 | return LlmError::RateLimit(error_message); 66 | } 67 | 68 | // Check error message for rate limit indicators 69 | let msg_lower = error_message.to_lowercase(); 70 | if msg_lower.contains("rate limit") 71 | || msg_lower.contains("too many requests") 72 | || msg_lower.contains("quota exceeded") 73 | || msg_lower.contains("overloaded") 74 | || msg_lower.contains("throttle") { 75 | return LlmError::RateLimit(error_message); 76 | } 77 | 78 | LlmError::ApiError(error_message) 79 | } 80 | } -------------------------------------------------------------------------------- /src/load_balancer/tracker.rs: -------------------------------------------------------------------------------- 1 | use crate::providers::LlmInstance; 2 | use crate::{LlmResponse, LlmResult}; 3 | use std::time::{Duration, Instant}; 4 | use std::sync::Arc; 5 | 6 | /// An LLM provider instance with associated metrics 7 | pub struct InstanceTracker { 8 | pub instance: Arc, 9 | pub last_used: Instant, 10 | pub response_times: Vec, 11 | pub request_count: usize, 12 | pub error_count: usize, 13 | } 14 | 15 | impl InstanceTracker { 16 | /// Create a new LLM instance 17 | /// 18 | /// # Parameters 19 | /// * `id` - Unique identifier for this instance 20 | /// * `provider` - Reference to the provider implementation 21 | pub fn new(instance: Arc) -> Self { 22 | Self { 23 | instance: instance, 24 | last_used: Instant::now(), 25 | response_times: Vec::new(), 26 | request_count: 0, 27 | error_count: 0, 28 | } 29 | } 30 | 31 | /// Record the result of a request for metrics tracking 32 | /// 33 | /// # Parameters 34 | /// * `duration` - How long the request took 35 | /// * `result` - The result of the request (success or error) 36 | pub fn record_result(&mut self, duration: Duration, result: &LlmResult) { 37 | self.last_used = Instant::now(); 38 | self.request_count += 1; 39 | 40 | match result { 41 | Ok(_) => { 42 | self.response_times.push(duration); 43 | if self.response_times.len() > 10 { 44 | self.response_times.remove(0); 45 | } } 46 | Err(e) => { 47 | self.error_count += 1; 48 | } 49 | } 50 | } 51 | 52 | /// Calculate the average response time from recent requests 53 | /// 54 | /// # Returns 55 | /// * Average duration, or zero if no requests recorded 56 | pub fn avg_response_time(&self) -> Duration { 57 | if self.response_times.is_empty() { 58 | return Duration::from_millis(0); 59 | } 60 | let total: Duration = self.response_times.iter().sum(); 61 | total / self.response_times.len().max(1) as u32 // Avoid division by zero 62 | } 63 | 64 | /// Calculate the error rate as a percentage 65 | /// 66 | /// # Returns 67 | /// * Error rate from 0.0 to 100.0, or 0.0 if no requests 68 | pub fn get_error_rate(&self) -> f64 { 69 | if self.request_count > 0 { 70 | (self.error_count as f64 / self.request_count as f64) * 100.0 71 | } else { 72 | 0.0 73 | } 74 | } 75 | 76 | /// Check if this instance is currently enabled 77 | /// 78 | /// # Returns 79 | /// * Whether this instance is enabled or not 80 | pub fn is_enabled(&self) -> bool { 81 | self.instance.is_enabled() 82 | } 83 | 84 | /// Check if this instance supports a specific task 85 | /// 86 | /// # Returns 87 | /// * Whether this instance supports this task or not 88 | pub fn supports_task(&self, task_name: &str) -> bool { 89 | self.instance.get_supported_tasks().contains_key(task_name) 90 | } 91 | } -------------------------------------------------------------------------------- /src/providers/instances.rs: -------------------------------------------------------------------------------- 1 | use crate::load_balancer::tasks::TaskDefinition; 2 | use crate::providers::types::{LlmRequest, LlmResponse, ProviderType}; 3 | use crate::providers::anthropic::AnthropicInstance; 4 | use crate::providers::openai::OpenAIInstance; 5 | use crate::providers::ollama::OllamaInstance; 6 | use crate::providers::google::GoogleInstance; 7 | use crate::providers::mistral::MistralInstance; 8 | use crate::errors::LlmResult; 9 | use std::collections::HashMap; 10 | use std::sync::Arc; 11 | 12 | use async_trait::async_trait; 13 | use std::time::Duration; 14 | use reqwest::Client; 15 | 16 | /// Common interface for all LLM instances 17 | /// 18 | /// This trait defines the interface that all LLM instances must implement 19 | /// to be compatible with the load balancer system. 20 | #[async_trait] 21 | pub trait LlmInstance { 22 | /// Generate a completion from the LLM instance 23 | async fn generate(&self, request: &LlmRequest) -> LlmResult; 24 | /// Get the name of this instance 25 | fn get_name(&self) -> &str; 26 | /// Get the currently configured model name 27 | fn get_model(&self) -> &str; 28 | /// Get the tasks this instance supports 29 | fn get_supported_tasks(&self) -> &HashMap; 30 | /// Check if this instance is enabled 31 | fn is_enabled(&self) -> bool; 32 | } 33 | 34 | /// Base instance implementation with common functionality 35 | /// 36 | /// Handles common properties and functionality shared across all instances: 37 | /// - HTTP client with timeout 38 | /// - API key storage 39 | /// - Model selection 40 | /// - Task support 41 | /// - Enable/disable status 42 | pub struct BaseInstance { 43 | name: String, 44 | client: Client, 45 | api_key: String, 46 | model: String, 47 | supported_tasks: HashMap, 48 | enabled: bool, 49 | } 50 | 51 | impl BaseInstance { 52 | /// Create a new Baseinstance with specified parameters 53 | /// 54 | /// # Parameters 55 | /// * `name` - instance name identifier 56 | /// * `api_key` - API key for authentication 57 | /// * `model` - Default model identifier to use 58 | /// * `supported_tasks` - Map of tasks this instance supports 59 | /// * `enabled` - Whether this instance is enabled 60 | pub fn new(name: String, api_key: String, model: String, supported_tasks: HashMap, enabled: bool) -> Self { 61 | let client = Client::builder() 62 | .timeout(Duration::from_secs(120)) 63 | .build() 64 | .expect("Failed to create HTTP client"); 65 | 66 | Self { name, client, api_key, model, supported_tasks, enabled } 67 | } 68 | 69 | /// Get the HTTP client instance 70 | pub fn client(&self) -> &Client { 71 | &self.client 72 | } 73 | 74 | /// Get the API key 75 | pub fn api_key(&self) -> &str { 76 | &self.api_key 77 | } 78 | 79 | /// Get the current model name 80 | pub fn model(&self) -> &str { 81 | &self.model 82 | } 83 | 84 | /// Check if this instance is enabled 85 | pub fn is_enabled(&self) -> bool { 86 | self.enabled 87 | } 88 | 89 | /// Get the instance name 90 | pub fn name(&self) -> &str { 91 | &self.name 92 | } 93 | 94 | /// Get the map of supported tasks 95 | pub fn supported_tasks(&self) -> &HashMap { 96 | &self.supported_tasks 97 | } 98 | } 99 | 100 | /// Factory function to create a instance instance based on type 101 | /// 102 | /// # Parameters 103 | /// * `instance_type` - Which instance type to create 104 | /// * `api_key` - API key for authentication 105 | /// * `model` - Default model identifier 106 | /// * `supported_tasks` - List of tasks this instance supports 107 | /// * `enabled` - Whether this instance should be enabled 108 | /// 109 | /// # Returns 110 | /// * Arc-wrapped trait object implementing Llminstance 111 | pub fn create_instance(instance_type: ProviderType, api_key: String, model: String, supported_tasks: Vec, enabled: bool, endpoint_url: Option) -> Arc { 112 | let supported_tasks: HashMap = supported_tasks 113 | .into_iter() 114 | .map(|task| (task.name.clone(), task)) 115 | .collect(); 116 | match instance_type { 117 | ProviderType::Anthropic => Arc::new(AnthropicInstance::new(api_key, model, supported_tasks, enabled)), 118 | ProviderType::OpenAI => Arc::new(OpenAIInstance::new(api_key, model, supported_tasks, enabled)), 119 | ProviderType::Mistral => Arc::new(MistralInstance::new(api_key, model, supported_tasks, enabled)), 120 | ProviderType::Google => Arc::new(GoogleInstance::new(api_key, model, supported_tasks, enabled)), 121 | ProviderType::Ollama => Arc::new(OllamaInstance::new(api_key, model, supported_tasks, enabled, endpoint_url)) 122 | } 123 | } -------------------------------------------------------------------------------- /src/providers/openai.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::load_balancer::tasks::TaskDefinition; 4 | use crate::providers::instances::{LlmInstance, BaseInstance}; 5 | use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage, Message}; 6 | use crate::errors::{LlmError, LlmResult}; 7 | use crate::constants; 8 | 9 | use async_trait::async_trait; 10 | use reqwest::header; 11 | use serde::{Serialize, Deserialize}; 12 | 13 | /// Provider implementation for OpenAI's API (GPT models) 14 | pub struct OpenAIInstance { 15 | base: BaseInstance, 16 | } 17 | 18 | /// Request structure for OpenAI's chat completion API 19 | /// Maps to the format expected by OpenAI's API 20 | #[derive(Serialize)] 21 | struct OpenAIRequest { 22 | model: String, 23 | messages: Vec, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | max_tokens: Option, 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | temperature: Option, 28 | } 29 | 30 | /// Response structure from OpenAI's chat completion API 31 | #[derive(Deserialize)] 32 | struct OpenAIResponse { 33 | choices: Vec, 34 | model: String, 35 | usage: Option, 36 | } 37 | 38 | /// Individual choice from OpenAI's response 39 | #[derive(Deserialize)] 40 | struct OpenAIChoice { 41 | message: Message, 42 | } 43 | 44 | /// Token usage information from OpenAI 45 | #[derive(Deserialize)] 46 | struct OpenAIUsage { 47 | prompt_tokens: u32, 48 | completion_tokens: u32, 49 | total_tokens: u32, 50 | } 51 | 52 | impl OpenAIInstance { 53 | /// Creates a new OpenAI provider instance 54 | /// 55 | /// # Parameters 56 | /// * `api_key` - OpenAI API key 57 | /// * `model` - Default model to use (e.g. "gpt-4-turbo") 58 | /// * `supported_tasks` - Map of tasks this provider supports 59 | /// * `enabled` - Whether this provider is enabled 60 | pub fn new(api_key: String, model: String, supported_tasks: HashMap, enabled: bool) -> Self { 61 | let base = BaseInstance::new("openai".to_string(), api_key, model, supported_tasks, enabled); 62 | Self { base } 63 | } 64 | } 65 | 66 | #[async_trait] 67 | impl LlmInstance for OpenAIInstance { 68 | /// Generates a completion using OpenAI's API 69 | /// 70 | /// # Parameters 71 | /// * `request` - The LLM request containing messages and parameters 72 | /// 73 | /// # Returns 74 | /// * `LlmResult` - The response from the model or an error 75 | async fn generate(&self, request: &LlmRequest) -> LlmResult { 76 | if !self.base.is_enabled() { 77 | return Err(LlmError::ProviderDisabled("OpenAI".to_string())); 78 | } 79 | 80 | let mut headers = header::HeaderMap::new(); 81 | headers.insert( 82 | header::AUTHORIZATION, 83 | header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) 84 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, 85 | ); 86 | headers.insert( 87 | header::CONTENT_TYPE, 88 | header::HeaderValue::from_static("application/json"), 89 | ); 90 | 91 | let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string()); 92 | 93 | let openai_request = OpenAIRequest { 94 | model, 95 | messages: request.messages.clone(), 96 | max_tokens: request.max_tokens, 97 | temperature: request.temperature, 98 | }; 99 | 100 | let response = self.base.client() 101 | .post(constants::OPENAI_API_ENDPOINT) 102 | .headers(headers) 103 | .json(&openai_request) 104 | .send() 105 | .await?; 106 | 107 | if !response.status().is_success() { 108 | let error_text = response.text().await 109 | .unwrap_or_else(|_| "Unknown error".to_string()); 110 | return Err(LlmError::ApiError(format!("OpenAI API error: {}", error_text))); 111 | } 112 | 113 | let openai_response: OpenAIResponse = response.json().await?; 114 | 115 | if openai_response.choices.is_empty() { 116 | return Err(LlmError::ApiError("No response from OpenAI".to_string())); 117 | } 118 | 119 | let usage = openai_response.usage.map(|u| TokenUsage { 120 | prompt_tokens: u.prompt_tokens, 121 | completion_tokens: u.completion_tokens, 122 | total_tokens: u.total_tokens, 123 | }); 124 | 125 | Ok(LlmResponse { 126 | content: openai_response.choices[0].message.content.clone(), 127 | model: openai_response.model, 128 | usage, 129 | }) 130 | } 131 | 132 | /// Returns provider name 133 | fn get_name(&self) -> &str { 134 | self.base.name() 135 | } 136 | 137 | /// Returns current model name 138 | fn get_model(&self) -> &str { 139 | self.base.model() 140 | } 141 | 142 | /// Returns supported tasks for this provider 143 | fn get_supported_tasks(&self) -> &HashMap { 144 | &self.base.supported_tasks() 145 | } 146 | 147 | /// Returns whether this provider is enabled 148 | fn is_enabled(&self) -> bool { 149 | self.base.is_enabled() 150 | } 151 | } -------------------------------------------------------------------------------- /src/load_balancer/strategies.rs: -------------------------------------------------------------------------------- 1 | use log::debug; 2 | use rand::Rng; 3 | 4 | use crate::load_balancer::tracker::InstanceTracker; 5 | 6 | /// Trait defining the interface for load balancing strategies 7 | /// 8 | /// Implementations of this trait determine how to select which LLM instance 9 | /// will handle a particular request based on instance metrics. 10 | pub trait LoadBalancingStrategy { 11 | /// Select an instance from available candidates 12 | /// 13 | /// # Parameters 14 | /// * `trackers` - Array of (id, tracker) tuples for available instances 15 | /// 16 | /// # Returns 17 | /// * Index into the trackers array of the selected instance 18 | fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize; 19 | } 20 | 21 | /// Strategy that selects the instance that was used least recently 22 | /// 23 | /// This strategy helps distribute load by prioritizing instances 24 | /// that haven't been used in the longest time. 25 | pub struct LeastRecentlyUsedStrategy; 26 | 27 | impl LeastRecentlyUsedStrategy { 28 | /// Creates a new LeastRecentlyUsedStrategy 29 | pub fn new() -> Self { 30 | Self {} 31 | } 32 | } 33 | 34 | impl LoadBalancingStrategy for LeastRecentlyUsedStrategy { 35 | /// Select the instance with the oldest last_used timestamp 36 | /// 37 | /// # Parameters 38 | /// * `trackers` - Array of (id, tracker) tuples for available instances 39 | /// 40 | /// # Returns 41 | /// * Index into the trackers array of the least recently used instance 42 | /// 43 | /// # Panics 44 | /// Panics if `trackers` is empty 45 | fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { 46 | if trackers.is_empty() { 47 | panic!("LoadBalancingStrategy::select_instance called with empty trackers slice"); 48 | } 49 | 50 | let mut oldest_index = 0; 51 | let mut oldest_time = trackers[0].1.last_used; 52 | 53 | for (i, (_id, tracker)) in trackers.iter().enumerate().skip(1) { 54 | if tracker.last_used < oldest_time { 55 | oldest_index = i; 56 | oldest_time = tracker.last_used; 57 | } 58 | } 59 | 60 | debug!( 61 | "LeastRecentlyUsedStrategy: Selected index {} (ID: {}) from {} eligible trackers with last_used: {:?}", 62 | oldest_index, trackers[oldest_index].0, trackers.len(), oldest_time 63 | ); 64 | 65 | oldest_index 66 | } 67 | } 68 | 69 | 70 | /// Strategy that selects the instance with the lowest average response time. 71 | #[derive(Debug, Default)] 72 | pub struct LowestLatencyStrategy; 73 | 74 | impl LowestLatencyStrategy { 75 | /// Creates a new LowestLatencyStrategy 76 | pub fn new() -> Self { 77 | Self {} 78 | } 79 | } 80 | 81 | impl LoadBalancingStrategy for LowestLatencyStrategy { 82 | /// Select the instance with the minimum `avg_response_time`. 83 | /// 84 | /// # Parameters 85 | /// * `trackers` - Array of (id, tracker) tuples for available instances. 86 | /// 87 | /// # Returns 88 | /// * Index into the trackers array of the fastest instance. 89 | /// 90 | /// # Panics 91 | /// * Panics if `trackers` is empty. 92 | fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { 93 | if trackers.is_empty() { 94 | panic!("LowestLatencyStrategy::select_instance called with empty trackers slice"); 95 | } 96 | 97 | let mut best_index = 0; 98 | let mut lowest_time = trackers[0].1.avg_response_time(); 99 | 100 | for (i, (_id, tracker)) in trackers.iter().enumerate().skip(1) { 101 | let avg_time = tracker.avg_response_time(); 102 | if avg_time < lowest_time { 103 | best_index = i; 104 | lowest_time = avg_time; 105 | } 106 | } 107 | 108 | debug!( 109 | "LowestLatencyStrategy: Selected index {} (ID: {}) from {} eligible trackers with avg_response_time: {:?}", 110 | best_index, trackers[best_index].0, trackers.len(), lowest_time 111 | ); 112 | 113 | best_index 114 | } 115 | } 116 | 117 | /// Strategy that selects a random instance from the available pool. 118 | #[derive(Debug, Default)] 119 | pub struct RandomStrategy; 120 | 121 | impl RandomStrategy { 122 | /// Creates a new RandomStrategy 123 | pub fn new() -> Self { 124 | Self {} 125 | } 126 | } 127 | 128 | impl LoadBalancingStrategy for RandomStrategy { 129 | /// Select a random instance. 130 | /// 131 | /// # Parameters 132 | /// * `trackers` - Array of (id, tracker) tuples for available instances. 133 | /// 134 | /// # Returns 135 | /// * Index into the trackers array of a randomly chosen instance. 136 | /// 137 | /// # Panics 138 | /// * Panics if `trackers` is empty. 139 | fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { 140 | if trackers.is_empty() { 141 | panic!("RandomStrategy::select_instance called with empty trackers slice"); 142 | } 143 | 144 | let index = rand::rng().random_range(0..trackers.len()); 145 | 146 | debug!( 147 | "RandomStrategy: Selected random index {} (ID: {}) from {} eligible trackers", 148 | index, trackers[index].0, trackers.len() 149 | ); 150 | 151 | index 152 | } 153 | } -------------------------------------------------------------------------------- /src/providers/mistral.rs: -------------------------------------------------------------------------------- 1 | use crate::load_balancer::tasks::TaskDefinition; 2 | use crate::providers::instances::{LlmInstance, BaseInstance}; 3 | use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage, Message}; 4 | use crate::errors::{LlmError, LlmResult}; 5 | use crate::constants; 6 | 7 | use std::collections::HashMap; 8 | use async_trait::async_trait; 9 | use reqwest::header; 10 | use serde::{Serialize, Deserialize}; 11 | 12 | /// Provider implementation for Mistral AI's API 13 | pub struct MistralInstance { 14 | base: BaseInstance, 15 | } 16 | 17 | /// Request structure for Mistral AI's chat completion API 18 | #[derive(Serialize)] 19 | struct MistralRequest { 20 | model: String, 21 | messages: Vec, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | temperature: Option, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | max_tokens: Option, 26 | } 27 | 28 | /// Response structure from Mistral AI's chat completion API 29 | #[derive(Deserialize, Debug)] 30 | struct MistralResponse { 31 | id: String, 32 | model: String, 33 | object: String, 34 | created: u64, 35 | choices: Vec, 36 | usage: Option, 37 | } 38 | 39 | /// Individual choice from Mistral's response 40 | #[derive(Deserialize, Debug)] 41 | struct MistralChoice { 42 | index: u32, // Removed underscore prefix 43 | message: Message, 44 | finish_reason: Option, 45 | } 46 | 47 | /// Token usage information from Mistral 48 | #[derive(Deserialize, Debug)] 49 | struct MistralUsage { 50 | prompt_tokens: u32, 51 | completion_tokens: u32, 52 | total_tokens: u32, 53 | } 54 | 55 | impl MistralInstance { 56 | /// Creates a new Mistral provider instance 57 | /// 58 | /// # Parameters 59 | /// * `api_key` - Mistral API key 60 | /// * `model` - Default model to use (e.g. "mistral-large") 61 | /// * `supported_tasks` - Map of tasks this provider supports 62 | /// * `enabled` - Whether this provider is enabled 63 | pub fn new(api_key: String, model: String, supported_tasks: HashMap, enabled: bool) -> Self { 64 | let base = BaseInstance::new("mistral".to_string(), api_key, model, supported_tasks, enabled); 65 | Self { base } 66 | } 67 | } 68 | 69 | #[async_trait] 70 | impl LlmInstance for MistralInstance { 71 | /// Generates a completion using Mistral AI's API 72 | /// 73 | /// # Parameters 74 | /// * `request` - The LLM request containing messages and parameters 75 | /// 76 | /// # Returns 77 | /// * `LlmResult` - The response from the model or an error 78 | async fn generate(&self, request: &LlmRequest) -> LlmResult { 79 | if !self.base.is_enabled() { 80 | return Err(LlmError::ProviderDisabled("Mistral".to_string())); 81 | } 82 | 83 | let mut headers = header::HeaderMap::new(); 84 | headers.insert( 85 | header::AUTHORIZATION, 86 | header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) 87 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, 88 | ); 89 | headers.insert( 90 | header::CONTENT_TYPE, 91 | header::HeaderValue::from_static("application/json"), 92 | ); 93 | headers.insert( 94 | header::ACCEPT, 95 | header::HeaderValue::from_static("application/json"), 96 | ); 97 | 98 | let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string()); 99 | 100 | if request.messages.is_empty() { 101 | return Err(LlmError::ApiError("Mistral requires at least one message".to_string())); 102 | } 103 | 104 | let mistral_request = MistralRequest { 105 | model, 106 | messages: request.messages.iter().map(|m| Message { 107 | role: match m.role.as_str() { 108 | "system" | "user" | "assistant" => m.role.clone(), 109 | _ => "user".to_string(), 110 | }, 111 | content: m.content.clone() 112 | }).collect(), 113 | temperature: request.temperature, 114 | max_tokens: request.max_tokens, 115 | }; 116 | 117 | let response = self.base.client() 118 | .post(constants::MISTRAL_API_ENDPOINT) 119 | .headers(headers) 120 | .json(&mistral_request) 121 | .send() 122 | .await?; 123 | 124 | if !response.status().is_success() { 125 | let status = response.status(); 126 | let error_text = response.text().await 127 | .unwrap_or_else(|_| "Unknown error reading response body".to_string()); 128 | return Err(LlmError::ApiError(format!( 129 | "Mistral API error ({}): {}", 130 | status, error_text 131 | ))); 132 | } 133 | 134 | // Debug: Log raw response body for inspection if needed 135 | let response_body = response.text().await 136 | .map_err(|e| LlmError::ApiError(format!("Failed to read Mistral response body: {}", e)))?; 137 | 138 | // Try to parse the response as JSON 139 | let mistral_response: MistralResponse = serde_json::from_str(&response_body) 140 | .map_err(|e| { 141 | // Provide more context in the error message 142 | LlmError::ApiError(format!( 143 | "Failed to parse Mistral JSON response: {}. Response body: {}", 144 | e, 145 | if response_body.len() > 200 { 146 | format!("{}... (truncated)", &response_body[..200]) 147 | } else { 148 | response_body.clone() 149 | } 150 | )) 151 | })?; 152 | 153 | if mistral_response.choices.is_empty() { 154 | return Err(LlmError::ApiError("No choices returned from Mistral".to_string())); 155 | } 156 | 157 | let choice = &mistral_response.choices[0]; 158 | 159 | let usage = mistral_response.usage.map(|u| TokenUsage { 160 | prompt_tokens: u.prompt_tokens, 161 | completion_tokens: u.completion_tokens, 162 | total_tokens: u.total_tokens, 163 | }); 164 | 165 | Ok(LlmResponse { 166 | content: choice.message.content.clone(), 167 | model: mistral_response.model, 168 | usage, 169 | }) 170 | } 171 | 172 | /// Returns provider name 173 | fn get_name(&self) -> &str { 174 | self.base.name() 175 | } 176 | 177 | /// Returns current model name 178 | fn get_model(&self) -> &str { 179 | self.base.model() 180 | } 181 | 182 | /// Returns supported tasks for this provider 183 | fn get_supported_tasks(&self) -> &HashMap { 184 | &self.base.supported_tasks() 185 | } 186 | 187 | /// Returns whether this provider is enabled 188 | fn is_enabled(&self) -> bool { 189 | self.base.is_enabled() 190 | } 191 | } -------------------------------------------------------------------------------- /src/providers/anthropic.rs: -------------------------------------------------------------------------------- 1 | use crate::load_balancer::tasks::TaskDefinition; 2 | use crate::providers::instances::{LlmInstance, BaseInstance}; 3 | use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage}; 4 | use crate::errors::{LlmError, LlmResult}; 5 | use crate::constants; 6 | 7 | use async_trait::async_trait; 8 | use reqwest::header; 9 | use serde::{Serialize, Deserialize}; 10 | use std::collections::HashMap; 11 | 12 | /// Provider implementation for Anthropic's Claude API 13 | pub struct AnthropicInstance { 14 | base: BaseInstance, 15 | } 16 | 17 | /// Request structure for the Anthropic Claude API 18 | #[derive(Serialize)] 19 | struct AnthropicRequest { 20 | model: String, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | system: Option, 23 | messages: Vec, 24 | max_tokens: u32, 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | temperature: Option, 27 | } 28 | 29 | /// Individual message structure for Anthropic's API 30 | #[derive(Serialize)] 31 | struct AnthropicMessage { 32 | role: String, 33 | content: String, 34 | } 35 | 36 | /// Response structure from Anthropic's Claude API 37 | #[derive(Deserialize)] 38 | struct AnthropicResponse { 39 | content: Vec, 40 | model: String, 41 | usage: Option, 42 | } 43 | 44 | /// Content block from Anthropic's response 45 | #[derive(Deserialize)] 46 | struct AnthropicContent { 47 | text: String, 48 | #[serde(rename = "type")] 49 | content_type: String, 50 | } 51 | 52 | /// Token usage information from Anthropic 53 | #[derive(Deserialize)] 54 | struct AnthropicUsage { 55 | input_tokens: u32, 56 | output_tokens: u32, 57 | } 58 | 59 | impl AnthropicInstance { 60 | /// Creates a new Anthropic provider instance 61 | /// 62 | /// # Parameters 63 | /// * `api_key` - Anthropic API key 64 | /// * `model` - Default model to use (e.g. "claude-3-opus-20240229") 65 | /// * `supported_tasks` - Map of tasks this provider supports 66 | /// * `enabled` - Whether this provider is enabled 67 | pub fn new(api_key: String, model: String, supported_tasks: HashMap, enabled: bool) -> Self { 68 | let base = BaseInstance::new("anthropic".to_string(), api_key, model, supported_tasks, enabled); 69 | Self { base } 70 | } 71 | } 72 | 73 | #[async_trait] 74 | impl LlmInstance for AnthropicInstance { 75 | /// Generates a completion using Anthropic's Claude API 76 | /// 77 | /// # Parameters 78 | /// * `request` - The LLM request containing messages and parameters 79 | /// 80 | /// # Returns 81 | /// * `LlmResult` - The response from the model or an error 82 | async fn generate(&self, request: &LlmRequest) -> LlmResult { 83 | if !self.base.is_enabled() { 84 | return Err(LlmError::ProviderDisabled("Anthropic".to_string())); 85 | } 86 | 87 | let mut headers = header::HeaderMap::new(); 88 | headers.insert( 89 | "x-api-key", 90 | header::HeaderValue::from_str(self.base.api_key()) 91 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, 92 | ); 93 | headers.insert( 94 | header::CONTENT_TYPE, 95 | header::HeaderValue::from_static("application/json"), 96 | ); 97 | headers.insert( 98 | "anthropic-version", 99 | header::HeaderValue::from_static(constants::ANTHROPIC_API_VERSION), 100 | ); 101 | 102 | let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string()); 103 | 104 | // Extract system message and regular messages 105 | let mut system_content = None; 106 | let mut regular_messages = Vec::new(); 107 | 108 | for msg in &request.messages { 109 | if msg.role == "system" { 110 | system_content = Some(msg.content.clone()); 111 | } else { 112 | regular_messages.push(AnthropicMessage { 113 | role: msg.role.clone(), 114 | content: msg.content.clone(), 115 | }); 116 | } 117 | } 118 | 119 | // Ensure we have at least one message 120 | if regular_messages.is_empty() && system_content.is_some() { 121 | regular_messages.push(AnthropicMessage { 122 | role: "user".to_string(), 123 | content: format!("Using this context: {}", system_content.unwrap()), 124 | }); 125 | system_content = None; 126 | } 127 | 128 | if regular_messages.is_empty() { 129 | return Err(LlmError::ApiError("Anthropic requires at least one message".to_string())); 130 | } 131 | 132 | let anthropic_request = AnthropicRequest { 133 | model, 134 | system: system_content, 135 | messages: regular_messages, 136 | max_tokens: request.max_tokens.unwrap_or(constants::DEFAULT_MAX_TOKENS), 137 | temperature: request.temperature, 138 | }; 139 | 140 | let response = self.base.client() 141 | .post(constants::ANTHROPIC_API_ENDPOINT) 142 | .headers(headers) 143 | .json(&anthropic_request) 144 | .send() 145 | .await?; 146 | 147 | if !response.status().is_success() { 148 | let error_text = response.text().await 149 | .unwrap_or_else(|_| "Unknown error".to_string()); 150 | return Err(LlmError::ApiError(format!("Anthropic API error: {}", error_text))); 151 | } 152 | 153 | let anthropic_response: AnthropicResponse = response.json().await?; 154 | 155 | if anthropic_response.content.is_empty() { 156 | return Err(LlmError::ApiError("No response from Anthropic".to_string())); 157 | } 158 | 159 | let usage = anthropic_response.usage.map(|u| TokenUsage { 160 | prompt_tokens: u.input_tokens, 161 | completion_tokens: u.output_tokens, 162 | total_tokens: u.input_tokens + u.output_tokens, 163 | }); 164 | 165 | let text = anthropic_response.content.iter() 166 | .filter(|c| c.content_type == "text") 167 | .map(|c| c.text.clone()) 168 | .collect::>() 169 | .join(""); 170 | 171 | Ok(LlmResponse { 172 | content: text, 173 | model: anthropic_response.model, 174 | usage, 175 | }) 176 | } 177 | 178 | /// Returns provider name 179 | fn get_name(&self) -> &str { 180 | self.base.name() 181 | } 182 | 183 | /// Returns current model name 184 | fn get_model(&self) -> &str { 185 | self.base.model() 186 | } 187 | 188 | /// Returns supported tasks for this provider 189 | fn get_supported_tasks(&self) -> &HashMap{ 190 | self.base.supported_tasks() 191 | } 192 | 193 | /// Returns whether this provider is enabled 194 | fn is_enabled(&self) -> bool { 195 | self.base.is_enabled() 196 | } 197 | } -------------------------------------------------------------------------------- /src/load_balancer/builder.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::{LlmResult, LlmError}; 2 | use crate::load_balancer::strategies::{LoadBalancingStrategy, LeastRecentlyUsedStrategy}; 3 | use crate::load_balancer::tasks::TaskDefinition; 4 | use crate::{ProviderType, constants}; 5 | use std::collections::HashMap; 6 | use std::path::PathBuf; 7 | use log::debug; 8 | use super::LlmManager; 9 | 10 | /// Internal helper struct for Builder 11 | #[derive(Clone)] 12 | struct ProviderConfig { 13 | provider_type: ProviderType, 14 | api_key: String, 15 | model: String, 16 | supported_task_names: Vec, 17 | enabled: bool, 18 | custom_endpoint: Option, 19 | } 20 | 21 | /// LlmManager Builder 22 | pub struct LlmManagerBuilder { 23 | defined_tasks: HashMap, 24 | providers_to_build: Vec, 25 | strategy: Box, 26 | max_retries: usize, 27 | debug_folder: Option 28 | } 29 | 30 | impl LlmManagerBuilder { 31 | /// Creates a new builder with default settings. 32 | pub fn new() -> Self { 33 | LlmManagerBuilder { 34 | defined_tasks: HashMap::new(), 35 | providers_to_build: Vec::new(), 36 | strategy: Box::new(LeastRecentlyUsedStrategy::new()), // Default strategy 37 | max_retries: constants::DEFAULT_MAX_TRIES, // Default retries 38 | debug_folder: None 39 | } 40 | } 41 | 42 | /// Defines a task that providers can later reference by name. 43 | pub fn define_task(mut self, task_def: TaskDefinition) -> Self { 44 | self.defined_tasks.insert(task_def.name.clone(), task_def); 45 | self 46 | } 47 | 48 | /// Sets the load balancing strategy for the manager. 49 | pub fn strategy(mut self, strategy: Box) -> Self { 50 | self.strategy = strategy; 51 | self 52 | } 53 | 54 | /// Sets the maximum number of retries for failed requests. 55 | pub fn max_retries(mut self, retries: usize) -> Self { 56 | self.max_retries = retries; 57 | self 58 | } 59 | 60 | /// Begins configuring a new provider instance. 61 | /// Subsequent calls like `.supports()`, `.enabled()`, `.custom_endpoint()` will apply to this provider. 62 | pub fn add_instance( 63 | mut self, 64 | provider_type: ProviderType, 65 | model: impl Into, 66 | api_key: impl Into, 67 | ) -> Self { 68 | let config = ProviderConfig { 69 | provider_type, 70 | api_key: api_key.into(), 71 | model: model.into(), 72 | supported_task_names: Vec::new(), 73 | enabled: true, // Default to enabled 74 | custom_endpoint: None, 75 | }; 76 | self.providers_to_build.push(config); 77 | self // Return self to allow chaining provider configurations 78 | } 79 | 80 | /// Specifies that the *last added* provider supports the task with the given name. 81 | /// Panics if `add_instance` was not called before this. 82 | pub fn supports(mut self, task_name: &str) -> Self { 83 | match self.providers_to_build.last_mut() { 84 | Some(last_provider) => { 85 | if !self.defined_tasks.contains_key(task_name) { 86 | // Optional: Warn or error early if task isn't defined yet 87 | log::warn!("Provider configured to support task '{}' which has not been defined yet with define_task().", task_name); 88 | } 89 | last_provider.supported_task_names.push(task_name.to_string()); 90 | } 91 | None => { 92 | panic!("'.supports()' called before '.add_instance()'"); 93 | } 94 | } 95 | self 96 | } 97 | 98 | /// Specifies that the *last added* provider supports multiple tasks by name. 99 | /// Panics if `add_provider` was not called before this. 100 | pub fn supports_many(mut self, task_names: &[&str]) -> Self { 101 | match self.providers_to_build.last_mut() { 102 | Some(last_provider) => { 103 | for task_name in task_names { 104 | if !self.defined_tasks.contains_key(*task_name) { 105 | log::warn!("Provider configured to support task '{}' which has not been defined yet with define_task().", task_name); 106 | } 107 | last_provider.supported_task_names.push(task_name.to_string()); 108 | } 109 | } 110 | None => { 111 | panic!("'.supports_many()' called before '.add_provider()'"); 112 | } 113 | } 114 | self 115 | } 116 | 117 | 118 | /// Sets the enabled status for the *last added* provider. 119 | /// Panics if `add_provider` was not called before this. 120 | pub fn enabled(mut self, enabled: bool) -> Self { 121 | match self.providers_to_build.last_mut() { 122 | Some(last_provider) => { 123 | last_provider.enabled = enabled; 124 | } 125 | None => { 126 | panic!("'.enabled()' called before '.add_provider()'"); 127 | } 128 | } 129 | self 130 | } 131 | 132 | pub fn debug_folder(mut self, path: impl Into) -> Self { 133 | self.debug_folder = Some(path.into()); 134 | self 135 | } 136 | 137 | /// Sets a custom endpoint for the *last added* provider. 138 | /// Panics if `add_provider` was not called before this. 139 | pub fn custom_endpoint(mut self, endpoint: impl Into) -> Self { 140 | match self.providers_to_build.last_mut() { 141 | Some(last_provider) => { 142 | last_provider.custom_endpoint = Some(endpoint.into()); 143 | } 144 | None => { 145 | panic!("'.custom_endpoint()' called before '.add_provider()'"); 146 | } 147 | } 148 | self 149 | } 150 | 151 | 152 | /// Consumes the builder and constructs the `LlmManager`. 153 | /// Returns an error if a referenced task was not defined. 154 | pub async fn build(self) -> LlmResult { 155 | let mut manager = LlmManager::new_with_strategy_and_retries(self.strategy, self.max_retries); 156 | 157 | // Set debug folder if specified 158 | manager.debug_folder = self.debug_folder; 159 | 160 | for provider_config in self.providers_to_build { 161 | // Resolve TaskDefinition structs from names 162 | let mut provider_tasks: Vec = Vec::new(); 163 | for task_name in &provider_config.supported_task_names { 164 | match self.defined_tasks.get(task_name) { 165 | Some(task_def) => provider_tasks.push(task_def.clone()), 166 | None => return Err(LlmError::ConfigError(format!( 167 | "Build failed: Task '{}' referenced by provider '{}' ({}) was not defined using define_task()", 168 | task_name, provider_config.provider_type, provider_config.model 169 | ))), 170 | } 171 | } 172 | 173 | manager.add_instance( 174 | provider_config.provider_type, 175 | provider_config.api_key, 176 | provider_config.model.clone(), 177 | provider_tasks, 178 | provider_config.enabled, 179 | provider_config.custom_endpoint, 180 | ).await; 181 | debug!("Built and added provider: {} ({})", provider_config.provider_type, provider_config.model); 182 | } 183 | 184 | // Check if the manager has instances 185 | let trackers = manager.trackers.lock().await; 186 | let is_empty = trackers.is_empty(); 187 | drop(trackers); 188 | 189 | if is_empty { 190 | log::warn!("LlmManager built with no provider instances."); 191 | } 192 | 193 | Ok(manager) 194 | } 195 | } -------------------------------------------------------------------------------- /src/providers/ollama.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use crate::load_balancer::tasks::TaskDefinition; 3 | use crate::providers::instances::{LlmInstance, BaseInstance}; 4 | use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage, Message}; 5 | use crate::errors::{LlmError, LlmResult}; 6 | use crate::constants; 7 | use async_trait::async_trait; 8 | use reqwest::header; 9 | use serde::{Serialize, Deserialize}; 10 | use url::Url; 11 | 12 | /// Provider implementation for Ollama (local LLMs) 13 | pub struct OllamaInstance { 14 | base: BaseInstance, 15 | // Specific URL for this provider instance 16 | endpoint_url: String, 17 | } 18 | 19 | /// Request structure for Ollama's chat API 20 | #[derive(Serialize)] 21 | struct OllamaRequest { 22 | model: String, 23 | messages: Vec, 24 | stream: bool, 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | options: Option, 27 | } 28 | 29 | #[derive(Serialize, Default)] 30 | struct OllamaOptions { 31 | #[serde(skip_serializing_if = "Option::is_none")] 32 | temperature: Option, 33 | #[serde(skip_serializing_if = "Option::is_none")] 34 | num_predict: Option, // Corresponds to max_tokens 35 | } 36 | 37 | /// Response structure from Ollama's chat API (non-streaming) 38 | #[derive(Deserialize, Debug)] 39 | struct OllamaResponse { 40 | model: String, 41 | created_at: String, 42 | message: Message, 43 | done: bool, // Should be true for non-streaming response 44 | #[serde(default)] // Use default (0) if not present 45 | prompt_eval_count: u32, 46 | #[serde(default)] // Use default (0) if not present 47 | eval_count: u32, // Corresponds roughly to completion tokens 48 | } 49 | 50 | impl OllamaInstance { 51 | /// Creates a new Ollama provider instance 52 | /// 53 | /// # Parameters 54 | /// * `api_key` - Unused for Ollama by default, but kept for consistency. Could be repurposed (e.g., for future auth or endpoint override). 55 | /// * `model` - Default model to use (e.g., "llama3") 56 | /// * `supported_tasks` - Map of tasks this provider supports 57 | /// * `enabled` - Whether this provider is enabled 58 | /// * `endpoint_url` - Optional base endpoint URL override. If None, uses the default from constants. 59 | pub fn new(api_key: String, model: String, supported_tasks: HashMap, enabled: bool, endpoint_url: Option) -> Self { 60 | // Determine the endpoint: use provided one or default 61 | let base_endpoint = endpoint_url.unwrap_or_else(|| constants::OLLAMA_API_ENDPOINT.to_string()); 62 | 63 | // Validate and ensure the path ends correctly 64 | let final_endpoint = match Url::parse(&base_endpoint) { 65 | Ok(mut url) => { 66 | if !url.path().ends_with("/api/chat") { 67 | if url.path() == "/" { 68 | url.set_path("api/chat"); 69 | } else { 70 | let current_path = url.path().trim_end_matches('/'); 71 | url.set_path(&format!("{}/api/chat", current_path)); 72 | } 73 | } 74 | url.to_string() 75 | } 76 | Err(_) => { 77 | eprintln!( 78 | "Warning: Invalid Ollama endpoint URL '{}' provided. Falling back to default: {}", 79 | base_endpoint, constants::OLLAMA_API_ENDPOINT 80 | ); 81 | constants::OLLAMA_API_ENDPOINT.to_string() 82 | } 83 | }; 84 | 85 | // Create BaseProvider with the actual API key (even if empty/unused) 86 | let base = BaseInstance::new("ollama".to_string(), api_key, model, supported_tasks, enabled); 87 | 88 | Self { 89 | base, 90 | endpoint_url: final_endpoint, 91 | } 92 | } 93 | } 94 | 95 | #[async_trait] 96 | impl LlmInstance for OllamaInstance { 97 | /// Generates a completion using Ollama's API 98 | /// 99 | /// # Parameters 100 | /// * `request` - The LLM request containing messages and parameters 101 | /// 102 | /// # Returns 103 | /// * `LlmResult` - The response from the model or an error 104 | async fn generate(&self, request: &LlmRequest) -> LlmResult { 105 | if !self.base.is_enabled() { 106 | return Err(LlmError::ProviderDisabled("Ollama".to_string())); 107 | } 108 | 109 | let mut headers = header::HeaderMap::new(); 110 | headers.insert( 111 | header::CONTENT_TYPE, 112 | header::HeaderValue::from_static("application/json"), 113 | ); 114 | 115 | // Add Authorization header if an API key is actually provided and non-empty 116 | if !self.base.api_key().is_empty() { 117 | match header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) { 118 | Ok(val) => { headers.insert(header::AUTHORIZATION, val); }, 119 | Err(e) => return Err(LlmError::ConfigError(format!("Invalid API key format for Ollama: {}", e))), 120 | } 121 | } 122 | 123 | let model = request.model.clone().unwrap_or_else(|| self.base.model().to_string()); 124 | 125 | // Map common parameters to Ollama options 126 | let mut options = OllamaOptions::default(); 127 | if request.temperature.is_some() { 128 | options.temperature = request.temperature; 129 | } 130 | if request.max_tokens.is_some() { 131 | options.num_predict = request.max_tokens; 132 | } 133 | 134 | let ollama_request = OllamaRequest { 135 | model, 136 | messages: request.messages.clone(), 137 | stream: false, 138 | options: if options.temperature.is_some() || options.num_predict.is_some() { Some(options) } else { None }, 139 | }; 140 | 141 | let response = self.base.client() 142 | .post(&self.endpoint_url) 143 | .headers(headers) 144 | .json(&ollama_request) 145 | .send() 146 | .await?; 147 | 148 | let response_status = response.status(); 149 | if !response_status.is_success() { 150 | let error_text = response.text().await 151 | .unwrap_or_else(|_| format!("Unknown error. Status: {}", response_status)); 152 | return Err(LlmError::ApiError(format!("Ollama API error: {}", error_text))); 153 | } 154 | 155 | let response_text = response.text().await?; 156 | if response_text.is_empty() { 157 | return Err(LlmError::ApiError("Received empty response body from Ollama".to_string())); 158 | } 159 | 160 | // Attempt to parse the JSON response 161 | let ollama_response: OllamaResponse = serde_json::from_str(&response_text) 162 | .map_err(|e| LlmError::ApiError(format!("Failed to parse Ollama JSON response: {}. Body: {}", e, response_text)))?; 163 | 164 | // Map Ollama token counts to unified format. 165 | // Note: Ollama's `eval_count` is often used for completion tokens. `prompt_eval_count` for prompt. 166 | // The exact definition might vary slightly depending on the model and Ollama version. 167 | let usage = Some(TokenUsage { 168 | prompt_tokens: ollama_response.prompt_eval_count, 169 | completion_tokens: ollama_response.eval_count, 170 | total_tokens: ollama_response.prompt_eval_count + ollama_response.eval_count, 171 | }); 172 | 173 | 174 | Ok(LlmResponse { 175 | content: ollama_response.message.content.clone(), 176 | model: ollama_response.model, 177 | usage, 178 | }) 179 | } 180 | 181 | /// Returns provider name 182 | fn get_name(&self) -> &str { 183 | self.base.name() 184 | } 185 | 186 | /// Returns current model name 187 | fn get_model(&self) -> &str { 188 | self.base.model() 189 | } 190 | 191 | /// Returns supported tasks for this provider 192 | fn get_supported_tasks(&self) -> &HashMap { 193 | &self.base.supported_tasks() 194 | } 195 | 196 | /// Returns whether this provider is enabled 197 | fn is_enabled(&self) -> bool { 198 | self.base.is_enabled() 199 | } 200 | } -------------------------------------------------------------------------------- /src/providers/google.rs: -------------------------------------------------------------------------------- 1 | use crate::load_balancer::tasks::TaskDefinition; 2 | use crate::providers::instances::{LlmInstance, BaseInstance}; 3 | use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage, Message}; 4 | use crate::errors::{LlmError, LlmResult}; 5 | use crate::constants; 6 | 7 | use async_trait::async_trait; 8 | use reqwest::header; 9 | use serde::{Serialize, Deserialize}; 10 | use std::collections::HashMap; 11 | use log::debug; 12 | 13 | /// Provider implementation for Google's Gemini AI models 14 | pub struct GoogleInstance { 15 | base: BaseInstance, 16 | } 17 | 18 | /// Request structure for Google's Gemini API 19 | #[derive(Serialize)] 20 | struct GoogleGenerateContentRequest { 21 | contents: Vec, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | #[serde(rename = "generationConfig")] 24 | generation_config: Option, 25 | } 26 | 27 | /// Content structure for Google's Gemini API messages 28 | #[derive(Serialize, Deserialize)] 29 | struct GoogleContent { 30 | role: String, 31 | parts: Vec, 32 | } 33 | 34 | /// Individual content part for Google's Gemini API 35 | #[derive(Serialize, Deserialize)] 36 | struct GooglePart { 37 | text: String, 38 | } 39 | 40 | /// Generation configuration for Google's Gemini API 41 | #[derive(Serialize, Default)] 42 | struct GoogleGenerationConfig { 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | temperature: Option, 45 | // #[serde(skip_serializing_if = "Option::is_none")] 46 | // top_k: Option, 47 | // #[serde(skip_serializing_if = "Option::is_none")] 48 | // top_p: Option, 49 | #[serde(skip_serializing_if = "Option::is_none")] 50 | #[serde(rename = "maxOutputTokens")] 51 | max_output_tokens: Option, 52 | // #[serde(skip_serializing_if = "Option::is_none")] 53 | // stop_sequences: Option>, 54 | } 55 | 56 | /// Response structure from Google's Gemini API 57 | #[derive(Deserialize)] 58 | struct GoogleGenerateContentResponse { 59 | candidates: Vec, 60 | } 61 | 62 | /// Individual candidate from Google's Gemini API response 63 | #[derive(Deserialize)] 64 | struct GoogleCandidate { 65 | content: GoogleContent, 66 | #[serde(rename = "tokenCount")] 67 | #[serde(default)] 68 | token_count: u32, // Note: Google provides total token count here 69 | // safety_ratings: Vec, // We don't use this currently 70 | } 71 | 72 | 73 | impl GoogleInstance { 74 | /// Creates a new Google provider instance 75 | /// 76 | /// # Parameters 77 | /// * `api_key` - Google API key 78 | /// * `model` - Default model to use (e.g. "gemini-pro") 79 | /// * `supported_tasks` - Map of tasks this provider supports 80 | /// * `enabled` - Whether this provider is enabled 81 | pub fn new(api_key: String, model: String, supported_tasks: HashMap, enabled: bool) -> Self { 82 | let base = BaseInstance::new("google".to_string(), api_key, model, supported_tasks, enabled); 83 | Self { base } 84 | } 85 | 86 | /// Maps standard message format to Google's expected format 87 | /// 88 | /// This function handles several Google-specific requirements: 89 | /// - Converts "assistant" role to "model" role 90 | /// - Prepends system messages to the first user message 91 | /// - Validates that the first message is from the user 92 | /// 93 | /// # Parameters 94 | /// * `messages` - Array of messages in our standard format 95 | /// 96 | /// # Returns 97 | /// * `LlmResult>` - Mapped contents or an error 98 | fn map_messages_to_contents(messages: &[Message]) -> LlmResult> { 99 | let mut contents = Vec::new(); 100 | let mut system_prompt: Option = None; 101 | let mut first_user_message_index: Option = None; 102 | for (_, msg) in messages.iter().enumerate() { 103 | match msg.role.as_str() { 104 | "system" => { 105 | if system_prompt.is_some() { 106 | return Err(LlmError::ApiError("Multiple system messages are not supported by Google provider mapping.".to_string())); 107 | } 108 | system_prompt = Some(msg.content.clone()); 109 | } 110 | "user" | "model" | "assistant" => { 111 | let role = if msg.role == "assistant" { "model" } else { &msg.role }; 112 | if role == "user" && first_user_message_index.is_none() { 113 | first_user_message_index = Some(contents.len()); 114 | } 115 | contents.push(GoogleContent { 116 | role: role.to_string(), 117 | parts: vec![GooglePart { text: msg.content.clone() }], 118 | }); 119 | } 120 | _ => { 121 | log::warn!("Ignoring message with unknown role: {}", msg.role); 122 | } 123 | } 124 | } 125 | 126 | if let Some(sys_prompt) = &system_prompt { 127 | if let Some(user_idx) = first_user_message_index { 128 | if let Some(user_content) = contents.get_mut(user_idx) { 129 | if let Some(part) = user_content.parts.get_mut(0) { 130 | part.text = format!("{}\n\n{}", sys_prompt, part.text); 131 | } 132 | } else { 133 | return Err(LlmError::ApiError("System message provided but no user message found.".to_string())); 134 | } 135 | } else { 136 | return Err(LlmError::ApiError("System message provided but no user message found.".to_string())); 137 | } 138 | } 139 | 140 | if contents.is_empty() { 141 | return Err(LlmError::ApiError("No valid messages found for Google provider.".to_string())); 142 | } 143 | if contents[0].role != "user" { 144 | return Err(LlmError::ApiError(format!("Google chat must start with a 'user' role message, found '{}'.", contents[0].role))); 145 | } 146 | Ok(contents) 147 | } 148 | } 149 | 150 | #[async_trait] 151 | impl LlmInstance for GoogleInstance { 152 | /// Generates a completion using Google's Gemini API 153 | /// 154 | /// # Parameters 155 | /// * `request` - The LLM request containing messages and parameters 156 | /// 157 | /// # Returns 158 | /// * `LlmResult` - The response from the model or an error 159 | async fn generate(&self, request: &LlmRequest) -> LlmResult { 160 | if !self.base.is_enabled() { 161 | return Err(LlmError::ProviderDisabled("Google".to_string())); 162 | } 163 | 164 | let model_name = self.base.model(); 165 | let api_key = self.base.api_key(); 166 | 167 | let url = format!( 168 | "{}/v1beta/models/{}:generateContent?key={}", 169 | constants::GOOGLE_API_ENDPOINT_PREFIX, 170 | model_name, 171 | api_key 172 | ); 173 | 174 | let mut headers = header::HeaderMap::new(); 175 | headers.insert( 176 | header::CONTENT_TYPE, 177 | header::HeaderValue::from_static("application/json"), 178 | ); 179 | 180 | let contents = Self::map_messages_to_contents(&request.messages)?; 181 | 182 | let mut generation_config = GoogleGenerationConfig::default(); 183 | generation_config.temperature = request.temperature; 184 | generation_config.max_output_tokens = request.max_tokens; 185 | 186 | let google_request = GoogleGenerateContentRequest { 187 | contents, 188 | generation_config: Some(generation_config).filter(|gc| { 189 | gc.temperature.is_some() || gc.max_output_tokens.is_some() 190 | }), 191 | }; 192 | 193 | let response = self.base.client() 194 | .post(&url) 195 | .headers(headers) 196 | .json(&google_request) 197 | .send() 198 | .await?; 199 | 200 | if !response.status().is_success() { 201 | let status = response.status(); 202 | let error_json: Result = response.json().await; 203 | let error_details = match error_json { 204 | Ok(json) => json.get("error") 205 | .and_then(|e| e.get("message")) 206 | .and_then(|m| m.as_str()) 207 | .map(|s| s.to_string()) 208 | .unwrap_or_else(|| format!("Unknown error structure: {}", json)), 209 | Err(_) => "Failed to parse error response body".to_string(), 210 | }; 211 | 212 | return Err(LlmError::ApiError(format!( 213 | "Google API error ({}): {}", 214 | status, error_details 215 | ))); 216 | } 217 | 218 | let google_response: GoogleGenerateContentResponse = response.json().await 219 | .map_err(|e| LlmError::ApiError(format!("Failed to parse Google JSON response: {}", e)))?; 220 | 221 | 222 | if google_response.candidates.is_empty() { 223 | return Err(LlmError::ApiError("No candidates returned from Google. Content may have been blocked.".to_string())); 224 | } 225 | 226 | let candidate = &google_response.candidates[0]; 227 | 228 | let combined_content = candidate.content.parts.iter() 229 | .map(|part| part.text.clone()) 230 | .collect::>() 231 | .join(""); 232 | 233 | let usage = if candidate.token_count > 0 { 234 | // Simply use the token count as the total 235 | Some(TokenUsage { 236 | prompt_tokens: 0, 237 | completion_tokens: 0, 238 | total_tokens: candidate.token_count, 239 | }) 240 | } else { 241 | None 242 | }; 243 | 244 | debug!("Google usage: {:?}", usage); 245 | 246 | Ok(LlmResponse { 247 | content: combined_content, 248 | model: model_name.to_string(), 249 | usage, 250 | }) 251 | } 252 | 253 | /// Returns provider name 254 | fn get_name(&self) -> &str { 255 | self.base.name() 256 | } 257 | 258 | /// Returns current model name 259 | fn get_model(&self) -> &str { 260 | self.base.model() 261 | } 262 | 263 | /// Returns supported tasks for this provider 264 | fn get_supported_tasks(&self) -> &HashMap { 265 | &self.base.supported_tasks() 266 | } 267 | 268 | /// Returns whether this provider is enabled 269 | fn is_enabled(&self) -> bool { 270 | self.base.is_enabled() 271 | } 272 | } -------------------------------------------------------------------------------- /examples/task_routing.rs: -------------------------------------------------------------------------------- 1 | use flyllm::{ 2 | ProviderType, LlmManager, GenerationRequest, LlmManagerResponse, TaskDefinition, LlmResult, 3 | use_logging, ModelDiscovery, ModelInfo 4 | }; 5 | use std::env; 6 | use std::path::PathBuf; 7 | use std::time::Instant; 8 | use std::collections::HashMap; 9 | use futures::future::join_all; 10 | use log::info; 11 | 12 | #[tokio::main] 13 | async fn main() -> LlmResult<()> { 14 | env::set_var("RUST_LOG", "debug"); // Uncomment this to see debugging messages 15 | use_logging(); // Setup logging 16 | 17 | info!("Starting Task Routing Example"); 18 | 19 | // --- API Keys --- 20 | let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); 21 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 22 | let mistral_api_key = env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); 23 | let google_api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY not set"); 24 | 25 | // --- Fetch and print available models --- 26 | print_available_models(&anthropic_api_key, &openai_api_key, &mistral_api_key, &google_api_key).await; 27 | 28 | // --- Configure Manager using Builder --- 29 | let manager = LlmManager::builder() 30 | // Define tasks centrally 31 | .define_task( 32 | TaskDefinition::new("summary") 33 | .with_max_tokens(500) // Use helper or with_param 34 | .with_param("temperature", 0.3) // Use generic method 35 | ) 36 | .define_task( 37 | TaskDefinition::new("creative_writing") 38 | .with_max_tokens(1500) 39 | .with_temperature(0.9) 40 | ) 41 | .define_task( 42 | TaskDefinition::new("code_generation") 43 | ) 44 | .define_task( 45 | TaskDefinition::new("short_poem") 46 | .with_max_tokens(100) 47 | .with_temperature(0.8) 48 | ) 49 | 50 | // Add providers and link tasks by name 51 | // .add_instance(ProviderType::Ollama, "llama2:7b", "") 52 | // .supports("summary") // Chain configuration for this provider 53 | // .supports("code_generation") 54 | // .custom_endpoint("http://localhost:11434/api/chat") // This is the default Ollama endpoint, but we can specify custom ones. 55 | // // .enabled(true) // Optional, defaults to true 56 | 57 | .add_instance(ProviderType::Mistral, "mistral-large-latest", &mistral_api_key) 58 | .supports("summary") 59 | .supports("code_generation") 60 | 61 | .add_instance(ProviderType::Anthropic, "claude-3-sonnet-20240229", &anthropic_api_key) 62 | .supports("summary") 63 | .supports("creative_writing") 64 | .supports("code_generation") 65 | 66 | .add_instance(ProviderType::Anthropic, "claude-3-opus-20240229", &anthropic_api_key) 67 | .supports_many(&["creative_writing", "short_poem"]) // Example using supports_many 68 | 69 | .add_instance(ProviderType::Google, "gemini-2.0-flash", &google_api_key) 70 | .supports("short_poem") 71 | 72 | .add_instance(ProviderType::OpenAI, "gpt-3.5-turbo", &openai_api_key) 73 | .supports("summary") 74 | // Example: Add a disabled provider 75 | // .add_instance(ProviderType::OpenAI, "gpt-4", &openai_api_key) 76 | // .supports("creative_writing") 77 | // .supports("code_generation") 78 | // .enabled(false) // Explicitly disable 79 | 80 | // Adds a debug folder for debugging all requests made 81 | .debug_folder(PathBuf::from("debug_folder")) 82 | 83 | // Finalize the manager configuration 84 | .build().await?; // Added .await here 85 | 86 | // Get provider count asynchronously 87 | let provider_count = manager.get_provider_count().await; 88 | info!("LlmManager configured with {} providers.", provider_count); 89 | 90 | // --- Define Requests using Builder --- 91 | let requests = vec![ 92 | GenerationRequest::builder( 93 | "Summarize the following text: Climate change refers to long-term shifts...", 94 | ) 95 | .task("summary") 96 | .build(), 97 | 98 | GenerationRequest::builder("Write a short story about a robot discovering emotions.") 99 | .task("creative_writing") 100 | .build(), 101 | 102 | GenerationRequest::builder( 103 | "Write a Python function that calculates the Fibonacci sequence up to n terms.", 104 | ) 105 | .task("code_generation") 106 | .build(), 107 | 108 | // Example overriding parameters for a specific request 109 | GenerationRequest::builder("Write a VERY short poem about the rain.") 110 | .task("creative_writing") // Target creative writing task defaults... 111 | .max_tokens(50) // ...but override max_tokens just for this request 112 | // .param("temperature", 0.95) // Could override temperature too 113 | .build(), 114 | 115 | GenerationRequest::builder("Write a rust program to sum two input numbers via console.") 116 | .task("code_generation") 117 | .build(), 118 | 119 | GenerationRequest::builder("Craft a haiku about a silent dawn.") 120 | .task("short_poem") 121 | .build(), 122 | ]; 123 | info!("Defined {} requests using builder pattern.", requests.len()); 124 | 125 | 126 | // --- Run Requests (Sequential and Parallel) --- 127 | println!("\n=== Running requests sequentially... ==="); 128 | let sequential_start = Instant::now(); 129 | let sequential_results = manager.generate_sequentially(requests.clone()).await; 130 | let sequential_duration = sequential_start.elapsed(); 131 | println!("Sequential processing completed in {:?}", sequential_duration); 132 | print_results(&sequential_results); 133 | 134 | 135 | println!("\n=== Running requests in parallel... ==="); 136 | let parallel_start = Instant::now(); 137 | let parallel_results = manager.batch_generate(requests).await; // Use original requests vec 138 | let parallel_duration = parallel_start.elapsed(); 139 | println!("Parallel processing completed in {:?}", parallel_duration); 140 | print_results(¶llel_results); 141 | 142 | 143 | info!("Task Routing Example Finished."); 144 | 145 | // --- Comparison --- 146 | println!("\n--- Comparison ---"); 147 | println!("Sequential Duration: {:?}", sequential_duration); 148 | println!("Parallel Duration: {:?}", parallel_duration); 149 | 150 | if parallel_duration < sequential_duration && parallel_duration.as_nanos() > 0 { 151 | let speedup = sequential_duration.as_secs_f64() / parallel_duration.as_secs_f64(); 152 | println!("Parallel execution was roughly {:.2}x faster.", speedup); 153 | } else if parallel_duration >= sequential_duration { 154 | println!("Parallel execution was not faster (or was equal) in this run."); 155 | } else { 156 | println!("Parallel execution finished too quickly to measure speedup reliably."); 157 | } 158 | 159 | // Print token usage asynchronously 160 | manager.print_token_usage().await; 161 | 162 | Ok(()) 163 | } 164 | 165 | /// Fetches models from all providers and prints them in a table format 166 | async fn print_available_models( 167 | anthropic_api_key: &str, 168 | openai_api_key: &str, 169 | mistral_api_key: &str, 170 | google_api_key: &str 171 | ) { 172 | println!("\n=== AVAILABLE MODELS ==="); 173 | 174 | // Clone the API keys for use in the spawned tasks 175 | let anthropic_key = anthropic_api_key.to_string(); 176 | let openai_key = openai_api_key.to_string(); 177 | let mistral_key = mistral_api_key.to_string(); 178 | let google_key = google_api_key.to_string(); 179 | 180 | // Fetch models from different providers in parallel 181 | let futures = vec![ 182 | tokio::spawn(async move { ModelDiscovery::list_anthropic_models(&anthropic_key).await }), 183 | tokio::spawn(async move { ModelDiscovery::list_openai_models(&openai_key).await }), 184 | tokio::spawn(async move { ModelDiscovery::list_mistral_models(&mistral_key).await }), 185 | tokio::spawn(async move { ModelDiscovery::list_google_models(&google_key).await }), 186 | tokio::spawn(async { ModelDiscovery::list_ollama_models(None).await }), 187 | ]; 188 | 189 | let results = join_all(futures).await; 190 | 191 | // Create a map to store models by provider 192 | let mut models_by_provider: HashMap> = HashMap::new(); 193 | 194 | // Define the provider order for each index 195 | let providers = [ 196 | ProviderType::Anthropic, 197 | ProviderType::OpenAI, 198 | ProviderType::Mistral, 199 | ProviderType::Google, 200 | ProviderType::Ollama 201 | ]; 202 | 203 | // Process results 204 | for (i, result) in results.into_iter().enumerate() { 205 | if i >= providers.len() { continue; } 206 | let provider = providers[i]; 207 | 208 | match result { 209 | Ok(Ok(models)) => { 210 | models_by_provider.insert(provider, models); 211 | }, 212 | Ok(Err(e)) => { 213 | println!("Error fetching {} models: {}", provider, e); 214 | }, 215 | Err(e) => { 216 | println!("Task error fetching {} models: {}", provider, e); 217 | } 218 | } 219 | } 220 | 221 | // Print models in a table format 222 | println!("\n{:<15} {:<40}", "PROVIDER", "MODEL NAME"); 223 | println!("{}", "=".repeat(55)); 224 | 225 | // Print models in the specified provider order 226 | for provider in providers.iter() { 227 | if let Some(models) = models_by_provider.get(provider) { 228 | for model in models { 229 | println!("{:<15} {:<40}", provider.to_string(), model.name); 230 | } 231 | // Add a separator between providers 232 | println!("{}", "-".repeat(55)); 233 | } 234 | } 235 | } 236 | 237 | fn print_results(results: &[LlmManagerResponse]) { 238 | println!("\n--- Request Results ---"); 239 | 240 | let task_labels = [ 241 | "Summary Request", 242 | "Creative Writing Request", 243 | "Code Generation Request", 244 | "Short Poem Request (Override)", 245 | "Rust Code Request", 246 | "Haiku Request" 247 | ]; 248 | 249 | for (i, result) in results.iter().enumerate() { 250 | let task_label = task_labels.get(i).map_or_else(|| "Unknown Task", |&name| name); 251 | println!("{}:", task_label); 252 | if result.success { 253 | let content_preview = result.content.chars().take(150).collect::(); 254 | let ellipsis = if result.content.chars().count() > 150 { "..." } else { "" }; 255 | println!(" Success: {}{}\n", content_preview, ellipsis); 256 | } else { 257 | println!(" Error: {}\n", result.error.as_ref().unwrap_or(&"Unknown error".to_string())); 258 | } 259 | } 260 | } -------------------------------------------------------------------------------- /src/providers/model_discovery.rs: -------------------------------------------------------------------------------- 1 | use crate::providers::types::{ModelInfo, ProviderType}; 2 | use crate::errors::{LlmError, LlmResult}; 3 | use crate::constants; 4 | use reqwest::{Client, header}; 5 | use serde::Deserialize; 6 | use std::time::Duration; 7 | 8 | /// Helper module for listing available models from providers 9 | /// without requiring a fully initialized provider instance 10 | pub struct ModelDiscovery; 11 | 12 | impl ModelDiscovery { 13 | /// Create a standardized HTTP client for model discovery 14 | fn create_client() -> Client { 15 | Client::builder() 16 | .timeout(Duration::from_secs(30)) 17 | .build() 18 | .expect("Failed to create HTTP client") 19 | } 20 | 21 | /// List available models from Anthropic 22 | /// 23 | /// # Parameters 24 | /// * `api_key` - Anthropic API key 25 | /// 26 | /// # Returns 27 | /// * Vector of ModelInfo structs containing model names 28 | pub async fn list_anthropic_models(api_key: &str) -> LlmResult> { 29 | let client = Self::create_client(); 30 | 31 | let mut headers = header::HeaderMap::new(); 32 | headers.insert( 33 | "x-api-key", 34 | header::HeaderValue::from_str(api_key) 35 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format for Anthropic: {}", e)))?, 36 | ); 37 | headers.insert( 38 | "anthropic-version", 39 | header::HeaderValue::from_static(constants::ANTHROPIC_API_VERSION), 40 | ); 41 | 42 | let models_endpoint = "https://api.anthropic.com/v1/models"; 43 | 44 | let response = client.get(models_endpoint) 45 | .headers(headers) 46 | .send() 47 | .await?; 48 | 49 | if !response.status().is_success() { 50 | let status = response.status(); 51 | let error_text = response.text().await 52 | .unwrap_or_else(|_| format!("Unknown error reading error response body, status: {}", status)); 53 | return Err(LlmError::ApiError(format!("Anthropic API error ({}): {}", status, error_text))); 54 | } 55 | 56 | let response_bytes = response.bytes().await?; 57 | 58 | #[derive(Deserialize, Debug)] 59 | struct AnthropicModelsResponse { 60 | data: Vec, 61 | } 62 | #[derive(Deserialize, Debug)] 63 | struct AnthropicModelInfo { 64 | id: String, 65 | display_name: String, 66 | } 67 | 68 | let anthropic_response: AnthropicModelsResponse = serde_json::from_slice(&response_bytes) 69 | .map_err(|e| { 70 | let snippet_len = std::cmp::min(response_bytes.len(), 256); 71 | let response_snippet = String::from_utf8_lossy(response_bytes.get(0..snippet_len).unwrap_or_default()); 72 | LlmError::ParseError(format!( 73 | "Error decoding Anthropic models JSON: {}. Response snippet: '{}'", 74 | e, 75 | response_snippet 76 | )) 77 | })?; 78 | 79 | let models = anthropic_response.data.into_iter() 80 | .map(|m| ModelInfo { 81 | name: m.id, 82 | provider: ProviderType::Anthropic, 83 | }) 84 | .collect(); 85 | 86 | Ok(models) 87 | } 88 | 89 | /// List available models from OpenAI 90 | /// 91 | /// # Parameters 92 | /// * `api_key` - OpenAI API key 93 | /// 94 | /// # Returns 95 | /// * Vector of ModelInfo structs containing model names 96 | pub async fn list_openai_models(api_key: &str) -> LlmResult> { 97 | let client = Self::create_client(); 98 | 99 | let mut headers = header::HeaderMap::new(); 100 | headers.insert( 101 | header::AUTHORIZATION, 102 | header::HeaderValue::from_str(&format!("Bearer {}", api_key)) 103 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, 104 | ); 105 | 106 | let models_endpoint = "https://api.openai.com/v1/models"; 107 | 108 | let response = client.get(models_endpoint) 109 | .headers(headers) 110 | .send() 111 | .await?; 112 | 113 | if !response.status().is_success() { 114 | let error_text = response.text().await 115 | .unwrap_or_else(|_| "Unknown error".to_string()); 116 | return Err(LlmError::ApiError(format!("OpenAI API error: {}", error_text))); 117 | } 118 | 119 | #[derive(Deserialize)] 120 | struct OpenAIModelsResponse { 121 | data: Vec, 122 | } 123 | 124 | #[derive(Deserialize)] 125 | struct OpenAIModelInfo { 126 | id: String, 127 | } 128 | 129 | let openai_response: OpenAIModelsResponse = response.json().await?; 130 | 131 | let models = openai_response.data.into_iter() 132 | .filter(|m| m.id.starts_with("gpt-")) 133 | .map(|m| ModelInfo { 134 | name: m.id, 135 | provider: ProviderType::OpenAI, 136 | }) 137 | .collect(); 138 | 139 | Ok(models) 140 | } 141 | 142 | /// List available models from Mistral 143 | /// 144 | /// # Parameters 145 | /// * `api_key` - Mistral API key 146 | /// 147 | /// # Returns 148 | /// * Vector of ModelInfo structs containing model names 149 | pub async fn list_mistral_models(api_key: &str) -> LlmResult> { 150 | let client = Self::create_client(); 151 | 152 | let mut headers = header::HeaderMap::new(); 153 | headers.insert( 154 | header::AUTHORIZATION, 155 | header::HeaderValue::from_str(&format!("Bearer {}", api_key)) 156 | .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, 157 | ); 158 | 159 | let models_endpoint = "https://api.mistral.ai/v1/models"; 160 | 161 | let response = client.get(models_endpoint) 162 | .headers(headers) 163 | .send() 164 | .await?; 165 | 166 | if !response.status().is_success() { 167 | let error_text = response.text().await 168 | .unwrap_or_else(|_| "Unknown error".to_string()); 169 | return Err(LlmError::ApiError(format!("Mistral API error: {}", error_text))); 170 | } 171 | 172 | #[derive(Deserialize)] 173 | struct MistralModelsResponse { 174 | data: Vec, 175 | } 176 | 177 | #[derive(Deserialize)] 178 | struct MistralModelInfo { 179 | id: String, 180 | } 181 | 182 | let mistral_response: MistralModelsResponse = response.json().await?; 183 | 184 | let models = mistral_response.data.into_iter() 185 | .map(|m| ModelInfo { 186 | name: m.id, 187 | provider: ProviderType::Mistral, 188 | }) 189 | .collect(); 190 | 191 | Ok(models) 192 | } 193 | 194 | /// List available models from Google 195 | /// 196 | /// # Parameters 197 | /// * `api_key` - Google API key 198 | /// 199 | /// # Returns 200 | /// * Vector of ModelInfo structs containing model names 201 | pub async fn list_google_models(api_key: &str) -> LlmResult> { 202 | let client = Self::create_client(); 203 | 204 | let models_endpoint = format!( 205 | "{}/v1beta/models?key={}", 206 | constants::GOOGLE_API_ENDPOINT_PREFIX, 207 | api_key 208 | ); 209 | 210 | let response = client.get(&models_endpoint) 211 | .send() 212 | .await?; 213 | 214 | if !response.status().is_success() { 215 | let error_text = response.text().await 216 | .unwrap_or_else(|_| "Unknown error".to_string()); 217 | return Err(LlmError::ApiError(format!("Google API error: {}", error_text))); 218 | } 219 | 220 | #[derive(Deserialize)] 221 | struct GoogleModelsResponse { 222 | models: Vec, 223 | } 224 | 225 | #[derive(Deserialize)] 226 | struct GoogleModelInfo { 227 | name: String, 228 | #[serde(default)] 229 | display_name: Option, 230 | } 231 | 232 | let google_response: GoogleModelsResponse = response.json().await?; 233 | 234 | let models = google_response.models.into_iter() 235 | .map(|m| { 236 | let name = m.display_name.unwrap_or_else(|| { 237 | m.name.split('/').last().unwrap_or(&m.name).to_string() 238 | }); 239 | 240 | ModelInfo { 241 | name, 242 | provider: ProviderType::Google, 243 | } 244 | }) 245 | .collect(); 246 | 247 | Ok(models) 248 | } 249 | 250 | /// List available models from Ollama 251 | /// 252 | /// # Parameters 253 | /// * `base_url` - Optional base URL for Ollama API, defaults to localhost 254 | /// 255 | /// # Returns 256 | /// * Vector of ModelInfo structs containing model names 257 | pub async fn list_ollama_models(base_url: Option<&str>) -> LlmResult> { 258 | let client = Self::create_client(); 259 | 260 | // Use provided base URL or default to localhost 261 | let base_url = base_url.unwrap_or("http://localhost:11434"); 262 | let models_endpoint = format!("{}/api/tags", base_url.trim_end_matches('/')); 263 | 264 | let response = client.get(&models_endpoint) 265 | .send() 266 | .await?; 267 | 268 | if !response.status().is_success() { 269 | let error_text = response.text().await 270 | .unwrap_or_else(|_| "Unknown error".to_string()); 271 | return Err(LlmError::ApiError(format!("Ollama API error: {}", error_text))); 272 | } 273 | 274 | #[derive(Deserialize)] 275 | struct OllamaModelsResponse { 276 | models: Vec, 277 | } 278 | 279 | #[derive(Deserialize)] 280 | struct OllamaModelInfo { 281 | name: String, 282 | } 283 | 284 | let ollama_response: OllamaModelsResponse = response.json().await?; 285 | 286 | let models = ollama_response.models.into_iter() 287 | .map(|m| ModelInfo { 288 | name: m.name, 289 | provider: ProviderType::Ollama, 290 | }) 291 | .collect(); 292 | 293 | Ok(models) 294 | } 295 | 296 | /// List all models from a specific provider 297 | /// 298 | /// # Parameters 299 | /// * `provider_type` - Type of provider to query 300 | /// * `api_key` - API key for authentication 301 | /// * `base_url` - Optional base URL (only used for Ollama) 302 | /// 303 | /// # Returns 304 | /// * Vector of ModelInfo structs containing model names 305 | pub async fn list_models( 306 | provider_type: ProviderType, 307 | api_key: &str, 308 | base_url: Option<&str> 309 | ) -> LlmResult> { 310 | match provider_type { 311 | ProviderType::Anthropic => Self::list_anthropic_models(api_key).await, 312 | ProviderType::OpenAI => Self::list_openai_models(api_key).await, 313 | ProviderType::Mistral => Self::list_mistral_models(api_key).await, 314 | ProviderType::Google => Self::list_google_models(api_key).await, 315 | ProviderType::Ollama => Self::list_ollama_models(base_url).await, 316 | } 317 | } 318 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlyLLM 2 | 3 | FlyLLM is a Rust library that provides a load-balanced, multi-provider client for Large Language Models. It enables developers to seamlessly work with multiple LLM providers (OpenAI, Anthropic, Google, Mistral...) through a unified API with request routing, load balancing, and failure handling. 4 | 5 |
6 | FlyLLM Logo 7 |
8 | 9 | ## Features 10 | 11 | - **Multiple Provider Support** 🌐: Unified interface for OpenAI, Anthropic, Google, Ollama and Mistral APIs 12 | - **Task-Based Routing** 🧭: Route requests to the most appropriate provider based on predefined tasks 13 | - **Load Balancing** ⚖️: Automatically distribute load across multiple provider instances 14 | - **Failure Handling** 🛡️: Retry mechanisms and automatic failover between providers 15 | - **Parallel Processing** ⚡: Process multiple requests concurrently for improved throughput 16 | - **Custom Parameters** 🔧: Set provider-specific parameters per task or request 17 | - **Usage Tracking** 📊: Monitor token consumption for cost management 18 | - **Debug Logging** 🔍: Optional request/response logging to JSON files for debugging and analysis 19 | - **Builder Pattern Configuration** ✨: Fluent and readable setup for tasks and providers. 20 | 21 | ## Installation 22 | 23 | Add FlyLLM to your `Cargo.toml`: 24 | 25 | ```toml 26 | [dependencies] 27 | flyllm = "0.3.0" 28 | tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } # For async runtime 29 | ``` 30 | 31 | ## Architecture 32 | 33 | ![Open Escordia_2025-04-25_13-41-55](https://github.com/user-attachments/assets/a56e375b-0bca-4de6-a4d3-c000812105d5) 34 | 35 | The LLM Manager (`LLMManager`) serves as the core component for orchestrating language model interactions in your application. It manages multiple LLM instances (`LLMInstance`), each defined by a model, API key, and supported tasks (`TaskDefinition`). 36 | 37 | When your application sends a generation request (`GenerationRequest`), the manager automatically selects an appropriate instance based on configurable strategies (Last Recently Used, Quickest Response Time, etc.) and returns the generated response by the LLM (`LLMResponse`). This design prevents rate limiting by distributing requests across multiple instances (even of the same model) with different API keys. 38 | 39 | The manager handles failures gracefully by re-routing requests to alternative instances. It also supports parallel execution for significant performance improvements when processing multiple requests simultaneously! 40 | 41 | You can define default parameters (temperature, max_tokens) for each task while retaining the ability to override these settings in specific requests. The system also tracks token usage across all instances: 42 | 43 | ``` 44 | --- Token Usage Statistics --- 45 | ID Provider Model Prompt Tokens Completion Tokens Total Tokens 46 | ----------------------------------------------------------------------------------------------- 47 | 0 mistral mistral-small-latest 109 897 1006 48 | 1 anthropic claude-3-sonnet-20240229 133 1914 2047 49 | 2 anthropic claude-3-opus-20240229 51 529 580 50 | 3 google gemini-2.0-flash 0 0 0 51 | 4 openai gpt-3.5-turbo 312 1003 1315 52 | ``` 53 | 54 | ## Usage Examples 55 | 56 | The following sections describe the usage of flyllm. You can also check out the example given in `examples/task_routing.rs`! To activate FlyLLM's debug messages by setting the environment variable `RUST_LOG` to the value `"debug"`. 57 | 58 | ### Quick Start 59 | 60 | ```rust 61 | use flyllm::{ 62 | ProviderType, LlmManager, GenerationRequest, LlmManagerResponse, TaskDefinition, LlmResult, 63 | use_logging, // Helper to setup basic logging 64 | }; 65 | use std::env; // To read API keys from environment variables 66 | 67 | #[tokio::main] 68 | async fn main() -> LlmResult<()> { // Use LlmResult for error handling 69 | // Initialize logging (optional, requires log and env_logger crates) 70 | use_logging(); 71 | 72 | // Retrieve API key from environment 73 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 74 | 75 | // Configure the LLM manager using the builder pattern 76 | let manager = LlmManager::builder() 77 | // Define a task with specific default parameters 78 | .define_task( 79 | TaskDefinition::new("summary") 80 | .with_max_tokens(500) // Set max tokens for this task 81 | .with_param("temperature", 0.3) // Set temperature for this task 82 | ) 83 | // Add a provider instance and specify the tasks it supports 84 | .add_instance( 85 | ProviderType::OpenAI, 86 | "gpt-3.5-turbo", 87 | &openai_api_key, // Pass the API key 88 | ) 89 | .supports("summary") // Link the provider to the "summary" task 90 | // Finalize the manager configuration 91 | .build().await?; // Use await and '?' for error propagation 92 | 93 | // Create a generation request using the builder pattern 94 | let request = GenerationRequest::builder( 95 | "Summarize the following text: Climate change refers to long-term shifts in temperatures..." 96 | ) 97 | .task("summary") // Specify the task for routing 98 | .build(); 99 | 100 | // Generate response sequentially (for a single request) 101 | // The Manager will automatically choose the configured OpenAI provider for the "summary" task. 102 | let responses = manager.generate_sequentially(vec![request]).await; 103 | 104 | // Handle the response 105 | if let Some(response) = responses.first() { 106 | if response.success { 107 | println!("Response: {}", response.content); 108 | } else { 109 | println!("Error: {}", response.error.as_ref().unwrap_or(&"Unknown error".to_string())); 110 | } 111 | } 112 | 113 | // Print token usage statistics 114 | manager.print_token_usage().await; 115 | 116 | Ok(()) 117 | } 118 | ``` 119 | 120 | ### Adding Multiple Providers 121 | 122 | Configure the LlmManager with various providers, each supporting different tasks. 123 | 124 | ```rust 125 | use flyllm::{ProviderType, LlmManager, TaskDefinition, LlmResult}; 126 | use std::env; 127 | use std::path::PathBuf; 128 | 129 | async fn configure_manager() -> LlmResult { 130 | // --- API Keys --- 131 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 132 | let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); 133 | let mistral_api_key = env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); 134 | let google_api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY not set"); 135 | // Ollama typically doesn't require an API key for local instances 136 | 137 | let manager = LlmManager::builder() 138 | // Define all tasks first 139 | .define_task(TaskDefinition::new("summary").with_max_tokens(500).with_param("temperature", 0.3)) 140 | .define_task(TaskDefinition::new("qa").with_max_tokens(1000)) 141 | .define_task(TaskDefinition::new("creative_writing").with_max_tokens(1500).with_temperature(0.9)) 142 | .define_task(TaskDefinition::new("code_generation").with_param("temperature", 0.1)) 143 | .define_task(TaskDefinition::new("translation")) // Task with default provider parameters 144 | 145 | // Add OpenAI provider supporting summary and QA 146 | .add_instance(ProviderType::OpenAI, "gpt-4-turbo", &openai_api_key) 147 | .supports_many(&["summary", "qa"]) // Assign multiple tasks 148 | 149 | // Add Anthropic provider supporting creative writing and code generation 150 | .add_instance(ProviderType::Anthropic, "claude-3-sonnet-20240229", &anthropic_api_key) 151 | .supports("creative_writing") 152 | .supports("code_generation") 153 | 154 | // Add Mistral provider supporting summary and translation 155 | .add_instance(ProviderType::Mistral, "mistral-large-latest", &mistral_api_key) 156 | .supports("summary") 157 | .supports("translation") 158 | 159 | // Add Google (Gemini) provider supporting QA and creative writing 160 | .add_instance(ProviderType::Google, "gemini-1.5-pro", &google_api_key) 161 | .supports("qa") 162 | .supports("creative_writing") 163 | 164 | // Add a local Ollama provider supporting summary and code generation 165 | .add_instance(ProviderType::Ollama, "llama3:8b", "") // API key often empty for local Ollama 166 | .supports("summary") 167 | .supports("code_generation") 168 | .custom_endpoint("http://localhost:11434/api/chat") // Optional: Specify if not default 169 | 170 | // Optional: Enable debug logging to JSON files 171 | .debug_folder(PathBuf::from("debug_logs")) // All request/response data will be logged here 172 | 173 | // Finalize configuration 174 | .build().await?; 175 | 176 | println!("LlmManager configured with multiple providers."); 177 | Ok(manager) 178 | } 179 | ``` 180 | 181 | ### Task-Based Routing 182 | 183 | Define tasks with specific default parameters and create requests targeting those tasks. FlyLLM routes the request to a provider configured to support that task. 184 | 185 | ```rust 186 | use flyllm::{LlmManager, GenerationRequest, TaskDefinition, LlmResult}; 187 | use std::env; 188 | 189 | // Assume manager is configured as shown in "Adding Multiple Providers" 190 | async fn route_by_task(manager: LlmManager) -> LlmResult<()> { 191 | 192 | // Define tasks centrally in the builder (shown conceptually here) 193 | // LlmManager::builder() 194 | // .define_task( 195 | // TaskDefinition::new("summary") 196 | // .with_max_tokens(500) 197 | // .with_temperature(0.3) 198 | // ) 199 | // .define_task( 200 | // TaskDefinition::new("creative_writing") 201 | // .with_max_tokens(1500) 202 | // .with_temperature(0.9) 203 | // ) 204 | // // ... add providers supporting these tasks ... 205 | // .build()?; 206 | 207 | // Create requests with different tasks using the request builder 208 | let summary_request = GenerationRequest::builder( 209 | "Summarize the following article about renewable energy: ..." 210 | ) 211 | .task("summary") // This request will be routed to providers supporting "summary" 212 | .build(); 213 | 214 | let story_request = GenerationRequest::builder( 215 | "Write a short story about a futuristic city powered by algae." 216 | ) 217 | .task("creative_writing") // This request uses the "creative_writing" task defaults 218 | .build(); 219 | 220 | // Example: Override task defaults for a specific request 221 | let short_story_request = GenerationRequest::builder( 222 | "Write a VERY short story about a time traveler meeting a dinosaur." 223 | ) 224 | .task("creative_writing") // Based on "creative_writing" task... 225 | .max_tokens(200) // ...but override max_tokens for this specific request 226 | .param("temperature", 0.95) // Can override any parameter 227 | .build(); 228 | 229 | // Process requests (e.g., sequentially) 230 | let requests = vec![summary_request, story_request, short_story_request]; 231 | let results = manager.generate_sequentially(requests).await; 232 | 233 | // Handle results... 234 | for (i, result) in results.iter().enumerate() { 235 | println!("Request {}: Success = {}, Content/Error = {}", 236 | i + 1, 237 | result.success, 238 | if result.success { &result.content[..std::cmp::min(result.content.len(), 50)] } else { result.error.as_deref().unwrap_or("Unknown") } 239 | ); 240 | } 241 | 242 | Ok(()) 243 | } 244 | ``` 245 | 246 | ### Parallel Processing 247 | 248 | ```rust 249 | // Process in parallel 250 | let parallel_results = manager.batch_generate(requests).await; 251 | 252 | // Process each result 253 | for result in parallel_results { 254 | if result.success { 255 | println!("Success: {}", result.content); 256 | } else { 257 | println!("Error: {}", result.error.as_ref().unwrap_or(&"Unknown error".to_string())); 258 | } 259 | } 260 | ``` 261 | 262 | ### Debug Logging 263 | 264 | FlyLLM supports optional debug logging to help you analyze requests and responses. When enabled, it creates JSON files with detailed information about each generation call. 265 | 266 | ```rust 267 | use flyllm::{ProviderType, LlmManager, GenerationRequest, TaskDefinition, LlmResult}; 268 | use std::path::PathBuf; 269 | 270 | async fn setup_with_debug_logging() -> LlmResult { 271 | let manager = LlmManager::builder() 272 | .define_task(TaskDefinition::new("summary").with_max_tokens(500)) 273 | .add_instance(ProviderType::OpenAI, "gpt-3.5-turbo", &api_key) 274 | .supports("summary") 275 | 276 | // Enable debug logging - creates folder structure: debug_logs/timestamp/instance_id_provider_model/debug.json 277 | .debug_folder(PathBuf::from("debug_logs")) 278 | 279 | .build().await?; 280 | 281 | Ok(manager) 282 | } 283 | ``` 284 | 285 | The debug files contain structured JSON with: 286 | - **Metadata**: timestamp, instance details, request duration 287 | - **Input**: prompt, task, parameters used 288 | - **Output**: success status, generated content or error, token usage 289 | 290 | Example debug file structure: 291 | ```json 292 | [ 293 | { 294 | "metadata": { 295 | "timestamp": 1703123456, 296 | "instance_id": 0, 297 | "instance_name": "openai", 298 | "instance_model": "gpt-3.5-turbo", 299 | "duration_ms": 1250 300 | }, 301 | "input": { 302 | "prompt": "Summarize this text...", 303 | "task": "summary", 304 | "parameters": { 305 | "max_tokens": 500, 306 | "temperature": 0.3 307 | } 308 | }, 309 | "output": { 310 | "success": true, 311 | "content": "This text discusses...", 312 | "usage": { 313 | "prompt_tokens": 45, 314 | "completion_tokens": 123, 315 | "total_tokens": 168 316 | } 317 | } 318 | } 319 | ] 320 | ``` 321 | 322 | ## License 323 | 324 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 325 | 326 | ## Contributing 327 | 328 | Contributions are always welcome! If you're interested in contributing to FlyLLM, please fork the repository and create a new branch for your changes. When you're done with your changes, submit a pull request to merge your changes into the main branch. 329 | 330 | ## Supporting FlyLLM 331 | 332 | If you want to support FlyLLM, you can: 333 | - **Star** :star: the project in Github! 334 | - **Donate** :coin: to my [Ko-fi](https://ko-fi.com/rodmarkun) page! 335 | - **Share** :heart: the project with your friends! 336 | -------------------------------------------------------------------------------- /src/load_balancer/manager.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::{LlmError, LlmResult}; 2 | use crate::load_balancer::builder::LlmManagerBuilder; 3 | use crate::load_balancer::strategies; 4 | use crate::load_balancer::tasks::TaskDefinition; 5 | use crate::load_balancer::tracker::InstanceTracker; 6 | use crate::load_balancer::utils::{get_debug_path, write_to_debug_file}; 7 | use crate::providers::{LlmInstance, LlmRequest, Message, TokenUsage}; 8 | use crate::{constants, create_instance, ProviderType}; // TODO - ? 9 | use futures::future::join_all; 10 | use log::{debug, info, warn}; 11 | use serde::{Deserialize, Serialize}; 12 | use serde_json::{json, Value}; 13 | use std::collections::HashMap; 14 | use std::fs; 15 | use std::path::PathBuf; 16 | use std::sync::Arc; 17 | use std::time::{Instant, SystemTime, UNIX_EPOCH}; 18 | use tokio::sync::Mutex; 19 | 20 | /// User-facing request for LLM generation 21 | #[derive(Debug, Serialize, Deserialize, Clone)] 22 | pub struct GenerationRequest { 23 | pub prompt: String, // Prompt for the LLM 24 | pub task: Option, // Task to route for 25 | pub params: Option>, // Extra parameters 26 | } 27 | 28 | impl Default for GenerationRequest { 29 | fn default() -> Self { 30 | Self { 31 | prompt: String::new(), 32 | task: None, 33 | params: None, 34 | } 35 | } 36 | } 37 | 38 | impl GenerationRequest { 39 | // Standard Constructor 40 | pub fn new(prompt: String) -> Self { 41 | GenerationRequest { 42 | prompt, 43 | ..Default::default() 44 | } 45 | } 46 | 47 | /// Creates a builder for a GenerationRequest. 48 | pub fn builder(prompt: impl Into) -> GenerationRequest { 49 | GenerationRequest::new(prompt.into()) 50 | } 51 | 52 | /// Sets the target task for this request. 53 | pub fn task(mut self, name: impl Into) -> Self { 54 | self.task = Some(name.into()); 55 | self 56 | } 57 | 58 | /// Adds or overrides a parameter specifically for this request. 59 | pub fn param(mut self, key: impl Into, value: impl Into) -> Self { 60 | self.params 61 | .get_or_insert_with(HashMap::new) 62 | .insert(key.into(), value.into()); 63 | self 64 | } 65 | 66 | /// Sets max tokens for this generation in specific 67 | pub fn max_tokens(self, tokens: u32) -> Self { 68 | self.param("max_tokens", json!(tokens)) 69 | } 70 | 71 | /// Finalizes the GenerationRequest 72 | pub fn build(self) -> Self { 73 | self 74 | } 75 | } 76 | 77 | /// Internal request structure with additional retry information 78 | #[derive(Clone)] 79 | struct LlmManagerRequest { 80 | pub prompt: String, 81 | pub task: Option, 82 | pub params: Option>, 83 | pub attempts: usize, 84 | pub failed_instances: Vec, 85 | } 86 | 87 | impl LlmManagerRequest { 88 | /// Convert a user-facing GenerationRequest to internal format 89 | fn from_generation_request(request: GenerationRequest) -> Self { 90 | Self { 91 | prompt: request.prompt, 92 | task: request.task, 93 | params: request.params, 94 | attempts: 0, 95 | failed_instances: Vec::new(), 96 | } 97 | } 98 | } 99 | 100 | /// Response structure returned to users 101 | #[derive(Debug, Serialize, Deserialize, Clone)] 102 | pub struct LlmManagerResponse { 103 | pub content: String, 104 | pub success: bool, 105 | pub error: Option, 106 | } 107 | 108 | /// Main manager for LLM providers that handles load balancing and retries 109 | /// 110 | /// The LlmManager: 111 | /// - Manages multiple LLM instances ( from different providers) 112 | /// - Maps tasks to compatible instances 113 | /// - Routes requests to appropriate instances 114 | /// - Implements retries and fallbacks 115 | /// - Tracks performance metrics and token usage 116 | pub struct LlmManager { 117 | pub trackers: Arc>>, // Current instance trackers in the manager (contains the instances themselves) 118 | pub strategy: Arc>>, // Current strategy for load balancing being used 119 | pub tasks_to_instances: Arc>>>, // Map of which instances handle which tasks 120 | pub instance_counter: Mutex, // Used for giving unique IDs to each instance in this manager 121 | pub max_retries: usize, // Controls how many times a failed request will be tried before giving up 122 | pub total_usage: Mutex>, // Token usage of each instance 123 | pub debug_folder: Option, // Path where JSONs with debug inputs/outputs of each model will be stored 124 | pub creation_time: SystemTime 125 | } 126 | 127 | impl LlmManager { 128 | /// Create a new LlmManager with default settings 129 | pub fn new() -> Self { 130 | Self { 131 | trackers: Arc::new(Mutex::new(HashMap::new())), 132 | strategy: Arc::new(Mutex::new(Box::new( 133 | strategies::LeastRecentlyUsedStrategy::new(), 134 | ))), 135 | tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), 136 | instance_counter: Mutex::new(0), 137 | max_retries: constants::DEFAULT_MAX_TRIES, 138 | total_usage: Mutex::new(HashMap::new()), 139 | debug_folder: None, 140 | creation_time: SystemTime::now() 141 | } 142 | } 143 | 144 | /// Creates a new builder to configure the LlmManager. 145 | pub fn builder() -> LlmManagerBuilder { 146 | LlmManagerBuilder::new() 147 | } 148 | 149 | /// Create a new LlmManager with a custom load balancing strategy 150 | /// 151 | /// # Parameters 152 | /// * `strategy` - The load balancing strategy to use 153 | pub fn new_with_strategy( 154 | strategy: Box, 155 | ) -> Self { 156 | Self { 157 | trackers: Arc::new(Mutex::new(HashMap::new())), 158 | strategy: Arc::new(Mutex::new(strategy)), 159 | tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), 160 | instance_counter: Mutex::new(0), 161 | max_retries: constants::DEFAULT_MAX_TRIES, 162 | total_usage: Mutex::new(HashMap::new()), 163 | debug_folder: None, 164 | creation_time: SystemTime::now() 165 | } 166 | } 167 | 168 | /// Constructor used by the builder. 169 | pub fn new_with_strategy_and_retries( 170 | strategy: Box, 171 | max_retries: usize, 172 | ) -> Self { 173 | Self { 174 | trackers: Arc::new(Mutex::new(HashMap::new())), 175 | strategy: Arc::new(Mutex::new(strategy)), 176 | tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), 177 | instance_counter: Mutex::new(0), 178 | max_retries, // Use passed value 179 | total_usage: Mutex::new(HashMap::new()), 180 | debug_folder: None, 181 | creation_time: SystemTime::now() 182 | } 183 | } 184 | 185 | /// Adds a new LLM instance by creating it from basic parameters 186 | /// 187 | /// # Parameters 188 | /// * `provider_type` - Which LLM provider to use (Anthropic, OpenAI, etc) 189 | /// * `api_key` - API key for the provider 190 | /// * `model` - Model identifier to use 191 | /// * `tasks` - List of tasks this provider supports 192 | /// * `enabled` - Whether this provider should be enabled 193 | /// * `custom_endpont` - Optional specification on where the requests for this instance should go 194 | pub async fn add_instance( 195 | &mut self, 196 | provider_type: ProviderType, 197 | api_key: String, 198 | model: String, 199 | tasks: Vec, 200 | enabled: bool, 201 | custom_endpoint: Option, 202 | ) { 203 | debug!("Creating provider with model {}", model); 204 | let instance = create_instance( 205 | provider_type, 206 | api_key, 207 | model.clone(), 208 | tasks.clone(), 209 | enabled, 210 | custom_endpoint, 211 | ); 212 | self.add_instance_to_manager(instance).await; 213 | info!( 214 | "Added Provider Instance ({}) - Model: {} - Supports Tasks: {:?}", 215 | provider_type, 216 | model, 217 | tasks.iter().map(|t| t.name.as_str()).collect::>() 218 | ); 219 | } 220 | 221 | /// Add a pre-created provider instance 222 | /// 223 | /// # Parameters 224 | /// * `provider` - The provider instance to add 225 | pub async fn add_instance_to_manager(&mut self, instance: Arc) { 226 | let id = { 227 | let mut counter = self.instance_counter.lock().await; 228 | let current_id = *counter; 229 | *counter += 1; 230 | current_id 231 | }; 232 | 233 | let tracker = InstanceTracker::new(instance.clone()); 234 | debug!("Adding instance {} ({})", id, instance.get_name()); 235 | 236 | let supported_tasks_names: Vec = 237 | instance.get_supported_tasks().keys().cloned().collect(); 238 | 239 | { 240 | let mut task_map = self.tasks_to_instances.lock().await; 241 | for task_name in &supported_tasks_names { 242 | task_map 243 | .entry(task_name.clone()) 244 | .or_insert_with(Vec::new) 245 | .push(id); 246 | debug!("Added instance {} to task mapping for '{}'", id, task_name); 247 | } 248 | } 249 | 250 | { 251 | let mut trackers = self.trackers.lock().await; 252 | trackers.insert(id, tracker); 253 | } 254 | 255 | { 256 | let mut usage_map = self.total_usage.lock().await; 257 | usage_map.insert(id, TokenUsage::default()); // TODO - Implement default 258 | } 259 | } 260 | 261 | /// Set a new load balancing strategy 262 | /// 263 | /// # Parameters 264 | /// * `strategy` - The new load balancing strategy to use 265 | pub async fn set_strategy( 266 | &mut self, 267 | strategy: Box, 268 | ) { 269 | let mut current_strategy = self.strategy.lock().await; 270 | *current_strategy = strategy; 271 | } 272 | 273 | /// Process multiple requests sequentially 274 | /// 275 | /// # Parameters 276 | /// * `requests` - List of generation requests to process 277 | /// 278 | /// # Returns 279 | /// * List of responses in the same order as the requests 280 | pub async fn generate_sequentially( 281 | &self, 282 | requests: Vec, 283 | ) -> Vec { 284 | let mut responses = Vec::with_capacity(requests.len()); 285 | info!( 286 | "Entering generate_sequentially with {} requests", 287 | requests.len() 288 | ); 289 | 290 | for (index, request) in requests.into_iter().enumerate() { 291 | info!("Starting sequential request index: {}", index); 292 | let internal_request = LlmManagerRequest::from_generation_request(request); 293 | 294 | let response_result = self.generate_response(internal_request, None).await; 295 | info!( 296 | "Sequential request index {} completed generate_response call.", 297 | index 298 | ); 299 | 300 | let response = match response_result { 301 | Ok(content) => { 302 | info!("Sequential request index {} succeeded.", index); 303 | LlmManagerResponse { 304 | content, 305 | success: true, 306 | error: None, 307 | } 308 | } 309 | Err(e) => { 310 | warn!("Sequential request index {} failed: {}", index, e); 311 | LlmManagerResponse { 312 | content: String::new(), 313 | success: false, 314 | error: Some(e.to_string()), 315 | } 316 | } 317 | }; 318 | 319 | debug!("Pushing response for sequential request index {}", index); 320 | responses.push(response); 321 | info!("Finished processing sequential request index {}", index); 322 | } 323 | 324 | info!("Exiting generate_sequentially"); 325 | responses 326 | } 327 | 328 | /// Process multiple requests in parallel 329 | /// 330 | /// # Parameters 331 | /// * `requests` - List of generation requests to process 332 | /// 333 | /// # Returns 334 | /// * List of responses in the same order as the requests 335 | pub async fn batch_generate( 336 | &self, 337 | requests: Vec, 338 | ) -> Vec { 339 | info!("Entering batch_generate with {} requests", requests.len()); 340 | let internal_requests = requests 341 | .into_iter() 342 | .map(|request| LlmManagerRequest::from_generation_request(request)) 343 | .collect::>(); 344 | 345 | let futures = internal_requests 346 | .into_iter() 347 | .enumerate() 348 | .map(|(index, request)| async move { 349 | info!("Starting parallel request index: {}", index); 350 | match self.generate_response(request, None).await { 351 | Ok(content) => { 352 | info!("Parallel request index {} succeeded.", index); 353 | LlmManagerResponse { 354 | content, 355 | success: true, 356 | error: None, 357 | } 358 | } 359 | Err(e) => { 360 | warn!("Parallel request index {} failed: {}", index, e); 361 | LlmManagerResponse { 362 | content: String::new(), 363 | success: false, 364 | error: Some(e.to_string()), 365 | } 366 | } 367 | } 368 | }) 369 | .collect::>(); 370 | 371 | let results = join_all(futures).await; 372 | info!("Exiting batch_generate"); 373 | results 374 | } 375 | 376 | /// Core function to generate a response with retries 377 | /// 378 | /// # Parameters 379 | /// * `request` - The internal request with retry state 380 | /// * `max_attempts` - Optional override for maximum retry attempts 381 | /// 382 | /// # Returns 383 | /// * Result with either the generated content or an error 384 | async fn generate_response( 385 | &self, 386 | request: LlmManagerRequest, 387 | max_attempts: Option, 388 | ) -> LlmResult { 389 | let start_time = Instant::now(); 390 | let mut attempts = request.attempts; 391 | let mut failed_instances = request.failed_instances.clone(); 392 | let prompt_preview = request.prompt.chars().take(50).collect::(); 393 | let task = request.task.as_deref(); 394 | let request_params = request.params.clone(); 395 | let max_retries = max_attempts.unwrap_or(self.max_retries); 396 | 397 | info!( 398 | "generate_response called for task: {:?}, prompt: '{}...'", 399 | task, prompt_preview 400 | ); 401 | 402 | while attempts <= max_retries { 403 | debug!( 404 | "Attempt {} of {} for request (task: {:?})", 405 | attempts + 1, 406 | max_retries + 1, 407 | task 408 | ); 409 | 410 | let attempt_result = self 411 | .instance_selection( 412 | &request.prompt, 413 | task, 414 | request_params.clone(), 415 | &failed_instances, 416 | ) 417 | .await; 418 | 419 | match attempt_result { 420 | Ok((content, instance_id)) => { 421 | let duration = start_time.elapsed(); 422 | info!( 423 | "Request successful on attempt {} with instance {} after {:?}", 424 | attempts + 1, 425 | instance_id, 426 | duration 427 | ); 428 | return Ok(content); 429 | } 430 | Err((error, instance_id)) => { 431 | warn!( 432 | "Attempt {} failed with instance {}: {}", 433 | attempts + 1, 434 | instance_id, 435 | error 436 | ); 437 | 438 | // Check if this is a rate limit error 439 | if matches!(error, LlmError::RateLimit(_)) { 440 | warn!( 441 | "Rate limit detected for instance {}. Waiting before retry...", 442 | instance_id 443 | ); 444 | 445 | // Wait before retrying (exponential backoff) 446 | let wait_time = 447 | std::time::Duration::from_secs(2_u64.pow(attempts as u32).min(60)); 448 | tokio::time::sleep(wait_time).await; 449 | 450 | // Don't mark this instance as failed for rate limits 451 | // Just increment attempts and try again 452 | attempts += 1; 453 | } else { 454 | // For non-rate-limit errors, mark instance as failed 455 | failed_instances.push(instance_id); 456 | attempts += 1; 457 | } 458 | 459 | if attempts > max_retries { 460 | warn!( 461 | "Max retries ({}) reached for task: {:?}. Returning last error.", 462 | max_retries + 1, 463 | task 464 | ); 465 | return Err(error); 466 | } 467 | 468 | debug!( 469 | "Retrying with next eligible instance for task: {:?}...", 470 | task 471 | ); 472 | } 473 | } 474 | } 475 | 476 | warn!("Exited retry loop unexpectedly for task: {:?}", task); 477 | Err(LlmError::ConfigError( 478 | "No available providers after all retry attempts".to_string(), 479 | )) 480 | } 481 | 482 | /// Select an appropriate instance and execute the request 483 | /// 484 | /// This function: 485 | /// 1. Identifies instances that support the requested task 486 | /// 2. Filters out failed and disabled instances 487 | /// 3. Uses the load balancing strategy to select an instance 488 | /// 4. Merges task and request parameters 489 | /// 5. Executes the request against the selected provider 490 | /// 6. Updates metrics based on the result 491 | /// 492 | /// # Parameters 493 | /// * `prompt` - The prompt text to send 494 | /// * `task` - Optional task identifier 495 | /// * `request_params` - Optional request parameters 496 | /// * `failed_instances` - List of instance IDs that have failed 497 | /// 498 | /// # Returns 499 | /// * Success: (generated content, instance ID) 500 | /// * Error: (error, instance ID that failed) 501 | async fn instance_selection( 502 | &self, 503 | prompt: &str, 504 | task: Option<&str>, 505 | request_params: Option>, 506 | failed_instances: &[usize], 507 | ) -> Result<(String, usize), (LlmError, usize)> { 508 | debug!( 509 | "instance_selection: Starting selection for task: {:?}", 510 | task 511 | ); 512 | 513 | // 1. Get candidate instance IDs based on task (if any) 514 | let candidate_ids: Option> = match task { 515 | Some(task_name) => { 516 | let task_map = self.tasks_to_instances.lock().await; 517 | task_map.get(task_name).cloned() 518 | } 519 | None => None, // No specific task, consider all instances initially 520 | }; 521 | 522 | if task.is_some() && candidate_ids.is_none() { 523 | warn!("No instances found supporting task: '{}'", task.unwrap()); 524 | debug!("instance_selection returning Err (no task support)"); 525 | return Err(( 526 | LlmError::ConfigError(format!( 527 | "No providers available for task: {}", 528 | task.unwrap() 529 | )), 530 | 0, 531 | )); 532 | } 533 | 534 | // 2. Filter candidates by availability and collect all needed data in one go 535 | let eligible_instances_data: Vec<( 536 | usize, 537 | String, 538 | Arc, 539 | Option, 540 | )>; 541 | 542 | // Get eligible instance IDs for strategy selection 543 | let eligible_instance_ids: Vec; 544 | 545 | // Scope the lock to ensure it's dropped before strategy selection 546 | { 547 | let trackers_guard = self.trackers.lock().await; 548 | debug!("instance_selection: Acquired trackers lock (1st time)"); 549 | 550 | if trackers_guard.is_empty() { 551 | warn!("No LLM providers configured."); 552 | return Err(( 553 | LlmError::ConfigError("No LLM providers available".to_string()), 554 | 0, 555 | )); 556 | } 557 | 558 | // Extract all the data we need while holding the lock 559 | match candidate_ids { 560 | Some(ids) => { 561 | debug!( 562 | "Filtering instances for task '{}' using IDs: {:?}", 563 | task.unwrap(), 564 | ids 565 | ); 566 | eligible_instances_data = trackers_guard 567 | .iter() 568 | .filter(|(id, tracker)| { 569 | ids.contains(id) && tracker.is_enabled() && !failed_instances.contains(id) 570 | }) 571 | .map(|(id, tracker)| { 572 | let task_def = task 573 | .and_then(|t| tracker.instance.get_supported_tasks().get(t).cloned()); 574 | ( 575 | *id, 576 | tracker.instance.get_name().to_string(), 577 | tracker.instance.clone(), 578 | task_def, 579 | ) 580 | }) 581 | .collect(); 582 | debug!( 583 | "Found {} eligible instances for task '{}'", 584 | eligible_instances_data.len(), 585 | task.unwrap() 586 | ); 587 | } 588 | None => { 589 | debug!("No specific task. Filtering all enabled instances."); 590 | eligible_instances_data = trackers_guard 591 | .iter() 592 | .filter(|(id, tracker)| tracker.is_enabled() && !failed_instances.contains(id)) 593 | .map(|(id, tracker)| { 594 | let task_def = task 595 | .and_then(|t| tracker.instance.get_supported_tasks().get(t).cloned()); 596 | ( 597 | *id, 598 | tracker.instance.get_name().to_string(), 599 | tracker.instance.clone(), 600 | task_def, 601 | ) 602 | }) 603 | .collect(); 604 | debug!( 605 | "Found {} eligible instances (no task)", 606 | eligible_instances_data.len() 607 | ); 608 | } 609 | } 610 | 611 | // Extract just the IDs for strategy selection 612 | eligible_instance_ids = eligible_instances_data 613 | .iter() 614 | .map(|(id, _, _, _)| *id) 615 | .collect(); 616 | 617 | // No eligible instances check 618 | if eligible_instances_data.is_empty() { 619 | let error_msg = format!( 620 | "No enabled providers available{}{}", 621 | task.map_or_else(|| "".to_string(), |t| format!(" for task: '{}'", t)), 622 | if !failed_instances.is_empty() { 623 | format!(" (excluded {} failed instances)", failed_instances.len()) 624 | } else { 625 | "".to_string() 626 | } 627 | ); 628 | warn!("{}", error_msg); 629 | return Err((LlmError::ConfigError(error_msg), 0)); 630 | } 631 | } 632 | 633 | // 5. Select instance using strategy (need to re-acquire lock for metrics) 634 | let selected_instance_id = { 635 | let trackers_guard = self.trackers.lock().await; 636 | let mut strategy = self.strategy.lock().await; 637 | debug!("instance_selection: Acquired strategy and trackers locks"); 638 | 639 | // Build the trackers slice for the strategy 640 | let eligible_trackers: Vec<(usize, &InstanceTracker)> = eligible_instance_ids 641 | .iter() 642 | .filter_map(|id| { 643 | trackers_guard.get(id).map(|tracker| (*id, tracker)) 644 | }) 645 | .collect(); 646 | 647 | let selected_metric_index = strategy.select_instance(&eligible_trackers); 648 | let selected_id = eligible_trackers[selected_metric_index].0; 649 | 650 | debug!("instance_selection: Released strategy lock"); 651 | selected_id 652 | }; 653 | 654 | // Find the corresponding instance in our extracted data 655 | let selected_instance = eligible_instances_data 656 | .iter() 657 | .find(|(id, _, _, _)| *id == selected_instance_id) 658 | .expect("Selected instance ID from metrics not found in eligible list - LOGIC ERROR!"); 659 | 660 | // Unpack the tuple 661 | let (selected_id, selected_name, selected_provider_arc, task_def) = ( 662 | selected_instance.0, 663 | &selected_instance.1, 664 | &selected_instance.2, 665 | &selected_instance.3, 666 | ); 667 | 668 | debug!( 669 | "Selected instance {} ({}) for the request.", 670 | selected_id, selected_name 671 | ); 672 | 673 | // 6. Merge parameters 674 | let mut final_params = HashMap::new(); 675 | if let Some(task_def) = task_def { 676 | final_params.extend(task_def.parameters.clone()); 677 | debug!("Applied parameters from task for instance {}", selected_id); 678 | } 679 | 680 | if let Some(req_params) = request_params { 681 | final_params.extend(req_params); 682 | debug!( 683 | "Applied request-specific parameters for instance {}", 684 | selected_id 685 | ); 686 | } 687 | 688 | // Create and execute the request 689 | let max_tokens = final_params 690 | .get("max_tokens") 691 | .and_then(|v| v.as_u64()) 692 | .map(|v| v as u32); 693 | 694 | let temperature = final_params 695 | .get("temperature") 696 | .and_then(|v| v.as_f64()) 697 | .map(|v| v as f32); 698 | 699 | let request = LlmRequest { 700 | messages: vec![Message { 701 | role: "user".to_string(), 702 | content: prompt.to_string(), 703 | }], 704 | model: None, // Let provider use its configured model 705 | max_tokens, 706 | temperature, 707 | }; 708 | 709 | debug!( 710 | "Instance {} ({}) sending request to provider...", 711 | selected_id, selected_name 712 | ); 713 | let start_time = Instant::now(); 714 | let result = selected_provider_arc.generate(&request).await; 715 | let duration = start_time.elapsed(); 716 | info!( 717 | "Instance {} ({}) received result in {:?}", 718 | selected_id, selected_name, duration 719 | ); 720 | 721 | // Update metrics regardless of success or failure 722 | { 723 | debug!("instance_selection: Attempting to acquire trackers lock (2nd time) for metrics update"); 724 | let mut trackers_guard = self.trackers.lock().await; 725 | debug!("instance_selection: Acquired trackers lock (2nd time)"); 726 | if let Some((_id, instance_tracker)) = trackers_guard 727 | .iter_mut() 728 | .find(|(id, _tracker)| **id == selected_id) 729 | { 730 | debug!("Recording result for instance {}", selected_id); 731 | instance_tracker.record_result(duration, &result); 732 | debug!("Finished recording result for instance {}", selected_id); 733 | } else { 734 | warn!( 735 | "Instance {} not found for metric update after request completion.", 736 | selected_id 737 | ); 738 | } 739 | debug!("instance_selection: Releasing trackers lock (2nd time) after metrics update"); 740 | // Lock released when trackers_guard goes out of scope here 741 | } 742 | 743 | // Write debug information if debug folder is configured 744 | self.write_debug_info( 745 | selected_id, 746 | selected_name, 747 | &selected_provider_arc.get_model(), 748 | prompt, 749 | task, 750 | &final_params, 751 | &result, 752 | duration, 753 | ).await; 754 | 755 | // Return either content or error with the instance ID 756 | match result { 757 | Ok(response) => { 758 | if let Some(usage) = &response.usage { 759 | self.update_instance_usage(selected_id, usage).await; 760 | debug!( 761 | "Updated token usage for instance {}: {:?}", 762 | selected_id, usage 763 | ); 764 | } 765 | debug!( 766 | "instance_selection returning Ok for instance {}", 767 | selected_id 768 | ); 769 | Ok((response.content, selected_id)) 770 | } 771 | Err(e) => { 772 | debug!( 773 | "instance_selection returning Err for instance {}: {}", 774 | selected_id, e 775 | ); 776 | Err((e, selected_id)) 777 | } 778 | } 779 | } 780 | 781 | /// Write debug information for a request/response to the debug folder 782 | async fn write_debug_info( 783 | &self, 784 | instance_id: usize, 785 | instance_name: &str, 786 | instance_model: &str, 787 | prompt: &str, 788 | task: Option<&str>, 789 | final_params: &HashMap, 790 | result: &Result, 791 | duration: std::time::Duration, 792 | ) { 793 | if let Some(debug_folder) = &self.debug_folder { 794 | let timestamp = self.creation_time 795 | .duration_since(UNIX_EPOCH) 796 | .unwrap_or_default() 797 | .as_secs(); 798 | 799 | let debug_path = get_debug_path( 800 | debug_folder, 801 | timestamp, 802 | instance_id, 803 | instance_name, 804 | instance_model 805 | ); 806 | 807 | // Create the new generation entry 808 | let generation_entry = json!({ 809 | "metadata": { 810 | "timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(), 811 | "instance_id": instance_id, 812 | "instance_name": instance_name, 813 | "instance_model": instance_model, 814 | "duration_ms": duration.as_millis() 815 | }, 816 | "input": { 817 | "prompt": prompt, 818 | "task": task, 819 | "parameters": final_params 820 | }, 821 | "output": match result { 822 | Ok(response) => json!({ 823 | "success": true, 824 | "content": response.content, 825 | "usage": response.usage 826 | }), 827 | Err(error) => json!({ 828 | "success": false, 829 | "error": error.to_string() 830 | }) 831 | } 832 | }); 833 | 834 | // Read existing file or create new array 835 | let mut generations: Vec = if debug_path.exists() { 836 | match fs::read_to_string(&debug_path) { 837 | Ok(content) => { 838 | match serde_json::from_str::>(&content) { 839 | Ok(array) => array, 840 | Err(e) => { 841 | warn!("Failed to parse existing debug file as JSON array, creating new: {}", e); 842 | Vec::new() 843 | } 844 | } 845 | } 846 | Err(e) => { 847 | warn!("Failed to read existing debug file, creating new: {}", e); 848 | Vec::new() 849 | } 850 | } 851 | } else { 852 | Vec::new() 853 | }; 854 | 855 | // Append new generation 856 | generations.push(generation_entry); 857 | 858 | // Write updated array back to file 859 | let json_string = match serde_json::to_string_pretty(&generations) { 860 | Ok(s) => s, 861 | Err(e) => { 862 | warn!("Failed to serialize debug data: {}", e); 863 | return; 864 | } 865 | }; 866 | 867 | if let Err(e) = write_to_debug_file(&debug_path, &json_string) { 868 | warn!("Failed to write debug file: {}", e); 869 | } 870 | } 871 | } 872 | 873 | /// Update token usage for a specific instance 874 | /// 875 | /// # Parameters 876 | /// * `instance_id` - ID of the instance to update 877 | /// * `usage` - The token usage to add 878 | async fn update_instance_usage(&self, instance_id: usize, usage: &TokenUsage) { 879 | let mut usage_map = self.total_usage.lock().await; 880 | 881 | let instance_usage = usage_map.entry(instance_id).or_insert(TokenUsage { 882 | prompt_tokens: 0, 883 | completion_tokens: 0, 884 | total_tokens: 0, 885 | }); 886 | 887 | instance_usage.prompt_tokens += usage.prompt_tokens; 888 | instance_usage.completion_tokens += usage.completion_tokens; 889 | instance_usage.total_tokens += usage.total_tokens; 890 | 891 | debug!( 892 | "Updated usage for instance {}: current total is {} tokens", 893 | instance_id, instance_usage.total_tokens 894 | ); 895 | } 896 | 897 | /// Get token usage for a specific instance 898 | /// 899 | /// # Parameters 900 | /// * `instance_id` - ID of the instance to query 901 | /// 902 | /// # Returns 903 | /// * Token usage for the specified instance, if found 904 | pub async fn get_instance_usage(&self, instance_id: usize) -> Option { 905 | let usage_map = self.total_usage.lock().await; 906 | usage_map.get(&instance_id).cloned() 907 | } 908 | 909 | /// Get total token usage across all instances 910 | /// 911 | /// # Returns 912 | /// * Combined token usage statistics 913 | pub async fn get_total_usage(&self) -> TokenUsage { 914 | let usage_map = self.total_usage.lock().await; 915 | 916 | usage_map.values().fold( 917 | TokenUsage { 918 | prompt_tokens: 0, 919 | completion_tokens: 0, 920 | total_tokens: 0, 921 | }, 922 | |mut acc, usage| { 923 | acc.prompt_tokens += usage.prompt_tokens; 924 | acc.completion_tokens += usage.completion_tokens; 925 | acc.total_tokens += usage.total_tokens; 926 | acc 927 | }, 928 | ) 929 | } 930 | 931 | /// Get the number of configured provider instances 932 | /// 933 | /// # Returns 934 | /// * Number of provider instances in the manager 935 | pub async fn get_provider_count(&self) -> usize { 936 | let trackers = self.trackers.lock().await; 937 | trackers.len() 938 | } 939 | 940 | /// Print token usage statistics to console 941 | pub async fn print_token_usage(&self) { 942 | println!("\n--- Token Usage Statistics ---"); 943 | println!( 944 | "{:<5} {:<15} {:<30} {:<15} {:<15} {:<15}", 945 | "ID", "Provider", "Model", "Prompt Tokens", "Completion Tokens", "Total Tokens" 946 | ); 947 | println!("{}", "-".repeat(95)); 948 | 949 | let trackers = self.trackers.lock().await; 950 | let usage_map = self.total_usage.lock().await; 951 | 952 | // Print usage for each instance 953 | for (instance_id, tracker) in trackers.iter() { 954 | if let Some(usage) = usage_map.get(instance_id) { 955 | println!( 956 | "{:<5} {:<15} {:<30} {:<15} {:<15} {:<15}", 957 | instance_id, 958 | tracker.instance.get_name(), 959 | tracker.instance.get_model(), 960 | usage.prompt_tokens, 961 | usage.completion_tokens, 962 | usage.total_tokens 963 | ); 964 | } 965 | } 966 | } 967 | 968 | } 969 | --------------------------------------------------------------------------------