├── .github └── pull_request_template.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ai.go ├── ai_test.go ├── aitesting ├── aitesting.go ├── aitesting_test.go └── types.go ├── api ├── embedding_model.go ├── embedding_options.go ├── err_ai_sdk.go ├── err_api_call.go ├── err_empty_response_body.go ├── err_invalid_argument.go ├── err_invalid_prompt.go ├── err_invalid_response_data.go ├── err_json_parse.go ├── err_load_api_key.go ├── err_load_setting.go ├── err_no_content_generated.go ├── err_no_such_model.go ├── err_too_many_embedding_values_for_call.go ├── err_type_validation.go ├── err_unsupported_functionality.go ├── image_model.go ├── image_model_call_options.go ├── image_model_call_warning.go ├── llm_call_options.go ├── llm_call_warning.go ├── llm_events.go ├── llm_finish_reason.go ├── llm_language_model.go ├── llm_logprobs.go ├── llm_object_generation_mode.go ├── llm_prompt.go ├── llm_provider_metadata.go ├── llm_source.go ├── llm_tool.go └── provider.go ├── builder ├── builder.go ├── builder_test.go └── convert.go ├── default.go ├── examples ├── README.md └── basic │ └── simple-text │ └── main.go ├── go.mod ├── go.sum ├── options.go ├── options_test.go └── provider ├── anthropic ├── codec │ ├── decode.go │ ├── decode_test.go │ ├── encode_params.go │ ├── encode_prompt.go │ ├── encode_prompt_test.go │ ├── encode_tools.go │ ├── encode_tools_test.go │ ├── merge_messages.go │ ├── merge_messages_test.go │ ├── metadata.go │ └── tools.go ├── constants.go ├── llm.go ├── llm_test.go ├── plan │ ├── api.md │ ├── llm.ts │ └── prompt.md └── tools.go ├── internal └── openrouter │ ├── client │ ├── finish_reason.go │ ├── logprob.go │ ├── prompt.go │ ├── prompt_json.go │ ├── prompt_test.go │ ├── response.go │ ├── settings.go │ └── settings_test.go │ ├── codec │ ├── decode_finish.go │ ├── decode_finish_test.go │ ├── decode_logprobs.go │ ├── decode_logprobs_test.go │ ├── encode_prompt.go │ └── encode_prompt_test.go │ ├── convert_to_openrouter_completion_prompt.go │ ├── convert_to_openrouter_completion_prompt_test.go │ ├── model │ ├── doc.go │ └── model.go │ ├── openrouter_chat_language_model.go │ ├── openrouter_chat_language_model_test.go │ ├── openrouter_error.go │ ├── openrouter_error_test.go │ ├── openrouter_provider.go │ └── ptr.go └── openai ├── constants.go ├── internal └── codec │ ├── decode.go │ ├── decode_stream.go │ ├── decode_stream_test.go │ ├── decode_test.go │ ├── encode.go │ ├── encode_prompt.go │ ├── encode_prompt_test.go │ ├── encode_tools.go │ ├── encode_tools_test.go │ ├── metadata.go │ └── tools.go ├── llm.go ├── llm_test.go ├── metadata.go └── tools.go /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # ⚠️ Submit PRs in our opensource monorepo instead ⚠️ 2 | 3 | This repository is automatically published from our opensource monorepo: 4 | https://github.com/jetify-com/opensource 5 | 6 | If you want to contribute code changes to this project, please submit your 7 | PR via the monorepo. 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement. Use the 63 | "Report to repository admins" functionality on GitHub to report. 64 | 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][faq]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 131 | [mozilla coc]: https://github.com/mozilla/diversity 132 | [faq]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | When contributing to this repository, please describe the change you wish to 4 | make via a related issue, or a pull request. 5 | 6 | Please note we have a [code of conduct](CODE_OF_CONDUCT.md), please follow it in 7 | all your interactions with the project. 8 | 9 | ## Opening a Pull Request 10 | 11 | This project is published as a standalone repo from our 12 | [opensource monorepo](https://github.com/jetify-com/opensource). Pull requests 13 | should be sent to the monorepo instead, and they will automatically be published 14 | to this repo when merged. 15 | 16 | Contributions made to this project must be made under the terms of the 17 | [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0). 18 | By contributing to this project you agree to the terms stated in the 19 | [Community Contribution License](https://github.com/jetify-com/opensource/blob/main/CONTRIBUTING.md#community-contribution-license). 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI SDK for Go 2 | 3 | ### Build powerful AI applications and agents using a unified API. 4 | 5 | [![Version](https://img.shields.io/github/v/release/jetify-com/ai?color=green&label=version&sort=semver)](https://github.com/jetify-com/ai/releases) 6 | [![Go Reference](https://pkg.go.dev/badge/go.jetify.com/ai)](https://pkg.go.dev/go.jetify.com/ai) 7 | [![License](https://img.shields.io/github/license/jetify-com/ai)]() 8 | [![Join Discord](https://img.shields.io/discord/903306922852245526?color=7389D8&label=discord&logo=discord&logoColor=ffffff&cacheSeconds=1800)](https://discord.gg/jetify) 9 | 10 | *Primary Author(s)*: [Daniel Loreto](https://github.com/loreto) 11 | 12 | ## Introduction 13 | 14 | Jetify's **AI SDK for Go** is a unified interface for interacting with multiple AI providers including OpenAI, Anthropic, and more. 15 | Inspired by [Vercel's AI SDK](https://github.com/vercel/ai) for TypeScript, we bring a similar developer experience to the Go ecosystem. 16 | 17 | It is maintained and developed by [Jetify](https://www.jetify.com). We are in the process of migrating our production code 18 | to use this SDK as the primary way our AI agents integrate with different LLM providers. 19 | 20 | ### The Problem 21 | 22 | Building AI applications go today means dealing with: 23 | - **Fragmented ecosystems** - Each provider has different APIs, authentication, and patterns 24 | - **Vendor lock-in** - Switching providers requires rewriting significant application code 25 | - **Poor Go developer experience** - Official Go SDKs are often auto-generated from OpenAPI specs, resulting in unidiomatic Go code 26 | - **Complex multi-modal handling** - Different providers handle images, files, and tools differently 27 | 28 | ### Our Solution 29 | 30 | The AI SDK provides a **unified interface** across multiple AI providers, with key advantages: 31 | 32 | 1. **Provider abstraction** - Common interfaces for language models, embeddings, and image generation 33 | 2. **Go-first design** - Built specifically for Go developers with idiomatic patterns and strong typing 34 | 3. **Production-ready** - Comprehensive error handling, automatic retries, rate limiting, and robust provider failover 35 | 4. **Multi-modal by default** - First-class support for text, images, files, and structured outputs across all providers 36 | 5. **Extensible architecture** - Clean interfaces make it easy to add new providers while maintaining backward compatibility 37 | 38 | ## Features 39 | 40 | * [x] **Multi-Provider Support** – [OpenAI](#), [Anthropic](#), with more coming 41 | * [x] **Multi-Modal Inputs** – Text, images, and files in conversations 42 | * [x] **Tool Calling** – Function calling with parallel execution 43 | * [x] **Language Models** – Text generation with streaming support 44 | * [ ] **Embedding Models** – Text embeddings for semantic search 45 | * [ ] **Image Models** – Generate images from text prompts 46 | * [ ] **Structured Outputs** – JSON generation with schema validation 47 | 48 | ### Language Models 49 | 50 | * [x] Text generation (streaming & non-streaming) 51 | * [x] Multi-modal conversations (text + images + files) 52 | * [x] System messages and conversation history 53 | * [x] Tool/function calling with structured schemas 54 | * [ ] JSON output with schema validation 55 | 56 | ### Provider-Specific Features 57 | 58 | * [x] **OpenAI** - Web search, computer use, file search tools 59 | * [x] **Anthropic** - Claude's advanced reasoning and tool use 60 | 61 | ## Status 62 | 63 | - [x] Private Alpha: We are testing the SDK with a select group of developers. 64 | - [x] Public Alpha: Open to all developers, but breaking changes still expected. 65 | - [ ] Public Beta: Stable enough for most non-enterprise use cases. 66 | - [ ] General Availability (v1): Ready for production use at scale with guaranteed API stability. 67 | 68 | We are currently in **Public Alpha**. The SDK functionality is stable but the API may have breaking changes. While in alpha, minor version bumps indicate breaking changes (`0.1.0` -> `0.2.0` would indicate a breaking change). Watch "releases" of this repo to get notified of major updates. 69 | 70 | ## Installation 71 | 72 | ```bash 73 | go get go.jetify.com/ai 74 | ``` 75 | 76 | ## Quickstart 77 | 78 | Get started with a simple text generation example: 79 | 80 | ```go 81 | package main 82 | 83 | import ( 84 | "context" 85 | "fmt" 86 | "log" 87 | 88 | "go.jetify.com/ai" 89 | "go.jetify.com/ai/provider/openai" 90 | ) 91 | 92 | func main() { 93 | // Set up your model 94 | model := openai.NewLanguageModel("gpt-4o") 95 | 96 | // Generate text 97 | response, err := ai.GenerateTextStr( 98 | context.Background(), 99 | "Explain quantum computing in simple terms", 100 | ai.WithModel(model), 101 | ai.WithMaxTokens(200), 102 | ) 103 | if err != nil { 104 | log.Fatal(err) 105 | } 106 | 107 | // Do whatever you want with the response... 108 | } 109 | ``` 110 | 111 | For detailed examples, see our [examples directory](examples/). 112 | 113 | ## Documentation 114 | 115 | Comprehensive documentation is available: 116 | 117 | * **[API Reference](https://pkg.go.dev/go.jetify.com/ai)** - Complete Go package documentation 118 | * **[Examples](examples/)** - Real-world usage patterns 119 | 120 | ## Community & Support 121 | 122 | Join our community and get help: 123 | 124 | * **Discord** – [https://discord.gg/jetify](https://discord.gg/jetify) (best for quick questions & showcase) 125 | * **GitHub Discussions** – [Discussions](https://github.com/jetify-com/ai/discussions) (best for ideas & design questions) 126 | * **Issues** – [Bug reports & feature requests](https://github.com/jetify-com/ai/issues) 127 | 128 | ## Contributing 129 | 130 | We 💖 contributions! Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines. 131 | 132 | ## License 133 | 134 | Licensed under the **Apache 2.0 License** – see [LICENSE](LICENSE) for details. 135 | -------------------------------------------------------------------------------- /ai.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | 6 | "go.jetify.com/ai/api" 7 | ) 8 | 9 | // GenerateText uses a language model to generate a text response from a given prompt. 10 | // 11 | // This function does not stream its output. 12 | // 13 | // It returns a [api.Response] containing the generated text, the results of 14 | // any tool calls, and additional information. 15 | // 16 | // A prompt is a sequence of [api.Message]s: 17 | // 18 | // GenerateText(ctx, []api.Message{ 19 | // &api.UserMessage{ 20 | // Content: []api.ContentBlock{ 21 | // &api.TextBlock{Text: "Show me a picture of a cat"}, 22 | // }, 23 | // }, 24 | // &api.AssistantMessage{ 25 | // Content: []api.ContentBlock{ 26 | // &api.TextBlock{Text: "Here is a picture of a cat"}, 27 | // &api.ImageBlock{URL: "https://example.com/cat.png"}, 28 | // }, 29 | // }, 30 | // }) 31 | // 32 | // The last argument can optionally be a series of [GenerateOption] arguments: 33 | // 34 | // GenerateText(ctx, messages, WithMaxTokens(100)) 35 | func GenerateText(ctx context.Context, prompt []api.Message, opts ...GenerateOption) (api.Response, error) { 36 | config := buildGenerateConfig(opts) 37 | return generate(ctx, prompt, config) 38 | } 39 | 40 | // GenerateTextStr uses a language model to generate a text response from a given string prompt. 41 | // 42 | // It is a convenience wrapper around GenerateText for simple string-based prompts. 43 | // 44 | // Example usage: 45 | // 46 | // GenerateTextStr(ctx, "Write a brief summary of the benefits of renewable energy") 47 | // 48 | // The function can optionally take [GenerateOption] arguments: 49 | // 50 | // GenerateTextStr(ctx, "Explain the key differences between REST and GraphQL APIs", WithMaxTokens(500)) 51 | // 52 | // The string prompt is automatically converted to a [api.UserMessage] before 53 | // being passed to GenerateText. 54 | func GenerateTextStr(ctx context.Context, prompt string, opts ...GenerateOption) (api.Response, error) { 55 | msg := api.UserMessage{ 56 | Content: []api.ContentBlock{api.TextBlock{Text: prompt}}, 57 | } 58 | return GenerateText(ctx, []api.Message{msg}, opts...) 59 | } 60 | 61 | func generate(ctx context.Context, prompt []api.Message, opts GenerateOptions) (api.Response, error) { 62 | return opts.Model.Generate(ctx, prompt, opts.CallOptions) 63 | } 64 | -------------------------------------------------------------------------------- /ai_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | -------------------------------------------------------------------------------- /aitesting/types.go: -------------------------------------------------------------------------------- 1 | package aitesting 2 | 3 | import "go.jetify.com/ai/api" 4 | 5 | // MockMetadataSource implements api.MetadataSource for testing 6 | type MockMetadataSource struct { 7 | ProviderMetadata *api.ProviderMetadata 8 | } 9 | 10 | func (m *MockMetadataSource) GetProviderMetadata() *api.ProviderMetadata { 11 | return m.ProviderMetadata 12 | } 13 | 14 | // MockUnsupportedMessage implements api.Message but is not a known type 15 | type MockUnsupportedMessage struct{} 16 | 17 | func (m *MockUnsupportedMessage) Role() api.MessageRole { return "unsupported" } 18 | func (m *MockUnsupportedMessage) GetProviderMetadata() *api.ProviderMetadata { return nil } 19 | 20 | // MockUnsupportedBlock implements api.ContentBlock for testing unsupported content types 21 | type MockUnsupportedBlock struct{} 22 | 23 | func (m *MockUnsupportedBlock) Type() api.ContentBlockType { return "unsupported" } 24 | func (m *MockUnsupportedBlock) GetProviderMetadata() *api.ProviderMetadata { return nil } 25 | -------------------------------------------------------------------------------- /api/embedding_model.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // Embedding is a vector, i.e. an array of numbers. 8 | // It is e.g. used to represent a text as a vector of word embeddings. 9 | type Embedding []float64 10 | 11 | // EmbeddingModel is a specification for an embedding model that implements the embedding model 12 | // interface version 1. 13 | // 14 | // T is the type of the values that the model can embed. 15 | // This will allow us to go beyond text embeddings in the future, 16 | // e.g. to support image embeddings 17 | type EmbeddingModel[T any] interface { 18 | // SpecificationVersion returns which embedding model interface version is implemented. 19 | // This will allow us to evolve the embedding model interface and retain backwards 20 | // compatibility. The different implementation versions can be handled as a discriminated 21 | // union on our side. 22 | SpecificationVersion() string 23 | 24 | // ProviderName returns the name of the provider for logging purposes. 25 | ProviderName() string 26 | 27 | // ModelID returns the provider-specific model ID for logging purposes. 28 | ModelID() string 29 | 30 | // MaxEmbeddingsPerCall returns the limit of how many embeddings can be generated in a single API call. 31 | MaxEmbeddingsPerCall() *int 32 | 33 | // SupportsParallelCalls returns if the model can handle multiple embedding calls in parallel. 34 | SupportsParallelCalls() bool 35 | 36 | // DoEmbed generates a list of embeddings for the given input values. 37 | // 38 | // Naming: "do" prefix to prevent accidental direct usage of the method 39 | // by the user. 40 | DoEmbed(ctx context.Context, values []T, opts ...EmbeddingOption) EmbeddingResponse 41 | } 42 | 43 | // EmbeddingResponse represents the response from generating embeddings. 44 | type EmbeddingResponse struct { 45 | // Embeddings are the generated embeddings. They are in the same order as the input values. 46 | Embeddings []Embedding 47 | 48 | // Usage contains token usage information. We only have input tokens for embeddings. 49 | Usage *EmbeddingUsage 50 | 51 | // RawResponse contains optional raw response information for debugging purposes. 52 | RawResponse *EmbeddingRawResponse 53 | } 54 | 55 | // EmbeddingUsage represents token usage information. 56 | type EmbeddingUsage struct { 57 | Tokens int 58 | } 59 | 60 | // EmbeddingRawResponse contains raw response information for debugging. 61 | type EmbeddingRawResponse struct { 62 | // Headers are the response headers. 63 | Headers map[string]string 64 | } 65 | -------------------------------------------------------------------------------- /api/embedding_options.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // EmbeddingOption represent the options for generating embeddings. 4 | type EmbeddingOption func(*EmbeddingOptions) 5 | 6 | // WithEmbeddingHeaders sets HTTP headers to be sent with the request. 7 | // Only applicable for HTTP-based providers. 8 | func WithEmbeddingHeaders(headers map[string]string) EmbeddingOption { 9 | return func(o *EmbeddingOptions) { 10 | o.Headers = headers 11 | } 12 | } 13 | 14 | // EmbeddingOptions represents the options for generating embeddings. 15 | type EmbeddingOptions struct { 16 | // Headers are additional HTTP headers to be sent with the request. 17 | // Only applicable for HTTP-based providers. 18 | Headers map[string]string 19 | } 20 | -------------------------------------------------------------------------------- /api/err_ai_sdk.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // AISDKError is a custom error class for AI SDK related errors. 4 | type AISDKError struct { 5 | // Name is the name of the error 6 | Name string 7 | 8 | // Message is the error message 9 | Message string 10 | 11 | // Cause is the underlying cause of the error, if any 12 | Cause any 13 | } 14 | 15 | // Error implements the error interface 16 | func (e *AISDKError) Error() string { 17 | return e.Message 18 | } 19 | 20 | // NewAISDKError creates an AI SDK Error. 21 | // Parameters: 22 | // - name: The name of the error. 23 | // - message: The error message. 24 | // - cause: The underlying cause of the error. 25 | func NewAISDKError(name string, message string, cause any) *AISDKError { 26 | return &AISDKError{ 27 | Name: name, 28 | Message: message, 29 | Cause: cause, 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /api/err_api_call.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | ) 7 | 8 | // APICallError represents an error that occurred during an API call 9 | type APICallError struct { 10 | *AISDKError 11 | 12 | // URL of the API endpoint that was called 13 | URL *url.URL 14 | 15 | // Request contains the original HTTP request 16 | Request *http.Request 17 | 18 | // StatusCode is the HTTP status code of the response if any 19 | StatusCode int 20 | 21 | // Response contains the original HTTP response 22 | Response *http.Response 23 | 24 | // Data contains additional error data, if any 25 | Data any 26 | } 27 | 28 | // IsRetryable indicates whether the request can be retried 29 | // Returns true for status codes: 408 (timeout), 409 (conflict), 429 (too many requests), or 5xx (server errors) 30 | func (e *APICallError) IsRetryable() bool { 31 | return e.StatusCode == http.StatusRequestTimeout || e.StatusCode == http.StatusConflict || e.StatusCode == http.StatusTooManyRequests || e.StatusCode >= 500 32 | } 33 | 34 | // TODO: 35 | // - Consider providing a constructor that takes a request and response, 36 | // and initializes the fields from the request and response. 37 | // - Better approach to handling Data 38 | // - Should http.Request be a shallow copy with headers removed? (it might 39 | // otherwise expose sensitive information) 40 | -------------------------------------------------------------------------------- /api/err_empty_response_body.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // EmptyResponseBodyError indicates that the response body is empty 4 | type EmptyResponseBodyError struct { 5 | *AISDKError 6 | } 7 | 8 | // NewEmptyResponseBodyError creates a new EmptyResponseBodyError instance 9 | // Parameters: 10 | // - message: The error message (optional, defaults to "Empty response body") 11 | func NewEmptyResponseBodyError(message string) *EmptyResponseBodyError { 12 | if message == "" { 13 | message = "Empty response body" 14 | } 15 | return &EmptyResponseBodyError{ 16 | AISDKError: NewAISDKError("AI_EmptyResponseBodyError", message, nil), 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /api/err_invalid_argument.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // InvalidArgumentError indicates that a function argument is invalid 4 | type InvalidArgumentError struct { 5 | *AISDKError 6 | 7 | // Argument is the name of the invalid argument 8 | Argument string 9 | } 10 | 11 | // NewInvalidArgumentError creates a new InvalidArgumentError instance 12 | // Parameters: 13 | // - message: The error message 14 | // - argument: The name of the invalid argument 15 | // - cause: The underlying cause of the error (optional) 16 | func NewInvalidArgumentError(message string, argument string, cause any) *InvalidArgumentError { 17 | return &InvalidArgumentError{ 18 | AISDKError: NewAISDKError("AI_InvalidArgumentError", message, cause), 19 | Argument: argument, 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /api/err_invalid_prompt.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "fmt" 4 | 5 | // InvalidPromptError indicates that a prompt is invalid. 6 | // This error should be returned by providers when they cannot process a prompt. 7 | type InvalidPromptError struct { 8 | *AISDKError 9 | 10 | // Prompt is the invalid prompt that caused the error 11 | Prompt any 12 | } 13 | 14 | // NewInvalidPromptError creates a new InvalidPromptError instance 15 | // Parameters: 16 | // - prompt: The invalid prompt 17 | // - message: The error message describing why the prompt is invalid 18 | // - cause: The underlying cause of the error (optional) 19 | func NewInvalidPromptError(prompt any, message string, cause any) *InvalidPromptError { 20 | fullMessage := fmt.Sprintf("Invalid prompt: %s", message) 21 | return &InvalidPromptError{ 22 | AISDKError: NewAISDKError("AI_InvalidPromptError", fullMessage, cause), 23 | Prompt: prompt, 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /api/err_invalid_response_data.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | // InvalidResponseDataError indicates that the server returned a response with invalid data content. 9 | // This should be returned by providers when they cannot parse the response from the API. 10 | type InvalidResponseDataError struct { 11 | *AISDKError 12 | 13 | // Data is the invalid response data that caused the error 14 | Data any 15 | } 16 | 17 | // NewInvalidResponseDataError creates a new InvalidResponseDataError instance 18 | // Parameters: 19 | // - data: The invalid response data 20 | // - message: The error message (optional, will be auto-generated if empty) 21 | func NewInvalidResponseDataError(data any, message string) *InvalidResponseDataError { 22 | if message == "" { 23 | dataJSON, _ := json.Marshal(data) 24 | message = fmt.Sprintf("Invalid response data: %s", string(dataJSON)) 25 | } 26 | return &InvalidResponseDataError{ 27 | AISDKError: NewAISDKError("AI_InvalidResponseDataError", message, nil), 28 | Data: data, 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /api/err_json_parse.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "fmt" 4 | 5 | // JSONParseError indicates a failure in parsing JSON 6 | type JSONParseError struct { 7 | *AISDKError 8 | 9 | // Text is the string that failed to parse as JSON 10 | Text string 11 | } 12 | 13 | // NewJSONParseError creates a new JSONParseError instance 14 | // Parameters: 15 | // - text: The text that failed to parse as JSON 16 | // - cause: The underlying parsing error 17 | func NewJSONParseError(text string, cause any) *JSONParseError { 18 | message := fmt.Sprintf("JSON parsing failed: Text: %s.\nError message: %v", text, cause) 19 | return &JSONParseError{ 20 | AISDKError: NewAISDKError("AI_JSONParseError", message, cause), 21 | Text: text, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /api/err_load_api_key.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // LoadAPIKeyError indicates a failure in loading an API key 4 | type LoadAPIKeyError struct { 5 | *AISDKError 6 | } 7 | 8 | // NewLoadAPIKeyError creates a new LoadAPIKeyError instance 9 | // Parameters: 10 | // - message: The error message describing why the API key failed to load 11 | func NewLoadAPIKeyError(message string) *LoadAPIKeyError { 12 | return &LoadAPIKeyError{ 13 | AISDKError: NewAISDKError("AI_LoadAPIKeyError", message, nil), 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /api/err_load_setting.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // LoadSettingError indicates a failure in loading a setting 4 | type LoadSettingError struct { 5 | *AISDKError 6 | } 7 | 8 | // NewLoadSettingError creates a new LoadSettingError instance 9 | // Parameters: 10 | // - message: The error message describing why the setting failed to load 11 | func NewLoadSettingError(message string) *LoadSettingError { 12 | return &LoadSettingError{ 13 | AISDKError: NewAISDKError("AI_LoadSettingError", message, nil), 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /api/err_no_content_generated.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // NoContentGeneratedError is returned when the AI provider fails to generate any content 4 | type NoContentGeneratedError struct { 5 | *AISDKError 6 | } 7 | 8 | // NewNoContentGeneratedError creates a new NoContentGeneratedError instance 9 | // Parameters: 10 | // - message: The error message (optional, defaults to "No content generated.") 11 | func NewNoContentGeneratedError(message string) *NoContentGeneratedError { 12 | if message == "" { 13 | message = "No content generated." 14 | } 15 | return &NoContentGeneratedError{ 16 | AISDKError: NewAISDKError("AI_NoContentGeneratedError", message, nil), 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /api/err_no_such_model.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // ModelType represents the type of AI model 8 | type ModelType string 9 | 10 | const ( 11 | // LanguageModelType represents a language model type 12 | LanguageModelType ModelType = "languageModel" 13 | // TextEmbeddingModelType represents a text embedding model type 14 | TextEmbeddingModelType ModelType = "textEmbeddingModel" 15 | // ImageModelType represents an image model type 16 | ImageModelType ModelType = "imageModel" 17 | ) 18 | 19 | // NoSuchModelError indicates that the requested model does not exist 20 | type NoSuchModelError struct { 21 | *AISDKError 22 | 23 | // ModelID is the identifier of the model that was not found 24 | ModelID string 25 | 26 | // ModelType is the type of model that was requested 27 | ModelType ModelType 28 | } 29 | 30 | // NewNoSuchModelError creates a new NoSuchModelError instance 31 | // Parameters: 32 | // - modelID: The identifier of the model that was not found 33 | // - modelType: The type of model that was requested 34 | func NewNoSuchModelError(modelID string, modelType ModelType) *NoSuchModelError { 35 | message := fmt.Sprintf("No such %s: %s", modelType, modelID) 36 | return &NoSuchModelError{ 37 | AISDKError: NewAISDKError("AI_NoSuchModelError", message, nil), 38 | ModelID: modelID, 39 | ModelType: modelType, 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /api/err_too_many_embedding_values_for_call.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "fmt" 4 | 5 | // TooManyEmbeddingValuesForCallError indicates that too many values were provided for a single embedding call 6 | type TooManyEmbeddingValuesForCallError struct { 7 | *AISDKError 8 | 9 | // Provider is the name of the AI provider 10 | Provider string 11 | 12 | // ModelID is the identifier of the model 13 | ModelID string 14 | 15 | // MaxEmbeddingsPerCall is the maximum number of embeddings allowed per call 16 | MaxEmbeddingsPerCall int 17 | 18 | // Values are the embedding values that were provided 19 | Values []any 20 | } 21 | 22 | // NewTooManyEmbeddingValuesForCallError creates a new TooManyEmbeddingValuesForCallError instance 23 | // Parameters: 24 | // - provider: The name of the AI provider 25 | // - modelID: The identifier of the model 26 | // - maxEmbeddingsPerCall: The maximum number of embeddings allowed per call 27 | // - values: The embedding values that were provided 28 | func NewTooManyEmbeddingValuesForCallError(provider string, modelID string, maxEmbeddingsPerCall int, values []any) *TooManyEmbeddingValuesForCallError { 29 | message := fmt.Sprintf( 30 | "Too many values for a single embedding call. The %s model \"%s\" can only embed up to %d values per call, but %d values were provided.", 31 | provider, 32 | modelID, 33 | maxEmbeddingsPerCall, 34 | len(values), 35 | ) 36 | 37 | return &TooManyEmbeddingValuesForCallError{ 38 | AISDKError: NewAISDKError("AI_TooManyEmbeddingValuesForCallError", message, nil), 39 | Provider: provider, 40 | ModelID: modelID, 41 | MaxEmbeddingsPerCall: maxEmbeddingsPerCall, 42 | Values: values, 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /api/err_type_validation.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // TypeValidationError represents a type validation failure 10 | type TypeValidationError struct { 11 | *AISDKError 12 | 13 | // Value is the value that failed validation 14 | Value any 15 | } 16 | 17 | // NewTypeValidationError creates a new TypeValidationError instance 18 | // Parameters: 19 | // - value: The value that failed validation 20 | // - cause: The original error or cause of the validation failure 21 | func NewTypeValidationError(value any, cause any) *TypeValidationError { 22 | valueJSON, _ := json.Marshal(value) 23 | message := fmt.Sprintf("Type validation failed: Value: %s.\nError message: %v", string(valueJSON), cause) 24 | return &TypeValidationError{ 25 | AISDKError: NewAISDKError("AI_TypeValidationError", message, cause), 26 | Value: value, 27 | } 28 | } 29 | 30 | // WrapTypeValidationError wraps an error into a TypeValidationError. 31 | // If the cause is already a TypeValidationError with the same value, it returns the cause. 32 | // Otherwise, it creates a new TypeValidationError. 33 | // Parameters: 34 | // - value: The value that failed validation 35 | // - cause: The original error or cause of the validation failure 36 | // 37 | // Returns a TypeValidationError instance 38 | func WrapTypeValidationError(value any, cause error) *TypeValidationError { 39 | var existingErr *TypeValidationError 40 | if errors.As(cause, &existingErr) && existingErr.Value == value { 41 | return existingErr 42 | } 43 | return NewTypeValidationError(value, cause) 44 | } 45 | -------------------------------------------------------------------------------- /api/err_unsupported_functionality.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "fmt" 4 | 5 | // UnsupportedFunctionalityError indicates that a requested functionality is not supported 6 | type UnsupportedFunctionalityError struct { 7 | *AISDKError 8 | 9 | // Functionality is the name of the unsupported functionality 10 | Functionality string 11 | } 12 | 13 | // NewUnsupportedFunctionalityError creates a new UnsupportedFunctionalityError instance 14 | // Parameters: 15 | // - functionality: The name of the unsupported functionality 16 | // - message: The error message (optional, will be auto-generated if empty) 17 | func NewUnsupportedFunctionalityError(functionality string, message string) *UnsupportedFunctionalityError { 18 | if message == "" { 19 | message = fmt.Sprintf("'%s' functionality not supported.", functionality) 20 | } 21 | return &UnsupportedFunctionalityError{ 22 | AISDKError: NewAISDKError("AI_UnsupportedFunctionalityError", message, nil), 23 | Functionality: functionality, 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /api/image_model.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // ImageModel is a specification for an image generation model that implements 9 | // the image model interface version 1. 10 | type ImageModel interface { 11 | // SpecificationVersion returns which image model interface version is implemented. 12 | // This will allow us to evolve the image model interface and retain backwards 13 | // compatibility. The different implementation versions can be handled as a 14 | // discriminated union on our side. 15 | SpecificationVersion() string 16 | 17 | // ProviderName returns the name of the provider for logging purposes. 18 | ProviderName() string 19 | 20 | // ModelID returns the provider-specific model ID for logging purposes. 21 | ModelID() string 22 | 23 | // MaxImagesPerCall returns the limit of how many images can be generated in a single API call. 24 | // If undefined, we will max generate one image per call. 25 | MaxImagesPerCall() *int 26 | 27 | // DoGenerate generates an array of images based on the given prompt. 28 | DoGenerate(ctx context.Context, prompt string, opts ...ImageCallOption) ImageResponse 29 | } 30 | 31 | // ImageResponse represents the response from generating images. 32 | type ImageResponse struct { 33 | // Images are the generated images as base64 encoded strings or binary data. 34 | // The images should be returned without any unnecessary conversion. 35 | // If the API returns base64 encoded strings, the images should be returned 36 | // as base64 encoded strings. If the API returns binary data, the images should 37 | // be returned as binary data. 38 | Images []ImageData 39 | 40 | // Warnings for the call, e.g. unsupported settings. 41 | Warnings []ImageCallWarning 42 | 43 | // Response contains information for telemetry and debugging purposes. 44 | Response ImageResponseMetadata 45 | } 46 | 47 | // ImageData represents either a base64 encoded string or binary data for an image 48 | type ImageData interface { 49 | // IsImageData is a marker method to ensure type safety 50 | IsImageData() 51 | } 52 | 53 | // Base64Image represents an image as a base64 encoded string 54 | type Base64Image string 55 | 56 | func (Base64Image) IsImageData() {} 57 | 58 | // BinaryImage represents an image as binary data 59 | type BinaryImage []byte 60 | 61 | func (BinaryImage) IsImageData() {} 62 | 63 | // ImageResponseMetadata contains response information for telemetry and debugging purposes. 64 | type ImageResponseMetadata struct { 65 | // Timestamp is the timestamp for the start of the generated response. 66 | Timestamp time.Time 67 | 68 | // ModelID is the ID of the response model that was used to generate the response. 69 | ModelID string 70 | 71 | // Headers are the response headers. 72 | Headers map[string]string 73 | } 74 | -------------------------------------------------------------------------------- /api/image_model_call_options.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // ImageCallOptions represents the options for generating images. 4 | type ImageCallOptions struct { 5 | // N is the number of images to generate. 6 | N int 7 | 8 | // Size of the images to generate. 9 | // Must have the format `{width}x{height}`. 10 | // nil will use the provider's default size. 11 | Size *string 12 | 13 | // AspectRatio of the images to generate. 14 | // Must have the format `{width}:{height}`. 15 | // nil will use the provider's default aspect ratio. 16 | AspectRatio *string 17 | 18 | // Seed for the image generation. 19 | // nil will use the provider's default seed. 20 | Seed *int 21 | 22 | // ProviderOptions are additional provider-specific options that are passed through to the provider 23 | // as body parameters. 24 | // 25 | // The outer map is keyed by the provider name, and the inner 26 | // map is keyed by the provider-specific metadata key. The value can be any JSON-compatible value 27 | // (string, number, boolean, null, array, or object). 28 | // Example: 29 | // { 30 | // "openai": { 31 | // "style": "vivid", 32 | // "quality": 1, 33 | // "hd": true, 34 | // "metadata": { 35 | // "user": "test" 36 | // } 37 | // } 38 | // } 39 | ProviderOptions map[string]map[string]any 40 | 41 | // Headers are additional HTTP headers to be sent with the request. 42 | // Only applicable for HTTP-based providers. 43 | Headers map[string]string 44 | } 45 | 46 | // ImageCallOption is a function that modifies ImageCallOptions. 47 | type ImageCallOption func(*ImageCallOptions) 48 | 49 | // WithImageCount sets the number of images to generate. 50 | // N is the number of images to generate. 51 | func WithImageCount(n int) ImageCallOption { 52 | return func(o *ImageCallOptions) { 53 | o.N = n 54 | } 55 | } 56 | 57 | // WithImageSize sets the size of images to generate. 58 | // Must have the format `{width}x{height}`. 59 | // nil will use the provider's default size. 60 | func WithImageSize(size string) ImageCallOption { 61 | return func(o *ImageCallOptions) { 62 | o.Size = &size 63 | } 64 | } 65 | 66 | // WithImageAspectRatio sets the aspect ratio of images to generate. 67 | // Must have the format `{width}:{height}`. 68 | // nil will use the provider's default aspect ratio. 69 | func WithImageAspectRatio(ratio string) ImageCallOption { 70 | return func(o *ImageCallOptions) { 71 | o.AspectRatio = &ratio 72 | } 73 | } 74 | 75 | // WithImageSeed sets the seed for image generation. 76 | // nil will use the provider's default seed. 77 | func WithImageSeed(seed int) ImageCallOption { 78 | return func(o *ImageCallOptions) { 79 | o.Seed = &seed 80 | } 81 | } 82 | 83 | // WithImageProviderOptions sets provider-specific options that are passed through to the provider 84 | // as body parameters. 85 | // 86 | // The outer map is keyed by the provider name, and the inner 87 | // map is keyed by the provider-specific metadata key. The value can be any JSON-compatible value 88 | // (string, number, boolean, null, array, or object). 89 | // Example: 90 | // 91 | // { 92 | // "openai": { 93 | // "style": "vivid", 94 | // "quality": 1, 95 | // "hd": true, 96 | // "metadata": { 97 | // "user": "test" 98 | // } 99 | // } 100 | // } 101 | func WithImageProviderOptions(options map[string]map[string]any) ImageCallOption { 102 | return func(o *ImageCallOptions) { 103 | o.ProviderOptions = options 104 | } 105 | } 106 | 107 | // WithImageHeaders sets additional HTTP headers to be sent with the request. 108 | // Only applicable for HTTP-based providers. 109 | func WithImageHeaders(headers map[string]string) ImageCallOption { 110 | return func(o *ImageCallOptions) { 111 | o.Headers = headers 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /api/image_model_call_warning.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // ImageCallWarningType represents the type of warning from the model provider. 4 | type ImageCallWarningType string 5 | 6 | const ( 7 | // ImageCallWarningTypeUnsupportedSetting indicates a warning about an unsupported setting. 8 | ImageCallWarningTypeUnsupportedSetting ImageCallWarningType = "unsupported-setting" 9 | 10 | // ImageCallWarningTypeOther indicates a generic warning. 11 | ImageCallWarningTypeOther ImageCallWarningType = "other" 12 | ) 13 | 14 | // ImageCallWarning represents a warning from the model provider for this call. 15 | // The call will proceed, but e.g. some settings might not be supported, 16 | // which can lead to suboptimal results. 17 | type ImageCallWarning interface { 18 | // isImageCallWarning is a marker method to ensure type safety 19 | isImageCallWarning() 20 | } 21 | 22 | // UnsupportedSettingWarning represents a warning about an unsupported setting. 23 | type UnsupportedSettingWarning struct { 24 | // Type is always ImageCallWarningTypeUnsupportedSetting for this warning 25 | Type ImageCallWarningType `json:"type"` 26 | 27 | // Setting is the name of the unsupported setting from ImageCallOptions 28 | Setting string `json:"setting"` 29 | 30 | // Details provides additional information about why the setting is unsupported 31 | Details *string `json:"details,omitzero"` 32 | } 33 | 34 | func (UnsupportedSettingWarning) isImageCallWarning() {} 35 | 36 | // NewUnsupportedSettingWarning creates a new UnsupportedSettingWarning with the given setting and optional details 37 | func NewUnsupportedSettingWarning(setting string, details *string) UnsupportedSettingWarning { 38 | return UnsupportedSettingWarning{ 39 | Type: ImageCallWarningTypeUnsupportedSetting, 40 | Setting: setting, 41 | Details: details, 42 | } 43 | } 44 | 45 | // OtherWarning represents a generic warning with a message. 46 | type OtherWarning struct { 47 | // Type is always ImageCallWarningTypeOther for this warning 48 | Type ImageCallWarningType `json:"type"` 49 | 50 | // Message describes the warning 51 | Message string `json:"message"` 52 | } 53 | 54 | func (OtherWarning) isImageCallWarning() {} 55 | 56 | // NewOtherWarning creates a new OtherWarning with the given message 57 | func NewOtherWarning(message string) OtherWarning { 58 | return OtherWarning{ 59 | Type: ImageCallWarningTypeOther, 60 | Message: message, 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /api/llm_call_options.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import jsonschema "github.com/sashabaranov/go-openai/jsonschema" 4 | 5 | // TODO: should we call it Config? Settings? 6 | // We should think about the field name if it was being sent as JSON in a request. 7 | // "Request" might also be a better name over "Call": "RequestOptions" or "RequestSettings" 8 | 9 | // CallOptions represents the options for language model calls. 10 | type CallOptions struct { 11 | // MaxOutputTokens specifies the maximum number of tokens to generate 12 | MaxOutputTokens int `json:"max_output_tokens,omitzero"` 13 | 14 | // Temperature controls randomness in the model's output. 15 | // It is recommended to set either Temperature or TopP, but not both. 16 | Temperature *float64 `json:"temperature,omitzero"` 17 | 18 | // StopSequences specifies sequences that will stop generation when produced. 19 | // Providers may have limits on the number of stop sequences. 20 | StopSequences []string `json:"stop_sequences,omitempty"` 21 | 22 | // TopP controls nucleus sampling. 23 | // It is recommended to set either Temperature or TopP, but not both. 24 | TopP float64 `json:"top_p,omitzero"` 25 | 26 | // TopK limits sampling to the top K options for each token. 27 | // Used to remove "long tail" low probability responses. 28 | // Recommended for advanced use cases only. 29 | TopK int `json:"top_k,omitzero"` 30 | 31 | // PresencePenalty affects the likelihood of the model repeating 32 | // information that is already in the prompt 33 | PresencePenalty float64 `json:"presence_penalty,omitzero"` 34 | 35 | // FrequencyPenalty affects the likelihood of the model 36 | // repeatedly using the same words or phrases 37 | FrequencyPenalty float64 `json:"frequency_penalty,omitzero"` 38 | 39 | // ResponseFormat specifies whether the output should be text or JSON. 40 | // For JSON output, a schema can optionally guide the model. 41 | ResponseFormat *ResponseFormat `json:"response_format,omitzero"` 42 | 43 | // Seed provides an integer seed for random sampling. 44 | // If supported by the model, calls will generate deterministic results. 45 | Seed int `json:"seed,omitzero"` 46 | 47 | // Tools that are available for the model to use. 48 | Tools []ToolDefinition `json:"tools,omitempty"` 49 | 50 | // ToolChoice specifies how the model should select which tool to use. 51 | // Defaults to 'auto'. 52 | ToolChoice *ToolChoice `json:"tool_choice,omitzero"` 53 | 54 | // Headers specifies additional HTTP headers to send with the request. 55 | // Only applicable for HTTP-based providers. 56 | Headers map[string]string `json:"headers,omitempty"` 57 | 58 | // ProviderMetadata contains additional provider-specific metadata. 59 | // The metadata is passed through to the provider from the AI SDK and enables 60 | // provider-specific functionality that can be fully encapsulated in the provider. 61 | ProviderMetadata *ProviderMetadata `json:"provider_metadata,omitzero"` 62 | } 63 | 64 | func (o CallOptions) GetProviderMetadata() *ProviderMetadata { return o.ProviderMetadata } 65 | 66 | // ResponseFormat specifies the format of the model's response. 67 | type ResponseFormat struct { 68 | // Type indicates the response format type ("text" or "json") 69 | Type string `json:"type"` 70 | 71 | // Schema optionally provides a JSON schema to guide the model's output 72 | Schema *jsonschema.Definition `json:"schema,omitzero"` 73 | 74 | // Name optionally provides a name for the output to guide the model 75 | Name string `json:"name,omitzero"` 76 | 77 | // Description optionally provides additional context to guide the model 78 | Description string `json:"description,omitzero"` 79 | } 80 | -------------------------------------------------------------------------------- /api/llm_call_warning.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // CallWarning represents a warning from the model provider for a call. 4 | // The call will proceed, but some settings might not be supported, 5 | // which can lead to suboptimal results. 6 | type CallWarning struct { 7 | // TODO: We might want to turn Type into an enum 8 | // OR we could make Warning an interface with different concrete types. 9 | 10 | // Type indicates the kind of warning: "unsupported-setting", "unsupported-tool", or "other" 11 | Type string `json:"type"` 12 | 13 | // TODO: These are usually called Configs or Options in go ... should we update 14 | // the name of the field? 15 | 16 | // Setting contains the name of the unsupported setting when Type is "unsupported-setting" 17 | Setting string `json:"setting,omitzero"` 18 | 19 | // Tool contains the unsupported tool when Type is "unsupported-tool" 20 | Tool ToolDefinition `json:"tool,omitzero"` 21 | 22 | // Details provides additional information about the warning 23 | Details string `json:"details,omitzero"` 24 | 25 | // Message contains a human-readable warning message 26 | Message string `json:"message,omitzero"` 27 | } 28 | -------------------------------------------------------------------------------- /api/llm_events.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | // EventType represents the different types of stream events. 10 | type EventType string 11 | 12 | const ( 13 | // EventTextDelta represents an incremental text response from the model. 14 | EventTextDelta EventType = "text-delta" 15 | 16 | // EventReasoning is an optional reasoning or intermediate explanation generated by the model. 17 | EventReasoning EventType = "reasoning" 18 | 19 | // EventReasoningSignature represents a signature that verifies reasoning content. 20 | EventReasoningSignature EventType = "reasoning-signature" 21 | 22 | // EventRedactedReasoning represents redacted reasoning data. 23 | EventRedactedReasoning EventType = "redacted-reasoning" 24 | 25 | // EventSource represents a citation that was used to generate the response. 26 | EventSource EventType = "source" 27 | 28 | // EventFile represents a file generated by the model. 29 | EventFile EventType = "file" 30 | 31 | // EventToolCall represents a completed tool call with all arguments provided. 32 | EventToolCall EventType = "tool-call" 33 | 34 | // EventToolCallDelta is an incremental update for tool call arguments. 35 | EventToolCallDelta EventType = "tool-call-delta" 36 | 37 | // EventResponseMetadata contains additional response metadata, such as timestamps or provider details. 38 | EventResponseMetadata EventType = "response-metadata" 39 | 40 | // EventFinish is the final part of the stream, providing the finish reason and usage statistics. 41 | EventFinish EventType = "finish" 42 | 43 | // EventError indicates that an error occurred during the stream. 44 | EventError EventType = "error" 45 | 46 | // TODO: How should we handle refusal events? Do we need an additional event type? 47 | ) 48 | 49 | // StreamEvent represents a streamed incremental update of the language model output. 50 | type StreamEvent interface { 51 | // Type returns the type of event being received. 52 | Type() EventType 53 | } 54 | 55 | // TextDeltaEvent represents an incremental text response from the model 56 | // 57 | // Used to update a TextBlock incrementally. 58 | type TextDeltaEvent struct { 59 | // TextDelta is a partial text response from the model 60 | TextDelta string `json:"text_delta"` 61 | } 62 | 63 | func (b TextDeltaEvent) Type() EventType { return EventTextDelta } 64 | 65 | // ReasoningEvent represents an incremental reasoning response from the model. 66 | // 67 | // Used to update the text of a ReasoningBlock. 68 | type ReasoningEvent struct { 69 | // TextDelta is a partial reasoning text from the model 70 | TextDelta string `json:"text_delta"` 71 | } 72 | 73 | func (b ReasoningEvent) Type() EventType { return EventReasoning } 74 | 75 | // ReasoningSignatureEvent represents an incremental signature update for reasoning text. 76 | // 77 | // Used to update the signature field of a ReasoningBlock. 78 | type ReasoningSignatureEvent struct { 79 | // Signature is the cryptographic signature for verifying reasoning 80 | Signature string `json:"signature"` 81 | } 82 | 83 | func (b ReasoningSignatureEvent) Type() EventType { return EventReasoningSignature } 84 | 85 | // RedactedReasoningEvent represents an update to redacted reasoning data. 86 | // 87 | // Used to update the data field of a RedactedReasoningBlock. 88 | type RedactedReasoningEvent struct { 89 | // Data contains redacted reasoning data 90 | Data string `json:"data"` 91 | } 92 | 93 | func (b RedactedReasoningEvent) Type() EventType { return EventRedactedReasoning } 94 | 95 | // SourceEvent represents a source that was used to generate the response. 96 | // 97 | // Used to add a source to the response. 98 | type SourceEvent struct { 99 | // Source contains information about the source 100 | Source Source `json:"source"` 101 | } 102 | 103 | func (b SourceEvent) Type() EventType { return EventSource } 104 | 105 | // FileEvent represents a file generated by the model. 106 | // 107 | // Used to add a file to the response via a FileBlock. 108 | type FileEvent struct { 109 | // MediaType is the IANA media type (mime type) of the file 110 | MediaType string `json:"media_type"` 111 | // Data contains the generated file as a byte array 112 | Data []byte `json:"data"` 113 | } 114 | 115 | func (b FileEvent) Type() EventType { return EventFile } 116 | 117 | // ToolCallEvent represents a complete tool call with all arguments. 118 | // 119 | // Used to add a tool call to the response via a ToolCallBlock. 120 | type ToolCallEvent struct { 121 | // ToolCallID is the ID of the tool call. This ID is used to match the tool call with the tool result. 122 | ToolCallID string `json:"tool_call_id"` 123 | 124 | // ToolName is the name of the tool being invoked. 125 | ToolName string `json:"tool_name"` 126 | 127 | // Args contains the arguments of the tool call as a JSON payload matching 128 | // the tool's input schema. 129 | // Note that args are often generated by the language model and may be 130 | // malformed. 131 | Args json.RawMessage `json:"args"` 132 | } 133 | 134 | func (b ToolCallEvent) Type() EventType { return EventToolCall } 135 | 136 | ////////////// Unify events above with content blocks? 137 | 138 | // ToolCallDeltaEvent represents a tool call with incremental arguments. 139 | // Tool call deltas are only needed for object generation modes. 140 | // The tool call deltas must be partial JSON. 141 | type ToolCallDeltaEvent struct { 142 | // ToolCallID is the ID of the tool call 143 | ToolCallID string `json:"tool_call_id"` 144 | 145 | // ToolName is the name of the tool being invoked 146 | ToolName string `json:"tool_name"` 147 | 148 | // ArgsDelta is a partial JSON byte slice update for the tool call arguments 149 | ArgsDelta []byte `json:"args_delta"` 150 | } 151 | 152 | func (b ToolCallDeltaEvent) Type() EventType { return EventToolCallDelta } 153 | 154 | // TODOTODO Stream Start. 155 | 156 | // ResponseMetadataEvent contains additional response metadata. 157 | // 158 | // It will be sent as soon as it is available, without having to wait for 159 | // the FinishEvent. 160 | type ResponseMetadataEvent struct { 161 | // ID for the generated response, if the provider sends one 162 | ID string `json:"id,omitzero"` 163 | // Timestamp represents when the stream part was generated 164 | Timestamp time.Time `json:"timestamp,omitzero"` 165 | // ModelID for the generated response, if the provider sends one 166 | ModelID string `json:"model_id,omitzero"` 167 | } 168 | 169 | func (b ResponseMetadataEvent) Type() EventType { 170 | return EventResponseMetadata 171 | } 172 | 173 | // FinishEvent represents the final part of the stream. 174 | // 175 | // It will be sent once the stream has finished processing. 176 | type FinishEvent struct { 177 | // Usage contains token usage statistics 178 | Usage Usage `json:"usage,omitzero"` 179 | 180 | // FinishReason indicates why the model stopped generating 181 | FinishReason FinishReason `json:"finish_reason"` 182 | 183 | // ProviderMetadata contains provider-specific metadata 184 | ProviderMetadata *ProviderMetadata `json:"provider_metadata,omitzero"` 185 | } 186 | 187 | func (b FinishEvent) Type() EventType { return EventFinish } 188 | 189 | func (b FinishEvent) GetProviderMetadata() *ProviderMetadata { return b.ProviderMetadata } 190 | 191 | // ErrorEvent represents an error that occurred during streaming. 192 | type ErrorEvent struct { 193 | // Err contains any error messages or error objects encountered during the stream 194 | Err any `json:"error"` 195 | 196 | // TODO: We might want to make sure that the error field is always serializable as JSON, 197 | // or we might force it to be an error type defined by our AI SDK so that the shape is 198 | // known if transmitted over the network. 199 | } 200 | 201 | func (b ErrorEvent) Type() EventType { return EventError } 202 | func (b ErrorEvent) Error() string { return fmt.Sprintf("%v", b.Err) } 203 | -------------------------------------------------------------------------------- /api/llm_finish_reason.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // FinishReason indicates why a language model finished generating a response. 4 | type FinishReason string 5 | 6 | const ( 7 | // FinishReasonStop indicates the model generated a stop sequence 8 | FinishReasonStop FinishReason = "stop" 9 | 10 | // FinishReasonLength indicates the model reached the maximum number of tokens 11 | FinishReasonLength FinishReason = "length" 12 | 13 | // FinishReasonContentFilter indicates a content filter violation stopped the model 14 | FinishReasonContentFilter FinishReason = "content-filter" 15 | 16 | // FinishReasonToolCalls indicates the model triggered tool calls 17 | FinishReasonToolCalls FinishReason = "tool-calls" 18 | 19 | // FinishReasonError indicates the model stopped because of an error 20 | FinishReasonError FinishReason = "error" 21 | 22 | // FinishReasonOther indicates the model stopped for other reasons 23 | FinishReasonOther FinishReason = "other" 24 | 25 | // FinishReasonUnknown indicates the model has not transmitted a finish reason 26 | FinishReasonUnknown FinishReason = "unknown" 27 | ) 28 | -------------------------------------------------------------------------------- /api/llm_language_model.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "iter" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | // LanguageModel represents a language model. 11 | type LanguageModel interface { 12 | // ProviderName returns the name of the provider for logging purposes. 13 | ProviderName() string 14 | 15 | // ModelID returns the provider-specific model ID for logging purposes. 16 | ModelID() string 17 | 18 | // SupportedUrls returns URL patterns supported by the model, grouped by media type. 19 | // 20 | // The MediaType field contains media type patterns or full media types (e.g. "*/*" for everything, 21 | // "audio/*", "video/*", or "application/pdf"). The URLPatterns field contains arrays of regular 22 | // expression patterns that match URL paths. 23 | // 24 | // The matching is performed against lowercase URLs. 25 | // 26 | // URLs that match these patterns are supported natively by the model and will not 27 | // be downloaded by the SDK. For non-matching URLs, the SDK will download the content 28 | // and pass it directly to the model. 29 | // 30 | // If nil or an empty slice is returned, the SDK will download all files. 31 | SupportedUrls() []SupportedURL 32 | 33 | // Generate generates a language model output (non-streaming). 34 | // 35 | // The prompt parameter is a standardized prompt type, not the user-facing prompt. 36 | // The AI SDK methods will map the user-facing prompt types such as chat or 37 | // instruction prompts to this format. 38 | Generate(ctx context.Context, prompt []Message, opts CallOptions) (Response, error) 39 | 40 | // Stream generates a language model output (streaming). 41 | // Returns a stream of events from the model. 42 | // 43 | // The prompt parameter is a standardized prompt type, not the user-facing prompt. 44 | // The AI SDK methods will map the user-facing prompt types such as chat or 45 | // instruction prompts to this format. 46 | Stream(ctx context.Context, prompt []Message, opts CallOptions) (StreamResponse, error) 47 | } 48 | 49 | // SupportedURL defines URL patterns supported for a specific media type 50 | type SupportedURL struct { 51 | // MediaType is the IANA media type (mime type) of the URL. 52 | // A simple '*' wildcard is supported for the mime type or subtype 53 | // (e.g., "application/pdf", "audio/*", "*/*"). 54 | MediaType string 55 | 56 | // TODO: change reasoning to support reasoning blocks as well 57 | // URLPatterns contains regex patterns for URL paths that match this media type 58 | URLPatterns []string 59 | } 60 | 61 | // Response represents the result of a non-streaming language model generation. 62 | type Response struct { 63 | // Content contains the ordered list of content blocks that the model generated. 64 | Content []ContentBlock `json:"content"` 65 | 66 | // FinishReason contains an explanation for why the model finished generating. 67 | FinishReason FinishReason `json:"finish_reason"` 68 | 69 | // Usage contains information about the number of tokens used by the model. 70 | Usage Usage `json:"usage"` 71 | 72 | // Additional provider-specific metadata. They are passed through from the 73 | // provider to enable provider-specific functionality. 74 | ProviderMetadata *ProviderMetadata `json:"provider_metadata,omitzero"` 75 | 76 | // RequestInfo is optional request information for telemetry and debugging purposes. 77 | RequestInfo *RequestInfo `json:"request,omitzero"` 78 | 79 | // ResponseInfo is optional response information for telemetry and debugging purposes. 80 | ResponseInfo *ResponseInfo `json:"response,omitzero"` 81 | 82 | // Warnings is a list of warnings that occurred during the call, 83 | // e.g. unsupported settings. 84 | Warnings []CallWarning `json:"warnings,omitempty"` 85 | } 86 | 87 | func (r Response) GetProviderMetadata() *ProviderMetadata { return r.ProviderMetadata } 88 | 89 | // StreamResponse represents the result of a streaming language model call. 90 | type StreamResponse struct { 91 | // Stream is the sequence of events received from the model. 92 | // Iterating over events might block if we're waiting for the LLM to respond. 93 | Stream iter.Seq[StreamEvent] 94 | // TODO: For now we're always encoding errors as ErrorEvent. Is that the right 95 | // behavior or should we consider iter.Seq2[StreamEvent, error]? 96 | 97 | // RequestInfo is optional request information for telemetry and debugging purposes. 98 | RequestInfo *RequestInfo `json:"request,omitzero"` 99 | 100 | // ResponseInfo is optional response information for telemetry and debugging purposes. 101 | ResponseInfo *ResponseInfo `json:"response,omitzero"` 102 | } 103 | 104 | // Usage represents token usage statistics for a model call. 105 | // 106 | // If a provider returns additional usage information besides the ones below, 107 | // that information is added to the provider metadata field. 108 | type Usage struct { 109 | // InputTokens is the number of tokens used by the input (prompt). 110 | InputTokens int `json:"input_tokens"` 111 | 112 | // OutputTokens is the number of tokens in the generated output (completion or tool call). 113 | OutputTokens int `json:"output_tokens"` 114 | 115 | // TotalTokens is the total number of tokens used as reported by the provider. 116 | // Note that this might be different from the sum of input tokens and output tokens 117 | // because it might include reasoning tokens or other overhead. 118 | TotalTokens int `json:"total_tokens"` 119 | 120 | // ReasoningTokens is the number of tokens used by model as part of the reasoning process. 121 | ReasoningTokens int `json:"reasoning_tokens,omitzero"` 122 | 123 | // CachedInputTokens is the number of input tokens that were cached from a previous call. 124 | CachedInputTokens int `json:"cached_input_tokens,omitzero"` 125 | } 126 | 127 | // IsZero returns true if all fields of the Usage struct are zero. 128 | func (u Usage) IsZero() bool { 129 | return u.InputTokens == 0 && 130 | u.OutputTokens == 0 && 131 | u.TotalTokens == 0 && 132 | u.ReasoningTokens == 0 && 133 | u.CachedInputTokens == 0 134 | } 135 | 136 | // RequestInfo contains optional request information for telemetry. 137 | type RequestInfo struct { 138 | // Body is the raw HTTP body that was sent to the provider 139 | Body []byte `json:"body,omitempty"` 140 | } 141 | 142 | // ResponseInfo contains optional response information for telemetry. 143 | type ResponseInfo struct { 144 | // ID for the generated response, if the provider sends one. 145 | ID string `json:"id,omitzero"` 146 | 147 | // Timestamp for the start of the generated response, if the provider sends one. 148 | Timestamp time.Time `json:"timestamp,omitzero"` 149 | 150 | // ModelID of the model that was used to generate the response, if the provider sends one. 151 | ModelID string `json:"model_id,omitzero"` 152 | 153 | // Header contains a map of the HTTP response headers. 154 | Header http.Header 155 | 156 | // Body is the raw HTTP body that was returned by the provider. 157 | // Not provided for streaming responses. 158 | Body []byte `json:"body,omitempty"` 159 | 160 | // Status is a status code and message. e.g. "200 OK" 161 | Status string `json:"status,omitzero"` 162 | 163 | // StatusCode is a status code as integer. e.g. 200 164 | StatusCode int `json:"status_code,omitzero"` 165 | 166 | // TODO: consider adding a duration field 167 | } 168 | -------------------------------------------------------------------------------- /api/llm_logprobs.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // TokenLogProb represents the log probability for a single token. 4 | type TokenLogProb struct { 5 | // Token is the text of the token 6 | Token string `json:"token"` 7 | 8 | // LogProb is the log probability of the token 9 | LogProb float64 `json:"logprob"` 10 | } 11 | 12 | // LogProb represents the log probability information for a token, 13 | // including its top alternative tokens. 14 | type LogProb struct { 15 | // Token is the text of the token 16 | Token string `json:"token"` 17 | 18 | // LogProb is the log probability of the token 19 | LogProb float64 `json:"logprob"` 20 | 21 | // TopLogProbs contains the log probabilities of alternative tokens 22 | TopLogProbs []TokenLogProb `json:"top_logprobs"` 23 | } 24 | 25 | // LogProbs represents a sequence of token log probabilities 26 | type LogProbs []LogProb 27 | -------------------------------------------------------------------------------- /api/llm_object_generation_mode.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // ObjectGenerationMode specifies how the model should generate structured objects. 4 | type ObjectGenerationMode string 5 | 6 | const ( 7 | // ObjectGenerationModeNone indicates no specific object generation mode (empty string) 8 | ObjectGenerationModeNone ObjectGenerationMode = "" 9 | 10 | // ObjectGenerationModeJSON indicates the model should generate JSON directly 11 | ObjectGenerationModeJSON ObjectGenerationMode = "json" 12 | 13 | // ObjectGenerationModeTool indicates the model should use tool calls to generate objects 14 | ObjectGenerationModeTool ObjectGenerationMode = "tool" 15 | ) 16 | -------------------------------------------------------------------------------- /api/llm_provider_metadata.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | // ProviderMetadata provides access to provider-specific metadata structures. 8 | // It stores and retrieves strongly-typed metadata for specific providers. 9 | type ProviderMetadata struct { 10 | data map[string]any 11 | } 12 | 13 | // NewProviderMetadata creates a new ProviderMetadata with the given initial data. 14 | // If data is nil, an empty map will be created. 15 | func NewProviderMetadata(data map[string]any) *ProviderMetadata { 16 | if data == nil { 17 | return &ProviderMetadata{data: make(map[string]any)} 18 | } 19 | return &ProviderMetadata{data: data} 20 | } 21 | 22 | // Get retrieves the metadata for a specific provider. 23 | // Returns the metadata and a boolean indicating whether the provider was found. 24 | func (p *ProviderMetadata) Get(provider string) (any, bool) { 25 | if p == nil || p.data == nil { 26 | return nil, false 27 | } 28 | 29 | metadata, exists := p.data[provider] 30 | return metadata, exists 31 | } 32 | 33 | // Set stores metadata for a specific provider. 34 | func (p *ProviderMetadata) Set(provider string, metadata any) { 35 | if p.data == nil { 36 | p.data = make(map[string]any) 37 | } 38 | p.data[provider] = metadata 39 | } 40 | 41 | // Has checks if metadata exists for a specific provider. 42 | func (p *ProviderMetadata) Has(provider string) bool { 43 | if p == nil || p.data == nil { 44 | return false 45 | } 46 | _, exists := p.data[provider] 47 | return exists 48 | } 49 | 50 | // MarshalJSON implements the json.Marshaler interface. 51 | // It serializes the underlying data map to JSON. 52 | func (p *ProviderMetadata) MarshalJSON() ([]byte, error) { 53 | if p == nil || p.data == nil { 54 | return json.Marshal(make(map[string]any)) 55 | } 56 | return json.Marshal(p.data) 57 | } 58 | 59 | // IsZero returns true if this ProviderMetadata is a zero value or contains no data. 60 | func (p *ProviderMetadata) IsZero() bool { 61 | return p == nil || len(p.data) == 0 62 | } 63 | 64 | type MetadataSource interface { 65 | GetProviderMetadata() *ProviderMetadata 66 | } 67 | 68 | // GetMetadata is a generic helper function to retrieve provider-specific 69 | // metadata as a pointer to the requested type. 70 | // 71 | // If the provider is not found or the type doesn't match, it returns nil. 72 | // 73 | // We recommend providers use this helper to expose predefined metadata functions. 74 | func GetMetadata[T any](provider string, source MetadataSource) *T { 75 | if source == nil { 76 | return nil 77 | } 78 | 79 | pm := source.GetProviderMetadata() 80 | if pm == nil { 81 | return nil 82 | } 83 | 84 | metadata, ok := pm.Get(provider) 85 | if !ok { 86 | return nil 87 | } 88 | 89 | // First try to get the pointer: 90 | ptr, ok := metadata.(*T) 91 | if ok { 92 | return ptr 93 | } 94 | 95 | // If that fails, try the value: 96 | value, ok := metadata.(T) 97 | if ok { 98 | return &value 99 | } 100 | 101 | // We couldn't cast it to the right type, return nil: 102 | return nil 103 | } 104 | -------------------------------------------------------------------------------- /api/llm_source.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // Source represents a source that has been used as input to generate the response. 4 | type Source struct { 5 | // SourceType indicates the type of source. Currently only "url" is supported. 6 | SourceType string `json:"source_type"` 7 | 8 | // ID is the unique identifier of the source. 9 | ID string `json:"id"` 10 | 11 | // URL is the URL of the source. 12 | URL string `json:"url"` 13 | 14 | // Title is the optional title of the source. 15 | Title string `json:"title,omitzero"` 16 | 17 | // ProviderMetadata contains additional provider-specific metadata. 18 | ProviderMetadata ProviderMetadata `json:"provider_metadata,omitzero"` 19 | } 20 | -------------------------------------------------------------------------------- /api/llm_tool.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // TODO: This schema package is pretty small. 4 | // It might be best to just in line it into our AI SDK. 5 | import "github.com/sashabaranov/go-openai/jsonschema" 6 | 7 | // ToolChoice specifies how tools should be selected by the model. 8 | type ToolChoice struct { 9 | // Type indicates how tools should be selected: 10 | // - "auto": tool selection is automatic (can be no tool) 11 | // - "none": no tool must be selected 12 | // - "required": one of the available tools must be selected 13 | // - "tool": a specific tool must be selected 14 | Type string `json:"type"` 15 | 16 | // ToolName specifies which tool to use when Type is "tool" 17 | ToolName string `json:"tool_name,omitzero"` 18 | } 19 | 20 | // ToolDefinition represents a tool that can be used in a language model call. 21 | // It can either be a user-defined tool or a built-in provider-defined tool. 22 | type ToolDefinition interface { 23 | // ToolType is the type of the tool. Either "function" or "provider-defined". 24 | ToolType() string 25 | } 26 | 27 | // FunctionTool represents a tool that has a name, description, and set of input arguments. 28 | // Note: this is not the user-facing tool definition. The AI SDK methods will 29 | // map the user-facing tool definitions to this format. 30 | type FunctionTool struct { 31 | // Name is the unique identifier for this tool within this model call 32 | Name string `json:"name"` 33 | 34 | // Description explains the tool's purpose. The language model uses this to understand 35 | // the tool's purpose and provide better completion suggestions. 36 | Description string `json:"description,omitzero"` 37 | 38 | // InputSchema defines the expected inputs. The language model uses this to understand 39 | // the tool's input requirements and provide matching suggestions. 40 | // InputSchema should be defined using a JSON schema. 41 | InputSchema *jsonschema.Definition `json:"input_schema,omitzero"` 42 | } 43 | 44 | var _ ToolDefinition = &FunctionTool{} 45 | 46 | // Type is the type of the tool (always "function") 47 | func (t FunctionTool) ToolType() string { return "function" } 48 | 49 | // ProviderDefinedTool represents a tool that has built-in support by the provider. 50 | // Provider implementations will usually predefine these. 51 | type ProviderDefinedTool interface { 52 | // ToolType is the type of the tool. Always "provider-defined" for provider-defined tools. 53 | ToolType() string 54 | 55 | // ID is the tool identifier in format "." 56 | ID() string 57 | 58 | // Name returns the unique name used to identify this tool within the model's messages. 59 | // This is the name that will be used by the language model as the value of the ToolName field 60 | // in ToolCall blocks. 61 | Name() string 62 | 63 | // TODO: Consider adding a Validate method that checks if the arguments are valid and returns an error otherwise. 64 | // This would be used to validate the tool call arguments before sending them to the language model. 65 | } 66 | -------------------------------------------------------------------------------- /api/provider.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | // Provider is a provider for language and text embedding models. 4 | type Provider interface { 5 | // LanguageModel returns the language model with the given id. 6 | // The model id is then passed to the provider function to get the model. 7 | // 8 | // Parameters: 9 | // modelID: The id of the model to return. 10 | // 11 | // Returns: 12 | // The language model associated with the id 13 | // error of type NoSuchModelError if no such model exists 14 | LanguageModel(modelID string) (LanguageModel, error) 15 | 16 | // TextEmbeddingModel returns the text embedding model with the given id. 17 | // The model id is then passed to the provider function to get the model. 18 | // 19 | // Parameters: 20 | // modelID: The id of the model to return. 21 | // 22 | // Returns: 23 | // The text embedding model associated with the id 24 | // error of type NoSuchModelError if no such model exists 25 | TextEmbeddingModel(modelID string) (EmbeddingModel[string], error) 26 | 27 | // ImageModel returns the image model with the given id. 28 | // The model id is then passed to the provider function to get the model. 29 | // This method is optional and may return nil if image models are not supported. 30 | // 31 | // Parameters: 32 | // modelID: The id of the model to return. 33 | // 34 | // Returns: 35 | // The image model associated with the id, or nil if image models are not supported 36 | // error of type NoSuchModelError if no such model exists and image models are supported 37 | ImageModel(modelID string) (*ImageModel, error) 38 | } 39 | -------------------------------------------------------------------------------- /builder/convert.go: -------------------------------------------------------------------------------- 1 | package builder 2 | 3 | import "go.jetify.com/ai/api" 4 | 5 | func StreamToResponse(stream *api.StreamResponse) (*api.Response, error) { 6 | if stream == nil { 7 | return nil, nil 8 | } 9 | 10 | builder := NewResponseBuilder() 11 | 12 | // Add any metadata from the stream response 13 | if err := builder.AddMetadata(stream); err != nil { 14 | return nil, err 15 | } 16 | 17 | // Process each event in the stream 18 | for event := range stream.Stream { 19 | if err := builder.AddEvent(event); err != nil { 20 | return nil, err 21 | } 22 | } 23 | 24 | // Build the final response 25 | resp, err := builder.Build() 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | return &resp, nil 31 | } 32 | -------------------------------------------------------------------------------- /default.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "sync/atomic" 5 | 6 | "go.jetify.com/ai/api" 7 | "go.jetify.com/ai/provider/anthropic" 8 | ) 9 | 10 | var defaultLanguageModel atomic.Value 11 | 12 | func init() { 13 | model := anthropic.NewLanguageModel(anthropic.ModelClaude37Sonnet20250219) 14 | defaultLanguageModel.Store(model) 15 | } 16 | 17 | func SetDefaultLanguageModel(lm api.LanguageModel) { 18 | defaultLanguageModel.Store(lm) 19 | } 20 | 21 | func DefaultLanguageModel() api.LanguageModel { 22 | return defaultLanguageModel.Load().(api.LanguageModel) 23 | } 24 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # AI SDK Examples 2 | 3 | This directory contains practical examples showing how to use the AI SDK for Go in real applications. 4 | 5 | ## Prerequisites 6 | 7 | Set your API keys as environment variables: 8 | 9 | ```bash 10 | export OPENAI_API_KEY="your-openai-key-here" 11 | export ANTHROPIC_API_KEY="your-anthropic-key-here" 12 | ``` 13 | 14 | Get your API keys from: 15 | - **OpenAI**: [platform.openai.com/api-keys](https://platform.openai.com/api-keys) 16 | - **Anthropic**: [console.anthropic.com](https://console.anthropic.com/) 17 | 18 | ## Examples 19 | 20 | | Example | Description | 21 | |---------|-------------| 22 | | [**simple-text**](basic/simple-text/) | Generate text from a simple string prompt | 23 | 24 | ### More Examples Coming Soon 25 | 26 | - **Conversation** - Multi-message conversations with context 27 | - **Streaming** - Stream responses in real-time 28 | - **Multi-modal** - Working with images and files 29 | - **Tools** - Function calling and tool usage 30 | - **Advanced** - Production patterns and error handling 31 | - **Real-world** - Complete application examples 32 | 33 | 34 | ## How to Run 35 | 36 | Each example is a standalone Go program: 37 | 38 | ```bash 39 | cd basic/simple-text 40 | go run main.go 41 | ``` 42 | 43 | ## Need Help? 44 | 45 | - [API Documentation](https://pkg.go.dev/go.jetify.com/ai) 46 | - [Discord Community](https://discord.gg/jetify) 47 | - [GitHub Issues](https://github.com/jetify-com/ai/issues) -------------------------------------------------------------------------------- /examples/basic/simple-text/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | 7 | "github.com/k0kubun/pp/v3" 8 | "go.jetify.com/ai" 9 | "go.jetify.com/ai/api" 10 | "go.jetify.com/ai/provider/openai" 11 | ) 12 | 13 | func example() error { 14 | // Create a model 15 | model := openai.NewLanguageModel("gpt-4o-mini") 16 | 17 | // Generate text 18 | response, err := ai.GenerateTextStr( 19 | context.Background(), 20 | "Explain what artificial intelligence is in simple terms", 21 | ai.WithModel(model), 22 | ai.WithMaxOutputTokens(100), 23 | ) 24 | if err != nil { 25 | return err 26 | } 27 | 28 | // Print the response: 29 | printResponse(response) 30 | 31 | return nil 32 | } 33 | 34 | func printResponse(response api.Response) { 35 | response.ProviderMetadata = nil 36 | response.Warnings = nil 37 | printer := pp.New() 38 | printer.SetOmitEmpty(true) 39 | printer.Print(response) 40 | } 41 | 42 | func main() { 43 | if err := example(); err != nil { 44 | log.Fatal(err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module go.jetify.com/ai 2 | 3 | go 1.24.0 4 | 5 | require ( 6 | github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.13 7 | github.com/k0kubun/pp/v3 v3.4.1 8 | github.com/openai/openai-go v1.0.0 9 | github.com/sashabaranov/go-openai v1.40.0 10 | github.com/stretchr/testify v1.10.0 11 | go.jetify.com/pkg v0.0.0-20250602221042-f9c387a74b34 12 | go.jetify.com/sse v0.0.0-20250521180548-aeb6bc6de065 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/kr/text v0.2.0 // indirect 18 | github.com/mattn/go-colorable v0.1.14 // indirect 19 | github.com/mattn/go-isatty v0.0.20 // indirect 20 | github.com/pmezard/go-difflib v1.0.0 // indirect 21 | github.com/tidwall/gjson v1.14.4 // indirect 22 | github.com/tidwall/match v1.1.1 // indirect 23 | github.com/tidwall/pretty v1.2.1 // indirect 24 | github.com/tidwall/sjson v1.2.5 // indirect 25 | golang.org/x/sys v0.31.0 // indirect 26 | golang.org/x/text v0.23.0 // indirect 27 | gopkg.in/dnaeon/go-vcr.v4 v4.0.2 // indirect 28 | gopkg.in/yaml.v3 v3.0.1 // indirect 29 | ) 30 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.13 h1:xXipLb6/J8hP0GqKPBqK9mBa8nO8KbJWNI4CGx3rYmY= 2 | github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.13/go.mod h1:GJxtdOs9K4neo8Gg65CjJ7jNautmldGli5/OFNabOoo= 3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/k0kubun/pp/v3 v3.4.1 h1:1WdFZDRRqe8UsR61N/2RoOZ3ziTEqgTPVqKrHeb779Y= 7 | github.com/k0kubun/pp/v3 v3.4.1/go.mod h1:+SiNiqKnBfw1Nkj82Lh5bIeKQOAkPy6Xw9CAZUZ8npI= 8 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 9 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 10 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 11 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 12 | github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= 13 | github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= 14 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 15 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 16 | github.com/openai/openai-go v1.0.0 h1:KtP+VfrgzX9dHwHrLwHeyWmS0jjm16N+753Vi7OwEYg= 17 | github.com/openai/openai-go v1.0.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= 21 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= 22 | github.com/sashabaranov/go-openai v1.40.0 h1:Peg9Iag5mUJtPW00aYatlsn97YML0iNULiLNe74iPrU= 23 | github.com/sashabaranov/go-openai v1.40.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= 24 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 25 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 26 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 27 | github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= 28 | github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 29 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 30 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 31 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 32 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= 33 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 34 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= 35 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= 36 | go.jetify.com/pkg v0.0.0-20250602221042-f9c387a74b34 h1:moNpfamWyihzycLXKQneb3DDPL5YbeXtmvmFYIn8VeM= 37 | go.jetify.com/pkg v0.0.0-20250602221042-f9c387a74b34/go.mod h1:RLeG6AllDXfrqSzPz77HMQSXRv2ZM2YggPzt6J4Pz9g= 38 | go.jetify.com/sse v0.0.0-20250521180548-aeb6bc6de065 h1:qIfcJxr3QZG+bNZTONXjENzrsR3SeM7rP+hvpe9RFOE= 39 | go.jetify.com/sse v0.0.0-20250521180548-aeb6bc6de065/go.mod h1:zFADPn3Z0aZJe3+PbArGMGwe3oTwHxPZIwNILoRCmU8= 40 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 41 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 42 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 43 | golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= 44 | golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= 45 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 46 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 47 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 48 | gopkg.in/dnaeon/go-vcr.v4 v4.0.2 h1:7T5VYf2ifyK01ETHbJPl5A6XTpUljD4Trw3GEDcdedk= 49 | gopkg.in/dnaeon/go-vcr.v4 v4.0.2/go.mod h1:65yxh9goQVrudqofKtHA4JNFWd6XZRkWfKN4YpMx7KI= 50 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 51 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 52 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import "go.jetify.com/ai/api" 4 | 5 | type GenerateOptions struct { 6 | CallOptions api.CallOptions 7 | Model api.LanguageModel 8 | } 9 | 10 | // GenerateOption is a function that modifies GenerateConfig. 11 | type GenerateOption func(*GenerateOptions) 12 | 13 | // WithModel sets the language model to use for generation 14 | func WithModel(model api.LanguageModel) GenerateOption { 15 | return func(o *GenerateOptions) { 16 | o.Model = model 17 | } 18 | } 19 | 20 | // WithMaxOutputTokens specifies the maximum number of tokens to generate 21 | func WithMaxOutputTokens(maxTokens int) GenerateOption { 22 | return func(o *GenerateOptions) { 23 | o.CallOptions.MaxOutputTokens = maxTokens 24 | } 25 | } 26 | 27 | // WithTemperature controls randomness in the model's output. 28 | // It is recommended to set either Temperature or TopP, but not both. 29 | func WithTemperature(temperature float64) GenerateOption { 30 | return func(o *GenerateOptions) { 31 | o.CallOptions.Temperature = &temperature 32 | } 33 | } 34 | 35 | // WithStopSequences specifies sequences that will stop generation when produced. 36 | // Providers may have limits on the number of stop sequences. 37 | func WithStopSequences(stopSequences ...string) GenerateOption { 38 | return func(o *GenerateOptions) { 39 | o.CallOptions.StopSequences = stopSequences 40 | } 41 | } 42 | 43 | // WithTopP controls nucleus sampling. 44 | // It is recommended to set either Temperature or TopP, but not both. 45 | func WithTopP(topP float64) GenerateOption { 46 | return func(o *GenerateOptions) { 47 | o.CallOptions.TopP = topP 48 | } 49 | } 50 | 51 | // WithTopK limits sampling to the top K options for each token. 52 | // Used to remove "long tail" low probability responses. 53 | // Recommended for advanced use cases only. 54 | func WithTopK(topK int) GenerateOption { 55 | return func(o *GenerateOptions) { 56 | o.CallOptions.TopK = topK 57 | } 58 | } 59 | 60 | // WithPresencePenalty affects the likelihood of the model repeating 61 | // information that is already in the prompt 62 | func WithPresencePenalty(penalty float64) GenerateOption { 63 | return func(o *GenerateOptions) { 64 | o.CallOptions.PresencePenalty = penalty 65 | } 66 | } 67 | 68 | // WithFrequencyPenalty affects the likelihood of the model 69 | // repeatedly using the same words or phrases 70 | func WithFrequencyPenalty(penalty float64) GenerateOption { 71 | return func(o *GenerateOptions) { 72 | o.CallOptions.FrequencyPenalty = penalty 73 | } 74 | } 75 | 76 | // WithResponseFormat specifies whether the output should be text or JSON. 77 | // For JSON output, a schema can optionally guide the model. 78 | func WithResponseFormat(format *api.ResponseFormat) GenerateOption { 79 | return func(o *GenerateOptions) { 80 | o.CallOptions.ResponseFormat = format 81 | } 82 | } 83 | 84 | // WithSeed provides an integer seed for random sampling. 85 | // If supported by the model, calls will generate deterministic results. 86 | func WithSeed(seed int) GenerateOption { 87 | return func(o *GenerateOptions) { 88 | o.CallOptions.Seed = seed 89 | } 90 | } 91 | 92 | // WithHeaders specifies additional HTTP headers to send with the request. 93 | // Only applicable for HTTP-based providers. 94 | func WithHeaders(headers map[string]string) GenerateOption { 95 | return func(o *GenerateOptions) { 96 | o.CallOptions.Headers = headers 97 | } 98 | } 99 | 100 | // WithTools specifies the tools available for the model to use during generation. 101 | func WithTools(tools ...api.ToolDefinition) GenerateOption { 102 | return func(o *GenerateOptions) { 103 | o.CallOptions.Tools = tools 104 | } 105 | } 106 | 107 | // WithToolChoice specifies how the model should select which tool to use. 108 | func WithToolChoice(toolChoice *api.ToolChoice) GenerateOption { 109 | return func(o *GenerateOptions) { 110 | o.CallOptions.ToolChoice = toolChoice 111 | } 112 | } 113 | 114 | // WithProviderMetadata sets additional provider-specific metadata. 115 | // The metadata is passed through to the provider from the AI SDK and enables 116 | // provider-specific functionality that can be fully encapsulated in the provider. 117 | func WithProviderMetadata(providerName string, metadata any) GenerateOption { 118 | return func(o *GenerateOptions) { 119 | if o.CallOptions.ProviderMetadata == nil { 120 | o.CallOptions.ProviderMetadata = api.NewProviderMetadata(map[string]any{}) 121 | } 122 | o.CallOptions.ProviderMetadata.Set(providerName, metadata) 123 | } 124 | } 125 | 126 | // buildGenerateConfig combines multiple generate options into a single GenerateConfig struct. 127 | func buildGenerateConfig(opts []GenerateOption) GenerateOptions { 128 | config := GenerateOptions{ 129 | CallOptions: api.CallOptions{ 130 | ProviderMetadata: api.NewProviderMetadata(map[string]any{}), 131 | }, 132 | Model: DefaultLanguageModel(), 133 | } 134 | for _, opt := range opts { 135 | opt(&config) 136 | } 137 | return config 138 | } 139 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/sashabaranov/go-openai/jsonschema" 8 | "github.com/stretchr/testify/assert" 9 | "go.jetify.com/ai/api" 10 | "go.jetify.com/pkg/pointer" 11 | ) 12 | 13 | func TestCallOptionBuilders(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | option GenerateOption 17 | expected GenerateOptions 18 | }{ 19 | { 20 | name: "WithMaxOutputTokens", 21 | option: WithMaxOutputTokens(100), 22 | expected: GenerateOptions{ 23 | CallOptions: api.CallOptions{MaxOutputTokens: 100}, 24 | }, 25 | }, 26 | { 27 | name: "WithTemperature", 28 | option: WithTemperature(0.7), 29 | expected: GenerateOptions{ 30 | CallOptions: api.CallOptions{Temperature: pointer.Float64(0.7)}, 31 | }, 32 | }, 33 | { 34 | name: "WithStopSequences", 35 | option: WithStopSequences("stop1", "stop2"), 36 | expected: GenerateOptions{ 37 | CallOptions: api.CallOptions{StopSequences: []string{"stop1", "stop2"}}, 38 | }, 39 | }, 40 | { 41 | name: "WithStopSequences_Empty", 42 | option: WithStopSequences(), 43 | expected: GenerateOptions{ 44 | CallOptions: api.CallOptions{StopSequences: nil}, 45 | }, 46 | }, 47 | { 48 | name: "WithTopP", 49 | option: WithTopP(0.9), 50 | expected: GenerateOptions{ 51 | CallOptions: api.CallOptions{TopP: 0.9}, 52 | }, 53 | }, 54 | { 55 | name: "WithTopK", 56 | option: WithTopK(40), 57 | expected: GenerateOptions{ 58 | CallOptions: api.CallOptions{TopK: 40}, 59 | }, 60 | }, 61 | { 62 | name: "WithPresencePenalty", 63 | option: WithPresencePenalty(1.0), 64 | expected: GenerateOptions{ 65 | CallOptions: api.CallOptions{PresencePenalty: 1.0}, 66 | }, 67 | }, 68 | { 69 | name: "WithFrequencyPenalty", 70 | option: WithFrequencyPenalty(1.5), 71 | expected: GenerateOptions{ 72 | CallOptions: api.CallOptions{FrequencyPenalty: 1.5}, 73 | }, 74 | }, 75 | { 76 | name: "WithResponseFormat", 77 | option: WithResponseFormat(&api.ResponseFormat{ 78 | Type: "json", 79 | Schema: &jsonschema.Definition{}, 80 | Name: "test", 81 | Description: "test desc", 82 | }), 83 | expected: GenerateOptions{ 84 | CallOptions: api.CallOptions{ 85 | ResponseFormat: &api.ResponseFormat{ 86 | Type: "json", 87 | Schema: &jsonschema.Definition{}, 88 | Name: "test", 89 | Description: "test desc", 90 | }, 91 | }, 92 | }, 93 | }, 94 | { 95 | name: "WithSeed", 96 | option: WithSeed(42), 97 | expected: GenerateOptions{ 98 | CallOptions: api.CallOptions{Seed: 42}, 99 | }, 100 | }, 101 | { 102 | name: "WithHeaders", 103 | option: WithHeaders(map[string]string{"key": "value"}), 104 | expected: GenerateOptions{ 105 | CallOptions: api.CallOptions{Headers: map[string]string{"key": "value"}}, 106 | }, 107 | }, 108 | { 109 | name: "WithTools", 110 | option: WithTools(api.FunctionTool{Name: "test-tool"}), 111 | expected: GenerateOptions{ 112 | CallOptions: api.CallOptions{ 113 | Tools: []api.ToolDefinition{api.FunctionTool{Name: "test-tool"}}, 114 | }, 115 | }, 116 | }, 117 | { 118 | name: "WithProviderMetadata_SingleProvider", 119 | option: WithProviderMetadata("test-provider", map[string]any{ 120 | "key": "value", 121 | }), 122 | expected: GenerateOptions{ 123 | CallOptions: api.CallOptions{ 124 | ProviderMetadata: api.NewProviderMetadata(map[string]any{ 125 | "test-provider": map[string]any{ 126 | "key": "value", 127 | }, 128 | }), 129 | }, 130 | }, 131 | }, 132 | { 133 | name: "WithProviderMetadata_MultipleProviders", 134 | option: func() GenerateOption { 135 | return func(o *GenerateOptions) { 136 | WithProviderMetadata("provider1", map[string]any{"key1": "value1"})(o) 137 | WithProviderMetadata("provider2", map[string]any{"key2": "value2"})(o) 138 | } 139 | }(), 140 | expected: GenerateOptions{ 141 | CallOptions: api.CallOptions{ 142 | ProviderMetadata: api.NewProviderMetadata(map[string]any{ 143 | "provider1": map[string]any{"key1": "value1"}, 144 | "provider2": map[string]any{"key2": "value2"}, 145 | }), 146 | }, 147 | }, 148 | }, 149 | { 150 | name: "WithModel", 151 | option: WithModel(&mockLanguageModel{name: "test-model"}), 152 | expected: GenerateOptions{ 153 | Model: &mockLanguageModel{name: "test-model"}, 154 | }, 155 | }, 156 | } 157 | 158 | for _, tt := range tests { 159 | t.Run(tt.name, func(t *testing.T) { 160 | opts := &GenerateOptions{} 161 | tt.option(opts) 162 | assert.Equal(t, tt.expected, *opts) 163 | }) 164 | } 165 | } 166 | 167 | func TestBuildCallOptions(t *testing.T) { 168 | tests := []struct { 169 | name string 170 | opts []GenerateOption 171 | expected GenerateOptions 172 | }{ 173 | { 174 | name: "Default options", 175 | opts: []GenerateOption{}, 176 | expected: GenerateOptions{ 177 | CallOptions: api.CallOptions{ 178 | ProviderMetadata: api.NewProviderMetadata(map[string]any{}), 179 | }, 180 | Model: DefaultLanguageModel(), 181 | }, 182 | }, 183 | { 184 | name: "Multiple options", 185 | opts: []GenerateOption{ 186 | WithMaxOutputTokens(100), 187 | WithTemperature(0.7), 188 | }, 189 | expected: GenerateOptions{ 190 | CallOptions: api.CallOptions{ 191 | MaxOutputTokens: 100, 192 | Temperature: pointer.Float64(0.7), 193 | ProviderMetadata: api.NewProviderMetadata(map[string]any{}), 194 | }, 195 | Model: DefaultLanguageModel(), 196 | }, 197 | }, 198 | { 199 | name: "With tools", 200 | opts: []GenerateOption{ 201 | WithTools(api.FunctionTool{Name: "test-tool"}), 202 | }, 203 | expected: GenerateOptions{ 204 | CallOptions: api.CallOptions{ 205 | Tools: []api.ToolDefinition{api.FunctionTool{Name: "test-tool"}}, 206 | ProviderMetadata: api.NewProviderMetadata(map[string]any{}), 207 | }, 208 | Model: DefaultLanguageModel(), 209 | }, 210 | }, 211 | } 212 | 213 | for _, tt := range tests { 214 | t.Run(tt.name, func(t *testing.T) { 215 | opts := buildGenerateConfig(tt.opts) 216 | assert.Equal(t, tt.expected, opts) 217 | }) 218 | } 219 | } 220 | 221 | // mockLanguageModel implements api.LanguageModel for testing 222 | type mockLanguageModel struct { 223 | name string 224 | } 225 | 226 | func (m *mockLanguageModel) Generate(ctx context.Context, prompt []api.Message, opts api.CallOptions) (api.Response, error) { 227 | return api.Response{}, nil 228 | } 229 | 230 | func (m *mockLanguageModel) Stream(ctx context.Context, prompt []api.Message, opts api.CallOptions) (api.StreamResponse, error) { 231 | return api.StreamResponse{}, nil 232 | } 233 | 234 | func (m *mockLanguageModel) ModelID() string { 235 | return m.name 236 | } 237 | 238 | func (m *mockLanguageModel) ProviderName() string { 239 | return "mock-provider" 240 | } 241 | 242 | func (m *mockLanguageModel) SupportedUrls() []api.SupportedURL { 243 | return nil 244 | } 245 | -------------------------------------------------------------------------------- /provider/anthropic/codec/decode.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | 7 | "github.com/anthropics/anthropic-sdk-go" 8 | "go.jetify.com/ai/api" 9 | ) 10 | 11 | // DecodeResponse converts an Anthropic Message to the AI SDK Response type 12 | func DecodeResponse(msg *anthropic.BetaMessage) (api.Response, error) { 13 | if msg == nil { 14 | return api.Response{}, errors.New("nil message provided") 15 | } 16 | 17 | response := api.Response{ 18 | FinishReason: decodeFinishReason(msg.StopReason), 19 | Usage: decodeUsage(msg.Usage), 20 | ResponseInfo: decodeResponseInfo(msg), 21 | ProviderMetadata: decodeProviderMetadata(msg), 22 | } 23 | 24 | response.Content = decodeContent(msg.Content) 25 | 26 | return response, nil 27 | } 28 | 29 | // decodeResponseInfo extracts the response info from an Anthropic message 30 | func decodeResponseInfo(msg *anthropic.BetaMessage) *api.ResponseInfo { 31 | return &api.ResponseInfo{ 32 | ID: msg.ID, 33 | ModelID: msg.Model, 34 | } 35 | } 36 | 37 | // decodeProviderMetadata extracts Anthropic-specific metadata 38 | func decodeProviderMetadata(msg *anthropic.BetaMessage) *api.ProviderMetadata { 39 | return api.NewProviderMetadata(map[string]any{ 40 | "anthropic": &Metadata{ 41 | Usage: Usage{ 42 | InputTokens: msg.Usage.InputTokens, 43 | OutputTokens: msg.Usage.OutputTokens, 44 | CacheCreationInputTokens: msg.Usage.CacheCreationInputTokens, 45 | CacheReadInputTokens: msg.Usage.CacheReadInputTokens, 46 | }, 47 | }, 48 | }) 49 | } 50 | 51 | // decodeContent processes the content blocks from an Anthropic message 52 | // and returns an ordered slice of content blocks 53 | func decodeContent(blocks []anthropic.BetaContentBlock) []api.ContentBlock { 54 | content := make([]api.ContentBlock, 0) 55 | 56 | if blocks == nil { 57 | return content 58 | } 59 | 60 | for _, block := range blocks { 61 | switch block.Type { 62 | case anthropic.BetaContentBlockTypeText: 63 | // Only add text block if it has content 64 | if block.Text != "" { 65 | content = append(content, &api.TextBlock{ 66 | Text: block.Text, 67 | }) 68 | } 69 | case anthropic.BetaContentBlockTypeToolUse: 70 | content = append(content, decodeToolUse(block)) 71 | case anthropic.BetaContentBlockTypeThinking, anthropic.BetaContentBlockTypeRedactedThinking: 72 | if reasoningBlock := decodeReasoning(block); reasoningBlock != nil { 73 | content = append(content, reasoningBlock) 74 | } 75 | } 76 | } 77 | 78 | return content 79 | } 80 | 81 | // decodeToolUse converts an Anthropic tool use block to an AI SDK ToolCallBlock 82 | func decodeToolUse(block anthropic.BetaContentBlock) *api.ToolCallBlock { 83 | var args string 84 | if block.Input != nil { 85 | rawArgs, err := json.Marshal(block.Input) 86 | if err == nil { 87 | args = string(rawArgs) 88 | } else { 89 | // If marshaling fails, use empty JSON object 90 | args = "{}" 91 | } 92 | } else { 93 | args = "{}" 94 | } 95 | 96 | return &api.ToolCallBlock{ 97 | ToolCallID: block.ID, 98 | ToolName: block.Name, 99 | Args: json.RawMessage(args), 100 | } 101 | } 102 | 103 | // decodeReasoning converts an Anthropic thinking block to an AI SDK ReasoningBlock 104 | func decodeReasoning(block anthropic.BetaContentBlock) api.Reasoning { 105 | if block.Type == anthropic.BetaContentBlockTypeThinking { 106 | // Check for nil or empty thinking text 107 | if block.Thinking == "" { 108 | return nil 109 | } 110 | return &api.ReasoningBlock{ 111 | Text: block.Thinking, 112 | Signature: block.Signature, 113 | } 114 | } else if block.Type == anthropic.BetaContentBlockTypeRedactedThinking { 115 | // Check for nil or empty data 116 | if block.Data == "" { 117 | return nil 118 | } 119 | return &api.RedactedReasoningBlock{ 120 | Data: block.Data, 121 | } 122 | } 123 | return nil 124 | } 125 | 126 | // decodeUsage converts Anthropic Usage to API SDK Usage 127 | func decodeUsage(usage anthropic.BetaUsage) api.Usage { 128 | return api.Usage{ 129 | InputTokens: int(usage.InputTokens), 130 | OutputTokens: int(usage.OutputTokens), 131 | TotalTokens: int(usage.InputTokens + usage.OutputTokens), 132 | CachedInputTokens: int(usage.CacheReadInputTokens), 133 | } 134 | } 135 | 136 | // decodeFinishReason converts an Anthropic stop reason to an AI SDK FinishReason type. 137 | // It handles nil/empty values by returning FinishReasonUnknown. 138 | func decodeFinishReason(finishReason anthropic.BetaMessageStopReason) api.FinishReason { 139 | switch finishReason { 140 | case anthropic.BetaMessageStopReasonEndTurn, anthropic.BetaMessageStopReasonStopSequence: 141 | return api.FinishReasonStop 142 | case anthropic.BetaMessageStopReasonToolUse: 143 | return api.FinishReasonToolCalls 144 | case anthropic.BetaMessageStopReasonMaxTokens: 145 | return api.FinishReasonLength 146 | default: 147 | return api.FinishReasonUnknown 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /provider/anthropic/codec/encode_params.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/anthropics/anthropic-sdk-go" 7 | "go.jetify.com/ai/api" 8 | ) 9 | 10 | func EncodeParams( 11 | prompt []api.Message, opts api.CallOptions, 12 | ) (anthropic.BetaMessageNewParams, []api.CallWarning, error) { 13 | anthropicPrompt, err := EncodePrompt(prompt) 14 | if err != nil { 15 | return anthropic.BetaMessageNewParams{}, []api.CallWarning{}, err 16 | } 17 | 18 | params, warnings, err := encodeCallOptions(opts) 19 | if err != nil { 20 | return anthropic.BetaMessageNewParams{}, warnings, err 21 | } 22 | 23 | if len(anthropicPrompt.System) > 0 { 24 | params.System = anthropic.F(anthropicPrompt.System) 25 | } 26 | if len(anthropicPrompt.Messages) > 0 { 27 | params.Messages = anthropic.F(anthropicPrompt.Messages) 28 | } 29 | 30 | params.Betas = anthropic.F(append(params.Betas.Value, anthropicPrompt.Betas...)) 31 | 32 | return params, warnings, nil 33 | } 34 | 35 | func encodeCallOptions(opts api.CallOptions) (anthropic.BetaMessageNewParams, []api.CallWarning, error) { 36 | params := anthropic.BetaMessageNewParams{ 37 | MaxTokens: anthropic.F(int64(4096)), // Default max tokens 38 | } 39 | 40 | // Set basic parameters 41 | if opts.MaxOutputTokens > 0 { 42 | params.MaxTokens = anthropic.F(int64(opts.MaxOutputTokens)) 43 | } 44 | if opts.Temperature != nil { 45 | params.Temperature = anthropic.F(*opts.Temperature) 46 | } 47 | if opts.TopP > 0 { 48 | params.TopP = anthropic.F(opts.TopP) 49 | } 50 | if opts.TopK > 0 { 51 | params.TopK = anthropic.F(int64(opts.TopK)) 52 | } 53 | if len(opts.StopSequences) > 0 { 54 | params.StopSequences = anthropic.F(opts.StopSequences) 55 | } 56 | 57 | // Handle unsupported settings 58 | warnings := unsupportedWarnings(opts) 59 | 60 | // Handle thinking-specific configuration 61 | thinkingWarnings, err := encodeThinking(¶ms, opts) 62 | if err != nil { 63 | return params, warnings, err 64 | } 65 | warnings = append(warnings, thinkingWarnings...) 66 | 67 | // Handle tool configuration 68 | tools, err := EncodeTools(opts.Tools, opts.ToolChoice) 69 | if err != nil { 70 | return params, warnings, err 71 | } 72 | 73 | // Apply tool configuration to params 74 | params.Betas = anthropic.F(append(params.Betas.Value, tools.Betas...)) 75 | warnings = append(warnings, tools.Warnings...) 76 | 77 | if len(tools.Tools) > 0 { 78 | params.Tools = anthropic.F(tools.Tools) 79 | } 80 | if len(tools.ToolChoice) > 0 { 81 | params.ToolChoice = anthropic.F(tools.ToolChoice[0]) 82 | } 83 | return params, warnings, nil 84 | } 85 | 86 | func unsupportedWarnings(opts api.CallOptions) []api.CallWarning { 87 | var warnings []api.CallWarning 88 | 89 | if opts.FrequencyPenalty != 0 { 90 | warnings = append(warnings, api.CallWarning{ 91 | Type: "unsupported-setting", 92 | Setting: "FrequencyPenalty", 93 | }) 94 | } 95 | 96 | if opts.PresencePenalty != 0 { 97 | warnings = append(warnings, api.CallWarning{ 98 | Type: "unsupported-setting", 99 | Setting: "PresencePenalty", 100 | }) 101 | } 102 | 103 | if opts.Seed != 0 { 104 | warnings = append(warnings, api.CallWarning{ 105 | Type: "unsupported-setting", 106 | Setting: "Seed", 107 | }) 108 | } 109 | 110 | if opts.ResponseFormat != nil && opts.ResponseFormat.Type != "text" { 111 | warnings = append(warnings, api.CallWarning{ 112 | Type: "unsupported-setting", 113 | Setting: "ResponseFormat", 114 | Details: "JSON response format is not supported.", 115 | }) 116 | } 117 | 118 | return warnings 119 | } 120 | 121 | func encodeThinking(params *anthropic.BetaMessageNewParams, opts api.CallOptions) ([]api.CallWarning, error) { 122 | var warnings []api.CallWarning 123 | 124 | metadata := GetMetadata(opts) 125 | thinkingEnabled := metadata != nil && metadata.Thinking.Enabled 126 | 127 | if !thinkingEnabled { 128 | return warnings, nil 129 | } 130 | 131 | if metadata.Thinking.BudgetTokens == 0 { 132 | return warnings, fmt.Errorf("thinking requires a budget") 133 | } 134 | 135 | // Configure thinking parameters 136 | params.Thinking = anthropic.F[anthropic.BetaThinkingConfigParamUnion]( 137 | anthropic.BetaThinkingConfigEnabledParam{ 138 | Type: anthropic.F(anthropic.BetaThinkingConfigEnabledTypeEnabled), 139 | BudgetTokens: anthropic.F(int64(metadata.Thinking.BudgetTokens)), 140 | }) 141 | 142 | // Adjust max tokens to account for thinking budget 143 | params.MaxTokens = anthropic.F(params.MaxTokens.Value + int64(metadata.Thinking.BudgetTokens)) 144 | 145 | // Add warnings for unsupported settings when thinking is enabled 146 | if opts.Temperature != nil { 147 | warnings = append(warnings, api.CallWarning{ 148 | Type: "unsupported-setting", 149 | Setting: "Temperature", 150 | Details: "Temperature is not supported when thinking is enabled", 151 | }) 152 | } 153 | 154 | if opts.TopK > 0 { 155 | warnings = append(warnings, api.CallWarning{ 156 | Type: "unsupported-setting", 157 | Setting: "TopK", 158 | Details: "TopK is not supported when thinking is enabled", 159 | }) 160 | } 161 | 162 | if opts.TopP > 0 { 163 | warnings = append(warnings, api.CallWarning{ 164 | Type: "unsupported-setting", 165 | Setting: "TopP", 166 | Details: "TopP is not supported when thinking is enabled", 167 | }) 168 | } 169 | 170 | return warnings, nil 171 | } 172 | -------------------------------------------------------------------------------- /provider/anthropic/codec/metadata.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "github.com/anthropics/anthropic-sdk-go" 5 | "go.jetify.com/ai/api" 6 | ) 7 | 8 | // For now we are using a single type for all metadata. 9 | // TODO: Decide if we will need different types for different metadata. 10 | type Metadata struct { 11 | // --- Used in requests --- 12 | CacheControl string `json:"cache_control,omitempty"` 13 | 14 | // --- Used in responses --- 15 | 16 | Thinking ThinkingConfig `json:"thinking,omitzero"` 17 | Usage Usage `json:"usage,omitempty"` 18 | } 19 | 20 | func GetMetadata(source api.MetadataSource) *Metadata { 21 | return api.GetMetadata[Metadata]("anthropic", source) 22 | } 23 | 24 | // Anthropic-specific Usage information like CacheCreationInputTokens 25 | // and CacheReadInputTokens. 26 | type Usage anthropic.BetaUsage 27 | 28 | // ThinkingConfig represents the configuration for thinking behavior 29 | type ThinkingConfig struct { 30 | // Whether to enable extended thinking. 31 | // 32 | // See 33 | // [extended thinking](https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking) 34 | // for details. 35 | Enabled bool `json:"enabled,omitzero"` 36 | 37 | // Determines how many tokens Claude can use for its internal reasoning process. 38 | // Larger budgets can enable more thorough analysis for complex problems, improving 39 | // response quality. 40 | // 41 | // Must be ≥1024 and less than `max_tokens`. 42 | BudgetTokens int `json:"budgetTokens,omitzero"` 43 | } 44 | -------------------------------------------------------------------------------- /provider/anthropic/codec/tools.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "go.jetify.com/ai/api" 7 | ) 8 | 9 | // ComputerAction is a ComputerToolCall action. 10 | type ComputerAction string 11 | 12 | const ( 13 | // ActionKey presses the key or key-combination specfied in 14 | // [ComputerToolCall.Text]. It supports xdotool's key syntax. 15 | ActionKey ComputerAction = "key" 16 | 17 | // ActionHoldKey holds down the key or key-combination in 18 | // [ComputerToolCall.Text] for a duration of 19 | // [ComputerToolCall.Duration]. It supports the same syntax as 20 | // [ActionKey]. 21 | ActionHoldKey ComputerAction = "hold_key" 22 | 23 | // ActionType types the string specified by [ComputerToolCall.Text] on 24 | // the keyboard. 25 | ActionType ComputerAction = "type" 26 | 27 | // ActionCursorPosition reports the current (x, y) pixel coordinates of 28 | // the cursor on the screen. 29 | ActionCursorPosition ComputerAction = "cursor_position" 30 | 31 | // ActionMouseMove moves the cursor to the pixel specified by 32 | // [ComputerToolCall.Coordinate]. 33 | ActionMouseMove ComputerAction = "mouse_move" 34 | 35 | // ActionLeftMouseDown presses down the left mouse button without 36 | // releasing it. 37 | ActionLeftMouseDown ComputerAction = "left_mouse_down" 38 | 39 | // ActionLeftMouseUp releases the left mouse button. 40 | ActionLeftMouseUp ComputerAction = "left_mouse_up" 41 | 42 | // ActionLeftClick clicks the left mouse button at 43 | // [ComputerToolCall.Coordinate] while optionally holding down the keys 44 | // in [ComputerToolCall.Text]. 45 | ActionLeftClick ComputerAction = "left_click" 46 | 47 | // ActionLeftClickDrag clicks and drags the cursor from 48 | // [ComputerToolCall.StartCoordinate] to [ComputerToolCall.Coordinate]. 49 | ActionLeftClickDrag ComputerAction = "left_click_drag" 50 | 51 | // ActionRightClick clicks the right mouse button at 52 | // [ComputerToolCall.Coordinate]. 53 | ActionRightClick ComputerAction = "right_click" 54 | 55 | // ActionMiddleClick clicks the middle mouse button at 56 | // [ComputerToolCall.Coordinate]. 57 | ActionMiddleClick ComputerAction = "middle_click" 58 | 59 | // ActionDoubleClick double-clicks the left mouse button at 60 | // [ComputerToolCall.Coordinate]. 61 | ActionDoubleClick ComputerAction = "double_click" 62 | 63 | // ActionTripleClick triple-clicks the left mouse button at 64 | // [ComputerToolCall.Coordinate]. 65 | ActionTripleClick ComputerAction = "triple_click" 66 | 67 | // ActionScroll turns the mouse scroll wheel by 68 | // [ComputerToolCall.ScrollAmount] in the direction of 69 | // [ComputerToolCall.ScrollDirection] at [ComputerToolCall.Coordinate]. 70 | ActionScroll ComputerAction = "scroll" 71 | 72 | // ActionWait pauses execution for [ComputerToolCall.Duration]. 73 | ActionWait ComputerAction = "wait" 74 | 75 | // ActionScreenshot takes a screenshot. 76 | ActionScreenshot ComputerAction = "screenshot" 77 | ) 78 | 79 | // ScrollDirection is a direction to scroll the screen. 80 | type ScrollDirection string 81 | 82 | const ( 83 | ScrollUp = "up" 84 | ScrollDown = "down" 85 | ScrollLeft = "left" 86 | ScrollRight = "right" 87 | ) 88 | 89 | // TODO(gcurtis): make ComputerToolCall implement json.Unmarshaler so it can 90 | // have better types for some of its fields: 91 | // 92 | // - dedicated coordinate type 93 | // - duration should be time.Duration 94 | // - use ints instead of json.Number while still being flexible about 95 | // accepting ints, floats, or number strings 96 | 97 | // ComputerToolCall contains the parameters of a call to [ComputerUseTool]. 98 | type ComputerToolCall struct { 99 | // Action is the action to perform. It is the only mandatory field. 100 | Action ComputerAction `json:"action"` 101 | 102 | // Text is a key, key-combination, or string literal to type on the 103 | // keyboard. It specifies individual key presses or key-combinations 104 | // using an xdotool-style syntax. Examples include "a", "Return", 105 | // "alt+Tab", "ctrl+s", "Up", "KP_0" (for numpad 0). The ActionType, 106 | // ActionKey, and ActionHoldKey actions require a non-empty Text value. 107 | // Click or scroll actions may optionally set Text to specify keys to 108 | // hold down keys during the click or scroll. 109 | Text string `json:"text,omitzero"` 110 | 111 | // Coordinate is a pair of (x, y) on-screen pixel coordinates for cursor 112 | // actions. (0, 0) is the top-left corner of the screen. The 113 | // ActionMouseMove and ActionLeftClickDrag actions require a coordinate. 114 | Coordinate [2]json.Number `json:"coordinate,omitzero"` 115 | 116 | // StartCoordinate is the starting point for mouse drag actions. 117 | StartCoordinate [2]json.Number `json:"start_coordinate,omitzero"` 118 | 119 | // Duration is the number of seconds to hold down keys or pause 120 | // execution. The ActionHoldKey and ActionWait actions require a 121 | // non-zero Duration. 122 | Duration json.Number `json:"duration,omitzero"` 123 | 124 | // ScrollAmount is the number of mouse wheel "clicks" to scroll. 125 | // ActionScroll requires a non-zero ScrollAmount. 126 | ScrollAmount json.Number `json:"scroll_amount,omitzero"` 127 | 128 | // ScrollDirection is the direction to scroll. ActionScroll requires 129 | // a ScrollDirection. 130 | ScrollDirection ScrollDirection `json:"scroll_direction,omitzero"` 131 | } 132 | 133 | // ComputerUseTool is a built-in tool that can be used to control a computer. 134 | // It allows the model to use a mouse and keyboard and to take screenshots. 135 | // See the [computer use guide](https://docs.anthropic.com/en/docs/agents-and-tools/computer-use) for more details. 136 | type ComputerUseTool struct { 137 | // The version of the computer tool to use. 138 | // Optional field, defaults to the latest version. Possible values are: "20250124", "20241022". 139 | Version string `json:"version"` 140 | 141 | // The width of the display being controlled by the model in pixels. 142 | // Required field. We recommend setting it to 1280. 143 | DisplayWidthPx int `json:"display_width_px"` 144 | // The height of the display being controlled by the model in pixels. 145 | // Required field. We recommend setting it to 800. 146 | DisplayHeightPx int `json:"display_height_px"` 147 | 148 | // The display number to control (only relevant for X11 environments). 149 | // Optional field, if specified, the tool will be provided a display number in the tool definition. 150 | DisplayNumber int `json:"display_number,omitzero"` 151 | } 152 | 153 | var _ api.ProviderDefinedTool = &ComputerUseTool{} 154 | 155 | func (t *ComputerUseTool) ToolType() string { return "provider-defined" } 156 | 157 | func (t *ComputerUseTool) ID() string { 158 | return "anthropic.computer" 159 | } 160 | 161 | func (t *ComputerUseTool) Name() string { return "computer" } 162 | 163 | // BashTool is a built-in tool that can be used to run shell commands. 164 | // See the [computer use guide](https://docs.anthropic.com/en/docs/agents-and-tools/computer-use) for more details. 165 | type BashTool struct { 166 | // The version of the bash tool to use. 167 | // Optional field, defaults to the latest version. Possible values are: "20250124", "20241022". 168 | Version string `json:"version"` 169 | } 170 | 171 | var _ api.ProviderDefinedTool = &BashTool{} 172 | 173 | func (t *BashTool) ToolType() string { return "provider-defined" } 174 | 175 | func (t *BashTool) ID() string { 176 | return "anthropic.bash" 177 | } 178 | 179 | func (t *BashTool) Name() string { return "bash" } 180 | 181 | // TextEditorTool is a built-in tool that can be used to view, create and edit text files. 182 | // See the [text editor guide](https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool) for more details. 183 | type TextEditorTool struct { 184 | // The version of the text editor tool to use. 185 | // Optional field, defaults to the latest version. Possible values are: "20250124", "20241022". 186 | Version string `json:"version"` 187 | } 188 | 189 | var _ api.ProviderDefinedTool = &TextEditorTool{} 190 | 191 | func (t *TextEditorTool) ToolType() string { return "provider-defined" } 192 | 193 | func (t *TextEditorTool) ID() string { 194 | return "anthropic.text_editor" 195 | } 196 | 197 | func (t *TextEditorTool) Name() string { return "str_replace_editor" } 198 | 199 | // TODO: Add predefined tool call blocks for the different built-in tools. 200 | -------------------------------------------------------------------------------- /provider/anthropic/constants.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | // ProviderName is the name of the Anthropic provider. 4 | const ProviderName = "anthropic" 5 | 6 | const ( 7 | // Claude 3.7 Models 8 | 9 | ModelClaude3_7SonnetLatest = "claude-3-7-sonnet-latest" 10 | ModelClaude37Sonnet20250219 = "claude-3-7-sonnet-20250219" 11 | 12 | // Claude 3.5 Models 13 | 14 | ModelClaude35HaikuLatest = "claude-3-5-haiku-latest" 15 | ModelClaude35Haiku20241022 = "claude-3-5-haiku-20241022" 16 | 17 | ModelClaude35SonnetLatest = "claude-3-5-sonnet-latest" 18 | ModelClaude35Sonnet20241022 = "claude-3-5-sonnet-20241022" 19 | ModelClaude35Sonnet20240620 = "claude-3-5-sonnet-20240620" 20 | 21 | // Claude 3.0 Models 22 | 23 | ModelClaude3OpusLatest = "claude-3-opus-latest" 24 | ModelClaude3Opus20240229 = "claude-3-opus-20240229" 25 | // Deprecated: Will reach end-of-life on July 21st, 2025. Please migrate to a newer 26 | // model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for 27 | // more information. 28 | ModelClaude3Sonnet20240229 = "claude-3-sonnet-20240229" 29 | ModelClaude3Haiku20240307 = "claude-3-haiku-20240307" 30 | 31 | // Claude 2 Models 32 | 33 | // Deprecated: Will reach end-of-life on July 21st, 2025. Please migrate to a newer 34 | // model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for 35 | // more information. 36 | ModelClaude21 = "claude-2.1" 37 | // Deprecated: Will reach end-of-life on July 21st, 2025. Please migrate to a newer 38 | // model. Visit https://docs.anthropic.com/en/docs/resources/model-deprecations for 39 | // more information. 40 | ModelClaude20 = "claude-2.0" 41 | ) 42 | -------------------------------------------------------------------------------- /provider/anthropic/llm.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/anthropics/anthropic-sdk-go" 7 | "go.jetify.com/ai/api" 8 | "go.jetify.com/ai/provider/anthropic/codec" 9 | ) 10 | 11 | // ModelOption is a function type that modifies a LanguageModel. 12 | type ModelOption func(*LanguageModel) 13 | 14 | // WithClient returns a ModelOption that sets the client. 15 | func WithClient(client *anthropic.Client) ModelOption { 16 | // TODO: Instead of only supporting an anthropic.Client, we can "flatten" 17 | // the options supported by the Anthropic SDK. 18 | return func(m *LanguageModel) { 19 | m.client = client 20 | } 21 | } 22 | 23 | // LanguageModel represents an Anthropic language model. 24 | type LanguageModel struct { 25 | modelID string 26 | client *anthropic.Client 27 | } 28 | 29 | var _ api.LanguageModel = &LanguageModel{} 30 | 31 | // NewLanguageModel creates a new Anthropic language model. 32 | func NewLanguageModel(modelID string, opts ...ModelOption) *LanguageModel { 33 | // Create model with default settings 34 | model := &LanguageModel{ 35 | modelID: modelID, 36 | client: anthropic.NewClient(), // Default client 37 | } 38 | 39 | // Apply options 40 | for _, opt := range opts { 41 | opt(model) 42 | } 43 | 44 | return model 45 | } 46 | 47 | func (m *LanguageModel) ProviderName() string { 48 | return ProviderName 49 | } 50 | 51 | func (m *LanguageModel) ModelID() string { 52 | return m.modelID 53 | } 54 | 55 | func (m *LanguageModel) SupportedUrls() []api.SupportedURL { 56 | // TODO: Make configurable via the constructor. 57 | return []api.SupportedURL{ 58 | { 59 | MediaType: "image/*", 60 | URLPatterns: []string{ 61 | "^https?://.*", 62 | }, 63 | }, 64 | } 65 | } 66 | 67 | func (m *LanguageModel) Generate( 68 | ctx context.Context, prompt []api.Message, opts api.CallOptions, 69 | ) (api.Response, error) { 70 | params, warnings, err := codec.EncodeParams(prompt, opts) 71 | if err != nil { 72 | return api.Response{}, err 73 | } 74 | 75 | message, err := m.client.Beta.Messages.New(ctx, params) 76 | if err != nil { 77 | return api.Response{}, err 78 | } 79 | 80 | response, err := codec.DecodeResponse(message) 81 | if err != nil { 82 | return api.Response{}, err 83 | } 84 | 85 | response.Warnings = append(response.Warnings, warnings...) 86 | return response, nil 87 | } 88 | 89 | func (m *LanguageModel) Stream( 90 | ctx context.Context, prompt []api.Message, opts api.CallOptions, 91 | ) (api.StreamResponse, error) { 92 | return api.StreamResponse{}, api.NewUnsupportedFunctionalityError("streaming generation", "") 93 | } 94 | -------------------------------------------------------------------------------- /provider/anthropic/llm_test.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/anthropics/anthropic-sdk-go" 8 | "github.com/anthropics/anthropic-sdk-go/option" 9 | "github.com/stretchr/testify/require" 10 | "go.jetify.com/ai/api" 11 | "go.jetify.com/pkg/httpmock" 12 | ) 13 | 14 | func TestGenerate(t *testing.T) { 15 | tests := []struct { 16 | name string 17 | prompt []api.Message 18 | exchanges []httpmock.Exchange 19 | expected *api.Response // Expected response for successful cases 20 | expectError string // Expected error message, empty means no error expected 21 | }{ 22 | { 23 | name: "successful generation with user message", 24 | prompt: []api.Message{ 25 | &api.UserMessage{ 26 | Content: api.ContentFromText("Hello, how are you?"), 27 | }, 28 | }, 29 | exchanges: []httpmock.Exchange{ 30 | { 31 | Request: httpmock.Request{ 32 | Method: http.MethodPost, 33 | Path: "/v1/messages", 34 | }, 35 | Response: httpmock.Response{ 36 | StatusCode: http.StatusOK, 37 | Body: &anthropic.Message{ 38 | Content: []anthropic.ContentBlock{ 39 | { 40 | Text: "I'm doing well, thank you for asking!", 41 | Type: anthropic.ContentBlockTypeText, 42 | }, 43 | }, 44 | Role: anthropic.MessageRoleAssistant, 45 | }, 46 | }, 47 | }, 48 | }, 49 | expected: &api.Response{ 50 | Content: []api.ContentBlock{ 51 | &api.TextBlock{ 52 | Text: "I'm doing well, thank you for asking!", 53 | }, 54 | }, 55 | }, 56 | }, 57 | { 58 | name: "successful generation with system message", 59 | prompt: []api.Message{ 60 | &api.SystemMessage{Content: "You are a helpful assistant"}, 61 | &api.UserMessage{ 62 | Content: api.ContentFromText("What's 2+2?"), 63 | }, 64 | }, 65 | exchanges: []httpmock.Exchange{ 66 | { 67 | Request: httpmock.Request{ 68 | Method: http.MethodPost, 69 | Path: "/v1/messages", 70 | }, 71 | Response: httpmock.Response{ 72 | StatusCode: http.StatusOK, 73 | Body: &anthropic.Message{ 74 | Content: []anthropic.ContentBlock{ 75 | { 76 | Text: "4", 77 | Type: anthropic.ContentBlockTypeText, 78 | }, 79 | }, 80 | Role: anthropic.MessageRoleAssistant, 81 | }, 82 | }, 83 | }, 84 | }, 85 | expected: &api.Response{ 86 | Content: []api.ContentBlock{ 87 | &api.TextBlock{ 88 | Text: "4", 89 | }, 90 | }, 91 | }, 92 | }, 93 | { 94 | name: "api error", 95 | prompt: []api.Message{ 96 | &api.UserMessage{ 97 | Content: api.ContentFromText("Hello"), 98 | }, 99 | }, 100 | exchanges: []httpmock.Exchange{ 101 | { 102 | Request: httpmock.Request{ 103 | Method: http.MethodPost, 104 | Path: "/v1/messages", 105 | }, 106 | Response: httpmock.Response{ 107 | StatusCode: http.StatusInternalServerError, 108 | Body: map[string]any{"error": "internal server error"}, 109 | }, 110 | }, 111 | }, 112 | expectError: "500 Internal Server Error", 113 | }, 114 | } 115 | 116 | for _, tt := range tests { 117 | t.Run(tt.name, func(t *testing.T) { 118 | server := httpmock.NewServer(t, tt.exchanges) 119 | defer server.Close() 120 | 121 | // Create client with mock server URL and test API key 122 | client := anthropic.NewClient( 123 | option.WithBaseURL(server.BaseURL()), 124 | option.WithAPIKey("test-key"), 125 | option.WithMaxRetries(0), // Disable retries 126 | ) 127 | 128 | // Create model with mocked client 129 | model := NewLanguageModel("claude-3", WithClient(client)) 130 | 131 | // Call Generate with empty CallOptions 132 | resp, err := model.Generate(t.Context(), tt.prompt, api.CallOptions{}) 133 | 134 | if tt.expectError != "" { 135 | require.Error(t, err) 136 | require.Contains(t, err.Error(), tt.expectError) 137 | return 138 | } 139 | 140 | require.NoError(t, err) 141 | require.NotNil(t, resp) 142 | 143 | // For successful cases, verify response content matches expected 144 | if tt.expected != nil { 145 | require.Equal(t, tt.expected.Content, resp.Content) 146 | } 147 | }) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /provider/anthropic/plan/prompt.md: -------------------------------------------------------------------------------- 1 | We are working on an AI SDK that allows you to use a common interface against different LLM providers. 2 | 3 | We've already implemented an openrouter provider, and now we want to implement an Anthropic provider using the anthropic go client @https://github.com/anthropics/anthropic-sdk-go 4 | 5 | You can use go doc commands to understand the anthropic SDK. The import path is: 6 | ``` 7 | import ( 8 | "github.com/anthropics/anthropic-sdk-go" // imported as anthropic 9 | ) 10 | ``` 11 | 12 | We need to translate our AI SDK prompt types defined in @llm_prompt.go into the corresponding types from the anthropic sdk. 13 | 14 | Implement the encoding functions in: @encode_prompt.go 15 | 16 | But first take a look at @encode_prompt.go to see how we did it for the openrouter case (we want to do something analougous) 17 | 18 | -------------------------------------------------------------------------------- /provider/anthropic/tools.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import "go.jetify.com/ai/provider/anthropic/codec" 4 | 5 | type ( 6 | ComputerUseTool = codec.ComputerUseTool 7 | ComputerToolCall = codec.ComputerToolCall 8 | ) 9 | 10 | type BashTool = codec.BashTool 11 | 12 | type TextEditorTool = codec.TextEditorTool 13 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/finish_reason.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | const ( 4 | FinishReasonStop = "stop" 5 | FinishReasonLength = "length" 6 | FinishReasonContentFilter = "content_filter" 7 | FinishReasonFunctionCall = "function_call" 8 | FinishReasonToolCalls = "tool_calls" 9 | ) 10 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/logprob.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | // LogProbs represents the logprobs structure returned by the OpenRouter 4 | // chat API. 5 | type LogProbs struct { 6 | Content []LogProb `json:"content,omitempty"` 7 | } 8 | 9 | // LogProb represents a single token's log probability information 10 | type LogProb struct { 11 | Token string `json:"token"` 12 | LogProb float64 `json:"logprob"` 13 | TopLogProbs []TopLogProb `json:"top_logprobs,omitempty"` 14 | } 15 | 16 | // TopLogProb represents a single top logprob entry 17 | type TopLogProb struct { 18 | Token string `json:"token"` 19 | LogProb float64 `json:"logprob"` 20 | } 21 | 22 | // CompletionLogProbs represents the logprobs structure returned by OpenRouter 23 | // completions API. 24 | type CompletionLogProbs struct { 25 | Tokens []string `json:"tokens"` 26 | TokenLogProbs []float64 `json:"token_logprobs"` 27 | TopLogProbs []map[string]float64 `json:"top_logprobs,omitempty"` 28 | } 29 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/prompt.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | // Role constants for message types 4 | const ( 5 | RoleSystem = "system" 6 | RoleUser = "user" 7 | RoleAssistant = "assistant" 8 | RoleTool = "tool" 9 | ) 10 | 11 | // Content type constants 12 | const ( 13 | ContentTypeText = "text" 14 | ContentTypeFunction = "function" 15 | ContentTypeImageURL = "image_url" 16 | ) 17 | 18 | // Prompt represents an array of chat messages in OpenRouter format. 19 | // This matches OpenRouter's chat completion API message format. 20 | type Prompt []Message 21 | 22 | // Message is the interface that all message types must implement 23 | type Message interface { 24 | // The role of the message. One of RoleSystem, RoleUser, RoleAssistant, or RoleTool 25 | Role() string 26 | } 27 | 28 | // SystemMessage represents a system message 29 | type SystemMessage struct { 30 | Content string `json:"content"` 31 | } 32 | 33 | var _ Message = &SystemMessage{} 34 | 35 | func (m *SystemMessage) Role() string { 36 | return RoleSystem 37 | } 38 | 39 | // AssistantMessage represents an assistant message 40 | type AssistantMessage struct { 41 | Content string `json:"content,omitempty"` 42 | // TODO: figure out whether reasoning is just for assistant, 43 | // or if it's a general field for all messages. 44 | // Coming from: https://github.com/OpenRouterTeam/ai-sdk-provider/blob/main/src/openrouter-chat-language-model.ts#L514 45 | Reasoning string `json:"reasoning,omitempty"` 46 | ToolCalls []ToolCall `json:"tool_calls,omitempty"` 47 | } 48 | 49 | var _ Message = &AssistantMessage{} 50 | 51 | func (m *AssistantMessage) Role() string { 52 | return RoleAssistant 53 | } 54 | 55 | // UserMessage represents a user message 56 | type UserMessage struct { 57 | Content UserMessageContent `json:"content"` 58 | } 59 | 60 | var _ Message = &UserMessage{} 61 | 62 | func (m *UserMessage) Role() string { 63 | return RoleUser 64 | } 65 | 66 | // UserMessageContent represents either a string or array of content parts. 67 | // When sending to OpenRouter's API: 68 | // - If Parts is non-nil, it will be sent as an array of content parts (e.g. for text + images) 69 | // - Otherwise, Text will be sent as a simple string content 70 | // Note: Text and Parts are mutually exclusive - only one should be set at a time 71 | type UserMessageContent struct { 72 | Text string // if content is a simple string 73 | Parts []ContentPart // if content is an array of parts 74 | } 75 | 76 | // ContentPart is the interface for different content part types 77 | type ContentPart interface { 78 | Type() string 79 | } 80 | 81 | // TextPart represents a text content part 82 | type TextPart struct { 83 | Text string `json:"text"` 84 | } 85 | 86 | var _ ContentPart = &TextPart{} 87 | 88 | func (p *TextPart) Type() string { 89 | return ContentTypeText 90 | } 91 | 92 | // ImagePart represents an image content part 93 | type ImagePart struct { 94 | ImageURL struct { 95 | URL string `json:"url"` 96 | } `json:"image_url"` 97 | } 98 | 99 | var _ ContentPart = &ImagePart{} 100 | 101 | func (p *ImagePart) Type() string { 102 | return ContentTypeImageURL 103 | } 104 | 105 | // ToolCall represents a tool call from the assistant 106 | type ToolCall struct { 107 | Type string `json:"type"` // always "function" 108 | ID string `json:"id"` 109 | Function struct { 110 | Name string `json:"name"` 111 | Arguments string `json:"arguments"` 112 | // TODO: Docs indicate there's an optional "description" field. 113 | // This differs from the TypeScript provider. 114 | } `json:"function"` 115 | } 116 | 117 | // ToolMessage represents a tool message 118 | type ToolMessage struct { 119 | Content string `json:"content"` 120 | ToolCallID string `json:"tool_call_id"` 121 | } 122 | 123 | var _ Message = &ToolMessage{} 124 | 125 | func (m *ToolMessage) Role() string { 126 | return RoleTool 127 | } 128 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/prompt_json.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | // UnmarshalJSON for SystemMessage 9 | func (m *SystemMessage) UnmarshalJSON(data []byte) error { 10 | type Alias SystemMessage 11 | aux := struct { 12 | Role string `json:"role"` 13 | *Alias 14 | }{ 15 | Alias: (*Alias)(m), 16 | } 17 | if err := json.Unmarshal(data, &aux); err != nil { 18 | return err 19 | } 20 | if aux.Role != RoleSystem { 21 | return fmt.Errorf("invalid role for SystemMessage: %s", aux.Role) 22 | } 23 | return nil 24 | } 25 | 26 | // UnmarshalJSON for AssistantMessage 27 | func (m *AssistantMessage) UnmarshalJSON(data []byte) error { 28 | type Alias AssistantMessage 29 | aux := struct { 30 | Role string `json:"role"` 31 | *Alias 32 | }{ 33 | Alias: (*Alias)(m), 34 | } 35 | if err := json.Unmarshal(data, &aux); err != nil { 36 | return err 37 | } 38 | if aux.Role != RoleAssistant { 39 | return fmt.Errorf("invalid role for AssistantMessage: %s", aux.Role) 40 | } 41 | return nil 42 | } 43 | 44 | // UnmarshalJSON for UserMessage 45 | func (m *UserMessage) UnmarshalJSON(data []byte) error { 46 | type Alias UserMessage 47 | aux := struct { 48 | Role string `json:"role"` 49 | *Alias 50 | }{ 51 | Alias: (*Alias)(m), 52 | } 53 | if err := json.Unmarshal(data, &aux); err != nil { 54 | return err 55 | } 56 | if aux.Role != RoleUser { 57 | return fmt.Errorf("invalid role for UserMessage: %s", aux.Role) 58 | } 59 | return nil 60 | } 61 | 62 | // UnmarshalJSON for ToolMessage 63 | func (m *ToolMessage) UnmarshalJSON(data []byte) error { 64 | type Alias ToolMessage 65 | aux := struct { 66 | Role string `json:"role"` 67 | *Alias 68 | }{ 69 | Alias: (*Alias)(m), 70 | } 71 | if err := json.Unmarshal(data, &aux); err != nil { 72 | return err 73 | } 74 | if aux.Role != RoleTool { 75 | return fmt.Errorf("invalid role for ToolMessage: %s", aux.Role) 76 | } 77 | return nil 78 | } 79 | 80 | // Add UnmarshalJSON methods for the content part types 81 | func (p *TextPart) UnmarshalJSON(data []byte) error { 82 | type Alias TextPart 83 | aux := struct { 84 | Type string `json:"type"` 85 | *Alias 86 | }{ 87 | Alias: (*Alias)(p), 88 | } 89 | if err := json.Unmarshal(data, &aux); err != nil { 90 | return err 91 | } 92 | if aux.Type != ContentTypeText { 93 | return fmt.Errorf("invalid type for TextPart: %s", aux.Type) 94 | } 95 | return nil 96 | } 97 | 98 | func (p *ImagePart) UnmarshalJSON(data []byte) error { 99 | type Alias ImagePart 100 | aux := struct { 101 | Type string `json:"type"` 102 | *Alias 103 | }{ 104 | Alias: (*Alias)(p), 105 | } 106 | if err := json.Unmarshal(data, &aux); err != nil { 107 | return err 108 | } 109 | if aux.Type != ContentTypeImageURL { 110 | return fmt.Errorf("invalid type for ImagePart: %s", aux.Type) 111 | } 112 | return nil 113 | } 114 | 115 | // MarshalJSON for SystemMessage 116 | func (m *SystemMessage) MarshalJSON() ([]byte, error) { 117 | type Alias SystemMessage 118 | return json.Marshal(struct { 119 | Role string `json:"role"` 120 | *Alias 121 | }{ 122 | Role: RoleSystem, 123 | Alias: (*Alias)(m), 124 | }) 125 | } 126 | 127 | // MarshalJSON for AssistantMessage 128 | func (m *AssistantMessage) MarshalJSON() ([]byte, error) { 129 | type Alias AssistantMessage 130 | return json.Marshal(struct { 131 | Role string `json:"role"` 132 | *Alias 133 | }{ 134 | Role: RoleAssistant, 135 | Alias: (*Alias)(m), 136 | }) 137 | } 138 | 139 | // MarshalJSON for UserMessage 140 | func (m *UserMessage) MarshalJSON() ([]byte, error) { 141 | type Alias UserMessage 142 | return json.Marshal(struct { 143 | Role string `json:"role"` 144 | *Alias 145 | }{ 146 | Role: RoleUser, 147 | Alias: (*Alias)(m), 148 | }) 149 | } 150 | 151 | // MarshalJSON for ToolMessage 152 | func (m *ToolMessage) MarshalJSON() ([]byte, error) { 153 | type Alias ToolMessage 154 | return json.Marshal(struct { 155 | Role string `json:"role"` 156 | *Alias 157 | }{ 158 | Role: RoleTool, 159 | Alias: (*Alias)(m), 160 | }) 161 | } 162 | 163 | // MarshalJSON for TextPart 164 | func (p *TextPart) MarshalJSON() ([]byte, error) { 165 | type Alias TextPart 166 | return json.Marshal(struct { 167 | Type string `json:"type"` 168 | *Alias 169 | }{ 170 | Type: ContentTypeText, 171 | Alias: (*Alias)(p), 172 | }) 173 | } 174 | 175 | // MarshalJSON for ImagePart 176 | func (p *ImagePart) MarshalJSON() ([]byte, error) { 177 | type Alias ImagePart 178 | return json.Marshal(struct { 179 | Type string `json:"type"` 180 | *Alias 181 | }{ 182 | Type: ContentTypeImageURL, 183 | Alias: (*Alias)(p), 184 | }) 185 | } 186 | 187 | // MarshalMessage marshals a Message interface into JSON bytes 188 | func MarshalMessage(msg Message) ([]byte, error) { 189 | return json.Marshal(msg) 190 | } 191 | 192 | // UnmarshalMessage unmarshals JSON bytes into the appropriate Message type 193 | func UnmarshalMessage(data []byte) (Message, error) { 194 | // First unmarshal just the role to determine the message type 195 | var roleCheck struct { 196 | Role string `json:"role"` 197 | } 198 | if err := json.Unmarshal(data, &roleCheck); err != nil { 199 | return nil, err 200 | } 201 | 202 | // Create and unmarshal into the appropriate type based on role 203 | var msg Message 204 | switch roleCheck.Role { 205 | case RoleSystem: 206 | msg = &SystemMessage{} 207 | case RoleUser: 208 | msg = &UserMessage{} 209 | case RoleAssistant: 210 | msg = &AssistantMessage{} 211 | case RoleTool: 212 | msg = &ToolMessage{} 213 | default: 214 | return nil, fmt.Errorf("unknown message role: %s", roleCheck.Role) 215 | } 216 | 217 | if err := json.Unmarshal(data, msg); err != nil { 218 | return nil, err 219 | } 220 | 221 | return msg, nil 222 | } 223 | 224 | // MarshalJSON implements custom JSON marshaling for user content 225 | func (c *UserMessageContent) MarshalJSON() ([]byte, error) { 226 | // If there are parts, marshal as an array of parts 227 | if c.Parts != nil { 228 | return json.Marshal(c.Parts) 229 | } 230 | 231 | // Otherwise marshal as string (empty string will be marshaled as "") 232 | return json.Marshal(c.Text) 233 | } 234 | 235 | // Add this function to handle unmarshaling of content parts 236 | func unmarshalContentPart(data []byte) (ContentPart, error) { 237 | var typeCheck struct { 238 | Type string `json:"type"` 239 | } 240 | if err := json.Unmarshal(data, &typeCheck); err != nil { 241 | return nil, err 242 | } 243 | 244 | switch typeCheck.Type { 245 | case ContentTypeText: 246 | var text TextPart 247 | if err := json.Unmarshal(data, &text); err != nil { 248 | return nil, err 249 | } 250 | return &text, nil 251 | case ContentTypeImageURL: 252 | var image ImagePart 253 | if err := json.Unmarshal(data, &image); err != nil { 254 | return nil, err 255 | } 256 | return &image, nil 257 | default: 258 | return nil, fmt.Errorf("unknown content part type: %s", typeCheck.Type) 259 | } 260 | } 261 | 262 | // UnmarshalJSON for UserMessageContent to handle both string and array cases 263 | func (c *UserMessageContent) UnmarshalJSON(data []byte) error { 264 | // Try unmarshaling as string first 265 | var text string 266 | if err := json.Unmarshal(data, &text); err == nil { 267 | c.Text = text 268 | c.Parts = nil 269 | return nil 270 | } 271 | 272 | // If that fails, try as array of content parts 273 | var rawParts []json.RawMessage 274 | if err := json.Unmarshal(data, &rawParts); err != nil { 275 | return err 276 | } 277 | 278 | c.Text = "" 279 | c.Parts = make([]ContentPart, len(rawParts)) 280 | for i, raw := range rawParts { 281 | part, err := unmarshalContentPart(raw) 282 | if err != nil { 283 | return err 284 | } 285 | c.Parts[i] = part 286 | } 287 | return nil 288 | } 289 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/prompt_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMessageMarshaling(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | message Message 14 | expected string 15 | }{ 16 | { 17 | name: "system message", 18 | message: &SystemMessage{ 19 | Content: "test content", 20 | }, 21 | expected: `{"role":"system","content":"test content"}`, 22 | }, 23 | { 24 | name: "user message with string content", 25 | message: &UserMessage{ 26 | Content: UserMessageContent{ 27 | Text: "test content", 28 | }, 29 | }, 30 | expected: `{"role":"user","content":"test content"}`, 31 | }, 32 | { 33 | name: "user message with content parts", 34 | message: &UserMessage{ 35 | Content: UserMessageContent{ 36 | Parts: []ContentPart{ 37 | &TextPart{ 38 | Text: "test text", 39 | }, 40 | &ImagePart{ 41 | ImageURL: struct { 42 | URL string `json:"url"` 43 | }{ 44 | URL: "http://example.com/image.jpg", 45 | }, 46 | }, 47 | }, 48 | }, 49 | }, 50 | expected: `{"role":"user","content":[{"type":"text","text":"test text"},{"type":"image_url","image_url":{"url":"http://example.com/image.jpg"}}]}`, 51 | }, 52 | { 53 | name: "assistant message with tool calls", 54 | message: &AssistantMessage{ 55 | Content: "test content", 56 | ToolCalls: []ToolCall{ 57 | { 58 | Type: "function", 59 | ID: "call_123", 60 | Function: struct { 61 | Name string `json:"name"` 62 | Arguments string `json:"arguments"` 63 | }{ 64 | Name: "test_function", 65 | Arguments: `{"arg":"value"}`, 66 | }, 67 | }, 68 | }, 69 | }, 70 | expected: `{"role":"assistant","content":"test content","tool_calls":[{"type":"function","id":"call_123","function":{"name":"test_function","arguments":"{\"arg\":\"value\"}"}}]}`, 71 | }, 72 | { 73 | name: "tool message", 74 | message: &ToolMessage{ 75 | Content: "test content", 76 | ToolCallID: "call_123", 77 | }, 78 | expected: `{"role":"tool","content":"test content","tool_call_id":"call_123"}`, 79 | }, 80 | } 81 | 82 | for _, tt := range tests { 83 | t.Run(tt.name, func(t *testing.T) { 84 | // Test direct marshaling matches expected JSON 85 | data, err := json.Marshal(tt.message) 86 | assert.NoError(t, err) 87 | assert.JSONEq(t, tt.expected, string(data)) 88 | 89 | // Unmarshaling should match the original struct 90 | unmarshaled, err := UnmarshalMessage(data) 91 | assert.NoError(t, err) 92 | assert.Equal(t, tt.message, unmarshaled) 93 | 94 | // Remarshaling should match the original JSON 95 | remarshaled, err := json.Marshal(unmarshaled) 96 | assert.NoError(t, err) 97 | assert.JSONEq(t, string(data), string(remarshaled)) 98 | }) 99 | } 100 | } 101 | 102 | func TestMessageUnmarshalingWithInvalidRoles(t *testing.T) { 103 | tests := []struct { 104 | name string 105 | json string 106 | targetMsg Message 107 | expectedErr string 108 | }{ 109 | { 110 | name: "system message with invalid role", 111 | json: `{"role":"user","content":"test content"}`, 112 | targetMsg: &SystemMessage{}, 113 | expectedErr: "invalid role for SystemMessage: user", 114 | }, 115 | { 116 | name: "user message with invalid role", 117 | json: `{"role":"system","content":"test content"}`, 118 | targetMsg: &UserMessage{}, 119 | expectedErr: "invalid role for UserMessage: system", 120 | }, 121 | { 122 | name: "assistant message with invalid role", 123 | json: `{"role":"user","content":"test content"}`, 124 | targetMsg: &AssistantMessage{}, 125 | expectedErr: "invalid role for AssistantMessage: user", 126 | }, 127 | { 128 | name: "tool message with invalid role", 129 | json: `{"role":"assistant","content":"test content","tool_call_id":"123"}`, 130 | targetMsg: &ToolMessage{}, 131 | expectedErr: "invalid role for ToolMessage: assistant", 132 | }, 133 | } 134 | 135 | for _, tt := range tests { 136 | t.Run(tt.name, func(t *testing.T) { 137 | err := json.Unmarshal([]byte(tt.json), tt.targetMsg) 138 | assert.Error(t, err) 139 | assert.Equal(t, tt.expectedErr, err.Error()) 140 | }) 141 | } 142 | } 143 | 144 | func TestContentPartUnmarshalingWithInvalidTypes(t *testing.T) { 145 | tests := []struct { 146 | name string 147 | json string 148 | targetPart ContentPart 149 | expectedErr string 150 | }{ 151 | { 152 | name: "text part with invalid type", 153 | json: `{"type":"image_url","text":"test content"}`, 154 | targetPart: &TextPart{}, 155 | expectedErr: "invalid type for TextPart: image_url", 156 | }, 157 | { 158 | name: "image part with invalid type", 159 | json: `{"type":"text","image_url":{"url":"http://example.com/image.jpg"}}`, 160 | targetPart: &ImagePart{}, 161 | expectedErr: "invalid type for ImagePart: text", 162 | }, 163 | } 164 | 165 | for _, tt := range tests { 166 | t.Run(tt.name, func(t *testing.T) { 167 | err := json.Unmarshal([]byte(tt.json), tt.targetPart) 168 | assert.Error(t, err) 169 | assert.Equal(t, tt.expectedErr, err.Error()) 170 | }) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/response.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | // Response represents a response from the OpenRouter chat API 4 | type Response struct { 5 | ID string `json:"id"` 6 | Object string `json:"object"` 7 | Created int64 `json:"created"` 8 | Model string `json:"model"` 9 | Choices []Choice `json:"choices"` 10 | Usage *Usage `json:"usage,omitempty"` 11 | SystemFingerprint string `json:"system_fingerprint"` 12 | } 13 | 14 | // Choice represents a single completion choice in the response 15 | type Choice struct { 16 | Index int `json:"index"` 17 | Message AssistantMessage `json:"message"` 18 | LogProbs LogProbs `json:"logprobs,omitempty"` 19 | FinishReason string `json:"finish_reason"` 20 | } 21 | 22 | // Usage represents token usage information in a response 23 | type Usage struct { 24 | PromptTokens int `json:"prompt_tokens"` 25 | TotalTokens int `json:"total_tokens"` 26 | CompletionTokens int `json:"completion_tokens"` 27 | } 28 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/settings.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strconv" 7 | ) 8 | 9 | // ChatSettings holds configuration for OpenRouter chat completions. 10 | type ChatSettings struct { 11 | // LogitBias modifies the likelihood of specified tokens appearing in the completion. 12 | // Maps token IDs to bias values from -100 to 100. Values between -1 and 1 decrease 13 | // or increase likelihood of selection; values like -100 or 100 result in a ban or 14 | // exclusive selection of the token. 15 | // Example: {50256: -100} prevents the <|endoftext|> token from being generated. 16 | LogitBias map[int]float64 `json:"logit_bias,omitempty"` 17 | 18 | // Logprobs controls returning log probabilities of the tokens. 19 | // When Enabled is true, returns all logprobs if TopK is 0, 20 | // or returns top K logprobs if TopK > 0. 21 | // Note: Including logprobs increases response size and can slow down response times. 22 | Logprobs *LogprobSettings `json:"logprobs,omitempty"` 23 | 24 | // ParallelToolCalls enables parallel function calling during tool use. 25 | // Defaults to true if not set. 26 | ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` 27 | 28 | // User is a unique identifier representing the end-user, which helps OpenRouter 29 | // monitor and detect abuse. 30 | User *string `json:"user,omitempty"` 31 | 32 | // Models is a list of model IDs to try in order if the primary model fails. 33 | // Example: ["anthropic/claude-2", "gryphe/mythomax-l2-13b"] 34 | Models []string `json:"models,omitempty"` 35 | 36 | // IncludeReasoning requests the model to return extra reasoning text in the response, 37 | // if the model supports it. 38 | IncludeReasoning *bool `json:"include_reasoning,omitempty"` 39 | } 40 | 41 | // CompletionSettings holds configuration for OpenRouter completions. 42 | type CompletionSettings struct { 43 | // Echo returns the prompt in addition to the completion 44 | Echo *bool `json:"echo,omitempty"` 45 | 46 | // LogitBias modifies the likelihood of specified tokens appearing in the completion. 47 | // Maps token IDs to bias values from -100 to 100. Values between -1 and 1 decrease 48 | // or increase likelihood of selection; values like -100 or 100 result in a ban or 49 | // exclusive selection of the token. 50 | // Example: {50256: -100} prevents the <|endoftext|> token from being generated. 51 | LogitBias map[int]float64 `json:"logit_bias,omitempty"` 52 | 53 | // Logprobs controls returning log probabilities of the tokens. 54 | // When Enabled is true, returns all logprobs if TopK is 0, 55 | // or returns top K logprobs if TopK > 0. 56 | // Note: Including logprobs increases response size and can slow down response times. 57 | Logprobs *LogprobSettings `json:"logprobs,omitempty"` 58 | 59 | // Suffix is appended after a completion of inserted text 60 | Suffix *string `json:"suffix,omitempty"` 61 | 62 | // User is a unique identifier representing the end-user, which helps OpenRouter 63 | // monitor and detect abuse. 64 | User *string `json:"user,omitempty"` 65 | 66 | // Models is a list of model IDs to try in order if the primary model fails. 67 | // Example: ["openai/gpt-4", "anthropic/claude-2"] 68 | Models []string `json:"models,omitempty"` 69 | 70 | // IncludeReasoning requests the model to return extra reasoning text in the response, 71 | // if the model supports it. 72 | IncludeReasoning *bool `json:"include_reasoning,omitempty"` 73 | } 74 | 75 | // LogprobSettings represents the configuration for token log probabilities. 76 | // It can be configured either as a boolean flag or with a number for top-N logprobs. 77 | type LogprobSettings struct { 78 | // Enabled indicates if logprobs should be returned 79 | Enabled bool 80 | // TopK specifies how many top logprobs to return (if > 0) 81 | TopK int 82 | } 83 | 84 | // MarshalJSON implements custom JSON marshaling for LogprobSettings 85 | func (l *LogprobSettings) MarshalJSON() ([]byte, error) { 86 | if l == nil { 87 | return []byte("null"), nil 88 | } 89 | if l.TopK > 0 { 90 | return []byte(strconv.Itoa(l.TopK)), nil 91 | } 92 | return []byte(strconv.FormatBool(l.Enabled)), nil 93 | } 94 | 95 | // UnmarshalJSON implements custom JSON unmarshaling for LogprobSettings 96 | func (l *LogprobSettings) UnmarshalJSON(data []byte) error { 97 | if string(data) == "null" { 98 | l.Enabled = false 99 | l.TopK = 0 100 | return nil 101 | } 102 | 103 | // Try as boolean first 104 | var b bool 105 | if err := json.Unmarshal(data, &b); err == nil { 106 | l.Enabled = b 107 | l.TopK = 0 108 | return nil 109 | } 110 | 111 | // Try as number 112 | var n int 113 | if err := json.Unmarshal(data, &n); err == nil { 114 | l.Enabled = true 115 | l.TopK = n 116 | return nil 117 | } 118 | 119 | return fmt.Errorf("logprobs must be boolean or number, got %s", string(data)) 120 | } 121 | -------------------------------------------------------------------------------- /provider/internal/openrouter/client/settings_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestLogprobSettings_Marshal(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | logprobs *LogprobSettings 14 | want string 15 | }{ 16 | { 17 | name: "nil logprobs", 18 | logprobs: nil, 19 | want: "null", 20 | }, 21 | { 22 | name: "enabled true", 23 | logprobs: &LogprobSettings{ 24 | Enabled: true, 25 | }, 26 | want: "true", 27 | }, 28 | { 29 | name: "enabled false", 30 | logprobs: &LogprobSettings{ 31 | Enabled: false, 32 | }, 33 | want: "false", 34 | }, 35 | { 36 | name: "top K", 37 | logprobs: &LogprobSettings{ 38 | Enabled: true, 39 | TopK: 5, 40 | }, 41 | want: "5", 42 | }, 43 | { 44 | name: "disabled with zero topK", 45 | logprobs: &LogprobSettings{ 46 | Enabled: false, 47 | TopK: 0, 48 | }, 49 | want: "false", 50 | }, 51 | { 52 | name: "enabled with zero topK", 53 | logprobs: &LogprobSettings{ 54 | Enabled: true, 55 | TopK: 0, 56 | }, 57 | want: "true", 58 | }, 59 | } 60 | 61 | for _, tt := range tests { 62 | t.Run(tt.name, func(t *testing.T) { 63 | // Test marshaling 64 | got, err := json.Marshal(tt.logprobs) 65 | assert.NoError(t, err) 66 | assert.Equal(t, tt.want, string(got)) 67 | 68 | // Test unmarshaling 69 | var l LogprobSettings 70 | err = json.Unmarshal([]byte(tt.want), &l) 71 | assert.NoError(t, err) 72 | if tt.logprobs != nil { 73 | assert.Equal(t, tt.logprobs.Enabled, l.Enabled) 74 | assert.Equal(t, tt.logprobs.TopK, l.TopK) 75 | } 76 | }) 77 | } 78 | } 79 | 80 | func TestLogprobSettings_UnmarshalErrors(t *testing.T) { 81 | tests := []struct { 82 | name string 83 | input string 84 | errSubstr string 85 | }{ 86 | { 87 | name: "invalid string", 88 | input: `"invalid"`, 89 | errSubstr: "logprobs must be boolean or number", 90 | }, 91 | } 92 | 93 | for _, tt := range tests { 94 | t.Run(tt.name, func(t *testing.T) { 95 | var l LogprobSettings 96 | err := json.Unmarshal([]byte(tt.input), &l) 97 | assert.Error(t, err) 98 | assert.Contains(t, err.Error(), tt.errSubstr) 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/decode_finish.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "go.jetify.com/ai/api" 5 | "go.jetify.com/ai/provider/internal/openrouter/client" 6 | ) 7 | 8 | // DecodeFinishReason converts an OpenRouter finish reason to an AI SDK FinishReason type. 9 | // It handles nil/empty values by returning FinishReasonUnknown. 10 | func DecodeFinishReason(finishReason string) api.FinishReason { 11 | switch finishReason { 12 | case client.FinishReasonStop: 13 | return api.FinishReasonStop 14 | case client.FinishReasonLength: 15 | return api.FinishReasonLength 16 | case client.FinishReasonContentFilter: 17 | return api.FinishReasonContentFilter 18 | case client.FinishReasonFunctionCall, client.FinishReasonToolCalls: 19 | return api.FinishReasonToolCalls 20 | default: 21 | return api.FinishReasonUnknown 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/decode_finish_test.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "testing" 5 | 6 | "go.jetify.com/ai/api" 7 | "go.jetify.com/ai/provider/internal/openrouter/client" 8 | ) 9 | 10 | func TestDecodeFinishReason(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | input string 14 | expected api.FinishReason 15 | }{ 16 | { 17 | name: "stop reason", 18 | input: client.FinishReasonStop, 19 | expected: api.FinishReasonStop, 20 | }, 21 | { 22 | name: "length reason", 23 | input: client.FinishReasonLength, 24 | expected: api.FinishReasonLength, 25 | }, 26 | { 27 | name: "content filter reason", 28 | input: client.FinishReasonContentFilter, 29 | expected: api.FinishReasonContentFilter, 30 | }, 31 | { 32 | name: "function call reason", 33 | input: client.FinishReasonFunctionCall, 34 | expected: api.FinishReasonToolCalls, 35 | }, 36 | { 37 | name: "tool calls reason", 38 | input: client.FinishReasonToolCalls, 39 | expected: api.FinishReasonToolCalls, 40 | }, 41 | { 42 | name: "empty string", 43 | input: "", 44 | expected: api.FinishReasonUnknown, 45 | }, 46 | { 47 | name: "unknown reason", 48 | input: "something_else", 49 | expected: api.FinishReasonUnknown, 50 | }, 51 | } 52 | 53 | for _, tt := range tests { 54 | t.Run(tt.name, func(t *testing.T) { 55 | result := DecodeFinishReason(tt.input) 56 | if result != tt.expected { 57 | t.Errorf("DecodeFinishReason(%q) = %q, want %q", tt.input, result, tt.expected) 58 | } 59 | }) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/decode_logprobs.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "go.jetify.com/ai/api" 5 | "go.jetify.com/ai/provider/internal/openrouter/client" 6 | ) 7 | 8 | // DecodeLogProbs converts OpenRouter's chat logprobs format to the SDK's LogProb format 9 | func DecodeLogProbs(logprobs *client.LogProbs) []api.LogProb { 10 | if logprobs == nil || logprobs.Content == nil { 11 | return []api.LogProb{} // Return empty slice instead of nil 12 | } 13 | 14 | result := make([]api.LogProb, len(logprobs.Content)) 15 | for i, item := range logprobs.Content { 16 | result[i] = decodeChatLogProb(item) 17 | } 18 | 19 | return result 20 | } 21 | 22 | // decodeChatLogProb converts a single OpenRouter LogProb to the SDK's LogProb format 23 | func decodeChatLogProb(item client.LogProb) api.LogProb { 24 | return api.LogProb{ 25 | Token: item.Token, 26 | LogProb: item.LogProb, 27 | TopLogProbs: decodeTokenLogProbs(item.TopLogProbs), 28 | } 29 | } 30 | 31 | // DecodeCompletionLogProbs converts OpenRouter's completion logprobs format to the SDK's LogProb format. 32 | // It handles nil input by returning an empty slice. 33 | func DecodeCompletionLogProbs(logprobs *client.CompletionLogProbs) []api.LogProb { 34 | if logprobs == nil { 35 | return []api.LogProb{} 36 | } 37 | 38 | result := make([]api.LogProb, len(logprobs.Tokens)) 39 | for i, token := range logprobs.Tokens { 40 | result[i] = decodeCompletionLogProbForToken(logprobs, i, token) 41 | } 42 | 43 | return result 44 | } 45 | 46 | // decodeCompletionLogProbForToken converts a single token's logprob data into the SDK's LogProb format 47 | func decodeCompletionLogProbForToken(logprobs *client.CompletionLogProbs, index int, token string) api.LogProb { 48 | return api.LogProb{ 49 | Token: token, 50 | LogProb: getCompletionLogProbValue(logprobs.TokenLogProbs, index), 51 | TopLogProbs: decodeCompletionTopLogProbs(logprobs.TopLogProbs, index), 52 | } 53 | } 54 | 55 | // getCompletionLogProbValue safely gets the logprob value for a token, defaulting to 0 if out of bounds 56 | func getCompletionLogProbValue(logprobs []float64, index int) float64 { 57 | if index < len(logprobs) { 58 | return logprobs[index] 59 | } 60 | return 0 61 | } 62 | 63 | // decodeCompletionTopLogProbs converts the top logprobs map for a token into the SDK's TokenLogProb format 64 | func decodeCompletionTopLogProbs(topLogProbs []map[string]float64, index int) []api.TokenLogProb { 65 | if topLogProbs == nil || index >= len(topLogProbs) || topLogProbs[index] == nil { 66 | return []api.TokenLogProb{} // Always return empty slice instead of nil 67 | } 68 | 69 | topMap := topLogProbs[index] 70 | pairs := make([]client.TopLogProb, 0, len(topMap)) 71 | 72 | for token, logprob := range topMap { 73 | pairs = append(pairs, client.TopLogProb{ 74 | Token: token, 75 | LogProb: logprob, 76 | }) 77 | } 78 | 79 | return decodeTokenLogProbs(pairs) 80 | } 81 | 82 | // decodeTokenLogProbs converts a slice of token/logprob pairs to the SDK's TokenLogProb format 83 | func decodeTokenLogProbs(pairs []client.TopLogProb) []api.TokenLogProb { 84 | result := make([]api.TokenLogProb, len(pairs)) 85 | for i, pair := range pairs { 86 | result[i] = api.TokenLogProb{ 87 | Token: pair.Token, 88 | LogProb: pair.LogProb, 89 | } 90 | } 91 | return result 92 | } 93 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/decode_logprobs_test.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "go.jetify.com/ai/api" 8 | "go.jetify.com/ai/provider/internal/openrouter/client" 9 | ) 10 | 11 | func TestDecodeLogProbs(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | input *client.LogProbs 15 | expected []api.LogProb 16 | }{ 17 | { 18 | name: "nil input", 19 | input: nil, 20 | expected: []api.LogProb{}, 21 | }, 22 | { 23 | name: "nil content", 24 | input: &client.LogProbs{Content: nil}, 25 | expected: []api.LogProb{}, 26 | }, 27 | { 28 | name: "single token with no top logprobs", 29 | input: &client.LogProbs{ 30 | Content: []client.LogProb{ 31 | { 32 | Token: "hello", 33 | LogProb: -0.5, 34 | TopLogProbs: nil, 35 | }, 36 | }, 37 | }, 38 | expected: []api.LogProb{ 39 | { 40 | Token: "hello", 41 | LogProb: -0.5, 42 | TopLogProbs: []api.TokenLogProb{}, 43 | }, 44 | }, 45 | }, 46 | { 47 | name: "single token with top logprobs", 48 | input: &client.LogProbs{ 49 | Content: []client.LogProb{ 50 | { 51 | Token: "hello", 52 | LogProb: -0.5, 53 | TopLogProbs: []client.TopLogProb{ 54 | {Token: "hello", LogProb: -0.5}, 55 | {Token: "hi", LogProb: -1.0}, 56 | }, 57 | }, 58 | }, 59 | }, 60 | expected: []api.LogProb{ 61 | { 62 | Token: "hello", 63 | LogProb: -0.5, 64 | TopLogProbs: []api.TokenLogProb{ 65 | {Token: "hello", LogProb: -0.5}, 66 | {Token: "hi", LogProb: -1.0}, 67 | }, 68 | }, 69 | }, 70 | }, 71 | { 72 | name: "multiple tokens with mixed top logprobs", 73 | input: &client.LogProbs{ 74 | Content: []client.LogProb{ 75 | { 76 | Token: "hello", 77 | LogProb: -0.5, 78 | TopLogProbs: []client.TopLogProb{ 79 | {Token: "hello", LogProb: -0.5}, 80 | }, 81 | }, 82 | { 83 | Token: "world", 84 | LogProb: -0.3, 85 | TopLogProbs: nil, 86 | }, 87 | }, 88 | }, 89 | expected: []api.LogProb{ 90 | { 91 | Token: "hello", 92 | LogProb: -0.5, 93 | TopLogProbs: []api.TokenLogProb{ 94 | {Token: "hello", LogProb: -0.5}, 95 | }, 96 | }, 97 | { 98 | Token: "world", 99 | LogProb: -0.3, 100 | TopLogProbs: []api.TokenLogProb{}, 101 | }, 102 | }, 103 | }, 104 | } 105 | 106 | for _, tt := range tests { 107 | t.Run(tt.name, func(t *testing.T) { 108 | result := DecodeLogProbs(tt.input) 109 | assert.Equal(t, tt.expected, result) 110 | }) 111 | } 112 | } 113 | 114 | func TestDecodeCompletionLogProbs(t *testing.T) { 115 | tests := []struct { 116 | name string 117 | input *client.CompletionLogProbs 118 | expected []api.LogProb 119 | }{ 120 | { 121 | name: "nil input", 122 | input: nil, 123 | expected: []api.LogProb{}, 124 | }, 125 | { 126 | name: "empty tokens", 127 | input: &client.CompletionLogProbs{ 128 | Tokens: []string{}, 129 | TokenLogProbs: []float64{}, 130 | TopLogProbs: nil, 131 | }, 132 | expected: []api.LogProb{}, 133 | }, 134 | { 135 | name: "single token without top logprobs", 136 | input: &client.CompletionLogProbs{ 137 | Tokens: []string{"hello"}, 138 | TokenLogProbs: []float64{-0.5}, 139 | TopLogProbs: nil, 140 | }, 141 | expected: []api.LogProb{ 142 | { 143 | Token: "hello", 144 | LogProb: -0.5, 145 | TopLogProbs: []api.TokenLogProb{}, 146 | }, 147 | }, 148 | }, 149 | { 150 | name: "token with missing logprob defaults to 0", 151 | input: &client.CompletionLogProbs{ 152 | Tokens: []string{"hello"}, 153 | TokenLogProbs: []float64{}, 154 | TopLogProbs: nil, 155 | }, 156 | expected: []api.LogProb{ 157 | { 158 | Token: "hello", 159 | LogProb: 0, 160 | TopLogProbs: []api.TokenLogProb{}, 161 | }, 162 | }, 163 | }, 164 | { 165 | name: "token with top logprobs", 166 | input: &client.CompletionLogProbs{ 167 | Tokens: []string{"hello"}, 168 | TokenLogProbs: []float64{-0.5}, 169 | TopLogProbs: []map[string]float64{ 170 | { 171 | "hello": -0.5, 172 | "hi": -1.0, 173 | }, 174 | }, 175 | }, 176 | expected: []api.LogProb{ 177 | { 178 | Token: "hello", 179 | LogProb: -0.5, 180 | TopLogProbs: []api.TokenLogProb{ 181 | {Token: "hello", LogProb: -0.5}, 182 | {Token: "hi", LogProb: -1.0}, 183 | }, 184 | }, 185 | }, 186 | }, 187 | { 188 | name: "multiple tokens with mixed top logprobs", 189 | input: &client.CompletionLogProbs{ 190 | Tokens: []string{"hello", "world"}, 191 | TokenLogProbs: []float64{-0.5, -0.3}, 192 | TopLogProbs: []map[string]float64{ 193 | { 194 | "hello": -0.5, 195 | "hi": -1.0, 196 | }, 197 | nil, 198 | }, 199 | }, 200 | expected: []api.LogProb{ 201 | { 202 | Token: "hello", 203 | LogProb: -0.5, 204 | TopLogProbs: []api.TokenLogProb{ 205 | {Token: "hello", LogProb: -0.5}, 206 | {Token: "hi", LogProb: -1.0}, 207 | }, 208 | }, 209 | { 210 | Token: "world", 211 | LogProb: -0.3, 212 | TopLogProbs: []api.TokenLogProb{}, 213 | }, 214 | }, 215 | }, 216 | } 217 | 218 | for _, tt := range tests { 219 | t.Run(tt.name, func(t *testing.T) { 220 | result := DecodeCompletionLogProbs(tt.input) 221 | assert.Equal(t, len(tt.expected), len(result)) 222 | for i := range result { 223 | assert.Equal(t, tt.expected[i].Token, result[i].Token) 224 | assert.Equal(t, tt.expected[i].LogProb, result[i].LogProb) 225 | assert.ElementsMatch(t, tt.expected[i].TopLogProbs, result[i].TopLogProbs) 226 | } 227 | }) 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/encode_prompt.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | // Functions that convert AI SDK types to OpenRouter types. 4 | 5 | import ( 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | 10 | "go.jetify.com/ai/api" 11 | "go.jetify.com/ai/provider/internal/openrouter/client" 12 | ) 13 | 14 | // EncodePrompt converts an AI SDK prompt into OpenRouter's chat message format 15 | func EncodePrompt(prompt []api.Message) (client.Prompt, error) { 16 | // Pre-allocate with extra space for potential tool message expansion 17 | messages := make(client.Prompt, 0, len(prompt)*2) 18 | 19 | for _, msg := range prompt { 20 | encodedMsgs, err := encodeMessage(msg) 21 | if err != nil { 22 | return nil, err 23 | } 24 | messages = append(messages, encodedMsgs...) 25 | } 26 | 27 | return messages, nil 28 | } 29 | 30 | func encodeMessage(msg api.Message) ([]client.Message, error) { 31 | switch msg := msg.(type) { 32 | case *api.SystemMessage: 33 | encoded := encodeSystemMessage(msg) 34 | return []client.Message{encoded}, nil 35 | case api.SystemMessage: 36 | encoded := encodeSystemMessage(&msg) 37 | return []client.Message{encoded}, nil 38 | case *api.UserMessage: 39 | encoded, err := encodeUserMessage(msg) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return []client.Message{encoded}, nil 44 | case api.UserMessage: 45 | encoded, err := encodeUserMessage(&msg) 46 | if err != nil { 47 | return nil, err 48 | } 49 | return []client.Message{encoded}, nil 50 | case *api.AssistantMessage: 51 | encoded, err := encodeAssistantMessage(msg) 52 | if err != nil { 53 | return nil, err 54 | } 55 | return []client.Message{encoded}, nil 56 | case api.AssistantMessage: 57 | encoded, err := encodeAssistantMessage(&msg) 58 | if err != nil { 59 | return nil, err 60 | } 61 | return []client.Message{encoded}, nil 62 | case *api.ToolMessage: 63 | return encodeToolMessage(msg) 64 | case api.ToolMessage: 65 | return encodeToolMessage(&msg) 66 | default: 67 | // TODO: use a more specific error type from api. 68 | return nil, fmt.Errorf("unsupported message type: %T", msg) 69 | } 70 | } 71 | 72 | func encodeSystemMessage(msg *api.SystemMessage) *client.SystemMessage { 73 | return &client.SystemMessage{ 74 | Content: msg.Content, 75 | } 76 | } 77 | 78 | func encodeUserMessage(msg *api.UserMessage) (*client.UserMessage, error) { 79 | // Special case: If there's exactly one text block, use a simpler format 80 | // This optimization avoids creating an unnecessary array of blocks 81 | if len(msg.Content) == 1 { 82 | if textBlock, ok := msg.Content[0].(*api.TextBlock); ok { 83 | return &client.UserMessage{ 84 | Content: client.UserMessageContent{ 85 | Text: textBlock.Text, 86 | }, 87 | }, nil 88 | } 89 | } 90 | 91 | // Otherwise encode all blocks 92 | parts, err := encodeUserContent(msg.Content) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | return &client.UserMessage{ 98 | Content: client.UserMessageContent{ 99 | Parts: parts, 100 | }, 101 | }, nil 102 | } 103 | 104 | func encodeUserContent(content []api.ContentBlock) ([]client.ContentPart, error) { 105 | parts := make([]client.ContentPart, 0, len(content)) 106 | for _, block := range content { 107 | encodedPart, err := encodeUserContentBlock(block) 108 | if err != nil { 109 | return nil, err 110 | } 111 | parts = append(parts, encodedPart) 112 | } 113 | return parts, nil 114 | } 115 | 116 | func encodeUserContentBlock(block api.ContentBlock) (client.ContentPart, error) { 117 | switch block := block.(type) { 118 | case *api.TextBlock: 119 | return encodeTextBlock(block), nil 120 | case api.TextBlock: 121 | return encodeTextBlock(&block), nil 122 | case *api.ImageBlock: 123 | return encodeImageBlock(block), nil 124 | case api.ImageBlock: 125 | return encodeImageBlock(&block), nil 126 | case *api.FileBlock: 127 | return encodeFileBlock(block), nil 128 | case api.FileBlock: 129 | return encodeFileBlock(&block), nil 130 | default: 131 | return nil, fmt.Errorf("unsupported content block type: %T", block) 132 | } 133 | } 134 | 135 | func encodeTextBlock(block *api.TextBlock) *client.TextPart { 136 | return &client.TextPart{ 137 | Text: block.Text, 138 | } 139 | } 140 | 141 | func encodeImageBlock(block *api.ImageBlock) *client.ImagePart { 142 | url := block.URL 143 | // If no URL is provided but we have raw image data, 144 | // convert it to a data URL (e.g., "...") 145 | if url == "" && block.Data != nil { 146 | mimeType := block.MediaType 147 | if mimeType == "" { 148 | mimeType = "image/jpeg" // Default to JPEG if no mime type specified 149 | } 150 | url = fmt.Sprintf("data:%s;base64,%s", 151 | mimeType, 152 | base64.StdEncoding.EncodeToString(block.Data), 153 | ) 154 | } 155 | 156 | imagePart := &client.ImagePart{} 157 | imagePart.ImageURL.URL = url 158 | return imagePart 159 | } 160 | 161 | func encodeFileBlock(block *api.FileBlock) *client.TextPart { 162 | text := block.URL 163 | if text == "" && block.Data != nil { 164 | if block.MediaType == "" { 165 | // If no mime type, treat the data as plain text 166 | text = string(block.Data) 167 | } else { 168 | // This extra functionality of encoding when a mime type is set is 169 | // not part of the TypeScript OpenRouter implementation. We've added it. 170 | // Double check that this is beneficial. 171 | text = fmt.Sprintf("data:%s;base64,%s", 172 | block.MediaType, 173 | base64.StdEncoding.EncodeToString(block.Data), 174 | ) 175 | } 176 | } 177 | return &client.TextPart{ 178 | Text: text, 179 | } 180 | } 181 | 182 | func encodeAssistantMessage(msg *api.AssistantMessage) (*client.AssistantMessage, error) { 183 | text := "" 184 | toolCalls := []client.ToolCall{} 185 | 186 | // Combine all text parts into a single string and collect tool calls 187 | for _, part := range msg.Content { 188 | switch block := part.(type) { 189 | case *api.TextBlock: 190 | encoded := encodeTextBlock(block) 191 | text += encoded.Text // Concatenate all text parts 192 | case api.TextBlock: 193 | encoded := encodeTextBlock(&block) 194 | text += encoded.Text // Concatenate all text parts 195 | case *api.ToolCallBlock: 196 | toolCall, err := encodeToolCallBlock(block) 197 | if err != nil { 198 | return nil, err 199 | } 200 | toolCalls = append(toolCalls, toolCall) 201 | case api.ToolCallBlock: 202 | toolCall, err := encodeToolCallBlock(&block) 203 | if err != nil { 204 | return nil, err 205 | } 206 | toolCalls = append(toolCalls, toolCall) 207 | default: 208 | return nil, fmt.Errorf("unsupported assistant content block type: %T", block) 209 | } 210 | } 211 | 212 | return &client.AssistantMessage{ 213 | Content: text, 214 | ToolCalls: toolCalls, 215 | }, nil 216 | } 217 | 218 | func encodeToolCallBlock(block *api.ToolCallBlock) (client.ToolCall, error) { 219 | args, err := json.Marshal(block.Args) 220 | if err != nil { 221 | return client.ToolCall{}, fmt.Errorf("failed to marshal tool call args: %w", err) 222 | } 223 | 224 | toolCall := client.ToolCall{ 225 | Type: "function", 226 | ID: block.ToolCallID, 227 | } 228 | toolCall.Function.Name = block.ToolName 229 | toolCall.Function.Arguments = string(args) 230 | return toolCall, nil 231 | } 232 | 233 | func encodeToolMessage(msg *api.ToolMessage) ([]client.Message, error) { 234 | // Convert each tool result into a separate ToolMessage 235 | // Note: OpenRouter expects tool messages to be flattened into separate messages 236 | messages := make([]client.Message, 0, len(msg.Content)) 237 | for _, result := range msg.Content { 238 | resultJSON, err := json.Marshal(result.Result) 239 | if err != nil { 240 | return nil, fmt.Errorf("failed to marshal tool result: %w", err) 241 | } 242 | messages = append(messages, &client.ToolMessage{ 243 | Content: string(resultJSON), 244 | ToolCallID: result.ToolCallID, 245 | }) 246 | } 247 | 248 | // If no results, return empty slice 249 | if len(messages) == 0 { 250 | return []client.Message{}, nil 251 | } 252 | 253 | return messages, nil 254 | } 255 | -------------------------------------------------------------------------------- /provider/internal/openrouter/codec/encode_prompt_test.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "go.jetify.com/ai/aitesting" 9 | "go.jetify.com/ai/api" 10 | ) 11 | 12 | func TestEncodePrompt(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | prompt []api.Message 16 | expected string // JSON string of expected output 17 | wantError bool 18 | }{ 19 | { 20 | name: "system message", 21 | prompt: []api.Message{ 22 | &api.SystemMessage{Content: "test system message"}, 23 | }, 24 | expected: `[{"role":"system","content":"test system message"}]`, 25 | }, 26 | { 27 | name: "user message with single text block", 28 | prompt: []api.Message{ 29 | &api.UserMessage{ 30 | Content: api.ContentFromText("hello"), 31 | }, 32 | }, 33 | expected: `[{"role":"user","content":"hello"}]`, 34 | }, 35 | { 36 | name: "user message with multiple blocks", 37 | prompt: []api.Message{ 38 | &api.UserMessage{ 39 | Content: []api.ContentBlock{ 40 | &api.TextBlock{Text: "hello"}, 41 | &api.ImageBlock{Data: []byte{0, 1, 2, 3}, MediaType: "image/png"}, 42 | &api.FileBlock{URL: "http://example.com/file.txt"}, 43 | }, 44 | }, 45 | }, 46 | expected: `[{"role":"user","content":[ 47 | {"type":"text","text":"hello"}, 48 | {"type":"image_url","image_url":{"url":""}}, 49 | {"type":"text","text":"http://example.com/file.txt"} 50 | ]}]`, 51 | }, 52 | { 53 | name: "assistant message with text", 54 | prompt: []api.Message{ 55 | &api.AssistantMessage{ 56 | Content: []api.ContentBlock{ 57 | &api.TextBlock{Text: "hello"}, 58 | }, 59 | }, 60 | }, 61 | expected: `[{"role":"assistant","content":"hello"}]`, 62 | }, 63 | { 64 | name: "assistant message with tool calls", 65 | prompt: []api.Message{ 66 | &api.AssistantMessage{ 67 | Content: []api.ContentBlock{ 68 | &api.TextBlock{Text: "Using calculator"}, 69 | &api.ToolCallBlock{ 70 | ToolCallID: "call_123", 71 | ToolName: "calculator", 72 | Args: json.RawMessage(`{"x": 1, "y": 2}`), 73 | }, 74 | }, 75 | }, 76 | }, 77 | expected: `[{"role":"assistant","content":"Using calculator","tool_calls":[ 78 | {"type":"function","id":"call_123","function":{"name":"calculator","arguments":"{\"x\":1,\"y\":2}"}} 79 | ]}]`, 80 | }, 81 | { 82 | name: "tool message with multiple results", 83 | prompt: []api.Message{ 84 | &api.ToolMessage{ 85 | Content: []api.ToolResultBlock{ 86 | { 87 | ToolCallID: "call_123", 88 | ToolName: "calculator", 89 | Result: json.RawMessage(`{"result": 3}`), 90 | }, 91 | { 92 | ToolCallID: "call_456", 93 | ToolName: "calculator", 94 | Result: json.RawMessage(`{"result": 4}`), 95 | }, 96 | }, 97 | }, 98 | }, 99 | expected: `[ 100 | {"role":"tool","content":"{\"result\":3}","tool_call_id":"call_123"}, 101 | {"role":"tool","content":"{\"result\":4}","tool_call_id":"call_456"} 102 | ]`, 103 | }, 104 | { 105 | name: "user message with binary image data", 106 | prompt: []api.Message{ 107 | &api.UserMessage{ 108 | Content: []api.ContentBlock{ 109 | &api.TextBlock{Text: "hello"}, 110 | api.ImageBlockFromData([]byte{0, 1, 2, 3}, "image/png"), 111 | }, 112 | }, 113 | }, 114 | expected: `[{"role":"user","content":[ 115 | {"type":"text","text":"hello"}, 116 | {"type":"image_url","image_url":{"url":""}} 117 | ]}]`, 118 | }, 119 | { 120 | name: "user message with image URL", 121 | prompt: []api.Message{ 122 | &api.UserMessage{ 123 | Content: []api.ContentBlock{ 124 | &api.TextBlock{Text: "hello"}, 125 | api.ImageBlockFromURL("https://example.com/image.jpg"), 126 | }, 127 | }, 128 | }, 129 | expected: `[{"role":"user","content":[ 130 | {"type":"text","text":"hello"}, 131 | {"type":"image_url","image_url":{"url":"https://example.com/image.jpg"}} 132 | ]}]`, 133 | }, 134 | { 135 | name: "user message with file URL", 136 | prompt: []api.Message{ 137 | &api.UserMessage{ 138 | Content: []api.ContentBlock{ 139 | &api.TextBlock{Text: "hello"}, 140 | api.FileBlockFromURL("http://example.com/file.txt"), 141 | }, 142 | }, 143 | }, 144 | expected: `[{"role":"user","content":[ 145 | {"type":"text","text":"hello"}, 146 | {"type":"text","text":"http://example.com/file.txt"} 147 | ]}]`, 148 | }, 149 | { 150 | name: "user message with audio file data", 151 | prompt: []api.Message{ 152 | &api.UserMessage{ 153 | Content: []api.ContentBlock{ 154 | api.FileBlockFromData([]byte{0, 1, 2, 3}, "audio/wav"), 155 | }, 156 | }, 157 | }, 158 | expected: `[{"role":"user","content":[ 159 | {"type":"text","text":"data:audio/wav;base64,AAECAw=="} 160 | ]}]`, 161 | }, 162 | { 163 | name: "user message with image data and missing mime type", 164 | prompt: []api.Message{ 165 | &api.UserMessage{ 166 | Content: []api.ContentBlock{ 167 | &api.ImageBlock{ 168 | Data: []byte{0, 1, 2, 3}, 169 | // MimeType intentionally omitted 170 | }, 171 | }, 172 | }, 173 | }, 174 | expected: `[{"role":"user","content":[ 175 | {"type":"image_url","image_url":{"url":""}} 176 | ]}]`, 177 | }, 178 | } 179 | 180 | for _, tt := range tests { 181 | t.Run(tt.name, func(t *testing.T) { 182 | messages, err := EncodePrompt(tt.prompt) 183 | if tt.wantError { 184 | assert.Error(t, err) 185 | return 186 | } 187 | 188 | assert.NoError(t, err) 189 | data, err := json.Marshal(messages) 190 | assert.NoError(t, err) 191 | assert.JSONEq(t, tt.expected, string(data)) 192 | }) 193 | } 194 | } 195 | 196 | func TestEncodePrompt_Failures(t *testing.T) { 197 | tests := []struct { 198 | name string 199 | prompt []api.Message 200 | expectedError string 201 | }{ 202 | { 203 | name: "user message with unsupported block", 204 | prompt: []api.Message{ 205 | &api.UserMessage{ 206 | Content: []api.ContentBlock{ 207 | &aitesting.MockUnsupportedBlock{}, 208 | }, 209 | }, 210 | }, 211 | expectedError: "unsupported content block type", 212 | }, 213 | { 214 | name: "user message with tool call block", 215 | prompt: []api.Message{ 216 | &api.UserMessage{ 217 | Content: []api.ContentBlock{ 218 | &api.ToolCallBlock{ 219 | ToolCallID: "call_123", 220 | ToolName: "calculator", 221 | Args: json.RawMessage(`{"x": 1}`), 222 | }, 223 | }, 224 | }, 225 | }, 226 | expectedError: "unsupported content block type", 227 | }, 228 | { 229 | name: "assistant message with unsupported block", 230 | prompt: []api.Message{ 231 | &api.AssistantMessage{ 232 | Content: []api.ContentBlock{ 233 | &aitesting.MockUnsupportedBlock{}, 234 | }, 235 | }, 236 | }, 237 | expectedError: "unsupported assistant content block type", 238 | }, 239 | { 240 | name: "assistant message with file block", 241 | prompt: []api.Message{ 242 | &api.AssistantMessage{ 243 | Content: []api.ContentBlock{ 244 | api.FileBlockFromURL("http://example.com/file.txt"), 245 | }, 246 | }, 247 | }, 248 | expectedError: "unsupported assistant content block type", 249 | }, 250 | { 251 | name: "unsupported message type", 252 | prompt: []api.Message{ 253 | &aitesting.MockUnsupportedMessage{}, 254 | }, 255 | expectedError: "unsupported message type", 256 | }, 257 | } 258 | 259 | for _, tt := range tests { 260 | t.Run(tt.name, func(t *testing.T) { 261 | _, err := EncodePrompt(tt.prompt) 262 | assert.Error(t, err) 263 | assert.Contains(t, err.Error(), tt.expectedError) 264 | }) 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /provider/internal/openrouter/convert_to_openrouter_completion_prompt.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | import ( 4 | "strings" 5 | 6 | "go.jetify.com/ai/api" 7 | ) 8 | 9 | type InputFormat string 10 | 11 | const ( 12 | InputFormatPrompt InputFormat = "prompt" 13 | InputFormatMessages InputFormat = "messages" 14 | ) 15 | 16 | type CompletionPromptOptions struct { 17 | Prompt []api.Message 18 | InputFormat InputFormat 19 | User string // defaults to "user" if empty 20 | Assistant string // defaults to "assistant" if empty 21 | } 22 | 23 | // ConvertToOpenRouterCompletionPrompt converts an AI SDK prompt into OpenRouter's completion format. 24 | // It returns the formatted prompt string and optional stop sequences. 25 | func ConvertToOpenRouterCompletionPrompt(opts CompletionPromptOptions) (string, []string, error) { 26 | if opts.User == "" { 27 | opts.User = "user" 28 | } 29 | if opts.Assistant == "" { 30 | opts.Assistant = "assistant" 31 | } 32 | 33 | // Handle direct prompt case 34 | if opts.InputFormat == InputFormatPrompt { 35 | if text, ok := isDirectUserPrompt(opts.Prompt); ok { 36 | return text, nil, nil 37 | } 38 | } 39 | 40 | var b strings.Builder 41 | 42 | // Handle system message prefix if present 43 | if len(opts.Prompt) > 0 { 44 | if sys, ok := opts.Prompt[0].(*api.SystemMessage); ok { 45 | b.WriteString(sys.Content) 46 | b.WriteString("\n\n") 47 | opts.Prompt = opts.Prompt[1:] 48 | } 49 | } 50 | 51 | // Process remaining messages 52 | for _, msg := range opts.Prompt { 53 | if err := writeMessage(&b, msg, opts.User, opts.Assistant); err != nil { 54 | return "", nil, err 55 | } 56 | } 57 | 58 | // Add final assistant prefix 59 | b.WriteString(opts.Assistant) 60 | b.WriteString(":\n") 61 | 62 | return b.String(), []string{"\n" + opts.User + ":"}, nil 63 | } 64 | 65 | // isDirectUserPrompt checks if the prompt is a single user message with a single text block 66 | func isDirectUserPrompt(prompt []api.Message) (string, bool) { 67 | if len(prompt) != 1 { 68 | return "", false 69 | } 70 | 71 | um, ok := prompt[0].(*api.UserMessage) 72 | if !ok { 73 | return "", false 74 | } 75 | 76 | if len(um.Content) != 1 { 77 | return "", false 78 | } 79 | 80 | textBlock, ok := um.Content[0].(*api.TextBlock) 81 | if !ok { 82 | return "", false 83 | } 84 | 85 | return textBlock.Text, true 86 | } 87 | 88 | // writeMessage formats and writes a single message to the string builder 89 | func writeMessage(b *strings.Builder, msg api.Message, user, assistant string) error { 90 | switch m := msg.(type) { 91 | case *api.SystemMessage: 92 | return api.NewInvalidPromptError(msg, "unexpected system message in prompt", nil) 93 | 94 | case *api.UserMessage: 95 | text, err := gatherUserText(m) 96 | if err != nil { 97 | return err 98 | } 99 | b.WriteString(user) 100 | b.WriteString(":\n") 101 | b.WriteString(text) 102 | b.WriteString("\n\n") 103 | 104 | case *api.AssistantMessage: 105 | text, err := gatherAssistantText(m) 106 | if err != nil { 107 | return err 108 | } 109 | b.WriteString(assistant) 110 | b.WriteString(":\n") 111 | b.WriteString(text) 112 | b.WriteString("\n\n") 113 | 114 | case *api.ToolMessage: 115 | return api.NewUnsupportedFunctionalityError("tool messages", "") 116 | 117 | default: 118 | return api.NewInvalidPromptError(msg, "unknown message type", nil) 119 | } 120 | 121 | return nil 122 | } 123 | 124 | // gatherUserText collects text from user message blocks, rejecting unsupported content 125 | func gatherUserText(msg *api.UserMessage) (string, error) { 126 | var b strings.Builder 127 | 128 | for _, block := range msg.Content { 129 | switch block := block.(type) { 130 | case *api.TextBlock: 131 | b.WriteString(block.Text) 132 | case *api.ImageBlock: 133 | return "", api.NewUnsupportedFunctionalityError("images", "") 134 | case *api.FileBlock: 135 | return "", api.NewUnsupportedFunctionalityError("file attachments", "") 136 | default: 137 | return "", api.NewUnsupportedFunctionalityError("unknown content type", "") 138 | } 139 | } 140 | 141 | return b.String(), nil 142 | } 143 | 144 | // gatherAssistantText collects text from assistant message blocks, rejecting tool calls 145 | func gatherAssistantText(msg *api.AssistantMessage) (string, error) { 146 | var b strings.Builder 147 | 148 | for _, block := range msg.Content { 149 | switch block := block.(type) { 150 | case *api.TextBlock: 151 | b.WriteString(block.Text) 152 | case *api.ToolCallBlock: 153 | return "", api.NewUnsupportedFunctionalityError("tool-call messages", "") 154 | default: 155 | return "", api.NewUnsupportedFunctionalityError("unknown content type", "") 156 | } 157 | } 158 | 159 | return b.String(), nil 160 | } 161 | -------------------------------------------------------------------------------- /provider/internal/openrouter/convert_to_openrouter_completion_prompt_test.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "go.jetify.com/ai/api" 9 | ) 10 | 11 | func TestConvertToOpenRouterCompletionPrompt(t *testing.T) { 12 | t.Run("direct prompt", func(t *testing.T) { 13 | prompt := []api.Message{ 14 | &api.UserMessage{ 15 | Content: []api.ContentBlock{ 16 | &api.TextBlock{Text: "hello world"}, 17 | }, 18 | }, 19 | } 20 | 21 | text, stop, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 22 | Prompt: prompt, 23 | InputFormat: InputFormatPrompt, 24 | }) 25 | 26 | assert.NoError(t, err) 27 | assert.Equal(t, "hello world", text) 28 | assert.Nil(t, stop) 29 | }) 30 | 31 | t.Run("system message prefix", func(t *testing.T) { 32 | prompt := []api.Message{ 33 | &api.SystemMessage{Content: "system instruction"}, 34 | &api.UserMessage{ 35 | Content: []api.ContentBlock{ 36 | &api.TextBlock{Text: "user message"}, 37 | }, 38 | }, 39 | } 40 | 41 | text, stop, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 42 | Prompt: prompt, 43 | InputFormat: InputFormatMessages, 44 | }) 45 | 46 | assert.NoError(t, err) 47 | assert.Equal(t, "system instruction\n\nuser:\nuser message\n\nassistant:\n", text) 48 | assert.Equal(t, []string{"\nuser:"}, stop) 49 | }) 50 | 51 | t.Run("user and assistant messages", func(t *testing.T) { 52 | prompt := []api.Message{ 53 | &api.UserMessage{ 54 | Content: []api.ContentBlock{ 55 | &api.TextBlock{Text: "hello"}, 56 | }, 57 | }, 58 | &api.AssistantMessage{ 59 | Content: []api.ContentBlock{ 60 | &api.TextBlock{Text: "hi there"}, 61 | }, 62 | }, 63 | &api.UserMessage{ 64 | Content: []api.ContentBlock{ 65 | &api.TextBlock{Text: "how are you?"}, 66 | }, 67 | }, 68 | } 69 | 70 | text, stop, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 71 | Prompt: prompt, 72 | InputFormat: InputFormatMessages, 73 | }) 74 | 75 | assert.NoError(t, err) 76 | assert.Equal(t, "user:\nhello\n\nassistant:\nhi there\n\nuser:\nhow are you?\n\nassistant:\n", text) 77 | assert.Equal(t, []string{"\nuser:"}, stop) 78 | }) 79 | 80 | t.Run("custom user and assistant labels", func(t *testing.T) { 81 | prompt := []api.Message{ 82 | &api.UserMessage{ 83 | Content: []api.ContentBlock{ 84 | &api.TextBlock{Text: "hello"}, 85 | }, 86 | }, 87 | &api.AssistantMessage{ 88 | Content: []api.ContentBlock{ 89 | &api.TextBlock{Text: "hi"}, 90 | }, 91 | }, 92 | } 93 | 94 | text, stop, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 95 | Prompt: prompt, 96 | InputFormat: InputFormatMessages, 97 | User: "Human", 98 | Assistant: "AI", 99 | }) 100 | 101 | assert.NoError(t, err) 102 | assert.Equal(t, "Human:\nhello\n\nAI:\nhi\n\nAI:\n", text) 103 | assert.Equal(t, []string{"\nHuman:"}, stop) 104 | }) 105 | 106 | t.Run("unsupported image content", func(t *testing.T) { 107 | prompt := []api.Message{ 108 | &api.UserMessage{ 109 | Content: []api.ContentBlock{ 110 | api.ImageBlockFromURL("http://example.com/image.jpg"), 111 | }, 112 | }, 113 | } 114 | 115 | _, _, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 116 | Prompt: prompt, 117 | InputFormat: InputFormatMessages, 118 | }) 119 | 120 | assert.Error(t, err) 121 | assert.IsType(t, &api.UnsupportedFunctionalityError{}, err) 122 | assert.Contains(t, err.Error(), "images") 123 | }) 124 | 125 | t.Run("unsupported file content", func(t *testing.T) { 126 | prompt := []api.Message{ 127 | &api.UserMessage{ 128 | Content: []api.ContentBlock{ 129 | api.FileBlockFromURL("http://example.com/doc.pdf"), 130 | }, 131 | }, 132 | } 133 | 134 | _, _, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 135 | Prompt: prompt, 136 | InputFormat: InputFormatMessages, 137 | }) 138 | 139 | assert.Error(t, err) 140 | assert.IsType(t, &api.UnsupportedFunctionalityError{}, err) 141 | assert.Contains(t, err.Error(), "file attachments") 142 | }) 143 | 144 | t.Run("unsupported tool call", func(t *testing.T) { 145 | prompt := []api.Message{ 146 | &api.AssistantMessage{ 147 | Content: []api.ContentBlock{ 148 | &api.ToolCallBlock{ 149 | ToolCallID: "123", 150 | ToolName: "calculator", 151 | Args: json.RawMessage(`{"x": 1, "y": 2}`), 152 | }, 153 | }, 154 | }, 155 | } 156 | 157 | _, _, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 158 | Prompt: prompt, 159 | InputFormat: InputFormatMessages, 160 | }) 161 | 162 | assert.Error(t, err) 163 | assert.IsType(t, &api.UnsupportedFunctionalityError{}, err) 164 | assert.Contains(t, err.Error(), "tool-call messages") 165 | }) 166 | 167 | t.Run("unexpected system message", func(t *testing.T) { 168 | prompt := []api.Message{ 169 | &api.UserMessage{ 170 | Content: []api.ContentBlock{ 171 | &api.TextBlock{Text: "hello"}, 172 | }, 173 | }, 174 | &api.SystemMessage{Content: "unexpected system"}, 175 | } 176 | 177 | _, _, err := ConvertToOpenRouterCompletionPrompt(CompletionPromptOptions{ 178 | Prompt: prompt, 179 | InputFormat: InputFormatMessages, 180 | }) 181 | 182 | assert.Error(t, err) 183 | assert.IsType(t, &api.InvalidPromptError{}, err) 184 | assert.Contains(t, err.Error(), "unexpected system message") 185 | }) 186 | } 187 | -------------------------------------------------------------------------------- /provider/internal/openrouter/model/doc.go: -------------------------------------------------------------------------------- 1 | // Package model provides constants for specifying which model to use when using 2 | // OpenRouter. 3 | // OpenRouter is a unified API that provides access to various AI models from different 4 | // providers including OpenAI, Anthropic, Google, Meta, and others. This package defines 5 | // the model identifiers needed to specify which model to use when 6 | // making requests through OpenRouter. 7 | // 8 | // Model identifiers follow the format "provider/model-name[:tag]", for example: 9 | // - openai/gpt-4 10 | // - anthropic/claude-3-opus 11 | // - meta-llama/llama-3-70b-instruct 12 | // 13 | // Some models have additional tags like ":free" or ":beta" that indicate special 14 | // versions or pricing tiers. 15 | // 16 | // Example usage: 17 | // 18 | // import "github.com/your-org/aisdk/provider/openrouter/model" 19 | // 20 | // // Use a predefined model constant 21 | // modelID := model.O3MiniHigh 22 | // 23 | // For the most up-to-date list of available models and their capabilities, 24 | // see https://openrouter.ai/docs#models 25 | package model 26 | -------------------------------------------------------------------------------- /provider/internal/openrouter/openrouter_error.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "go.jetify.com/ai/api" 9 | ) 10 | 11 | // openRouterErrorData matches the JSON structure of OpenRouter error responses. 12 | type openRouterErrorData struct { 13 | Error struct { 14 | Message string `json:"message"` 15 | Type string `json:"type"` 16 | Param any `json:"param"` 17 | Code *string `json:"code"` 18 | } `json:"error"` 19 | } 20 | 21 | // parseOpenRouterErrorJSON attempts to unmarshal the body into openRouterErrorData. 22 | func parseOpenRouterErrorJSON(body []byte) (*openRouterErrorData, error) { 23 | var parsed openRouterErrorData 24 | if err := json.Unmarshal(body, &parsed); err != nil { 25 | return nil, api.NewJSONParseError(string(body), err) 26 | } 27 | return &parsed, nil 28 | } 29 | 30 | // OpenRouterFailedResponseHandler constructs an APICallError from a non-2xx OpenRouter response. 31 | func OpenRouterFailedResponseHandler(resp *http.Response, rawBody []byte, requestBody any) error { 32 | parsed, err := parseOpenRouterErrorJSON(rawBody) 33 | if err == nil { 34 | return &api.APICallError{ 35 | AISDKError: api.NewAISDKError("AI_APICallError", parsed.Error.Message, nil), 36 | URL: resp.Request.URL, 37 | Request: resp.Request, 38 | StatusCode: resp.StatusCode, 39 | Response: resp, 40 | Data: parsed, 41 | } 42 | } 43 | 44 | // Fallback if we cannot parse the error JSON 45 | return &api.APICallError{ 46 | AISDKError: api.NewAISDKError("AI_APICallError", fmt.Sprintf("%d %s", resp.StatusCode, resp.Status), err), 47 | URL: resp.Request.URL, 48 | Request: resp.Request, 49 | StatusCode: resp.StatusCode, 50 | Response: resp, 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /provider/internal/openrouter/openrouter_error_test.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "go.jetify.com/ai/api" 13 | ) 14 | 15 | func TestParseOpenRouterErrorJSON(t *testing.T) { 16 | tests := []struct { 17 | name string 18 | json string 19 | want *openRouterErrorData 20 | wantErr bool 21 | }{ 22 | { 23 | name: "valid error json", 24 | json: `{ 25 | "error": { 26 | "message": "invalid request", 27 | "type": "invalid_request_error", 28 | "param": "model", 29 | "code": "model_not_found" 30 | } 31 | }`, 32 | want: &openRouterErrorData{ 33 | Error: struct { 34 | Message string `json:"message"` 35 | Type string `json:"type"` 36 | Param any `json:"param"` 37 | Code *string `json:"code"` 38 | }{ 39 | Message: "invalid request", 40 | Type: "invalid_request_error", 41 | Param: "model", 42 | Code: strPtr("model_not_found"), 43 | }, 44 | }, 45 | }, 46 | { 47 | name: "null fields", 48 | json: `{ 49 | "error": { 50 | "message": "rate limited", 51 | "type": "rate_limit_error", 52 | "param": null, 53 | "code": null 54 | } 55 | }`, 56 | want: &openRouterErrorData{ 57 | Error: struct { 58 | Message string `json:"message"` 59 | Type string `json:"type"` 60 | Param any `json:"param"` 61 | Code *string `json:"code"` 62 | }{ 63 | Message: "rate limited", 64 | Type: "rate_limit_error", 65 | Param: nil, 66 | Code: nil, 67 | }, 68 | }, 69 | }, 70 | { 71 | name: "invalid json", 72 | json: `{"error": {`, 73 | wantErr: true, 74 | }, 75 | { 76 | name: "empty json", 77 | json: `{}`, 78 | want: &openRouterErrorData{}, 79 | }, 80 | } 81 | 82 | for _, tt := range tests { 83 | t.Run(tt.name, func(t *testing.T) { 84 | got, err := parseOpenRouterErrorJSON([]byte(tt.json)) 85 | if tt.wantErr { 86 | assert.Error(t, err) 87 | return 88 | } 89 | assert.NoError(t, err) 90 | assert.Equal(t, tt.want, got) 91 | }) 92 | } 93 | } 94 | 95 | func TestOpenRouterFailedResponseHandler(t *testing.T) { 96 | tests := []struct { 97 | name string 98 | statusCode int 99 | body string 100 | wantMessage string 101 | }{ 102 | { 103 | name: "valid error response", 104 | statusCode: http.StatusBadRequest, 105 | body: `{ 106 | "error": { 107 | "message": "invalid request", 108 | "type": "invalid_request_error", 109 | "param": "model", 110 | "code": "model_not_found" 111 | } 112 | }`, 113 | wantMessage: "invalid request", 114 | }, 115 | { 116 | name: "rate limit error", 117 | statusCode: http.StatusTooManyRequests, 118 | body: `{ 119 | "error": { 120 | "message": "rate limited", 121 | "type": "rate_limit_error", 122 | "param": null, 123 | "code": null 124 | } 125 | }`, 126 | wantMessage: "rate limited", 127 | }, 128 | { 129 | name: "invalid json falls back to status", 130 | statusCode: http.StatusBadRequest, 131 | body: `{"error": {`, 132 | wantMessage: "400 Bad Request", 133 | }, 134 | { 135 | name: "empty response falls back to status", 136 | statusCode: http.StatusInternalServerError, 137 | body: "", 138 | wantMessage: "500 Internal Server Error", 139 | }, 140 | } 141 | 142 | for _, tt := range tests { 143 | t.Run(tt.name, func(t *testing.T) { 144 | // Create a mock response 145 | url := &url.URL{Scheme: "https", Host: "api.openrouter.ai", Path: "/api/v1/chat/completions"} 146 | req := &http.Request{URL: url} 147 | resp := &http.Response{ 148 | StatusCode: tt.statusCode, 149 | Status: http.StatusText(tt.statusCode), 150 | Request: req, 151 | Body: io.NopCloser(bytes.NewBufferString(tt.body)), 152 | } 153 | 154 | // Call the handler 155 | err := OpenRouterFailedResponseHandler(resp, []byte(tt.body), nil) 156 | 157 | // Use errors.As instead of type assertion 158 | var apiErr *api.APICallError 159 | assert.True(t, errors.As(err, &apiErr)) 160 | 161 | // Check the error details 162 | assert.Equal(t, tt.statusCode, apiErr.StatusCode) 163 | assert.Equal(t, url, apiErr.URL) 164 | assert.Equal(t, tt.wantMessage, apiErr.Error()) 165 | 166 | // Check retryable status 167 | isRetryable := apiErr.IsRetryable() 168 | if tt.statusCode == 429 || tt.statusCode >= 500 { 169 | assert.True(t, isRetryable) 170 | } else { 171 | assert.False(t, isRetryable) 172 | } 173 | }) 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /provider/internal/openrouter/openrouter_provider.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | ) 11 | 12 | // OpenRouterProvider represents the OpenRouter API provider. 13 | type OpenRouterProvider struct { 14 | baseURL string 15 | apiKey string 16 | client *http.Client 17 | headers map[string]string 18 | } 19 | 20 | // NewOpenRouterProvider creates a new OpenRouter provider. 21 | func NewOpenRouterProvider(baseURL string, apiKey string, opts ...ProviderOption) *OpenRouterProvider { 22 | p := &OpenRouterProvider{ 23 | baseURL: baseURL, 24 | apiKey: apiKey, 25 | client: http.DefaultClient, 26 | headers: make(map[string]string), 27 | } 28 | 29 | for _, opt := range opts { 30 | opt(p) 31 | } 32 | 33 | return p 34 | } 35 | 36 | // ProviderOption configures the OpenRouter provider. 37 | type ProviderOption func(*OpenRouterProvider) 38 | 39 | // WithClient sets a custom HTTP client. 40 | func WithClient(client *http.Client) ProviderOption { 41 | return func(p *OpenRouterProvider) { 42 | p.client = client 43 | } 44 | } 45 | 46 | // WithHeaders sets custom headers for API requests. 47 | func WithHeaders(headers map[string]string) ProviderOption { 48 | return func(p *OpenRouterProvider) { 49 | for k, v := range headers { 50 | p.headers[k] = v 51 | } 52 | } 53 | } 54 | 55 | // doJSONRequest makes a JSON request to the OpenRouter API. 56 | func (p *OpenRouterProvider) doJSONRequest(ctx context.Context, method, path string, body any, extraHeaders map[string]string) (*http.Response, error) { 57 | var bodyReader io.Reader 58 | if body != nil { 59 | jsonBytes, err := json.Marshal(body) 60 | if err != nil { 61 | return nil, fmt.Errorf("marshal request body: %w", err) 62 | } 63 | bodyReader = bytes.NewReader(jsonBytes) 64 | } 65 | 66 | req, err := http.NewRequestWithContext(ctx, method, p.baseURL+path, bodyReader) 67 | if err != nil { 68 | return nil, fmt.Errorf("create request: %w", err) 69 | } 70 | 71 | // Set default headers 72 | req.Header.Set("Content-Type", "application/json") 73 | req.Header.Set("Authorization", "Bearer "+p.apiKey) 74 | 75 | // Set provider headers 76 | for k, v := range p.headers { 77 | req.Header.Set(k, v) 78 | } 79 | 80 | // Set request-specific headers 81 | for k, v := range extraHeaders { 82 | req.Header.Set(k, v) 83 | } 84 | 85 | resp, err := p.client.Do(req) 86 | if err != nil { 87 | return nil, fmt.Errorf("do request: %w", err) 88 | } 89 | 90 | return resp, nil 91 | } 92 | -------------------------------------------------------------------------------- /provider/internal/openrouter/ptr.go: -------------------------------------------------------------------------------- 1 | package openrouter 2 | 3 | // Helper functions for creating pointers 4 | func strPtr(s string) *string { 5 | return &s 6 | } 7 | -------------------------------------------------------------------------------- /provider/openai/constants.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | const ( 4 | ChatModelO3Mini = "o3-mini" 5 | ChatModelO3Mini2025_01_31 = "o3-mini-2025-01-31" 6 | ChatModelO1 = "o1" 7 | ChatModelO1_2024_12_17 = "o1-2024-12-17" 8 | ChatModelO1Preview = "o1-preview" 9 | ChatModelO1Preview2024_09_12 = "o1-preview-2024-09-12" 10 | ChatModelO1Mini = "o1-mini" 11 | ChatModelO1Mini2024_09_12 = "o1-mini-2024-09-12" 12 | ChatModelGPT4o = "gpt-4o" 13 | ChatModelGPT4o2024_11_20 = "gpt-4o-2024-11-20" 14 | ChatModelGPT4o2024_08_06 = "gpt-4o-2024-08-06" 15 | ChatModelGPT4o2024_05_13 = "gpt-4o-2024-05-13" 16 | ChatModelGPT4oAudioPreview = "gpt-4o-audio-preview" 17 | ChatModelGPT4oAudioPreview2024_10_01 = "gpt-4o-audio-preview-2024-10-01" 18 | ChatModelGPT4oAudioPreview2024_12_17 = "gpt-4o-audio-preview-2024-12-17" 19 | ChatModelGPT4oMiniAudioPreview = "gpt-4o-mini-audio-preview" 20 | ChatModelGPT4oMiniAudioPreview2024_12_17 = "gpt-4o-mini-audio-preview-2024-12-17" 21 | ChatModelGPT4oSearchPreview = "gpt-4o-search-preview" 22 | ChatModelGPT4oMiniSearchPreview = "gpt-4o-mini-search-preview" 23 | ChatModelGPT4oSearchPreview2025_03_11 = "gpt-4o-search-preview-2025-03-11" 24 | ChatModelGPT4oMiniSearchPreview2025_03_11 = "gpt-4o-mini-search-preview-2025-03-11" 25 | ChatModelChatgpt4oLatest = "chatgpt-4o-latest" 26 | ChatModelGPT4oMini = "gpt-4o-mini" 27 | ChatModelGPT4oMini2024_07_18 = "gpt-4o-mini-2024-07-18" 28 | ChatModelGPT4Turbo = "gpt-4-turbo" 29 | ChatModelGPT4Turbo2024_04_09 = "gpt-4-turbo-2024-04-09" 30 | ChatModelGPT4_0125Preview = "gpt-4-0125-preview" 31 | ChatModelGPT4TurboPreview = "gpt-4-turbo-preview" 32 | ChatModelGPT4_1106Preview = "gpt-4-1106-preview" 33 | ChatModelGPT4VisionPreview = "gpt-4-vision-preview" 34 | ChatModelGPT4 = "gpt-4" 35 | ChatModelGPT4_0314 = "gpt-4-0314" 36 | ChatModelGPT4_0613 = "gpt-4-0613" 37 | ChatModelGPT4_32k = "gpt-4-32k" 38 | ChatModelGPT4_32k0314 = "gpt-4-32k-0314" 39 | ChatModelGPT4_32k0613 = "gpt-4-32k-0613" 40 | ChatModelGPT3_5Turbo = "gpt-3.5-turbo" 41 | ChatModelGPT3_5Turbo16k = "gpt-3.5-turbo-16k" 42 | ChatModelGPT3_5Turbo0301 = "gpt-3.5-turbo-0301" 43 | ChatModelGPT3_5Turbo0613 = "gpt-3.5-turbo-0613" 44 | ChatModelGPT3_5Turbo1106 = "gpt-3.5-turbo-1106" 45 | ChatModelGPT3_5Turbo0125 = "gpt-3.5-turbo-0125" 46 | ChatModelGPT3_5Turbo16k0613 = "gpt-3.5-turbo-16k-0613" 47 | ) 48 | -------------------------------------------------------------------------------- /provider/openai/internal/codec/decode_stream_test.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | 8 | "github.com/openai/openai-go/responses" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | "go.jetify.com/ai/api" 12 | ) 13 | 14 | func TestDecodeStreamEvents(t *testing.T) { 15 | tests := []struct { 16 | name string 17 | eventJSONs []string 18 | want []api.StreamEvent 19 | }{ 20 | { 21 | name: "simple text stream", 22 | eventJSONs: []string{ 23 | `{"type": "response.created", "response": {"id": "resp_123", "created_at": 1741269019, "model": "gpt-4"}}`, 24 | `{"type": "response.output_text.delta", "delta": "Hello world"}`, 25 | `{"type": "response.completed", "response": {"usage": {"input_tokens": 10, "output_tokens": 5}}}`, 26 | }, 27 | want: []api.StreamEvent{ 28 | &api.ResponseMetadataEvent{ 29 | ID: "resp_123", 30 | Timestamp: time.Date(2025, 3, 6, 13, 50, 19, 0, time.UTC), 31 | ModelID: "gpt-4", 32 | }, 33 | &api.TextDeltaEvent{ 34 | TextDelta: "Hello world", 35 | }, 36 | &api.FinishEvent{ 37 | FinishReason: api.FinishReasonStop, 38 | Usage: api.Usage{ 39 | InputTokens: 10, 40 | OutputTokens: 5, 41 | TotalTokens: 15, 42 | }, 43 | ProviderMetadata: api.NewProviderMetadata(map[string]any{ 44 | "openai": &Metadata{ 45 | ResponseID: "resp_123", 46 | Usage: Usage{ 47 | InputTokens: 10, 48 | OutputTokens: 5, 49 | }, 50 | }, 51 | }), 52 | }, 53 | }, 54 | }, 55 | { 56 | name: "tool call stream", 57 | eventJSONs: []string{ 58 | `{"type": "response.created", "response": {"id": "resp_456", "created_at": 1741269019, "model": "gpt-4"}}`, 59 | `{"type": "response.output_item.added", "output_index": 0, "item": {"type": "function_call", "call_id": "call_123", "name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}`, 60 | `{"type": "response.completed", "response": {"usage": {"input_tokens": 15, "output_tokens": 8}}}`, 61 | }, 62 | want: []api.StreamEvent{ 63 | &api.ResponseMetadataEvent{ 64 | ID: "resp_456", 65 | Timestamp: time.Date(2025, 3, 6, 13, 50, 19, 0, time.UTC), 66 | ModelID: "gpt-4", 67 | }, 68 | &api.ToolCallDeltaEvent{ 69 | ToolCallID: "call_123", 70 | ToolName: "get_weather", 71 | ArgsDelta: []byte(`{"location":"New York"}`), 72 | }, 73 | &api.FinishEvent{ 74 | FinishReason: api.FinishReasonToolCalls, 75 | Usage: api.Usage{ 76 | InputTokens: 15, 77 | OutputTokens: 8, 78 | TotalTokens: 23, 79 | }, 80 | ProviderMetadata: api.NewProviderMetadata(map[string]any{ 81 | "openai": &Metadata{ 82 | ResponseID: "resp_456", 83 | Usage: Usage{ 84 | InputTokens: 15, 85 | OutputTokens: 8, 86 | }, 87 | }, 88 | }), 89 | }, 90 | }, 91 | }, 92 | } 93 | 94 | for _, testCase := range tests { 95 | t.Run(testCase.name, func(t *testing.T) { 96 | // Parse the JSON events 97 | var events []responses.ResponseStreamEventUnion 98 | for _, jsonStr := range testCase.eventJSONs { 99 | var event responses.ResponseStreamEventUnion 100 | err := json.Unmarshal([]byte(jsonStr), &event) 101 | require.NoError(t, err) 102 | events = append(events, event) 103 | } 104 | 105 | // Create a mock stream 106 | stream := newMockStreamReader(events) 107 | 108 | // Decode the stream 109 | result, err := DecodeStream(stream) 110 | require.NoError(t, err) 111 | 112 | // Collect all events from the stream 113 | var got []api.StreamEvent 114 | for event := range result.Stream { 115 | got = append(got, event) 116 | } 117 | 118 | // Compare events using deep equality 119 | assert.Equal(t, testCase.want, got) 120 | }) 121 | } 122 | } 123 | 124 | // mockStreamReader implements the StreamReader interface for testing 125 | type mockStreamReader struct { 126 | events []responses.ResponseStreamEventUnion 127 | index int 128 | err error 129 | } 130 | 131 | // newMockStreamReader creates a new mock stream reader with the given events 132 | func newMockStreamReader(events []responses.ResponseStreamEventUnion) *mockStreamReader { 133 | return &mockStreamReader{ 134 | events: events, 135 | index: -1, 136 | } 137 | } 138 | 139 | // Next advances to the next event, returning true if there is one, false otherwise 140 | func (m *mockStreamReader) Next() bool { 141 | m.index++ 142 | return m.index < len(m.events) 143 | } 144 | 145 | // Current returns the current event 146 | func (m *mockStreamReader) Current() responses.ResponseStreamEventUnion { 147 | return m.events[m.index] 148 | } 149 | 150 | // Err returns any error that occurred while reading the stream 151 | func (m *mockStreamReader) Err() error { 152 | return m.err 153 | } 154 | -------------------------------------------------------------------------------- /provider/openai/internal/codec/metadata.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "go.jetify.com/ai/api" 5 | ) 6 | 7 | // For now we are using a single type for all metadata. 8 | // TODO: Decide if we will need different types for different metadata. 9 | type Metadata struct { 10 | // --- Used in requests --- 11 | 12 | // ParallelToolCalls determines whether to allow the model to run tool calls in parallel. 13 | // When not specified (nil), OpenAI defaults this to true. 14 | ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` 15 | 16 | // PreviousResponseID is the unique ID of the previous response to the model. Use this to create 17 | // multi-turn conversations. Learn more about 18 | // [conversation state](https://platform.openai.com/docs/guides/conversation-state). 19 | PreviousResponseID string `json:"previous_response_id,omitempty"` 20 | 21 | // Store determines whether to store the generated model response for later retrieval via API. 22 | // When not specified (nil), OpenAI defaults this to true. 23 | Store *bool `json:"store,omitempty"` 24 | 25 | // User is a unique identifier representing your end-user, which can help OpenAI to monitor 26 | // and detect abuse. 27 | // [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) 28 | User string `json:"user,omitempty"` 29 | 30 | // Instructions is a system (or developer) message that is inserted as the first item 31 | // in the model's context. 32 | // 33 | // When using along with `previous_response_id`, the instructions from a previous 34 | // response will not be carried over to the next response. This makes it simple to 35 | // swap out system (or developer) messages in new responses. 36 | Instructions string `json:"instructions,omitempty"` 37 | 38 | // StrictSchemas determines whether JSON schema validation should be strict. 39 | // When true (default), the model will strictly follow the provided JSON schema. 40 | // 41 | // Whether to enable strict schema adherence when generating the output. If set to 42 | // true, the model will always follow the exact schema defined in the `schema` 43 | // field. Only a subset of JSON Schema is supported when `strict` is `true`. To 44 | // learn more, read the 45 | // [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). 46 | StrictSchemas *bool `json:"strict_schemas,omitempty"` 47 | 48 | // ReasoningEffort controls the amount of reasoning the model puts into its response. 49 | // Reducing reasoning effort can result in faster responses and fewer tokens used on 50 | // reasoning in a response. 51 | // 52 | // Only supported by reasoning models (O series). 53 | // 54 | // Read more about [reasoning models](https://platform.openai.com/docs/guides/reasoning) 55 | // for more information. 56 | // 57 | // Supported values are `low`, `medium`, and `high`. 58 | ReasoningEffort string `json:"reasoning_effort,omitempty"` 59 | 60 | // ReasoningSummary indicates the level of detail that should be used when 61 | // summarizing the reasoning performed by the model. This can be useful for 62 | // debugging and understanding the model's reasoning process. 63 | // 64 | // Only supported by computer_use_preview. 65 | // 66 | // Supported values are `concise` and `detailed`. 67 | ReasoningSummary string `json:"reasoning_summary,omitempty"` 68 | 69 | // --- Used in blocks --- 70 | 71 | // ImageDetail indicates the level of detail that should be used when processing 72 | // and understanding the image that is being sent to the model. 73 | // 74 | // One of `high`, `low`, or `auto`. Defaults to `auto`. 75 | ImageDetail string `json:"image_detail,omitempty"` 76 | 77 | // Filename is the custom filename to use when sending a file to the model. 78 | Filename string `json:"filename,omitempty"` 79 | 80 | // --- Used in responses --- 81 | 82 | // ResponseID is the unique ID of the response. 83 | ResponseID string `json:"response_id,omitempty"` 84 | // TODO: Decide whether to promote ID to a top-level field. 85 | 86 | // Usage stores token usage details including input tokens, output tokens, a 87 | // breakdown of output tokens, and the total tokens used. 88 | Usage Usage `json:"usage,omitempty"` 89 | 90 | // ComputerSafetyChecks is a list of pending safety checks for the computer call. 91 | ComputerSafetyChecks []ComputerSafetyCheck `json:"computer_safety_checks,omitempty"` 92 | } 93 | 94 | func GetMetadata(source api.MetadataSource) *Metadata { 95 | return api.GetMetadata[Metadata]("openai", source) 96 | } 97 | 98 | // Usage stores token usage details including input tokens, output tokens, a 99 | // breakdown of output tokens, and the total tokens used. 100 | type Usage struct { 101 | // The number of input tokens. 102 | InputTokens int `json:"input_tokens,omitempty"` 103 | 104 | // The number of tokens that were retrieved from the cache. 105 | // [More on prompt caching](https://platform.openai.com/docs/guides/prompt-caching). 106 | InputCachedTokens int `json:"cached_tokens,omitempty"` 107 | 108 | // The number of output tokens. 109 | OutputTokens int `json:"output_tokens,omitempty"` 110 | 111 | // The number of reasoning tokens. 112 | OutputReasoningTokens int `json:"reasoning_tokens,omitempty"` 113 | } 114 | -------------------------------------------------------------------------------- /provider/openai/internal/codec/tools.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import "go.jetify.com/ai/api" 4 | 5 | // FileSearchTool is a built-in tool that searches for relevant content from uploaded files. 6 | // Learn more about the [file search tool](https://platform.openai.com/docs/guides/tools-file-search). 7 | type FileSearchTool struct { 8 | // The IDs of the vector stores to search. 9 | VectorStoreIDs []string `json:"vector_store_ids,omitzero"` 10 | 11 | // The maximum number of results to return. This number should be between 1 and 50 12 | // inclusive. If not provided, it's set to a default. 13 | MaxNumResults int `json:"max_num_results,omitzero"` 14 | 15 | // TODO: Add filters and ranking options 16 | // // A filter to apply based on file attributes. 17 | // Filters X `json:"filters,omitzero"` 18 | // // Ranking options for search. 19 | // RankingOptions X `json:"ranking_options,omitzero"` 20 | } 21 | 22 | var _ api.ProviderDefinedTool = &FileSearchTool{} 23 | 24 | func (t *FileSearchTool) ToolType() string { return "provider-defined" } 25 | 26 | func (t *FileSearchTool) ID() string { 27 | return "openai.file_search" 28 | } 29 | 30 | func (t *FileSearchTool) Name() string { return "file_search" } 31 | 32 | // FileSearchToolCall represents the results of a file search operation. 33 | // See the [file search guide](https://platform.openai.com/docs/guides/tools-file-search) 34 | // for more information. 35 | type FileSearchToolCall struct { 36 | // Queries contains the search terms used to find files 37 | Queries []string `json:"queries"` 38 | // Results holds the matching files after executing the file search. 39 | Results []FileSearchResult `json:"results"` 40 | } 41 | 42 | // FileSearchResult contains metadata and content for a single file match. 43 | type FileSearchResult struct { 44 | // FileID uniquely identifies the file 45 | FileID string `json:"file_id"` 46 | // Filename is the name of the matched file 47 | Filename string `json:"filename"` 48 | // Score indicates the relevance of the match (0.0 to 1.0) 49 | Score float64 `json:"score"` 50 | // Text contains the retrieved file content 51 | Text string `json:"text"` 52 | } 53 | 54 | // WebSearchTool is a built-in tool that searches the web for relevant results to use in a response. 55 | // Learn more about the [web search tool](https://platform.openai.com/docs/guides/tools-web-search). 56 | type WebSearchTool struct { 57 | // High level guidance for the amount of context window space to use for the 58 | // search. One of `low`, `medium`, or `high`. `medium` is the default. 59 | SearchContextSize string `json:"search_context_size,omitempty"` 60 | // User location information for geographically relevant results 61 | UserLocation *WebSearchUserLocation `json:"user_location,omitempty"` 62 | } 63 | 64 | var _ api.ProviderDefinedTool = &WebSearchTool{} 65 | 66 | func (t *WebSearchTool) ToolType() string { return "provider-defined" } 67 | 68 | func (t *WebSearchTool) ID() string { 69 | return "openai.web_search_preview" 70 | } 71 | 72 | func (t *WebSearchTool) Name() string { return "web_search_preview" } 73 | 74 | // WebSearchUserLocation represents the user location information for a web search 75 | type WebSearchUserLocation struct { 76 | // Free text input for the city of the user, e.g. `San Francisco`. 77 | City string `json:"city,omitzero"` 78 | // The two-letter [ISO country code](https://en.wikipedia.org/wiki/ISO_3166-1) of 79 | // the user, e.g. `US`. 80 | Country string `json:"country,omitzero"` 81 | // Free text input for the region of the user, e.g. `California`. 82 | Region string `json:"region,omitzero"` 83 | // The [IANA timezone](https://timeapi.io/documentation/iana-timezones) of the 84 | // user, e.g. `America/Los_Angeles`. 85 | Timezone string `json:"timezone,omitzero"` 86 | } 87 | 88 | // ComputerUseTool is a built-in tool that controls a virtual computer. Learn more about the 89 | // [computer tool](https://platform.openai.com/docs/guides/tools-computer-use). 90 | // 91 | // The properties DisplayHeight, DisplayWidth, Environment, Type are required. 92 | type ComputerUseTool struct { 93 | // The height of the computer display. 94 | DisplayHeight int `json:"display_height,omitempty"` 95 | // The width of the computer display. 96 | DisplayWidth int `json:"display_width,omitempty"` 97 | // The type of computer environment to control. 98 | // 99 | // Any of "mac", "windows", "ubuntu", "browser". 100 | Environment string `json:"environment,omitempty"` 101 | } 102 | 103 | var _ api.ProviderDefinedTool = &ComputerUseTool{} 104 | 105 | func (t *ComputerUseTool) ToolType() string { return "provider-defined" } 106 | 107 | func (t *ComputerUseTool) ID() string { 108 | return "openai.computer_use_preview" 109 | } 110 | 111 | func (t *ComputerUseTool) Name() string { return "computer_use_preview" } 112 | 113 | // ComputerToolCall represents a computer-based tool operation. See the 114 | // [computer use guide](https://platform.openai.com/docs/guides/tools-computer-use) 115 | // for more information. 116 | type ComputerToolCall struct { 117 | // Action represents the type of action to perform. Any of "click", "double_click", 118 | // "drag", "keypress", "move", "screenshot", "scroll", "type", or "wait". 119 | Action string 120 | 121 | // Coordinates represents the screen coordinates to perform the action on, if 122 | // applicable. 123 | // Applies to these types of actions: "click", "double_click", "move". 124 | Coordinates ComputerCoordinates 125 | 126 | // MouseButton indicates which mouse button to press for a "click" action. One of 127 | // "left", "right", "wheel", "back", or "forward". Assume "left" if not specified. 128 | MouseButton string 129 | 130 | // DragPath is the path of coordinates to follow for a "drag" action. Coordinates 131 | // will appear as an array of coordinate objects, eg 132 | // 133 | // ``` 134 | // [ 135 | // 136 | // { x: 100, y: 200 }, 137 | // { x: 200, y: 300 } 138 | // 139 | // ] 140 | // ``` 141 | DragPath []ComputerCoordinates 142 | 143 | // Keys indicates the combination of keys the model is requesting to be pressed 144 | // for a "keypress" action. This is an array of strings, each representing a 145 | // key to be pressed simultaneously. 146 | Keys []string 147 | 148 | // ScrollDistance indicates the distance to scroll in the x and y directions for 149 | // a "scroll" action. 150 | ScrollDistance ComputerCoordinates 151 | 152 | // Text is the text that should be typed in a "type" action. 153 | Text string 154 | } 155 | 156 | type ComputerCoordinates struct { 157 | X int 158 | Y int 159 | } 160 | 161 | // ComputerSafetyCheck represents a pending safety check for the computer call. 162 | // 163 | // The properties ID, Code, Message are required. 164 | type ComputerSafetyCheck struct { 165 | // The ID of the pending safety check. 166 | ID string 167 | // The type of the pending safety check. 168 | Code string 169 | // Details about the pending safety check. 170 | Message string 171 | } 172 | -------------------------------------------------------------------------------- /provider/openai/llm.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/openai/openai-go" 7 | "go.jetify.com/ai/api" 8 | "go.jetify.com/ai/provider/openai/internal/codec" 9 | ) 10 | 11 | // ModelOption is a function type that modifies a LanguageModel. 12 | type ModelOption func(*LanguageModel) 13 | 14 | // WithClient returns a ModelOption that sets the client. 15 | func WithClient(client openai.Client) ModelOption { 16 | // TODO: Instead of only supporting a single client, we can "flatten" 17 | // the options supported by the OpenAI SDK. 18 | return func(m *LanguageModel) { 19 | m.client = client 20 | } 21 | } 22 | 23 | // LanguageModel represents an OpenAI language model. 24 | type LanguageModel struct { 25 | modelID string 26 | client openai.Client 27 | } 28 | 29 | var _ api.LanguageModel = &LanguageModel{} 30 | 31 | // NewLanguageModel creates a new OpenAI language model. 32 | func NewLanguageModel(modelID string, opts ...ModelOption) *LanguageModel { 33 | // Create model with default settings 34 | model := &LanguageModel{ 35 | modelID: modelID, 36 | client: openai.NewClient(), // Default client 37 | } 38 | 39 | // Apply options 40 | for _, opt := range opts { 41 | opt(model) 42 | } 43 | 44 | return model 45 | } 46 | 47 | func (m *LanguageModel) ProviderName() string { 48 | return "openai" 49 | } 50 | 51 | func (m *LanguageModel) ModelID() string { 52 | return m.modelID 53 | } 54 | 55 | func (m *LanguageModel) SupportedUrls() []api.SupportedURL { 56 | // TODO: Make configurable via the constructor. 57 | return []api.SupportedURL{ 58 | { 59 | MediaType: "image/*", 60 | URLPatterns: []string{ 61 | "^https?://.*", 62 | }, 63 | }, 64 | } 65 | } 66 | 67 | func (m *LanguageModel) Generate( 68 | ctx context.Context, prompt []api.Message, opts api.CallOptions, 69 | ) (api.Response, error) { 70 | params, warnings, err := codec.Encode(m.modelID, prompt, opts) 71 | if err != nil { 72 | return api.Response{}, err 73 | } 74 | 75 | openaiResponse, err := m.client.Responses.New(ctx, params) 76 | if err != nil { 77 | return api.Response{}, err 78 | } 79 | 80 | response, err := codec.DecodeResponse(openaiResponse) 81 | if err != nil { 82 | return api.Response{}, err 83 | } 84 | 85 | response.Warnings = append(response.Warnings, warnings...) 86 | return response, nil 87 | } 88 | 89 | func (m *LanguageModel) Stream( 90 | ctx context.Context, prompt []api.Message, opts api.CallOptions, 91 | ) (api.StreamResponse, error) { 92 | // TODO: add warnings to the stream response by adding an initial StreamStart event 93 | // (it could happen inside of codec.Encode) 94 | params, _, err := codec.Encode(m.modelID, prompt, opts) 95 | if err != nil { 96 | return api.StreamResponse{}, err 97 | } 98 | 99 | stream := m.client.Responses.NewStreaming(ctx, params) 100 | response, err := codec.DecodeStream(stream) 101 | if err != nil { 102 | return api.StreamResponse{}, err 103 | } 104 | 105 | return response, nil 106 | } 107 | -------------------------------------------------------------------------------- /provider/openai/metadata.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import "go.jetify.com/ai/provider/openai/internal/codec" 4 | 5 | type Metadata = codec.Metadata 6 | -------------------------------------------------------------------------------- /provider/openai/tools.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import "go.jetify.com/ai/provider/openai/internal/codec" 4 | 5 | type ( 6 | FileSearchTool = codec.FileSearchTool 7 | FileSearchToolCall = codec.FileSearchToolCall 8 | FileSearchResult = codec.FileSearchResult 9 | ) 10 | 11 | type ( 12 | WebSearchTool = codec.WebSearchTool 13 | WebSearchUserLocation = codec.WebSearchUserLocation 14 | ) 15 | 16 | type ( 17 | ComputerUseTool = codec.ComputerUseTool 18 | ComputerToolCall = codec.ComputerToolCall 19 | ) 20 | --------------------------------------------------------------------------------