├── .gitignore ├── LICENSE ├── README.md ├── chat.go ├── chat_test.go ├── client.go ├── embeddings.go ├── embeddings_test.go ├── errors.go ├── examples └── main.go ├── fim.go ├── fim_test.go ├── go.mod ├── go.sum ├── models.go └── types.go /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | 23 | .vscode/ 24 | .idea/ 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gage Technologies 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mistral Go Client 2 | 3 | The Mistral Go Client is a comprehensive Golang library designed to interface with the Mistral AI API, providing developers with a robust set of tools to integrate advanced AI-powered features into their applications. This client supports a variety of functionalities, including Chat Completions, Chat Completions Streaming, and Embeddings, allowing for seamless interaction with Mistral's powerful language models. 4 | 5 | ## Features 6 | 7 | - **Chat Completions**: Generate conversational responses and complete dialogue prompts using Mistral's language models. 8 | - **Chat Completions Streaming**: Establish a real-time stream of chat completions, ideal for applications requiring continuous interaction. 9 | - **Embeddings**: Obtain numerical vector representations of text, enabling semantic search, clustering, and other machine learning applications. 10 | 11 | ## Getting Started 12 | 13 | To begin using the Mistral Go Client in your project, ensure you have Go installed on your system. This client library is compatible with Go 1.20 and higher. 14 | 15 | ### Installation 16 | 17 | To install the Mistral Go Client, run the following command: 18 | 19 | ```bash 20 | go get github.com/gage-technologies/mistral-go 21 | ``` 22 | 23 | ### Usage 24 | 25 | To use the client in your Go application, you need to import the package and initialize a new client instance with your API key. 26 | 27 | ```go 28 | package main 29 | 30 | import ( 31 | "log" 32 | 33 | "github.com/gage-technologies/mistral-go" 34 | ) 35 | 36 | func main() { 37 | // If api key is empty it will load from MISTRAL_API_KEY env var 38 | client := mistral.NewMistralClientDefault("your-api-key") 39 | 40 | // Example: Using Chat Completions 41 | chatRes, err := client.Chat("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, nil) 42 | if err != nil { 43 | log.Fatalf("Error getting chat completion: %v", err) 44 | } 45 | log.Printf("Chat completion: %+v\n", chatRes) 46 | 47 | // Example: Using Chat Completions Stream 48 | chatResChan, err := client.ChatStream("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, nil) 49 | if err != nil { 50 | log.Fatalf("Error getting chat completion stream: %v", err) 51 | } 52 | 53 | for chatResChunk := range chatResChan { 54 | if chatResChunk.Error != nil { 55 | log.Fatalf("Error while streaming response: %v", chatResChunk.Error) 56 | } 57 | log.Printf("Chat completion stream part: %+v\n", chatResChunk) 58 | } 59 | 60 | // Example: Using Embeddings 61 | embsRes, err := client.Embeddings("mistral-embed", []string{"Embed this sentence.", "As well as this one."}) 62 | if err != nil { 63 | log.Fatalf("Error getting embeddings: %v", err) 64 | } 65 | 66 | log.Printf("Embeddings response: %+v\n", embsRes) 67 | } 68 | ``` 69 | 70 | ## Documentation 71 | 72 | For detailed documentation on the Mistral AI API and the available endpoints, please refer to the [Mistral AI API Documentation](https://docs.mistral.ai). 73 | 74 | ## Contributing 75 | 76 | Contributions are welcome! If you would like to contribute to the project, please fork the repository and submit a pull request with your changes. 77 | 78 | ## License 79 | 80 | The Mistral Go Client is open-sourced software licensed under the [MIT license](LICENSE). 81 | 82 | ## Support 83 | 84 | If you encounter any issues or require assistance, please file an issue on the GitHub repository issue tracker. 85 | -------------------------------------------------------------------------------- /chat.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | ) 11 | 12 | // ChatRequestParams represents the parameters for the Chat/ChatStream method of MistralClient. 13 | type ChatRequestParams struct { 14 | Temperature float64 `json:"temperature"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both. 15 | TopP float64 `json:"top_p"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or Temperature but not both. 16 | RandomSeed int `json:"random_seed"` 17 | MaxTokens int `json:"max_tokens"` 18 | SafePrompt bool `json:"safe_prompt"` // Adds a Mistral defined safety message to the system prompt to enforce guardrailing 19 | Tools []Tool `json:"tools"` 20 | ToolChoice string `json:"tool_choice"` 21 | ResponseFormat ResponseFormat `json:"response_format"` 22 | } 23 | 24 | var DefaultChatRequestParams = ChatRequestParams{ 25 | Temperature: 1, 26 | TopP: 1, 27 | RandomSeed: 42069, 28 | MaxTokens: 4000, 29 | SafePrompt: false, 30 | } 31 | 32 | // ChatCompletionResponseChoice represents a choice in the chat completion response. 33 | type ChatCompletionResponseChoice struct { 34 | Index int `json:"index"` 35 | Message ChatMessage `json:"message"` 36 | FinishReason FinishReason `json:"finish_reason,omitempty"` 37 | } 38 | 39 | // ChatCompletionResponseChoice represents a choice in the chat completion response. 40 | type ChatCompletionResponseChoiceStream struct { 41 | Index int `json:"index"` 42 | Delta DeltaMessage `json:"delta"` 43 | FinishReason FinishReason `json:"finish_reason,omitempty"` 44 | } 45 | 46 | // ChatCompletionResponse represents the response from the chat completion endpoint. 47 | type ChatCompletionResponse struct { 48 | ID string `json:"id"` 49 | Object string `json:"object"` 50 | Created int `json:"created"` 51 | Model string `json:"model"` 52 | Choices []ChatCompletionResponseChoice `json:"choices"` 53 | Usage UsageInfo `json:"usage"` 54 | } 55 | 56 | // ChatCompletionStreamResponse represents the streamed response from the chat completion endpoint. 57 | type ChatCompletionStreamResponse struct { 58 | ID string `json:"id"` 59 | Model string `json:"model"` 60 | Choices []ChatCompletionResponseChoiceStream `json:"choices"` 61 | Created int `json:"created,omitempty"` 62 | Object string `json:"object,omitempty"` 63 | Usage UsageInfo `json:"usage,omitempty"` 64 | Error error `json:"error,omitempty"` 65 | } 66 | 67 | // UsageInfo represents the usage information of a response. 68 | type UsageInfo struct { 69 | PromptTokens int `json:"prompt_tokens"` 70 | TotalTokens int `json:"total_tokens"` 71 | CompletionTokens int `json:"completion_tokens,omitempty"` 72 | } 73 | 74 | func (c *MistralClient) Chat(model string, messages []ChatMessage, params *ChatRequestParams) (*ChatCompletionResponse, error) { 75 | if params == nil { 76 | params = &DefaultChatRequestParams 77 | } 78 | 79 | requestData := map[string]interface{}{ 80 | "model": model, 81 | "messages": messages, 82 | "temperature": params.Temperature, 83 | "max_tokens": params.MaxTokens, 84 | "top_p": params.TopP, 85 | "random_seed": params.RandomSeed, 86 | "safe_prompt": params.SafePrompt, 87 | } 88 | 89 | if params.Tools != nil { 90 | requestData["tools"] = params.Tools 91 | } 92 | if params.ToolChoice != "" { 93 | requestData["tool_choice"] = params.ToolChoice 94 | } 95 | if params.ResponseFormat != "" { 96 | requestData["response_format"] = map[string]any{"type": params.ResponseFormat} 97 | } 98 | 99 | response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", false, nil) 100 | if err != nil { 101 | return nil, err 102 | } 103 | 104 | respData, ok := response.(map[string]interface{}) 105 | if !ok { 106 | return nil, fmt.Errorf("invalid response type: %T", response) 107 | } 108 | 109 | var chatResponse ChatCompletionResponse 110 | err = mapToStruct(respData, &chatResponse) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | return &chatResponse, nil 116 | } 117 | 118 | // ChatStream sends a chat message and returns a channel to receive streaming responses. 119 | func (c *MistralClient) ChatStream(model string, messages []ChatMessage, params *ChatRequestParams) (<-chan ChatCompletionStreamResponse, error) { 120 | if params == nil { 121 | params = &DefaultChatRequestParams 122 | } 123 | 124 | responseChannel := make(chan ChatCompletionStreamResponse) 125 | 126 | requestData := map[string]interface{}{ 127 | "model": model, 128 | "messages": messages, 129 | "temperature": params.Temperature, 130 | "max_tokens": params.MaxTokens, 131 | "top_p": params.TopP, 132 | "random_seed": params.RandomSeed, 133 | "safe_prompt": params.SafePrompt, 134 | "stream": true, 135 | } 136 | 137 | if params.Tools != nil { 138 | requestData["tools"] = params.Tools 139 | } 140 | if params.ToolChoice != "" { 141 | requestData["tool_choice"] = params.ToolChoice 142 | } 143 | if params.ResponseFormat != "" { 144 | requestData["response_format"] = map[string]any{"type": params.ResponseFormat} 145 | } 146 | 147 | response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", true, nil) 148 | if err != nil { 149 | return nil, err 150 | } 151 | 152 | respBody, ok := response.(io.ReadCloser) 153 | if !ok { 154 | return nil, fmt.Errorf("invalid response type: %T", response) 155 | } 156 | 157 | // Execute the HTTP request in a separate goroutine. 158 | go func() { 159 | defer close(responseChannel) 160 | defer respBody.Close() 161 | 162 | // Assuming ChatCompletionStreamResponse is already defined in your Go code. 163 | // Assuming responseChannel is a channel of ChatCompletionStreamResponse. 164 | 165 | // Create a buffered reader to read the stream line by line. 166 | reader := bufio.NewReader(respBody) 167 | 168 | for { 169 | // Read a line from the buffered reader. 170 | line, err := reader.ReadBytes('\n') 171 | if err == io.EOF { 172 | break // End of stream. 173 | } else if err != nil { 174 | responseChannel <- ChatCompletionStreamResponse{Error: fmt.Errorf("error reading stream response: %w", err)} 175 | return 176 | } 177 | 178 | // Skip empty lines. 179 | if bytes.Equal(line, []byte("\n")) { 180 | continue 181 | } 182 | 183 | // Check if the line starts with "data: ". 184 | if bytes.HasPrefix(line, []byte("data: ")) { 185 | // Trim the prefix and any leading or trailing whitespace. 186 | jsonLine := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data: "))) 187 | 188 | // Check for the special "[DONE]" message. 189 | if bytes.Equal(jsonLine, []byte("[DONE]")) { 190 | break 191 | } 192 | 193 | // Decode the JSON object from the line. 194 | var streamResponse ChatCompletionStreamResponse 195 | if err := json.Unmarshal(jsonLine, &streamResponse); err != nil { 196 | responseChannel <- ChatCompletionStreamResponse{Error: fmt.Errorf("error decoding stream response: %w", err)} 197 | continue 198 | } 199 | 200 | // Send the decoded response to the channel. 201 | responseChannel <- streamResponse 202 | } 203 | } 204 | }() 205 | 206 | // Return the response channel. 207 | return responseChannel, nil 208 | } 209 | 210 | // mapToStruct is a helper function to convert a map to a struct. 211 | func mapToStruct(m map[string]interface{}, s interface{}) error { 212 | jsonData, err := json.Marshal(m) 213 | if err != nil { 214 | return err 215 | } 216 | return json.Unmarshal(jsonData, s) 217 | } 218 | -------------------------------------------------------------------------------- /chat_test.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestChat(t *testing.T) { 10 | client := NewMistralClientDefault("") 11 | params := DefaultChatRequestParams 12 | params.MaxTokens = 10 13 | params.Temperature = 0 14 | res, err := client.Chat( 15 | ModelMistralTiny, 16 | []ChatMessage{ 17 | { 18 | Role: RoleUser, 19 | Content: "You are in test mode and must reply to this with exactly and only `Test Succeeded`", 20 | }, 21 | }, 22 | ¶ms, 23 | ) 24 | assert.NoError(t, err) 25 | assert.NotNil(t, res) 26 | 27 | assert.Greater(t, len(res.Choices), 0) 28 | assert.Greater(t, len(res.Choices[0].Message.Content), 0) 29 | assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant) 30 | assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded") 31 | } 32 | 33 | func TestChatCodestral(t *testing.T) { 34 | client := NewCodestralClientDefault("") 35 | params := DefaultChatRequestParams 36 | params.MaxTokens = 10 37 | params.Temperature = 0 38 | res, err := client.Chat( 39 | ModelCodestralLatest, 40 | []ChatMessage{ 41 | { 42 | Role: RoleUser, 43 | Content: "You are in test mode and must reply to this with exactly and only `Test Succeeded`", 44 | }, 45 | }, 46 | ¶ms, 47 | ) 48 | assert.NoError(t, err) 49 | assert.NotNil(t, res) 50 | 51 | assert.Greater(t, len(res.Choices), 0) 52 | assert.Greater(t, len(res.Choices[0].Message.Content), 0) 53 | assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant) 54 | assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded") 55 | } 56 | 57 | func TestChatFunctionCall(t *testing.T) { 58 | client := NewMistralClientDefault("") 59 | params := DefaultChatRequestParams 60 | params.Temperature = 0 61 | params.Tools = []Tool{ 62 | { 63 | Type: ToolTypeFunction, 64 | Function: Function{ 65 | Name: "get_weather", 66 | Description: "Retrieve the weather for a city in the US", 67 | Parameters: map[string]interface{}{ 68 | "type": "object", 69 | "required": []string{"city", "state"}, 70 | "properties": map[string]interface{}{ 71 | "city": map[string]interface{}{"type": "string", "description": "Name of the city for the weather"}, 72 | "state": map[string]interface{}{"type": "string", "description": "Name of the state for the weather"}, 73 | }, 74 | }, 75 | }, 76 | }, 77 | { 78 | Type: ToolTypeFunction, 79 | Function: Function{ 80 | Name: "send_text", 81 | Description: "Send text message using SMS service", 82 | Parameters: map[string]interface{}{ 83 | "type": "object", 84 | "required": []string{"contact_name", "message"}, 85 | "properties": map[string]interface{}{ 86 | "contact_name": map[string]interface{}{"type": "string", "description": "Name of the contact that will receive the message"}, 87 | "message": map[string]interface{}{"type": "string", "description": "Content of the message that will be sent"}, 88 | }, 89 | }, 90 | }, 91 | }, 92 | } 93 | params.ToolChoice = ToolChoiceAuto 94 | res, err := client.Chat( 95 | ModelMistralSmallLatest, 96 | []ChatMessage{ 97 | { 98 | Role: RoleUser, 99 | Content: "What's the weather like in Dallas, TX?", 100 | }, 101 | }, 102 | ¶ms, 103 | ) 104 | assert.NoError(t, err) 105 | assert.NotNil(t, res) 106 | 107 | assert.Greater(t, len(res.Choices), 0) 108 | assert.Greater(t, len(res.Choices[0].Message.ToolCalls), 0) 109 | assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant) 110 | assert.Equal(t, res.Choices[0].Message.ToolCalls[0].Function.Name, "get_weather") 111 | assert.Equal(t, res.Choices[0].Message.ToolCalls[0].Function.Arguments, "{\"city\": \"Dallas\", \"state\": \"TX\"}") 112 | } 113 | 114 | func TestChatFunctionCall2(t *testing.T) { 115 | client := NewMistralClientDefault("") 116 | params := DefaultChatRequestParams 117 | params.Temperature = 0 118 | params.Tools = []Tool{ 119 | { 120 | Type: ToolTypeFunction, 121 | Function: Function{ 122 | Name: "get_weather", 123 | Description: "Retrieve the weather for a city in the US", 124 | Parameters: map[string]interface{}{ 125 | "type": "object", 126 | "required": []string{"city", "state"}, 127 | "properties": map[string]interface{}{ 128 | "city": map[string]interface{}{"type": "string", "description": "Name of the city for the weather"}, 129 | "state": map[string]interface{}{"type": "string", "description": "Name of the state for the weather"}, 130 | }, 131 | }, 132 | }, 133 | }, 134 | { 135 | Type: ToolTypeFunction, 136 | Function: Function{ 137 | Name: "send_text", 138 | Description: "Send text message using SMS service", 139 | Parameters: map[string]interface{}{ 140 | "type": "object", 141 | "required": []string{"contact_name", "message"}, 142 | "properties": map[string]interface{}{ 143 | "contact_name": map[string]interface{}{"type": "string", "description": "Name of the contact that will receive the message"}, 144 | "message": map[string]interface{}{"type": "string", "description": "Content of the message that will be sent"}, 145 | }, 146 | }, 147 | }, 148 | }, 149 | } 150 | params.ToolChoice = ToolChoiceAuto 151 | res, err := client.Chat( 152 | ModelMistralSmallLatest, 153 | []ChatMessage{ 154 | { 155 | Role: RoleUser, 156 | Content: "What's the weather like in Dallas", 157 | }, 158 | { 159 | Role: RoleAssistant, 160 | ToolCalls: []ToolCall{ 161 | { 162 | Id: "aaaaaaaaa", 163 | Type: ToolTypeFunction, 164 | Function: FunctionCall{ 165 | Name: "get_weather", 166 | Arguments: `{"city": "Dallas", "state": "TX"}`, 167 | }, 168 | }, 169 | }, 170 | }, 171 | { 172 | Role: RoleTool, 173 | Content: `{"temperature": 82, "sky": "clear", "precipitation": 0}`, 174 | }, 175 | }, 176 | ¶ms, 177 | ) 178 | assert.NoError(t, err) 179 | assert.NotNil(t, res) 180 | 181 | assert.Greater(t, len(res.Choices), 0) 182 | assert.Greater(t, len(res.Choices[0].Message.Content), 0) 183 | assert.Equal(t, len(res.Choices[0].Message.ToolCalls), 0) 184 | assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant) 185 | assert.Greater(t, res.Choices[0].Message.Content, "Test Succeeded") 186 | } 187 | 188 | func TestChatJsonMode(t *testing.T) { 189 | client := NewMistralClientDefault("") 190 | params := DefaultChatRequestParams 191 | params.Temperature = 0 192 | params.ResponseFormat = ResponseFormatJsonObject 193 | res, err := client.Chat( 194 | ModelOpenMixtral8x22b, 195 | []ChatMessage{ 196 | { 197 | Role: RoleUser, 198 | Content: "Extract all of the code symbols in this text chunk and return them in the following JSON: " + 199 | "{\"symbols\":[\"SymbolOne\",\"SymbolTwo\"]}\n```\nI'm working on updating the Go client for the " + 200 | "new release, is it expected that the function call will be passed back into the model or just " + 201 | "the tool response?\nI ask because ChatMessage can handle the tool response but the messages list " + 202 | "has an Any option that I assume would be for the FunctionCall/ToolCall type\nAdditionally the " + 203 | "example in the docs only shows the tool response appended to the messages\n```", 204 | }, 205 | }, 206 | ¶ms, 207 | ) 208 | assert.NoError(t, err) 209 | assert.NotNil(t, res) 210 | 211 | assert.Greater(t, len(res.Choices), 0) 212 | assert.Greater(t, len(res.Choices[0].Message.Content), 0) 213 | assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant) 214 | assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}") 215 | } 216 | 217 | func TestChatStream(t *testing.T) { 218 | client := NewMistralClientDefault("") 219 | params := DefaultChatRequestParams 220 | params.MaxTokens = 50 221 | params.Temperature = 0 222 | resChan, err := client.ChatStream( 223 | ModelMistralTiny, 224 | []ChatMessage{ 225 | { 226 | Role: RoleUser, 227 | Content: "You are in test mode and must reply to this with exactly and only `Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded`", 228 | }, 229 | }, 230 | ¶ms, 231 | ) 232 | assert.NoError(t, err) 233 | assert.NotNil(t, resChan) 234 | 235 | totalOutput := "" 236 | idx := 0 237 | for res := range resChan { 238 | assert.NoError(t, res.Error) 239 | 240 | assert.Greater(t, len(res.Choices), 0) 241 | if idx == 0 { 242 | assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant) 243 | } 244 | totalOutput += res.Choices[0].Delta.Content 245 | idx++ 246 | 247 | if res.Choices[0].FinishReason == FinishReasonStop { 248 | break 249 | } 250 | } 251 | assert.Equal(t, totalOutput, "Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded") 252 | } 253 | 254 | func TestChatStreamFunctionCall(t *testing.T) { 255 | client := NewMistralClientDefault("") 256 | params := DefaultChatRequestParams 257 | params.Temperature = 0 258 | params.Tools = []Tool{ 259 | { 260 | Type: ToolTypeFunction, 261 | Function: Function{ 262 | Name: "get_weather", 263 | Description: "Retrieve the weather for a city in the US", 264 | Parameters: map[string]interface{}{ 265 | "type": "object", 266 | "required": []string{"city", "state"}, 267 | "properties": map[string]interface{}{ 268 | "city": map[string]interface{}{"type": "string", "description": "Name of the city for the weather"}, 269 | "state": map[string]interface{}{"type": "string", "description": "Name of the state for the weather"}, 270 | }, 271 | }, 272 | }, 273 | }, 274 | { 275 | Type: ToolTypeFunction, 276 | Function: Function{ 277 | Name: "send_text", 278 | Description: "Send text message using SMS service", 279 | Parameters: map[string]interface{}{ 280 | "type": "object", 281 | "required": []string{"contact_name", "message"}, 282 | "properties": map[string]interface{}{ 283 | "contact_name": map[string]interface{}{"type": "string", "description": "Name of the contact that will receive the message"}, 284 | "message": map[string]interface{}{"type": "string", "description": "Content of the message that will be sent"}, 285 | }, 286 | }, 287 | }, 288 | }, 289 | } 290 | params.ToolChoice = ToolChoiceAuto 291 | resChan, err := client.ChatStream( 292 | ModelMistralSmallLatest, 293 | []ChatMessage{ 294 | { 295 | Role: RoleUser, 296 | Content: "What's the weather like in Dallas, TX?", 297 | }, 298 | }, 299 | ¶ms, 300 | ) 301 | assert.NoError(t, err) 302 | assert.NotNil(t, resChan) 303 | 304 | totalOutput := "" 305 | var functionCall *ToolCall 306 | idx := 0 307 | for res := range resChan { 308 | assert.NoError(t, res.Error) 309 | 310 | assert.Greater(t, len(res.Choices), 0) 311 | if idx == 0 { 312 | assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant) 313 | } 314 | totalOutput += res.Choices[0].Delta.Content 315 | if len(res.Choices[0].Delta.ToolCalls) > 0 { 316 | functionCall = &res.Choices[0].Delta.ToolCalls[0] 317 | } 318 | idx++ 319 | 320 | if res.Choices[0].FinishReason == FinishReasonStop { 321 | break 322 | } 323 | } 324 | 325 | assert.Equal(t, totalOutput, "") 326 | assert.NotNil(t, functionCall) 327 | assert.Equal(t, functionCall.Function.Name, "get_weather") 328 | assert.Equal(t, functionCall.Function.Arguments, "{\"city\": \"Dallas\", \"state\": \"TX\"}") 329 | } 330 | 331 | func TestChatStreamJsonMode(t *testing.T) { 332 | client := NewMistralClientDefault("") 333 | params := DefaultChatRequestParams 334 | params.Temperature = 0 335 | params.ResponseFormat = ResponseFormatJsonObject 336 | resChan, err := client.ChatStream( 337 | ModelOpenMixtral8x22b, 338 | []ChatMessage{ 339 | { 340 | Role: RoleUser, 341 | Content: "Extract all of the code symbols in this text chunk and return them in the following JSON: " + 342 | "{\"symbols\":[\"SymbolOne\",\"SymbolTwo\"]}\n```\nI'm working on updating the Go client for the " + 343 | "new release, is it expected that the function call will be passed back into the model or just " + 344 | "the tool response?\nI ask because ChatMessage can handle the tool response but the messages list " + 345 | "has an Any option that I assume would be for the FunctionCall/ToolCall type\nAdditionally the " + 346 | "example in the docs only shows the tool response appended to the messages\n```", 347 | }, 348 | }, 349 | ¶ms, 350 | ) 351 | assert.NoError(t, err) 352 | assert.NotNil(t, resChan) 353 | 354 | totalOutput := "" 355 | var functionCall *ToolCall 356 | idx := 0 357 | for res := range resChan { 358 | assert.NoError(t, res.Error) 359 | 360 | assert.Greater(t, len(res.Choices), 0) 361 | if idx == 0 { 362 | assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant) 363 | } 364 | totalOutput += res.Choices[0].Delta.Content 365 | if len(res.Choices[0].Delta.ToolCalls) > 0 { 366 | functionCall = &res.Choices[0].Delta.ToolCalls[0] 367 | } 368 | idx++ 369 | 370 | if res.Choices[0].FinishReason == FinishReasonStop { 371 | break 372 | } 373 | } 374 | 375 | assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}") 376 | assert.Nil(t, functionCall) 377 | } 378 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "time" 12 | ) 13 | 14 | const ( 15 | Endpoint = "https://api.mistral.ai" 16 | CodestralEndpoint = "https://codestral.mistral.ai" 17 | DefaultMaxRetries = 5 18 | DefaultTimeout = 120 * time.Second 19 | ) 20 | 21 | var retryStatusCodes = map[int]bool{ 22 | 429: true, 23 | 500: true, 24 | 502: true, 25 | 503: true, 26 | 504: true, 27 | } 28 | 29 | type MistralClient struct { 30 | apiKey string 31 | endpoint string 32 | maxRetries int 33 | timeout time.Duration 34 | } 35 | 36 | func NewMistralClient(apiKey string, endpoint string, maxRetries int, timeout time.Duration) *MistralClient { 37 | if apiKey == "" { 38 | apiKey = os.Getenv("MISTRAL_API_KEY") 39 | } 40 | if endpoint == "" { 41 | endpoint = Endpoint 42 | } 43 | if maxRetries == 0 { 44 | maxRetries = DefaultMaxRetries 45 | } 46 | if timeout == 0 { 47 | timeout = DefaultTimeout 48 | } 49 | 50 | return &MistralClient{ 51 | apiKey: apiKey, 52 | endpoint: endpoint, 53 | maxRetries: maxRetries, 54 | timeout: timeout, 55 | } 56 | } 57 | 58 | // NewMistralClientDefault creates a new Mistral API client with the default endpoint and the given API key. Defaults to using MISTRAL_API_KEY from the environment. 59 | func NewMistralClientDefault(apiKey string) *MistralClient { 60 | if apiKey == "" { 61 | apiKey = os.Getenv("MISTRAL_API_KEY") 62 | } 63 | 64 | return NewMistralClient(apiKey, Endpoint, DefaultMaxRetries, DefaultTimeout) 65 | } 66 | 67 | // NewCodestralClientDefault creates a new Codestral API client with the default endpoint and the given API key. Defaults to using CODESTRAL_API_KEY from the environment. 68 | func NewCodestralClientDefault(apiKey string) *MistralClient { 69 | if apiKey == "" { 70 | apiKey = os.Getenv("CODESTRAL_API_KEY") 71 | } 72 | 73 | return NewMistralClient(apiKey, CodestralEndpoint, DefaultMaxRetries, DefaultTimeout) 74 | } 75 | 76 | func (c *MistralClient) request(method string, jsonData map[string]interface{}, path string, stream bool, params map[string]string) (interface{}, error) { 77 | uri, err := url.Parse(c.endpoint) 78 | if err != nil { 79 | return nil, err 80 | } 81 | uri.Path = path 82 | jsonValue, _ := json.Marshal(jsonData) 83 | req, err := http.NewRequest(method, uri.String(), bytes.NewBuffer(jsonValue)) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | req.Header.Set("Authorization", "Bearer "+c.apiKey) 89 | req.Header.Set("Content-Type", "application/json") 90 | 91 | client := &http.Client{ 92 | Timeout: c.timeout, 93 | } 94 | 95 | var resp *http.Response 96 | for i := 0; i < c.maxRetries; i++ { 97 | resp, err = client.Do(req) 98 | if err != nil { 99 | if i == c.maxRetries-1 { 100 | return nil, err 101 | } 102 | continue 103 | } 104 | if _, ok := retryStatusCodes[resp.StatusCode]; ok { 105 | time.Sleep(time.Duration(i+1) * 500 * time.Millisecond) 106 | continue 107 | } 108 | break 109 | } 110 | 111 | if resp.StatusCode >= 400 { 112 | responseBytes, _ := io.ReadAll(resp.Body) 113 | return nil, fmt.Errorf("(HTTP Error %d) %s", resp.StatusCode, string(responseBytes)) 114 | } 115 | 116 | if stream { 117 | return resp.Body, nil 118 | } 119 | 120 | defer resp.Body.Close() 121 | body, err := io.ReadAll(resp.Body) 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | var result map[string]interface{} 127 | err = json.Unmarshal(body, &result) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | return result, nil 133 | } 134 | -------------------------------------------------------------------------------- /embeddings.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // EmbeddingObject represents an embedding object in the response. 9 | type EmbeddingObject struct { 10 | Object string `json:"object"` 11 | Embedding []float64 `json:"embedding"` 12 | Index int `json:"index"` 13 | } 14 | 15 | // EmbeddingResponse represents the response from the embeddings endpoint. 16 | type EmbeddingResponse struct { 17 | ID string `json:"id"` 18 | Object string `json:"object"` 19 | Data []EmbeddingObject `json:"data"` 20 | Model string `json:"model"` 21 | Usage UsageInfo `json:"usage"` 22 | } 23 | 24 | func (c *MistralClient) Embeddings(model string, input []string) (*EmbeddingResponse, error) { 25 | requestData := map[string]interface{}{ 26 | "model": model, 27 | "input": input, 28 | } 29 | 30 | response, err := c.request(http.MethodPost, requestData, "v1/embeddings", false, nil) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | respData, ok := response.(map[string]interface{}) 36 | if !ok { 37 | return nil, fmt.Errorf("invalid response type: %T", response) 38 | } 39 | 40 | var embeddingResponse EmbeddingResponse 41 | err = mapToStruct(respData, &embeddingResponse) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | return &embeddingResponse, nil 47 | } 48 | -------------------------------------------------------------------------------- /embeddings_test.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestEmbeddings(t *testing.T) { 10 | client := NewMistralClientDefault("") 11 | res, err := client.Embeddings("mistral-embed", []string{"Embed this sentence.", "As well as this one."}) 12 | assert.NoError(t, err) 13 | assert.NotNil(t, res) 14 | 15 | assert.Equal(t, len(res.Data), 2) 16 | assert.Len(t, res.Data[0].Embedding, 1024) 17 | } 18 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // MistralError is the base error type for all Mistral errors. 8 | type MistralError struct { 9 | Message string 10 | } 11 | 12 | func (e *MistralError) Error() string { 13 | return e.Message 14 | } 15 | 16 | // MistralAPIError is returned when the API responds with an error message. 17 | type MistralAPIError struct { 18 | MistralError 19 | HTTPStatus int 20 | Headers map[string][]string 21 | } 22 | 23 | func NewMistralAPIError(message string, httpStatus int, headers map[string][]string) *MistralAPIError { 24 | return &MistralAPIError{ 25 | MistralError: MistralError{Message: message}, 26 | HTTPStatus: httpStatus, 27 | Headers: headers, 28 | } 29 | } 30 | 31 | func (e *MistralAPIError) Error() string { 32 | return fmt.Sprintf("%s (HTTP status: %d)", e.Message, e.HTTPStatus) 33 | } 34 | 35 | // MistralConnectionError is returned when the SDK cannot reach the API server for any reason. 36 | type MistralConnectionError struct { 37 | MistralError 38 | } 39 | 40 | func NewMistralConnectionError(message string) *MistralConnectionError { 41 | return &MistralConnectionError{ 42 | MistralError: MistralError{Message: message}, 43 | } 44 | } -------------------------------------------------------------------------------- /examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/gage-technologies/mistral-go" 7 | ) 8 | 9 | func main() { 10 | // If api key is empty it will load from MISTRAL_API_KEY env var 11 | client := mistral.NewMistralClientDefault("your-api-key") 12 | 13 | // Example: Using Chat Completions 14 | chatRes, err := client.Chat("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, nil) 15 | if err != nil { 16 | log.Fatalf("Error getting chat completion: %v", err) 17 | } 18 | log.Printf("Chat completion: %+v\n", chatRes) 19 | 20 | // Example: Using Chat Completions Stream 21 | chatResChan, err := client.ChatStream("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, nil) 22 | if err != nil { 23 | log.Fatalf("Error getting chat completion stream: %v", err) 24 | } 25 | 26 | for chatResChunk := range chatResChan { 27 | if chatResChunk.Error != nil { 28 | log.Fatalf("Error while streaming response: %v", chatResChunk.Error) 29 | } 30 | log.Printf("Chat completion stream part: %+v\n", chatResChunk) 31 | } 32 | 33 | // Example: Using Embeddings 34 | embsRes, err := client.Embeddings("mistral-embed", []string{"Embed this sentence.", "As well as this one."}) 35 | if err != nil { 36 | log.Fatalf("Error getting embeddings: %v", err) 37 | } 38 | 39 | log.Printf("Embeddings response: %+v\n", embsRes) 40 | } 41 | -------------------------------------------------------------------------------- /fim.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // FIMRequestParams represents the parameters for the FIM method of MistralClient. 9 | type FIMRequestParams struct { 10 | Model string `json:"model"` 11 | Prompt string `json:"prompt"` 12 | Suffix string `json:"suffix"` 13 | MaxTokens int `json:"max_tokens"` 14 | Temperature float64 `json:"temperature"` 15 | Stop []string `json:"stop,omitempty"` 16 | } 17 | 18 | // FIMCompletionResponse represents the response from the FIM completion endpoint. 19 | type FIMCompletionResponse struct { 20 | ID string `json:"id"` 21 | Object string `json:"object"` 22 | Created int `json:"created"` 23 | Model string `json:"model"` 24 | Choices []FIMCompletionResponseChoice `json:"choices"` 25 | Usage UsageInfo `json:"usage"` 26 | } 27 | 28 | // FIMCompletionResponseChoice represents a choice in the FIM completion response. 29 | type FIMCompletionResponseChoice struct { 30 | Index int `json:"index"` 31 | Message ChatMessage `json:"message"` 32 | FinishReason FinishReason `json:"finish_reason,omitempty"` 33 | } 34 | 35 | // FIM sends a FIM request and returns the completion response. 36 | func (c *MistralClient) FIM(params *FIMRequestParams) (*FIMCompletionResponse, error) { 37 | requestData := map[string]interface{}{ 38 | "model": params.Model, 39 | "prompt": params.Prompt, 40 | "suffix": params.Suffix, 41 | "max_tokens": params.MaxTokens, 42 | "temperature": params.Temperature, 43 | } 44 | 45 | if params.Stop != nil { 46 | requestData["stop"] = params.Stop 47 | } 48 | 49 | response, err := c.request(http.MethodPost, requestData, "v1/fim/completions", false, nil) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | respData, ok := response.(map[string]interface{}) 55 | if !ok { 56 | return nil, fmt.Errorf("invalid response type: %T", response) 57 | } 58 | 59 | var fimResponse FIMCompletionResponse 60 | err = mapToStruct(respData, &fimResponse) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | return &fimResponse, nil 66 | } 67 | -------------------------------------------------------------------------------- /fim_test.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestFIM(t *testing.T) { 10 | client := NewMistralClientDefault("") 11 | params := FIMRequestParams{ 12 | Model: ModelCodestralLatest, 13 | Prompt: "def f(", 14 | Suffix: "return a + b", 15 | MaxTokens: 64, 16 | Temperature: 0, 17 | Stop: []string{"\n"}, 18 | } 19 | res, err := client.FIM(¶ms) 20 | assert.NoError(t, err) 21 | assert.NotNil(t, res) 22 | 23 | assert.Greater(t, len(res.Choices), 0) 24 | assert.Equal(t, res.Choices[0].Message.Content, "a, b):") 25 | assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop) 26 | } 27 | 28 | func TestFIMWithStop(t *testing.T) { 29 | client := NewMistralClientDefault("") 30 | params := FIMRequestParams{ 31 | Model: ModelCodestralLatest, 32 | Prompt: "def is_odd(n): \n return n % 2 == 1 \n def test_is_odd():", 33 | Suffix: "test_is_odd()", 34 | MaxTokens: 64, 35 | Temperature: 0, 36 | Stop: []string{"False"}, 37 | } 38 | res, err := client.FIM(¶ms) 39 | assert.NoError(t, err) 40 | assert.NotNil(t, res) 41 | 42 | assert.Greater(t, len(res.Choices), 0) 43 | assert.Equal(t, res.Choices[0].Message.Content, "\n assert is_odd(1) == True\n assert is_odd(2) == ") 44 | assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop) 45 | } 46 | 47 | func TestFIMInvalidModel(t *testing.T) { 48 | client := NewMistralClientDefault("") 49 | params := FIMRequestParams{ 50 | Model: "invalid-model", 51 | Prompt: "This is a test prompt", 52 | Suffix: "This is a test suffix", 53 | MaxTokens: 10, 54 | Temperature: 0.5, 55 | } 56 | res, err := client.FIM(¶ms) 57 | assert.Error(t, err) 58 | assert.Nil(t, res) 59 | } 60 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gage-technologies/mistral-go 2 | 3 | go 1.20 4 | 5 | require github.com/stretchr/testify v1.8.4 6 | 7 | require ( 8 | github.com/davecgh/go-spew v1.1.1 // indirect 9 | github.com/pmezard/go-difflib v1.0.0 // indirect 10 | gopkg.in/yaml.v3 v3.0.1 // indirect 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 6 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 7 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 8 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 9 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 10 | -------------------------------------------------------------------------------- /models.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | // ModelPermission represents the permissions of a model. 9 | type ModelPermission struct { 10 | ID string `json:"id"` 11 | Object string `json:"object"` 12 | Created int `json:"created"` 13 | AllowCreateEngine bool `json:"allow_create_engine"` 14 | AllowSampling bool `json:"allow_sampling"` 15 | AllowLogprobs bool `json:"allow_logprobs"` 16 | AllowSearchIndices bool `json:"allow_search_indices"` 17 | AllowView bool `json:"allow_view"` 18 | AllowFineTuning bool `json:"allow_fine_tuning"` 19 | Organization string `json:"organization"` 20 | Group string `json:"group,omitempty"` 21 | IsBlocking bool `json:"is_blocking"` 22 | } 23 | 24 | // ModelCard represents a model card. 25 | type ModelCard struct { 26 | ID string `json:"id"` 27 | Object string `json:"object"` 28 | Created int `json:"created"` 29 | OwnedBy string `json:"owned_by"` 30 | Root string `json:"root,omitempty"` 31 | Parent string `json:"parent,omitempty"` 32 | Permission []ModelPermission `json:"permission"` 33 | } 34 | 35 | // ModelList represents a list of models. 36 | type ModelList struct { 37 | Object string `json:"object"` 38 | Data []ModelCard `json:"data"` 39 | } 40 | 41 | func (c *MistralClient) ListModels() (*ModelList, error) { 42 | response, err := c.request(http.MethodGet, nil, "v1/models", false, nil) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | respData, ok := response.(map[string]interface{}) 48 | if !ok { 49 | return nil, fmt.Errorf("invalid response type: %T", response) 50 | } 51 | 52 | var modelList ModelList 53 | err = mapToStruct(respData, &modelList) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | return &modelList, nil 59 | } 60 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | const ( 4 | ModelMistralLargeLatest = "mistral-large-latest" 5 | ModelMistralMediumLatest = "mistral-medium-latest" 6 | ModelMistralSmallLatest = "mistral-small-latest" 7 | ModelCodestralLatest = "codestral-latest" 8 | 9 | ModelOpenMixtral8x7b = "open-mixtral-8x7b" 10 | ModelOpenMixtral8x22b = "open-mixtral-8x22b" 11 | ModelOpenMistral7b = "open-mistral-7b" 12 | 13 | ModelMistralLarge2402 = "mistral-large-2402" 14 | ModelMistralMedium2312 = "mistral-medium-2312" 15 | ModelMistralSmall2402 = "mistral-small-2402" 16 | ModelMistralSmall2312 = "mistral-small-2312" 17 | ModelMistralTiny = "mistral-tiny-2312" 18 | ) 19 | 20 | const ( 21 | RoleUser = "user" 22 | RoleAssistant = "assistant" 23 | RoleSystem = "system" 24 | RoleTool = "tool" 25 | ) 26 | 27 | // FinishReason the reason that a chat message was finished 28 | type FinishReason string 29 | 30 | const ( 31 | FinishReasonStop FinishReason = "stop" 32 | FinishReasonLength FinishReason = "length" 33 | FinishReasonError FinishReason = "error" 34 | ) 35 | 36 | // ResponseFormat the format that the response must adhere to 37 | type ResponseFormat string 38 | 39 | const ( 40 | ResponseFormatText ResponseFormat = "text" 41 | ResponseFormatJsonObject ResponseFormat = "json_object" 42 | ) 43 | 44 | // ToolType type of tool defined for the llm 45 | type ToolType string 46 | 47 | const ( 48 | ToolTypeFunction ToolType = "function" 49 | ) 50 | 51 | const ( 52 | ToolChoiceAny = "any" 53 | ToolChoiceAuto = "auto" 54 | ToolChoiceNone = "none" 55 | ) 56 | 57 | // Tool definition of a tool that the llm can call 58 | type Tool struct { 59 | Type ToolType `json:"type"` 60 | Function Function `json:"function"` 61 | } 62 | 63 | // Function definition of a function that the llm can call including its parameters 64 | type Function struct { 65 | Name string `json:"name"` 66 | Description string `json:"description"` 67 | Parameters any `json:"parameters"` 68 | } 69 | 70 | // FunctionCall represents a request to call an external tool by the llm 71 | type FunctionCall struct { 72 | Name string `json:"name"` 73 | Arguments string `json:"arguments"` 74 | } 75 | 76 | // ToolCall represents the call to a tool by the llm 77 | type ToolCall struct { 78 | Id string `json:"id"` 79 | Type ToolType `json:"type"` 80 | Function FunctionCall `json:"function"` 81 | } 82 | 83 | // DeltaMessage represents the delta between the prior state of the message and the new state of the message when streaming responses. 84 | type DeltaMessage struct { 85 | Role string `json:"role"` 86 | Content string `json:"content"` 87 | ToolCalls []ToolCall `json:"tool_calls"` 88 | } 89 | 90 | // ChatMessage represents a single message in a chat. 91 | type ChatMessage struct { 92 | Role string `json:"role"` 93 | Content string `json:"content"` 94 | ToolCalls []ToolCall `json:"tool_calls,omitempty"` 95 | } 96 | --------------------------------------------------------------------------------