├── .swiftlint.yml ├── .gitignore ├── Sources └── AgentSDK-Swift │ ├── AgentSDK_Swift.swift │ ├── RunContext.swift │ ├── Usage.swift │ ├── Examples │ └── HelloWorld.swift │ ├── Handoff.swift │ ├── Guardrail.swift │ ├── Models │ ├── ModelInterface.swift │ └── OpenAIModel.swift │ ├── Tool.swift │ ├── Agent.swift │ ├── Run.swift │ ├── AgentRunner.swift │ └── ModelSettings.swift ├── .github └── workflows │ └── swift.yml ├── Package.resolved ├── LICENSE ├── Package.swift ├── Tests └── AgentSDK-SwiftTests │ ├── OpenAIModelTests.swift │ └── AgentSDK_SwiftTests.swift ├── CLAUDE.md ├── Examples └── SimpleApp.swift └── README.md /.swiftlint.yml: -------------------------------------------------------------------------------- 1 | included: 2 | - Sources 3 | - Tests 4 | excluded: 5 | - Examples 6 | - .build 7 | - .vscode 8 | - .github 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | xcuserdata/ 5 | DerivedData/ 6 | .swiftpm/configuration/registries.json 7 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 8 | .netrc 9 | .vscode 10 | .vendor -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/AgentSDK_Swift.swift: -------------------------------------------------------------------------------- 1 | // AgentSDK-Swift main module file 2 | // Provides public exports for the entire package 3 | 4 | // Core components 5 | public typealias RunnerError = AgentRunner.RunnerError 6 | public typealias AgentName = String 7 | public typealias RunResult = Run.Result 8 | -------------------------------------------------------------------------------- /.github/workflows/swift.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Swift project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-swift 3 | 4 | name: Swift 5 | on: 6 | push: 7 | branches: "*" 8 | pull_request: 9 | branches: "*" 10 | jobs: 11 | build: 12 | runs-on: macos-latest 13 | steps: 14 | - uses: actions/checkout@v5 15 | - name: Install swiftlint 16 | run: | 17 | brew update 18 | brew install swiftlint 19 | - name: Lint 20 | run: swiftlint 21 | - name: Build 22 | run: swift build -v 23 | - name: Run tests 24 | run: swift test -v 25 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "73f52251324c3f3cca923d24457b8701c869616cfa337a0282ff2c522d6f07d0", 3 | "pins" : [ 4 | { 5 | "identity" : "swift-http-types", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/apple/swift-http-types", 8 | "state" : { 9 | "revision" : "ef18d829e8b92d731ad27bb81583edd2094d1ce3", 10 | "version" : "1.3.1" 11 | } 12 | }, 13 | { 14 | "identity" : "swift-openapi-runtime", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/apple/swift-openapi-runtime", 17 | "state" : { 18 | "revision" : "23146bc8710ac5e57abb693113f02dc274cf39b6", 19 | "version" : "1.8.0" 20 | } 21 | } 22 | ], 23 | "version" : 3 24 | } 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2025 Fumito Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/RunContext.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Wraps the caller-provided context and aggregates usage for the duration of a run. 4 | public final class RunContext: @unchecked Sendable { 5 | /// The user provided context object. 6 | public let value: Context 7 | /// Accumulated model usage. 8 | public private(set) var usage: Usage 9 | 10 | /// Creates a new run context wrapper. 11 | /// - Parameters: 12 | /// - value: The underlying context value. 13 | /// - usage: Optional initial usage information. 14 | public init(value: Context, usage: Usage = Usage()) { 15 | self.value = value 16 | self.usage = usage 17 | } 18 | 19 | /// Updates the usage with information from a model response. 20 | /// - Parameter usage: The usage information returned by the model. 21 | public func recordUsage(_ usage: ModelResponse.Usage?) { 22 | var aggregated = self.usage 23 | aggregated.add(usage) 24 | self.usage = aggregated 25 | } 26 | 27 | /// Merges usage from another run context. 28 | /// - Parameter other: The run context to merge. 29 | public func mergeUsage(from other: RunContext) { 30 | var aggregated = usage 31 | aggregated.merge(other.usage) 32 | usage = aggregated 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 6.0 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "AgentSDK-Swift", 8 | platforms: [ 9 | .macOS(.v13), 10 | .iOS(.v16) 11 | ], 12 | products: [ 13 | // Products define the executables and libraries a package produces, making them visible to other packages. 14 | .library( 15 | name: "AgentSDK-Swift", 16 | targets: ["AgentSDK-Swift"]), 17 | .executable( 18 | name: "SimpleApp", 19 | targets: ["SimpleApp"]) 20 | ], 21 | dependencies: [ 22 | .package(url: "https://github.com/apple/swift-openapi-runtime", from: "1.0.0") 23 | ], 24 | targets: [ 25 | // Targets are the basic building blocks of a package, defining a module or a test suite. 26 | // Targets can depend on other targets in this package and products from dependencies. 27 | .target( 28 | name: "AgentSDK-Swift", 29 | dependencies: [ 30 | .product(name: "OpenAPIRuntime", package: "swift-openapi-runtime") 31 | ]), 32 | .executableTarget( 33 | name: "SimpleApp", 34 | dependencies: ["AgentSDK-Swift"], 35 | path: "Examples" 36 | ), 37 | .testTarget( 38 | name: "AgentSDK-SwiftTests", 39 | dependencies: ["AgentSDK-Swift"] 40 | ) 41 | ] 42 | ) 43 | -------------------------------------------------------------------------------- /Tests/AgentSDK-SwiftTests/OpenAIModelTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | #if canImport(FoundationNetworking) 4 | import FoundationNetworking 5 | #endif 6 | @testable import AgentSDK_Swift 7 | 8 | // MARK: - OpenAI Model Tests 9 | 10 | @Test func testOpenAIModelCreation() async throws { 11 | // Create an OpenAI model 12 | let apiKey = "test-api-key" 13 | _ = OpenAIModel(apiKey: apiKey) 14 | 15 | // Just check instantiation works - no assertions needed 16 | #expect(Bool(true)) 17 | } 18 | 19 | @Test func testCustomBaseURL() async throws { 20 | // Create an OpenAI model with custom base URL 21 | let customBaseURL = URL(string: "https://custom-openai-api.example.com/v1")! 22 | _ = OpenAIModel(apiKey: "test-key", apiBaseURL: customBaseURL) 23 | 24 | // Just check instantiation works 25 | #expect(Bool(true)) 26 | } 27 | 28 | @Test func testModelSettings() async throws { 29 | // This is mostly a compilation validation test 30 | _ = [ 31 | Message(role: .user, content: .text("Hello, how are you?")) 32 | ] 33 | 34 | let settings = ModelSettings( 35 | modelName: "test-model", 36 | temperature: 0.7, 37 | topP: 0.9, 38 | maxTokens: 1000 39 | ) 40 | 41 | // Check settings values 42 | #expect(settings.modelName == "test-model") 43 | #expect(settings.temperature == 0.7) 44 | #expect(settings.topP == 0.9) 45 | #expect(settings.maxTokens == 1000) 46 | } 47 | 48 | @Test func testMessageContent() async throws { 49 | // Create text message 50 | let textMessage = Message( 51 | role: .user, 52 | content: .text("Hello") 53 | ) 54 | 55 | // Check message content type 56 | if case .text(let content) = textMessage.content { 57 | #expect(content == "Hello") 58 | } else { 59 | #expect(Bool(false), "Should be text content") 60 | } 61 | 62 | // Check role 63 | #expect(textMessage.role == .user) 64 | } 65 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Usage.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Tracks token and request usage for a run. 4 | public struct Usage: Sendable { 5 | /// Total requests made to the model provider. 6 | public private(set) var requests: Int 7 | /// Total input tokens sent across all requests. 8 | public private(set) var inputTokens: Int 9 | /// Total output tokens received across all requests. 10 | public private(set) var outputTokens: Int 11 | /// Combined input and output tokens across all requests. 12 | public private(set) var totalTokens: Int 13 | 14 | /// Creates a new usage tracker. 15 | /// - Parameters: 16 | /// - requests: Initial number of requests. 17 | /// - inputTokens: Initial number of input tokens. 18 | /// - outputTokens: Initial number of output tokens. 19 | /// - totalTokens: Initial total tokens. 20 | public init( 21 | requests: Int = 0, 22 | inputTokens: Int = 0, 23 | outputTokens: Int = 0, 24 | totalTokens: Int = 0 25 | ) { 26 | self.requests = requests 27 | self.inputTokens = inputTokens 28 | self.outputTokens = outputTokens 29 | self.totalTokens = totalTokens 30 | } 31 | 32 | /// Adds usage details from a model response. 33 | /// - Parameter usage: Usage information returned by the model response. 34 | public mutating func add(_ usage: ModelResponse.Usage?) { 35 | guard let usage = usage else { return } 36 | requests += 1 37 | inputTokens += usage.promptTokens 38 | outputTokens += usage.completionTokens 39 | totalTokens += usage.totalTokens 40 | } 41 | 42 | /// Merges another usage tracker into this one. 43 | /// - Parameter other: The usage to merge. 44 | public mutating func merge(_ other: Usage) { 45 | requests += other.requests 46 | inputTokens += other.inputTokens 47 | outputTokens += other.outputTokens 48 | totalTokens += other.totalTokens 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # AgentSDK-Swift Guidelines 2 | 3 | ## Build & Test Commands 4 | ```bash 5 | # Build the project 6 | swift build 7 | 8 | # Run all tests 9 | swift test 10 | 11 | # Run a specific test 12 | swift test --filter AgentSDK_SwiftTests/testSpecificFunction 13 | 14 | # Run a group of tests 15 | swift test --filter AgentSDK_SwiftTests 16 | 17 | # Run example app (requires API key) 18 | export OPENAI_API_KEY=your_api_key_here 19 | swift run SimpleApp 20 | ``` 21 | 22 | ## Test Guidelines 23 | - Use `@Test` annotation to mark test functions 24 | - Use `#expect` assertions for validation 25 | - Follow AAA pattern (Arrange, Act, Assert) 26 | - Test each class & function independently 27 | - Use descriptive test names starting with "test" 28 | - Group tests with `// MARK: - Category` comments 29 | - Create separate test files for complex classes 30 | - Test public interfaces rather than implementation details 31 | - Use `@testable import` to access internal members 32 | - Handle platform-specific imports in test files 33 | 34 | ## Code Style Guidelines 35 | 36 | ### Formatting & Structure 37 | - Use 4-space indentation 38 | - PascalCase for types (classes, structs, enums) 39 | - camelCase for functions, variables, parameters 40 | - Explicit visibility modifiers (public, internal, private) 41 | - Triple-slash (///) for documentation comments 42 | 43 | ### Types & Error Handling 44 | - Protocol-oriented design with clear interfaces 45 | - Use generics for type safety 46 | - Structured error types with nested enums 47 | - Consistent do-catch blocks and error propagation 48 | - Include descriptive error messages 49 | 50 | ### Swift Practices 51 | - Minimal imports (Foundation first) 52 | - Use URLSession for networking (not AsyncHTTPClient) 53 | - Handle platform differences with conditional imports (`#if canImport(FoundationNetworking)`) 54 | - Async/await for asynchronous operations 55 | - Immutable data structures where possible 56 | - Dependency injection via initializers 57 | - Comprehensive unit tests for core functionality -------------------------------------------------------------------------------- /Examples/SimpleApp.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import AgentSDK_Swift 3 | 4 | /// Simple demo app showing basic usage of AgentSDK-Swift 5 | @main 6 | struct SimpleApp { 7 | /// Main entry point 8 | static func main() async throws { 9 | guard let apiKey = ProcessInfo.processInfo.environment["OPENAI_API_KEY"] else { 10 | print("Error: OPENAI_API_KEY environment variable not set") 11 | print("Please set the OPENAI_API_KEY environment variable to your OpenAI API key") 12 | exit(1) 13 | } 14 | 15 | print("🤖 AgentSDK-Swift Simple Demo") 16 | print("==============================") 17 | 18 | try await runSimpleAgent(apiKey: apiKey) 19 | } 20 | 21 | /// Runs a simple agent example 22 | /// - Parameter apiKey: OpenAI API key 23 | static func runSimpleAgent(apiKey: String) async throws { 24 | // Register models 25 | await ModelProvider.shared.registerOpenAIModels(apiKey: apiKey) 26 | 27 | // Create a tool that calculates a sum 28 | let calculateTool = Tool( 29 | name: "calculateSum", 30 | description: "Calculate the sum of two numbers", 31 | parameters: [ 32 | Tool.Parameter( 33 | name: "a", 34 | description: "First number", 35 | type: .number 36 | ), 37 | Tool.Parameter( 38 | name: "b", 39 | description: "Second number", 40 | type: .number 41 | ) 42 | ], 43 | execute: { parameters, _ in 44 | guard let a = parameters["a"] as? Double, 45 | let b = parameters["b"] as? Double else { 46 | return "Invalid numbers provided" 47 | } 48 | 49 | let sum = a + b 50 | return "The sum of \(a) and \(b) is \(sum)" 51 | } 52 | ) 53 | 54 | // Create agent with the calculation tool 55 | let agent = Agent( 56 | name: "CalculatorAssistant", 57 | instructions: """ 58 | You are a helpful assistant that can perform math calculations. 59 | When asked about calculations, use the calculateSum tool to add numbers together. 60 | """ 61 | ).addTool(calculateTool) 62 | 63 | // Input with streaming 64 | print("\nSending query: What is 42 + 17?") 65 | 66 | _ = try await AgentRunner.runStreamed( 67 | agent: agent, 68 | input: "What is 42 + 17?", 69 | context: () 70 | ) { content in 71 | print(content, terminator: "") 72 | } 73 | 74 | print("\n\nDemo complete! 👋\n") 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Examples/HelloWorld.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Example showing basic agent usage with a tool 4 | public struct HelloWorldExample { 5 | /// Runs the hello world example 6 | /// - Parameter apiKey: The OpenAI API key 7 | public static func run(apiKey: String) async throws { 8 | // Register OpenAI models 9 | await ModelProvider.shared.registerOpenAIModels(apiKey: apiKey) 10 | 11 | // Create a tool that returns the current time 12 | let currentTimeTool = Tool( 13 | name: "getCurrentTime", 14 | description: "Get the current time", 15 | parameters: [], 16 | availability: .always, 17 | execute: { _, _ in 18 | let formatter = DateFormatter() 19 | formatter.timeStyle = .medium 20 | formatter.dateStyle = .medium 21 | return formatter.string(from: Date()) 22 | } 23 | ) 24 | 25 | // Create an agent with the current time tool 26 | let agent = Agent( 27 | name: "TimeAssistant", 28 | instructions: """ 29 | You are a helpful assistant that can tell users the current time. 30 | When asked about the time, use the getCurrentTime tool to provide an accurate response. 31 | """ 32 | ).addTool(currentTimeTool) 33 | 34 | // Run the agent 35 | print("Running agent...") 36 | 37 | let result = try await AgentRunner.run( 38 | agent: agent, 39 | input: "What time is it right now?", 40 | context: () 41 | ) 42 | 43 | // Print the result 44 | print("Agent response:") 45 | print(result.finalOutput) 46 | } 47 | 48 | /// Runs the hello world example with streaming 49 | /// - Parameter apiKey: The OpenAI API key 50 | public static func runStreamed(apiKey: String) async throws { 51 | // Register OpenAI models 52 | await ModelProvider.shared.registerOpenAIModels(apiKey: apiKey) 53 | 54 | // Create a tool that returns the current time 55 | let currentTimeTool = Tool( 56 | name: "getCurrentTime", 57 | description: "Get the current time", 58 | parameters: [], 59 | execute: { _, _ in 60 | let formatter = DateFormatter() 61 | formatter.timeStyle = .medium 62 | formatter.dateStyle = .medium 63 | return formatter.string(from: Date()) 64 | } 65 | ) 66 | 67 | // Create an agent with the current time tool 68 | let agent = Agent( 69 | name: "TimeAssistant", 70 | instructions: """ 71 | You are a helpful assistant that can tell users the current time. 72 | When asked about the time, use the getCurrentTime tool to provide an accurate response. 73 | """ 74 | ).addTool(currentTimeTool) 75 | 76 | // Run the agent with streaming 77 | print("Running agent with streaming...") 78 | 79 | _ = try await AgentRunner.runStreamed( 80 | agent: agent, 81 | input: "What time is it right now?", 82 | context: () 83 | ) { content in 84 | // Print each content chunk as it arrives 85 | print(content, terminator: "") 86 | } 87 | 88 | // Print completion message 89 | print("\nAgent response complete.") 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AgentSDK-Swift 2 | 3 | A Swift implementation of the OpenAI Agents SDK, allowing you to build AI agent applications with tools, guardrails, and multi-agent workflows. 4 | 5 | ## Status 6 | 7 | 🚧 **Early Development** - This project is a Swift port of the [OpenAI Agents Python SDK](https://github.com/openai/openai-agents-py) and is currently in early development. 8 | 9 | ## Features 10 | 11 | - 🤖 Create AI agents with custom instructions and tools 12 | - 🛠️ Define and use tools as Swift functions 13 | - 🔄 Hand off between multiple agents for complex workflows 14 | - 🛡️ Apply guardrails to ensure safe and high-quality outputs 15 | - 📊 Stream responses in real-time 16 | - 📝 Support for OpenAI's latest models 17 | 18 | ## Requirements 19 | 20 | - Swift 6.0+ 21 | - macOS 13.0+ / iOS 16.0+ 22 | - OpenAI API Key 23 | 24 | ## Installation 25 | 26 | ### Swift Package Manager 27 | 28 | Add the following to your `Package.swift` file: 29 | 30 | ```swift 31 | dependencies: [ 32 | .package(url: "https://github.com/fumito-ito/AgentSDK-Swift.git", from: "0.1.0") 33 | ] 34 | ``` 35 | 36 | ## Quick Start 37 | 38 | Here's a simple example to get you started: 39 | 40 | ```swift 41 | import AgentSDK_Swift 42 | 43 | // Register OpenAI models 44 | ModelProvider.shared.registerOpenAIModels(apiKey: "your-api-key-here") 45 | 46 | // Create a tool 47 | let weatherTool = Tool( 48 | name: "getWeather", 49 | description: "Get the current weather for a location", 50 | parameters: [ 51 | Tool.Parameter( 52 | name: "location", 53 | description: "The location to get weather for", 54 | type: .string 55 | ) 56 | ], 57 | execute: { params, _ in 58 | let location = params["location"] as? String ?? "Unknown" 59 | return "It's sunny and 72°F in \(location)" 60 | } 61 | ) 62 | 63 | // Create an agent 64 | let agent = Agent( 65 | name: "WeatherAssistant", 66 | instructions: "You are a helpful weather assistant." 67 | ).addTool(weatherTool) 68 | 69 | // Run the agent 70 | Task { 71 | do { 72 | let result = try await AgentRunner.run( 73 | agent: agent, 74 | input: "What's the weather like in San Francisco?", 75 | context: () 76 | ) 77 | 78 | print(result.finalOutput) 79 | } catch { 80 | print("Error: \(error)") 81 | } 82 | } 83 | ``` 84 | 85 | ## Advanced Usage 86 | 87 | ### Using Guardrails 88 | 89 | ```swift 90 | // Create a length guardrail 91 | let lengthGuardrail = InputLengthGuardrail(maxLength: 500) 92 | 93 | // Create an agent with the guardrail 94 | let agent = Agent( 95 | name: "AssistantWithGuardrails", 96 | instructions: "You are a helpful assistant." 97 | ).addInputGuardrail(lengthGuardrail) 98 | ``` 99 | 100 | ### Multi-Agent Handoffs 101 | 102 | ```swift 103 | // Create specialized agents 104 | let weatherAgent = Agent(name: "WeatherAgent", instructions: "...") 105 | .addTool(weatherTool) 106 | 107 | let travelAgent = Agent(name: "TravelAgent", instructions: "...") 108 | .addTool(travelTool) 109 | 110 | // Create main agent with handoff to weather agent 111 | let mainAgent = Agent( 112 | name: "MainAgent", 113 | instructions: "You are a helpful assistant.", 114 | handoffs: [ 115 | Handoff.withKeywords( 116 | agent: weatherAgent, 117 | keywords: ["weather", "temperature", "forecast"] 118 | ), 119 | Handoff.withKeywords( 120 | agent: travelAgent, 121 | keywords: ["travel", "flight", "hotel", "booking"] 122 | ) 123 | ] 124 | ) 125 | ``` 126 | 127 | ### Streaming Responses 128 | 129 | ```swift 130 | // Run with streaming 131 | let result = try await AgentRunner.runStreamed( 132 | agent: agent, 133 | input: "Tell me about the weather in London", 134 | context: () 135 | ) { content in 136 | // Process each chunk as it arrives 137 | print(content, terminator: "") 138 | } 139 | ``` 140 | 141 | ## Running the Example 142 | 143 | The project includes a simple example application that demonstrates how to use the SDK: 144 | 145 | ```bash 146 | # Set your OpenAI API key 147 | export OPENAI_API_KEY=your_api_key_here 148 | 149 | # Run the example 150 | swift run SimpleApp 151 | ``` 152 | 153 | ## Documentation 154 | 155 | Documentation is currently in development. For now, please refer to the source code and examples for usage guidance. 156 | 157 | ## Contributing 158 | 159 | Contributions are welcome! Feel free to open issues or submit pull requests. 160 | 161 | ## License 162 | 163 | This project is licensed under the MIT License - see the LICENSE file for details. 164 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Handoff.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Represents a handoff from one agent to another 4 | public struct Handoff { 5 | /// The agent to hand off to 6 | public let agent: Agent 7 | 8 | /// The filter to determine whether to hand off 9 | public let filter: any HandoffFilter 10 | 11 | /// Creates a new handoff 12 | /// - Parameters: 13 | /// - agent: The agent to hand off to 14 | /// - filter: The filter to determine whether to hand off 15 | public init(agent: Agent, filter: any HandoffFilter) { 16 | self.agent = agent 17 | self.filter = filter 18 | } 19 | 20 | /// Creates a new handoff with a keyword filter 21 | /// - Parameters: 22 | /// - agent: The agent to hand off to 23 | /// - keywords: The keywords to trigger the handoff 24 | /// - caseSensitive: Whether the keyword matching is case sensitive 25 | /// - Returns: A new handoff 26 | public static func withKeywords( 27 | agent: Agent, 28 | keywords: [String], 29 | caseSensitive: Bool = false 30 | ) -> Handoff { 31 | let filter = KeywordHandoffFilter( 32 | keywords: keywords, 33 | caseSensitive: caseSensitive 34 | ) 35 | 36 | return Handoff(agent: agent, filter: filter) 37 | } 38 | 39 | /// Creates a new handoff with a custom filter function 40 | /// - Parameters: 41 | /// - agent: The agent to hand off to 42 | /// - filterFunction: The function to determine whether to hand off 43 | /// - Returns: A new handoff 44 | public static func withCustomFilter( 45 | agent: Agent, 46 | filterFunction: @escaping (String, Context) -> Bool 47 | ) -> Handoff { 48 | let filter = CustomHandoffFilter(filterFunction: filterFunction) 49 | return Handoff(agent: agent, filter: filter) 50 | } 51 | } 52 | 53 | /// Protocol for determining whether to hand off to another agent 54 | public protocol HandoffFilter { 55 | /// The context type for the filter 56 | associatedtype Context 57 | 58 | /// Determines whether to hand off to another agent 59 | /// - Parameters: 60 | /// - input: The input to check 61 | /// - context: The context for the check 62 | /// - Returns: True if the input should trigger a handoff, false otherwise 63 | func shouldHandoff(input: String, context: Context) -> Bool 64 | } 65 | 66 | /// A handoff filter that triggers on keywords 67 | public struct KeywordHandoffFilter: HandoffFilter { 68 | /// The keywords to trigger the handoff 69 | private let keywords: [String] 70 | 71 | /// Whether the keyword matching is case sensitive 72 | private let caseSensitive: Bool 73 | 74 | /// Creates a new keyword handoff filter 75 | /// - Parameters: 76 | /// - keywords: The keywords to trigger the handoff 77 | /// - caseSensitive: Whether the keyword matching is case sensitive 78 | public init(keywords: [String], caseSensitive: Bool = false) { 79 | self.keywords = keywords 80 | self.caseSensitive = caseSensitive 81 | } 82 | 83 | /// Determines whether to hand off to another agent based on keywords 84 | /// - Parameters: 85 | /// - input: The input to check 86 | /// - context: The context for the check 87 | /// - Returns: True if the input contains any of the keywords, false otherwise 88 | public func shouldHandoff(input: String, context: Context) -> Bool { 89 | let searchInput = caseSensitive ? input : input.lowercased() 90 | 91 | for keyword in keywords { 92 | let searchKeyword = caseSensitive ? keyword : keyword.lowercased() 93 | if searchInput.contains(searchKeyword) { 94 | return true 95 | } 96 | } 97 | 98 | return false 99 | } 100 | } 101 | 102 | /// A handoff filter that uses a custom function 103 | public struct CustomHandoffFilter: HandoffFilter { 104 | /// The function to determine whether to hand off 105 | private let filterFunction: (String, Context) -> Bool 106 | 107 | /// Creates a new custom handoff filter 108 | /// - Parameter filterFunction: The function to determine whether to hand off 109 | public init(filterFunction: @escaping (String, Context) -> Bool) { 110 | self.filterFunction = filterFunction 111 | } 112 | 113 | /// Determines whether to hand off to another agent using the custom function 114 | /// - Parameters: 115 | /// - input: The input to check 116 | /// - context: The context for the check 117 | /// - Returns: The result of the filter function 118 | public func shouldHandoff(input: String, context: Context) -> Bool { 119 | filterFunction(input, context) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Guardrail.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Marker protocol for guardrail errors. 4 | public enum GuardrailError: Error, Sendable { 5 | case invalidInput(reason: String) 6 | case invalidOutput(reason: String) 7 | } 8 | 9 | /// Protocol for enforcing constraints on agent input. 10 | public protocol InputGuardrail: Sendable { 11 | associatedtype Context 12 | func validate(_ input: String, context: Context) throws -> String 13 | } 14 | 15 | /// Protocol for enforcing constraints on agent output. 16 | public protocol OutputGuardrail: Sendable { 17 | associatedtype Context 18 | func validate(_ output: String, context: Context) throws -> String 19 | } 20 | 21 | /// Type-erased wrapper for input guardrails. 22 | public struct AnyInputGuardrail: Sendable { 23 | private let validator: @Sendable (String, Context) throws -> String 24 | 25 | /// Creates a type-erased wrapper around a strongly typed guardrail. 26 | /// - Parameter guardrail: The concrete guardrail to wrap. 27 | public init(_ guardrail: G) where G.Context == Context { 28 | validator = guardrail.validate 29 | } 30 | 31 | /// Creates a type-erased wrapper from a validation closure. 32 | /// - Parameter validator: Closure that performs validation for the provided context. 33 | public init(_ validator: @Sendable @escaping (String, Context) throws -> String) { 34 | self.validator = validator 35 | } 36 | 37 | /// Validates input text using the wrapped guardrail logic. 38 | /// - Parameters: 39 | /// - input: The input text to validate. 40 | /// - context: The associated context forwarded to the guardrail. 41 | /// - Returns: The validated (and potentially transformed) input string. 42 | /// - Throws: `GuardrailError` if the guardrail fails validation. 43 | public func validate(_ input: String, context: Context) throws -> String { 44 | try validator(input, context) 45 | } 46 | } 47 | 48 | /// Type-erased wrapper for output guardrails. 49 | public struct AnyOutputGuardrail: Sendable { 50 | private let validator: @Sendable (String, Context) throws -> String 51 | 52 | /// Creates a type-erased wrapper around a strongly typed guardrail. 53 | /// - Parameter guardrail: The concrete guardrail to wrap. 54 | public init(_ guardrail: G) where G.Context == Context { 55 | validator = guardrail.validate 56 | } 57 | 58 | /// Creates a type-erased wrapper from a validation closure. 59 | /// - Parameter validator: Closure that performs validation for the provided context. 60 | public init(_ validator: @Sendable @escaping (String, Context) throws -> String) { 61 | self.validator = validator 62 | } 63 | 64 | /// Validates output text using the wrapped guardrail logic. 65 | /// - Parameters: 66 | /// - output: The output text to validate. 67 | /// - context: The associated context forwarded to the guardrail. 68 | /// - Returns: The validated (and potentially transformed) output string. 69 | /// - Throws: `GuardrailError` if the guardrail fails validation. 70 | public func validate(_ output: String, context: Context) throws -> String { 71 | try validator(output, context) 72 | } 73 | } 74 | 75 | /// A guardrail that enforces constraints on input length. 76 | public struct InputLengthGuardrail: InputGuardrail { 77 | public typealias Context = Void 78 | 79 | private let maxLength: Int 80 | 81 | /// Creates a guardrail that enforces a maximum input length. 82 | /// - Parameter maxLength: The maximum number of characters permitted. 83 | public init(maxLength: Int) { 84 | self.maxLength = maxLength 85 | } 86 | 87 | /// Validates that the provided input does not exceed the configured maximum length. 88 | /// - Parameters: 89 | /// - input: The input text to validate. 90 | /// - context: The context supplied during validation (unused by default). 91 | /// - Returns: The original input if validation succeeds. 92 | /// - Throws: `GuardrailError.invalidInput` when the input is too long. 93 | public func validate(_ input: String, context: Context) throws -> String { 94 | if input.count > maxLength { 95 | throw GuardrailError.invalidInput( 96 | reason: "Input is too long. Maximum length is \(maxLength) characters." 97 | ) 98 | } 99 | return input 100 | } 101 | } 102 | 103 | /// A guardrail that enforces constraints on output content using a regular expression. 104 | public struct RegexContentGuardrail: OutputGuardrail { 105 | public typealias Context = Void 106 | 107 | private let regex: NSRegularExpression 108 | private let blockMatches: Bool 109 | 110 | /// Creates a guardrail that evaluates outputs against a regular expression. 111 | /// - Parameters: 112 | /// - pattern: The regex pattern used for validation. 113 | /// - blockMatches: When `true`, outputs matching the pattern are blocked; when `false`, 114 | /// outputs must contain a match. 115 | public init(pattern: String, blockMatches: Bool = true) throws { 116 | self.regex = try NSRegularExpression(pattern: pattern, options: []) 117 | self.blockMatches = blockMatches 118 | } 119 | 120 | /// Validates that the output satisfies the regex constraint. 121 | /// - Parameters: 122 | /// - output: The output text to validate. 123 | /// - context: Ignored placeholder context for protocol conformance. 124 | /// - Returns: The original output if validation succeeds. 125 | /// - Throws: `GuardrailError.invalidOutput` when the regex constraint is violated. 126 | public func validate(_ output: String, context: Context) throws -> String { 127 | let range = NSRange(location: 0, length: output.utf16.count) 128 | let matches = regex.matches(in: output, options: [], range: range) 129 | 130 | if blockMatches && !matches.isEmpty { 131 | throw GuardrailError.invalidOutput(reason: "Output contains blocked content.") 132 | } else if !blockMatches && matches.isEmpty { 133 | throw GuardrailError.invalidOutput(reason: "Output does not contain required content.") 134 | } 135 | 136 | return output 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Models/ModelInterface.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Protocol defining the interface for language models 4 | public protocol ModelInterface: Sendable { 5 | /// Gets a response from the model 6 | /// - Parameters: 7 | /// - messages: The messages to send to the model 8 | /// - settings: The settings to use for the model call 9 | /// - Returns: The model response 10 | func getResponse(messages: [Message], settings: ModelSettings) async throws -> ModelResponse 11 | 12 | /// Gets a streamed response from the model 13 | /// - Parameters: 14 | /// - messages: The messages to send to the model 15 | /// - settings: The settings to use for the model call 16 | /// - callback: The callback to call for each streamed chunk 17 | func getStreamedResponse( 18 | messages: [Message], 19 | settings: ModelSettings, 20 | callback: @escaping (ModelStreamEvent) async -> Void 21 | ) async throws -> ModelResponse 22 | } 23 | 24 | /// Represents a message for a model 25 | public struct Message { 26 | /// The role of the message sender 27 | public let role: Role 28 | 29 | /// The content of the message 30 | public let content: MessageContent 31 | 32 | /// Creates a new message 33 | /// - Parameters: 34 | /// - role: The role of the message sender 35 | /// - content: The content of the message 36 | public init(role: Role, content: MessageContent) { 37 | self.role = role 38 | self.content = content 39 | } 40 | 41 | /// Creates a new user message with text content 42 | /// - Parameter text: The text content 43 | /// - Returns: A new user message 44 | public static func user(_ text: String) -> Message { 45 | Message(role: .user, content: .text(text)) 46 | } 47 | 48 | /// Creates a new assistant message with text content 49 | /// - Parameter text: The text content 50 | /// - Returns: A new assistant message 51 | public static func assistant(_ text: String) -> Message { 52 | Message(role: .assistant, content: .text(text)) 53 | } 54 | 55 | /// Creates a new system message with text content 56 | /// - Parameter text: The text content 57 | /// - Returns: A new system message 58 | public static func system(_ text: String) -> Message { 59 | Message(role: .system, content: .text(text)) 60 | } 61 | 62 | /// Represents the role of a message sender 63 | public enum Role: String { 64 | case system 65 | case user 66 | case assistant 67 | case tool 68 | } 69 | } 70 | 71 | /// Represents the content of a message 72 | public enum MessageContent { 73 | case text(String) 74 | case toolResults(ToolResult) 75 | 76 | /// Represents the result of a tool call 77 | public struct ToolResult { 78 | /// The ID of the tool call 79 | public let toolCallId: String 80 | 81 | /// The result of the tool call 82 | public let result: String 83 | 84 | /// Creates a new tool result 85 | /// - Parameters: 86 | /// - toolCallId: The ID of the tool call 87 | /// - result: The result of the tool call 88 | public init(toolCallId: String, result: String) { 89 | self.toolCallId = toolCallId 90 | self.result = result 91 | } 92 | } 93 | } 94 | 95 | /// Represents a response from a model 96 | public struct ModelResponse { 97 | /// The generated text content 98 | public let content: String 99 | 100 | /// The tool calls made by the model 101 | public let toolCalls: [ToolCall] 102 | 103 | /// Whether the response was flagged for moderation 104 | public let flagged: Bool 105 | 106 | /// The reason the response was flagged, if applicable 107 | public let flaggedReason: String? 108 | 109 | /// Usage statistics for the model call 110 | public let usage: Usage? 111 | 112 | /// Creates a new model response 113 | /// - Parameters: 114 | /// - content: The generated text content 115 | /// - toolCalls: The tool calls made by the model 116 | /// - flagged: Whether the response was flagged for moderation 117 | /// - flaggedReason: The reason the response was flagged, if applicable 118 | /// - usage: Usage statistics for the model call 119 | public init( 120 | content: String, 121 | toolCalls: [ToolCall] = [], 122 | flagged: Bool = false, 123 | flaggedReason: String? = nil, 124 | usage: Usage? = nil 125 | ) { 126 | self.content = content 127 | self.toolCalls = toolCalls 128 | self.flagged = flagged 129 | self.flaggedReason = flaggedReason 130 | self.usage = usage 131 | } 132 | 133 | /// Represents a tool call made by the model 134 | public struct ToolCall { 135 | /// The ID of the tool call 136 | public let id: String 137 | 138 | /// The name of the tool being called 139 | public let name: String 140 | 141 | /// The parameters for the tool call 142 | public let parameters: [String: Any] 143 | 144 | /// Creates a new tool call 145 | /// - Parameters: 146 | /// - id: The ID of the tool call 147 | /// - name: The name of the tool being called 148 | /// - parameters: The parameters for the tool call 149 | public init(id: String, name: String, parameters: [String: Any]) { 150 | self.id = id 151 | self.name = name 152 | self.parameters = parameters 153 | } 154 | } 155 | 156 | /// Represents usage statistics for a model call 157 | public struct Usage { 158 | /// The number of prompt tokens used 159 | public let promptTokens: Int 160 | 161 | /// The number of completion tokens used 162 | public let completionTokens: Int 163 | 164 | /// The total number of tokens used 165 | public let totalTokens: Int 166 | 167 | /// Creates a new usage statistics object 168 | /// - Parameters: 169 | /// - promptTokens: The number of prompt tokens used 170 | /// - completionTokens: The number of completion tokens used 171 | /// - totalTokens: The total number of tokens used 172 | public init(promptTokens: Int, completionTokens: Int, totalTokens: Int) { 173 | self.promptTokens = promptTokens 174 | self.completionTokens = completionTokens 175 | self.totalTokens = totalTokens 176 | } 177 | } 178 | } 179 | 180 | /// Represents an event from a streamed model response 181 | public enum ModelStreamEvent { 182 | /// A content chunk was received 183 | case content(String) 184 | 185 | /// A tool call was received 186 | case toolCall(ModelResponse.ToolCall) 187 | 188 | /// The stream has ended 189 | case end 190 | } 191 | 192 | /// Factory for creating model instances by name 193 | public actor ModelProvider { 194 | /// The shared instance of the model provider 195 | public static let shared = ModelProvider() 196 | 197 | /// Dictionary mapping model names to factory functions 198 | private var modelFactories: [String: () -> ModelInterface] = [:] 199 | 200 | private init() {} 201 | 202 | /// Registers a model factory with the provider 203 | /// - Parameters: 204 | /// - modelName: The name of the model 205 | /// - factory: The factory function for creating the model 206 | public func register(modelName: String, factory: @escaping () -> ModelInterface) { 207 | modelFactories[modelName] = factory 208 | } 209 | 210 | /// Gets a model by name 211 | /// - Parameter modelName: The name of the model 212 | /// - Returns: The model instance 213 | /// - Throws: An error if the model is not registered 214 | public func getModel(modelName: String) throws -> ModelInterface { 215 | guard let factory = modelFactories[modelName] else { 216 | throw ModelProviderError.modelNotFound(modelName: modelName) 217 | } 218 | 219 | return factory() 220 | } 221 | 222 | /// Errors that can occur when using the model provider 223 | public enum ModelProviderError: Error { 224 | /// The requested model was not found 225 | case modelNotFound(modelName: String) 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Tool.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Represents a tool that can be used by an agent to perform actions 4 | public struct Tool: Sendable { 5 | /// Controls the availability of the tool for a given run context. 6 | public enum Availability: Sendable { 7 | case always 8 | case disabled 9 | case whenEnabled(@Sendable (_ context: RunContext) async -> Bool) 10 | 11 | func resolve(for context: RunContext) async -> Bool { 12 | switch self { 13 | case .always: 14 | return true 15 | case .disabled: 16 | return false 17 | case .whenEnabled(let closure): 18 | return await closure(context) 19 | } 20 | } 21 | } 22 | 23 | /// The name of the tool 24 | public let name: String 25 | 26 | /// A description of what the tool does 27 | public let description: String 28 | 29 | /// The parameters required by the tool 30 | public let parameters: [Parameter] 31 | 32 | /// Availability strategy for the tool 33 | public let availability: Availability 34 | 35 | /// The function to execute when the tool is called 36 | private let executeClosure: @Sendable (ToolParameters, RunContext) async throws -> Any 37 | 38 | /// Creates a new tool 39 | /// - Parameters: 40 | /// - name: The name of the tool 41 | /// - description: A description of what the tool does 42 | /// - parameters: The parameters required by the tool 43 | /// - availability: Availability strategy for the tool 44 | /// - execute: The function to execute when the tool is called that receives the run context 45 | public init( 46 | name: String, 47 | description: String, 48 | parameters: [Parameter] = [], 49 | availability: Availability = .always, 50 | executeClosure: @Sendable @escaping (ToolParameters, RunContext) async throws -> Any 51 | ) { 52 | self.name = name 53 | self.description = description 54 | self.parameters = parameters 55 | self.availability = availability 56 | self.executeClosure = executeClosure 57 | } 58 | 59 | /// Convenience initializer mirroring the previous signature that only exposed context value. 60 | /// - Parameters: 61 | /// - name: The name of the tool 62 | /// - description: A description of what the tool does 63 | /// - parameters: The parameters required by the tool 64 | /// - availability: Availability strategy for the tool 65 | /// - execute: The function to execute when the tool is called 66 | public init( 67 | name: String, 68 | description: String, 69 | parameters: [Parameter] = [], 70 | availability: Availability = .always, 71 | execute: @Sendable @escaping (ToolParameters, Context) async throws -> Any 72 | ) { 73 | self.init( 74 | name: name, 75 | description: description, 76 | parameters: parameters, 77 | availability: availability 78 | ) { params, runContext in 79 | try await execute(params, runContext.value) 80 | } 81 | } 82 | 83 | /// Executes the tool with the provided parameters and context 84 | /// - Parameters: 85 | /// - parameters: The parameters for the tool execution 86 | /// - runContext: The run context for the tool execution 87 | /// - Returns: The result of the tool execution 88 | public func invoke( 89 | parameters: ToolParameters, 90 | runContext: RunContext 91 | ) async throws -> Any { 92 | try await executeClosure(parameters, runContext) 93 | } 94 | 95 | /// Executes the tool with the provided parameters and context (backwards compatibility helper) 96 | /// - Parameters: 97 | /// - parameters: The parameters for the tool execution 98 | /// - context: The context for the tool execution 99 | /// - Returns: The result of the tool execution 100 | @available(*, deprecated, message: "Use invoke(parameters:runContext:) instead to access usage.") 101 | public func callAsFunction(_ parameters: ToolParameters, context: Context) async throws -> Any { 102 | let runContext = RunContext(value: context) 103 | return try await invoke(parameters: parameters, runContext: runContext) 104 | } 105 | 106 | /// Determines whether the tool is enabled for the provided run context. 107 | /// - Parameter runContext: The run context to check. 108 | /// - Returns: True if the tool may be invoked. 109 | public func isEnabled(for runContext: RunContext) async -> Bool { 110 | await availability.resolve(for: runContext) 111 | } 112 | 113 | /// Represents a parameter for a tool 114 | public struct Parameter: Sendable { 115 | /// The name of the parameter 116 | public let name: String 117 | 118 | /// A description of the parameter 119 | public let description: String 120 | 121 | /// The type of the parameter 122 | public let type: ParameterType 123 | 124 | /// Whether the parameter is required 125 | public let required: Bool 126 | 127 | /// Creates a new parameter 128 | /// - Parameters: 129 | /// - name: The name of the parameter 130 | /// - description: A description of the parameter 131 | /// - type: The type of the parameter 132 | /// - required: Whether the parameter is required 133 | public init(name: String, description: String, type: ParameterType, required: Bool = true) { 134 | self.name = name 135 | self.description = description 136 | self.type = type 137 | self.required = required 138 | } 139 | } 140 | 141 | /// Represents the type of a parameter 142 | public enum ParameterType: Sendable { 143 | case string 144 | case number 145 | case boolean 146 | case array 147 | case object 148 | 149 | /// Returns the string representation of the type for OpenAI 150 | public var jsonType: String { 151 | switch self { 152 | case .string: return "string" 153 | case .number: return "number" 154 | case .boolean: return "boolean" 155 | case .array: return "array" 156 | case .object: return "object" 157 | } 158 | } 159 | } 160 | } 161 | 162 | /// Represents the parameters passed to a tool 163 | public typealias ToolParameters = [String: Any] 164 | 165 | /// Creates a function tool from a function 166 | /// - Parameters: 167 | /// - name: The name of the tool 168 | /// - description: A description of what the tool does 169 | /// - function: The function to execute when the tool is called 170 | /// - Returns: A new function tool 171 | public func functionTool( 172 | name: String, 173 | description: String, 174 | availability: Tool.Availability = .always, 175 | function: @Sendable @escaping (Input, RunContext) async throws -> Output 176 | ) -> Tool { 177 | Tool(name: name, description: description, availability: availability) { parameters, runContext in 178 | // Convert parameters dictionary to Input type 179 | let data = try JSONSerialization.data(withJSONObject: parameters) 180 | let input = try JSONDecoder().decode(Input.self, from: data) 181 | 182 | // Call the function with the decoded input and run context 183 | return try await function(input, runContext) 184 | } 185 | } 186 | 187 | /// Convenience overload preserving the previous signature using only the context value. 188 | public func functionTool( 189 | name: String, 190 | description: String, 191 | availability: Tool.Availability = .always, 192 | function: @Sendable @escaping (Input, Context) async throws -> Output 193 | ) -> Tool { 194 | Tool(name: name, description: description, availability: availability) { parameters, runContext in 195 | // Convert parameters dictionary to Input type 196 | let data = try JSONSerialization.data(withJSONObject: parameters) 197 | let input = try JSONDecoder().decode(Input.self, from: data) 198 | 199 | // Call the function with the decoded input and context value 200 | return try await function(input, runContext.value) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Agent.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Represents the instructions backing an agent. 4 | public enum AgentInstructions: Sendable { 5 | case literal(String) 6 | case dynamic(@Sendable (_ context: RunContext, _ agent: Agent) async throws -> String) 7 | } 8 | 9 | /// Represents an AI agent capable of interacting with tools, handling conversations, 10 | /// and producing outputs based on instructions. 11 | public final class Agent { 12 | /// Describes how tool usage should influence the run loop. 13 | public enum ToolUseBehavior: Sendable { 14 | case runLLMAgain 15 | case stopOnFirstTool 16 | case stopAtTools(Set) 17 | case custom(@Sendable (_ context: RunContext, _ toolResults: [ToolCallResult]) async throws -> ToolsToFinalOutputResult) 18 | } 19 | 20 | /// Wraps the outcome of processing tool calls. 21 | public struct ToolsToFinalOutputResult: Sendable { 22 | public let isFinalOutput: Bool 23 | public let finalOutput: String? 24 | 25 | public init(isFinalOutput: Bool, finalOutput: String?) { 26 | self.isFinalOutput = isFinalOutput 27 | self.finalOutput = finalOutput 28 | } 29 | } 30 | 31 | /// Captures the result of a tool invocation for decision making. 32 | public struct ToolCallResult: Sendable { 33 | public let id: String 34 | public let name: String 35 | public let output: String 36 | 37 | public init(id: String, name: String, output: String) { 38 | self.id = id 39 | self.name = name 40 | self.output = output 41 | } 42 | } 43 | 44 | /// The name of the agent, used for identification 45 | public let name: AgentName 46 | 47 | /// Optional description used when referenced from handoffs or other agents. 48 | public var handoffDescription: String? 49 | 50 | /// Instructions that guide the agent's behavior. 51 | public var instructions: AgentInstructions? 52 | 53 | /// Tools available to the agent 54 | public private(set) var tools: [Tool] 55 | 56 | /// Guardrails that enforce constraints on agent input 57 | public private(set) var inputGuardrails: [AnyInputGuardrail] 58 | 59 | /// Guardrails that enforce constraints on agent output 60 | public private(set) var outputGuardrails: [AnyOutputGuardrail] 61 | 62 | /// Handoffs for delegating work to other agents 63 | public private(set) var handoffs: [Handoff] 64 | 65 | /// Settings for the model used by this agent 66 | public var modelSettings: ModelSettings 67 | 68 | /// Determines how tool use is handled for this agent. 69 | public var toolUseBehavior: ToolUseBehavior 70 | 71 | /// Whether to reset the tool choice back to default after a call. 72 | public var resetToolChoice: Bool 73 | 74 | /// Creates a new agent with the specified configuration 75 | /// - Parameters: 76 | /// - name: The name of the agent 77 | /// - instructions: Instructions for guiding agent behavior 78 | /// - handoffDescription: Optional description used for handoffs 79 | /// - tools: Optional array of tools available to the agent 80 | /// - inputGuardrails: Optional array of input guardrails for the agent 81 | /// - outputGuardrails: Optional array of output guardrails for the agent 82 | /// - handoffs: Optional array of handoffs for the agent 83 | /// - modelSettings: Optional model settings for the agent 84 | /// - toolUseBehavior: Strategy for handling tool calls 85 | /// - resetToolChoice: Whether tool choice should reset after invocation 86 | public init( 87 | name: AgentName, 88 | instructions: String, 89 | handoffDescription: String? = nil, 90 | tools: [Tool] = [], 91 | inputGuardrails: [AnyInputGuardrail] = [], 92 | outputGuardrails: [AnyOutputGuardrail] = [], 93 | handoffs: [Handoff] = [], 94 | modelSettings: ModelSettings = ModelSettings(), 95 | toolUseBehavior: ToolUseBehavior = .runLLMAgain, 96 | resetToolChoice: Bool = true 97 | ) { 98 | self.name = name 99 | self.instructions = .literal(instructions) 100 | self.handoffDescription = handoffDescription 101 | self.tools = tools 102 | self.inputGuardrails = inputGuardrails 103 | self.outputGuardrails = outputGuardrails 104 | self.handoffs = handoffs 105 | self.modelSettings = modelSettings 106 | self.toolUseBehavior = toolUseBehavior 107 | self.resetToolChoice = resetToolChoice 108 | } 109 | 110 | /// Alternate initializer for dynamic instructions while keeping other defaults. 111 | public init( 112 | name: AgentName, 113 | instructions: AgentInstructions?, 114 | handoffDescription: String? = nil, 115 | tools: [Tool] = [], 116 | inputGuardrails: [AnyInputGuardrail] = [], 117 | outputGuardrails: [AnyOutputGuardrail] = [], 118 | handoffs: [Handoff] = [], 119 | modelSettings: ModelSettings = ModelSettings(), 120 | toolUseBehavior: ToolUseBehavior = .runLLMAgain, 121 | resetToolChoice: Bool = true 122 | ) { 123 | self.name = name 124 | self.instructions = instructions 125 | self.handoffDescription = handoffDescription 126 | self.tools = tools 127 | self.inputGuardrails = inputGuardrails 128 | self.outputGuardrails = outputGuardrails 129 | self.handoffs = handoffs 130 | self.modelSettings = modelSettings 131 | self.toolUseBehavior = toolUseBehavior 132 | self.resetToolChoice = resetToolChoice 133 | } 134 | 135 | /// Adds a tool to the agent 136 | /// - Parameter tool: The tool to add 137 | /// - Returns: Self for method chaining 138 | @discardableResult 139 | public func addTool(_ tool: Tool) -> Self { 140 | tools.append(tool) 141 | return self 142 | } 143 | 144 | /// Adds multiple tools to the agent 145 | /// - Parameter tools: The tools to add 146 | /// - Returns: Self for method chaining 147 | @discardableResult 148 | public func addTools(_ tools: [Tool]) -> Self { 149 | self.tools.append(contentsOf: tools) 150 | return self 151 | } 152 | 153 | /// Adds an input guardrail to the agent. 154 | @discardableResult 155 | public func addInputGuardrail(_ guardrail: G) -> Self where G.Context == Context { 156 | inputGuardrails.append(AnyInputGuardrail(guardrail)) 157 | return self 158 | } 159 | 160 | /// Adds an output guardrail to the agent. 161 | @discardableResult 162 | public func addOutputGuardrail(_ guardrail: G) -> Self where G.Context == Context { 163 | outputGuardrails.append(AnyOutputGuardrail(guardrail)) 164 | return self 165 | } 166 | 167 | /// Adds a handoff to the agent 168 | /// - Parameter handoff: The handoff to add 169 | /// - Returns: Self for method chaining 170 | @discardableResult 171 | public func addHandoff(_ handoff: Handoff) -> Self { 172 | handoffs.append(handoff) 173 | return self 174 | } 175 | 176 | /// Creates a copy of this agent 177 | /// - Returns: A new agent with the same configuration 178 | public func clone( 179 | name: AgentName? = nil, 180 | instructions: AgentInstructions? = nil, 181 | handoffDescription: String? = nil, 182 | tools: [Tool]? = nil, 183 | inputGuardrails: [AnyInputGuardrail]? = nil, 184 | outputGuardrails: [AnyOutputGuardrail]? = nil, 185 | handoffs: [Handoff]? = nil, 186 | modelSettings: ModelSettings? = nil, 187 | toolUseBehavior: ToolUseBehavior? = nil, 188 | resetToolChoice: Bool? = nil 189 | ) -> Agent { 190 | Agent( 191 | name: name ?? self.name, 192 | instructions: instructions ?? self.instructions, 193 | handoffDescription: handoffDescription ?? self.handoffDescription, 194 | tools: tools ?? self.tools, 195 | inputGuardrails: inputGuardrails ?? self.inputGuardrails, 196 | outputGuardrails: outputGuardrails ?? self.outputGuardrails, 197 | handoffs: handoffs ?? self.handoffs, 198 | modelSettings: modelSettings ?? self.modelSettings, 199 | toolUseBehavior: toolUseBehavior ?? self.toolUseBehavior, 200 | resetToolChoice: resetToolChoice ?? self.resetToolChoice 201 | ) 202 | } 203 | 204 | /// Resolves the active instructions based on the run context. 205 | /// - Parameter runContext: The context of the current run. 206 | /// - Returns: The resolved instructions, if any. 207 | public func resolveInstructions(runContext: RunContext) async throws -> String? { 208 | switch instructions { 209 | case .literal(let value): 210 | return value 211 | case .dynamic(let closure): 212 | return try await closure(runContext, self) 213 | case .none: 214 | return nil 215 | } 216 | } 217 | 218 | /// Returns the tools enabled for the provided run context. 219 | /// - Parameter runContext: The current run context. 220 | /// - Returns: Enabled tools ready for invocation. 221 | public func enabledTools(for runContext: RunContext) async -> [Tool] { 222 | await withTaskGroup(of: (Int, Tool)?.self) { group in 223 | for (index, tool) in tools.enumerated() { 224 | group.addTask { 225 | if await tool.isEnabled(for: runContext) { 226 | return (index, tool) 227 | } 228 | return nil 229 | } 230 | } 231 | var enabled: [(Int, Tool)] = [] 232 | for await result in group { 233 | if let value = result { 234 | enabled.append(value) 235 | } 236 | } 237 | return enabled.sorted { $0.0 < $1.0 }.map { $0.1 } 238 | } 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Run.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Represents a single run of an agent 4 | public final class Run { 5 | /// The agent being run 6 | public let agent: Agent 7 | 8 | /// The input for the run 9 | public let input: String 10 | 11 | /// The run context wrapper containing the caller context and usage information 12 | public let runContext: RunContext 13 | 14 | /// The history of messages for the run 15 | public private(set) var messages: [Message] = [] 16 | 17 | /// The current state of the run 18 | public private(set) var state: State = .notStarted 19 | 20 | /// The model used for the run 21 | private let model: ModelInterface 22 | 23 | /// Maximum number of turns before giving up. Mirrors Python default. 24 | private let maxTurns: Int 25 | 26 | /// Creates a new run 27 | /// - Parameters: 28 | /// - agent: The agent to run 29 | /// - input: The input for the run 30 | /// - context: The context for the run 31 | /// - model: The model to use for the run 32 | /// - maxTurns: Maximum number of iterations allowed in the run loop 33 | public init( 34 | agent: Agent, 35 | input: String, 36 | context: Context, 37 | model: ModelInterface, 38 | maxTurns: Int = 10 39 | ) { 40 | self.agent = agent 41 | self.input = input 42 | self.model = model 43 | self.maxTurns = maxTurns 44 | self.runContext = RunContext(value: context) 45 | } 46 | 47 | /// Executes the run 48 | /// - Returns: The result of the run 49 | /// - Throws: RunError if there is a problem during execution 50 | public func execute() async throws -> Result { 51 | guard state == .notStarted else { 52 | throw RunError.invalidState("Run has already been started") 53 | } 54 | 55 | state = .running 56 | 57 | do { 58 | if let systemInstructions = try await agent.resolveInstructions(runContext: runContext) { 59 | messages.append(.system(systemInstructions)) 60 | } 61 | 62 | var validatedInput = input 63 | for guardrail in agent.inputGuardrails { 64 | do { 65 | validatedInput = try guardrail.validate(validatedInput, context: runContext.value) 66 | } catch let error as GuardrailError { 67 | state = .failed 68 | throw RunError.guardrailError(error) 69 | } 70 | } 71 | 72 | // Check for handoffs before running the agent 73 | for handoff in agent.handoffs { 74 | if handoff.filter.shouldHandoff(input: validatedInput, context: runContext.value) { 75 | let handoffRun = Run( 76 | agent: handoff.agent, 77 | input: validatedInput, 78 | context: runContext.value, 79 | model: model, 80 | maxTurns: maxTurns 81 | ) 82 | let result = try await handoffRun.execute() 83 | runContext.mergeUsage(from: handoffRun.runContext) 84 | messages = handoffRun.messages 85 | state = .completed 86 | return result 87 | } 88 | } 89 | 90 | messages.append(.user(validatedInput)) 91 | 92 | var currentTurn = 0 93 | while currentTurn < maxTurns { 94 | currentTurn += 1 95 | let enabledTools = await agent.enabledTools(for: runContext) 96 | let response = try await model.getResponse( 97 | messages: messages, 98 | settings: agent.modelSettings 99 | ) 100 | runContext.recordUsage(response.usage) 101 | messages.append(.assistant(response.content)) 102 | 103 | if response.toolCalls.isEmpty { 104 | var finalOutput = response.content 105 | for guardrail in agent.outputGuardrails { 106 | do { 107 | finalOutput = try guardrail.validate(finalOutput, context: runContext.value) 108 | } catch let error as GuardrailError { 109 | state = .failed 110 | throw RunError.guardrailError(error) 111 | } 112 | } 113 | state = .completed 114 | return Result(finalOutput: finalOutput, messages: messages, usage: runContext.usage) 115 | } 116 | 117 | let toolProcessing = try await processToolCalls( 118 | response.toolCalls, 119 | availableTools: enabledTools 120 | ) 121 | for message in toolProcessing.messageResults { 122 | messages.append(Message(role: .tool, content: .toolResults(message))) 123 | } 124 | 125 | if let finalFromTools = try await resolveToolBehavior( 126 | toolProcessing.callResults, 127 | behavior: agent.toolUseBehavior 128 | ) { 129 | var finalOutput = finalFromTools 130 | for guardrail in agent.outputGuardrails { 131 | do { 132 | finalOutput = try guardrail.validate(finalOutput, context: runContext.value) 133 | } catch let error as GuardrailError { 134 | state = .failed 135 | throw RunError.guardrailError(error) 136 | } 137 | } 138 | messages.append(.assistant(finalOutput)) 139 | state = .completed 140 | return Result(finalOutput: finalOutput, messages: messages, usage: runContext.usage) 141 | } 142 | 143 | // Continue loop with tool results appended to message history. 144 | } 145 | 146 | state = .failed 147 | throw RunError.maxTurnsExceeded(maxTurns) 148 | } catch let error as RunError { 149 | state = .failed 150 | throw error 151 | } catch { 152 | state = .failed 153 | throw RunError.executionError(error) 154 | } 155 | } 156 | 157 | private func processToolCalls( 158 | _ toolCalls: [ModelResponse.ToolCall], 159 | availableTools: [Tool] 160 | ) async throws -> ToolProcessingOutcome { 161 | let toolMap = Dictionary(uniqueKeysWithValues: availableTools.map { ($0.name, $0) }) 162 | var messageResults: [MessageContent.ToolResult] = [] 163 | var callResults: [Agent.ToolCallResult] = [] 164 | 165 | for toolCall in toolCalls { 166 | guard let tool = toolMap[toolCall.name] else { 167 | throw RunError.toolNotFound("Tool \(toolCall.name) not found") 168 | } 169 | 170 | do { 171 | let result = try await tool.invoke( 172 | parameters: toolCall.parameters, 173 | runContext: runContext 174 | ) 175 | let resultString = stringifyToolResult(result) 176 | let toolResult = MessageContent.ToolResult( 177 | toolCallId: toolCall.id, 178 | result: resultString 179 | ) 180 | messageResults.append(toolResult) 181 | callResults.append(Agent.ToolCallResult( 182 | id: toolCall.id, 183 | name: toolCall.name, 184 | output: resultString 185 | )) 186 | } catch { 187 | throw RunError.toolExecutionError(toolName: toolCall.name, error: error) 188 | } 189 | } 190 | 191 | return ToolProcessingOutcome( 192 | messageResults: messageResults, 193 | callResults: callResults 194 | ) 195 | } 196 | 197 | private func stringifyToolResult(_ result: Any) -> String { 198 | if let stringResult = result as? String { 199 | return stringResult 200 | } 201 | if let data = try? JSONSerialization.data(withJSONObject: result, options: [.prettyPrinted]), 202 | let jsonString = String(data: data, encoding: .utf8) { 203 | return jsonString 204 | } 205 | return String(describing: result) 206 | } 207 | 208 | private func resolveToolBehavior( 209 | _ toolResults: [Agent.ToolCallResult], 210 | behavior: Agent.ToolUseBehavior 211 | ) async throws -> String? { 212 | guard !toolResults.isEmpty else { return nil } 213 | switch behavior { 214 | case .runLLMAgain: 215 | return nil 216 | case .stopOnFirstTool: 217 | return toolResults.first?.output 218 | case .stopAtTools(let names): 219 | if let match = toolResults.first(where: { names.contains($0.name) }) { 220 | return match.output 221 | } 222 | return nil 223 | case .custom(let handler): 224 | let decision = try await handler(runContext, toolResults) 225 | return decision.isFinalOutput ? decision.finalOutput ?? toolResults.last?.output : nil 226 | } 227 | } 228 | 229 | /// Represents the result of a run 230 | public struct Result { 231 | /// The final output from the agent 232 | public let finalOutput: String 233 | 234 | /// The complete message history for the run 235 | public let messages: [Message] 236 | 237 | /// Aggregated usage information for the run 238 | public let usage: Usage 239 | } 240 | 241 | /// Represents the state of a run 242 | public enum State { 243 | case notStarted 244 | case running 245 | case completed 246 | case failed 247 | } 248 | 249 | /// Errors that can occur during a run 250 | public enum RunError: Error { 251 | case invalidState(String) 252 | case maxTurnsExceeded(Int) 253 | case guardrailError(GuardrailError) 254 | case toolNotFound(String) 255 | case toolExecutionError(toolName: String, error: Error) 256 | case executionError(Error) 257 | } 258 | 259 | private struct ToolProcessingOutcome { 260 | let messageResults: [MessageContent.ToolResult] 261 | let callResults: [Agent.ToolCallResult] 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/AgentRunner.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Static class for running agents 4 | public struct AgentRunner { 5 | /// Executes an agent using the configured model provider. 6 | /// - Parameters: 7 | /// - agent: The agent to run. 8 | /// - input: The user input that starts the conversation. 9 | /// - context: Arbitrary state passed through the run. 10 | /// - Returns: The completed run result containing the final output, messages, and usage. 11 | /// - Throws: `RunnerError` when model lookup, execution, or guardrail evaluation fails. 12 | public static func run( 13 | agent: Agent, 14 | input: String, 15 | context: Context 16 | ) async throws -> Run.Result { 17 | do { 18 | let model = try await ModelProvider.shared.getModel(modelName: agent.modelSettings.modelName) 19 | let run = Run(agent: agent, input: input, context: context, model: model) 20 | return try await run.execute() 21 | } catch let error as ModelProvider.ModelProviderError { 22 | throw RunnerError.modelError(error) 23 | } catch let error as Run.RunError { 24 | throw RunnerError.runError(error) 25 | } catch { 26 | throw RunnerError.unknownError(error) 27 | } 28 | } 29 | 30 | /// Executes an agent while streaming intermediate output chunks to the supplied handler. 31 | /// - Parameters: 32 | /// - agent: The agent to run. 33 | /// - input: The user input that starts the conversation. 34 | /// - context: Arbitrary state passed through the run. 35 | /// - streamHandler: Callback that receives streamed content chunks. 36 | /// - Returns: The completed run result containing the final output, messages, and usage. 37 | /// - Throws: `RunnerError` when model lookup, execution, or guardrail evaluation fails. 38 | public static func runStreamed( 39 | agent: Agent, 40 | input: String, 41 | context: Context, 42 | streamHandler: @escaping (String) async -> Void 43 | ) async throws -> Run.Result { 44 | do { 45 | let model = try await ModelProvider.shared.getModel(modelName: agent.modelSettings.modelName) 46 | let runContext = RunContext(value: context) 47 | return try await runStreamedInternal( 48 | agent: agent, 49 | input: input, 50 | runContext: runContext, 51 | model: model, 52 | settings: agent.modelSettings, 53 | streamHandler: streamHandler 54 | ) 55 | } catch let error as ModelProvider.ModelProviderError { 56 | throw RunnerError.modelError(error) 57 | } catch let error as Run.RunError { 58 | throw RunnerError.runError(error) 59 | } catch { 60 | throw RunnerError.unknownError(error) 61 | } 62 | } 63 | 64 | private static func runStreamedInternal( 65 | agent: Agent, 66 | input: String, 67 | runContext: RunContext, 68 | model: ModelInterface, 69 | settings: ModelSettings, 70 | streamHandler: @escaping (String) async -> Void, 71 | turn: Int = 0, 72 | maxTurns: Int = 10 73 | ) async throws -> Run.Result { 74 | guard turn < maxTurns else { 75 | throw Run.RunError.maxTurnsExceeded(maxTurns) 76 | } 77 | 78 | let systemInstructions = try await agent.resolveInstructions(runContext: runContext) 79 | 80 | var validatedInput = input 81 | for guardrail in agent.inputGuardrails { 82 | do { 83 | validatedInput = try guardrail.validate(validatedInput, context: runContext.value) 84 | } catch let error as GuardrailError { 85 | throw Run.RunError.guardrailError(error) 86 | } 87 | } 88 | 89 | for handoff in agent.handoffs { 90 | if handoff.filter.shouldHandoff(input: validatedInput, context: runContext.value) { 91 | let result = try await runStreamedInternal( 92 | agent: handoff.agent, 93 | input: validatedInput, 94 | runContext: runContext, 95 | model: model, 96 | settings: handoff.agent.modelSettings, 97 | streamHandler: streamHandler, 98 | turn: turn, 99 | maxTurns: maxTurns 100 | ) 101 | return result 102 | } 103 | } 104 | 105 | var messages: [Message] = [] 106 | if let systemInstructions { 107 | messages.append(.system(systemInstructions)) 108 | } 109 | messages.append(.user(validatedInput)) 110 | 111 | var currentTurn = turn 112 | while currentTurn < maxTurns { 113 | currentTurn += 1 114 | let enabledTools = await agent.enabledTools(for: runContext) 115 | var toolCalls: [ModelResponse.ToolCall] = [] 116 | let response = try await model.getStreamedResponse( 117 | messages: messages, 118 | settings: settings 119 | ) { event in 120 | switch event { 121 | case .content(let content): 122 | await streamHandler(content) 123 | case .toolCall(let toolCall): 124 | toolCalls.append(toolCall) 125 | case .end: 126 | break 127 | } 128 | } 129 | runContext.recordUsage(response.usage) 130 | messages.append(.assistant(response.content)) 131 | 132 | if response.toolCalls.isEmpty { 133 | var finalOutput = response.content 134 | for guardrail in agent.outputGuardrails { 135 | do { 136 | finalOutput = try guardrail.validate(finalOutput, context: runContext.value) 137 | } catch let error as GuardrailError { 138 | throw Run.RunError.guardrailError(error) 139 | } 140 | } 141 | return Run.Result( 142 | finalOutput: finalOutput, 143 | messages: messages, 144 | usage: runContext.usage 145 | ) 146 | } 147 | 148 | let processing = try await processToolCalls( 149 | toolCalls, 150 | enabledTools: enabledTools, 151 | runContext: runContext, 152 | streamHandler: streamHandler 153 | ) 154 | for toolMessage in processing.messageResults { 155 | messages.append(Message(role: .tool, content: .toolResults(toolMessage))) 156 | } 157 | 158 | if let finalFromTools = try await resolveToolBehavior( 159 | processing.callResults, 160 | behavior: agent.toolUseBehavior, 161 | runContext: runContext 162 | ) { 163 | var finalOutput = finalFromTools 164 | for guardrail in agent.outputGuardrails { 165 | do { 166 | finalOutput = try guardrail.validate(finalOutput, context: runContext.value) 167 | } catch let error as GuardrailError { 168 | throw Run.RunError.guardrailError(error) 169 | } 170 | } 171 | messages.append(.assistant(finalOutput)) 172 | await streamHandler(finalOutput) 173 | return Run.Result( 174 | finalOutput: finalOutput, 175 | messages: messages, 176 | usage: runContext.usage 177 | ) 178 | } 179 | } 180 | 181 | throw Run.RunError.maxTurnsExceeded(maxTurns) 182 | } 183 | 184 | private static func processToolCalls( 185 | _ toolCalls: [ModelResponse.ToolCall], 186 | enabledTools: [Tool], 187 | runContext: RunContext, 188 | streamHandler: @escaping (String) async -> Void 189 | ) async throws -> ToolProcessingOutcome { 190 | let toolMap = Dictionary(uniqueKeysWithValues: enabledTools.map { ($0.name, $0) }) 191 | var messageResults: [MessageContent.ToolResult] = [] 192 | var callResults: [Agent.ToolCallResult] = [] 193 | 194 | for toolCall in toolCalls { 195 | guard let tool = toolMap[toolCall.name] else { 196 | throw Run.RunError.toolNotFound("Tool \(toolCall.name) not found") 197 | } 198 | 199 | await streamHandler("\nExecuting tool: \(toolCall.name)...\n") 200 | do { 201 | let rawResult = try await tool.invoke( 202 | parameters: toolCall.parameters, 203 | runContext: runContext 204 | ) 205 | let resultString = stringifyToolResult(rawResult) 206 | await streamHandler("\nTool result: \(resultString)\n") 207 | let toolResult = MessageContent.ToolResult( 208 | toolCallId: toolCall.id, 209 | result: resultString 210 | ) 211 | messageResults.append(toolResult) 212 | callResults.append(Agent.ToolCallResult( 213 | id: toolCall.id, 214 | name: toolCall.name, 215 | output: resultString 216 | )) 217 | } catch { 218 | throw Run.RunError.toolExecutionError(toolName: toolCall.name, error: error) 219 | } 220 | } 221 | 222 | return ToolProcessingOutcome( 223 | messageResults: messageResults, 224 | callResults: callResults 225 | ) 226 | } 227 | 228 | private static func stringifyToolResult(_ result: Any) -> String { 229 | if let stringResult = result as? String { 230 | return stringResult 231 | } 232 | if let data = try? JSONSerialization.data(withJSONObject: result, options: [.prettyPrinted]), 233 | let jsonString = String(data: data, encoding: .utf8) { 234 | return jsonString 235 | } 236 | return String(describing: result) 237 | } 238 | 239 | private static func resolveToolBehavior( 240 | _ toolResults: [Agent.ToolCallResult], 241 | behavior: Agent.ToolUseBehavior, 242 | runContext: RunContext 243 | ) async throws -> String? { 244 | guard !toolResults.isEmpty else { return nil } 245 | switch behavior { 246 | case .runLLMAgain: 247 | return nil 248 | case .stopOnFirstTool: 249 | return toolResults.first?.output 250 | case .stopAtTools(let names): 251 | if let match = toolResults.first(where: { names.contains($0.name) }) { 252 | return match.output 253 | } 254 | return nil 255 | case .custom(let handler): 256 | let decision = try await handler(runContext, toolResults) 257 | return decision.isFinalOutput ? decision.finalOutput ?? toolResults.last?.output : nil 258 | } 259 | } 260 | 261 | public enum RunnerError: Error { 262 | case modelError(ModelProvider.ModelProviderError) 263 | case runError(any Error) 264 | case guardrailError(GuardrailError) 265 | case toolNotFound(String) 266 | case toolExecutionError(toolName: String, error: Error) 267 | case unknownError(Error) 268 | } 269 | 270 | private struct ToolProcessingOutcome { 271 | let messageResults: [MessageContent.ToolResult] 272 | let callResults: [Agent.ToolCallResult] 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/ModelSettings.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Settings for configuring model behavior 4 | public struct ModelSettings: Sendable { 5 | public enum ToolChoice: Sendable, Equatable { 6 | case auto 7 | case required 8 | case none 9 | case named(String) 10 | } 11 | 12 | public enum TruncationStrategy: String, Sendable, Equatable { 13 | case auto 14 | case disabled 15 | } 16 | 17 | public struct Reasoning: Sendable, Equatable { 18 | public enum Effort: String, Sendable { 19 | case minimal 20 | case low 21 | case medium 22 | case high 23 | } 24 | 25 | public var effort: Effort? 26 | 27 | public init(effort: Effort? = nil) { 28 | self.effort = effort 29 | } 30 | } 31 | 32 | public enum Verbosity: String, Sendable, Equatable { 33 | case low 34 | case medium 35 | case high 36 | } 37 | 38 | /// The name of the model to use 39 | public var modelName: String 40 | /// Temperature controls randomness (0.0 to 1.0) 41 | public var temperature: Double? 42 | /// Top-p controls diversity of output (0.0 to 1.0) 43 | public var topP: Double? 44 | /// Penalizes repeated tokens 45 | public var frequencyPenalty: Double? 46 | /// Encourages exploring new topics 47 | public var presencePenalty: Double? 48 | /// Configures how the model chooses tools 49 | public var toolChoice: ToolChoice? 50 | /// Enables multiple tool calls in a single turn when supported 51 | public var parallelToolCalls: Bool? 52 | /// Controls truncation behavior for long conversations 53 | public var truncation: TruncationStrategy? 54 | /// Maximum number of tokens to generate 55 | public var maxTokens: Int? 56 | /// Reasoning configuration for reasoning-capable models 57 | public var reasoning: Reasoning? 58 | /// Verbosity configuration for the response 59 | public var verbosity: Verbosity? 60 | /// Metadata forwarded to the model provider 61 | public var metadata: [String: String]? 62 | /// Whether to store the generated response server-side 63 | public var store: Bool? 64 | /// Whether to include usage statistics in responses 65 | public var includeUsage: Bool? 66 | /// Extra fields to include in the model response payload 67 | public var responseInclude: [String]? 68 | /// Number of top tokens to return log probabilities for 69 | public var topLogprobs: Int? 70 | /// Additional query parameters for the provider 71 | public var extraQuery: [String: String]? 72 | /// Additional body parameters for the provider 73 | public var extraBody: [String: String]? 74 | /// Additional headers for the provider 75 | public var extraHeaders: [String: String]? 76 | /// Response formats to use (e.g., JSON) 77 | public var responseFormat: ResponseFormat? 78 | /// Seeds for deterministic generation 79 | public var seed: Int? 80 | /// Additional model-specific parameters 81 | public var additionalParameters: [String: String] 82 | 83 | public init( 84 | modelName: String = "gpt-4.1", 85 | temperature: Double? = nil, 86 | topP: Double? = nil, 87 | frequencyPenalty: Double? = nil, 88 | presencePenalty: Double? = nil, 89 | toolChoice: ToolChoice? = nil, 90 | parallelToolCalls: Bool? = nil, 91 | truncation: TruncationStrategy? = nil, 92 | maxTokens: Int? = nil, 93 | reasoning: Reasoning? = nil, 94 | verbosity: Verbosity? = nil, 95 | metadata: [String: String]? = nil, 96 | store: Bool? = nil, 97 | includeUsage: Bool? = nil, 98 | responseInclude: [String]? = nil, 99 | topLogprobs: Int? = nil, 100 | extraQuery: [String: String]? = nil, 101 | extraBody: [String: String]? = nil, 102 | extraHeaders: [String: String]? = nil, 103 | responseFormat: ResponseFormat? = nil, 104 | seed: Int? = nil, 105 | additionalParameters: [String: String] = [:] 106 | ) { 107 | self.modelName = modelName 108 | self.temperature = temperature 109 | self.topP = topP 110 | self.frequencyPenalty = frequencyPenalty 111 | self.presencePenalty = presencePenalty 112 | self.toolChoice = toolChoice 113 | self.parallelToolCalls = parallelToolCalls 114 | self.truncation = truncation 115 | self.maxTokens = maxTokens 116 | self.reasoning = reasoning 117 | self.verbosity = verbosity 118 | self.metadata = metadata 119 | self.store = store 120 | self.includeUsage = includeUsage 121 | self.responseInclude = responseInclude 122 | self.topLogprobs = topLogprobs 123 | self.extraQuery = extraQuery 124 | self.extraBody = extraBody 125 | self.extraHeaders = extraHeaders 126 | self.responseFormat = responseFormat 127 | self.seed = seed 128 | self.additionalParameters = additionalParameters 129 | } 130 | 131 | /// Creates a copy of these settings with optional overrides 132 | public func with( 133 | modelName: String? = nil, 134 | temperature: Double? = nil, 135 | topP: Double? = nil, 136 | maxTokens: Int? = nil, 137 | responseFormat: ResponseFormat? = nil, 138 | seed: Int? = nil, 139 | additionalParameters: [String: String]? = nil, 140 | frequencyPenalty: Double? = nil, 141 | presencePenalty: Double? = nil, 142 | toolChoice: ToolChoice? = nil, 143 | parallelToolCalls: Bool? = nil, 144 | truncation: TruncationStrategy? = nil, 145 | reasoning: Reasoning? = nil, 146 | verbosity: Verbosity? = nil, 147 | metadata: [String: String]? = nil, 148 | store: Bool? = nil, 149 | includeUsage: Bool? = nil, 150 | responseInclude: [String]? = nil, 151 | topLogprobs: Int? = nil, 152 | extraQuery: [String: String]? = nil, 153 | extraBody: [String: String]? = nil, 154 | extraHeaders: [String: String]? = nil 155 | ) -> ModelSettings { 156 | var settings = self 157 | applyBasicOverrides( 158 | to: &settings, 159 | modelName: modelName, 160 | temperature: temperature, 161 | topP: topP, 162 | maxTokens: maxTokens, 163 | responseFormat: responseFormat, 164 | seed: seed, 165 | additionalParameters: additionalParameters 166 | ) 167 | applyAdvancedOverrides( 168 | to: &settings, 169 | frequencyPenalty: frequencyPenalty, 170 | presencePenalty: presencePenalty, 171 | toolChoice: toolChoice, 172 | parallelToolCalls: parallelToolCalls, 173 | truncation: truncation, 174 | reasoning: reasoning, 175 | verbosity: verbosity 176 | ) 177 | applyMetadataOverrides( 178 | to: &settings, 179 | overrides: MetadataOverrides( 180 | metadata: metadata, 181 | store: store, 182 | includeUsage: includeUsage, 183 | responseInclude: responseInclude, 184 | topLogprobs: topLogprobs, 185 | extraQuery: extraQuery, 186 | extraBody: extraBody, 187 | extraHeaders: extraHeaders 188 | ) 189 | ) 190 | return settings 191 | } 192 | 193 | private func applyBasicOverrides( 194 | to settings: inout ModelSettings, 195 | modelName: String?, 196 | temperature: Double?, 197 | topP: Double?, 198 | maxTokens: Int?, 199 | responseFormat: ResponseFormat?, 200 | seed: Int?, 201 | additionalParameters: [String: String]? 202 | ) { 203 | if let modelName { settings.modelName = modelName } 204 | if let temperature { settings.temperature = temperature } 205 | if let topP { settings.topP = topP } 206 | if let maxTokens { settings.maxTokens = maxTokens } 207 | if let responseFormat { settings.responseFormat = responseFormat } 208 | if let seed { settings.seed = seed } 209 | if let additionalParameters { settings.additionalParameters = additionalParameters } 210 | } 211 | 212 | private func applyAdvancedOverrides( 213 | to settings: inout ModelSettings, 214 | frequencyPenalty: Double?, 215 | presencePenalty: Double?, 216 | toolChoice: ToolChoice?, 217 | parallelToolCalls: Bool?, 218 | truncation: TruncationStrategy?, 219 | reasoning: Reasoning?, 220 | verbosity: Verbosity? 221 | ) { 222 | if let frequencyPenalty { settings.frequencyPenalty = frequencyPenalty } 223 | if let presencePenalty { settings.presencePenalty = presencePenalty } 224 | if let toolChoice { settings.toolChoice = toolChoice } 225 | if let parallelToolCalls { settings.parallelToolCalls = parallelToolCalls } 226 | if let truncation { settings.truncation = truncation } 227 | if let reasoning { settings.reasoning = reasoning } 228 | if let verbosity { settings.verbosity = verbosity } 229 | } 230 | 231 | private struct MetadataOverrides { 232 | let metadata: [String: String]? 233 | let store: Bool? 234 | let includeUsage: Bool? 235 | let responseInclude: [String]? 236 | let topLogprobs: Int? 237 | let extraQuery: [String: String]? 238 | let extraBody: [String: String]? 239 | let extraHeaders: [String: String]? 240 | } 241 | 242 | private func applyMetadataOverrides(to settings: inout ModelSettings, overrides: MetadataOverrides) { 243 | if let metadata = overrides.metadata { settings.metadata = metadata } 244 | if let store = overrides.store { settings.store = store } 245 | if let includeUsage = overrides.includeUsage { settings.includeUsage = includeUsage } 246 | if let responseInclude = overrides.responseInclude { settings.responseInclude = responseInclude } 247 | if let topLogprobs = overrides.topLogprobs { settings.topLogprobs = topLogprobs } 248 | if let extraQuery = overrides.extraQuery { settings.extraQuery = extraQuery } 249 | if let extraBody = overrides.extraBody { settings.extraBody = extraBody } 250 | if let extraHeaders = overrides.extraHeaders { settings.extraHeaders = extraHeaders } 251 | } 252 | 253 | /// Produces a new settings object by overlaying non-nil values from another instance. 254 | public func merged(with override: ModelSettings?) -> ModelSettings { 255 | guard let override else { return self } 256 | var merged = self 257 | mergeBasicSettings(to: &merged, from: override) 258 | mergeAdvancedSettings(to: &merged, from: override) 259 | mergeDictionarySettings(to: &merged, from: override) 260 | return merged 261 | } 262 | 263 | private func mergeBasicSettings(to merged: inout ModelSettings, from override: ModelSettings) { 264 | if override.modelName != self.modelName { merged.modelName = override.modelName } 265 | if let value = override.temperature { merged.temperature = value } 266 | if let value = override.topP { merged.topP = value } 267 | if let value = override.maxTokens { merged.maxTokens = value } 268 | if let value = override.responseFormat { merged.responseFormat = value } 269 | if let value = override.seed { merged.seed = value } 270 | if let value = override.frequencyPenalty { merged.frequencyPenalty = value } 271 | if let value = override.presencePenalty { merged.presencePenalty = value } 272 | } 273 | 274 | private func mergeAdvancedSettings(to merged: inout ModelSettings, from override: ModelSettings) { 275 | if let value = override.toolChoice { merged.toolChoice = value } 276 | if let value = override.parallelToolCalls { merged.parallelToolCalls = value } 277 | if let value = override.truncation { merged.truncation = value } 278 | if let value = override.reasoning { merged.reasoning = value } 279 | if let value = override.verbosity { merged.verbosity = value } 280 | if let value = override.metadata { merged.metadata = value } 281 | if let value = override.store { merged.store = value } 282 | if let value = override.includeUsage { merged.includeUsage = value } 283 | if let value = override.responseInclude { merged.responseInclude = value } 284 | if let value = override.topLogprobs { merged.topLogprobs = value } 285 | } 286 | 287 | private func mergeDictionarySettings(to merged: inout ModelSettings, from override: ModelSettings) { 288 | if let value = override.extraQuery { 289 | merged.extraQuery = merged.extraQuery?.merging(value, uniquingKeysWith: { _, new in new }) ?? value 290 | } 291 | if let value = override.extraBody { 292 | merged.extraBody = merged.extraBody?.merging(value, uniquingKeysWith: { _, new in new }) ?? value 293 | } 294 | if let value = override.extraHeaders { 295 | merged.extraHeaders = merged.extraHeaders?.merging(value, uniquingKeysWith: { _, new in new }) ?? value 296 | } 297 | if !override.additionalParameters.isEmpty { 298 | merged.additionalParameters = merged.additionalParameters.merging(override.additionalParameters) { _, new in new } 299 | } 300 | } 301 | 302 | /// Generates a serialisable representation compatible with provider SDKs. 303 | public func toDictionaryRepresentation() -> [String: Any] { 304 | var dict: [String: Any] = [:] 305 | addBasicParametersToDictionary(&dict) 306 | addAdvancedParametersToDictionary(&dict) 307 | addMetadataParametersToDictionary(&dict) 308 | addExtraParametersToDictionary(&dict) 309 | return dict 310 | } 311 | 312 | private func addBasicParametersToDictionary(_ dict: inout [String: Any]) { 313 | if let temperature { dict["temperature"] = temperature } 314 | if let topP { dict["top_p"] = topP } 315 | if let frequencyPenalty { dict["frequency_penalty"] = frequencyPenalty } 316 | if let presencePenalty { dict["presence_penalty"] = presencePenalty } 317 | if let maxTokens { dict["max_tokens"] = maxTokens } 318 | if let responseFormat { dict["response_format"] = responseFormat.jsonValue } 319 | if let seed { dict["seed"] = seed } 320 | } 321 | 322 | private func addAdvancedParametersToDictionary(_ dict: inout [String: Any]) { 323 | if let toolChoice { 324 | switch toolChoice { 325 | case .auto: dict["tool_choice"] = "auto" 326 | case .required: dict["tool_choice"] = "required" 327 | case .none: dict["tool_choice"] = "none" 328 | case .named(let name): dict["tool_choice"] = name 329 | } 330 | } 331 | if let parallelToolCalls { dict["parallel_tool_calls"] = parallelToolCalls } 332 | if let truncation { dict["truncation"] = truncation.rawValue } 333 | if let reasoning { 334 | var reasoningDict: [String: Any] = [:] 335 | if let effort = reasoning.effort { reasoningDict["effort"] = effort.rawValue } 336 | dict["reasoning"] = reasoningDict 337 | } 338 | if let verbosity { dict["verbosity"] = verbosity.rawValue } 339 | } 340 | 341 | private func addMetadataParametersToDictionary(_ dict: inout [String: Any]) { 342 | if let metadata { dict["metadata"] = metadata } 343 | if let store { dict["store"] = store } 344 | if let includeUsage { dict["include_usage"] = includeUsage } 345 | if let responseInclude { dict["response_include"] = responseInclude } 346 | if let topLogprobs { dict["top_logprobs"] = topLogprobs } 347 | } 348 | 349 | private func addExtraParametersToDictionary(_ dict: inout [String: Any]) { 350 | if let extraQuery { dict["extra_query"] = extraQuery } 351 | if let extraBody { dict["extra_body"] = extraBody } 352 | if let extraHeaders { dict["extra_headers"] = extraHeaders } 353 | if !additionalParameters.isEmpty { 354 | for (key, value) in additionalParameters { 355 | dict[key] = value 356 | } 357 | } 358 | } 359 | 360 | public enum ResponseFormat: Sendable { 361 | case json 362 | case text 363 | 364 | public var jsonValue: String { 365 | switch self { 366 | case .json: return "json_object" 367 | case .text: return "text" 368 | } 369 | } 370 | } 371 | } 372 | -------------------------------------------------------------------------------- /Tests/AgentSDK-SwiftTests/AgentSDK_SwiftTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | @testable import AgentSDK_Swift 3 | 4 | // MARK: - Agent Tests 5 | 6 | @Test func testAgentCreation() async throws { 7 | // Create a simple agent 8 | let agent = Agent( 9 | name: "TestAgent", 10 | instructions: "You are a helpful assistant." 11 | ) 12 | 13 | #expect(agent.name == "TestAgent") 14 | if case .literal(let instructions)? = agent.instructions { 15 | #expect(instructions == "You are a helpful assistant.") 16 | } else { 17 | #expect(Bool(false), "Instructions should be literal") 18 | } 19 | #expect(agent.tools.isEmpty) 20 | #expect(agent.inputGuardrails.isEmpty) 21 | #expect(agent.outputGuardrails.isEmpty) 22 | #expect(agent.handoffs.isEmpty) 23 | } 24 | 25 | @Test func testAgentCreationWithFullConfig() async throws { 26 | // Create tools 27 | let tool1 = Tool( 28 | name: "echo", 29 | description: "Echoes the input", 30 | execute: { params, _ in 31 | return params["text"] as? String ?? "No text provided" 32 | } 33 | ) 34 | 35 | let tool2 = Tool( 36 | name: "reverse", 37 | description: "Reverses the input", 38 | execute: { params, _ in 39 | let text = params["text"] as? String ?? "" 40 | return String(text.reversed()) 41 | } 42 | ) 43 | 44 | // Create guardrails 45 | let inputGuardrail = AnyInputGuardrail(InputLengthGuardrail(maxLength: 100)) 46 | 47 | // Create model settings 48 | let modelSettings = ModelSettings( 49 | modelName: "test-model", 50 | temperature: 0.7, 51 | topP: 0.9, 52 | maxTokens: 1000 53 | ) 54 | 55 | // Create agent with all components 56 | let agent = Agent( 57 | name: "FullConfigAgent", 58 | instructions: "You are a comprehensive test agent.", 59 | tools: [tool1, tool2], 60 | inputGuardrails: [inputGuardrail], 61 | modelSettings: modelSettings 62 | ) 63 | 64 | #expect(agent.name == "FullConfigAgent") 65 | if case .literal(let instructions)? = agent.instructions { 66 | #expect(instructions == "You are a comprehensive test agent.") 67 | } else { 68 | #expect(Bool(false), "Instructions should be literal") 69 | } 70 | #expect(agent.tools.count == 2) 71 | #expect(agent.tools[0].name == "echo") 72 | #expect(agent.tools[1].name == "reverse") 73 | #expect(agent.inputGuardrails.count == 1) 74 | #expect(agent.modelSettings.modelName == "test-model") 75 | #expect(agent.modelSettings.temperature == 0.7) 76 | #expect(agent.modelSettings.topP == 0.9) 77 | #expect(agent.modelSettings.maxTokens == 1000) 78 | } 79 | 80 | @Test func testAgentMethodChaining() async throws { 81 | // Create tools 82 | let tool1 = Tool( 83 | name: "echo", 84 | description: "Echoes the input", 85 | execute: { params, _ in 86 | return params["text"] as? String ?? "No text provided" 87 | } 88 | ) 89 | 90 | let tool2 = Tool( 91 | name: "reverse", 92 | description: "Reverses the input", 93 | execute: { params, _ in 94 | let text = params["text"] as? String ?? "" 95 | return String(text.reversed()) 96 | } 97 | ) 98 | 99 | // Create guardrails 100 | let inputGuardrail = InputLengthGuardrail(maxLength: 100) 101 | 102 | // Create agent with method chaining 103 | let agent = Agent(name: "ChainedAgent", instructions: "You are a method-chained agent.") 104 | .addTool(tool1) 105 | .addTool(tool2) 106 | .addInputGuardrail(inputGuardrail) 107 | 108 | #expect(agent.name == "ChainedAgent") 109 | #expect(agent.tools.count == 2) 110 | #expect(agent.inputGuardrails.count == 1) 111 | } 112 | 113 | @Test func testAgentClone() async throws { 114 | // Create initial agent 115 | let originalAgent = Agent( 116 | name: "OriginalAgent", 117 | instructions: "You are the original agent." 118 | ).addTool(Tool( 119 | name: "echo", 120 | description: "Echoes the input", 121 | execute: { params, _ in 122 | return params["text"] as? String ?? "" 123 | } 124 | )) 125 | 126 | // Clone the agent 127 | let clonedAgent = originalAgent.clone() 128 | 129 | // Verify the clone has the same properties 130 | #expect(clonedAgent.name == originalAgent.name) 131 | switch (clonedAgent.instructions, originalAgent.instructions) { 132 | case (.literal(let lhs)?, .literal(let rhs)?): 133 | #expect(lhs == rhs) 134 | case (.none, .none): 135 | break 136 | default: 137 | #expect(Bool(false), "Instructions mismatch") 138 | } 139 | #expect(clonedAgent.tools.count == originalAgent.tools.count) 140 | #expect(clonedAgent.tools[0].name == originalAgent.tools[0].name) 141 | 142 | // Verify that modifying the clone doesn't affect the original 143 | clonedAgent.addTool(Tool( 144 | name: "newTool", 145 | description: "A new tool", 146 | execute: { _, _ in return "result" } 147 | )) 148 | 149 | #expect(clonedAgent.tools.count == 2) 150 | #expect(originalAgent.tools.count == 1) 151 | } 152 | 153 | // MARK: - Tool Tests 154 | 155 | @Test func testToolCreation() async throws { 156 | // Create a simple tool 157 | let tool = Tool( 158 | name: "echo", 159 | description: "Echoes the input", 160 | parameters: [ 161 | Tool.Parameter( 162 | name: "text", 163 | description: "The text to echo", 164 | type: .string 165 | ) 166 | ], 167 | execute: { params, _ in 168 | return params["text"] as? String ?? "No text provided" 169 | } 170 | ) 171 | 172 | #expect(tool.name == "echo") 173 | #expect(tool.description == "Echoes the input") 174 | #expect(tool.parameters.count == 1) 175 | #expect(tool.parameters[0].name == "text") 176 | #expect(tool.parameters[0].description == "The text to echo") 177 | #expect(tool.parameters[0].type == .string) 178 | #expect(tool.parameters[0].required == true) 179 | } 180 | 181 | @Test func testToolParameterTypes() async throws { 182 | // Create a tool with different parameter types 183 | let tool = Tool( 184 | name: "multiTypeTest", 185 | description: "Tests different parameter types", 186 | parameters: [ 187 | Tool.Parameter(name: "stringParam", description: "A string", type: .string), 188 | Tool.Parameter(name: "numberParam", description: "A number", type: .number), 189 | Tool.Parameter(name: "boolParam", description: "A boolean", type: .boolean), 190 | Tool.Parameter(name: "arrayParam", description: "An array", type: .array), 191 | Tool.Parameter(name: "objectParam", description: "An object", type: .object), 192 | Tool.Parameter(name: "optionalParam", description: "Optional", type: .string, required: false) 193 | ], 194 | execute: { _, _ in return "result" } 195 | ) 196 | 197 | #expect(tool.parameters.count == 6) 198 | #expect(tool.parameters[0].type.jsonType == "string") 199 | #expect(tool.parameters[1].type.jsonType == "number") 200 | #expect(tool.parameters[2].type.jsonType == "boolean") 201 | #expect(tool.parameters[3].type.jsonType == "array") 202 | #expect(tool.parameters[4].type.jsonType == "object") 203 | #expect(tool.parameters[5].required == false) 204 | } 205 | 206 | @Test func testToolExecution() async throws { 207 | // Create a tool that performs an operation 208 | let calculator = Tool( 209 | name: "add", 210 | description: "Adds two numbers", 211 | parameters: [ 212 | Tool.Parameter(name: "paramA", description: "First number", type: .number), 213 | Tool.Parameter(name: "paramB", description: "Second number", type: .number) 214 | ], 215 | execute: { params, _ in 216 | // Integer numbers might be parsed as different numeric types 217 | // We convert everything to Int for consistency 218 | if let paramA = params["paramA"] as? Int, let paramB = params["paramB"] as? Int { 219 | return paramA + paramB 220 | } else if let paramA = params["paramA"] as? Double, let paramB = params["paramB"] as? Double { 221 | return Int(paramA + paramB) 222 | } else { 223 | return 0 224 | } 225 | } 226 | ) 227 | 228 | // Execute the tool 229 | let runContext = RunContext(value: ()) 230 | let result = try await calculator.invoke(parameters: ["paramA": 5, "paramB": 3], runContext: runContext) 231 | 232 | #expect(result as? Int == 8) 233 | } 234 | 235 | @Test func testAddingToolToAgent() async throws { 236 | // Create a simple agent 237 | let agent = Agent( 238 | name: "TestAgent", 239 | instructions: "You are a helpful assistant." 240 | ) 241 | 242 | // Create a simple tool 243 | let tool = Tool( 244 | name: "echo", 245 | description: "Echoes the input", 246 | execute: { params, _ in 247 | return params["text"] as? String ?? "No text provided" 248 | } 249 | ) 250 | 251 | // Add tool to agent 252 | let updatedAgent = agent.addTool(tool) 253 | 254 | #expect(updatedAgent.tools.count == 1) 255 | #expect(updatedAgent.tools[0].name == "echo") 256 | } 257 | 258 | @Test func testAddingMultipleToolsToAgent() async throws { 259 | // Create a simple agent 260 | let agent = Agent( 261 | name: "TestAgent", 262 | instructions: "You are a helpful assistant." 263 | ) 264 | 265 | // Create tools 266 | let tool1 = Tool(name: "tool1", description: "First tool", execute: { _, _ in return "1" }) 267 | let tool2 = Tool(name: "tool2", description: "Second tool", execute: { _, _ in return "2" }) 268 | let tool3 = Tool(name: "tool3", description: "Third tool", execute: { _, _ in return "3" }) 269 | 270 | // Add multiple tools at once 271 | let updatedAgent = agent.addTools([tool1, tool2, tool3]) 272 | 273 | #expect(updatedAgent.tools.count == 3) 274 | #expect(updatedAgent.tools[0].name == "tool1") 275 | #expect(updatedAgent.tools[1].name == "tool2") 276 | #expect(updatedAgent.tools[2].name == "tool3") 277 | } 278 | 279 | @Test func testTypedTool() async throws { 280 | // Define input and output using a simple struct 281 | struct AddInput: Codable { 282 | let inputA: Int 283 | let inputB: Int 284 | } 285 | 286 | // Create a tool with manual parameter handling 287 | let addTool = Tool( 288 | name: "add", 289 | description: "Adds two numbers", 290 | parameters: [ 291 | Tool.Parameter(name: "inputA", description: "First number", type: .number), 292 | Tool.Parameter(name: "inputB", description: "Second number", type: .number) 293 | ], 294 | execute: { params, _ in 295 | // Parse the parameters manually 296 | guard let inputA = params["inputA"] as? Int, 297 | let inputB = params["inputB"] as? Int else { 298 | return 0 299 | } 300 | return inputA + inputB 301 | } 302 | ) 303 | 304 | // Execute the tool 305 | let result = try await addTool.invoke(parameters: ["inputA": 10, "inputB": 20], runContext: RunContext(value: ())) 306 | 307 | #expect(result as? Int == 30) 308 | } 309 | 310 | // MARK: - Guardrail Tests 311 | 312 | @Test func testGuardrailValidation() async throws { 313 | // Create a simple input length guardrail 314 | let guardrail = InputLengthGuardrail(maxLength: 10) 315 | 316 | // Test valid input 317 | let validInput = "Hello" 318 | let processedInput = try guardrail.validate(validInput, context: ()) 319 | #expect(processedInput == validInput) 320 | 321 | // Test invalid input 322 | let invalidInput = "This is a very long input that exceeds the maximum length" 323 | do { 324 | _ = try guardrail.validate(invalidInput, context: ()) 325 | #expect(Bool(false), "Should have thrown an error") 326 | } catch let error as GuardrailError { 327 | switch error { 328 | case .invalidInput(let reason): 329 | #expect(reason.contains("Maximum length is 10")) 330 | default: 331 | #expect(Bool(false), "Wrong error type") 332 | } 333 | } 334 | } 335 | 336 | @Test func testRegexContentGuardrail() async throws { 337 | // Create a regex guardrail to block content containing "forbidden" 338 | let blockingGuardrail = try RegexContentGuardrail(pattern: "forbidden", blockMatches: true) 339 | 340 | // Test valid output (doesn't contain the blocked word) 341 | let validOutput = "This is an allowed message" 342 | let processedOutput = try blockingGuardrail.validate(validOutput, context: ()) 343 | #expect(processedOutput == validOutput) 344 | 345 | // Test invalid output (contains the blocked word) 346 | let invalidOutput = "This message contains forbidden content" 347 | do { 348 | _ = try blockingGuardrail.validate(invalidOutput, context: ()) 349 | #expect(Bool(false), "Should have thrown an error") 350 | } catch let error as GuardrailError { 351 | switch error { 352 | case .invalidOutput(let reason): 353 | #expect(reason.contains("blocked content")) 354 | default: 355 | #expect(Bool(false), "Wrong error type") 356 | } 357 | } 358 | 359 | // Create a regex guardrail to require content matching "required" 360 | let requiringGuardrail = try RegexContentGuardrail(pattern: "required", blockMatches: false) 361 | 362 | // Test valid output (contains the required word) 363 | let validRequiredOutput = "This message contains required content" 364 | let processedRequiredOutput = try requiringGuardrail.validate(validRequiredOutput, context: ()) 365 | #expect(processedRequiredOutput == validRequiredOutput) 366 | 367 | // Test invalid output (doesn't contain the required word) 368 | let invalidRequiredOutput = "This message doesn't have the necessary text" 369 | do { 370 | _ = try requiringGuardrail.validate(invalidRequiredOutput, context: ()) 371 | #expect(Bool(false), "Should have thrown an error") 372 | } catch let error as GuardrailError { 373 | switch error { 374 | case .invalidOutput(let reason): 375 | #expect(reason.contains("required content")) 376 | default: 377 | #expect(Bool(false), "Wrong error type") 378 | } 379 | } 380 | } 381 | 382 | // MARK: - Model Settings Tests 383 | 384 | @Test func testModelSettingsCreation() async throws { 385 | // Create model settings with all parameters 386 | let settings = ModelSettings( 387 | modelName: "test-model", 388 | temperature: 0.8, 389 | topP: 0.95, 390 | maxTokens: 2000, 391 | responseFormat: .json, 392 | seed: 12345, 393 | additionalParameters: ["custom": "value"] 394 | ) 395 | 396 | #expect(settings.modelName == "test-model") 397 | #expect(settings.temperature == 0.8) 398 | #expect(settings.topP == 0.95) 399 | #expect(settings.maxTokens == 2000) 400 | #expect(settings.responseFormat == .json) 401 | #expect(settings.seed == 12345) 402 | #expect(settings.additionalParameters["custom"] == "value") 403 | } 404 | 405 | @Test func testDefaultModelSettings() async throws { 406 | // Create model settings with defaults 407 | let settings = ModelSettings() 408 | 409 | #expect(settings.modelName == "gpt-4.1") 410 | #expect(settings.temperature == nil) 411 | #expect(settings.topP == nil) 412 | #expect(settings.maxTokens == nil) 413 | #expect(settings.responseFormat == nil) 414 | #expect(settings.seed == nil) 415 | #expect(settings.additionalParameters.isEmpty) 416 | } 417 | 418 | @Test func testResponseFormatJsonValue() async throws { 419 | // Test JSON value for text response format 420 | let textFormat = ModelSettings.ResponseFormat.text 421 | #expect(textFormat.jsonValue == "text") 422 | 423 | // Test JSON value for JSON response format 424 | let jsonFormat = ModelSettings.ResponseFormat.json 425 | #expect(jsonFormat.jsonValue == "json_object") 426 | } 427 | 428 | @Test func testUpdateModelSettings() async throws { 429 | // Create initial settings 430 | var settings = ModelSettings(modelName: "initial-model", temperature: 0.7) 431 | 432 | // Update settings 433 | settings.modelName = "updated-model" 434 | settings.temperature = 0.9 435 | settings.maxTokens = 500 436 | 437 | #expect(settings.modelName == "updated-model") 438 | #expect(settings.temperature == 0.9) 439 | #expect(settings.maxTokens == 500) 440 | } 441 | -------------------------------------------------------------------------------- /Sources/AgentSDK-Swift/Models/OpenAIModel.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | #if canImport(FoundationNetworking) 3 | import FoundationNetworking 4 | #endif 5 | import OpenAPIRuntime 6 | 7 | /// Implementation of ModelInterface for OpenAI models 8 | public final class OpenAIModel: ModelInterface { 9 | /// The API key for OpenAI 10 | private let apiKey: String 11 | 12 | /// The API base URL 13 | private let apiBaseURL: URL 14 | 15 | /// The URL session used for network requests 16 | private let urlSession: URLSession 17 | 18 | /// Creates a new OpenAI model 19 | /// - Parameters: 20 | /// - apiKey: The API key for OpenAI 21 | /// - apiBaseURL: The API base URL (defaults to OpenAI's API) 22 | /// - urlSession: Optional custom URL session 23 | public init( 24 | apiKey: String, 25 | apiBaseURL: URL = URL(string: "https://api.openai.com/v1")!, 26 | urlSession: URLSession? = nil 27 | ) { 28 | self.apiKey = apiKey 29 | self.apiBaseURL = apiBaseURL 30 | self.urlSession = urlSession ?? URLSession.shared 31 | } 32 | 33 | /// Gets a response from the model 34 | /// - Parameters: 35 | /// - messages: The messages to send to the model 36 | /// - settings: The settings to use for the model call 37 | /// - Returns: The model response 38 | public func getResponse(messages: [Message], settings: ModelSettings) async throws -> ModelResponse { 39 | let requestBody = try createRequestBody(messages: messages, settings: settings) 40 | 41 | let endpoint = "\(apiBaseURL)/chat/completions" 42 | 43 | // Create request 44 | var request = createURLRequest(url: endpoint) 45 | 46 | // Add request body 47 | let bodyData = try JSONEncoder().encode(requestBody) 48 | request.httpBody = bodyData 49 | 50 | // Send request 51 | let (data, response) = try await urlSession.data(for: request, delegate: nil) 52 | 53 | // Check response status 54 | guard let httpResponse = response as? HTTPURLResponse, 55 | httpResponse.statusCode == 200 else { 56 | let errorString = String(data: data, encoding: .utf8) ?? "Unknown error" 57 | let statusCode = (response as? HTTPURLResponse)?.statusCode ?? 0 58 | throw OpenAIModelError.requestFailed(statusCode: statusCode, message: errorString) 59 | } 60 | 61 | // Parse response 62 | let openAIResponse = try JSONDecoder().decode(ChatCompletionResponse.self, from: data) 63 | 64 | return try convertResponse(openAIResponse) 65 | } 66 | 67 | /// Gets a streamed response from the model 68 | /// - Parameters: 69 | /// - messages: The messages to send to the model 70 | /// - settings: The settings to use for the model call 71 | /// - callback: The callback to call for each streamed chunk 72 | public func getStreamedResponse( 73 | messages: [Message], 74 | settings: ModelSettings, 75 | callback: @escaping (ModelStreamEvent) async -> Void 76 | ) async throws -> ModelResponse { 77 | let data = try await performStreamRequest(messages: messages, settings: settings) 78 | return try await processStreamedData(data, callback: callback) 79 | } 80 | 81 | private func performStreamRequest(messages: [Message], settings: ModelSettings) async throws -> Data { 82 | var requestBody = try createRequestBody(messages: messages, settings: settings) 83 | requestBody.stream = true 84 | 85 | let endpoint = "\(apiBaseURL)/chat/completions" 86 | var request = createURLRequest(url: endpoint) 87 | 88 | let bodyData = try JSONEncoder().encode(requestBody) 89 | request.httpBody = bodyData 90 | 91 | let (data, response) = try await urlSession.data(for: request) 92 | 93 | guard let httpResponse = response as? HTTPURLResponse, 94 | httpResponse.statusCode == 200 else { 95 | let errorString = String(data: data, encoding: .utf8) ?? "Unknown error" 96 | let statusCode = (response as? HTTPURLResponse)?.statusCode ?? 0 97 | throw OpenAIModelError.requestFailed(statusCode: statusCode, message: errorString) 98 | } 99 | 100 | return data 101 | } 102 | 103 | private func processStreamedData( 104 | _ data: Data, 105 | callback: @escaping (ModelStreamEvent) async -> Void 106 | ) async throws -> ModelResponse { 107 | var contentBuffer = "" 108 | var toolCalls: [ModelResponse.ToolCall] = [] 109 | 110 | guard let responseStr = String(data: data, encoding: .utf8) else { 111 | return ModelResponse(content: contentBuffer, toolCalls: toolCalls) 112 | } 113 | 114 | let lines = responseStr.split(separator: "\n") 115 | for line in lines { 116 | if line.hasPrefix("data: ") { 117 | let dataContent = line.dropFirst(6) 118 | 119 | if dataContent == "[DONE]" { 120 | await callback(.end) 121 | continue 122 | } 123 | 124 | try await processStreamChunk( 125 | dataContent, 126 | contentBuffer: &contentBuffer, 127 | toolCalls: &toolCalls, 128 | callback: callback 129 | ) 130 | } 131 | } 132 | 133 | return ModelResponse(content: contentBuffer, toolCalls: toolCalls) 134 | } 135 | 136 | private func processStreamChunk( 137 | _ dataContent: Substring, 138 | contentBuffer: inout String, 139 | toolCalls: inout [ModelResponse.ToolCall], 140 | callback: @escaping (ModelStreamEvent) async -> Void 141 | ) async throws { 142 | do { 143 | let chunkData = Data(dataContent.utf8) 144 | let chunkResponse = try JSONDecoder().decode(ChatCompletionChunk.self, from: chunkData) 145 | 146 | guard let choice = chunkResponse.choices.first else { return } 147 | 148 | if let content = choice.delta.content, !content.isEmpty { 149 | contentBuffer += content 150 | await callback(.content(content)) 151 | } 152 | 153 | if let toolCall = choice.delta.toolCalls?.first { 154 | await processToolCallDelta(toolCall, toolCalls: &toolCalls, callback: callback) 155 | } 156 | } catch { 157 | // Ignore partial JSON errors 158 | } 159 | } 160 | 161 | private func processToolCallDelta( 162 | _ toolCall: ChatCompletionChunk.ToolCall, 163 | toolCalls: inout [ModelResponse.ToolCall], 164 | callback: @escaping (ModelStreamEvent) async -> Void 165 | ) async { 166 | if let existingIndex = toolCalls.firstIndex(where: { $0.id == toolCall.id }) { 167 | updateExistingToolCall(at: existingIndex, with: toolCall, toolCalls: &toolCalls) 168 | } else if let id = toolCall.id, let function = toolCall.function, let name = function.name { 169 | let newToolCall = createNewToolCall(id: id, function: function, name: name) 170 | toolCalls.append(newToolCall) 171 | await callback(.toolCall(newToolCall)) 172 | } 173 | } 174 | 175 | private func updateExistingToolCall( 176 | at index: Int, 177 | with toolCall: ChatCompletionChunk.ToolCall, 178 | toolCalls: inout [ModelResponse.ToolCall] 179 | ) { 180 | let existingToolCall = toolCalls[index] 181 | var params = existingToolCall.parameters 182 | 183 | guard let function = toolCall.function else { return } 184 | 185 | if let name = function.name { 186 | toolCalls[index] = ModelResponse.ToolCall(id: existingToolCall.id, name: name, parameters: params) 187 | } 188 | 189 | if let arguments = function.arguments { 190 | do { 191 | if let jsonData = arguments.data(using: .utf8), 192 | let jsonParams = try JSONSerialization.jsonObject(with: jsonData) as? [String: Any] { 193 | for (key, value) in jsonParams { 194 | params[key] = value 195 | } 196 | toolCalls[index] = ModelResponse.ToolCall( 197 | id: existingToolCall.id, 198 | name: existingToolCall.name, 199 | parameters: params 200 | ) 201 | } 202 | } catch { 203 | // Ignore parsing errors for partial JSON 204 | } 205 | } 206 | } 207 | 208 | private func createNewToolCall(id: String, function: ChatCompletionChunk.Function, name: String) -> ModelResponse.ToolCall { 209 | var params: [String: Any] = [:] 210 | 211 | if let arguments = function.arguments { 212 | do { 213 | if let jsonData = arguments.data(using: .utf8), 214 | let jsonParams = try JSONSerialization.jsonObject(with: jsonData) as? [String: Any] { 215 | params = jsonParams 216 | } 217 | } catch { 218 | // Ignore parsing errors for partial JSON 219 | } 220 | } 221 | 222 | return ModelResponse.ToolCall(id: id, name: name, parameters: params) 223 | } 224 | 225 | /// Creates a URLRequest configured with the appropriate headers 226 | /// - Parameter url: The URL string for the request 227 | /// - Returns: A configured URLRequest 228 | private func createURLRequest(url: String) -> URLRequest { 229 | var request = URLRequest(url: URL(string: url)!) 230 | request.httpMethod = "POST" 231 | request.addValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") 232 | request.addValue("application/json", forHTTPHeaderField: "Content-Type") 233 | request.timeoutInterval = 600 // 10 minute timeout 234 | return request 235 | } 236 | 237 | /// Creates a request body for the OpenAI API 238 | /// - Parameters: 239 | /// - messages: The messages to send to the model 240 | /// - settings: The settings to use for the model call 241 | /// - Returns: The request body 242 | private func createRequestBody(messages: [Message], settings: ModelSettings) throws -> ChatCompletionRequest { 243 | // Convert messages to OpenAI format 244 | let openAIMessages = messages.map { message -> ChatMessage in 245 | let role = message.role.rawValue 246 | 247 | switch message.content { 248 | case .text(let text): 249 | return ChatMessage(role: role, content: text) 250 | 251 | case .toolResults(let toolResult): 252 | return ChatMessage( 253 | role: role, 254 | toolCallId: toolResult.toolCallId, 255 | content: toolResult.result 256 | ) 257 | } 258 | } 259 | 260 | // Convert tools to OpenAI format 261 | let tools: [OpenAITool]? = nil // Implement tool conversion if needed 262 | 263 | // Create request body 264 | var request = ChatCompletionRequest( 265 | model: settings.modelName, 266 | messages: openAIMessages, 267 | tools: tools 268 | ) 269 | 270 | // Add optional parameters from settings 271 | if let temperature = settings.temperature { 272 | request.temperature = temperature 273 | } 274 | 275 | if let topP = settings.topP { 276 | request.topP = topP 277 | } 278 | 279 | if let maxTokens = settings.maxTokens { 280 | request.maxTokens = maxTokens 281 | } 282 | 283 | if let responseFormat = settings.responseFormat { 284 | request.responseFormat = ["type": responseFormat.jsonValue] 285 | } 286 | 287 | if let seed = settings.seed { 288 | request.seed = seed 289 | } 290 | 291 | // Add any additional parameters 292 | for (_, _) in settings.additionalParameters { 293 | // This is a simplification - in a real implementation, we would need to properly handle adding these params 294 | } 295 | 296 | return request 297 | } 298 | 299 | /// Converts an OpenAI response to a ModelResponse 300 | /// - Parameter response: The OpenAI response 301 | /// - Returns: The converted ModelResponse 302 | private func convertResponse(_ response: ChatCompletionResponse) throws -> ModelResponse { 303 | guard let choice = response.choices.first else { 304 | throw OpenAIModelError.emptyResponse 305 | } 306 | 307 | // Get content 308 | let content = choice.message.content ?? "" 309 | 310 | // Get tool calls if any 311 | var toolCalls: [ModelResponse.ToolCall] = [] 312 | 313 | if let openAIToolCalls = choice.message.toolCalls { 314 | for toolCall in openAIToolCalls { 315 | do { 316 | let arguments = toolCall.function.arguments 317 | let argsData = arguments.data(using: .utf8) ?? Data() 318 | let params = try JSONSerialization.jsonObject(with: argsData) as? [String: Any] ?? [:] 319 | 320 | toolCalls.append(ModelResponse.ToolCall( 321 | id: toolCall.id, 322 | name: toolCall.function.name, 323 | parameters: params 324 | )) 325 | } catch { 326 | throw OpenAIModelError.invalidToolCallArguments(error) 327 | } 328 | } 329 | } 330 | 331 | // Get usage statistics 332 | var usage: ModelResponse.Usage? 333 | if let responseUsage = response.usage { 334 | usage = ModelResponse.Usage( 335 | promptTokens: responseUsage.promptTokens, 336 | completionTokens: responseUsage.completionTokens, 337 | totalTokens: responseUsage.totalTokens 338 | ) 339 | } 340 | 341 | return ModelResponse( 342 | content: content, 343 | toolCalls: toolCalls, 344 | usage: usage 345 | ) 346 | } 347 | 348 | /// Errors that can occur when using the OpenAI model 349 | public enum OpenAIModelError: Error { 350 | case requestFailed(statusCode: Int, message: String) 351 | case emptyResponse 352 | case invalidToolCallArguments(Error) 353 | } 354 | 355 | } 356 | 357 | // MARK: - OpenAI API Types 358 | 359 | /// Request for the OpenAI chat completions API 360 | private struct ChatCompletionRequest: Encodable { 361 | let model: String 362 | let messages: [ChatMessage] 363 | let tools: [OpenAITool]? 364 | var temperature: Double? 365 | var topP: Double? 366 | var maxTokens: Int? 367 | var responseFormat: [String: String]? 368 | var seed: Int? 369 | var stream: Bool = false 370 | } 371 | 372 | /// Tool for the OpenAI chat completions API 373 | private struct OpenAITool: Encodable { 374 | let type: String 375 | let function: FunctionDefinition 376 | } 377 | 378 | /// Message for the OpenAI chat completions API 379 | private struct ChatMessage: Encodable { 380 | let role: String 381 | var content: String? 382 | var toolCallId: String? 383 | 384 | init(role: String, content: String) { 385 | self.role = role 386 | self.content = content 387 | } 388 | 389 | init(role: String, toolCallId: String, content: String) { 390 | self.role = role 391 | self.toolCallId = toolCallId 392 | self.content = content 393 | } 394 | } 395 | 396 | /// Function definition for the OpenAI chat completions API 397 | private struct FunctionDefinition: Encodable { 398 | let name: String 399 | let description: String 400 | let parameters: [String: Any] 401 | 402 | enum CodingKeys: String, CodingKey { 403 | case name, description, parameters 404 | } 405 | 406 | func encode(to encoder: Encoder) throws { 407 | var container = encoder.container(keyedBy: CodingKeys.self) 408 | try container.encode(name, forKey: .name) 409 | try container.encode(description, forKey: .description) 410 | 411 | // Encode parameters dictionary as a raw JSON string 412 | let parametersData = try JSONSerialization.data(withJSONObject: parameters) 413 | let parametersString = String(data: parametersData, encoding: .utf8) ?? "{}" 414 | try container.encode(parametersString, forKey: .parameters) 415 | } 416 | } 417 | 418 | /// Response from the OpenAI chat completions API 419 | private struct ChatCompletionResponse: Decodable { 420 | let id: String 421 | let object: String 422 | let created: Int 423 | let model: String 424 | let choices: [Choice] 425 | let usage: Usage? 426 | 427 | struct Choice: Decodable { 428 | let index: Int 429 | let message: Message 430 | let finishReason: String 431 | } 432 | 433 | struct Message: Decodable { 434 | let role: String 435 | let content: String? 436 | let toolCalls: [ToolCall]? 437 | } 438 | 439 | struct ToolCall: Decodable { 440 | let id: String 441 | let type: String 442 | let function: Function 443 | } 444 | 445 | struct Function: Decodable { 446 | let name: String 447 | let arguments: String 448 | } 449 | 450 | struct Usage: Decodable { 451 | let promptTokens: Int 452 | let completionTokens: Int 453 | let totalTokens: Int 454 | } 455 | } 456 | 457 | /// Chunk response from the OpenAI chat completions API when streaming 458 | private struct ChatCompletionChunk: Decodable { 459 | let id: String 460 | let object: String 461 | let created: Int 462 | let model: String 463 | let choices: [Choice] 464 | 465 | struct Choice: Decodable { 466 | let index: Int 467 | let delta: Delta 468 | let finishReason: String? 469 | } 470 | 471 | struct Delta: Decodable { 472 | let role: String? 473 | let content: String? 474 | let toolCalls: [ToolCall]? 475 | 476 | enum CodingKeys: String, CodingKey { 477 | case role, content 478 | case toolCalls = "tool_calls" 479 | } 480 | } 481 | 482 | struct ToolCall: Decodable { 483 | let id: String? 484 | let type: String? 485 | let function: Function? 486 | 487 | enum CodingKeys: String, CodingKey { 488 | case id, type, function 489 | } 490 | } 491 | 492 | struct Function: Decodable { 493 | let name: String? 494 | let arguments: String? 495 | 496 | enum CodingKeys: String, CodingKey { 497 | case name, arguments 498 | } 499 | } 500 | } 501 | 502 | /// Extension to register OpenAI models with the model provider 503 | public extension ModelProvider { 504 | /// Registers OpenAI models with the model provider 505 | /// - Parameter apiKey: The API key for OpenAI 506 | func registerOpenAIModels(apiKey: String) { 507 | // Register default OpenAI models 508 | register(modelName: "gpt-4-turbo") { 509 | OpenAIModel(apiKey: apiKey) 510 | } 511 | 512 | register(modelName: "gpt-4") { 513 | OpenAIModel(apiKey: apiKey) 514 | } 515 | 516 | register(modelName: "gpt-3.5-turbo") { 517 | OpenAIModel(apiKey: apiKey) 518 | } 519 | } 520 | } 521 | --------------------------------------------------------------------------------