├── crates ├── core │ ├── src │ │ └── lib.rs │ ├── .gitignore │ └── Cargo.toml └── providers │ ├── .gitignore │ ├── tests │ ├── openai.rs │ └── openai │ │ ├── streaming.rs │ │ └── generating.rs │ ├── src │ ├── openai │ │ ├── common │ │ │ ├── web_search_tool_call_item.rs │ │ │ ├── status.rs │ │ │ ├── reasoning_item.rs │ │ │ ├── function_tool_call_item.rs │ │ │ ├── truncation.rs │ │ │ ├── service_tier.rs │ │ │ ├── file_search_tool_item.rs │ │ │ ├── output_message_item.rs │ │ │ ├── reasoning.rs │ │ │ ├── tool_choice.rs │ │ │ ├── computer_tool_call_item.rs │ │ │ ├── text.rs │ │ │ └── tool.rs │ │ ├── response │ │ │ ├── incomplete_details.rs │ │ │ ├── response_error.rs │ │ │ ├── response_output.rs │ │ │ ├── usage.rs │ │ │ └── events │ │ │ │ └── streaming.rs │ │ ├── request │ │ │ ├── input_models │ │ │ │ ├── input_reference.rs │ │ │ │ ├── input_message.rs │ │ │ │ ├── item.rs │ │ │ │ └── common.rs │ │ │ ├── input.rs │ │ │ └── include.rs │ │ ├── errors.rs │ │ ├── client.rs │ │ ├── types.rs │ │ └── constants.rs │ ├── lib.rs │ └── utils │ │ ├── provider_strategy.rs │ │ └── errors.rs │ └── Cargo.toml ├── .env.example ├── NOTES.md ├── README.md ├── package.json ├── Cargo.toml ├── tsconfig.json ├── .gitignore ├── LICENSE └── bun.lock /crates/core/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /crates/core/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /crates/providers/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # OpenAI API Configuration 2 | OPENAI_API_KEY=your_openai_api_key_here -------------------------------------------------------------------------------- /crates/providers/tests/openai.rs: -------------------------------------------------------------------------------- 1 | mod openai { 2 | mod generating; 3 | mod streaming; 4 | } 5 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | ## TODO 2 | 3 | - [ ] Review optional fields and ensure they are skipped during serialization when None 4 | -------------------------------------------------------------------------------- /crates/providers/tests/openai/streaming.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn test_it_works() { 3 | assert_eq!("test", "test"); 4 | } 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI SDK 2 | 3 | A modular and composable AI SDK built in Rust. 4 | 5 | ## Installation 6 | 7 | Add this to your `Cargo.toml`: 8 | 9 | ```toml 10 | [dependencies] 11 | ai-sdk = "0.0.1" 12 | ``` 13 | 14 | ## License 15 | 16 | MIT 17 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ai-sdk", 3 | "module": "index.ts", 4 | "type": "module", 5 | "private": true, 6 | "scripts": { 7 | "build": "cargo build --release", 8 | "test": "cargo test --release", 9 | "lint": "cargo clippy --release", 10 | "format": "cargo fmt --release" 11 | }, 12 | "devDependencies": { 13 | "@types/bun": "latest" 14 | }, 15 | "peerDependencies": { 16 | "typescript": "^5" 17 | } 18 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "crates/core", 4 | "crates/providers", 5 | ] 6 | resolver = "2" 7 | 8 | [workspace.dependencies] 9 | serde = { version = "1.0.219", features = ["derive"] } 10 | serde_json = "1.0.140" 11 | tokio-stream = "0.1.17" 12 | reqwest = { version = "0.12.15", features = ["json"] } 13 | async-trait = "0.1.73" 14 | bytes = "1.4.0" 15 | futures = "0.3.28" 16 | tokio = { version = "1.45.0", features = ["full"] } 17 | dotenv = "0.15.0" -------------------------------------------------------------------------------- /crates/core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ai-sdk" 3 | version = "0.0.1" 4 | edition = "2021" 5 | description = "A modular and composable AI SDK built in Rust." 6 | license = "MIT" 7 | repository = "https://github.com/cupolalabs/ai-sdk" 8 | homepage = "https://github.com/cupolalabs/ai-sdk" 9 | readme = "../../README.md" 10 | keywords = ["AI", "SDK"] 11 | categories = ["api-bindings", "science"] 12 | publish = true 13 | 14 | [dependencies] 15 | ai-providers = "0.0.1" -------------------------------------------------------------------------------- /crates/providers/src/openai/common/web_search_tool_call_item.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 4 | pub struct WebSearchToolCallItem { 5 | pub id: String, 6 | pub status: String, 7 | } 8 | 9 | impl WebSearchToolCallItem { 10 | pub fn new(id: impl Into, status: impl Into) -> Self { 11 | Self { 12 | id: id.into(), 13 | status: status.into(), 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /crates/providers/src/openai/response/incomplete_details.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Default, PartialEq, Serialize, Deserialize)] 4 | pub struct IncompleteDetails { 5 | reason: String, 6 | } 7 | 8 | impl IncompleteDetails { 9 | pub fn new(reason: impl Into) -> Self { 10 | Self { 11 | reason: reason.into(), 12 | } 13 | } 14 | } 15 | 16 | #[cfg(test)] 17 | mod tests { 18 | use super::*; 19 | use serde_json::json; 20 | 21 | #[test] 22 | fn it_serializes_to_json() { 23 | let incomplete_details = IncompleteDetails::new("test_reason"); 24 | let json = serde_json::to_value(&incomplete_details).unwrap(); 25 | assert_eq!(json, json!({ "reason": "test_reason" })); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | // Environment setup & latest features 4 | "lib": ["esnext"], 5 | "target": "ESNext", 6 | "module": "ESNext", 7 | "moduleDetection": "force", 8 | "jsx": "react-jsx", 9 | "allowJs": true, 10 | 11 | // Bundler mode 12 | "moduleResolution": "bundler", 13 | "allowImportingTsExtensions": true, 14 | "verbatimModuleSyntax": true, 15 | "noEmit": true, 16 | 17 | // Best practices 18 | "strict": true, 19 | "skipLibCheck": true, 20 | "noFallthroughCasesInSwitch": true, 21 | "noUncheckedIndexedAccess": true, 22 | 23 | // Some stricter flags (disabled by default) 24 | "noUnusedLocals": false, 25 | "noUnusedParameters": false, 26 | "noPropertyAccessFromIndexSignature": false 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /crates/providers/tests/openai/generating.rs: -------------------------------------------------------------------------------- 1 | use ai_providers::{ 2 | openai::constants::OpenAIModelId, openai::request::input::Input, OpenAIProvider, OpenAIRequest, 3 | ProviderStrategy, 4 | }; 5 | 6 | #[tokio::test] 7 | async fn test_it_works() { 8 | // Load environment variables from .env file 9 | dotenv::dotenv().ok(); 10 | 11 | let api_key = 12 | std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable not set"); 13 | let provider = OpenAIProvider::new(api_key); 14 | 15 | let request = OpenAIRequest::new( 16 | OpenAIModelId::Gpt3_5Turbo, 17 | Input::Message("Who's Fred again..".into()), 18 | ); 19 | 20 | let result = provider.generate(&request).await.unwrap(); 21 | 22 | println!("{:?}", result); 23 | 24 | assert_eq!("test", "test"); 25 | } 26 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/status.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::errors::ConversionError; 2 | use std::str::FromStr; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize)] 7 | #[serde(rename_all = "snake_case")] 8 | pub enum Status { 9 | InProgress, 10 | Completed, 11 | Incomplete, 12 | Failed, 13 | } 14 | 15 | impl FromStr for Status { 16 | type Err = ConversionError; 17 | 18 | fn from_str(s: &str) -> Result { 19 | match s { 20 | "in_progress" => Ok(Status::InProgress), 21 | "completed" => Ok(Status::Completed), 22 | "incomplete" => Ok(Status::Incomplete), 23 | "failed" => Ok(Status::Failed), 24 | _ => Err(ConversionError::FromStr(s.to_string())), 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Rust ### 2 | # Generated by Cargo 3 | # will have compiled files and executables 4 | debug/ 5 | target/ 6 | 7 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 8 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 9 | Cargo.lock 10 | 11 | # These are backup files generated by rustfmt 12 | **/*.rs.bk 13 | 14 | # MSVC Windows builds of rustc generate these, which store debugging information 15 | *.pdb 16 | 17 | # Node modules 18 | node_modules/ 19 | 20 | # Coverage report 21 | tarpaulin-report.html 22 | 23 | # Environment variables 24 | .env 25 | .env.local 26 | .env.*.local 27 | .env.development 28 | .env.test 29 | .env.production 30 | .env.development.local 31 | .env.test.local 32 | .env.production.local 33 | 34 | # IDE specific files 35 | .idea/ 36 | .vscode/ 37 | *.swp 38 | *.swo 39 | .DS_Store 40 | 41 | -------------------------------------------------------------------------------- /crates/providers/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ai-providers" 3 | version = "0.0.1" 4 | edition = "2021" 5 | description = "A comprehensive collection of AI provider implementations, schematics, and API integrations for various AI services" 6 | license = "MIT" 7 | repository = "https://github.com/cupolalabs/ai-sdk" 8 | homepage = "https://github.com/cupolalabs/ai-sdk" 9 | readme = "../../README.md" 10 | keywords = ["AI", "providers", "openai", "api", "sdk"] 11 | categories = ["api-bindings", "science"] 12 | publish = true 13 | 14 | [dependencies] 15 | serde.workspace = true 16 | serde_json.workspace = true 17 | tokio-stream.workspace = true 18 | reqwest = { workspace = true, features = ["stream", "json"] } 19 | async-trait.workspace = true 20 | bytes.workspace = true 21 | futures.workspace = true 22 | 23 | [dev-dependencies] 24 | tokio = { workspace = true, features = ["full", "test-util", "macros"] } 25 | dotenv.workspace = true -------------------------------------------------------------------------------- /crates/providers/src/openai/response/response_error.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Default, PartialEq, Serialize, Deserialize)] 4 | pub struct ResponseError { 5 | code: String, 6 | message: String, 7 | } 8 | 9 | impl ResponseError { 10 | pub fn new(code: impl Into, message: impl Into) -> Self { 11 | Self { 12 | code: code.into(), 13 | message: message.into(), 14 | } 15 | } 16 | } 17 | 18 | #[cfg(test)] 19 | mod tests { 20 | use super::*; 21 | use serde_json::json; 22 | 23 | #[test] 24 | fn it_serializes_to_json() { 25 | let error = ResponseError::new("test_code", "test_message"); 26 | let json = serde_json::to_value(&error).unwrap(); 27 | assert_eq!( 28 | json, 29 | json!({ 30 | "code": "test_code", 31 | "message": "test_message" 32 | }) 33 | ); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /crates/providers/src/openai/response/response_output.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::common::{ 2 | computer_tool_call_item::ComputerToolCallItem, file_search_tool_item::FileSearchToolCallItem, 3 | function_tool_call_item::FunctionToolCallItem, output_message_item::OutputMessageItem, 4 | reasoning_item::ReasoningItem, web_search_tool_call_item::WebSearchToolCallItem, 5 | }; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | #[derive(Debug, Serialize, Deserialize, PartialEq)] 9 | #[serde(tag = "type")] 10 | pub enum ResponseOutput { 11 | #[serde(rename = "message")] 12 | OutputMessage(OutputMessageItem), 13 | #[serde(rename = "file_search_call")] 14 | FileSearchToolCall(FileSearchToolCallItem), 15 | #[serde(rename = "computer_call")] 16 | ComputerToolCall(ComputerToolCallItem), 17 | #[serde(rename = "web_search_call")] 18 | WebSearchToolCall(WebSearchToolCallItem), 19 | #[serde(rename = "function_call")] 20 | FunctionToolCall(FunctionToolCallItem), 21 | #[serde(rename = "reasoning")] 22 | Reasoning(ReasoningItem), 23 | } 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 AI SDK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/reasoning_item.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::common::status::Status; 2 | use crate::openai::request::input_models::item::Summary; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 6 | pub struct ReasoningItem { 7 | pub id: String, 8 | pub summary: Vec, 9 | #[serde(skip_serializing_if = "Option::is_none")] 10 | pub encrypted_content: Option, 11 | #[serde(skip_serializing_if = "Option::is_none")] 12 | pub status: Option, 13 | } 14 | 15 | impl ReasoningItem { 16 | pub fn new(id: impl Into, summary: Vec) -> Self { 17 | Self { 18 | id: id.into(), 19 | summary, 20 | encrypted_content: None, 21 | status: None, 22 | } 23 | } 24 | 25 | pub fn encrypted_content(mut self, value: impl Into) -> Self { 26 | self.encrypted_content = Some(value.into()); 27 | self 28 | } 29 | 30 | pub fn status(mut self, value: Status) -> Self { 31 | self.status = Some(value); 32 | self 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /crates/providers/src/openai/response/usage.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 4 | pub struct InputTokensDetails { 5 | pub cached_tokens: usize, 6 | } 7 | 8 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 9 | pub struct OutputTokensDetails { 10 | pub reasoning_tokens: usize, 11 | } 12 | 13 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 14 | pub struct Usage { 15 | pub input_tokens: usize, 16 | pub input_tokens_details: InputTokensDetails, 17 | pub output_tokens: usize, 18 | pub output_tokens_details: OutputTokensDetails, 19 | pub total_tokens: usize, 20 | } 21 | 22 | impl Usage { 23 | pub fn new( 24 | input_tokens: usize, 25 | input_tokens_details: InputTokensDetails, 26 | output_tokens: usize, 27 | output_tokens_details: OutputTokensDetails, 28 | total_tokens: usize, 29 | ) -> Self { 30 | Self { 31 | input_tokens, 32 | input_tokens_details, 33 | output_tokens, 34 | output_tokens_details, 35 | total_tokens, 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/input_models/input_reference.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 4 | pub struct InputReference { 5 | pub id: String, 6 | #[serde(rename = "type")] 7 | #[serde(skip_serializing_if = "Option::is_none")] 8 | pub type_field: Option, 9 | } 10 | 11 | impl InputReference { 12 | pub fn new(id: impl Into) -> Self { 13 | Self { 14 | id: id.into(), 15 | type_field: None, 16 | } 17 | } 18 | 19 | pub fn insert_type(mut self) -> Self { 20 | self.type_field = Some("item_reference".to_string()); 21 | self 22 | } 23 | } 24 | 25 | #[cfg(test)] 26 | mod tests { 27 | use super::*; 28 | 29 | #[test] 30 | fn test_json_values() { 31 | let input_reference = InputReference::new("123").insert_type(); 32 | let json_value = serde_json::to_value(&input_reference).unwrap(); 33 | assert_eq!( 34 | json_value, 35 | serde_json::json!({ 36 | "id": "123", 37 | "type": "item_reference" 38 | }) 39 | ); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/function_tool_call_item.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::openai::common::status::Status; 4 | 5 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 6 | pub struct FunctionToolCallItem { 7 | pub arguments: String, 8 | pub call_id: String, 9 | pub name: String, 10 | #[serde(skip_serializing_if = "Option::is_none")] 11 | pub id: Option, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub status: Option, 14 | } 15 | 16 | impl FunctionToolCallItem { 17 | pub fn new( 18 | arguments: impl Into, 19 | call_id: impl Into, 20 | name: impl Into, 21 | ) -> Self { 22 | Self { 23 | arguments: arguments.into(), 24 | call_id: call_id.into(), 25 | name: name.into(), 26 | id: None, 27 | status: None, 28 | } 29 | } 30 | 31 | pub fn id(mut self, value: impl Into) -> Self { 32 | self.id = Some(value.into()); 33 | self 34 | } 35 | 36 | pub fn status(mut self, value: Status) -> Self { 37 | self.status = Some(value); 38 | self 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /bun.lock: -------------------------------------------------------------------------------- 1 | { 2 | "lockfileVersion": 1, 3 | "workspaces": { 4 | "": { 5 | "name": "ai-sdk", 6 | "devDependencies": { 7 | "@types/bun": "latest", 8 | }, 9 | "peerDependencies": { 10 | "typescript": "^5", 11 | }, 12 | }, 13 | }, 14 | "packages": { 15 | "@types/bun": ["@types/bun@1.2.10", "", { "dependencies": { "bun-types": "1.2.10" } }, "sha512-eilv6WFM3M0c9ztJt7/g80BDusK98z/FrFwseZgT4bXCq2vPhXD4z8R3oddmAn+R/Nmz9vBn4kweJKmGTZj+lg=="], 16 | 17 | "@types/node": ["@types/node@22.15.2", "", { "dependencies": { "undici-types": "~6.21.0" } }, "sha512-uKXqKN9beGoMdBfcaTY1ecwz6ctxuJAcUlwE55938g0ZJ8lRxwAZqRz2AJ4pzpt5dHdTPMB863UZ0ESiFUcP7A=="], 18 | 19 | "bun-types": ["bun-types@1.2.10", "", { "dependencies": { "@types/node": "*" } }, "sha512-b5ITZMnVdf3m1gMvJHG+gIfeJHiQPJak0f7925Hxu6ZN5VKA8AGy4GZ4lM+Xkn6jtWxg5S3ldWvfmXdvnkp3GQ=="], 20 | 21 | "typescript": ["typescript@5.8.3", "", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ=="], 22 | 23 | "undici-types": ["undici-types@6.21.0", "", {}, "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ=="], 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/input.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::request::input_models::{ 2 | input_message::InputMessage, input_reference::InputReference, item::Item, 3 | }; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 7 | #[serde(untagged)] 8 | pub enum InputItemList { 9 | InputMessage(InputMessage), 10 | Item(Item), 11 | ItemReference(InputReference), 12 | } 13 | 14 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 15 | #[serde(untagged)] 16 | pub enum Input { 17 | Messages(Vec), 18 | Message(String), 19 | } 20 | 21 | impl From for Input { 22 | fn from(value: String) -> Self { 23 | Self::Message(value) 24 | } 25 | } 26 | 27 | impl From> for Input { 28 | fn from(input_item_list: Vec) -> Self { 29 | Self::Messages(input_item_list) 30 | } 31 | } 32 | 33 | impl Default for Input { 34 | fn default() -> Self { 35 | Self::Message("".into()) 36 | } 37 | } 38 | 39 | impl Input { 40 | pub fn from_text(value: impl Into) -> Self { 41 | Self::Message(value.into()) 42 | } 43 | 44 | pub fn from_input_item_list(input_item_list: Vec) -> Self { 45 | Self::Messages(input_item_list) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /crates/providers/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub use crate::utils::provider_strategy::ProviderStrategy; 2 | pub use openai::client::OpenAIProvider; 3 | pub use openai::types::{OpenAIRequest, OpenAIResponse}; 4 | 5 | pub mod openai { 6 | pub mod client; 7 | pub mod constants; 8 | pub mod errors; 9 | pub mod types; 10 | pub mod common { 11 | pub mod computer_tool_call_item; 12 | pub mod file_search_tool_item; 13 | pub mod function_tool_call_item; 14 | pub mod output_message_item; 15 | pub mod reasoning; 16 | pub mod reasoning_item; 17 | pub mod service_tier; 18 | pub mod status; 19 | pub mod text; 20 | pub mod tool; 21 | pub mod tool_choice; 22 | pub mod truncation; 23 | pub mod web_search_tool_call_item; 24 | } 25 | pub mod request { 26 | pub mod include; 27 | pub mod input; 28 | pub mod input_models { 29 | pub mod common; 30 | pub mod input_message; 31 | pub mod input_reference; 32 | pub mod item; 33 | } 34 | } 35 | pub mod response { 36 | pub mod incomplete_details; 37 | pub mod response_error; 38 | pub mod response_output; 39 | pub mod usage; 40 | pub mod events { 41 | pub mod streaming; 42 | } 43 | } 44 | } 45 | 46 | pub mod utils { 47 | pub mod errors; 48 | pub mod provider_strategy; 49 | } 50 | -------------------------------------------------------------------------------- /crates/providers/src/utils/provider_strategy.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::errors::ProviderError; 2 | use async_trait::async_trait; 3 | use serde::{de::Deserialize, Serialize}; 4 | use std::pin::Pin; 5 | use tokio_stream::Stream; 6 | 7 | /// Provider trait defines a strategy pattern interface for API providers. 8 | /// 9 | /// This trait acts as the Strategy interface in the Strategy pattern: 10 | /// - Concrete providers implement this trait to provide specific API endpoints and authentication 11 | /// - Client code can use any provider that implements this trait interchangeably 12 | /// - Allows for runtime switching between different API providers without changing client code 13 | /// 14 | /// Each provider must implement methods to get the base URL and API key for their specific service. 15 | #[async_trait] 16 | pub trait ProviderStrategy { 17 | type GenerationRequest: Serialize + Send + Sync; 18 | type StreamingRequest: Serialize + Send + Sync; 19 | type GenerationResponse: for<'de> Deserialize<'de> + Send; 20 | type StreamingResponse: for<'de> Deserialize<'de> + Send; 21 | 22 | fn get_base_url(&self) -> String; 23 | fn get_api_key(&self) -> String; 24 | async fn generate( 25 | &self, 26 | request: &Self::GenerationRequest, 27 | ) -> Result; 28 | async fn stream( 29 | &self, 30 | request: &Self::StreamingRequest, 31 | ) -> Result< 32 | Pin> + Send>>, 33 | ProviderError, 34 | >; 35 | } 36 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/truncation.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | #[serde(rename_all = "lowercase")] 9 | pub enum Truncation { 10 | Auto, 11 | Disabled, 12 | } 13 | 14 | impl FromStr for Truncation { 15 | type Err = ConversionError; 16 | 17 | fn from_str(s: &str) -> Result { 18 | match s { 19 | "auto" => Ok(Truncation::Auto), 20 | "disabled" => Ok(Truncation::Disabled), 21 | _ => Err(ConversionError::FromStr(s.to_string())), 22 | } 23 | } 24 | } 25 | 26 | #[cfg(test)] 27 | mod tests { 28 | use super::*; 29 | use serde_json::json; 30 | 31 | #[test] 32 | fn test_from_str() { 33 | assert_eq!(Truncation::from_str("auto").unwrap(), Truncation::Auto); 34 | assert_eq!( 35 | Truncation::from_str("disabled").unwrap(), 36 | Truncation::Disabled 37 | ); 38 | } 39 | 40 | #[test] 41 | fn test_from_str_error() { 42 | assert!(Truncation::from_str("invalid").is_err()); 43 | } 44 | 45 | // test json representation 46 | #[test] 47 | fn test_json_representation() { 48 | let truncation = Truncation::Auto; 49 | let json = serde_json::to_value(&truncation).unwrap(); 50 | assert_eq!(json, json!("auto")); 51 | } 52 | 53 | #[test] 54 | fn test_json_representation_disabled() { 55 | let truncation = Truncation::Disabled; 56 | let json = serde_json::to_value(&truncation).unwrap(); 57 | assert_eq!(json, json!("disabled")); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/service_tier.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | #[serde(rename_all = "lowercase")] 9 | pub enum ServiceTier { 10 | Auto, 11 | Default, 12 | Flex, 13 | } 14 | 15 | impl FromStr for ServiceTier { 16 | type Err = ConversionError; 17 | 18 | fn from_str(s: &str) -> Result { 19 | match s { 20 | "auto" => Ok(ServiceTier::Auto), 21 | "default" => Ok(ServiceTier::Default), 22 | "flex" => Ok(ServiceTier::Flex), 23 | _ => Err(ConversionError::FromStr(s.to_string())), 24 | } 25 | } 26 | } 27 | 28 | #[cfg(test)] 29 | mod tests { 30 | use super::*; 31 | 32 | #[test] 33 | fn it_converts_str_to_service_tier() { 34 | let values = ["auto", "default", "flex"]; 35 | let expected_values = [ServiceTier::Auto, ServiceTier::Default, ServiceTier::Flex]; 36 | 37 | for (index, value) in values.iter().enumerate() { 38 | assert_eq!( 39 | ServiceTier::from_str(value).unwrap(), 40 | expected_values[index] 41 | ); 42 | } 43 | } 44 | 45 | #[test] 46 | fn it_returns_error_when_wrong_service_tier_is_given() { 47 | let value = "wrong"; 48 | 49 | assert_eq!( 50 | ServiceTier::from_str(value), 51 | Err(ConversionError::FromStr(value.to_string())) 52 | ); 53 | } 54 | 55 | #[test] 56 | fn test_json_values() { 57 | let service_tier = ServiceTier::Auto; 58 | let json_value = serde_json::to_value(&service_tier).unwrap(); 59 | assert_eq!(json_value, serde_json::json!("auto")); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /crates/providers/src/openai/errors.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, fmt::Display}; 2 | 3 | #[derive(Debug, PartialEq)] 4 | pub enum ConversionError { 5 | FromStr(String), 6 | TryFrom(String), 7 | } 8 | 9 | impl Display for ConversionError { 10 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 11 | match self { 12 | ConversionError::FromStr(msg) => write!(f, "Failed to convert from string: {}", msg), 13 | ConversionError::TryFrom(msg) => { 14 | write!(f, "Failed to convert from {} to target type", msg) 15 | } 16 | } 17 | } 18 | } 19 | 20 | impl Error for ConversionError {} 21 | 22 | #[derive(Debug, PartialEq)] 23 | pub enum InputError { 24 | InvalidToolType(String), 25 | InvalidRole(String), 26 | InvalidButton(String), 27 | ConversionError(ConversionError), 28 | InvalidModelId(String), 29 | } 30 | 31 | impl Display for InputError { 32 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 33 | match self { 34 | InputError::InvalidToolType(msg) => write!(f, "Invalid tool type: {}", msg), 35 | InputError::InvalidRole(msg) => { 36 | write!( 37 | f, 38 | "The role {} is not compatible with InputMessageItem", 39 | msg 40 | ) 41 | } 42 | InputError::InvalidButton(msg) => { 43 | write!(f, "Invalid button value: {}", msg) 44 | } 45 | InputError::ConversionError(err) => write!(f, "Conversion error: {}", err), 46 | InputError::InvalidModelId(err) => write!(f, "Invalid model id: {}", err), 47 | } 48 | } 49 | } 50 | 51 | impl Error for InputError {} 52 | 53 | impl From for InputError { 54 | fn from(error: ConversionError) -> Self { 55 | InputError::ConversionError(error) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /crates/providers/src/utils/errors.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | #[derive(Debug)] 4 | pub enum ProviderError { 5 | NetworkError(String), 6 | ApiError { status: u16, message: String }, 7 | DeserializationError(String), 8 | ValidationError(String), 9 | CapabilityError(String), 10 | NotSupported(String), 11 | InternalError(String), 12 | Other(String), 13 | } 14 | 15 | impl fmt::Display for ProviderError { 16 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 17 | match self { 18 | ProviderError::NetworkError(msg) => write!(f, "Network error: {}", msg), 19 | ProviderError::ApiError { status, message } => { 20 | write!(f, "API error (status {}): {}", status, message) 21 | } 22 | ProviderError::DeserializationError(msg) => { 23 | write!(f, "Deserialization error: {}", msg) 24 | } 25 | ProviderError::ValidationError(msg) => write!(f, "Validation error: {}", msg), 26 | ProviderError::CapabilityError(msg) => write!(f, "Capability error: {}", msg), 27 | ProviderError::NotSupported(msg) => write!(f, "Operation not supported: {}", msg), 28 | ProviderError::InternalError(msg) => write!(f, "Internal provider error: {}", msg), 29 | ProviderError::Other(msg) => write!(f, "An unexpected error occurred: {}", msg), 30 | } 31 | } 32 | } 33 | 34 | impl std::error::Error for ProviderError { 35 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 36 | match self { 37 | ProviderError::NetworkError(_) => None, 38 | ProviderError::ApiError { .. } => None, 39 | ProviderError::DeserializationError(_) => None, 40 | ProviderError::ValidationError(_) => None, 41 | ProviderError::CapabilityError(_) => None, 42 | ProviderError::NotSupported(_) => None, 43 | ProviderError::InternalError(_) => None, 44 | ProviderError::Other(_) => None, 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/include.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | pub enum Include { 9 | #[serde(rename = "file_search_call.results")] 10 | FileSearchCallResults, 11 | #[serde(rename = "message.input_image.image_url")] 12 | MessageInputImageUrl, 13 | #[serde(rename = "computer_call_output.output.image_url")] 14 | ComputerCallOutputImageUrl, 15 | } 16 | 17 | impl FromStr for Include { 18 | type Err = ConversionError; 19 | 20 | fn from_str(s: &str) -> Result { 21 | match s { 22 | "file_search_call.results" => Ok(Include::FileSearchCallResults), 23 | "message.input_image.image_url" => Ok(Include::MessageInputImageUrl), 24 | "computer_call_output.output.image_url" => Ok(Include::ComputerCallOutputImageUrl), 25 | _ => Err(ConversionError::FromStr(s.to_string())), 26 | } 27 | } 28 | } 29 | 30 | #[cfg(test)] 31 | mod tests { 32 | use super::*; 33 | 34 | #[test] 35 | fn it_converts_str_to_include_value() { 36 | let values = [ 37 | "file_search_call.results", 38 | "message.input_image.image_url", 39 | "computer_call_output.output.image_url", 40 | ]; 41 | 42 | let expected = [ 43 | Include::FileSearchCallResults, 44 | Include::MessageInputImageUrl, 45 | Include::ComputerCallOutputImageUrl, 46 | ]; 47 | 48 | for (index, value) in values.iter().enumerate() { 49 | assert_eq!(Include::from_str(value).unwrap(), expected[index]); 50 | } 51 | } 52 | 53 | #[test] 54 | fn it_converts_include_into_json() { 55 | for value in [ 56 | "file_search_call.results", 57 | "message.input_image.image_url", 58 | "computer_call_output.output.image_url", 59 | ] 60 | .iter() 61 | { 62 | let result = serde_json::to_value(Include::from_str(value).unwrap()).unwrap(); 63 | 64 | let expected = value.to_string(); 65 | 66 | assert_eq!(result, expected); 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/file_search_tool_item.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | use std::str::FromStr; 4 | 5 | use crate::openai::common::status::Status; 6 | use crate::openai::errors::ConversionError; 7 | 8 | #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] 9 | pub struct FileSearchToolCallResult { 10 | #[serde(skip_serializing_if = "Option::is_none")] 11 | pub attributes: Option>, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub file_id: Option, 14 | #[serde(skip_serializing_if = "Option::is_none")] 15 | pub filename: Option, 16 | #[serde(skip_serializing_if = "Option::is_none")] 17 | pub score: Option, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub text: Option, 20 | } 21 | 22 | impl FileSearchToolCallResult { 23 | pub fn new() -> Self { 24 | Self::default() 25 | } 26 | 27 | pub fn insert_attribute(mut self, key: String, value: String) -> Self { 28 | if self.attributes.is_none() { 29 | self.attributes = Some(HashMap::new()); 30 | } 31 | 32 | if let Some(attrs) = &mut self.attributes { 33 | attrs.insert(key, value); 34 | } 35 | 36 | self 37 | } 38 | 39 | pub fn file_id(mut self, value: impl Into) -> Self { 40 | self.file_id = Some(value.into()); 41 | self 42 | } 43 | 44 | pub fn filename(mut self, value: impl Into) -> Self { 45 | self.filename = Some(value.into()); 46 | self 47 | } 48 | 49 | pub fn score(mut self, value: usize) -> Self { 50 | self.score = Some(value); 51 | self 52 | } 53 | 54 | pub fn text(mut self, value: impl Into) -> Self { 55 | self.text = Some(value.into()); 56 | self 57 | } 58 | } 59 | 60 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 61 | pub struct FileSearchToolCallItem { 62 | pub id: String, 63 | pub queries: Vec, 64 | pub status: Status, 65 | #[serde(rename = "type")] 66 | pub results: Vec, 67 | } 68 | 69 | impl FileSearchToolCallItem { 70 | pub fn new(id: impl Into, status: impl AsRef) -> Result { 71 | Ok(Self { 72 | id: id.into(), 73 | queries: vec![], 74 | status: Status::from_str(status.as_ref())?, 75 | results: vec![], 76 | }) 77 | } 78 | 79 | pub fn extend_queries(mut self, queries: Vec>) -> Self { 80 | self.queries.extend(queries.into_iter().map(|q| q.into())); 81 | self 82 | } 83 | 84 | pub fn extend_results(mut self, results: Vec) -> Self { 85 | self.results.extend(results); 86 | self 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/output_message_item.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::common::status::Status; 2 | use crate::openai::errors::ConversionError; 3 | use crate::openai::request::input_models::common::Role; 4 | use serde::{Deserialize, Serialize}; 5 | use std::str::FromStr; 6 | 7 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 8 | pub struct FileCitation { 9 | pub file_id: String, 10 | pub index: usize, 11 | #[serde(rename = "type")] 12 | pub type_field: String, 13 | } 14 | 15 | impl FileCitation { 16 | pub fn new(file_id: impl Into, index: usize) -> Self { 17 | Self { 18 | file_id: file_id.into(), 19 | index, 20 | type_field: "file_citation".to_string(), 21 | } 22 | } 23 | } 24 | 25 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 26 | pub struct UrlCitation { 27 | pub end_index: String, 28 | pub start_index: String, 29 | pub title: String, 30 | #[serde(rename = "type")] 31 | pub type_field: String, 32 | pub url: String, 33 | } 34 | 35 | impl UrlCitation { 36 | pub fn new( 37 | end_index: impl Into, 38 | start_index: impl Into, 39 | title: impl Into, 40 | url: impl Into, 41 | ) -> Self { 42 | Self { 43 | end_index: end_index.into(), 44 | start_index: start_index.into(), 45 | title: title.into(), 46 | url: url.into(), 47 | type_field: "url_citation".to_string(), 48 | } 49 | } 50 | } 51 | 52 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 53 | pub struct FilePath { 54 | pub file_id: String, 55 | pub index: usize, 56 | #[serde(rename = "type")] 57 | pub type_field: String, 58 | } 59 | 60 | impl FilePath { 61 | pub fn new(file_id: impl Into, index: usize) -> Self { 62 | Self { 63 | file_id: file_id.into(), 64 | index, 65 | type_field: "file_path".to_string(), 66 | } 67 | } 68 | } 69 | 70 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 71 | #[serde(untagged)] 72 | pub enum Annotation { 73 | FileCitation(FileCitation), 74 | UrlCitation(UrlCitation), 75 | FilePath(FilePath), 76 | } 77 | 78 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 79 | pub struct OutputText { 80 | pub annotations: Vec, 81 | pub text: String, 82 | #[serde(rename = "type")] 83 | pub type_field: String, 84 | } 85 | 86 | impl OutputText { 87 | pub fn new(text: impl Into) -> Self { 88 | Self { 89 | annotations: vec![], 90 | text: text.into(), 91 | type_field: "output_text".to_string(), 92 | } 93 | } 94 | 95 | pub fn extend_annotations(mut self, annotation: Vec) -> Self { 96 | self.annotations.extend(annotation); 97 | self 98 | } 99 | } 100 | 101 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 102 | pub struct Refusal { 103 | pub refusal: String, 104 | #[serde(rename = "type")] 105 | pub type_field: String, 106 | } 107 | 108 | impl Refusal { 109 | pub fn new(refusal: impl Into) -> Self { 110 | Self { 111 | refusal: refusal.into(), 112 | type_field: "refusal".to_string(), 113 | } 114 | } 115 | } 116 | 117 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 118 | #[serde(untagged)] 119 | pub enum OutputContent { 120 | OutputText(OutputText), 121 | Refusal(Refusal), 122 | } 123 | 124 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 125 | pub struct OutputMessageItem { 126 | pub content: Vec, 127 | pub id: String, 128 | pub role: Role, 129 | pub status: Status, 130 | } 131 | 132 | impl OutputMessageItem { 133 | pub fn new(id: impl Into, status: impl AsRef) -> Result { 134 | Ok(Self { 135 | content: vec![], 136 | id: id.into(), 137 | role: Role::Assistant, 138 | status: Status::from_str(status.as_ref())?, 139 | }) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/reasoning.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | #[serde(rename_all = "lowercase")] 9 | pub enum Effort { 10 | Low, 11 | Medium, 12 | High, 13 | } 14 | 15 | impl FromStr for Effort { 16 | type Err = ConversionError; 17 | 18 | fn from_str(s: &str) -> Result { 19 | match s { 20 | "low" => Ok(Effort::Low), 21 | "medium" => Ok(Effort::Medium), 22 | "high" => Ok(Effort::High), 23 | _ => Err(ConversionError::FromStr(s.to_string())), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 29 | #[serde(rename_all = "lowercase")] 30 | pub enum Summary { 31 | Auto, 32 | Concise, 33 | Detailed, 34 | } 35 | 36 | impl FromStr for Summary { 37 | type Err = ConversionError; 38 | 39 | fn from_str(s: &str) -> Result { 40 | match s { 41 | "auto" => Ok(Summary::Auto), 42 | "concise" => Ok(Summary::Concise), 43 | "detailed" => Ok(Summary::Detailed), 44 | _ => Err(ConversionError::FromStr(s.to_string())), 45 | } 46 | } 47 | } 48 | 49 | #[derive(Debug, Default, PartialEq, Serialize, Deserialize)] 50 | pub struct Reasoning { 51 | #[serde(skip_serializing_if = "Option::is_none")] 52 | pub effort: Option, 53 | #[serde(skip_serializing_if = "Option::is_none")] 54 | pub summary: Option, 55 | } 56 | 57 | impl Reasoning { 58 | pub fn new() -> Self { 59 | Self::default() 60 | } 61 | 62 | pub fn effort(mut self, effort: &str) -> Self { 63 | self.effort = Some(Effort::from_str(effort).unwrap()); 64 | self 65 | } 66 | 67 | pub fn summary(mut self, summary: &str) -> Self { 68 | self.summary = Some(Summary::from_str(summary).unwrap()); 69 | self 70 | } 71 | } 72 | 73 | #[cfg(test)] 74 | mod tests { 75 | use super::*; 76 | 77 | #[test] 78 | fn it_creates_reasoning_instance() { 79 | let result = Reasoning::new().effort("low").summary("detailed"); 80 | 81 | assert_eq!(result.effort, Some(Effort::Low)); 82 | assert_eq!(result.summary, Some(Summary::Detailed)); 83 | } 84 | 85 | #[test] 86 | fn it_converts_string_into_effort() { 87 | let values = ["low", "medium", "high"]; 88 | let expected = [Effort::Low, Effort::Medium, Effort::High]; 89 | 90 | for (index, value) in values.iter().enumerate() { 91 | assert_eq!(Effort::from_str(value).unwrap(), expected[index]); 92 | } 93 | } 94 | 95 | #[test] 96 | fn it_reverts_when_invalid_effort_value_is_given() { 97 | let invalid_value = "invalid_value"; 98 | 99 | assert_eq!( 100 | Effort::from_str(invalid_value), 101 | Err(ConversionError::FromStr(invalid_value.to_string())) 102 | ); 103 | } 104 | 105 | #[test] 106 | fn it_converts_string_into_summary() { 107 | let values = ["auto", "concise", "detailed"]; 108 | let expected = [Summary::Auto, Summary::Concise, Summary::Detailed]; 109 | 110 | for (index, value) in values.iter().enumerate() { 111 | assert_eq!(Summary::from_str(value).unwrap(), expected[index]); 112 | } 113 | } 114 | 115 | #[test] 116 | fn it_reverts_when_invalid_summary_value_is_given() { 117 | let invalid_value = "invalid_value"; 118 | 119 | assert_eq!( 120 | Summary::from_str(invalid_value), 121 | Err(ConversionError::FromStr(invalid_value.to_string())) 122 | ); 123 | } 124 | 125 | #[test] 126 | fn test_json_values() { 127 | let reasoning = Reasoning::new().effort("low").summary("detailed"); 128 | let json_value = serde_json::to_value(&reasoning).unwrap(); 129 | assert_eq!( 130 | json_value, 131 | serde_json::json!({ 132 | "effort": "low", 133 | "summary": "detailed" 134 | }) 135 | ); 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /crates/providers/src/openai/client.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::constants::OPENAI_API_URL; 2 | use crate::openai::response::events::streaming::OpenAIStreamingEvent; 3 | use crate::utils::{errors::ProviderError, provider_strategy::ProviderStrategy}; 4 | use async_trait::async_trait; 5 | use futures::stream::StreamExt; 6 | use std::pin::Pin; 7 | use tokio_stream::Stream; 8 | 9 | use super::types::{OpenAIRequest, OpenAIResponse}; 10 | 11 | pub struct OpenAIProvider { 12 | api_key: String, 13 | } 14 | 15 | impl OpenAIProvider { 16 | pub fn new(api_key: String) -> Self { 17 | OpenAIProvider { api_key } 18 | } 19 | } 20 | 21 | #[async_trait] 22 | impl ProviderStrategy for OpenAIProvider { 23 | type GenerationRequest = OpenAIRequest; 24 | type StreamingRequest = OpenAIRequest; 25 | type GenerationResponse = OpenAIResponse; 26 | type StreamingResponse = OpenAIStreamingEvent; 27 | 28 | fn get_base_url(&self) -> String { 29 | OPENAI_API_URL.to_string() 30 | } 31 | 32 | fn get_api_key(&self) -> String { 33 | self.api_key.clone() 34 | } 35 | 36 | async fn generate( 37 | &self, 38 | request: &Self::GenerationRequest, 39 | ) -> Result { 40 | let client = reqwest::Client::new(); 41 | let url = format!("{}/responses", self.get_base_url()); 42 | 43 | let response = client 44 | .post(&url) 45 | .header("Authorization", format!("Bearer {}", self.get_api_key())) 46 | .header("Content-Type", "application/json") 47 | .json(request) 48 | .send() 49 | .await 50 | .map_err(|e| ProviderError::NetworkError(e.to_string()))?; 51 | 52 | if !response.status().is_success() { 53 | let status = response.status().as_u16(); 54 | let error_message = response 55 | .text() 56 | .await 57 | .unwrap_or_else(|_| "Failed to read error response".to_string()); 58 | 59 | return Err(ProviderError::ApiError { 60 | status, 61 | message: error_message, 62 | }); 63 | } 64 | 65 | let response_bytes = response 66 | .bytes() 67 | .await 68 | .map_err(|e| ProviderError::NetworkError(e.to_string()))?; 69 | 70 | serde_json::from_slice(&response_bytes) 71 | .map_err(|e| ProviderError::DeserializationError(e.to_string())) 72 | } 73 | 74 | async fn stream( 75 | &self, 76 | request: &Self::StreamingRequest, 77 | ) -> Result< 78 | Pin> + Send>>, 79 | ProviderError, 80 | > { 81 | let client = reqwest::Client::new(); 82 | let url = format!("{}/responses", self.get_base_url()); 83 | 84 | let response = client 85 | .post(&url) 86 | .header("Authorization", format!("Bearer {}", self.get_api_key())) 87 | .header("Content-Type", "application/json") 88 | .header("Accept", "text/event-stream") 89 | .json(&request.wrap_for_streaming()) 90 | .send() 91 | .await 92 | .map_err(|e| ProviderError::NetworkError(e.to_string()))?; 93 | 94 | if !response.status().is_success() { 95 | let status = response.status().as_u16(); 96 | let error_message = response 97 | .text() 98 | .await 99 | .unwrap_or_else(|_| "Failed to read error response".to_string()); 100 | 101 | return Err(ProviderError::ApiError { 102 | status, 103 | message: error_message, 104 | }); 105 | } 106 | 107 | let stream = response.bytes_stream(); 108 | let parsed_stream = stream.map(|chunk_result| { 109 | chunk_result 110 | .map_err(|e| ProviderError::NetworkError(e.to_string())) 111 | .and_then(|chunk| { 112 | serde_json::from_slice(&chunk) 113 | .map_err(|e| ProviderError::DeserializationError(e.to_string())) 114 | }) 115 | }); 116 | 117 | Ok(Box::pin(parsed_stream)) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/input_models/input_message.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::errors::ConversionError; 2 | use crate::openai::request::input_models::common::{Content, Role}; 3 | use crate::openai::request::input_models::input_reference::InputReference; 4 | use crate::openai::request::input_models::item::Item; 5 | use serde::{Deserialize, Serialize}; 6 | use std::str::FromStr; 7 | 8 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 9 | pub struct TextInput { 10 | pub role: Role, 11 | pub content: String, 12 | #[serde(rename = "type")] 13 | #[serde(skip_serializing_if = "Option::is_none")] 14 | pub type_field: Option, 15 | } 16 | 17 | impl TextInput { 18 | pub fn new(content: impl Into) -> Self { 19 | Self { 20 | role: Role::default(), 21 | content: content.into(), 22 | type_field: None, 23 | } 24 | } 25 | 26 | pub fn role(mut self, role: impl AsRef) -> Result { 27 | self.role = Role::from_str(role.as_ref())?; 28 | Ok(self) 29 | } 30 | 31 | pub fn insert_type(mut self) -> Self { 32 | self.type_field = Some("message".to_string()); 33 | self 34 | } 35 | } 36 | 37 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)] 38 | pub struct InputItemContentList { 39 | pub role: Role, 40 | pub content: Vec, 41 | #[serde(rename = "type")] 42 | #[serde(skip_serializing_if = "Option::is_none")] 43 | pub type_field: Option, 44 | } 45 | 46 | impl InputItemContentList { 47 | pub fn new() -> Self { 48 | Self::default() 49 | } 50 | 51 | pub fn role(mut self, role: impl AsRef) -> Result { 52 | self.role = Role::from_str(role.as_ref())?; 53 | Ok(self) 54 | } 55 | 56 | pub fn insert_type(mut self) -> Self { 57 | self.type_field = Some("message".to_string()); 58 | self 59 | } 60 | } 61 | 62 | impl From for InputItemContentList { 63 | fn from(_item: Item) -> Self { 64 | Self { 65 | role: Role::default(), 66 | content: Vec::new(), 67 | type_field: Some("message".to_string()), 68 | } 69 | } 70 | } 71 | 72 | impl From for InputItemContentList { 73 | fn from(_reference: InputReference) -> Self { 74 | Self { 75 | role: Role::default(), 76 | content: Vec::new(), 77 | type_field: Some("message".to_string()), 78 | } 79 | } 80 | } 81 | 82 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 83 | #[serde(untagged)] 84 | pub enum InputMessage { 85 | TextInput(TextInput), 86 | InputItemContentList(InputItemContentList), 87 | } 88 | 89 | impl From for InputMessage { 90 | fn from(text_input: TextInput) -> Self { 91 | InputMessage::TextInput(text_input) 92 | } 93 | } 94 | 95 | impl From for InputMessage { 96 | fn from(content_list: InputItemContentList) -> Self { 97 | InputMessage::InputItemContentList(content_list) 98 | } 99 | } 100 | 101 | #[cfg(test)] 102 | mod tests { 103 | use crate::openai::request::input_models::common::TextContent; 104 | 105 | use super::*; 106 | 107 | #[test] 108 | fn test_json_values() { 109 | let text_input = TextInput::new("Hello, world!"); 110 | let input_message: InputMessage = text_input.clone().into(); 111 | assert_eq!(input_message, InputMessage::TextInput(text_input)); 112 | 113 | let json_value = serde_json::to_value(&input_message).unwrap(); 114 | assert_eq!( 115 | json_value, 116 | serde_json::json!({ 117 | "role": "user", 118 | "content": "Hello, world!" 119 | }) 120 | ); 121 | } 122 | 123 | #[test] 124 | fn test_json_values_input_item_content_list() { 125 | let mut input_item_content_list = InputItemContentList::new() 126 | .insert_type() 127 | .role("developer") 128 | .unwrap(); 129 | 130 | input_item_content_list 131 | .content 132 | .push(Content::Text(TextContent::new().text("Hello, world!"))); 133 | 134 | let input_message: InputMessage = input_item_content_list.clone().into(); 135 | assert_eq!( 136 | input_message, 137 | InputMessage::InputItemContentList(input_item_content_list) 138 | ); 139 | 140 | let json_value = serde_json::to_value(&input_message).unwrap(); 141 | assert_eq!( 142 | json_value, 143 | serde_json::json!({ 144 | "role": "developer", 145 | "content": [ 146 | { 147 | "type": "input_text", 148 | "text": "Hello, world!" 149 | } 150 | ], 151 | "type": "message" 152 | }) 153 | ); 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/tool_choice.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | #[serde(rename_all = "lowercase")] 9 | pub enum ToolChoiceMode { 10 | None, 11 | Auto, 12 | Required, 13 | } 14 | 15 | impl FromStr for ToolChoiceMode { 16 | type Err = ConversionError; 17 | 18 | fn from_str(s: &str) -> Result { 19 | match s { 20 | "none" => Ok(ToolChoiceMode::None), 21 | "auto" => Ok(ToolChoiceMode::Auto), 22 | "required" => Ok(ToolChoiceMode::Required), 23 | _ => Err(ConversionError::FromStr(s.to_string())), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 29 | #[serde(rename = "snake_case")] 30 | enum HostedToolType { 31 | FileSearch, 32 | WebSearchPreview, 33 | ComputerUsePreview, 34 | } 35 | 36 | impl FromStr for HostedToolType { 37 | type Err = ConversionError; 38 | 39 | fn from_str(s: &str) -> Result { 40 | match s { 41 | "file_search" => Ok(HostedToolType::FileSearch), 42 | "web_search_preview" => Ok(HostedToolType::WebSearchPreview), 43 | "computer_use_preview" => Ok(HostedToolType::ComputerUsePreview), 44 | _ => Err(ConversionError::FromStr(s.to_string())), 45 | } 46 | } 47 | } 48 | 49 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 50 | pub struct HostedToolChoice { 51 | #[serde(rename = "type")] 52 | type_field: HostedToolType, 53 | } 54 | 55 | impl HostedToolChoice { 56 | pub fn new(hosted_tool_type: &str) -> Self { 57 | Self { 58 | type_field: HostedToolType::from_str(hosted_tool_type).unwrap(), 59 | } 60 | } 61 | } 62 | 63 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 64 | pub struct FunctionToolChoice { 65 | name: String, 66 | #[serde(rename = "type")] 67 | type_field: String, 68 | } 69 | 70 | impl FunctionToolChoice { 71 | pub fn new(name: impl Into) -> Self { 72 | Self { 73 | name: name.into(), 74 | type_field: "function".to_string(), 75 | } 76 | } 77 | } 78 | 79 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 80 | #[serde(untagged)] 81 | pub enum ToolChoice { 82 | Mode(ToolChoiceMode), 83 | HostedTool(HostedToolChoice), 84 | FunctionTool(FunctionToolChoice), 85 | } 86 | 87 | impl From for ToolChoice { 88 | fn from(tool: ToolChoiceMode) -> Self { 89 | ToolChoice::Mode(tool) 90 | } 91 | } 92 | 93 | impl From for ToolChoice { 94 | fn from(tool: HostedToolChoice) -> Self { 95 | ToolChoice::HostedTool(tool) 96 | } 97 | } 98 | 99 | impl From for ToolChoice { 100 | fn from(tool: FunctionToolChoice) -> Self { 101 | ToolChoice::FunctionTool(tool) 102 | } 103 | } 104 | 105 | #[cfg(test)] 106 | mod tests { 107 | use super::*; 108 | 109 | #[test] 110 | fn it_builds_choice_mode() { 111 | let result: ToolChoice = ToolChoiceMode::from_str("auto").unwrap().into(); 112 | let expected = ToolChoice::Mode(ToolChoiceMode::Auto); 113 | 114 | assert_eq!(result, expected); 115 | } 116 | 117 | #[test] 118 | fn it_builds_hosted_tool() { 119 | let result: ToolChoice = HostedToolChoice::new("web_search_preview").into(); 120 | let expected = ToolChoice::HostedTool(HostedToolChoice { 121 | type_field: HostedToolType::from_str("web_search_preview").unwrap(), 122 | }); 123 | 124 | assert_eq!(result, expected); 125 | } 126 | 127 | #[test] 128 | fn it_builds_function_tool() { 129 | let result: ToolChoice = FunctionToolChoice::new("test name").into(); 130 | let expected = ToolChoice::FunctionTool(FunctionToolChoice { 131 | name: "test name".to_string(), 132 | type_field: "function".to_string(), 133 | }); 134 | 135 | assert_eq!(result, expected); 136 | } 137 | 138 | #[test] 139 | fn it_builds_tool_choice_from_str() { 140 | let test_cases = [ 141 | ( 142 | "auto", 143 | ToolChoice::Mode(ToolChoiceMode::from_str("auto").unwrap()), 144 | ), 145 | ( 146 | "none", 147 | ToolChoice::Mode(ToolChoiceMode::from_str("none").unwrap()), 148 | ), 149 | ( 150 | "required", 151 | ToolChoice::Mode(ToolChoiceMode::from_str("required").unwrap()), 152 | ), 153 | ]; 154 | 155 | for (input, expected) in test_cases { 156 | let result: ToolChoice = ToolChoiceMode::from_str(input).unwrap().into(); 157 | assert_eq!(result, expected); 158 | } 159 | } 160 | 161 | #[test] 162 | fn it_builds_hosted_tool_from_str() { 163 | let test_cases = [ 164 | ( 165 | "file_search", 166 | ToolChoice::HostedTool(HostedToolChoice { 167 | type_field: HostedToolType::from_str("file_search").unwrap(), 168 | }), 169 | ), 170 | ( 171 | "web_search_preview", 172 | ToolChoice::HostedTool(HostedToolChoice { 173 | type_field: HostedToolType::from_str("web_search_preview").unwrap(), 174 | }), 175 | ), 176 | ( 177 | "computer_use_preview", 178 | ToolChoice::HostedTool(HostedToolChoice { 179 | type_field: HostedToolType::from_str("computer_use_preview").unwrap(), 180 | }), 181 | ), 182 | ]; 183 | 184 | for (input, expected) in test_cases { 185 | let result: ToolChoice = HostedToolChoice::new(input).into(); 186 | assert_eq!(result, expected); 187 | } 188 | } 189 | 190 | #[test] 191 | fn it_returns_error_for_invalid_tool_choice_mode() { 192 | let result = ToolChoiceMode::from_str("invalid"); 193 | assert!(result.is_err()); 194 | } 195 | 196 | #[test] 197 | fn it_returns_error_for_invalid_hosted_tool_type() { 198 | let result = HostedToolType::from_str("invalid"); 199 | assert!(result.is_err()); 200 | } 201 | 202 | #[test] 203 | fn test_json_values() { 204 | let tool_choice = ToolChoice::Mode(ToolChoiceMode::Auto); 205 | let json_value = serde_json::to_value(&tool_choice).unwrap(); 206 | assert_eq!(json_value, serde_json::json!("auto")); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/computer_tool_call_item.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::openai::common::status::Status; 4 | use crate::openai::errors::InputError; 5 | 6 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 7 | pub struct ClickAction { 8 | pub button: String, 9 | #[serde(rename = "type")] 10 | pub type_field: String, 11 | pub x: usize, 12 | pub y: usize, 13 | } 14 | 15 | impl ClickAction { 16 | const BUTTON: [&'static str; 5] = ["left", "right", "wheel", "back", "forward"]; 17 | 18 | pub fn new(button: impl Into, x: usize, y: usize) -> Result { 19 | let button_str = button.into(); 20 | if Self::BUTTON.contains(&button_str.as_str()) { 21 | Ok(Self { 22 | button: button_str, 23 | type_field: "click".to_string(), 24 | x, 25 | y, 26 | }) 27 | } else { 28 | Err(InputError::InvalidButton(button_str)) 29 | } 30 | } 31 | } 32 | 33 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 34 | pub struct DoubleClickAction { 35 | #[serde(rename = "type")] 36 | pub type_field: String, 37 | pub x: usize, 38 | pub y: usize, 39 | } 40 | 41 | impl DoubleClickAction { 42 | pub fn new(x: usize, y: usize) -> Self { 43 | Self { 44 | type_field: "double_click".to_string(), 45 | x, 46 | y, 47 | } 48 | } 49 | } 50 | 51 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 52 | pub struct DragActionPath { 53 | pub x: usize, 54 | pub y: usize, 55 | } 56 | 57 | impl DragActionPath { 58 | pub fn new(x: usize, y: usize) -> Self { 59 | Self { x, y } 60 | } 61 | } 62 | 63 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 64 | pub struct DragAction { 65 | #[serde(rename = "type")] 66 | pub type_field: String, 67 | pub path: Vec, 68 | } 69 | 70 | impl DragAction { 71 | pub fn new(path: Vec) -> Self { 72 | Self { 73 | type_field: "drag".to_string(), 74 | path, 75 | } 76 | } 77 | } 78 | 79 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 80 | pub struct KeyPressAction { 81 | #[serde(rename = "type")] 82 | pub type_field: String, 83 | pub keys: Vec, 84 | } 85 | 86 | impl KeyPressAction { 87 | pub fn new(keys: Vec>) -> Self { 88 | Self { 89 | type_field: "keypress".to_string(), 90 | keys: keys.into_iter().map(|k| k.into()).collect(), 91 | } 92 | } 93 | } 94 | 95 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 96 | pub struct MoveAction { 97 | #[serde(rename = "type")] 98 | pub type_field: String, 99 | pub x: usize, 100 | pub y: usize, 101 | } 102 | 103 | impl MoveAction { 104 | pub fn new(x: usize, y: usize) -> Self { 105 | Self { 106 | type_field: "move".to_string(), 107 | x, 108 | y, 109 | } 110 | } 111 | } 112 | 113 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 114 | pub struct ScreenshotAction { 115 | #[serde(rename = "type")] 116 | pub type_field: String, 117 | } 118 | 119 | impl ScreenshotAction { 120 | pub fn new() -> Self { 121 | Self { 122 | type_field: "screenshot".to_string(), 123 | } 124 | } 125 | } 126 | 127 | impl Default for ScreenshotAction { 128 | fn default() -> Self { 129 | Self::new() 130 | } 131 | } 132 | 133 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 134 | pub struct ScrollAction { 135 | #[serde(rename = "type")] 136 | pub type_field: String, 137 | pub scroll_x: usize, 138 | pub scroll_y: usize, 139 | pub x: usize, 140 | pub y: usize, 141 | } 142 | 143 | impl ScrollAction { 144 | pub fn new(scroll_x: usize, scroll_y: usize, x: usize, y: usize) -> Self { 145 | Self { 146 | type_field: "scroll".to_string(), 147 | scroll_x, 148 | scroll_y, 149 | x, 150 | y, 151 | } 152 | } 153 | } 154 | 155 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 156 | pub struct TypeAction { 157 | #[serde(rename = "type")] 158 | pub type_field: String, 159 | pub text: String, 160 | } 161 | 162 | impl TypeAction { 163 | pub fn new(text: impl Into) -> Self { 164 | Self { 165 | type_field: "type".to_string(), 166 | text: text.into(), 167 | } 168 | } 169 | } 170 | 171 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 172 | pub struct WaitAction { 173 | #[serde(rename = "type")] 174 | pub type_field: String, 175 | } 176 | 177 | impl WaitAction { 178 | pub fn new() -> Self { 179 | Self { 180 | type_field: "wait".to_string(), 181 | } 182 | } 183 | } 184 | 185 | impl Default for WaitAction { 186 | fn default() -> Self { 187 | Self::new() 188 | } 189 | } 190 | 191 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 192 | #[serde(untagged)] 193 | pub enum ComputerToolAction { 194 | Click(ClickAction), 195 | DoubleClick(DoubleClickAction), 196 | Drag(DragAction), 197 | KeyPress(KeyPressAction), 198 | Move(MoveAction), 199 | Screenshot(ScreenshotAction), 200 | Scroll(ScrollAction), 201 | Type(TypeAction), 202 | Wait(WaitAction), 203 | } 204 | 205 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 206 | pub struct PendingSafetyChecks { 207 | pub code: String, 208 | pub id: String, 209 | pub message: String, 210 | } 211 | 212 | impl PendingSafetyChecks { 213 | pub fn new(code: impl Into, id: impl Into, message: impl Into) -> Self { 214 | Self { 215 | code: code.into(), 216 | id: id.into(), 217 | message: message.into(), 218 | } 219 | } 220 | } 221 | 222 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 223 | pub struct ComputerToolCallItem { 224 | pub action: ComputerToolAction, 225 | pub call_id: String, 226 | pub id: String, 227 | pub pending_safety_checks: Vec, 228 | pub status: Status, 229 | } 230 | 231 | impl ComputerToolCallItem { 232 | pub fn new( 233 | action: ComputerToolAction, 234 | call_id: impl Into, 235 | id: impl Into, 236 | pending_safety_checks: Vec, 237 | status: Status, 238 | ) -> Self { 239 | Self { 240 | action, 241 | call_id: call_id.into(), 242 | id: id.into(), 243 | pending_safety_checks, 244 | status, 245 | } 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /crates/providers/src/openai/response/events/streaming.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::common::output_message_item::{Annotation, OutputContent}; 2 | use crate::openai::common::{ 3 | reasoning::Reasoning, service_tier::ServiceTier, status::Status, text::Text, tool::Tool, 4 | tool_choice::ToolChoice, truncation::Truncation, 5 | }; 6 | use crate::openai::response::{ 7 | incomplete_details::IncompleteDetails, response_error::ResponseError, 8 | response_output::ResponseOutput, usage::Usage, 9 | }; 10 | use serde::{Deserialize, Serialize}; 11 | use std::collections::HashMap; 12 | 13 | #[derive(Debug, Serialize)] 14 | #[serde(bound(deserialize = ""))] 15 | #[derive(Deserialize)] 16 | #[serde(tag = "type")] 17 | pub enum OpenAIStreamingEvent { 18 | #[serde(rename = "response.created")] 19 | Created { response: StreamingResponse }, 20 | #[serde(rename = "response.in_progress")] 21 | InProgress { response: StreamingResponse }, 22 | #[serde(rename = "response.completed")] 23 | Completed { response: StreamingResponse }, 24 | #[serde(rename = "response.failed")] 25 | Failed { response: StreamingResponse }, 26 | #[serde(rename = "response.incomplete")] 27 | Incomplete { response: StreamingResponse }, 28 | #[serde(rename = "response.output_item.added")] 29 | OutputItemAdded { 30 | output_index: usize, 31 | item: ResponseOutput, 32 | }, 33 | #[serde(rename = "response.output_item.done")] 34 | OutputItemDone { 35 | output_index: usize, 36 | item: ResponseOutput, 37 | }, 38 | #[serde(rename = "response.content_part.added")] 39 | ContentPartAdded { 40 | item_id: String, 41 | output_index: usize, 42 | content_index: usize, 43 | part: OutputContent, 44 | }, 45 | #[serde(rename = "response.content_part.done")] 46 | ContentPartDone { 47 | item_id: String, 48 | output_index: usize, 49 | content_index: usize, 50 | part: OutputContent, 51 | }, 52 | #[serde(rename = "response.output_text.delta")] 53 | OutputTextDelta { 54 | item_id: String, 55 | output_index: usize, 56 | content_index: usize, 57 | delta: String, 58 | }, 59 | #[serde(rename = "response.output_text.annotation.added")] 60 | OutputTextAnnotationAdded { 61 | item_id: String, 62 | output_index: usize, 63 | content_index: usize, 64 | annotation_index: usize, 65 | annotation: Annotation, 66 | }, 67 | #[serde(rename = "response.output_text.done")] 68 | OutputTextDone { 69 | item_id: String, 70 | output_index: usize, 71 | content_index: usize, 72 | text: String, 73 | }, 74 | #[serde(rename = "response.refusal.delta")] 75 | RefusalDelta { 76 | item_id: String, 77 | output_index: usize, 78 | content_index: usize, 79 | delta: String, 80 | }, 81 | #[serde(rename = "response.refusal.done")] 82 | RefusalDone { 83 | item_id: String, 84 | output_index: usize, 85 | content_index: usize, 86 | refusal: String, 87 | }, 88 | #[serde(rename = "response.function_call_arguments.delta")] 89 | FunctionCallArgumentsDelta { 90 | item_id: String, 91 | output_index: usize, 92 | delta: String, 93 | }, 94 | #[serde(rename = "response.function_call_arguments.done")] 95 | FunctionCallArgumentsDone { 96 | item_id: String, 97 | output_index: usize, 98 | arguments: String, 99 | }, 100 | #[serde(rename = "response.file_search_call.in_progress")] 101 | FileSearchCallInProgress { 102 | item_id: String, 103 | output_index: usize, 104 | }, 105 | #[serde(rename = "response.file_search_call.searching")] 106 | FileSearchCallSearching { 107 | item_id: String, 108 | output_index: usize, 109 | }, 110 | #[serde(rename = "response.file_search_call.completed")] 111 | FileSearchCallCompleted { 112 | item_id: String, 113 | output_index: usize, 114 | }, 115 | #[serde(rename = "response.web_search_call.in_progress")] 116 | WebSearchCallInProgress { 117 | item_id: String, 118 | output_index: usize, 119 | }, 120 | #[serde(rename = "response.web_search_call.searching")] 121 | WebSearchCallSearching { 122 | item_id: String, 123 | output_index: usize, 124 | }, 125 | #[serde(rename = "response.web_search_call.completed")] 126 | WebSearchCallCompleted { 127 | item_id: String, 128 | output_index: usize, 129 | }, 130 | #[serde(rename = "response.reasoning_summary_part.added")] 131 | ReasoningSummaryPartAdded { 132 | item_id: String, 133 | output_index: usize, 134 | part: ReasoningPart, 135 | summary_index: usize, 136 | }, 137 | #[serde(rename = "response.reasoning_summary_part.done")] 138 | ReasoningSummaryPartDone { 139 | item_id: String, 140 | output_index: usize, 141 | part: ReasoningPart, 142 | summary_index: usize, 143 | }, 144 | #[serde(rename = "response.reasoning_summary_text.delta")] 145 | ReasoningSummaryTextDelta { 146 | delta: String, 147 | item_id: String, 148 | output_index: usize, 149 | summary_index: usize, 150 | }, 151 | #[serde(rename = "response.reasoning_summary_text.done")] 152 | ReasoningSummaryTextDone { 153 | item_id: String, 154 | output_index: usize, 155 | summary_index: usize, 156 | text: String, 157 | }, 158 | #[serde(rename = "error")] 159 | Error { 160 | #[serde(skip_serializing_if = "Option::is_none")] 161 | code: Option, 162 | message: String, 163 | #[serde(skip_serializing_if = "Option::is_none")] 164 | param: Option, 165 | }, 166 | } 167 | 168 | #[derive(Debug, Serialize, Deserialize)] 169 | #[serde(tag = "type")] 170 | pub struct ReasoningPart { 171 | #[serde(rename = "type")] 172 | pub type_field: ReasoningPartType, 173 | pub text: String, 174 | } 175 | 176 | #[derive(Debug, Serialize, Deserialize)] 177 | #[serde(rename_all = "snake_case")] 178 | pub enum ReasoningPartType { 179 | #[serde(rename = "summary_text")] 180 | SummaryText, 181 | } 182 | 183 | #[derive(Debug, Serialize)] 184 | #[serde(bound(deserialize = ""))] 185 | #[derive(Deserialize)] 186 | pub struct StreamingResponse { 187 | pub created_at: u64, 188 | #[serde(skip_serializing_if = "Option::is_none")] 189 | pub error: Option, 190 | pub id: String, 191 | #[serde(skip_serializing_if = "Option::is_none")] 192 | pub incomplete_details: Option, 193 | #[serde(skip_serializing_if = "Option::is_none")] 194 | pub instructions: Option, 195 | pub metadata: HashMap, 196 | pub model: String, 197 | pub object: String, 198 | pub output: Vec, 199 | pub parallel_tool_calls: bool, 200 | #[serde(skip_serializing_if = "Option::is_none")] 201 | pub temperature: Option, 202 | pub tool_choice: ToolChoice, 203 | pub tools: Vec, 204 | #[serde(skip_serializing_if = "Option::is_none")] 205 | pub top_p: Option, 206 | #[serde(skip_serializing_if = "Option::is_none")] 207 | pub max_output_tokens: Option, 208 | #[serde(skip_serializing_if = "Option::is_none")] 209 | pub previous_response_id: Option, 210 | #[serde(skip_serializing_if = "Option::is_none")] 211 | pub reasoning: Option, 212 | #[serde(skip_serializing_if = "Option::is_none")] 213 | pub service_tier: Option, 214 | pub status: Status, 215 | pub text: Text, 216 | #[serde(skip_serializing_if = "Option::is_none")] 217 | pub truncation: Option, 218 | pub usage: Usage, 219 | pub user: String, 220 | } 221 | -------------------------------------------------------------------------------- /crates/providers/src/openai/types.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | 4 | use crate::openai::common::{ 5 | reasoning::Reasoning, service_tier::ServiceTier, status::Status, text::Text, tool::Tool, 6 | tool_choice::ToolChoice, truncation::Truncation, 7 | }; 8 | use crate::openai::constants::OpenAIModelId; 9 | use crate::openai::request::{include::Include, input::Input}; 10 | use crate::openai::response::{ 11 | incomplete_details::IncompleteDetails, response_error::ResponseError, 12 | response_output::ResponseOutput, usage::Usage, 13 | }; 14 | 15 | use serde_json::{json, Value}; 16 | 17 | #[derive(Debug, Default, PartialEq, Serialize, Deserialize)] 18 | pub struct OpenAIRequest { 19 | input: Input, 20 | model: OpenAIModelId, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | include: Option>, 23 | #[serde(skip_serializing_if = "Option::is_none")] 24 | instructions: Option, 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | max_output_tokens: Option, 27 | #[serde(skip_serializing_if = "Option::is_none")] 28 | metadata: Option>, 29 | #[serde(skip_serializing_if = "Option::is_none")] 30 | parallel_tool_calls: Option, 31 | #[serde(skip_serializing_if = "Option::is_none")] 32 | previous_response_id: Option, 33 | #[serde(skip_serializing_if = "Option::is_none")] 34 | reasoning: Option, 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | service_tier: Option, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | store: Option, 39 | #[serde(skip_serializing_if = "Option::is_none")] 40 | temperature: Option, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | text: Option, 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | tool_choice: Option, 45 | #[serde(skip_serializing_if = "Option::is_none")] 46 | tools: Option>, 47 | #[serde(skip_serializing_if = "Option::is_none")] 48 | top_p: Option, 49 | #[serde(skip_serializing_if = "Option::is_none")] 50 | truncation: Option, 51 | #[serde(skip_serializing_if = "Option::is_none")] 52 | user: Option, 53 | } 54 | 55 | impl OpenAIRequest { 56 | pub fn new(model: OpenAIModelId, input: Input) -> Self { 57 | Self { 58 | model, 59 | input, 60 | ..Default::default() 61 | } 62 | } 63 | 64 | pub fn include(mut self, value: Include) -> Self { 65 | match self.include { 66 | Some(ref mut include) => include.push(value), 67 | None => self.include = Some(vec![value]), 68 | } 69 | 70 | self 71 | } 72 | 73 | pub fn instructions(mut self, value: impl Into) -> Self { 74 | self.instructions = Some(value.into()); 75 | self 76 | } 77 | 78 | pub fn max_output_tokens(mut self, value: usize) -> Self { 79 | self.max_output_tokens = Some(value); 80 | self 81 | } 82 | 83 | pub fn insert_metadata(mut self, key: String, value: String) -> Self { 84 | match self.metadata { 85 | Some(ref mut metadata) => { 86 | metadata.insert(key, value); 87 | } 88 | None => { 89 | self.metadata = Some({ 90 | let mut new_map: HashMap = HashMap::new(); 91 | new_map.insert(key, value); 92 | new_map 93 | }); 94 | } 95 | } 96 | 97 | self 98 | } 99 | 100 | pub fn parallel_tool_calls(mut self, value: bool) -> Self { 101 | self.parallel_tool_calls = Some(value); 102 | self 103 | } 104 | 105 | pub fn previous_response_id(mut self, value: impl Into) -> Self { 106 | self.previous_response_id = Some(value.into()); 107 | self 108 | } 109 | 110 | pub fn reasoning(mut self, value: Reasoning) -> Self { 111 | self.reasoning = Some(value); 112 | self 113 | } 114 | 115 | pub fn service_tier(mut self, value: ServiceTier) -> Self { 116 | self.service_tier = Some(value); 117 | self 118 | } 119 | 120 | pub fn store(mut self, value: bool) -> Self { 121 | self.store = Some(value); 122 | self 123 | } 124 | 125 | pub fn temperature(mut self, value: f32) -> Self { 126 | self.temperature = Some(value); 127 | self 128 | } 129 | 130 | pub fn text(mut self, value: Text) -> Self { 131 | self.text = Some(value); 132 | self 133 | } 134 | 135 | pub fn tool_choice(mut self, value: ToolChoice) -> Self { 136 | self.tool_choice = Some(value); 137 | self 138 | } 139 | 140 | pub fn add_tool(mut self, value: Tool) -> Self { 141 | match self.tools { 142 | Some(ref mut tools) => tools.push(value), 143 | None => self.tools = Some(vec![value]), 144 | } 145 | self 146 | } 147 | 148 | pub fn top_p(mut self, value: f32) -> Self { 149 | self.top_p = Some(value); 150 | self 151 | } 152 | 153 | pub fn truncation(mut self, value: Truncation) -> Self { 154 | self.truncation = Some(value); 155 | self 156 | } 157 | 158 | pub fn user(mut self, value: impl Into) -> Self { 159 | self.user = Some(value.into()); 160 | self 161 | } 162 | 163 | pub fn wrap_for_streaming(&self) -> impl Serialize + '_ { 164 | struct Wrapper<'a> { 165 | inner: &'a OpenAIRequest, 166 | } 167 | 168 | impl Serialize for Wrapper<'_> { 169 | fn serialize(&self, serializer: S) -> Result 170 | where 171 | S: serde::Serializer, 172 | { 173 | let mut original = serde_json::to_value(self.inner) 174 | .map_err(|e| serde::ser::Error::custom(e.to_string()))?; 175 | 176 | if let Value::Object(ref mut map) = original { 177 | map.insert("stream".to_string(), json!(true)); 178 | 179 | map.serialize(serializer) 180 | } else { 181 | Err(serde::ser::Error::custom("Expected object")) 182 | } 183 | } 184 | } 185 | 186 | Wrapper { inner: self } 187 | } 188 | } 189 | 190 | #[derive(Debug, PartialEq, Serialize)] 191 | #[serde(bound(deserialize = ""))] 192 | #[derive(Deserialize)] 193 | pub struct OpenAIResponse { 194 | created_at: u64, 195 | #[serde(skip_serializing_if = "Option::is_none")] 196 | error: Option, 197 | id: String, 198 | #[serde(skip_serializing_if = "Option::is_none")] 199 | incomplete_details: Option, 200 | #[serde(skip_serializing_if = "Option::is_none")] 201 | instructions: Option, 202 | #[serde(skip_serializing_if = "Option::is_none")] 203 | max_output_tokens: Option, 204 | #[serde(skip_serializing_if = "Option::is_none")] 205 | metadata: Option>, 206 | model: String, 207 | object: String, 208 | output: Vec, 209 | parallel_tool_calls: bool, 210 | #[serde(skip_serializing_if = "Option::is_none")] 211 | previous_response_id: Option, 212 | #[serde(skip_serializing_if = "Option::is_none")] 213 | reasoning: Option, 214 | #[serde(skip_serializing_if = "Option::is_none")] 215 | service_tier: Option, 216 | status: Status, 217 | #[serde(skip_serializing_if = "Option::is_none")] 218 | temperature: Option, 219 | text: Text, 220 | tool_choice: ToolChoice, 221 | tools: Vec, 222 | #[serde(skip_serializing_if = "Option::is_none")] 223 | top_p: Option, 224 | #[serde(skip_serializing_if = "Option::is_none")] 225 | truncation: Option, 226 | usage: Usage, 227 | #[serde(skip_serializing_if = "Option::is_none")] 228 | user: Option, 229 | } 230 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/input_models/item.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::errors::InputError; 2 | use crate::openai::request::input_models::common::{Content, Role}; 3 | 4 | use crate::openai::common::{ 5 | computer_tool_call_item::ComputerToolCallItem, file_search_tool_item::FileSearchToolCallItem, 6 | function_tool_call_item::FunctionToolCallItem, output_message_item::OutputMessageItem, 7 | reasoning_item::ReasoningItem, status::Status, 8 | web_search_tool_call_item::WebSearchToolCallItem, 9 | }; 10 | use std::str::FromStr; 11 | 12 | use serde::{Deserialize, Serialize}; 13 | 14 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)] 15 | pub struct InputMessageItem { 16 | pub content: Vec, 17 | pub role: Role, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub status: Option, 20 | #[serde(rename = "type")] 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | pub type_field: Option, 23 | } 24 | 25 | impl InputMessageItem { 26 | pub fn new() -> Self { 27 | Self::default() 28 | } 29 | 30 | pub fn role(mut self, role: impl AsRef) -> Result { 31 | if role.as_ref().eq("assistant") { 32 | Err(InputError::InvalidRole("assistant".to_string())) 33 | } else { 34 | self.role = Role::from_str(role.as_ref()).map_err(InputError::ConversionError)?; 35 | Ok(self) 36 | } 37 | } 38 | 39 | pub fn insert_type(mut self) -> Self { 40 | self.type_field = Some("message".to_string()); 41 | self 42 | } 43 | } 44 | 45 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 46 | pub struct ComputerToolCallOutputItemOutput { 47 | #[serde(rename = "type")] 48 | pub type_field: String, 49 | #[serde(skip_serializing_if = "Option::is_none")] 50 | pub file_id: Option, 51 | #[serde(skip_serializing_if = "Option::is_none")] 52 | pub image_url: Option, 53 | } 54 | 55 | impl Default for ComputerToolCallOutputItemOutput { 56 | fn default() -> Self { 57 | Self { 58 | type_field: "computer_screenshot".to_string(), 59 | image_url: None, 60 | file_id: None, 61 | } 62 | } 63 | } 64 | 65 | impl ComputerToolCallOutputItemOutput { 66 | pub fn new() -> Self { 67 | Self::default() 68 | } 69 | 70 | pub fn file_id(mut self, value: impl Into) -> Self { 71 | self.file_id = Some(value.into()); 72 | self 73 | } 74 | 75 | pub fn image_url(mut self, value: impl Into) -> Self { 76 | self.image_url = Some(value.into()); 77 | self 78 | } 79 | } 80 | 81 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 82 | pub struct AcknowledgedSafetyChecks { 83 | pub id: String, 84 | #[serde(skip_serializing_if = "Option::is_none")] 85 | pub code: Option, 86 | #[serde(skip_serializing_if = "Option::is_none")] 87 | pub message: Option, 88 | } 89 | 90 | impl AcknowledgedSafetyChecks { 91 | pub fn new(id: impl Into) -> Self { 92 | Self { 93 | id: id.into(), 94 | code: None, 95 | message: None, 96 | } 97 | } 98 | 99 | pub fn code(mut self, value: impl Into) -> Self { 100 | self.code = Some(value.into()); 101 | self 102 | } 103 | 104 | pub fn message(mut self, value: impl Into) -> Self { 105 | self.message = Some(value.into()); 106 | self 107 | } 108 | } 109 | 110 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 111 | pub struct ComputerToolCallOutputItem { 112 | pub call_id: String, 113 | pub output: ComputerToolCallOutputItemOutput, 114 | #[serde(rename = "type")] 115 | pub type_field: String, 116 | pub acknowledged_safety_checks: Option>, 117 | #[serde(skip_serializing_if = "Option::is_none")] 118 | pub id: Option, 119 | #[serde(skip_serializing_if = "Option::is_none")] 120 | pub status: Option, 121 | } 122 | 123 | impl ComputerToolCallOutputItem { 124 | pub fn new(call_id: impl Into, output: ComputerToolCallOutputItemOutput) -> Self { 125 | Self { 126 | call_id: call_id.into(), 127 | output, 128 | type_field: "computer_call_output".to_string(), 129 | acknowledged_safety_checks: None, 130 | id: None, 131 | status: None, 132 | } 133 | } 134 | 135 | pub fn acknowledged_safety_checks(mut self, value: Vec) -> Self { 136 | self.acknowledged_safety_checks = Some(value); 137 | self 138 | } 139 | 140 | pub fn id(mut self, value: impl Into) -> Self { 141 | self.id = Some(value.into()); 142 | self 143 | } 144 | 145 | pub fn status(mut self, value: Status) -> Self { 146 | self.status = Some(value); 147 | self 148 | } 149 | } 150 | 151 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 152 | pub struct FunctionToolCallOutputItem { 153 | pub call_id: String, 154 | pub output: String, 155 | #[serde(rename = "type")] 156 | pub type_field: String, 157 | #[serde(skip_serializing_if = "Option::is_none")] 158 | pub id: Option, 159 | #[serde(skip_serializing_if = "Option::is_none")] 160 | pub status: Option, 161 | } 162 | 163 | impl FunctionToolCallOutputItem { 164 | pub fn new(call_id: impl Into, output: impl Into) -> Self { 165 | Self { 166 | call_id: call_id.into(), 167 | output: output.into(), 168 | type_field: "function_call_output".to_string(), 169 | id: None, 170 | status: None, 171 | } 172 | } 173 | 174 | pub fn id(mut self, value: impl Into) -> Self { 175 | self.id = Some(value.into()); 176 | self 177 | } 178 | 179 | pub fn status(mut self, value: Status) -> Self { 180 | self.status = Some(value); 181 | self 182 | } 183 | } 184 | 185 | #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] 186 | pub struct Summary { 187 | pub text: String, 188 | #[serde(rename = "type")] 189 | pub type_field: String, 190 | } 191 | 192 | impl Summary { 193 | pub fn new(text: impl Into) -> Self { 194 | Self { 195 | text: text.into(), 196 | type_field: "summary_text".to_string(), 197 | } 198 | } 199 | } 200 | 201 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 202 | #[serde(untagged)] 203 | pub enum Item { 204 | InputMessage(InputMessageItem), 205 | OutputMessage(OutputMessageItem), 206 | FileSearchToolCall(FileSearchToolCallItem), 207 | ComputerToolCall(ComputerToolCallItem), 208 | ComputerToolCallOutput(ComputerToolCallOutputItem), 209 | WebSearchToolCall(WebSearchToolCallItem), 210 | FunctionToolCall(FunctionToolCallItem), 211 | FunctionToolCallOutput(FunctionToolCallOutputItem), 212 | Reasoning(ReasoningItem), 213 | } 214 | 215 | impl From for Item { 216 | fn from(item: InputMessageItem) -> Self { 217 | Item::InputMessage(item) 218 | } 219 | } 220 | 221 | impl From for Item { 222 | fn from(item: OutputMessageItem) -> Self { 223 | Item::OutputMessage(item) 224 | } 225 | } 226 | 227 | impl From for Item { 228 | fn from(item: FileSearchToolCallItem) -> Self { 229 | Item::FileSearchToolCall(item) 230 | } 231 | } 232 | 233 | impl From for Item { 234 | fn from(item: ComputerToolCallItem) -> Self { 235 | Item::ComputerToolCall(item) 236 | } 237 | } 238 | 239 | impl From for Item { 240 | fn from(item: ComputerToolCallOutputItem) -> Self { 241 | Item::ComputerToolCallOutput(item) 242 | } 243 | } 244 | 245 | impl From for Item { 246 | fn from(item: WebSearchToolCallItem) -> Self { 247 | Item::WebSearchToolCall(item) 248 | } 249 | } 250 | 251 | impl From for Item { 252 | fn from(item: FunctionToolCallItem) -> Self { 253 | Item::FunctionToolCall(item) 254 | } 255 | } 256 | 257 | impl From for Item { 258 | fn from(item: FunctionToolCallOutputItem) -> Self { 259 | Item::FunctionToolCallOutput(item) 260 | } 261 | } 262 | 263 | impl From for Item { 264 | fn from(item: ReasoningItem) -> Self { 265 | Item::Reasoning(item) 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/text.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::openai::errors::ConversionError; 4 | 5 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 6 | #[serde(rename_all = "snake_case")] 7 | enum ResponseFormatType { 8 | Text, 9 | JsonSchema, 10 | JsonObject, 11 | } 12 | 13 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 14 | pub struct TextFormat { 15 | #[serde(rename = "type")] 16 | type_field: ResponseFormatType, // always text 17 | } 18 | 19 | impl TextFormat { 20 | pub fn new() -> Self { 21 | Self { 22 | type_field: ResponseFormatType::Text, 23 | } 24 | } 25 | } 26 | 27 | impl Default for TextFormat { 28 | fn default() -> Self { 29 | Self::new() 30 | } 31 | } 32 | 33 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 34 | pub struct JsonSchemaFormat { 35 | #[serde(rename = "type")] 36 | type_field: ResponseFormatType, // always json_schema 37 | name: String, 38 | schema: serde_json::Value, 39 | #[serde(skip_serializing_if = "Option::is_none")] 40 | description: Option, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | strict: Option, 43 | } 44 | 45 | impl JsonSchemaFormat { 46 | pub fn new(name: impl Into, schema: serde_json::Value) -> Self { 47 | Self { 48 | type_field: ResponseFormatType::JsonSchema, 49 | name: name.into(), 50 | schema, 51 | description: None, 52 | strict: Some(false), 53 | } 54 | } 55 | 56 | pub fn description(mut self, value: impl Into) -> Self { 57 | self.description = Some(value.into()); 58 | self 59 | } 60 | 61 | pub fn strict(mut self) -> Self { 62 | self.strict = Some(true); 63 | self 64 | } 65 | } 66 | 67 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 68 | pub struct JsonObjectFormat { 69 | #[serde(rename = "type")] 70 | type_field: ResponseFormatType, // always json_object 71 | } 72 | 73 | impl JsonObjectFormat { 74 | pub fn new() -> Self { 75 | Self { 76 | type_field: ResponseFormatType::JsonObject, 77 | } 78 | } 79 | } 80 | 81 | impl Default for JsonObjectFormat { 82 | fn default() -> Self { 83 | Self::new() 84 | } 85 | } 86 | 87 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 88 | #[serde(untagged)] 89 | pub enum ResponseFormat { 90 | Text(TextFormat), 91 | JsonSchema(JsonSchemaFormat), 92 | JsonObject(JsonObjectFormat), 93 | } 94 | 95 | impl std::fmt::Display for ResponseFormat { 96 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 97 | match self { 98 | ResponseFormat::Text(_) => write!(f, "text"), 99 | ResponseFormat::JsonSchema(_) => write!(f, "json_schema"), 100 | ResponseFormat::JsonObject(_) => write!(f, "json_object"), 101 | } 102 | } 103 | } 104 | 105 | impl From for ResponseFormat { 106 | fn from(text_format: TextFormat) -> Self { 107 | Self::Text(text_format) 108 | } 109 | } 110 | 111 | impl From for ResponseFormat { 112 | fn from(format: JsonSchemaFormat) -> Self { 113 | Self::JsonSchema(format) 114 | } 115 | } 116 | 117 | impl From for ResponseFormat { 118 | fn from(format: JsonObjectFormat) -> Self { 119 | Self::JsonObject(format) 120 | } 121 | } 122 | 123 | impl TryFrom for TextFormat { 124 | type Error = ConversionError; 125 | 126 | fn try_from(format: ResponseFormat) -> Result { 127 | match format { 128 | ResponseFormat::Text(inner) => Ok(inner), 129 | _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())), 130 | } 131 | } 132 | } 133 | 134 | impl TryFrom for JsonSchemaFormat { 135 | type Error = ConversionError; 136 | 137 | fn try_from(format: ResponseFormat) -> Result { 138 | match format { 139 | ResponseFormat::JsonSchema(inner) => Ok(inner), 140 | _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())), 141 | } 142 | } 143 | } 144 | 145 | impl TryFrom for JsonObjectFormat { 146 | type Error = ConversionError; 147 | 148 | fn try_from(format: ResponseFormat) -> Result { 149 | match format { 150 | ResponseFormat::JsonObject(inner) => Ok(inner), 151 | _ => Err(ConversionError::TryFrom("ResponseFormat".to_string())), 152 | } 153 | } 154 | } 155 | 156 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 157 | pub struct Text { 158 | #[serde(skip_serializing_if = "Option::is_none")] 159 | format: Option, 160 | } 161 | 162 | impl Default for Text { 163 | fn default() -> Self { 164 | Self { 165 | format: Some(ResponseFormat::Text(TextFormat::default())), 166 | } 167 | } 168 | } 169 | 170 | impl Text { 171 | pub fn response_format(mut self, value: ResponseFormat) -> Self { 172 | self.format = Some(value); 173 | self 174 | } 175 | } 176 | 177 | #[cfg(test)] 178 | mod tests { 179 | use super::*; 180 | use serde_json::json; 181 | 182 | #[test] 183 | fn it_builds_text_response_format() { 184 | let result = Text::default().response_format(TextFormat::new().into()); 185 | 186 | assert_eq!( 187 | result, 188 | Text { 189 | format: Some(ResponseFormat::Text(TextFormat { 190 | type_field: ResponseFormatType::Text 191 | })) 192 | } 193 | ); 194 | } 195 | 196 | #[test] 197 | fn it_builds_json_schema_response_format() { 198 | let schema = json!({ 199 | "name": "Alice", 200 | "age": 30, 201 | "active": true, 202 | "friends": ["Bob", "Charlie"], 203 | "address": { 204 | "street": "123 Main St", 205 | "city": "Somewhere" 206 | } 207 | }); 208 | 209 | let response_format: ResponseFormat = JsonSchemaFormat::new("test", schema.clone()) 210 | .description("this is a description") 211 | .into(); 212 | 213 | let result = Text::default().response_format(response_format); 214 | 215 | let expected = Text { 216 | format: Some(ResponseFormat::JsonSchema(JsonSchemaFormat { 217 | type_field: ResponseFormatType::JsonSchema, 218 | name: "test".to_string(), 219 | schema: schema, 220 | description: Some("this is a description".to_string()), 221 | strict: Some(false), 222 | })), 223 | }; 224 | 225 | assert_eq!(result, expected); 226 | } 227 | 228 | #[test] 229 | fn it_builds_json_object_response_format() { 230 | let response_format: ResponseFormat = JsonObjectFormat::new().into(); 231 | let result = Text::default().response_format(response_format); 232 | 233 | let expected = Text { 234 | format: Some(ResponseFormat::JsonObject(JsonObjectFormat { 235 | type_field: ResponseFormatType::JsonObject, 236 | })), 237 | }; 238 | 239 | assert_eq!(result, expected); 240 | } 241 | 242 | #[test] 243 | fn test_json_values() { 244 | // Test default text format 245 | let text = Text::default(); 246 | let json_value = serde_json::to_value(&text).unwrap(); 247 | assert_eq!( 248 | json_value, 249 | serde_json::json!({ 250 | "format": { 251 | "type": "text" 252 | } 253 | }) 254 | ); 255 | 256 | // Test with JSON schema format 257 | let schema = json!({ 258 | "type": "object", 259 | "properties": { 260 | "name": { "type": "string" }, 261 | "age": { "type": "number" }, 262 | "active": { "type": "boolean" } 263 | }, 264 | "required": ["name", "age"] 265 | }); 266 | 267 | let json_schema_format = JsonSchemaFormat::new("user_data", schema.clone()) 268 | .description("User information schema") 269 | .strict(); 270 | let text_with_schema = Text::default().response_format(json_schema_format.into()); 271 | let json_value = serde_json::to_value(&text_with_schema).unwrap(); 272 | assert_eq!( 273 | json_value, 274 | serde_json::json!({ 275 | "format": { 276 | "type": "json_schema", 277 | "name": "user_data", 278 | "schema": { 279 | "type": "object", 280 | "properties": { 281 | "name": { "type": "string" }, 282 | "age": { "type": "number" }, 283 | "active": { "type": "boolean" } 284 | }, 285 | "required": ["name", "age"] 286 | }, 287 | "description": "User information schema", 288 | "strict": true 289 | } 290 | }) 291 | ); 292 | 293 | // Test with JSON object format 294 | let text_with_json_object = Text::default().response_format(JsonObjectFormat::new().into()); 295 | let json_value = serde_json::to_value(&text_with_json_object).unwrap(); 296 | assert_eq!( 297 | json_value, 298 | serde_json::json!({ 299 | "format": { 300 | "type": "json_object" 301 | } 302 | }) 303 | ); 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /crates/providers/src/openai/request/input_models/common.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::errors::ConversionError; 2 | use std::str::FromStr; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Debug, Default, PartialEq, Copy, Clone, Serialize, Deserialize)] 7 | #[serde(rename_all = "lowercase")] 8 | pub enum Role { 9 | #[default] 10 | User, 11 | Assistant, 12 | System, 13 | Developer, 14 | } 15 | 16 | impl FromStr for Role { 17 | type Err = ConversionError; 18 | 19 | fn from_str(s: &str) -> Result { 20 | match s { 21 | "user" => Ok(Role::User), 22 | "assistant" => Ok(Role::Assistant), 23 | "system" => Ok(Role::System), 24 | "developer" => Ok(Role::Developer), 25 | _ => Err(ConversionError::FromStr(s.to_string())), 26 | } 27 | } 28 | } 29 | 30 | #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] 31 | #[serde(rename_all = "snake_case")] 32 | pub enum ContentType { 33 | InputText, 34 | InputImage, 35 | InputFile, 36 | } 37 | 38 | impl FromStr for ContentType { 39 | type Err = ConversionError; 40 | 41 | fn from_str(s: &str) -> Result { 42 | match s { 43 | "input_text" => Ok(ContentType::InputText), 44 | "input_image" => Ok(ContentType::InputImage), 45 | "input_file" => Ok(ContentType::InputFile), 46 | _ => Err(ConversionError::FromStr(s.to_string())), 47 | } 48 | } 49 | } 50 | 51 | #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] 52 | pub struct TextContent { 53 | #[serde(rename = "type")] 54 | pub type_field: ContentType, // always InputText 55 | pub text: String, 56 | } 57 | 58 | impl Default for TextContent { 59 | fn default() -> Self { 60 | Self { 61 | type_field: ContentType::InputText, 62 | text: String::new(), 63 | } 64 | } 65 | } 66 | 67 | impl TextContent { 68 | pub fn new() -> Self { 69 | Self::default() 70 | } 71 | 72 | pub fn text(mut self, text: impl Into) -> Self { 73 | self.text = text.into(); 74 | self 75 | } 76 | } 77 | 78 | #[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] 79 | #[serde(rename_all = "lowercase")] 80 | pub enum ImageDetail { 81 | High, 82 | Low, 83 | #[default] 84 | Auto, 85 | } 86 | 87 | impl FromStr for ImageDetail { 88 | type Err = ConversionError; 89 | 90 | fn from_str(s: &str) -> Result { 91 | match s { 92 | "high" => Ok(ImageDetail::High), 93 | "low" => Ok(ImageDetail::Low), 94 | "auto" => Ok(ImageDetail::Auto), 95 | _ => Err(ConversionError::FromStr(s.to_string())), 96 | } 97 | } 98 | } 99 | 100 | #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] 101 | pub struct ImageContent { 102 | #[serde(rename = "type")] 103 | pub type_field: ContentType, // always InputImage 104 | #[serde(skip_serializing_if = "Option::is_none")] 105 | pub image_url: Option, 106 | #[serde(skip_serializing_if = "Option::is_none")] 107 | pub file_id: Option, 108 | pub detail: ImageDetail, 109 | } 110 | 111 | impl Default for ImageContent { 112 | fn default() -> Self { 113 | Self { 114 | type_field: ContentType::InputImage, 115 | image_url: None, 116 | file_id: None, 117 | detail: ImageDetail::Auto, 118 | } 119 | } 120 | } 121 | 122 | impl ImageContent { 123 | pub fn new() -> Self { 124 | Self::default() 125 | } 126 | 127 | pub fn image_url(mut self, value: impl Into) -> Self { 128 | self.image_url = Some(value.into()); 129 | self 130 | } 131 | 132 | pub fn file_id(mut self, value: impl Into) -> Self { 133 | self.file_id = Some(value.into()); 134 | self 135 | } 136 | 137 | pub fn detail(mut self, value: impl AsRef) -> Result { 138 | self.detail = ImageDetail::from_str(value.as_ref())?; 139 | Ok(self) 140 | } 141 | } 142 | 143 | #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] 144 | pub struct FileContent { 145 | #[serde(rename = "type")] 146 | pub type_field: ContentType, // always InputFile, 147 | #[serde(skip_serializing_if = "Option::is_none")] 148 | pub file_id: Option, 149 | #[serde(skip_serializing_if = "Option::is_none")] 150 | pub file_data: Option, 151 | #[serde(skip_serializing_if = "Option::is_none")] 152 | pub filename: Option, 153 | } 154 | 155 | impl Default for FileContent { 156 | fn default() -> Self { 157 | Self { 158 | type_field: ContentType::InputFile, 159 | file_id: None, 160 | file_data: None, 161 | filename: None, 162 | } 163 | } 164 | } 165 | 166 | impl FileContent { 167 | pub fn new() -> Self { 168 | Self::default() 169 | } 170 | 171 | pub fn file_id(mut self, value: impl Into) -> Self { 172 | self.file_id = Some(value.into()); 173 | self 174 | } 175 | 176 | pub fn file_data(mut self, value: impl Into) -> Self { 177 | self.file_data = Some(value.into()); 178 | self 179 | } 180 | 181 | pub fn filename(mut self, value: impl Into) -> Self { 182 | self.filename = Some(value.into()); 183 | self 184 | } 185 | } 186 | 187 | #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] 188 | #[serde(untagged)] 189 | pub enum Content { 190 | Text(TextContent), 191 | Image(ImageContent), 192 | File(FileContent), 193 | } 194 | 195 | impl From for Content { 196 | fn from(text_content: TextContent) -> Self { 197 | Self::Text(text_content) 198 | } 199 | } 200 | 201 | impl From for Content { 202 | fn from(image_content: ImageContent) -> Self { 203 | Self::Image(image_content) 204 | } 205 | } 206 | 207 | impl From for Content { 208 | fn from(file_content: FileContent) -> Self { 209 | Self::File(file_content) 210 | } 211 | } 212 | 213 | #[cfg(test)] 214 | mod tests { 215 | use super::*; 216 | use serde_json::json; 217 | 218 | // let's compare the json output of the default values 219 | #[test] 220 | fn test_default_values() { 221 | let text_content = TextContent::default(); 222 | let image_content = ImageContent::default(); 223 | let file_content = FileContent::default(); 224 | 225 | let text_content_json = serde_json::to_value(&text_content).unwrap(); 226 | let image_content_json = serde_json::to_value(&image_content).unwrap(); 227 | let file_content_json = serde_json::to_value(&file_content).unwrap(); 228 | 229 | assert_eq!(text_content_json, json!({"type": "input_text", "text": ""})); 230 | assert_eq!( 231 | image_content_json, 232 | json!({"type": "input_image", "detail": "auto"}) 233 | ); 234 | assert_eq!(file_content_json, json!({"type": "input_file"})); 235 | } 236 | 237 | #[test] 238 | fn test_text_content() { 239 | let text = "Hello, world!"; 240 | let text_content = TextContent::new().text(text); 241 | let text_content_json = serde_json::to_value(&text_content).unwrap(); 242 | assert_eq!( 243 | text_content_json, 244 | json!({"type": "input_text", "text": text}) 245 | ); 246 | } 247 | 248 | #[test] 249 | fn test_image_content() { 250 | let image_url = "https://example.com/image.png"; 251 | let file_id = "1234567890"; 252 | let detail = "auto"; 253 | 254 | let image_content = ImageContent::new() 255 | .image_url(image_url) 256 | .file_id(file_id) 257 | .detail(detail) 258 | .unwrap(); 259 | 260 | let image_content_json = serde_json::to_value(&image_content).unwrap(); 261 | assert_eq!( 262 | image_content_json, 263 | json!({"type": "input_image", "image_url": image_url, "file_id": file_id, "detail": detail}) 264 | ); 265 | } 266 | 267 | #[test] 268 | fn test_file_content() { 269 | let file_id = "1234567890"; 270 | let file_data = ""; 271 | let filename = "image.png"; 272 | 273 | let file_content = FileContent::new() 274 | .file_id(file_id) 275 | .file_data(file_data) 276 | .filename(filename); 277 | 278 | let file_content_json = serde_json::to_value(&file_content).unwrap(); 279 | assert_eq!( 280 | file_content_json, 281 | json!({"type": "input_file", "file_id": file_id, "file_data": file_data, "filename": filename}) 282 | ); 283 | } 284 | 285 | #[test] 286 | fn test_role_from_str() { 287 | assert_eq!(Role::from_str("user").unwrap(), Role::User); 288 | assert_eq!(Role::from_str("assistant").unwrap(), Role::Assistant); 289 | assert_eq!(Role::from_str("system").unwrap(), Role::System); 290 | assert_eq!(Role::from_str("developer").unwrap(), Role::Developer); 291 | } 292 | 293 | #[test] 294 | fn test_image_detail_from_str() { 295 | assert_eq!(ImageDetail::from_str("high").unwrap(), ImageDetail::High); 296 | assert_eq!(ImageDetail::from_str("low").unwrap(), ImageDetail::Low); 297 | assert_eq!(ImageDetail::from_str("auto").unwrap(), ImageDetail::Auto); 298 | } 299 | 300 | #[test] 301 | fn test_from_specific_content_to_content() { 302 | let text = "Hello, world!"; 303 | let image_url = "https://example.com/image.png"; 304 | let file_id = "1234567890"; 305 | 306 | let text_content_builder = TextContent::new().text(text); 307 | let text_content: Content = text_content_builder.into(); 308 | 309 | let image_content_builder = ImageContent::new().image_url(image_url); 310 | let image_content: Content = image_content_builder.into(); 311 | 312 | let file_content_builder = FileContent::new().file_id(file_id); 313 | let file_content: Content = file_content_builder.into(); 314 | 315 | assert_eq!(text_content, Content::Text(TextContent::new().text(text))); 316 | assert_eq!( 317 | image_content, 318 | Content::Image(ImageContent::new().image_url(image_url)) 319 | ); 320 | assert_eq!( 321 | file_content, 322 | Content::File(FileContent::new().file_id(file_id)) 323 | ); 324 | } 325 | } 326 | -------------------------------------------------------------------------------- /crates/providers/src/openai/constants.rs: -------------------------------------------------------------------------------- 1 | use crate::openai::errors::InputError; 2 | use serde::{Deserialize, Serialize}; 3 | use std::str::FromStr; 4 | 5 | pub const OPENAI_API_URL: &str = "https://api.openai.com/v1"; 6 | 7 | #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] 8 | #[serde(into = "String", try_from = "&str")] 9 | pub enum OpenAIModelId { 10 | Gpt4, 11 | Gpt4Turbo, 12 | Gpt4TurboPreview, 13 | Gpt4_0125Preview, 14 | Gpt4_1106Preview, 15 | Gpt4_0613, 16 | Gpt4O, 17 | Gpt4O2024_05_13, 18 | Gpt4O2024_08_06, 19 | Gpt4O2024_11_20, 20 | Gpt4ORealtimePreview, 21 | Gpt4ORealtimePreview2024_10_01, 22 | Gpt4ORealtimePreview2024_12_17, 23 | Gpt4OAudioPreview, 24 | Gpt4OAudioPreview2024_10_01, 25 | Gpt4OAudioPreview2024_12_17, 26 | Gpt4OMini, 27 | Gpt4OMini2024_07_18, 28 | Gpt4OMiniRealtimePreview, 29 | Gpt4OMiniRealtimePreview2024_12_17, 30 | Gpt4OMiniAudioPreview, 31 | Gpt4OMiniAudioPreview2024_12_17, 32 | Gpt4OMiniSearchPreview, 33 | Gpt4OMiniSearchPreview2025_03_11, 34 | Gpt4OSearchPreview, 35 | Gpt4OSearchPreview2025_03_11, 36 | Gpt4OMiniTts, 37 | Gpt4OTranscribe, 38 | Gpt4OMiniTranscribe, 39 | Gpt4_5Preview, 40 | Gpt4_5Preview2025_02_27, 41 | Gpt4_1, 42 | Gpt4_1_2025_04_14, 43 | Gpt4_1Mini, 44 | Gpt4_1Mini2025_04_14, 45 | Gpt4_1Nano, 46 | Gpt4_1Nano2025_04_14, 47 | Gpt3_5Turbo, 48 | Gpt3_5Turbo0125, 49 | Gpt3_5Turbo1106, 50 | Gpt3_5Turbo16k, 51 | Gpt3_5TurboInstruct, 52 | Gpt3_5TurboInstruct0914, 53 | GptImage1, 54 | Tts1Hd, 55 | Tts1Hd1106, 56 | TextEmbeddingAda002, 57 | TextEmbedding3Small, 58 | TextEmbedding3Large, 59 | ChatGpt4oLatest, 60 | O1Preview, 61 | O1Preview2024_09_12, 62 | O1Mini, 63 | O1Mini2024_09_12, 64 | O1Pro, 65 | O1Pro2025_03_19, 66 | O3Mini, 67 | O3Mini2025_01_31, 68 | O4Mini, 69 | O4Mini2025_04_16, 70 | OmniModerationLatest, 71 | OmniModeration2024_09_26, 72 | CodexMiniLatest, 73 | } 74 | 75 | impl Default for OpenAIModelId { 76 | fn default() -> Self { 77 | Self::Gpt3_5Turbo 78 | } 79 | } 80 | 81 | impl OpenAIModelId { 82 | pub fn as_str(&self) -> &'static str { 83 | match self { 84 | Self::Gpt4 => "gpt-4", 85 | Self::Gpt4Turbo => "gpt-4-turbo", 86 | Self::Gpt4TurboPreview => "gpt-4-turbo-preview", 87 | Self::Gpt4_0125Preview => "gpt-4-0125-preview", 88 | Self::Gpt4_1106Preview => "gpt-4-1106-preview", 89 | Self::Gpt4_0613 => "gpt-4-0613", 90 | Self::Gpt4O => "gpt-4o", 91 | Self::Gpt4O2024_05_13 => "gpt-4o-2024-05-13", 92 | Self::Gpt4O2024_08_06 => "gpt-4o-2024-08-06", 93 | Self::Gpt4O2024_11_20 => "gpt-4o-2024-11-20", 94 | Self::Gpt4ORealtimePreview => "gpt-4o-realtime-preview", 95 | Self::Gpt4ORealtimePreview2024_10_01 => "gpt-4o-realtime-preview-2024-10-01", 96 | Self::Gpt4ORealtimePreview2024_12_17 => "gpt-4o-realtime-preview-2024-12-17", 97 | Self::Gpt4OAudioPreview => "gpt-4o-audio-preview", 98 | Self::Gpt4OAudioPreview2024_10_01 => "gpt-4o-audio-preview-2024-10-01", 99 | Self::Gpt4OAudioPreview2024_12_17 => "gpt-4o-audio-preview-2024-12-17", 100 | Self::Gpt4OMini => "gpt-4o-mini", 101 | Self::Gpt4OMini2024_07_18 => "gpt-4o-mini-2024-07-18", 102 | Self::Gpt4OMiniRealtimePreview => "gpt-4o-mini-realtime-preview", 103 | Self::Gpt4OMiniRealtimePreview2024_12_17 => "gpt-4o-mini-realtime-preview-2024-12-17", 104 | Self::Gpt4OMiniAudioPreview => "gpt-4o-mini-audio-preview", 105 | Self::Gpt4OMiniAudioPreview2024_12_17 => "gpt-4o-mini-audio-preview-2024-12-17", 106 | Self::Gpt4OMiniSearchPreview => "gpt-4o-mini-search-preview", 107 | Self::Gpt4OMiniSearchPreview2025_03_11 => "gpt-4o-mini-search-preview-2025-03-11", 108 | Self::Gpt4OSearchPreview => "gpt-4o-search-preview", 109 | Self::Gpt4OSearchPreview2025_03_11 => "gpt-4o-search-preview-2025-03-11", 110 | Self::Gpt4OMiniTts => "gpt-4o-mini-tts", 111 | Self::Gpt4OTranscribe => "gpt-4o-transcribe", 112 | Self::Gpt4OMiniTranscribe => "gpt-4o-mini-transcribe", 113 | Self::Gpt4_5Preview => "gpt-4.5-preview", 114 | Self::Gpt4_5Preview2025_02_27 => "gpt-4.5-preview-2025-02-27", 115 | Self::Gpt4_1 => "gpt-4.1", 116 | Self::Gpt4_1_2025_04_14 => "gpt-4.1-2025-04-14", 117 | Self::Gpt4_1Mini => "gpt-4.1-mini", 118 | Self::Gpt4_1Mini2025_04_14 => "gpt-4.1-mini-2025-04-14", 119 | Self::Gpt4_1Nano => "gpt-4.1-nano", 120 | Self::Gpt4_1Nano2025_04_14 => "gpt-4.1-nano-2025-04-14", 121 | Self::Gpt3_5Turbo => "gpt-3.5-turbo", 122 | Self::Gpt3_5Turbo0125 => "gpt-3.5-turbo-0125", 123 | Self::Gpt3_5Turbo1106 => "gpt-3.5-turbo-1106", 124 | Self::Gpt3_5Turbo16k => "gpt-3.5-turbo-16k", 125 | Self::Gpt3_5TurboInstruct => "gpt-3.5-turbo-instruct", 126 | Self::Gpt3_5TurboInstruct0914 => "gpt-3.5-turbo-instruct-0914", 127 | Self::GptImage1 => "gpt-image-1", 128 | Self::Tts1Hd => "tts-1-hd", 129 | Self::Tts1Hd1106 => "tts-1-hd-1106", 130 | Self::TextEmbeddingAda002 => "text-embedding-ada-002", 131 | Self::TextEmbedding3Small => "text-embedding-3-small", 132 | Self::TextEmbedding3Large => "text-embedding-3-large", 133 | Self::ChatGpt4oLatest => "chatgpt4o-latest", 134 | Self::O1Preview => "o1-preview", 135 | Self::O1Preview2024_09_12 => "o1-preview-2024-09-12", 136 | Self::O1Mini => "o1-mini", 137 | Self::O1Mini2024_09_12 => "o1-mini-2024-09-12", 138 | Self::O1Pro => "o1-pro", 139 | Self::O1Pro2025_03_19 => "o1-pro-2025-03-19", 140 | Self::O3Mini => "o3-mini", 141 | Self::O3Mini2025_01_31 => "o3-mini-2025-01-31", 142 | Self::O4Mini => "o4-mini", 143 | Self::O4Mini2025_04_16 => "o4-mini-2025-04-16", 144 | Self::OmniModerationLatest => "omni-moderation-latest", 145 | Self::OmniModeration2024_09_26 => "omni-moderation-2024-09-26", 146 | Self::CodexMiniLatest => "codex-mini-latest", 147 | } 148 | } 149 | } 150 | 151 | impl TryFrom<&str> for OpenAIModelId { 152 | type Error = InputError; 153 | 154 | fn try_from(value: &str) -> Result { 155 | OpenAIModelId::from_str(value) 156 | } 157 | } 158 | 159 | impl From for String { 160 | fn from(value: OpenAIModelId) -> Self { 161 | value.as_str().to_string() 162 | } 163 | } 164 | 165 | impl FromStr for OpenAIModelId { 166 | type Err = InputError; 167 | 168 | fn from_str(s: &str) -> Result { 169 | match s { 170 | "gpt-4" => Ok(Self::Gpt4), 171 | "gpt-4-turbo" => Ok(Self::Gpt4Turbo), 172 | "gpt-4-turbo-preview" => Ok(Self::Gpt4TurboPreview), 173 | "gpt-4-0125-preview" => Ok(Self::Gpt4_0125Preview), 174 | "gpt-4-1106-preview" => Ok(Self::Gpt4_1106Preview), 175 | "gpt-4-0613" => Ok(Self::Gpt4_0613), 176 | "gpt-4o" => Ok(Self::Gpt4O), 177 | "gpt-4o-2024-05-13" => Ok(Self::Gpt4O2024_05_13), 178 | "gpt-4o-2024-08-06" => Ok(Self::Gpt4O2024_08_06), 179 | "gpt-4o-2024-11-20" => Ok(Self::Gpt4O2024_11_20), 180 | "gpt-4o-realtime-preview" => Ok(Self::Gpt4ORealtimePreview), 181 | "gpt-4o-realtime-preview-2024-10-01" => Ok(Self::Gpt4ORealtimePreview2024_10_01), 182 | "gpt-4o-realtime-preview-2024-12-17" => Ok(Self::Gpt4ORealtimePreview2024_12_17), 183 | "gpt-4o-audio-preview" => Ok(Self::Gpt4OAudioPreview), 184 | "gpt-4o-audio-preview-2024-10-01" => Ok(Self::Gpt4OAudioPreview2024_10_01), 185 | "gpt-4o-audio-preview-2024-12-17" => Ok(Self::Gpt4OAudioPreview2024_12_17), 186 | "gpt-4o-mini" => Ok(Self::Gpt4OMini), 187 | "gpt-4o-mini-2024-07-18" => Ok(Self::Gpt4OMini2024_07_18), 188 | "gpt-4o-mini-realtime-preview" => Ok(Self::Gpt4OMiniRealtimePreview), 189 | "gpt-4o-mini-realtime-preview-2024-12-17" => { 190 | Ok(Self::Gpt4OMiniRealtimePreview2024_12_17) 191 | } 192 | "gpt-4o-mini-audio-preview" => Ok(Self::Gpt4OMiniAudioPreview), 193 | "gpt-4o-mini-audio-preview-2024-12-17" => Ok(Self::Gpt4OMiniAudioPreview2024_12_17), 194 | "gpt-4o-mini-search-preview" => Ok(Self::Gpt4OMiniSearchPreview), 195 | "gpt-4o-mini-search-preview-2025-03-11" => Ok(Self::Gpt4OMiniSearchPreview2025_03_11), 196 | "gpt-4o-search-preview" => Ok(Self::Gpt4OSearchPreview), 197 | "gpt-4o-search-preview-2025-03-11" => Ok(Self::Gpt4OSearchPreview2025_03_11), 198 | "gpt-4o-mini-tts" => Ok(Self::Gpt4OMiniTts), 199 | "gpt-4o-transcribe" => Ok(Self::Gpt4OTranscribe), 200 | "gpt-4o-mini-transcribe" => Ok(Self::Gpt4OMiniTranscribe), 201 | "gpt-4.5-preview" => Ok(Self::Gpt4_5Preview), 202 | "gpt-4.5-preview-2025-02-27" => Ok(Self::Gpt4_5Preview2025_02_27), 203 | "gpt-4.1" => Ok(Self::Gpt4_1), 204 | "gpt-4.1-2025-04-14" => Ok(Self::Gpt4_1_2025_04_14), 205 | "gpt-4.1-mini" => Ok(Self::Gpt4_1Mini), 206 | "gpt-4.1-mini-2025-04-14" => Ok(Self::Gpt4_1Mini2025_04_14), 207 | "gpt-4.1-nano" => Ok(Self::Gpt4_1Nano), 208 | "gpt-4.1-nano-2025-04-14" => Ok(Self::Gpt4_1Nano2025_04_14), 209 | "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo), 210 | "gpt-3.5-turbo-0125" => Ok(Self::Gpt3_5Turbo0125), 211 | "gpt-3.5-turbo-1106" => Ok(Self::Gpt3_5Turbo1106), 212 | "gpt-3.5-turbo-16k" => Ok(Self::Gpt3_5Turbo16k), 213 | "gpt-3.5-turbo-instruct" => Ok(Self::Gpt3_5TurboInstruct), 214 | "gpt-3.5-turbo-instruct-0914" => Ok(Self::Gpt3_5TurboInstruct0914), 215 | "gpt-image-1" => Ok(Self::GptImage1), 216 | "tts-1-hd" => Ok(Self::Tts1Hd), 217 | "tts-1-hd-1106" => Ok(Self::Tts1Hd1106), 218 | "text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002), 219 | "text-embedding-3-small" => Ok(Self::TextEmbedding3Small), 220 | "text-embedding-3-large" => Ok(Self::TextEmbedding3Large), 221 | "chatgpt4o-latest" => Ok(Self::ChatGpt4oLatest), 222 | "o1-preview" => Ok(Self::O1Preview), 223 | "o1-preview-2024-09-12" => Ok(Self::O1Preview2024_09_12), 224 | "o1-mini" => Ok(Self::O1Mini), 225 | "o1-mini-2024-09-12" => Ok(Self::O1Mini2024_09_12), 226 | "o1-pro" => Ok(Self::O1Pro), 227 | "o1-pro-2025-03-19" => Ok(Self::O1Pro2025_03_19), 228 | "o3-mini" => Ok(Self::O3Mini), 229 | "o3-mini-2025-01-31" => Ok(Self::O3Mini2025_01_31), 230 | "o4-mini" => Ok(Self::O4Mini), 231 | "o4-mini-2025-04-16" => Ok(Self::O4Mini2025_04_16), 232 | "omni-moderation-latest" => Ok(Self::OmniModerationLatest), 233 | "omni-moderation-2024-09-26" => Ok(Self::OmniModeration2024_09_26), 234 | "codex-mini-latest" => Ok(Self::CodexMiniLatest), 235 | _ => Err(InputError::InvalidModelId(s.to_string())), 236 | } 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /crates/providers/src/openai/common/tool.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use crate::openai::errors::ConversionError; 6 | 7 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 8 | #[serde(rename_all = "lowercase")] 9 | pub enum ComparisonOperator { 10 | Eq, 11 | Ne, 12 | Gt, 13 | Gte, 14 | Lt, 15 | Lte, 16 | } 17 | 18 | impl FromStr for ComparisonOperator { 19 | type Err = ConversionError; 20 | 21 | fn from_str(s: &str) -> Result { 22 | match s { 23 | "eq" => Ok(ComparisonOperator::Eq), 24 | "ne" => Ok(ComparisonOperator::Ne), 25 | "gt" => Ok(ComparisonOperator::Gt), 26 | "gte" => Ok(ComparisonOperator::Gte), 27 | "lt" => Ok(ComparisonOperator::Lt), 28 | "lte" => Ok(ComparisonOperator::Lte), 29 | _ => Err(ConversionError::FromStr(s.to_string())), 30 | } 31 | } 32 | } 33 | 34 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 35 | #[serde(untagged)] 36 | pub enum FilterValue { 37 | String(String), 38 | Boolean(bool), 39 | Number(f64), 40 | } 41 | 42 | impl FilterValue { 43 | pub fn string(filter: impl Into) -> Self { 44 | Self::String(filter.into()) 45 | } 46 | 47 | pub fn boolean(filter: bool) -> Self { 48 | Self::Boolean(filter) 49 | } 50 | 51 | pub fn number(filter: f64) -> Self { 52 | Self::Number(filter) 53 | } 54 | } 55 | 56 | impl From for FilterValue { 57 | fn from(value: String) -> Self { 58 | FilterValue::String(value) 59 | } 60 | } 61 | 62 | impl From<&str> for FilterValue { 63 | fn from(value: &str) -> Self { 64 | FilterValue::String(value.to_string()) 65 | } 66 | } 67 | 68 | impl From for FilterValue { 69 | fn from(value: bool) -> Self { 70 | FilterValue::Boolean(value) 71 | } 72 | } 73 | 74 | impl From for FilterValue { 75 | fn from(value: f64) -> Self { 76 | FilterValue::Number(value) 77 | } 78 | } 79 | 80 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 81 | pub struct ComparisonFilter { 82 | key: String, 83 | #[serde(rename = "type")] 84 | type_field: ComparisonOperator, 85 | value: FilterValue, 86 | } 87 | 88 | impl ComparisonFilter { 89 | pub fn build>( 90 | key: impl Into, 91 | comparison_operator: impl AsRef, 92 | value: V, 93 | ) -> Self { 94 | Self { 95 | key: key.into(), 96 | type_field: ComparisonOperator::from_str(comparison_operator.as_ref()).unwrap(), 97 | value: value.into(), 98 | } 99 | } 100 | } 101 | 102 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 103 | #[serde(rename_all = "lowercase")] 104 | pub enum CompoundOperator { 105 | And, 106 | Or, 107 | } 108 | 109 | impl FromStr for CompoundOperator { 110 | type Err = ConversionError; 111 | 112 | fn from_str(s: &str) -> Result { 113 | match s { 114 | "and" => Ok(CompoundOperator::And), 115 | "or" => Ok(CompoundOperator::Or), 116 | _ => Err(ConversionError::FromStr(s.to_string())), 117 | } 118 | } 119 | } 120 | 121 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 122 | pub struct CompoundFilter { 123 | filters: Vec, 124 | #[serde(rename = "type")] 125 | type_field: CompoundOperator, 126 | } 127 | 128 | impl CompoundFilter { 129 | pub fn build(filters: Vec, compound_operator: impl AsRef) -> Self { 130 | Self { 131 | filters, 132 | type_field: CompoundOperator::from_str(compound_operator.as_ref()).unwrap(), 133 | } 134 | } 135 | } 136 | 137 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 138 | #[serde(untagged)] 139 | pub enum FileSearchFilter { 140 | Comparison(ComparisonFilter), 141 | Compound(CompoundFilter), 142 | } 143 | 144 | impl FileSearchFilter { 145 | pub fn build_comparison_filter>( 146 | key: impl Into, 147 | comparison_operator: impl AsRef, 148 | value: V, 149 | ) -> Self { 150 | Self::Comparison(ComparisonFilter::build(key, comparison_operator, value)) 151 | } 152 | 153 | pub fn build_compound_filter( 154 | filters: Vec, 155 | compound_operator: impl AsRef, 156 | ) -> Self { 157 | Self::Compound(CompoundFilter::build(filters, compound_operator)) 158 | } 159 | } 160 | 161 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 162 | pub struct RankingOptions { 163 | #[serde(skip_serializing_if = "Option::is_none")] 164 | ranker: Option, 165 | #[serde(skip_serializing_if = "Option::is_none")] 166 | score_threshold: Option, 167 | } 168 | 169 | impl RankingOptions { 170 | pub fn new() -> Self { 171 | Self { 172 | ranker: None, 173 | score_threshold: None, 174 | } 175 | } 176 | 177 | pub fn ranker(mut self, value: impl Into) -> Self { 178 | self.ranker = Some(value.into()); 179 | self 180 | } 181 | 182 | pub fn score_threshold(mut self, value: f32) -> Self { 183 | self.score_threshold = Some(value); 184 | self 185 | } 186 | } 187 | 188 | impl Default for RankingOptions { 189 | fn default() -> Self { 190 | Self::new() 191 | } 192 | } 193 | 194 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 195 | pub struct FileSearchTool { 196 | #[serde(rename = "type")] 197 | type_field: String, 198 | vector_store_ids: Vec, 199 | #[serde(skip_serializing_if = "Option::is_none")] 200 | filters: Option, 201 | #[serde(skip_serializing_if = "Option::is_none")] 202 | max_num_results: Option, 203 | #[serde(skip_serializing_if = "Option::is_none")] 204 | ranking_options: Option, 205 | } 206 | 207 | impl FileSearchTool { 208 | pub fn new(vector_store_ids: Vec>) -> Self { 209 | Self { 210 | type_field: "file_search".to_string(), 211 | vector_store_ids: vector_store_ids.into_iter().map(|id| id.into()).collect(), 212 | filters: None, 213 | max_num_results: None, 214 | ranking_options: None, 215 | } 216 | } 217 | 218 | pub fn filters(mut self, filters: FileSearchFilter) -> Self { 219 | self.filters = Some(filters); 220 | self 221 | } 222 | 223 | pub fn max_num_results(mut self, value: u8) -> Self { 224 | self.max_num_results = Some(value); 225 | self 226 | } 227 | 228 | pub fn ranking_options(mut self, value: RankingOptions) -> Self { 229 | self.ranking_options = Some(value); 230 | self 231 | } 232 | } 233 | 234 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 235 | pub struct FunctionTool { 236 | name: String, 237 | parameters: serde_json::Value, 238 | strict: bool, 239 | #[serde(rename = "type")] 240 | type_field: String, 241 | #[serde(skip_serializing_if = "Option::is_none")] 242 | description: Option, 243 | } 244 | 245 | impl FunctionTool { 246 | pub fn new(name: impl Into, parameters: serde_json::Value) -> Self { 247 | Self { 248 | name: name.into(), 249 | parameters, 250 | strict: true, 251 | type_field: "function".to_string(), 252 | description: None, 253 | } 254 | } 255 | 256 | pub fn strict(mut self, value: bool) -> Self { 257 | self.strict = value; 258 | self 259 | } 260 | 261 | pub fn description(mut self, value: impl Into) -> Self { 262 | self.description = Some(value.into()); 263 | self 264 | } 265 | } 266 | 267 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 268 | pub struct ComputerUseTool { 269 | display_height: f32, 270 | display_width: f32, 271 | environment: String, 272 | #[serde(rename = "type")] 273 | type_field: String, 274 | } 275 | 276 | impl ComputerUseTool { 277 | pub fn new(display_height: f32, display_width: f32, environment: impl Into) -> Self { 278 | Self { 279 | display_height, 280 | display_width, 281 | environment: environment.into(), 282 | type_field: "computer_use_preview".to_string(), 283 | } 284 | } 285 | } 286 | 287 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 288 | #[serde(rename_all = "lowercase")] 289 | pub enum SearchContextSize { 290 | Low, 291 | Medium, 292 | High, 293 | } 294 | 295 | impl FromStr for SearchContextSize { 296 | type Err = ConversionError; 297 | 298 | fn from_str(s: &str) -> Result { 299 | match s { 300 | "low" => Ok(SearchContextSize::Low), 301 | "medium" => Ok(SearchContextSize::Medium), 302 | "high" => Ok(SearchContextSize::High), 303 | _ => Err(ConversionError::FromStr(s.to_string())), 304 | } 305 | } 306 | } 307 | 308 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 309 | pub struct UserLocation { 310 | #[serde(rename = "type")] 311 | type_field: String, // NOTE: this is always "approximate" value 312 | #[serde(skip_serializing_if = "Option::is_none")] 313 | city: Option, 314 | #[serde(skip_serializing_if = "Option::is_none")] 315 | country: Option, // NOTE: this is ISO-3166 country code 316 | #[serde(skip_serializing_if = "Option::is_none")] 317 | region: Option, 318 | #[serde(skip_serializing_if = "Option::is_none")] 319 | timezone: Option, // NOTE: this is IANA timezone 320 | } 321 | 322 | impl UserLocation { 323 | pub fn new() -> Self { 324 | Self { 325 | type_field: "approximate".to_string(), 326 | city: None, 327 | country: None, 328 | region: None, 329 | timezone: None, 330 | } 331 | } 332 | 333 | pub fn city(mut self, value: impl Into) -> Self { 334 | self.city = Some(value.into()); 335 | self 336 | } 337 | 338 | pub fn country(mut self, value: impl Into) -> Self { 339 | self.country = Some(value.into()); 340 | self 341 | } 342 | 343 | pub fn region(mut self, value: impl Into) -> Self { 344 | self.region = Some(value.into()); 345 | self 346 | } 347 | 348 | pub fn timezone(mut self, value: impl Into) -> Self { 349 | self.timezone = Some(value.into()); 350 | self 351 | } 352 | } 353 | 354 | impl Default for UserLocation { 355 | fn default() -> Self { 356 | Self::new() 357 | } 358 | } 359 | 360 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 361 | pub struct WebSearchTool { 362 | #[serde(rename = "type")] 363 | type_field: String, // NOTE: this is either web_search_preview or web_search_preview_2025_03_11C 364 | #[serde(skip_serializing_if = "Option::is_none")] 365 | search_context_size: Option, 366 | #[serde(skip_serializing_if = "Option::is_none")] 367 | user_location: Option, 368 | } 369 | 370 | impl WebSearchTool { 371 | pub fn new(type_field: impl Into) -> Self { 372 | Self { 373 | type_field: type_field.into(), 374 | search_context_size: None, 375 | user_location: None, 376 | } 377 | } 378 | 379 | pub fn search_context_size(mut self, value: SearchContextSize) -> Self { 380 | self.search_context_size = Some(value); 381 | self 382 | } 383 | 384 | pub fn user_location(mut self, value: UserLocation) -> Self { 385 | self.user_location = Some(value); 386 | self 387 | } 388 | } 389 | 390 | #[derive(Debug, PartialEq, Serialize, Deserialize)] 391 | #[serde(untagged)] 392 | pub enum Tool { 393 | FileSearch(FileSearchTool), 394 | Function(FunctionTool), 395 | ComputerUse(ComputerUseTool), 396 | WebSearch(WebSearchTool), 397 | } 398 | 399 | impl From for Tool { 400 | fn from(tool: FileSearchTool) -> Self { 401 | Tool::FileSearch(tool) 402 | } 403 | } 404 | 405 | impl TryFrom for FileSearchTool { 406 | type Error = ConversionError; 407 | 408 | fn try_from(tool: Tool) -> Result { 409 | match tool { 410 | Tool::FileSearch(inner) => Ok(inner), 411 | _ => Err(ConversionError::TryFrom("Tool".to_string())), 412 | } 413 | } 414 | } 415 | 416 | impl From for Tool { 417 | fn from(tool: FunctionTool) -> Self { 418 | Tool::Function(tool) 419 | } 420 | } 421 | 422 | impl TryFrom for FunctionTool { 423 | type Error = ConversionError; 424 | 425 | fn try_from(tool: Tool) -> Result { 426 | match tool { 427 | Tool::Function(inner) => Ok(inner), 428 | _ => Err(ConversionError::TryFrom("Tool".to_string())), 429 | } 430 | } 431 | } 432 | 433 | impl From for Tool { 434 | fn from(tool: ComputerUseTool) -> Self { 435 | Tool::ComputerUse(tool) 436 | } 437 | } 438 | 439 | impl TryFrom for ComputerUseTool { 440 | type Error = ConversionError; 441 | 442 | fn try_from(tool: Tool) -> Result { 443 | match tool { 444 | Tool::ComputerUse(inner) => Ok(inner), 445 | _ => Err(ConversionError::TryFrom("Tool".to_string())), 446 | } 447 | } 448 | } 449 | 450 | impl From for Tool { 451 | fn from(tool: WebSearchTool) -> Self { 452 | Tool::WebSearch(tool) 453 | } 454 | } 455 | 456 | impl TryFrom for WebSearchTool { 457 | type Error = ConversionError; 458 | 459 | fn try_from(tool: Tool) -> Result { 460 | match tool { 461 | Tool::WebSearch(inner) => Ok(inner), 462 | _ => Err(ConversionError::TryFrom("Tool".to_string())), 463 | } 464 | } 465 | } 466 | 467 | #[cfg(test)] 468 | mod tests { 469 | use super::*; 470 | use serde_json::json; 471 | 472 | #[test] 473 | fn it_creates_file_search_tool_with_comparison_operator() { 474 | let vector_store_ids = vec![ 475 | "id_1".to_string(), 476 | "id_2".to_string(), 477 | "id_3".to_string(), 478 | "id_4".to_string(), 479 | ]; 480 | let tool: Tool = FileSearchTool::new(vector_store_ids.clone()).into(); 481 | let tool: Tool = FileSearchTool::try_from(tool) 482 | .unwrap() 483 | .ranking_options( 484 | RankingOptions::new() 485 | .ranker("test_ranker") 486 | .score_threshold(1.0), 487 | ) 488 | .filters(FileSearchFilter::build_comparison_filter( 489 | "test_key", 490 | "eq", 491 | "test_value", 492 | )) 493 | .max_num_results(1) 494 | .into(); 495 | 496 | let expected = Tool::FileSearch(FileSearchTool { 497 | type_field: "file_search".to_string(), 498 | vector_store_ids, 499 | ranking_options: Some(RankingOptions { 500 | ranker: Some("test_ranker".to_string()), 501 | score_threshold: Some(1.0), 502 | }), 503 | filters: Some(FileSearchFilter::Comparison(ComparisonFilter { 504 | key: "test_key".to_string(), 505 | type_field: ComparisonOperator::Eq, 506 | value: FilterValue::String("test_value".to_string()), 507 | })), 508 | max_num_results: Some(1), 509 | }); 510 | 511 | assert_eq!(tool, expected); 512 | } 513 | 514 | #[test] 515 | fn it_creates_file_search_tool_with_compound_operator() { 516 | let vector_store_ids = vec![ 517 | "id_1".to_string(), 518 | "id_2".to_string(), 519 | "id_3".to_string(), 520 | "id_4".to_string(), 521 | ]; 522 | let tool: Tool = FileSearchTool::new(vector_store_ids.clone()) 523 | .filters(FileSearchFilter::build_compound_filter( 524 | vec![FileSearchFilter::build_comparison_filter( 525 | "test_key", 526 | "eq", 527 | "test_value", 528 | )], 529 | "and", 530 | )) 531 | .ranking_options( 532 | RankingOptions::new() 533 | .ranker("test_ranker") 534 | .score_threshold(1.0), 535 | ) 536 | .into(); 537 | 538 | let expected = Tool::FileSearch(FileSearchTool { 539 | type_field: "file_search".to_string(), 540 | vector_store_ids, 541 | ranking_options: Some(RankingOptions { 542 | ranker: Some("test_ranker".to_string()), 543 | score_threshold: Some(1.0), 544 | }), 545 | filters: Some(FileSearchFilter::Compound(CompoundFilter { 546 | type_field: CompoundOperator::And, 547 | filters: vec![FileSearchFilter::Comparison(ComparisonFilter { 548 | key: "test_key".to_string(), 549 | type_field: ComparisonOperator::Eq, 550 | value: FilterValue::String("test_value".to_string()), 551 | })], 552 | })), 553 | max_num_results: None, 554 | }); 555 | 556 | assert_eq!(tool, expected); 557 | } 558 | 559 | #[test] 560 | fn it_creates_function_tool() { 561 | let tool: Tool = FunctionTool::new( 562 | "function_tool_test", 563 | json!({ 564 | "name": "test" 565 | }), 566 | ) 567 | .description("this is description") 568 | .into(); 569 | 570 | let expected = Tool::Function(FunctionTool { 571 | description: Some("this is description".to_string()), 572 | type_field: "function".to_string(), 573 | strict: true, 574 | parameters: json!({"name": "test"}), 575 | name: "function_tool_test".to_string(), 576 | }); 577 | 578 | assert_eq!(tool, expected); 579 | } 580 | 581 | #[test] 582 | fn it_creates_computer_use_tool() { 583 | let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into(); 584 | 585 | let expected = Tool::ComputerUse(ComputerUseTool { 586 | type_field: "computer_use_preview".to_string(), 587 | environment: "test_environment".to_string(), 588 | display_width: 64.0, 589 | display_height: 64.0, 590 | }); 591 | 592 | assert_eq!(tool, expected); 593 | } 594 | 595 | #[test] 596 | fn it_creates_web_search_tool() { 597 | let tool: Tool = WebSearchTool::new("web_search_preview".to_string()) 598 | .search_context_size(SearchContextSize::Low) 599 | .user_location( 600 | UserLocation::new() 601 | .city("Istanbul") 602 | .country("TR") 603 | .region("Marmara") 604 | .timezone("Europe/Istanbul"), 605 | ) 606 | .into(); 607 | 608 | let expected = Tool::WebSearch(WebSearchTool { 609 | user_location: Some(UserLocation { 610 | type_field: "approximate".to_string(), 611 | city: Some("Istanbul".to_string()), 612 | country: Some("TR".to_string()), 613 | region: Some("Marmara".to_string()), 614 | timezone: Some("Europe/Istanbul".to_string()), 615 | }), 616 | search_context_size: Some(SearchContextSize::Low), 617 | type_field: "web_search_preview".to_string(), 618 | }); 619 | 620 | assert_eq!(tool, expected); 621 | } 622 | 623 | // test the json values of the tool 624 | #[test] 625 | fn test_json_values() { 626 | // FileSearchTool test 627 | let tool: Tool = FileSearchTool::new(vec!["id_1", "id_2"]) 628 | .filters(FileSearchFilter::build_comparison_filter( 629 | "test_key", 630 | "eq", 631 | "test_value".to_string(), 632 | )) 633 | .max_num_results(1) 634 | .ranking_options( 635 | RankingOptions::new() 636 | .ranker("test_ranker") 637 | .score_threshold(1.0), 638 | ) 639 | .into(); 640 | let json_value = serde_json::to_value(&tool).unwrap(); 641 | 642 | assert_eq!( 643 | json_value, 644 | serde_json::json!({ 645 | "type": "file_search", 646 | "vector_store_ids": ["id_1", "id_2"], 647 | "filters": { 648 | "type": "comparison", 649 | "key": "test_key", 650 | "type": "eq", 651 | "value": "test_value" 652 | }, 653 | "max_num_results": 1, 654 | "ranking_options": { 655 | "ranker": "test_ranker", 656 | "score_threshold": 1.0 657 | } 658 | }) 659 | ); 660 | 661 | // FunctionTool test 662 | let tool: Tool = FunctionTool::new("test", json!({})) 663 | .description("this is description") 664 | .into(); 665 | let json_value = serde_json::to_value(&tool).unwrap(); 666 | 667 | assert_eq!( 668 | json_value, 669 | serde_json::json!({ 670 | "type": "function", 671 | "name": "test", 672 | "parameters": {}, 673 | "strict": true, 674 | "description": "this is description" 675 | }) 676 | ); 677 | 678 | // ComputerUseTool test 679 | let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into(); 680 | let json_value = serde_json::to_value(&tool).unwrap(); 681 | 682 | assert_eq!( 683 | json_value, 684 | serde_json::json!({ 685 | "type": "computer_use_preview", 686 | "environment": "test_environment", 687 | "display_width": 64.0, 688 | "display_height": 64.0 689 | }) 690 | ); 691 | 692 | // WebSearchTool test with web_search_preview 693 | let tool: Tool = WebSearchTool::new("web_search_preview".to_string()) 694 | .search_context_size(SearchContextSize::Low) 695 | .user_location( 696 | UserLocation::new() 697 | .city("Istanbul") 698 | .country("TR") 699 | .region("Marmara") 700 | .timezone("Europe/Istanbul"), 701 | ) 702 | .into(); 703 | let json_value = serde_json::to_value(&tool).unwrap(); 704 | 705 | assert_eq!( 706 | json_value, 707 | serde_json::json!({ 708 | "type": "web_search_preview", 709 | "search_context_size": "low", 710 | "user_location": { 711 | "type": "approximate", 712 | "city": "Istanbul", 713 | "country": "TR", 714 | "region": "Marmara", 715 | "timezone": "Europe/Istanbul" 716 | } 717 | }) 718 | ); 719 | 720 | // WebSearchTool test with web_search_preview_2025_03_11C 721 | let tool: Tool = WebSearchTool::new("web_search_preview_2025_03_11C".to_string()) 722 | .search_context_size(SearchContextSize::Low) 723 | .user_location( 724 | UserLocation::new() 725 | .city("Istanbul") 726 | .country("TR") 727 | .region("Marmara") 728 | .timezone("Europe/Istanbul"), 729 | ) 730 | .into(); 731 | let json_value = serde_json::to_value(&tool).unwrap(); 732 | 733 | assert_eq!( 734 | json_value, 735 | serde_json::json!({ 736 | "type": "web_search_preview_2025_03_11C", 737 | "search_context_size": "low", 738 | "user_location": { 739 | "type": "approximate", 740 | "city": "Istanbul", 741 | "country": "TR", 742 | "region": "Marmara", 743 | "timezone": "Europe/Istanbul" 744 | } 745 | }) 746 | ); 747 | } 748 | } 749 | --------------------------------------------------------------------------------