├── go.mod ├── .github ├── FUNDING.yml ├── workflows │ ├── integration-tests.yml │ ├── pr.yml │ └── close-inactive-issues.yml ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── PULL_REQUEST_TEMPLATE.md ├── examples ├── README.md ├── completion │ └── main.go ├── images │ └── main.go ├── voice-to-text │ └── main.go ├── chatbot │ └── main.go └── completion-with-tool │ └── main.go ├── internal ├── marshaller.go ├── unmarshaler.go ├── test │ ├── failer.go │ ├── checks │ │ └── checks.go │ ├── server.go │ └── helpers.go ├── error_accumulator.go ├── error_accumulator_test.go ├── request_builder.go ├── form_builder_test.go ├── form_builder.go └── request_builder_test.go ├── .gitignore ├── common.go ├── openai_test.go ├── engines.go ├── stream.go ├── config_test.go ├── ratelimit.go ├── edits.go ├── engines_test.go ├── speech.go ├── jsonschema ├── validate.go ├── json.go ├── json_test.go └── validate_test.go ├── config.go ├── fine_tunes_test.go ├── stream_reader_test.go ├── speech_test.go ├── stream_reader.go ├── models.go ├── edits_test.go ├── error.go ├── models_test.go ├── fine_tuning_job_test.go ├── audio_test.go ├── files_test.go ├── chat_stream.go ├── moderation.go ├── CONTRIBUTING.md ├── files.go ├── audio_api_test.go ├── fine_tuning_job.go ├── thread_test.go ├── image_test.go ├── api_internal_test.go ├── thread.go ├── moderation_test.go ├── image.go ├── completion_test.go ├── files_api_test.go ├── messages.go ├── image_api_test.go ├── run_test.go ├── audio.go ├── fine_tunes.go ├── batch.go └── messages_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sashabaranov/go-openai 2 | 3 | go 1.18 4 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [sashabaranov, vvatanabe] 4 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | To run an example: 2 | 3 | ``` 4 | export OPENAI_API_KEY="" 5 | go run ./example/ 6 | ``` 7 | -------------------------------------------------------------------------------- /internal/marshaller.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Marshaller interface { 8 | Marshal(value any) ([]byte, error) 9 | } 10 | 11 | type JSONMarshaller struct{} 12 | 13 | func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { 14 | return json.Marshal(value) 15 | } 16 | -------------------------------------------------------------------------------- /internal/unmarshaler.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Unmarshaler interface { 8 | Unmarshal(data []byte, v any) error 9 | } 10 | 11 | type JSONUnmarshaler struct{} 12 | 13 | func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error { 14 | return json.Unmarshal(data, v) 15 | } 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | # Auth token for tests 18 | .openai-token 19 | .idea 20 | 21 | # Generated by tests 22 | test.mp3 -------------------------------------------------------------------------------- /internal/test/failer.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed") 7 | ) 8 | 9 | type FailingErrorBuffer struct{} 10 | 11 | func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) { 12 | return 0, ErrTestErrorAccumulatorWriteFailed 13 | } 14 | 15 | func (b *FailingErrorBuffer) Len() int { 16 | return 0 17 | } 18 | 19 | func (b *FailingErrorBuffer) Bytes() []byte { 20 | return []byte{} 21 | } 22 | -------------------------------------------------------------------------------- /.github/workflows/integration-tests.yml: -------------------------------------------------------------------------------- 1 | name: Integration tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | integration_tests: 10 | name: Run integration tests 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Setup Go 15 | uses: actions/setup-go@v5 16 | with: 17 | go-version: '1.21' 18 | - name: Run integration tests 19 | env: 20 | OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} 21 | run: go test -v -tags=integration ./api_integration_test.go 22 | -------------------------------------------------------------------------------- /examples/completion/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/sashabaranov/go-openai" 9 | ) 10 | 11 | func main() { 12 | client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 13 | resp, err := client.CreateCompletion( 14 | context.Background(), 15 | openai.CompletionRequest{ 16 | Model: openai.GPT3Babbage002, 17 | MaxTokens: 5, 18 | Prompt: "Lorem ipsum", 19 | }, 20 | ) 21 | if err != nil { 22 | fmt.Printf("Completion error: %v\n", err) 23 | return 24 | } 25 | fmt.Println(resp.Choices[0].Text) 26 | } 27 | -------------------------------------------------------------------------------- /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | name: Sanity check 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | prcheck: 9 | name: Sanity check 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Setup Go 14 | uses: actions/setup-go@v5 15 | with: 16 | go-version: '1.21' 17 | - name: Run vet 18 | run: | 19 | go vet . 20 | - name: Run golangci-lint 21 | uses: golangci/golangci-lint-action@v4 22 | with: 23 | version: latest 24 | - name: Run tests 25 | run: go test -race -covermode=atomic -coverprofile=coverage.out -v . 26 | - name: Upload coverage reports to Codecov 27 | uses: codecov/codecov-action@v4 28 | -------------------------------------------------------------------------------- /examples/images/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/sashabaranov/go-openai" 9 | ) 10 | 11 | func main() { 12 | client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 13 | 14 | respUrl, err := client.CreateImage( 15 | context.Background(), 16 | openai.ImageRequest{ 17 | Prompt: "Parrot on a skateboard performs a trick, cartoon style, natural light, high detail", 18 | Size: openai.CreateImageSize256x256, 19 | ResponseFormat: openai.CreateImageResponseFormatURL, 20 | N: 1, 21 | }, 22 | ) 23 | if err != nil { 24 | fmt.Printf("Image creation error: %v\n", err) 25 | return 26 | } 27 | fmt.Println(respUrl.Data[0].URL) 28 | } 29 | -------------------------------------------------------------------------------- /examples/voice-to-text/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/sashabaranov/go-openai" 10 | ) 11 | 12 | func main() { 13 | if len(os.Args) < 2 { 14 | fmt.Println("please provide a filename to convert to text") 15 | return 16 | } 17 | if _, err := os.Stat(os.Args[1]); errors.Is(err, os.ErrNotExist) { 18 | fmt.Printf("file %s does not exist\n", os.Args[1]) 19 | return 20 | } 21 | 22 | client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 23 | resp, err := client.CreateTranscription( 24 | context.Background(), 25 | openai.AudioRequest{ 26 | Model: openai.Whisper1, 27 | FilePath: os.Args[1], 28 | }, 29 | ) 30 | if err != nil { 31 | fmt.Printf("Transcription error: %v\n", err) 32 | return 33 | } 34 | fmt.Println(resp.Text) 35 | } 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | Your issue may already be reported! 11 | Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. 12 | 13 | **Is your feature request related to a problem? Please describe.** 14 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 15 | 16 | **Describe the solution you'd like** 17 | A clear and concise description of what you want to happen. 18 | 19 | **Describe alternatives you've considered** 20 | A clear and concise description of any alternative solutions or features you've considered. 21 | 22 | **Additional context** 23 | Add any other context or screenshots about the feature request here. 24 | -------------------------------------------------------------------------------- /.github/workflows/close-inactive-issues.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | exempt-issue-labels: 'bug,enhancement' 19 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 20 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 21 | days-before-pr-stale: -1 22 | days-before-pr-close: -1 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /internal/error_accumulator.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | type ErrorAccumulator interface { 10 | Write(p []byte) error 11 | Bytes() []byte 12 | } 13 | 14 | type errorBuffer interface { 15 | io.Writer 16 | Len() int 17 | Bytes() []byte 18 | } 19 | 20 | type DefaultErrorAccumulator struct { 21 | Buffer errorBuffer 22 | } 23 | 24 | func NewErrorAccumulator() ErrorAccumulator { 25 | return &DefaultErrorAccumulator{ 26 | Buffer: &bytes.Buffer{}, 27 | } 28 | } 29 | 30 | func (e *DefaultErrorAccumulator) Write(p []byte) error { 31 | _, err := e.Buffer.Write(p) 32 | if err != nil { 33 | return fmt.Errorf("error accumulator write error, %w", err) 34 | } 35 | return nil 36 | } 37 | 38 | func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { 39 | if e.Buffer.Len() == 0 { 40 | return 41 | } 42 | errBytes = e.Buffer.Bytes() 43 | return 44 | } 45 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | // common.go defines common types used throughout the OpenAI API. 4 | 5 | // Usage Represents the total token usage per request to OpenAI. 6 | type Usage struct { 7 | PromptTokens int `json:"prompt_tokens"` 8 | CompletionTokens int `json:"completion_tokens"` 9 | TotalTokens int `json:"total_tokens"` 10 | PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` 11 | CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` 12 | } 13 | 14 | // CompletionTokensDetails Breakdown of tokens used in a completion. 15 | type CompletionTokensDetails struct { 16 | AudioTokens int `json:"audio_tokens"` 17 | ReasoningTokens int `json:"reasoning_tokens"` 18 | } 19 | 20 | // PromptTokensDetails Breakdown of tokens used in the prompt. 21 | type PromptTokensDetails struct { 22 | AudioTokens int `json:"audio_tokens"` 23 | CachedTokens int `json:"cached_tokens"` 24 | } 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | Your issue may already be reported! 11 | Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. 12 | 13 | **Describe the bug** 14 | A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s). 15 | 16 | **To Reproduce** 17 | Steps to reproduce the behavior, including any relevant code snippets. 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots/Logs** 23 | If applicable, add screenshots to help explain your problem. For non-graphical issues, please provide any relevant logs or stack traces. 24 | 25 | **Environment (please complete the following information):** 26 | - go-openai version: [e.g. v1.12.0] 27 | - Go version: [e.g. 1.18] 28 | - OpenAI API version: [e.g. v1] 29 | - OS: [e.g. Ubuntu 20.04] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /examples/chatbot/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/sashabaranov/go-openai" 10 | ) 11 | 12 | func main() { 13 | client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 14 | 15 | req := openai.ChatCompletionRequest{ 16 | Model: openai.GPT3Dot5Turbo, 17 | Messages: []openai.ChatCompletionMessage{ 18 | { 19 | Role: openai.ChatMessageRoleSystem, 20 | Content: "you are a helpful chatbot", 21 | }, 22 | }, 23 | } 24 | fmt.Println("Conversation") 25 | fmt.Println("---------------------") 26 | fmt.Print("> ") 27 | s := bufio.NewScanner(os.Stdin) 28 | for s.Scan() { 29 | req.Messages = append(req.Messages, openai.ChatCompletionMessage{ 30 | Role: openai.ChatMessageRoleUser, 31 | Content: s.Text(), 32 | }) 33 | resp, err := client.CreateChatCompletion(context.Background(), req) 34 | if err != nil { 35 | fmt.Printf("ChatCompletion error: %v\n", err) 36 | continue 37 | } 38 | fmt.Printf("%s\n\n", resp.Choices[0].Message.Content) 39 | req.Messages = append(req.Messages, resp.Choices[0].Message) 40 | fmt.Print("> ") 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /internal/error_accumulator_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "testing" 7 | 8 | utils "github.com/sashabaranov/go-openai/internal" 9 | "github.com/sashabaranov/go-openai/internal/test" 10 | ) 11 | 12 | func TestErrorAccumulatorBytes(t *testing.T) { 13 | accumulator := &utils.DefaultErrorAccumulator{ 14 | Buffer: &bytes.Buffer{}, 15 | } 16 | 17 | errBytes := accumulator.Bytes() 18 | if len(errBytes) != 0 { 19 | t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) 20 | } 21 | 22 | err := accumulator.Write([]byte("{}")) 23 | if err != nil { 24 | t.Fatalf("%+v", err) 25 | } 26 | 27 | errBytes = accumulator.Bytes() 28 | if len(errBytes) == 0 { 29 | t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) 30 | } 31 | } 32 | 33 | func TestErrorByteWriteErrors(t *testing.T) { 34 | accumulator := &utils.DefaultErrorAccumulator{ 35 | Buffer: &test.FailingErrorBuffer{}, 36 | } 37 | err := accumulator.Write([]byte("{")) 38 | if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { 39 | t.Fatalf("Did not return error when write failed: %v", err) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /internal/request_builder.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | type RequestBuilder interface { 11 | Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) 12 | } 13 | 14 | type HTTPRequestBuilder struct { 15 | marshaller Marshaller 16 | } 17 | 18 | func NewRequestBuilder() *HTTPRequestBuilder { 19 | return &HTTPRequestBuilder{ 20 | marshaller: &JSONMarshaller{}, 21 | } 22 | } 23 | 24 | func (b *HTTPRequestBuilder) Build( 25 | ctx context.Context, 26 | method string, 27 | url string, 28 | body any, 29 | header http.Header, 30 | ) (req *http.Request, err error) { 31 | var bodyReader io.Reader 32 | if body != nil { 33 | if v, ok := body.(io.Reader); ok { 34 | bodyReader = v 35 | } else { 36 | var reqBytes []byte 37 | reqBytes, err = b.marshaller.Marshal(body) 38 | if err != nil { 39 | return 40 | } 41 | bodyReader = bytes.NewBuffer(reqBytes) 42 | } 43 | } 44 | req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) 45 | if err != nil { 46 | return 47 | } 48 | if header != nil { 49 | req.Header = header 50 | } 51 | return 52 | } 53 | -------------------------------------------------------------------------------- /internal/test/checks/checks.go: -------------------------------------------------------------------------------- 1 | package checks 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func NoError(t *testing.T, err error, message ...string) { 9 | t.Helper() 10 | if err != nil { 11 | t.Error(err, message) 12 | } 13 | } 14 | 15 | func NoErrorF(t *testing.T, err error, message ...string) { 16 | t.Helper() 17 | if err != nil { 18 | t.Fatal(err, message) 19 | } 20 | } 21 | 22 | func HasError(t *testing.T, err error, message ...string) { 23 | t.Helper() 24 | if err == nil { 25 | t.Error(err, message) 26 | } 27 | } 28 | 29 | func ErrorIs(t *testing.T, err, target error, msg ...string) { 30 | t.Helper() 31 | if !errors.Is(err, target) { 32 | t.Fatal(msg) 33 | } 34 | } 35 | 36 | func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) { 37 | t.Helper() 38 | if !errors.Is(err, target) { 39 | t.Fatalf(format, msg) 40 | } 41 | } 42 | 43 | func ErrorIsNot(t *testing.T, err, target error, msg ...string) { 44 | t.Helper() 45 | if errors.Is(err, target) { 46 | t.Fatal(msg) 47 | } 48 | } 49 | 50 | func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) { 51 | t.Helper() 52 | if errors.Is(err, target) { 53 | t.Fatalf(format, msg) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /openai_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "github.com/sashabaranov/go-openai" 5 | "github.com/sashabaranov/go-openai/internal/test" 6 | ) 7 | 8 | func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { 9 | server = test.NewTestServer() 10 | ts := server.OpenAITestServer() 11 | ts.Start() 12 | teardown = ts.Close 13 | config := openai.DefaultConfig(test.GetTestToken()) 14 | config.BaseURL = ts.URL + "/v1" 15 | client = openai.NewClientWithConfig(config) 16 | return 17 | } 18 | 19 | func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { 20 | server = test.NewTestServer() 21 | ts := server.OpenAITestServer() 22 | ts.Start() 23 | teardown = ts.Close 24 | config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") 25 | config.BaseURL = ts.URL 26 | client = openai.NewClientWithConfig(config) 27 | return 28 | } 29 | 30 | // numTokens Returns the number of GPT-3 encoded tokens in the given text. 31 | // This function approximates based on the rule of thumb stated by OpenAI: 32 | // https://beta.openai.com/tokenizer 33 | // 34 | // TODO: implement an actual tokenizer for GPT-3 and Codex (once available) 35 | func numTokens(s string) int { 36 | return int(float32(len(s)) / 4) 37 | } 38 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | A similar PR may already be submitted! 2 | Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. 3 | 4 | If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly. 5 | 6 | Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. 7 | 8 | **Describe the change** 9 | Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. 10 | 11 | **Provide OpenAI documentation link** 12 | Provide a relevant API doc from https://platform.openai.com/docs/api-reference 13 | 14 | **Describe your solution** 15 | Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. 16 | 17 | **Tests** 18 | Briefly describe how you have tested these changes. If possible — please add integration tests. 19 | 20 | **Additional context** 21 | Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. 22 | 23 | Issue: #XXXX 24 | -------------------------------------------------------------------------------- /engines.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // Engine struct represents engine from OpenAPI API. 10 | type Engine struct { 11 | ID string `json:"id"` 12 | Object string `json:"object"` 13 | Owner string `json:"owner"` 14 | Ready bool `json:"ready"` 15 | 16 | httpHeader 17 | } 18 | 19 | // EnginesList is a list of engines. 20 | type EnginesList struct { 21 | Engines []Engine `json:"data"` 22 | 23 | httpHeader 24 | } 25 | 26 | // ListEngines Lists the currently available engines, and provides basic 27 | // information about each option such as the owner and availability. 28 | func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { 29 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines")) 30 | if err != nil { 31 | return 32 | } 33 | 34 | err = c.sendRequest(req, &engines) 35 | return 36 | } 37 | 38 | // GetEngine Retrieves an engine instance, providing basic information about 39 | // the engine such as the owner and availability. 40 | func (c *Client) GetEngine( 41 | ctx context.Context, 42 | engineID string, 43 | ) (engine Engine, err error) { 44 | urlSuffix := fmt.Sprintf("/engines/%s", engineID) 45 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 46 | if err != nil { 47 | return 48 | } 49 | 50 | err = c.sendRequest(req, &engine) 51 | return 52 | } 53 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | ) 8 | 9 | var ( 10 | ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages") 11 | ) 12 | 13 | type CompletionStream struct { 14 | *streamReader[CompletionResponse] 15 | } 16 | 17 | // CreateCompletionStream — API call to create a completion w/ streaming 18 | // support. It sets whether to stream back partial progress. If set, tokens will be 19 | // sent as data-only server-sent events as they become available, with the 20 | // stream terminated by a data: [DONE] message. 21 | func (c *Client) CreateCompletionStream( 22 | ctx context.Context, 23 | request CompletionRequest, 24 | ) (stream *CompletionStream, err error) { 25 | urlSuffix := "/completions" 26 | if !checkEndpointSupportsModel(urlSuffix, request.Model) { 27 | err = ErrCompletionUnsupportedModel 28 | return 29 | } 30 | 31 | if !checkPromptType(request.Prompt) { 32 | err = ErrCompletionRequestPromptTypeNotSupported 33 | return 34 | } 35 | 36 | request.Stream = true 37 | req, err := c.newRequest( 38 | ctx, 39 | http.MethodPost, 40 | c.fullURL(urlSuffix, withModel(request.Model)), 41 | withBody(request), 42 | ) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | resp, err := sendRequestStream[CompletionResponse](c, req) 48 | if err != nil { 49 | return 50 | } 51 | stream = &CompletionStream{ 52 | streamReader: resp, 53 | } 54 | return 55 | } 56 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sashabaranov/go-openai" 7 | ) 8 | 9 | func TestGetAzureDeploymentByModel(t *testing.T) { 10 | cases := []struct { 11 | Model string 12 | AzureModelMapperFunc func(model string) string 13 | Expect string 14 | }{ 15 | { 16 | Model: "gpt-3.5-turbo", 17 | Expect: "gpt-35-turbo", 18 | }, 19 | { 20 | Model: "gpt-3.5-turbo-0301", 21 | Expect: "gpt-35-turbo-0301", 22 | }, 23 | { 24 | Model: "text-embedding-ada-002", 25 | Expect: "text-embedding-ada-002", 26 | }, 27 | { 28 | Model: "", 29 | Expect: "", 30 | }, 31 | { 32 | Model: "models", 33 | Expect: "models", 34 | }, 35 | { 36 | Model: "gpt-3.5-turbo", 37 | Expect: "my-gpt35", 38 | AzureModelMapperFunc: func(model string) string { 39 | modelmapper := map[string]string{ 40 | "gpt-3.5-turbo": "my-gpt35", 41 | } 42 | if val, ok := modelmapper[model]; ok { 43 | return val 44 | } 45 | return model 46 | }, 47 | }, 48 | } 49 | 50 | for _, c := range cases { 51 | t.Run(c.Model, func(t *testing.T) { 52 | conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") 53 | if c.AzureModelMapperFunc != nil { 54 | conf.AzureModelMapperFunc = c.AzureModelMapperFunc 55 | } 56 | actual := conf.GetAzureDeploymentByModel(c.Model) 57 | if actual != c.Expect { 58 | t.Errorf("Expected %s, got %s", c.Expect, actual) 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /ratelimit.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "time" 7 | ) 8 | 9 | // RateLimitHeaders struct represents Openai rate limits headers. 10 | type RateLimitHeaders struct { 11 | LimitRequests int `json:"x-ratelimit-limit-requests"` 12 | LimitTokens int `json:"x-ratelimit-limit-tokens"` 13 | RemainingRequests int `json:"x-ratelimit-remaining-requests"` 14 | RemainingTokens int `json:"x-ratelimit-remaining-tokens"` 15 | ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` 16 | ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` 17 | } 18 | 19 | type ResetTime string 20 | 21 | func (r ResetTime) String() string { 22 | return string(r) 23 | } 24 | 25 | func (r ResetTime) Time() time.Time { 26 | d, _ := time.ParseDuration(string(r)) 27 | return time.Now().Add(d) 28 | } 29 | 30 | func newRateLimitHeaders(h http.Header) RateLimitHeaders { 31 | limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) 32 | limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) 33 | remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) 34 | remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) 35 | return RateLimitHeaders{ 36 | LimitRequests: limitReq, 37 | LimitTokens: limitTokens, 38 | RemainingRequests: remainingReq, 39 | RemainingTokens: remainingTokens, 40 | ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), 41 | ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /internal/form_builder_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | "github.com/sashabaranov/go-openai/internal/test" 5 | "github.com/sashabaranov/go-openai/internal/test/checks" 6 | 7 | "bytes" 8 | "errors" 9 | "os" 10 | "testing" 11 | ) 12 | 13 | type failingWriter struct { 14 | } 15 | 16 | var errMockFailingWriterError = errors.New("mock writer failed") 17 | 18 | func (*failingWriter) Write([]byte) (int, error) { 19 | return 0, errMockFailingWriterError 20 | } 21 | 22 | func TestFormBuilderWithFailingWriter(t *testing.T) { 23 | dir, cleanup := test.CreateTestDirectory(t) 24 | defer cleanup() 25 | 26 | file, err := os.CreateTemp(dir, "") 27 | if err != nil { 28 | t.Errorf("Error creating tmp file: %v", err) 29 | } 30 | defer file.Close() 31 | defer os.Remove(file.Name()) 32 | 33 | builder := NewFormBuilder(&failingWriter{}) 34 | err = builder.CreateFormFile("file", file) 35 | checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") 36 | } 37 | 38 | func TestFormBuilderWithClosedFile(t *testing.T) { 39 | dir, cleanup := test.CreateTestDirectory(t) 40 | defer cleanup() 41 | 42 | file, err := os.CreateTemp(dir, "") 43 | if err != nil { 44 | t.Errorf("Error creating tmp file: %v", err) 45 | } 46 | file.Close() 47 | defer os.Remove(file.Name()) 48 | 49 | body := &bytes.Buffer{} 50 | builder := NewFormBuilder(body) 51 | err = builder.CreateFormFile("file", file) 52 | checks.HasError(t, err, "formbuilder should return error if file is closed") 53 | checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") 54 | } 55 | -------------------------------------------------------------------------------- /edits.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // EditsRequest represents a request structure for Edits API. 10 | type EditsRequest struct { 11 | Model *string `json:"model,omitempty"` 12 | Input string `json:"input,omitempty"` 13 | Instruction string `json:"instruction,omitempty"` 14 | N int `json:"n,omitempty"` 15 | Temperature float32 `json:"temperature,omitempty"` 16 | TopP float32 `json:"top_p,omitempty"` 17 | } 18 | 19 | // EditsChoice represents one of possible edits. 20 | type EditsChoice struct { 21 | Text string `json:"text"` 22 | Index int `json:"index"` 23 | } 24 | 25 | // EditsResponse represents a response structure for Edits API. 26 | type EditsResponse struct { 27 | Object string `json:"object"` 28 | Created int64 `json:"created"` 29 | Usage Usage `json:"usage"` 30 | Choices []EditsChoice `json:"choices"` 31 | 32 | httpHeader 33 | } 34 | 35 | // Edits Perform an API call to the Edits endpoint. 36 | /* Deprecated: Users of the Edits API and its associated models (e.g., text-davinci-edit-001 or code-davinci-edit-001) 37 | will need to migrate to GPT-3.5 Turbo by January 4, 2024. 38 | You can use CreateChatCompletion or CreateChatCompletionStream instead. 39 | */ 40 | func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { 41 | req, err := c.newRequest( 42 | ctx, 43 | http.MethodPost, 44 | c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), 45 | withBody(request), 46 | ) 47 | if err != nil { 48 | return 49 | } 50 | 51 | err = c.sendRequest(req, &response) 52 | return 53 | } 54 | -------------------------------------------------------------------------------- /engines_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/sashabaranov/go-openai" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. 15 | func TestGetEngine(t *testing.T) { 16 | client, server, teardown := setupOpenAITestServer() 17 | defer teardown() 18 | server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { 19 | resBytes, _ := json.Marshal(openai.Engine{}) 20 | fmt.Fprintln(w, string(resBytes)) 21 | }) 22 | _, err := client.GetEngine(context.Background(), "text-davinci-003") 23 | checks.NoError(t, err, "GetEngine error") 24 | } 25 | 26 | // TestListEngines Tests the list engines endpoint of the API using the mocked server. 27 | func TestListEngines(t *testing.T) { 28 | client, server, teardown := setupOpenAITestServer() 29 | defer teardown() 30 | server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { 31 | resBytes, _ := json.Marshal(openai.EnginesList{}) 32 | fmt.Fprintln(w, string(resBytes)) 33 | }) 34 | _, err := client.ListEngines(context.Background()) 35 | checks.NoError(t, err, "ListEngines error") 36 | } 37 | 38 | func TestListEnginesReturnError(t *testing.T) { 39 | client, server, teardown := setupOpenAITestServer() 40 | defer teardown() 41 | server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { 42 | w.WriteHeader(http.StatusTeapot) 43 | }) 44 | 45 | _, err := client.ListEngines(context.Background()) 46 | checks.HasError(t, err, "ListEngines did not fail") 47 | } 48 | -------------------------------------------------------------------------------- /internal/form_builder.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "mime/multipart" 7 | "os" 8 | "path" 9 | ) 10 | 11 | type FormBuilder interface { 12 | CreateFormFile(fieldname string, file *os.File) error 13 | CreateFormFileReader(fieldname string, r io.Reader, filename string) error 14 | WriteField(fieldname, value string) error 15 | Close() error 16 | FormDataContentType() string 17 | } 18 | 19 | type DefaultFormBuilder struct { 20 | writer *multipart.Writer 21 | } 22 | 23 | func NewFormBuilder(body io.Writer) *DefaultFormBuilder { 24 | return &DefaultFormBuilder{ 25 | writer: multipart.NewWriter(body), 26 | } 27 | } 28 | 29 | func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { 30 | return fb.createFormFile(fieldname, file, file.Name()) 31 | } 32 | 33 | func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { 34 | return fb.createFormFile(fieldname, r, path.Base(filename)) 35 | } 36 | 37 | func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { 38 | if filename == "" { 39 | return fmt.Errorf("filename cannot be empty") 40 | } 41 | 42 | fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | _, err = io.Copy(fieldWriter, r) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | return nil 53 | } 54 | 55 | func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { 56 | return fb.writer.WriteField(fieldname, value) 57 | } 58 | 59 | func (fb *DefaultFormBuilder) Close() error { 60 | return fb.writer.Close() 61 | } 62 | 63 | func (fb *DefaultFormBuilder) FormDataContentType() string { 64 | return fb.writer.FormDataContentType() 65 | } 66 | -------------------------------------------------------------------------------- /speech.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type SpeechModel string 9 | 10 | const ( 11 | TTSModel1 SpeechModel = "tts-1" 12 | TTSModel1HD SpeechModel = "tts-1-hd" 13 | TTSModelCanary SpeechModel = "canary-tts" 14 | ) 15 | 16 | type SpeechVoice string 17 | 18 | const ( 19 | VoiceAlloy SpeechVoice = "alloy" 20 | VoiceEcho SpeechVoice = "echo" 21 | VoiceFable SpeechVoice = "fable" 22 | VoiceOnyx SpeechVoice = "onyx" 23 | VoiceNova SpeechVoice = "nova" 24 | VoiceShimmer SpeechVoice = "shimmer" 25 | ) 26 | 27 | type SpeechResponseFormat string 28 | 29 | const ( 30 | SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" 31 | SpeechResponseFormatOpus SpeechResponseFormat = "opus" 32 | SpeechResponseFormatAac SpeechResponseFormat = "aac" 33 | SpeechResponseFormatFlac SpeechResponseFormat = "flac" 34 | SpeechResponseFormatWav SpeechResponseFormat = "wav" 35 | SpeechResponseFormatPcm SpeechResponseFormat = "pcm" 36 | ) 37 | 38 | type CreateSpeechRequest struct { 39 | Model SpeechModel `json:"model"` 40 | Input string `json:"input"` 41 | Voice SpeechVoice `json:"voice"` 42 | ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 43 | Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 44 | } 45 | 46 | func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { 47 | req, err := c.newRequest( 48 | ctx, 49 | http.MethodPost, 50 | c.fullURL("/audio/speech", withModel(string(request.Model))), 51 | withBody(request), 52 | withContentType("application/json"), 53 | ) 54 | if err != nil { 55 | return 56 | } 57 | 58 | return c.sendRequestRaw(req) 59 | } 60 | -------------------------------------------------------------------------------- /internal/test/server.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/http/httptest" 7 | "regexp" 8 | "strings" 9 | ) 10 | 11 | const testAPI = "this-is-my-secure-token-do-not-steal!!" 12 | 13 | func GetTestToken() string { 14 | return testAPI 15 | } 16 | 17 | type ServerTest struct { 18 | handlers map[string]handler 19 | } 20 | type handler func(w http.ResponseWriter, r *http.Request) 21 | 22 | func NewTestServer() *ServerTest { 23 | return &ServerTest{handlers: make(map[string]handler)} 24 | } 25 | 26 | func (ts *ServerTest) RegisterHandler(path string, handler handler) { 27 | // to make the registered paths friendlier to a regex match in the route handler 28 | // in OpenAITestServer 29 | path = strings.ReplaceAll(path, "*", ".*") 30 | ts.handlers[path] = handler 31 | } 32 | 33 | // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. 34 | func (ts *ServerTest) OpenAITestServer() *httptest.Server { 35 | return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 36 | log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) 37 | 38 | // check auth 39 | if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { 40 | w.WriteHeader(http.StatusUnauthorized) 41 | return 42 | } 43 | 44 | // Handle /path/* routes. 45 | // Note: the * is converted to a .* in register handler for proper regex handling 46 | for route, handler := range ts.handlers { 47 | // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered 48 | pattern, _ := regexp.Compile("^" + route + "$") 49 | if pattern.MatchString(r.URL.Path) { 50 | handler(w, r) 51 | return 52 | } 53 | } 54 | http.Error(w, "the resource path doesn't exist", http.StatusNotFound) 55 | })) 56 | } 57 | -------------------------------------------------------------------------------- /internal/test/helpers.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/sashabaranov/go-openai/internal/test/checks" 5 | 6 | "net/http" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | // CreateTestFile creates a fake file with "hello" as the content. 12 | func CreateTestFile(t *testing.T, path string) { 13 | file, err := os.Create(path) 14 | checks.NoError(t, err, "failed to create file") 15 | 16 | if _, err = file.WriteString("hello"); err != nil { 17 | t.Fatalf("failed to write to file %v", err) 18 | } 19 | file.Close() 20 | } 21 | 22 | // CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. 23 | func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { 24 | t.Helper() 25 | 26 | path, err := os.MkdirTemp(os.TempDir(), "") 27 | checks.NoError(t, err) 28 | 29 | return path, func() { os.RemoveAll(path) } 30 | } 31 | 32 | // TokenRoundTripper is a struct that implements the RoundTripper 33 | // interface, specifically to handle the authentication token by adding a token 34 | // to the request header. We need this because the API requires that each 35 | // request include a valid API token in the headers for authentication and 36 | // authorization. 37 | type TokenRoundTripper struct { 38 | Token string 39 | Fallback http.RoundTripper 40 | } 41 | 42 | // RoundTrip takes an *http.Request as input and returns an 43 | // *http.Response and an error. 44 | // 45 | // It is expected to use the provided request to create a connection to an HTTP 46 | // server and return the response, or an error if one occurred. The returned 47 | // Response should have its Body closed. If the RoundTrip method returns an 48 | // error, the Client's Get, Head, Post, and PostForm methods return the same 49 | // error. 50 | func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 51 | req.Header.Set("Authorization", "Bearer "+t.Token) 52 | return t.Fallback.RoundTrip(req) 53 | } 54 | -------------------------------------------------------------------------------- /internal/request_builder_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "net/http" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | var errTestMarshallerFailed = errors.New("test marshaller failed") 13 | 14 | type failingMarshaller struct{} 15 | 16 | func (*failingMarshaller) Marshal(_ any) ([]byte, error) { 17 | return []byte{}, errTestMarshallerFailed 18 | } 19 | 20 | func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { 21 | builder := HTTPRequestBuilder{ 22 | marshaller: &failingMarshaller{}, 23 | } 24 | 25 | _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) 26 | if !errors.Is(err, errTestMarshallerFailed) { 27 | t.Fatalf("Did not return error when marshaller failed: %v", err) 28 | } 29 | } 30 | 31 | func TestRequestBuilderReturnsRequest(t *testing.T) { 32 | b := NewRequestBuilder() 33 | var ( 34 | ctx = context.Background() 35 | method = http.MethodPost 36 | url = "/foo" 37 | request = map[string]string{"foo": "bar"} 38 | reqBytes, _ = b.marshaller.Marshal(request) 39 | want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) 40 | ) 41 | got, _ := b.Build(ctx, method, url, request, nil) 42 | if !reflect.DeepEqual(got.Body, want.Body) || 43 | !reflect.DeepEqual(got.URL, want.URL) || 44 | !reflect.DeepEqual(got.Method, want.Method) { 45 | t.Errorf("Build() got = %v, want %v", got, want) 46 | } 47 | } 48 | 49 | func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { 50 | var ( 51 | ctx = context.Background() 52 | method = http.MethodGet 53 | url = "/foo" 54 | want, _ = http.NewRequestWithContext(ctx, method, url, nil) 55 | ) 56 | b := NewRequestBuilder() 57 | got, _ := b.Build(ctx, method, url, nil, nil) 58 | if !reflect.DeepEqual(got, want) { 59 | t.Errorf("Build() got = %v, want %v", got, want) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /jsonschema/validate.go: -------------------------------------------------------------------------------- 1 | package jsonschema 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | ) 7 | 8 | func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { 9 | var data any 10 | err := json.Unmarshal(content, &data) 11 | if err != nil { 12 | return err 13 | } 14 | if !Validate(schema, data) { 15 | return errors.New("data validation failed against the provided schema") 16 | } 17 | return json.Unmarshal(content, &v) 18 | } 19 | 20 | func Validate(schema Definition, data any) bool { 21 | switch schema.Type { 22 | case Object: 23 | return validateObject(schema, data) 24 | case Array: 25 | return validateArray(schema, data) 26 | case String: 27 | _, ok := data.(string) 28 | return ok 29 | case Number: // float64 and int 30 | _, ok := data.(float64) 31 | if !ok { 32 | _, ok = data.(int) 33 | } 34 | return ok 35 | case Boolean: 36 | _, ok := data.(bool) 37 | return ok 38 | case Integer: 39 | // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer 40 | if num, ok := data.(float64); ok { 41 | return num == float64(int64(num)) 42 | } 43 | _, ok := data.(int) 44 | return ok 45 | case Null: 46 | return data == nil 47 | default: 48 | return false 49 | } 50 | } 51 | 52 | func validateObject(schema Definition, data any) bool { 53 | dataMap, ok := data.(map[string]any) 54 | if !ok { 55 | return false 56 | } 57 | for _, field := range schema.Required { 58 | if _, exists := dataMap[field]; !exists { 59 | return false 60 | } 61 | } 62 | for key, valueSchema := range schema.Properties { 63 | value, exists := dataMap[key] 64 | if exists && !Validate(valueSchema, value) { 65 | return false 66 | } else if !exists && contains(schema.Required, key) { 67 | return false 68 | } 69 | } 70 | return true 71 | } 72 | 73 | func validateArray(schema Definition, data any) bool { 74 | dataArray, ok := data.([]any) 75 | if !ok { 76 | return false 77 | } 78 | for _, item := range dataArray { 79 | if !Validate(*schema.Items, item) { 80 | return false 81 | } 82 | } 83 | return true 84 | } 85 | 86 | func contains[S ~[]E, E comparable](s S, v E) bool { 87 | for i := range s { 88 | if v == s[i] { 89 | return true 90 | } 91 | } 92 | return false 93 | } 94 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | ) 7 | 8 | const ( 9 | openaiAPIURLv1 = "https://api.openai.com/v1" 10 | defaultEmptyMessagesLimit uint = 300 11 | 12 | azureAPIPrefix = "openai" 13 | azureDeploymentsPrefix = "deployments" 14 | ) 15 | 16 | type APIType string 17 | 18 | const ( 19 | APITypeOpenAI APIType = "OPEN_AI" 20 | APITypeAzure APIType = "AZURE" 21 | APITypeAzureAD APIType = "AZURE_AD" 22 | APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" 23 | ) 24 | 25 | const AzureAPIKeyHeader = "api-key" 26 | 27 | const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store 28 | 29 | type HTTPDoer interface { 30 | Do(req *http.Request) (*http.Response, error) 31 | } 32 | 33 | // ClientConfig is a configuration of a client. 34 | type ClientConfig struct { 35 | authToken string 36 | 37 | BaseURL string 38 | OrgID string 39 | APIType APIType 40 | APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD 41 | AssistantVersion string 42 | AzureModelMapperFunc func(model string) string // replace model to azure deployment name func 43 | HTTPClient HTTPDoer 44 | 45 | EmptyMessagesLimit uint 46 | } 47 | 48 | func DefaultConfig(authToken string) ClientConfig { 49 | return ClientConfig{ 50 | authToken: authToken, 51 | BaseURL: openaiAPIURLv1, 52 | APIType: APITypeOpenAI, 53 | AssistantVersion: defaultAssistantVersion, 54 | OrgID: "", 55 | 56 | HTTPClient: &http.Client{}, 57 | 58 | EmptyMessagesLimit: defaultEmptyMessagesLimit, 59 | } 60 | } 61 | 62 | func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { 63 | return ClientConfig{ 64 | authToken: apiKey, 65 | BaseURL: baseURL, 66 | OrgID: "", 67 | APIType: APITypeAzure, 68 | APIVersion: "2023-05-15", 69 | AzureModelMapperFunc: func(model string) string { 70 | return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") 71 | }, 72 | 73 | HTTPClient: &http.Client{}, 74 | 75 | EmptyMessagesLimit: defaultEmptyMessagesLimit, 76 | } 77 | } 78 | 79 | func (ClientConfig) String() string { 80 | return "" 81 | } 82 | 83 | func (c ClientConfig) GetAzureDeploymentByModel(model string) string { 84 | if c.AzureModelMapperFunc != nil { 85 | return c.AzureModelMapperFunc(model) 86 | } 87 | 88 | return model 89 | } 90 | -------------------------------------------------------------------------------- /fine_tunes_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/sashabaranov/go-openai" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | const testFineTuneID = "fine-tune-id" 15 | 16 | // TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. 17 | func TestFineTunes(t *testing.T) { 18 | client, server, teardown := setupOpenAITestServer() 19 | defer teardown() 20 | server.RegisterHandler( 21 | "/v1/fine-tunes", 22 | func(w http.ResponseWriter, r *http.Request) { 23 | var resBytes []byte 24 | if r.Method == http.MethodGet { 25 | resBytes, _ = json.Marshal(openai.FineTuneList{}) 26 | } else { 27 | resBytes, _ = json.Marshal(openai.FineTune{}) 28 | } 29 | fmt.Fprintln(w, string(resBytes)) 30 | }, 31 | ) 32 | 33 | server.RegisterHandler( 34 | "/v1/fine-tunes/"+testFineTuneID+"/cancel", 35 | func(w http.ResponseWriter, _ *http.Request) { 36 | resBytes, _ := json.Marshal(openai.FineTune{}) 37 | fmt.Fprintln(w, string(resBytes)) 38 | }, 39 | ) 40 | 41 | server.RegisterHandler( 42 | "/v1/fine-tunes/"+testFineTuneID, 43 | func(w http.ResponseWriter, r *http.Request) { 44 | var resBytes []byte 45 | if r.Method == http.MethodDelete { 46 | resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) 47 | } else { 48 | resBytes, _ = json.Marshal(openai.FineTune{}) 49 | } 50 | fmt.Fprintln(w, string(resBytes)) 51 | }, 52 | ) 53 | 54 | server.RegisterHandler( 55 | "/v1/fine-tunes/"+testFineTuneID+"/events", 56 | func(w http.ResponseWriter, _ *http.Request) { 57 | resBytes, _ := json.Marshal(openai.FineTuneEventList{}) 58 | fmt.Fprintln(w, string(resBytes)) 59 | }, 60 | ) 61 | 62 | ctx := context.Background() 63 | 64 | _, err := client.ListFineTunes(ctx) 65 | checks.NoError(t, err, "ListFineTunes error") 66 | 67 | _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) 68 | checks.NoError(t, err, "CreateFineTune error") 69 | 70 | _, err = client.CancelFineTune(ctx, testFineTuneID) 71 | checks.NoError(t, err, "CancelFineTune error") 72 | 73 | _, err = client.GetFineTune(ctx, testFineTuneID) 74 | checks.NoError(t, err, "GetFineTune error") 75 | 76 | _, err = client.DeleteFineTune(ctx, testFineTuneID) 77 | checks.NoError(t, err, "DeleteFineTune error") 78 | 79 | _, err = client.ListFineTuneEvents(ctx, testFineTuneID) 80 | checks.NoError(t, err, "ListFineTuneEvents error") 81 | } 82 | -------------------------------------------------------------------------------- /stream_reader_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "testing" 8 | 9 | utils "github.com/sashabaranov/go-openai/internal" 10 | "github.com/sashabaranov/go-openai/internal/test" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") 15 | 16 | type failingUnMarshaller struct{} 17 | 18 | func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error { 19 | return errTestUnmarshalerFailed 20 | } 21 | 22 | func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) { 23 | stream := &streamReader[ChatCompletionStreamResponse]{ 24 | errAccumulator: utils.NewErrorAccumulator(), 25 | unmarshaler: &failingUnMarshaller{}, 26 | } 27 | 28 | respErr := stream.unmarshalError() 29 | if respErr != nil { 30 | t.Fatalf("Did not return nil with empty buffer: %v", respErr) 31 | } 32 | 33 | err := stream.errAccumulator.Write([]byte("{")) 34 | if err != nil { 35 | t.Fatalf("%+v", err) 36 | } 37 | 38 | respErr = stream.unmarshalError() 39 | if respErr != nil { 40 | t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr) 41 | } 42 | } 43 | 44 | func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { 45 | stream := &streamReader[ChatCompletionStreamResponse]{ 46 | emptyMessagesLimit: 3, 47 | reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))), 48 | errAccumulator: utils.NewErrorAccumulator(), 49 | unmarshaler: &utils.JSONUnmarshaler{}, 50 | } 51 | _, err := stream.Recv() 52 | checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error()) 53 | } 54 | 55 | func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { 56 | stream := &streamReader[ChatCompletionStreamResponse]{ 57 | reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), 58 | errAccumulator: &utils.DefaultErrorAccumulator{ 59 | Buffer: &test.FailingErrorBuffer{}, 60 | }, 61 | unmarshaler: &utils.JSONUnmarshaler{}, 62 | } 63 | _, err := stream.Recv() 64 | checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) 65 | } 66 | 67 | func TestStreamReaderRecvRaw(t *testing.T) { 68 | stream := &streamReader[ChatCompletionStreamResponse]{ 69 | reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), 70 | } 71 | rawLine, err := stream.RecvRaw() 72 | if err != nil { 73 | t.Fatalf("Did not return raw line: %v", err) 74 | } 75 | if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { 76 | t.Fatalf("Did not return raw line: %v", string(rawLine)) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /speech_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "mime" 9 | "net/http" 10 | "os" 11 | "path/filepath" 12 | "testing" 13 | 14 | "github.com/sashabaranov/go-openai" 15 | "github.com/sashabaranov/go-openai/internal/test" 16 | "github.com/sashabaranov/go-openai/internal/test/checks" 17 | ) 18 | 19 | func TestSpeechIntegration(t *testing.T) { 20 | client, server, teardown := setupOpenAITestServer() 21 | defer teardown() 22 | 23 | server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { 24 | dir, cleanup := test.CreateTestDirectory(t) 25 | path := filepath.Join(dir, "fake.mp3") 26 | test.CreateTestFile(t, path) 27 | defer cleanup() 28 | 29 | // audio endpoints only accept POST requests 30 | if r.Method != "POST" { 31 | http.Error(w, "method not allowed", http.StatusMethodNotAllowed) 32 | return 33 | } 34 | 35 | mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) 36 | if err != nil { 37 | http.Error(w, "failed to parse media type", http.StatusBadRequest) 38 | return 39 | } 40 | 41 | if mediaType != "application/json" { 42 | http.Error(w, "request is not json", http.StatusBadRequest) 43 | return 44 | } 45 | 46 | // Parse the JSON body of the request 47 | var params map[string]interface{} 48 | err = json.NewDecoder(r.Body).Decode(¶ms) 49 | if err != nil { 50 | http.Error(w, "failed to parse request body", http.StatusBadRequest) 51 | return 52 | } 53 | 54 | // Check if each required field is present in the parsed JSON object 55 | reqParams := []string{"model", "input", "voice"} 56 | for _, param := range reqParams { 57 | _, ok := params[param] 58 | if !ok { 59 | http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) 60 | return 61 | } 62 | } 63 | 64 | // read audio file content 65 | audioFile, err := os.ReadFile(path) 66 | if err != nil { 67 | http.Error(w, "failed to read audio file", http.StatusInternalServerError) 68 | return 69 | } 70 | 71 | // write audio file content to response 72 | w.Header().Set("Content-Type", "audio/mpeg") 73 | w.Header().Set("Transfer-Encoding", "chunked") 74 | w.Header().Set("Connection", "keep-alive") 75 | _, err = w.Write(audioFile) 76 | if err != nil { 77 | http.Error(w, "failed to write body", http.StatusInternalServerError) 78 | return 79 | } 80 | }) 81 | 82 | t.Run("happy path", func(t *testing.T) { 83 | res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ 84 | Model: openai.TTSModel1, 85 | Input: "Hello!", 86 | Voice: openai.VoiceAlloy, 87 | }) 88 | checks.NoError(t, err, "CreateSpeech error") 89 | defer res.Close() 90 | 91 | buf, err := io.ReadAll(res) 92 | checks.NoError(t, err, "ReadAll error") 93 | 94 | // save buf to file as mp3 95 | err = os.WriteFile("test.mp3", buf, 0644) 96 | checks.NoError(t, err, "Create error") 97 | }) 98 | } 99 | -------------------------------------------------------------------------------- /stream_reader.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | 10 | utils "github.com/sashabaranov/go-openai/internal" 11 | ) 12 | 13 | var ( 14 | headerData = []byte("data: ") 15 | errorPrefix = []byte(`data: {"error":`) 16 | ) 17 | 18 | type streamable interface { 19 | ChatCompletionStreamResponse | CompletionResponse 20 | } 21 | 22 | type streamReader[T streamable] struct { 23 | emptyMessagesLimit uint 24 | isFinished bool 25 | 26 | reader *bufio.Reader 27 | response *http.Response 28 | errAccumulator utils.ErrorAccumulator 29 | unmarshaler utils.Unmarshaler 30 | 31 | httpHeader 32 | } 33 | 34 | func (stream *streamReader[T]) Recv() (response T, err error) { 35 | rawLine, err := stream.RecvRaw() 36 | if err != nil { 37 | return 38 | } 39 | 40 | err = stream.unmarshaler.Unmarshal(rawLine, &response) 41 | if err != nil { 42 | return 43 | } 44 | return response, nil 45 | } 46 | 47 | func (stream *streamReader[T]) RecvRaw() ([]byte, error) { 48 | if stream.isFinished { 49 | return nil, io.EOF 50 | } 51 | 52 | return stream.processLines() 53 | } 54 | 55 | //nolint:gocognit 56 | func (stream *streamReader[T]) processLines() ([]byte, error) { 57 | var ( 58 | emptyMessagesCount uint 59 | hasErrorPrefix bool 60 | ) 61 | 62 | for { 63 | rawLine, readErr := stream.reader.ReadBytes('\n') 64 | if readErr != nil || hasErrorPrefix { 65 | respErr := stream.unmarshalError() 66 | if respErr != nil { 67 | return nil, fmt.Errorf("error, %w", respErr.Error) 68 | } 69 | return nil, readErr 70 | } 71 | 72 | noSpaceLine := bytes.TrimSpace(rawLine) 73 | if bytes.HasPrefix(noSpaceLine, errorPrefix) { 74 | hasErrorPrefix = true 75 | } 76 | if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { 77 | if hasErrorPrefix { 78 | noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) 79 | } 80 | writeErr := stream.errAccumulator.Write(noSpaceLine) 81 | if writeErr != nil { 82 | return nil, writeErr 83 | } 84 | emptyMessagesCount++ 85 | if emptyMessagesCount > stream.emptyMessagesLimit { 86 | return nil, ErrTooManyEmptyStreamMessages 87 | } 88 | 89 | continue 90 | } 91 | 92 | noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) 93 | if string(noPrefixLine) == "[DONE]" { 94 | stream.isFinished = true 95 | return nil, io.EOF 96 | } 97 | 98 | return noPrefixLine, nil 99 | } 100 | } 101 | 102 | func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { 103 | errBytes := stream.errAccumulator.Bytes() 104 | if len(errBytes) == 0 { 105 | return 106 | } 107 | 108 | err := stream.unmarshaler.Unmarshal(errBytes, &errResp) 109 | if err != nil { 110 | errResp = nil 111 | } 112 | 113 | return 114 | } 115 | 116 | func (stream *streamReader[T]) Close() error { 117 | return stream.response.Body.Close() 118 | } 119 | -------------------------------------------------------------------------------- /models.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // Model struct represents an OpenAPI model. 10 | type Model struct { 11 | CreatedAt int64 `json:"created"` 12 | ID string `json:"id"` 13 | Object string `json:"object"` 14 | OwnedBy string `json:"owned_by"` 15 | Permission []Permission `json:"permission"` 16 | Root string `json:"root"` 17 | Parent string `json:"parent"` 18 | 19 | httpHeader 20 | } 21 | 22 | // Permission struct represents an OpenAPI permission. 23 | type Permission struct { 24 | CreatedAt int64 `json:"created"` 25 | ID string `json:"id"` 26 | Object string `json:"object"` 27 | AllowCreateEngine bool `json:"allow_create_engine"` 28 | AllowSampling bool `json:"allow_sampling"` 29 | AllowLogprobs bool `json:"allow_logprobs"` 30 | AllowSearchIndices bool `json:"allow_search_indices"` 31 | AllowView bool `json:"allow_view"` 32 | AllowFineTuning bool `json:"allow_fine_tuning"` 33 | Organization string `json:"organization"` 34 | Group interface{} `json:"group"` 35 | IsBlocking bool `json:"is_blocking"` 36 | } 37 | 38 | // FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. 39 | type FineTuneModelDeleteResponse struct { 40 | ID string `json:"id"` 41 | Object string `json:"object"` 42 | Deleted bool `json:"deleted"` 43 | 44 | httpHeader 45 | } 46 | 47 | // ModelsList is a list of models, including those that belong to the user or organization. 48 | type ModelsList struct { 49 | Models []Model `json:"data"` 50 | 51 | httpHeader 52 | } 53 | 54 | // ListModels Lists the currently available models, 55 | // and provides basic information about each model such as the model id and parent. 56 | func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { 57 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models")) 58 | if err != nil { 59 | return 60 | } 61 | 62 | err = c.sendRequest(req, &models) 63 | return 64 | } 65 | 66 | // GetModel Retrieves a model instance, providing basic information about 67 | // the model such as the owner and permissioning. 68 | func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { 69 | urlSuffix := fmt.Sprintf("/models/%s", modelID) 70 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 71 | if err != nil { 72 | return 73 | } 74 | 75 | err = c.sendRequest(req, &model) 76 | return 77 | } 78 | 79 | // DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner 80 | // role in your organization to delete a model. 81 | func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( 82 | response FineTuneModelDeleteResponse, err error) { 83 | req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) 84 | if err != nil { 85 | return 86 | } 87 | 88 | err = c.sendRequest(req, &response) 89 | return 90 | } 91 | -------------------------------------------------------------------------------- /edits_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "testing" 10 | "time" 11 | 12 | "github.com/sashabaranov/go-openai" 13 | "github.com/sashabaranov/go-openai/internal/test/checks" 14 | ) 15 | 16 | // TestEdits Tests the edits endpoint of the API using the mocked server. 17 | func TestEdits(t *testing.T) { 18 | client, server, teardown := setupOpenAITestServer() 19 | defer teardown() 20 | server.RegisterHandler("/v1/edits", handleEditEndpoint) 21 | // create an edit request 22 | model := "ada" 23 | editReq := openai.EditsRequest{ 24 | Model: &model, 25 | Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + 26 | "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + 27 | " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + 28 | " ex ea commodo consequat. Duis aute irure dolor in reprehe", 29 | Instruction: "test instruction", 30 | N: 3, 31 | } 32 | response, err := client.Edits(context.Background(), editReq) 33 | checks.NoError(t, err, "Edits error") 34 | if len(response.Choices) != editReq.N { 35 | t.Fatalf("edits does not properly return the correct number of choices") 36 | } 37 | } 38 | 39 | // handleEditEndpoint Handles the edit endpoint by the test server. 40 | func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { 41 | var err error 42 | var resBytes []byte 43 | 44 | // edits only accepts POST requests 45 | if r.Method != "POST" { 46 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 47 | } 48 | var editReq openai.EditsRequest 49 | editReq, err = getEditBody(r) 50 | if err != nil { 51 | http.Error(w, "could not read request", http.StatusInternalServerError) 52 | return 53 | } 54 | // create a response 55 | res := openai.EditsResponse{ 56 | Object: "test-object", 57 | Created: time.Now().Unix(), 58 | } 59 | // edit and calculate token usage 60 | editString := "edited by mocked OpenAI server :)" 61 | inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N 62 | completionTokens := int(float32(len(editString))/4) * editReq.N 63 | for i := 0; i < editReq.N; i++ { 64 | // instruction will be hidden and only seen by OpenAI 65 | res.Choices = append(res.Choices, openai.EditsChoice{ 66 | Text: editReq.Input + editString, 67 | Index: i, 68 | }) 69 | } 70 | res.Usage = openai.Usage{ 71 | PromptTokens: inputTokens, 72 | CompletionTokens: completionTokens, 73 | TotalTokens: inputTokens + completionTokens, 74 | } 75 | resBytes, _ = json.Marshal(res) 76 | fmt.Fprint(w, string(resBytes)) 77 | } 78 | 79 | // getEditBody Returns the body of the request to create an edit. 80 | func getEditBody(r *http.Request) (openai.EditsRequest, error) { 81 | edit := openai.EditsRequest{} 82 | // read the request body 83 | reqBody, err := io.ReadAll(r.Body) 84 | if err != nil { 85 | return openai.EditsRequest{}, err 86 | } 87 | err = json.Unmarshal(reqBody, &edit) 88 | if err != nil { 89 | return openai.EditsRequest{}, err 90 | } 91 | return edit, nil 92 | } 93 | -------------------------------------------------------------------------------- /examples/completion-with-tool/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/sashabaranov/go-openai" 9 | "github.com/sashabaranov/go-openai/jsonschema" 10 | ) 11 | 12 | func main() { 13 | ctx := context.Background() 14 | client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 15 | 16 | // describe the function & its inputs 17 | params := jsonschema.Definition{ 18 | Type: jsonschema.Object, 19 | Properties: map[string]jsonschema.Definition{ 20 | "location": { 21 | Type: jsonschema.String, 22 | Description: "The city and state, e.g. San Francisco, CA", 23 | }, 24 | "unit": { 25 | Type: jsonschema.String, 26 | Enum: []string{"celsius", "fahrenheit"}, 27 | }, 28 | }, 29 | Required: []string{"location"}, 30 | } 31 | f := openai.FunctionDefinition{ 32 | Name: "get_current_weather", 33 | Description: "Get the current weather in a given location", 34 | Parameters: params, 35 | } 36 | t := openai.Tool{ 37 | Type: openai.ToolTypeFunction, 38 | Function: &f, 39 | } 40 | 41 | // simulate user asking a question that requires the function 42 | dialogue := []openai.ChatCompletionMessage{ 43 | {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, 44 | } 45 | fmt.Printf("Asking OpenAI '%v' and providing it a '%v()' function...\n", 46 | dialogue[0].Content, f.Name) 47 | resp, err := client.CreateChatCompletion(ctx, 48 | openai.ChatCompletionRequest{ 49 | Model: openai.GPT4TurboPreview, 50 | Messages: dialogue, 51 | Tools: []openai.Tool{t}, 52 | }, 53 | ) 54 | if err != nil || len(resp.Choices) != 1 { 55 | fmt.Printf("Completion error: err:%v len(choices):%v\n", err, 56 | len(resp.Choices)) 57 | return 58 | } 59 | msg := resp.Choices[0].Message 60 | if len(msg.ToolCalls) != 1 { 61 | fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) 62 | return 63 | } 64 | 65 | // simulate calling the function & responding to OpenAI 66 | dialogue = append(dialogue, msg) 67 | fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", 68 | msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) 69 | dialogue = append(dialogue, openai.ChatCompletionMessage{ 70 | Role: openai.ChatMessageRoleTool, 71 | Content: "Sunny and 80 degrees.", 72 | Name: msg.ToolCalls[0].Function.Name, 73 | ToolCallID: msg.ToolCalls[0].ID, 74 | }) 75 | fmt.Printf("Sending OpenAI our '%v()' function's response and requesting the reply to the original question...\n", 76 | f.Name) 77 | resp, err = client.CreateChatCompletion(ctx, 78 | openai.ChatCompletionRequest{ 79 | Model: openai.GPT4TurboPreview, 80 | Messages: dialogue, 81 | Tools: []openai.Tool{t}, 82 | }, 83 | ) 84 | if err != nil || len(resp.Choices) != 1 { 85 | fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, 86 | len(resp.Choices)) 87 | return 88 | } 89 | 90 | // display OpenAI's response to the original question utilizing our function 91 | msg = resp.Choices[0].Message 92 | fmt.Printf("OpenAI answered the original request with: %v\n", 93 | msg.Content) 94 | } 95 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // APIError provides error information returned by the OpenAI API. 10 | // InnerError struct is only valid for Azure OpenAI Service. 11 | type APIError struct { 12 | Code any `json:"code,omitempty"` 13 | Message string `json:"message"` 14 | Param *string `json:"param,omitempty"` 15 | Type string `json:"type"` 16 | HTTPStatus string `json:"-"` 17 | HTTPStatusCode int `json:"-"` 18 | InnerError *InnerError `json:"innererror,omitempty"` 19 | } 20 | 21 | // InnerError Azure Content filtering. Only valid for Azure OpenAI Service. 22 | type InnerError struct { 23 | Code string `json:"code,omitempty"` 24 | ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` 25 | } 26 | 27 | // RequestError provides information about generic request errors. 28 | type RequestError struct { 29 | HTTPStatus string 30 | HTTPStatusCode int 31 | Err error 32 | Body []byte 33 | } 34 | 35 | type ErrorResponse struct { 36 | Error *APIError `json:"error,omitempty"` 37 | } 38 | 39 | func (e *APIError) Error() string { 40 | if e.HTTPStatusCode > 0 { 41 | return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) 42 | } 43 | 44 | return e.Message 45 | } 46 | 47 | func (e *APIError) UnmarshalJSON(data []byte) (err error) { 48 | var rawMap map[string]json.RawMessage 49 | err = json.Unmarshal(data, &rawMap) 50 | if err != nil { 51 | return 52 | } 53 | 54 | err = json.Unmarshal(rawMap["message"], &e.Message) 55 | if err != nil { 56 | // If the parameter field of a function call is invalid as a JSON schema 57 | // refs: https://github.com/sashabaranov/go-openai/issues/381 58 | var messages []string 59 | err = json.Unmarshal(rawMap["message"], &messages) 60 | if err != nil { 61 | return 62 | } 63 | e.Message = strings.Join(messages, ", ") 64 | } 65 | 66 | // optional fields for azure openai 67 | // refs: https://github.com/sashabaranov/go-openai/issues/343 68 | if _, ok := rawMap["type"]; ok { 69 | err = json.Unmarshal(rawMap["type"], &e.Type) 70 | if err != nil { 71 | return 72 | } 73 | } 74 | 75 | if _, ok := rawMap["innererror"]; ok { 76 | err = json.Unmarshal(rawMap["innererror"], &e.InnerError) 77 | if err != nil { 78 | return 79 | } 80 | } 81 | 82 | // optional fields 83 | if _, ok := rawMap["param"]; ok { 84 | err = json.Unmarshal(rawMap["param"], &e.Param) 85 | if err != nil { 86 | return 87 | } 88 | } 89 | 90 | if _, ok := rawMap["code"]; !ok { 91 | return nil 92 | } 93 | 94 | // if the api returned a number, we need to force an integer 95 | // since the json package defaults to float64 96 | var intCode int 97 | err = json.Unmarshal(rawMap["code"], &intCode) 98 | if err == nil { 99 | e.Code = intCode 100 | return nil 101 | } 102 | 103 | return json.Unmarshal(rawMap["code"], &e.Code) 104 | } 105 | 106 | func (e *RequestError) Error() string { 107 | return fmt.Sprintf( 108 | "error, status code: %d, status: %s, message: %s, body: %s", 109 | e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, 110 | ) 111 | } 112 | 113 | func (e *RequestError) Unwrap() error { 114 | return e.Err 115 | } 116 | -------------------------------------------------------------------------------- /models_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "testing" 10 | "time" 11 | 12 | "github.com/sashabaranov/go-openai" 13 | "github.com/sashabaranov/go-openai/internal/test/checks" 14 | ) 15 | 16 | const testFineTuneModelID = "fine-tune-model-id" 17 | 18 | // TestListModels Tests the list models endpoint of the API using the mocked server. 19 | func TestListModels(t *testing.T) { 20 | client, server, teardown := setupOpenAITestServer() 21 | defer teardown() 22 | server.RegisterHandler("/v1/models", handleListModelsEndpoint) 23 | _, err := client.ListModels(context.Background()) 24 | checks.NoError(t, err, "ListModels error") 25 | } 26 | 27 | func TestAzureListModels(t *testing.T) { 28 | client, server, teardown := setupAzureTestServer() 29 | defer teardown() 30 | server.RegisterHandler("/openai/models", handleListModelsEndpoint) 31 | _, err := client.ListModels(context.Background()) 32 | checks.NoError(t, err, "ListModels error") 33 | } 34 | 35 | // handleListModelsEndpoint Handles the list models endpoint by the test server. 36 | func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { 37 | resBytes, _ := json.Marshal(openai.ModelsList{}) 38 | fmt.Fprintln(w, string(resBytes)) 39 | } 40 | 41 | // TestGetModel Tests the retrieve model endpoint of the API using the mocked server. 42 | func TestGetModel(t *testing.T) { 43 | client, server, teardown := setupOpenAITestServer() 44 | defer teardown() 45 | server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) 46 | _, err := client.GetModel(context.Background(), "text-davinci-003") 47 | checks.NoError(t, err, "GetModel error") 48 | } 49 | 50 | func TestAzureGetModel(t *testing.T) { 51 | client, server, teardown := setupAzureTestServer() 52 | defer teardown() 53 | server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint) 54 | _, err := client.GetModel(context.Background(), "text-davinci-003") 55 | checks.NoError(t, err, "GetModel error") 56 | } 57 | 58 | // handleGetModelsEndpoint Handles the get model endpoint by the test server. 59 | func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { 60 | resBytes, _ := json.Marshal(openai.Model{}) 61 | fmt.Fprintln(w, string(resBytes)) 62 | } 63 | 64 | func TestGetModelReturnTimeoutError(t *testing.T) { 65 | client, server, teardown := setupOpenAITestServer() 66 | defer teardown() 67 | server.RegisterHandler("/v1/models/text-davinci-003", func(http.ResponseWriter, *http.Request) { 68 | time.Sleep(10 * time.Nanosecond) 69 | }) 70 | ctx := context.Background() 71 | ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) 72 | defer cancel() 73 | 74 | _, err := client.GetModel(ctx, "text-davinci-003") 75 | if err == nil { 76 | t.Fatal("Did not return error") 77 | } 78 | if !os.IsTimeout(err) { 79 | t.Fatal("Did not return timeout error") 80 | } 81 | } 82 | 83 | func TestDeleteFineTuneModel(t *testing.T) { 84 | client, server, teardown := setupOpenAITestServer() 85 | defer teardown() 86 | server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) 87 | _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) 88 | checks.NoError(t, err, "DeleteFineTuneModel error") 89 | } 90 | 91 | func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { 92 | resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) 93 | fmt.Fprintln(w, string(resBytes)) 94 | } 95 | -------------------------------------------------------------------------------- /fine_tuning_job_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/sashabaranov/go-openai" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | const testFineTuninigJobID = "fine-tuning-job-id" 15 | 16 | // TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. 17 | func TestFineTuningJob(t *testing.T) { 18 | client, server, teardown := setupOpenAITestServer() 19 | defer teardown() 20 | server.RegisterHandler( 21 | "/v1/fine_tuning/jobs", 22 | func(w http.ResponseWriter, _ *http.Request) { 23 | resBytes, _ := json.Marshal(openai.FineTuningJob{ 24 | Object: "fine_tuning.job", 25 | ID: testFineTuninigJobID, 26 | Model: "davinci-002", 27 | CreatedAt: 1692661014, 28 | FinishedAt: 1692661190, 29 | FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", 30 | OrganizationID: "org-123", 31 | ResultFiles: []string{"file-abc123"}, 32 | Status: "succeeded", 33 | ValidationFile: "", 34 | TrainingFile: "file-abc123", 35 | Hyperparameters: openai.Hyperparameters{ 36 | Epochs: "auto", 37 | LearningRateMultiplier: "auto", 38 | BatchSize: "auto", 39 | }, 40 | TrainedTokens: 5768, 41 | }) 42 | fmt.Fprintln(w, string(resBytes)) 43 | }, 44 | ) 45 | 46 | server.RegisterHandler( 47 | "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", 48 | func(w http.ResponseWriter, _ *http.Request) { 49 | resBytes, _ := json.Marshal(openai.FineTuningJob{}) 50 | fmt.Fprintln(w, string(resBytes)) 51 | }, 52 | ) 53 | 54 | server.RegisterHandler( 55 | "/v1/fine_tuning/jobs/"+testFineTuninigJobID, 56 | func(w http.ResponseWriter, _ *http.Request) { 57 | var resBytes []byte 58 | resBytes, _ = json.Marshal(openai.FineTuningJob{}) 59 | fmt.Fprintln(w, string(resBytes)) 60 | }, 61 | ) 62 | 63 | server.RegisterHandler( 64 | "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", 65 | func(w http.ResponseWriter, _ *http.Request) { 66 | resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) 67 | fmt.Fprintln(w, string(resBytes)) 68 | }, 69 | ) 70 | 71 | ctx := context.Background() 72 | 73 | _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) 74 | checks.NoError(t, err, "CreateFineTuningJob error") 75 | 76 | _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) 77 | checks.NoError(t, err, "CancelFineTuningJob error") 78 | 79 | _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) 80 | checks.NoError(t, err, "RetrieveFineTuningJob error") 81 | 82 | _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) 83 | checks.NoError(t, err, "ListFineTuningJobEvents error") 84 | 85 | _, err = client.ListFineTuningJobEvents( 86 | ctx, 87 | testFineTuninigJobID, 88 | openai.ListFineTuningJobEventsWithAfter("last-event-id"), 89 | ) 90 | checks.NoError(t, err, "ListFineTuningJobEvents error") 91 | 92 | _, err = client.ListFineTuningJobEvents( 93 | ctx, 94 | testFineTuninigJobID, 95 | openai.ListFineTuningJobEventsWithLimit(10), 96 | ) 97 | checks.NoError(t, err, "ListFineTuningJobEvents error") 98 | 99 | _, err = client.ListFineTuningJobEvents( 100 | ctx, 101 | testFineTuninigJobID, 102 | openai.ListFineTuningJobEventsWithAfter("last-event-id"), 103 | openai.ListFineTuningJobEventsWithLimit(10), 104 | ) 105 | checks.NoError(t, err, "ListFineTuningJobEvents error") 106 | } 107 | -------------------------------------------------------------------------------- /audio_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | 11 | "github.com/sashabaranov/go-openai/internal/test" 12 | "github.com/sashabaranov/go-openai/internal/test/checks" 13 | ) 14 | 15 | func TestAudioWithFailingFormBuilder(t *testing.T) { 16 | dir, cleanup := test.CreateTestDirectory(t) 17 | defer cleanup() 18 | path := filepath.Join(dir, "fake.mp3") 19 | test.CreateTestFile(t, path) 20 | 21 | req := AudioRequest{ 22 | FilePath: path, 23 | Prompt: "test", 24 | Temperature: 0.5, 25 | Language: "en", 26 | Format: AudioResponseFormatSRT, 27 | TimestampGranularities: []TranscriptionTimestampGranularity{ 28 | TranscriptionTimestampGranularitySegment, 29 | TranscriptionTimestampGranularityWord, 30 | }, 31 | } 32 | 33 | mockFailedErr := fmt.Errorf("mock form builder fail") 34 | mockBuilder := &mockFormBuilder{} 35 | 36 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 37 | return mockFailedErr 38 | } 39 | err := audioMultipartForm(req, mockBuilder) 40 | checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") 41 | 42 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 43 | return nil 44 | } 45 | 46 | var failForField string 47 | mockBuilder.mockWriteField = func(fieldname, _ string) error { 48 | if fieldname == failForField { 49 | return mockFailedErr 50 | } 51 | return nil 52 | } 53 | 54 | failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} 55 | for _, failingField := range failOn { 56 | failForField = failingField 57 | mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) 58 | 59 | err = audioMultipartForm(req, mockBuilder) 60 | checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") 61 | } 62 | } 63 | 64 | func TestCreateFileField(t *testing.T) { 65 | t.Run("createFileField failing file", func(t *testing.T) { 66 | dir, cleanup := test.CreateTestDirectory(t) 67 | defer cleanup() 68 | path := filepath.Join(dir, "fake.mp3") 69 | test.CreateTestFile(t, path) 70 | 71 | req := AudioRequest{ 72 | FilePath: path, 73 | } 74 | 75 | mockFailedErr := fmt.Errorf("mock form builder fail") 76 | mockBuilder := &mockFormBuilder{ 77 | mockCreateFormFile: func(string, *os.File) error { 78 | return mockFailedErr 79 | }, 80 | } 81 | 82 | err := createFileField(req, mockBuilder) 83 | checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails") 84 | }) 85 | 86 | t.Run("createFileField failing reader", func(t *testing.T) { 87 | req := AudioRequest{ 88 | FilePath: "test.wav", 89 | Reader: bytes.NewBuffer([]byte(`wav test contents`)), 90 | } 91 | 92 | mockFailedErr := fmt.Errorf("mock form builder fail") 93 | mockBuilder := &mockFormBuilder{ 94 | mockCreateFormFileReader: func(string, io.Reader, string) error { 95 | return mockFailedErr 96 | }, 97 | } 98 | 99 | err := createFileField(req, mockBuilder) 100 | checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails") 101 | }) 102 | 103 | t.Run("createFileField failing open", func(t *testing.T) { 104 | req := AudioRequest{ 105 | FilePath: "non_existing_file.wav", 106 | } 107 | 108 | mockBuilder := &mockFormBuilder{} 109 | 110 | err := createFileField(req, mockBuilder) 111 | checks.HasError(t, err, "createFileField using file should return error when open file fails") 112 | }) 113 | } 114 | -------------------------------------------------------------------------------- /files_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | "testing" 9 | 10 | utils "github.com/sashabaranov/go-openai/internal" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { 15 | config := DefaultConfig("") 16 | config.BaseURL = "" 17 | client := NewClientWithConfig(config) 18 | mockBuilder := &mockFormBuilder{} 19 | client.createFormBuilder = func(io.Writer) utils.FormBuilder { 20 | return mockBuilder 21 | } 22 | 23 | ctx := context.Background() 24 | req := FileBytesRequest{ 25 | Name: "foo", 26 | Bytes: []byte("foo"), 27 | Purpose: PurposeAssistants, 28 | } 29 | 30 | mockError := fmt.Errorf("mockWriteField error") 31 | mockBuilder.mockWriteField = func(string, string) error { 32 | return mockError 33 | } 34 | _, err := client.CreateFileBytes(ctx, req) 35 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 36 | 37 | mockError = fmt.Errorf("mockCreateFormFile error") 38 | mockBuilder.mockWriteField = func(string, string) error { 39 | return nil 40 | } 41 | mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { 42 | return mockError 43 | } 44 | _, err = client.CreateFileBytes(ctx, req) 45 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 46 | 47 | mockError = fmt.Errorf("mockClose error") 48 | mockBuilder.mockWriteField = func(string, string) error { 49 | return nil 50 | } 51 | mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { 52 | return nil 53 | } 54 | mockBuilder.mockClose = func() error { 55 | return mockError 56 | } 57 | _, err = client.CreateFileBytes(ctx, req) 58 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 59 | } 60 | 61 | func TestFileUploadWithFailingFormBuilder(t *testing.T) { 62 | config := DefaultConfig("") 63 | config.BaseURL = "" 64 | client := NewClientWithConfig(config) 65 | mockBuilder := &mockFormBuilder{} 66 | client.createFormBuilder = func(io.Writer) utils.FormBuilder { 67 | return mockBuilder 68 | } 69 | 70 | ctx := context.Background() 71 | req := FileRequest{ 72 | FileName: "test.go", 73 | FilePath: "client.go", 74 | Purpose: "fine-tune", 75 | } 76 | 77 | mockError := fmt.Errorf("mockWriteField error") 78 | mockBuilder.mockWriteField = func(string, string) error { 79 | return mockError 80 | } 81 | _, err := client.CreateFile(ctx, req) 82 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 83 | 84 | mockError = fmt.Errorf("mockCreateFormFile error") 85 | mockBuilder.mockWriteField = func(string, string) error { 86 | return nil 87 | } 88 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 89 | return mockError 90 | } 91 | _, err = client.CreateFile(ctx, req) 92 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 93 | 94 | mockError = fmt.Errorf("mockClose error") 95 | mockBuilder.mockWriteField = func(string, string) error { 96 | return nil 97 | } 98 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 99 | return nil 100 | } 101 | mockBuilder.mockClose = func() error { 102 | return mockError 103 | } 104 | _, err = client.CreateFile(ctx, req) 105 | if err == nil { 106 | t.Fatal("CreateFile should return error if form builder fails") 107 | } 108 | checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") 109 | } 110 | 111 | func TestFileUploadWithNonExistentPath(t *testing.T) { 112 | config := DefaultConfig("") 113 | config.BaseURL = "" 114 | client := NewClientWithConfig(config) 115 | 116 | ctx := context.Background() 117 | req := FileRequest{ 118 | FilePath: "some non existent file path/F616FD18-589E-44A8-BF0C-891EAE69C455", 119 | } 120 | 121 | _, err := client.CreateFile(ctx, req) 122 | checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") 123 | } 124 | -------------------------------------------------------------------------------- /chat_stream.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type ChatCompletionStreamChoiceDelta struct { 9 | Content string `json:"content,omitempty"` 10 | Role string `json:"role,omitempty"` 11 | FunctionCall *FunctionCall `json:"function_call,omitempty"` 12 | ToolCalls []ToolCall `json:"tool_calls,omitempty"` 13 | Refusal string `json:"refusal,omitempty"` 14 | } 15 | 16 | type ChatCompletionStreamChoiceLogprobs struct { 17 | Content []ChatCompletionTokenLogprob `json:"content,omitempty"` 18 | Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` 19 | } 20 | 21 | type ChatCompletionTokenLogprob struct { 22 | Token string `json:"token"` 23 | Bytes []int64 `json:"bytes,omitempty"` 24 | Logprob float64 `json:"logprob,omitempty"` 25 | TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` 26 | } 27 | 28 | type ChatCompletionTokenLogprobTopLogprob struct { 29 | Token string `json:"token"` 30 | Bytes []int64 `json:"bytes"` 31 | Logprob float64 `json:"logprob"` 32 | } 33 | 34 | type ChatCompletionStreamChoice struct { 35 | Index int `json:"index"` 36 | Delta ChatCompletionStreamChoiceDelta `json:"delta"` 37 | Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` 38 | FinishReason FinishReason `json:"finish_reason"` 39 | ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` 40 | } 41 | 42 | type PromptFilterResult struct { 43 | Index int `json:"index"` 44 | ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` 45 | } 46 | 47 | type ChatCompletionStreamResponse struct { 48 | ID string `json:"id"` 49 | Object string `json:"object"` 50 | Created int64 `json:"created"` 51 | Model string `json:"model"` 52 | Choices []ChatCompletionStreamChoice `json:"choices"` 53 | SystemFingerprint string `json:"system_fingerprint"` 54 | PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` 55 | PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` 56 | // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. 57 | // When present, it contains a null value except for the last chunk which contains the token usage statistics 58 | // for the entire request. 59 | Usage *Usage `json:"usage,omitempty"` 60 | } 61 | 62 | // ChatCompletionStream 63 | // Note: Perhaps it is more elegant to abstract Stream using generics. 64 | type ChatCompletionStream struct { 65 | *streamReader[ChatCompletionStreamResponse] 66 | } 67 | 68 | // CreateChatCompletionStream — API call to create a chat completion w/ streaming 69 | // support. It sets whether to stream back partial progress. If set, tokens will be 70 | // sent as data-only server-sent events as they become available, with the 71 | // stream terminated by a data: [DONE] message. 72 | func (c *Client) CreateChatCompletionStream( 73 | ctx context.Context, 74 | request ChatCompletionRequest, 75 | ) (stream *ChatCompletionStream, err error) { 76 | urlSuffix := chatCompletionsSuffix 77 | if !checkEndpointSupportsModel(urlSuffix, request.Model) { 78 | err = ErrChatCompletionInvalidModel 79 | return 80 | } 81 | 82 | request.Stream = true 83 | if err = validateRequestForO1Models(request); err != nil { 84 | return 85 | } 86 | 87 | req, err := c.newRequest( 88 | ctx, 89 | http.MethodPost, 90 | c.fullURL(urlSuffix, withModel(request.Model)), 91 | withBody(request), 92 | ) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) 98 | if err != nil { 99 | return 100 | } 101 | stream = &ChatCompletionStream{ 102 | streamReader: resp, 103 | } 104 | return 105 | } 106 | -------------------------------------------------------------------------------- /moderation.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | ) 8 | 9 | // The moderation endpoint is a tool you can use to check whether content complies with OpenAI's usage policies. 10 | // Developers can thus identify content that our usage policies prohibits and take action, for instance by filtering it. 11 | 12 | // The default is text-moderation-latest which will be automatically upgraded over time. 13 | // This ensures you are always using our most accurate model. 14 | // If you use text-moderation-stable, we will provide advanced notice before updating the model. 15 | // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. 16 | const ( 17 | ModerationOmniLatest = "omni-moderation-latest" 18 | ModerationOmni20240926 = "omni-moderation-2024-09-26" 19 | ModerationTextStable = "text-moderation-stable" 20 | ModerationTextLatest = "text-moderation-latest" 21 | // Deprecated: use ModerationTextStable and ModerationTextLatest instead. 22 | ModerationText001 = "text-moderation-001" 23 | ) 24 | 25 | var ( 26 | ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll 27 | ) 28 | 29 | var validModerationModel = map[string]struct{}{ 30 | ModerationOmniLatest: {}, 31 | ModerationOmni20240926: {}, 32 | ModerationTextStable: {}, 33 | ModerationTextLatest: {}, 34 | } 35 | 36 | // ModerationRequest represents a request structure for moderation API. 37 | type ModerationRequest struct { 38 | Input string `json:"input,omitempty"` 39 | Model string `json:"model,omitempty"` 40 | } 41 | 42 | // Result represents one of possible moderation results. 43 | type Result struct { 44 | Categories ResultCategories `json:"categories"` 45 | CategoryScores ResultCategoryScores `json:"category_scores"` 46 | Flagged bool `json:"flagged"` 47 | } 48 | 49 | // ResultCategories represents Categories of Result. 50 | type ResultCategories struct { 51 | Hate bool `json:"hate"` 52 | HateThreatening bool `json:"hate/threatening"` 53 | Harassment bool `json:"harassment"` 54 | HarassmentThreatening bool `json:"harassment/threatening"` 55 | SelfHarm bool `json:"self-harm"` 56 | SelfHarmIntent bool `json:"self-harm/intent"` 57 | SelfHarmInstructions bool `json:"self-harm/instructions"` 58 | Sexual bool `json:"sexual"` 59 | SexualMinors bool `json:"sexual/minors"` 60 | Violence bool `json:"violence"` 61 | ViolenceGraphic bool `json:"violence/graphic"` 62 | } 63 | 64 | // ResultCategoryScores represents CategoryScores of Result. 65 | type ResultCategoryScores struct { 66 | Hate float32 `json:"hate"` 67 | HateThreatening float32 `json:"hate/threatening"` 68 | Harassment float32 `json:"harassment"` 69 | HarassmentThreatening float32 `json:"harassment/threatening"` 70 | SelfHarm float32 `json:"self-harm"` 71 | SelfHarmIntent float32 `json:"self-harm/intent"` 72 | SelfHarmInstructions float32 `json:"self-harm/instructions"` 73 | Sexual float32 `json:"sexual"` 74 | SexualMinors float32 `json:"sexual/minors"` 75 | Violence float32 `json:"violence"` 76 | ViolenceGraphic float32 `json:"violence/graphic"` 77 | } 78 | 79 | // ModerationResponse represents a response structure for moderation API. 80 | type ModerationResponse struct { 81 | ID string `json:"id"` 82 | Model string `json:"model"` 83 | Results []Result `json:"results"` 84 | 85 | httpHeader 86 | } 87 | 88 | // Moderations — perform a moderation api call over a string. 89 | // Input can be an array or slice but a string will reduce the complexity. 90 | func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { 91 | if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok { 92 | err = ErrModerationInvalidModel 93 | return 94 | } 95 | req, err := c.newRequest( 96 | ctx, 97 | http.MethodPost, 98 | c.fullURL("/moderations", withModel(request.Model)), 99 | withBody(&request), 100 | ) 101 | if err != nil { 102 | return 103 | } 104 | 105 | err = c.sendRequest(req, &response) 106 | return 107 | } 108 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | ## Overview 4 | Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. 5 | 6 | ## Reporting Bugs 7 | If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. 8 | 9 | ## Suggesting Features 10 | If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. 11 | 12 | ## Reporting Vulnerabilities 13 | If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. 14 | 15 | ## Questions for Users 16 | If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). 17 | 18 | ## Contributing Code 19 | There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. 20 | 21 | ### Requirements for Merging a Pull Request 22 | 23 | The requirements to accept a pull request are as follows: 24 | 25 | - Features not provided by the OpenAI API will not be accepted. 26 | - The functionality of the feature must match that of the official OpenAI API. 27 | - All pull requests should be written in Go according to common conventions, formatted with `goimports`, and free of warnings from tools like `golangci-lint`. 28 | - Include tests and ensure all tests pass. 29 | - Maintain test coverage without any reduction. 30 | - All pull requests require approval from at least one Go OpenAI maintainer. 31 | 32 | **Note:** 33 | The merging method for pull requests in this repository is squash merge. 34 | 35 | ### Creating a Pull Request 36 | - Fork the repository. 37 | - Create a new branch and commit your changes. 38 | - Push that branch to GitHub. 39 | - Start a new Pull Request on GitHub. (Please use the pull request template to provide detailed information.) 40 | 41 | **Note:** 42 | If your changes introduce breaking changes, please prefix your pull request title with "[BREAKING_CHANGES]". 43 | 44 | ### Code Style 45 | In this project, we adhere to the standard coding style of Go. Your code should maintain consistency with the rest of the codebase. To achieve this, please format your code using tools like `goimports` and resolve any syntax or style issues with `golangci-lint`. 46 | 47 | **Run goimports:** 48 | ``` 49 | go install golang.org/x/tools/cmd/goimports@latest 50 | ``` 51 | 52 | ``` 53 | goimports -w . 54 | ``` 55 | 56 | **Run golangci-lint:** 57 | ``` 58 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 59 | ``` 60 | 61 | ``` 62 | golangci-lint run --out-format=github-actions 63 | ``` 64 | 65 | ### Unit Test 66 | Please create or update tests relevant to your changes. Ensure all tests run successfully to verify that your modifications do not adversely affect other functionalities. 67 | 68 | **Run test:** 69 | ``` 70 | go test -v ./... 71 | ``` 72 | 73 | ### Integration Test 74 | Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. 75 | 76 | **Notes:** 77 | These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. 78 | 79 | **Run integration test:** 80 | ``` 81 | OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go 82 | ``` 83 | 84 | If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. 85 | 86 | --- 87 | 88 | We wholeheartedly welcome your active participation. Let's build an amazing project together! 89 | -------------------------------------------------------------------------------- /files.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | ) 10 | 11 | type FileRequest struct { 12 | FileName string `json:"file"` 13 | FilePath string `json:"-"` 14 | Purpose string `json:"purpose"` 15 | } 16 | 17 | // PurposeType represents the purpose of the file when uploading. 18 | type PurposeType string 19 | 20 | const ( 21 | PurposeFineTune PurposeType = "fine-tune" 22 | PurposeFineTuneResults PurposeType = "fine-tune-results" 23 | PurposeAssistants PurposeType = "assistants" 24 | PurposeAssistantsOutput PurposeType = "assistants_output" 25 | PurposeBatch PurposeType = "batch" 26 | ) 27 | 28 | // FileBytesRequest represents a file upload request. 29 | type FileBytesRequest struct { 30 | // the name of the uploaded file in OpenAI 31 | Name string 32 | // the bytes of the file 33 | Bytes []byte 34 | // the purpose of the file 35 | Purpose PurposeType 36 | } 37 | 38 | // File struct represents an OpenAPI file. 39 | type File struct { 40 | Bytes int `json:"bytes"` 41 | CreatedAt int64 `json:"created_at"` 42 | ID string `json:"id"` 43 | FileName string `json:"filename"` 44 | Object string `json:"object"` 45 | Status string `json:"status"` 46 | Purpose string `json:"purpose"` 47 | StatusDetails string `json:"status_details"` 48 | 49 | httpHeader 50 | } 51 | 52 | // FilesList is a list of files that belong to the user or organization. 53 | type FilesList struct { 54 | Files []File `json:"data"` 55 | 56 | httpHeader 57 | } 58 | 59 | // CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. 60 | func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { 61 | var b bytes.Buffer 62 | reader := bytes.NewReader(request.Bytes) 63 | builder := c.createFormBuilder(&b) 64 | 65 | err = builder.WriteField("purpose", string(request.Purpose)) 66 | if err != nil { 67 | return 68 | } 69 | 70 | err = builder.CreateFormFileReader("file", reader, request.Name) 71 | if err != nil { 72 | return 73 | } 74 | 75 | err = builder.Close() 76 | if err != nil { 77 | return 78 | } 79 | 80 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), 81 | withBody(&b), withContentType(builder.FormDataContentType())) 82 | if err != nil { 83 | return 84 | } 85 | 86 | err = c.sendRequest(req, &file) 87 | return 88 | } 89 | 90 | // CreateFile uploads a jsonl file to GPT3 91 | // FilePath must be a local file path. 92 | func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { 93 | var b bytes.Buffer 94 | builder := c.createFormBuilder(&b) 95 | 96 | err = builder.WriteField("purpose", request.Purpose) 97 | if err != nil { 98 | return 99 | } 100 | 101 | fileData, err := os.Open(request.FilePath) 102 | if err != nil { 103 | return 104 | } 105 | defer fileData.Close() 106 | 107 | err = builder.CreateFormFile("file", fileData) 108 | if err != nil { 109 | return 110 | } 111 | 112 | err = builder.Close() 113 | if err != nil { 114 | return 115 | } 116 | 117 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), 118 | withBody(&b), withContentType(builder.FormDataContentType())) 119 | if err != nil { 120 | return 121 | } 122 | 123 | err = c.sendRequest(req, &file) 124 | return 125 | } 126 | 127 | // DeleteFile deletes an existing file. 128 | func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { 129 | req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID)) 130 | if err != nil { 131 | return 132 | } 133 | 134 | err = c.sendRequest(req, nil) 135 | return 136 | } 137 | 138 | // ListFiles Lists the currently available files, 139 | // and provides basic information about each file such as the file name and purpose. 140 | func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { 141 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files")) 142 | if err != nil { 143 | return 144 | } 145 | 146 | err = c.sendRequest(req, &files) 147 | return 148 | } 149 | 150 | // GetFile Retrieves a file instance, providing basic information about the file 151 | // such as the file name and purpose. 152 | func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { 153 | urlSuffix := fmt.Sprintf("/files/%s", fileID) 154 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 155 | if err != nil { 156 | return 157 | } 158 | 159 | err = c.sendRequest(req, &file) 160 | return 161 | } 162 | 163 | func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { 164 | urlSuffix := fmt.Sprintf("/files/%s/content", fileID) 165 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 166 | if err != nil { 167 | return 168 | } 169 | 170 | return c.sendRequestRaw(req) 171 | } 172 | -------------------------------------------------------------------------------- /audio_api_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "mime" 9 | "mime/multipart" 10 | "net/http" 11 | "path/filepath" 12 | "strings" 13 | "testing" 14 | 15 | "github.com/sashabaranov/go-openai" 16 | "github.com/sashabaranov/go-openai/internal/test" 17 | "github.com/sashabaranov/go-openai/internal/test/checks" 18 | ) 19 | 20 | // TestAudio Tests the transcription and translation endpoints of the API using the mocked server. 21 | func TestAudio(t *testing.T) { 22 | client, server, teardown := setupOpenAITestServer() 23 | defer teardown() 24 | server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) 25 | server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) 26 | 27 | testcases := []struct { 28 | name string 29 | createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) 30 | }{ 31 | { 32 | "transcribe", 33 | client.CreateTranscription, 34 | }, 35 | { 36 | "translate", 37 | client.CreateTranslation, 38 | }, 39 | } 40 | 41 | ctx := context.Background() 42 | 43 | dir, cleanup := test.CreateTestDirectory(t) 44 | defer cleanup() 45 | 46 | for _, tc := range testcases { 47 | t.Run(tc.name, func(t *testing.T) { 48 | path := filepath.Join(dir, "fake.mp3") 49 | test.CreateTestFile(t, path) 50 | 51 | req := openai.AudioRequest{ 52 | FilePath: path, 53 | Model: "whisper-3", 54 | } 55 | _, err := tc.createFn(ctx, req) 56 | checks.NoError(t, err, "audio API error") 57 | }) 58 | 59 | t.Run(tc.name+" (with reader)", func(t *testing.T) { 60 | req := openai.AudioRequest{ 61 | FilePath: "fake.webm", 62 | Reader: bytes.NewBuffer([]byte(`some webm binary data`)), 63 | Model: "whisper-3", 64 | } 65 | _, err := tc.createFn(ctx, req) 66 | checks.NoError(t, err, "audio API error") 67 | }) 68 | } 69 | } 70 | 71 | func TestAudioWithOptionalArgs(t *testing.T) { 72 | client, server, teardown := setupOpenAITestServer() 73 | defer teardown() 74 | server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) 75 | server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) 76 | 77 | testcases := []struct { 78 | name string 79 | createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) 80 | }{ 81 | { 82 | "transcribe", 83 | client.CreateTranscription, 84 | }, 85 | { 86 | "translate", 87 | client.CreateTranslation, 88 | }, 89 | } 90 | 91 | ctx := context.Background() 92 | 93 | dir, cleanup := test.CreateTestDirectory(t) 94 | defer cleanup() 95 | 96 | for _, tc := range testcases { 97 | t.Run(tc.name, func(t *testing.T) { 98 | path := filepath.Join(dir, "fake.mp3") 99 | test.CreateTestFile(t, path) 100 | 101 | req := openai.AudioRequest{ 102 | FilePath: path, 103 | Model: "whisper-3", 104 | Prompt: "用简体中文", 105 | Temperature: 0.5, 106 | Language: "zh", 107 | Format: openai.AudioResponseFormatSRT, 108 | TimestampGranularities: []openai.TranscriptionTimestampGranularity{ 109 | openai.TranscriptionTimestampGranularitySegment, 110 | openai.TranscriptionTimestampGranularityWord, 111 | }, 112 | } 113 | _, err := tc.createFn(ctx, req) 114 | checks.NoError(t, err, "audio API error") 115 | }) 116 | } 117 | } 118 | 119 | // handleAudioEndpoint Handles the completion endpoint by the test server. 120 | func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { 121 | var err error 122 | 123 | // audio endpoints only accept POST requests 124 | if r.Method != "POST" { 125 | http.Error(w, "method not allowed", http.StatusMethodNotAllowed) 126 | } 127 | 128 | mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) 129 | if err != nil { 130 | http.Error(w, "failed to parse media type", http.StatusBadRequest) 131 | return 132 | } 133 | 134 | if !strings.HasPrefix(mediaType, "multipart") { 135 | http.Error(w, "request is not multipart", http.StatusBadRequest) 136 | } 137 | 138 | boundary, ok := params["boundary"] 139 | if !ok { 140 | http.Error(w, "no boundary in params", http.StatusBadRequest) 141 | return 142 | } 143 | 144 | fileData := &bytes.Buffer{} 145 | mr := multipart.NewReader(r.Body, boundary) 146 | part, err := mr.NextPart() 147 | if err != nil && errors.Is(err, io.EOF) { 148 | http.Error(w, "error accessing file", http.StatusBadRequest) 149 | return 150 | } 151 | if _, err = io.Copy(fileData, part); err != nil { 152 | http.Error(w, "failed to copy file", http.StatusInternalServerError) 153 | return 154 | } 155 | 156 | if len(fileData.Bytes()) == 0 { 157 | w.WriteHeader(http.StatusInternalServerError) 158 | http.Error(w, "received empty file data", http.StatusBadRequest) 159 | return 160 | } 161 | 162 | if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil { 163 | http.Error(w, "failed to write body", http.StatusInternalServerError) 164 | return 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /fine_tuning_job.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | ) 9 | 10 | type FineTuningJob struct { 11 | ID string `json:"id"` 12 | Object string `json:"object"` 13 | CreatedAt int64 `json:"created_at"` 14 | FinishedAt int64 `json:"finished_at"` 15 | Model string `json:"model"` 16 | FineTunedModel string `json:"fine_tuned_model,omitempty"` 17 | OrganizationID string `json:"organization_id"` 18 | Status string `json:"status"` 19 | Hyperparameters Hyperparameters `json:"hyperparameters"` 20 | TrainingFile string `json:"training_file"` 21 | ValidationFile string `json:"validation_file,omitempty"` 22 | ResultFiles []string `json:"result_files"` 23 | TrainedTokens int `json:"trained_tokens"` 24 | 25 | httpHeader 26 | } 27 | 28 | type Hyperparameters struct { 29 | Epochs any `json:"n_epochs,omitempty"` 30 | LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` 31 | BatchSize any `json:"batch_size,omitempty"` 32 | } 33 | 34 | type FineTuningJobRequest struct { 35 | TrainingFile string `json:"training_file"` 36 | ValidationFile string `json:"validation_file,omitempty"` 37 | Model string `json:"model,omitempty"` 38 | Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` 39 | Suffix string `json:"suffix,omitempty"` 40 | } 41 | 42 | type FineTuningJobEventList struct { 43 | Object string `json:"object"` 44 | Data []FineTuneEvent `json:"data"` 45 | HasMore bool `json:"has_more"` 46 | 47 | httpHeader 48 | } 49 | 50 | type FineTuningJobEvent struct { 51 | Object string `json:"object"` 52 | ID string `json:"id"` 53 | CreatedAt int `json:"created_at"` 54 | Level string `json:"level"` 55 | Message string `json:"message"` 56 | Data any `json:"data"` 57 | Type string `json:"type"` 58 | } 59 | 60 | // CreateFineTuningJob create a fine tuning job. 61 | func (c *Client) CreateFineTuningJob( 62 | ctx context.Context, 63 | request FineTuningJobRequest, 64 | ) (response FineTuningJob, err error) { 65 | urlSuffix := "/fine_tuning/jobs" 66 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) 67 | if err != nil { 68 | return 69 | } 70 | 71 | err = c.sendRequest(req, &response) 72 | return 73 | } 74 | 75 | // CancelFineTuningJob cancel a fine tuning job. 76 | func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { 77 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) 78 | if err != nil { 79 | return 80 | } 81 | 82 | err = c.sendRequest(req, &response) 83 | return 84 | } 85 | 86 | // RetrieveFineTuningJob retrieve a fine tuning job. 87 | func (c *Client) RetrieveFineTuningJob( 88 | ctx context.Context, 89 | fineTuningJobID string, 90 | ) (response FineTuningJob, err error) { 91 | urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) 92 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 93 | if err != nil { 94 | return 95 | } 96 | 97 | err = c.sendRequest(req, &response) 98 | return 99 | } 100 | 101 | type listFineTuningJobEventsParameters struct { 102 | after *string 103 | limit *int 104 | } 105 | 106 | type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) 107 | 108 | func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { 109 | return func(args *listFineTuningJobEventsParameters) { 110 | args.after = &after 111 | } 112 | } 113 | 114 | func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { 115 | return func(args *listFineTuningJobEventsParameters) { 116 | args.limit = &limit 117 | } 118 | } 119 | 120 | // ListFineTuningJobs list fine tuning jobs events. 121 | func (c *Client) ListFineTuningJobEvents( 122 | ctx context.Context, 123 | fineTuningJobID string, 124 | setters ...ListFineTuningJobEventsParameter, 125 | ) (response FineTuningJobEventList, err error) { 126 | parameters := &listFineTuningJobEventsParameters{ 127 | after: nil, 128 | limit: nil, 129 | } 130 | 131 | for _, setter := range setters { 132 | setter(parameters) 133 | } 134 | 135 | urlValues := url.Values{} 136 | if parameters.after != nil { 137 | urlValues.Add("after", *parameters.after) 138 | } 139 | if parameters.limit != nil { 140 | urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) 141 | } 142 | 143 | encodedValues := "" 144 | if len(urlValues) > 0 { 145 | encodedValues = "?" + urlValues.Encode() 146 | } 147 | 148 | req, err := c.newRequest( 149 | ctx, 150 | http.MethodGet, 151 | c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), 152 | ) 153 | if err != nil { 154 | return 155 | } 156 | 157 | err = c.sendRequest(req, &response) 158 | return 159 | } 160 | -------------------------------------------------------------------------------- /jsonschema/json.go: -------------------------------------------------------------------------------- 1 | // Package jsonschema provides very simple functionality for representing a JSON schema as a 2 | // (nested) struct. This struct can be used with the chat completion "function call" feature. 3 | // For more complicated schemas, it is recommended to use a dedicated JSON schema library 4 | // and/or pass in the schema in []byte format. 5 | package jsonschema 6 | 7 | import ( 8 | "encoding/json" 9 | "fmt" 10 | "reflect" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | type DataType string 16 | 17 | const ( 18 | Object DataType = "object" 19 | Number DataType = "number" 20 | Integer DataType = "integer" 21 | String DataType = "string" 22 | Array DataType = "array" 23 | Null DataType = "null" 24 | Boolean DataType = "boolean" 25 | ) 26 | 27 | // Definition is a struct for describing a JSON Schema. 28 | // It is fairly limited, and you may have better luck using a third-party library. 29 | type Definition struct { 30 | // Type specifies the data type of the schema. 31 | Type DataType `json:"type,omitempty"` 32 | // Description is the description of the schema. 33 | Description string `json:"description,omitempty"` 34 | // Enum is used to restrict a value to a fixed set of values. It must be an array with at least 35 | // one element, where each element is unique. You will probably only use this with strings. 36 | Enum []string `json:"enum,omitempty"` 37 | // Properties describes the properties of an object, if the schema type is Object. 38 | Properties map[string]Definition `json:"properties,omitempty"` 39 | // Required specifies which properties are required, if the schema type is Object. 40 | Required []string `json:"required,omitempty"` 41 | // Items specifies which data type an array contains, if the schema type is Array. 42 | Items *Definition `json:"items,omitempty"` 43 | // AdditionalProperties is used to control the handling of properties in an object 44 | // that are not explicitly defined in the properties section of the schema. example: 45 | // additionalProperties: true 46 | // additionalProperties: false 47 | // additionalProperties: jsonschema.Definition{Type: jsonschema.String} 48 | AdditionalProperties any `json:"additionalProperties,omitempty"` 49 | } 50 | 51 | func (d *Definition) MarshalJSON() ([]byte, error) { 52 | if d.Properties == nil { 53 | d.Properties = make(map[string]Definition) 54 | } 55 | type Alias Definition 56 | return json.Marshal(struct { 57 | Alias 58 | }{ 59 | Alias: (Alias)(*d), 60 | }) 61 | } 62 | 63 | func (d *Definition) Unmarshal(content string, v any) error { 64 | return VerifySchemaAndUnmarshal(*d, []byte(content), v) 65 | } 66 | 67 | func GenerateSchemaForType(v any) (*Definition, error) { 68 | return reflectSchema(reflect.TypeOf(v)) 69 | } 70 | 71 | func reflectSchema(t reflect.Type) (*Definition, error) { 72 | var d Definition 73 | switch t.Kind() { 74 | case reflect.String: 75 | d.Type = String 76 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 77 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 78 | d.Type = Integer 79 | case reflect.Float32, reflect.Float64: 80 | d.Type = Number 81 | case reflect.Bool: 82 | d.Type = Boolean 83 | case reflect.Slice, reflect.Array: 84 | d.Type = Array 85 | items, err := reflectSchema(t.Elem()) 86 | if err != nil { 87 | return nil, err 88 | } 89 | d.Items = items 90 | case reflect.Struct: 91 | d.Type = Object 92 | d.AdditionalProperties = false 93 | object, err := reflectSchemaObject(t) 94 | if err != nil { 95 | return nil, err 96 | } 97 | d = *object 98 | case reflect.Ptr: 99 | definition, err := reflectSchema(t.Elem()) 100 | if err != nil { 101 | return nil, err 102 | } 103 | d = *definition 104 | case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, 105 | reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, 106 | reflect.UnsafePointer: 107 | return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) 108 | default: 109 | } 110 | return &d, nil 111 | } 112 | 113 | func reflectSchemaObject(t reflect.Type) (*Definition, error) { 114 | var d = Definition{ 115 | Type: Object, 116 | AdditionalProperties: false, 117 | } 118 | properties := make(map[string]Definition) 119 | var requiredFields []string 120 | for i := 0; i < t.NumField(); i++ { 121 | field := t.Field(i) 122 | if !field.IsExported() { 123 | continue 124 | } 125 | jsonTag := field.Tag.Get("json") 126 | var required = true 127 | if jsonTag == "" { 128 | jsonTag = field.Name 129 | } else if strings.HasSuffix(jsonTag, ",omitempty") { 130 | jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") 131 | required = false 132 | } 133 | 134 | item, err := reflectSchema(field.Type) 135 | if err != nil { 136 | return nil, err 137 | } 138 | description := field.Tag.Get("description") 139 | if description != "" { 140 | item.Description = description 141 | } 142 | properties[jsonTag] = *item 143 | 144 | if s := field.Tag.Get("required"); s != "" { 145 | required, _ = strconv.ParseBool(s) 146 | } 147 | if required { 148 | requiredFields = append(requiredFields, jsonTag) 149 | } 150 | } 151 | d.Required = requiredFields 152 | d.Properties = properties 153 | return &d, nil 154 | } 155 | -------------------------------------------------------------------------------- /thread_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | 10 | openai "github.com/sashabaranov/go-openai" 11 | "github.com/sashabaranov/go-openai/internal/test/checks" 12 | ) 13 | 14 | // TestThread Tests the thread endpoint of the API using the mocked server. 15 | func TestThread(t *testing.T) { 16 | threadID := "thread_abc123" 17 | client, server, teardown := setupOpenAITestServer() 18 | defer teardown() 19 | 20 | server.RegisterHandler( 21 | "/v1/threads/"+threadID, 22 | func(w http.ResponseWriter, r *http.Request) { 23 | switch r.Method { 24 | case http.MethodGet: 25 | resBytes, _ := json.Marshal(openai.Thread{ 26 | ID: threadID, 27 | Object: "thread", 28 | CreatedAt: 1234567890, 29 | }) 30 | fmt.Fprintln(w, string(resBytes)) 31 | case http.MethodPost: 32 | var request openai.ThreadRequest 33 | err := json.NewDecoder(r.Body).Decode(&request) 34 | checks.NoError(t, err, "Decode error") 35 | 36 | resBytes, _ := json.Marshal(openai.Thread{ 37 | ID: threadID, 38 | Object: "thread", 39 | CreatedAt: 1234567890, 40 | }) 41 | fmt.Fprintln(w, string(resBytes)) 42 | case http.MethodDelete: 43 | fmt.Fprintln(w, `{ 44 | "id": "thread_abc123", 45 | "object": "thread.deleted", 46 | "deleted": true 47 | }`) 48 | } 49 | }, 50 | ) 51 | 52 | server.RegisterHandler( 53 | "/v1/threads", 54 | func(w http.ResponseWriter, r *http.Request) { 55 | if r.Method == http.MethodPost { 56 | var request openai.ModifyThreadRequest 57 | err := json.NewDecoder(r.Body).Decode(&request) 58 | checks.NoError(t, err, "Decode error") 59 | 60 | resBytes, _ := json.Marshal(openai.Thread{ 61 | ID: threadID, 62 | Object: "thread", 63 | CreatedAt: 1234567890, 64 | Metadata: request.Metadata, 65 | }) 66 | fmt.Fprintln(w, string(resBytes)) 67 | } 68 | }, 69 | ) 70 | 71 | ctx := context.Background() 72 | 73 | _, err := client.CreateThread(ctx, openai.ThreadRequest{ 74 | Messages: []openai.ThreadMessage{ 75 | { 76 | Role: openai.ThreadMessageRoleUser, 77 | Content: "Hello, World!", 78 | }, 79 | }, 80 | }) 81 | checks.NoError(t, err, "CreateThread error") 82 | 83 | _, err = client.RetrieveThread(ctx, threadID) 84 | checks.NoError(t, err, "RetrieveThread error") 85 | 86 | _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ 87 | Metadata: map[string]interface{}{ 88 | "key": "value", 89 | }, 90 | }) 91 | checks.NoError(t, err, "ModifyThread error") 92 | 93 | _, err = client.DeleteThread(ctx, threadID) 94 | checks.NoError(t, err, "DeleteThread error") 95 | } 96 | 97 | // TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. 98 | func TestAzureThread(t *testing.T) { 99 | threadID := "thread_abc123" 100 | client, server, teardown := setupAzureTestServer() 101 | defer teardown() 102 | 103 | server.RegisterHandler( 104 | "/openai/threads/"+threadID, 105 | func(w http.ResponseWriter, r *http.Request) { 106 | switch r.Method { 107 | case http.MethodGet: 108 | resBytes, _ := json.Marshal(openai.Thread{ 109 | ID: threadID, 110 | Object: "thread", 111 | CreatedAt: 1234567890, 112 | }) 113 | fmt.Fprintln(w, string(resBytes)) 114 | case http.MethodPost: 115 | var request openai.ThreadRequest 116 | err := json.NewDecoder(r.Body).Decode(&request) 117 | checks.NoError(t, err, "Decode error") 118 | 119 | resBytes, _ := json.Marshal(openai.Thread{ 120 | ID: threadID, 121 | Object: "thread", 122 | CreatedAt: 1234567890, 123 | }) 124 | fmt.Fprintln(w, string(resBytes)) 125 | case http.MethodDelete: 126 | fmt.Fprintln(w, `{ 127 | "id": "thread_abc123", 128 | "object": "thread.deleted", 129 | "deleted": true 130 | }`) 131 | } 132 | }, 133 | ) 134 | 135 | server.RegisterHandler( 136 | "/openai/threads", 137 | func(w http.ResponseWriter, r *http.Request) { 138 | if r.Method == http.MethodPost { 139 | var request openai.ModifyThreadRequest 140 | err := json.NewDecoder(r.Body).Decode(&request) 141 | checks.NoError(t, err, "Decode error") 142 | 143 | resBytes, _ := json.Marshal(openai.Thread{ 144 | ID: threadID, 145 | Object: "thread", 146 | CreatedAt: 1234567890, 147 | Metadata: request.Metadata, 148 | }) 149 | fmt.Fprintln(w, string(resBytes)) 150 | } 151 | }, 152 | ) 153 | 154 | ctx := context.Background() 155 | 156 | _, err := client.CreateThread(ctx, openai.ThreadRequest{ 157 | Messages: []openai.ThreadMessage{ 158 | { 159 | Role: openai.ThreadMessageRoleUser, 160 | Content: "Hello, World!", 161 | }, 162 | }, 163 | }) 164 | checks.NoError(t, err, "CreateThread error") 165 | 166 | _, err = client.RetrieveThread(ctx, threadID) 167 | checks.NoError(t, err, "RetrieveThread error") 168 | 169 | _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ 170 | Metadata: map[string]interface{}{ 171 | "key": "value", 172 | }, 173 | }) 174 | checks.NoError(t, err, "ModifyThread error") 175 | 176 | _, err = client.DeleteThread(ctx, threadID) 177 | checks.NoError(t, err, "DeleteThread error") 178 | } 179 | -------------------------------------------------------------------------------- /image_test.go: -------------------------------------------------------------------------------- 1 | package openai //nolint:testpackage // testing private field 2 | 3 | import ( 4 | utils "github.com/sashabaranov/go-openai/internal" 5 | "github.com/sashabaranov/go-openai/internal/test/checks" 6 | 7 | "context" 8 | "fmt" 9 | "io" 10 | "os" 11 | "testing" 12 | ) 13 | 14 | type mockFormBuilder struct { 15 | mockCreateFormFile func(string, *os.File) error 16 | mockCreateFormFileReader func(string, io.Reader, string) error 17 | mockWriteField func(string, string) error 18 | mockClose func() error 19 | } 20 | 21 | func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { 22 | return fb.mockCreateFormFile(fieldname, file) 23 | } 24 | 25 | func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { 26 | return fb.mockCreateFormFileReader(fieldname, r, filename) 27 | } 28 | 29 | func (fb *mockFormBuilder) WriteField(fieldname, value string) error { 30 | return fb.mockWriteField(fieldname, value) 31 | } 32 | 33 | func (fb *mockFormBuilder) Close() error { 34 | return fb.mockClose() 35 | } 36 | 37 | func (fb *mockFormBuilder) FormDataContentType() string { 38 | return "" 39 | } 40 | 41 | func TestImageFormBuilderFailures(t *testing.T) { 42 | config := DefaultConfig("") 43 | config.BaseURL = "" 44 | client := NewClientWithConfig(config) 45 | 46 | mockBuilder := &mockFormBuilder{} 47 | client.createFormBuilder = func(io.Writer) utils.FormBuilder { 48 | return mockBuilder 49 | } 50 | ctx := context.Background() 51 | 52 | req := ImageEditRequest{ 53 | Mask: &os.File{}, 54 | } 55 | 56 | mockFailedErr := fmt.Errorf("mock form builder fail") 57 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 58 | return mockFailedErr 59 | } 60 | _, err := client.CreateEditImage(ctx, req) 61 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 62 | 63 | mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { 64 | if name == "mask" { 65 | return mockFailedErr 66 | } 67 | return nil 68 | } 69 | _, err = client.CreateEditImage(ctx, req) 70 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 71 | 72 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 73 | return nil 74 | } 75 | 76 | var failForField string 77 | mockBuilder.mockWriteField = func(fieldname, _ string) error { 78 | if fieldname == failForField { 79 | return mockFailedErr 80 | } 81 | return nil 82 | } 83 | 84 | failForField = "prompt" 85 | _, err = client.CreateEditImage(ctx, req) 86 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 87 | 88 | failForField = "n" 89 | _, err = client.CreateEditImage(ctx, req) 90 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 91 | 92 | failForField = "size" 93 | _, err = client.CreateEditImage(ctx, req) 94 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 95 | 96 | failForField = "response_format" 97 | _, err = client.CreateEditImage(ctx, req) 98 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 99 | 100 | failForField = "" 101 | mockBuilder.mockClose = func() error { 102 | return mockFailedErr 103 | } 104 | _, err = client.CreateEditImage(ctx, req) 105 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 106 | } 107 | 108 | func TestVariImageFormBuilderFailures(t *testing.T) { 109 | config := DefaultConfig("") 110 | config.BaseURL = "" 111 | client := NewClientWithConfig(config) 112 | 113 | mockBuilder := &mockFormBuilder{} 114 | client.createFormBuilder = func(io.Writer) utils.FormBuilder { 115 | return mockBuilder 116 | } 117 | ctx := context.Background() 118 | 119 | req := ImageVariRequest{} 120 | 121 | mockFailedErr := fmt.Errorf("mock form builder fail") 122 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 123 | return mockFailedErr 124 | } 125 | _, err := client.CreateVariImage(ctx, req) 126 | checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") 127 | 128 | mockBuilder.mockCreateFormFile = func(string, *os.File) error { 129 | return nil 130 | } 131 | 132 | var failForField string 133 | mockBuilder.mockWriteField = func(fieldname, _ string) error { 134 | if fieldname == failForField { 135 | return mockFailedErr 136 | } 137 | return nil 138 | } 139 | 140 | failForField = "n" 141 | _, err = client.CreateVariImage(ctx, req) 142 | checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") 143 | 144 | failForField = "size" 145 | _, err = client.CreateVariImage(ctx, req) 146 | checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") 147 | 148 | failForField = "response_format" 149 | _, err = client.CreateVariImage(ctx, req) 150 | checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") 151 | 152 | failForField = "" 153 | mockBuilder.mockClose = func() error { 154 | return mockFailedErr 155 | } 156 | _, err = client.CreateVariImage(ctx, req) 157 | checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") 158 | } 159 | -------------------------------------------------------------------------------- /jsonschema/json_test.go: -------------------------------------------------------------------------------- 1 | package jsonschema_test 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/sashabaranov/go-openai/jsonschema" 9 | ) 10 | 11 | func TestDefinition_MarshalJSON(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | def jsonschema.Definition 15 | want string 16 | }{ 17 | { 18 | name: "Test with empty Definition", 19 | def: jsonschema.Definition{}, 20 | want: `{"properties":{}}`, 21 | }, 22 | { 23 | name: "Test with Definition properties set", 24 | def: jsonschema.Definition{ 25 | Type: jsonschema.String, 26 | Description: "A string type", 27 | Properties: map[string]jsonschema.Definition{ 28 | "name": { 29 | Type: jsonschema.String, 30 | }, 31 | }, 32 | }, 33 | want: `{ 34 | "type":"string", 35 | "description":"A string type", 36 | "properties":{ 37 | "name":{ 38 | "type":"string", 39 | "properties":{} 40 | } 41 | } 42 | }`, 43 | }, 44 | { 45 | name: "Test with nested Definition properties", 46 | def: jsonschema.Definition{ 47 | Type: jsonschema.Object, 48 | Properties: map[string]jsonschema.Definition{ 49 | "user": { 50 | Type: jsonschema.Object, 51 | Properties: map[string]jsonschema.Definition{ 52 | "name": { 53 | Type: jsonschema.String, 54 | }, 55 | "age": { 56 | Type: jsonschema.Integer, 57 | }, 58 | }, 59 | }, 60 | }, 61 | }, 62 | want: `{ 63 | "type":"object", 64 | "properties":{ 65 | "user":{ 66 | "type":"object", 67 | "properties":{ 68 | "name":{ 69 | "type":"string", 70 | "properties":{} 71 | }, 72 | "age":{ 73 | "type":"integer", 74 | "properties":{} 75 | } 76 | } 77 | } 78 | } 79 | }`, 80 | }, 81 | { 82 | name: "Test with complex nested Definition", 83 | def: jsonschema.Definition{ 84 | Type: jsonschema.Object, 85 | Properties: map[string]jsonschema.Definition{ 86 | "user": { 87 | Type: jsonschema.Object, 88 | Properties: map[string]jsonschema.Definition{ 89 | "name": { 90 | Type: jsonschema.String, 91 | }, 92 | "age": { 93 | Type: jsonschema.Integer, 94 | }, 95 | "address": { 96 | Type: jsonschema.Object, 97 | Properties: map[string]jsonschema.Definition{ 98 | "city": { 99 | Type: jsonschema.String, 100 | }, 101 | "country": { 102 | Type: jsonschema.String, 103 | }, 104 | }, 105 | }, 106 | }, 107 | }, 108 | }, 109 | }, 110 | want: `{ 111 | "type":"object", 112 | "properties":{ 113 | "user":{ 114 | "type":"object", 115 | "properties":{ 116 | "name":{ 117 | "type":"string", 118 | "properties":{} 119 | }, 120 | "age":{ 121 | "type":"integer", 122 | "properties":{} 123 | }, 124 | "address":{ 125 | "type":"object", 126 | "properties":{ 127 | "city":{ 128 | "type":"string", 129 | "properties":{} 130 | }, 131 | "country":{ 132 | "type":"string", 133 | "properties":{} 134 | } 135 | } 136 | } 137 | } 138 | } 139 | } 140 | }`, 141 | }, 142 | { 143 | name: "Test with Array type Definition", 144 | def: jsonschema.Definition{ 145 | Type: jsonschema.Array, 146 | Items: &jsonschema.Definition{ 147 | Type: jsonschema.String, 148 | }, 149 | Properties: map[string]jsonschema.Definition{ 150 | "name": { 151 | Type: jsonschema.String, 152 | }, 153 | }, 154 | }, 155 | want: `{ 156 | "type":"array", 157 | "items":{ 158 | "type":"string", 159 | "properties":{ 160 | 161 | } 162 | }, 163 | "properties":{ 164 | "name":{ 165 | "type":"string", 166 | "properties":{} 167 | } 168 | } 169 | }`, 170 | }, 171 | } 172 | 173 | for _, tt := range tests { 174 | t.Run(tt.name, func(t *testing.T) { 175 | wantBytes := []byte(tt.want) 176 | var want map[string]interface{} 177 | err := json.Unmarshal(wantBytes, &want) 178 | if err != nil { 179 | t.Errorf("Failed to Unmarshal JSON: error = %v", err) 180 | return 181 | } 182 | 183 | got := structToMap(t, tt.def) 184 | gotPtr := structToMap(t, &tt.def) 185 | 186 | if !reflect.DeepEqual(got, want) { 187 | t.Errorf("MarshalJSON() got = %v, want %v", got, want) 188 | } 189 | if !reflect.DeepEqual(gotPtr, want) { 190 | t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) 191 | } 192 | }) 193 | } 194 | } 195 | 196 | func structToMap(t *testing.T, v any) map[string]any { 197 | t.Helper() 198 | gotBytes, err := json.Marshal(v) 199 | if err != nil { 200 | t.Errorf("Failed to Marshal JSON: error = %v", err) 201 | return nil 202 | } 203 | 204 | var got map[string]interface{} 205 | err = json.Unmarshal(gotBytes, &got) 206 | if err != nil { 207 | t.Errorf("Failed to Unmarshal JSON: error = %v", err) 208 | return nil 209 | } 210 | return got 211 | } 212 | -------------------------------------------------------------------------------- /api_internal_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestOpenAIFullURL(t *testing.T) { 9 | cases := []struct { 10 | Name string 11 | Suffix string 12 | Expect string 13 | }{ 14 | { 15 | "ChatCompletionsURL", 16 | "/chat/completions", 17 | "https://api.openai.com/v1/chat/completions", 18 | }, 19 | { 20 | "CompletionsURL", 21 | "/completions", 22 | "https://api.openai.com/v1/completions", 23 | }, 24 | } 25 | 26 | for _, c := range cases { 27 | t.Run(c.Name, func(t *testing.T) { 28 | az := DefaultConfig("dummy") 29 | cli := NewClientWithConfig(az) 30 | actual := cli.fullURL(c.Suffix) 31 | if actual != c.Expect { 32 | t.Errorf("Expected %s, got %s", c.Expect, actual) 33 | } 34 | t.Logf("Full URL: %s", actual) 35 | }) 36 | } 37 | } 38 | 39 | func TestRequestAuthHeader(t *testing.T) { 40 | cases := []struct { 41 | Name string 42 | APIType APIType 43 | HeaderKey string 44 | Token string 45 | OrgID string 46 | Expect string 47 | }{ 48 | { 49 | "OpenAIDefault", 50 | "", 51 | "Authorization", 52 | "dummy-token-openai", 53 | "", 54 | "Bearer dummy-token-openai", 55 | }, 56 | { 57 | "OpenAIOrg", 58 | APITypeOpenAI, 59 | "Authorization", 60 | "dummy-token-openai", 61 | "dummy-org-openai", 62 | "Bearer dummy-token-openai", 63 | }, 64 | { 65 | "OpenAI", 66 | APITypeOpenAI, 67 | "Authorization", 68 | "dummy-token-openai", 69 | "", 70 | "Bearer dummy-token-openai", 71 | }, 72 | { 73 | "AzureAD", 74 | APITypeAzureAD, 75 | "Authorization", 76 | "dummy-token-azure", 77 | "", 78 | "Bearer dummy-token-azure", 79 | }, 80 | { 81 | "Azure", 82 | APITypeAzure, 83 | AzureAPIKeyHeader, 84 | "dummy-api-key-here", 85 | "", 86 | "dummy-api-key-here", 87 | }, 88 | } 89 | 90 | for _, c := range cases { 91 | t.Run(c.Name, func(t *testing.T) { 92 | az := DefaultConfig(c.Token) 93 | az.APIType = c.APIType 94 | az.OrgID = c.OrgID 95 | 96 | cli := NewClientWithConfig(az) 97 | req, err := cli.newRequest(context.Background(), "POST", "/chat/completions") 98 | if err != nil { 99 | t.Errorf("Failed to create request: %v", err) 100 | } 101 | actual := req.Header.Get(c.HeaderKey) 102 | if actual != c.Expect { 103 | t.Errorf("Expected %s, got %s", c.Expect, actual) 104 | } 105 | t.Logf("%s: %s", c.HeaderKey, actual) 106 | }) 107 | } 108 | } 109 | 110 | func TestAzureFullURL(t *testing.T) { 111 | cases := []struct { 112 | Name string 113 | BaseURL string 114 | AzureModelMapper map[string]string 115 | Suffix string 116 | Model string 117 | Expect string 118 | }{ 119 | { 120 | "AzureBaseURLWithSlashAutoStrip", 121 | "https://httpbin.org/", 122 | nil, 123 | "/chat/completions", 124 | "chatgpt-demo", 125 | "https://httpbin.org/" + 126 | "openai/deployments/chatgpt-demo" + 127 | "/chat/completions?api-version=2023-05-15", 128 | }, 129 | { 130 | "AzureBaseURLWithoutSlashOK", 131 | "https://httpbin.org", 132 | nil, 133 | "/chat/completions", 134 | "chatgpt-demo", 135 | "https://httpbin.org/" + 136 | "openai/deployments/chatgpt-demo" + 137 | "/chat/completions?api-version=2023-05-15", 138 | }, 139 | { 140 | "", 141 | "https://httpbin.org", 142 | nil, 143 | "/assistants?limit=10", 144 | "chatgpt-demo", 145 | "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", 146 | }, 147 | } 148 | 149 | for _, c := range cases { 150 | t.Run(c.Name, func(t *testing.T) { 151 | az := DefaultAzureConfig("dummy", c.BaseURL) 152 | cli := NewClientWithConfig(az) 153 | // /openai/deployments/{engine}/chat/completions?api-version={api_version} 154 | actual := cli.fullURL(c.Suffix, withModel(c.Model)) 155 | if actual != c.Expect { 156 | t.Errorf("Expected %s, got %s", c.Expect, actual) 157 | } 158 | t.Logf("Full URL: %s", actual) 159 | }) 160 | } 161 | } 162 | 163 | func TestCloudflareAzureFullURL(t *testing.T) { 164 | cases := []struct { 165 | Name string 166 | BaseURL string 167 | Suffix string 168 | Expect string 169 | }{ 170 | { 171 | "CloudflareAzureBaseURLWithSlashAutoStrip", 172 | "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", 173 | "/chat/completions", 174 | "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + 175 | "chat/completions?api-version=2023-05-15", 176 | }, 177 | { 178 | "", 179 | "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", 180 | "/assistants?limit=10", 181 | "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + 182 | "/assistants?api-version=2023-05-15&limit=10", 183 | }, 184 | } 185 | 186 | for _, c := range cases { 187 | t.Run(c.Name, func(t *testing.T) { 188 | az := DefaultAzureConfig("dummy", c.BaseURL) 189 | az.APIType = APITypeCloudflareAzure 190 | 191 | cli := NewClientWithConfig(az) 192 | 193 | actual := cli.fullURL(c.Suffix) 194 | if actual != c.Expect { 195 | t.Errorf("Expected %s, got %s", c.Expect, actual) 196 | } 197 | t.Logf("Full URL: %s", actual) 198 | }) 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /thread.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | const ( 9 | threadsSuffix = "/threads" 10 | ) 11 | 12 | type Thread struct { 13 | ID string `json:"id"` 14 | Object string `json:"object"` 15 | CreatedAt int64 `json:"created_at"` 16 | Metadata map[string]any `json:"metadata"` 17 | ToolResources ToolResources `json:"tool_resources,omitempty"` 18 | 19 | httpHeader 20 | } 21 | 22 | type ThreadRequest struct { 23 | Messages []ThreadMessage `json:"messages,omitempty"` 24 | Metadata map[string]any `json:"metadata,omitempty"` 25 | ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` 26 | } 27 | 28 | type ToolResources struct { 29 | CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` 30 | FileSearch *FileSearchToolResources `json:"file_search,omitempty"` 31 | } 32 | 33 | type CodeInterpreterToolResources struct { 34 | FileIDs []string `json:"file_ids,omitempty"` 35 | } 36 | 37 | type FileSearchToolResources struct { 38 | VectorStoreIDs []string `json:"vector_store_ids,omitempty"` 39 | } 40 | 41 | type ToolResourcesRequest struct { 42 | CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` 43 | FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` 44 | } 45 | 46 | type CodeInterpreterToolResourcesRequest struct { 47 | FileIDs []string `json:"file_ids,omitempty"` 48 | } 49 | 50 | type FileSearchToolResourcesRequest struct { 51 | VectorStoreIDs []string `json:"vector_store_ids,omitempty"` 52 | VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` 53 | } 54 | 55 | type VectorStoreToolResources struct { 56 | FileIDs []string `json:"file_ids,omitempty"` 57 | ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` 58 | Metadata map[string]any `json:"metadata,omitempty"` 59 | } 60 | 61 | type ChunkingStrategy struct { 62 | Type ChunkingStrategyType `json:"type"` 63 | Static *StaticChunkingStrategy `json:"static,omitempty"` 64 | } 65 | 66 | type StaticChunkingStrategy struct { 67 | MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` 68 | ChunkOverlapTokens int `json:"chunk_overlap_tokens"` 69 | } 70 | 71 | type ChunkingStrategyType string 72 | 73 | const ( 74 | ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" 75 | ChunkingStrategyTypeStatic ChunkingStrategyType = "static" 76 | ) 77 | 78 | type ModifyThreadRequest struct { 79 | Metadata map[string]any `json:"metadata"` 80 | ToolResources *ToolResources `json:"tool_resources,omitempty"` 81 | } 82 | 83 | type ThreadMessageRole string 84 | 85 | const ( 86 | ThreadMessageRoleAssistant ThreadMessageRole = "assistant" 87 | ThreadMessageRoleUser ThreadMessageRole = "user" 88 | ) 89 | 90 | type ThreadMessage struct { 91 | Role ThreadMessageRole `json:"role"` 92 | Content string `json:"content"` 93 | FileIDs []string `json:"file_ids,omitempty"` 94 | Attachments []ThreadAttachment `json:"attachments,omitempty"` 95 | Metadata map[string]any `json:"metadata,omitempty"` 96 | } 97 | 98 | type ThreadAttachment struct { 99 | FileID string `json:"file_id"` 100 | Tools []ThreadAttachmentTool `json:"tools"` 101 | } 102 | 103 | type ThreadAttachmentTool struct { 104 | Type string `json:"type"` 105 | } 106 | 107 | type ThreadDeleteResponse struct { 108 | ID string `json:"id"` 109 | Object string `json:"object"` 110 | Deleted bool `json:"deleted"` 111 | 112 | httpHeader 113 | } 114 | 115 | // CreateThread creates a new thread. 116 | func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { 117 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), 118 | withBetaAssistantVersion(c.config.AssistantVersion)) 119 | if err != nil { 120 | return 121 | } 122 | 123 | err = c.sendRequest(req, &response) 124 | return 125 | } 126 | 127 | // RetrieveThread retrieves a thread. 128 | func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { 129 | urlSuffix := threadsSuffix + "/" + threadID 130 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), 131 | withBetaAssistantVersion(c.config.AssistantVersion)) 132 | if err != nil { 133 | return 134 | } 135 | 136 | err = c.sendRequest(req, &response) 137 | return 138 | } 139 | 140 | // ModifyThread modifies a thread. 141 | func (c *Client) ModifyThread( 142 | ctx context.Context, 143 | threadID string, 144 | request ModifyThreadRequest, 145 | ) (response Thread, err error) { 146 | urlSuffix := threadsSuffix + "/" + threadID 147 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), 148 | withBetaAssistantVersion(c.config.AssistantVersion)) 149 | if err != nil { 150 | return 151 | } 152 | 153 | err = c.sendRequest(req, &response) 154 | return 155 | } 156 | 157 | // DeleteThread deletes a thread. 158 | func (c *Client) DeleteThread( 159 | ctx context.Context, 160 | threadID string, 161 | ) (response ThreadDeleteResponse, err error) { 162 | urlSuffix := threadsSuffix + "/" + threadID 163 | req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), 164 | withBetaAssistantVersion(c.config.AssistantVersion)) 165 | if err != nil { 166 | return 167 | } 168 | 169 | err = c.sendRequest(req, &response) 170 | return 171 | } 172 | -------------------------------------------------------------------------------- /moderation_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strconv" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/sashabaranov/go-openai" 15 | "github.com/sashabaranov/go-openai/internal/test/checks" 16 | ) 17 | 18 | // TestModeration Tests the moderations endpoint of the API using the mocked server. 19 | func TestModerations(t *testing.T) { 20 | client, server, teardown := setupOpenAITestServer() 21 | defer teardown() 22 | server.RegisterHandler("/v1/moderations", handleModerationEndpoint) 23 | _, err := client.Moderations(context.Background(), openai.ModerationRequest{ 24 | Model: openai.ModerationTextStable, 25 | Input: "I want to kill them.", 26 | }) 27 | checks.NoError(t, err, "Moderation error") 28 | } 29 | 30 | // TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint. 31 | func TestModerationsWithDifferentModelOptions(t *testing.T) { 32 | var modelOptions []struct { 33 | model string 34 | expect error 35 | } 36 | modelOptions = append(modelOptions, 37 | getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), 38 | getModerationModelTestOption(openai.ModerationTextStable, nil), 39 | getModerationModelTestOption(openai.ModerationTextLatest, nil), 40 | getModerationModelTestOption(openai.ModerationOmni20240926, nil), 41 | getModerationModelTestOption(openai.ModerationOmniLatest, nil), 42 | getModerationModelTestOption("", nil), 43 | ) 44 | client, server, teardown := setupOpenAITestServer() 45 | defer teardown() 46 | server.RegisterHandler("/v1/moderations", handleModerationEndpoint) 47 | for _, modelTest := range modelOptions { 48 | _, err := client.Moderations(context.Background(), openai.ModerationRequest{ 49 | Model: modelTest.model, 50 | Input: "I want to kill them.", 51 | }) 52 | checks.ErrorIs(t, err, modelTest.expect, 53 | fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err)) 54 | } 55 | } 56 | 57 | func getModerationModelTestOption(model string, expect error) struct { 58 | model string 59 | expect error 60 | } { 61 | return struct { 62 | model string 63 | expect error 64 | }{model: model, expect: expect} 65 | } 66 | 67 | // handleModerationEndpoint Handles the moderation endpoint by the test server. 68 | func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { 69 | var err error 70 | var resBytes []byte 71 | 72 | // completions only accepts POST requests 73 | if r.Method != "POST" { 74 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 75 | } 76 | var moderationReq openai.ModerationRequest 77 | if moderationReq, err = getModerationBody(r); err != nil { 78 | http.Error(w, "could not read request", http.StatusInternalServerError) 79 | return 80 | } 81 | 82 | resCat := openai.ResultCategories{} 83 | resCatScore := openai.ResultCategoryScores{} 84 | switch { 85 | case strings.Contains(moderationReq.Input, "hate"): 86 | resCat = openai.ResultCategories{Hate: true} 87 | resCatScore = openai.ResultCategoryScores{Hate: 1} 88 | 89 | case strings.Contains(moderationReq.Input, "hate more"): 90 | resCat = openai.ResultCategories{HateThreatening: true} 91 | resCatScore = openai.ResultCategoryScores{HateThreatening: 1} 92 | 93 | case strings.Contains(moderationReq.Input, "harass"): 94 | resCat = openai.ResultCategories{Harassment: true} 95 | resCatScore = openai.ResultCategoryScores{Harassment: 1} 96 | 97 | case strings.Contains(moderationReq.Input, "harass hard"): 98 | resCat = openai.ResultCategories{Harassment: true} 99 | resCatScore = openai.ResultCategoryScores{HarassmentThreatening: 1} 100 | 101 | case strings.Contains(moderationReq.Input, "suicide"): 102 | resCat = openai.ResultCategories{SelfHarm: true} 103 | resCatScore = openai.ResultCategoryScores{SelfHarm: 1} 104 | 105 | case strings.Contains(moderationReq.Input, "wanna suicide"): 106 | resCat = openai.ResultCategories{SelfHarmIntent: true} 107 | resCatScore = openai.ResultCategoryScores{SelfHarm: 1} 108 | 109 | case strings.Contains(moderationReq.Input, "drink bleach"): 110 | resCat = openai.ResultCategories{SelfHarmInstructions: true} 111 | resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: 1} 112 | 113 | case strings.Contains(moderationReq.Input, "porn"): 114 | resCat = openai.ResultCategories{Sexual: true} 115 | resCatScore = openai.ResultCategoryScores{Sexual: 1} 116 | 117 | case strings.Contains(moderationReq.Input, "child porn"): 118 | resCat = openai.ResultCategories{SexualMinors: true} 119 | resCatScore = openai.ResultCategoryScores{SexualMinors: 1} 120 | 121 | case strings.Contains(moderationReq.Input, "kill"): 122 | resCat = openai.ResultCategories{Violence: true} 123 | resCatScore = openai.ResultCategoryScores{Violence: 1} 124 | 125 | case strings.Contains(moderationReq.Input, "corpse"): 126 | resCat = openai.ResultCategories{ViolenceGraphic: true} 127 | resCatScore = openai.ResultCategoryScores{ViolenceGraphic: 1} 128 | } 129 | 130 | result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} 131 | 132 | res := openai.ModerationResponse{ 133 | ID: strconv.Itoa(int(time.Now().Unix())), 134 | Model: moderationReq.Model, 135 | } 136 | res.Results = append(res.Results, result) 137 | 138 | resBytes, _ = json.Marshal(res) 139 | fmt.Fprintln(w, string(resBytes)) 140 | } 141 | 142 | // getModerationBody Returns the body of the request to do a moderation. 143 | func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { 144 | moderation := openai.ModerationRequest{} 145 | // read the request body 146 | reqBody, err := io.ReadAll(r.Body) 147 | if err != nil { 148 | return openai.ModerationRequest{}, err 149 | } 150 | err = json.Unmarshal(reqBody, &moderation) 151 | if err != nil { 152 | return openai.ModerationRequest{}, err 153 | } 154 | return moderation, nil 155 | } 156 | -------------------------------------------------------------------------------- /image.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "net/http" 7 | "os" 8 | "strconv" 9 | ) 10 | 11 | // Image sizes defined by the OpenAI API. 12 | const ( 13 | CreateImageSize256x256 = "256x256" 14 | CreateImageSize512x512 = "512x512" 15 | CreateImageSize1024x1024 = "1024x1024" 16 | // dall-e-3 supported only. 17 | CreateImageSize1792x1024 = "1792x1024" 18 | CreateImageSize1024x1792 = "1024x1792" 19 | ) 20 | 21 | const ( 22 | CreateImageResponseFormatURL = "url" 23 | CreateImageResponseFormatB64JSON = "b64_json" 24 | ) 25 | 26 | const ( 27 | CreateImageModelDallE2 = "dall-e-2" 28 | CreateImageModelDallE3 = "dall-e-3" 29 | ) 30 | 31 | const ( 32 | CreateImageQualityHD = "hd" 33 | CreateImageQualityStandard = "standard" 34 | ) 35 | 36 | const ( 37 | CreateImageStyleVivid = "vivid" 38 | CreateImageStyleNatural = "natural" 39 | ) 40 | 41 | // ImageRequest represents the request structure for the image API. 42 | type ImageRequest struct { 43 | Prompt string `json:"prompt,omitempty"` 44 | Model string `json:"model,omitempty"` 45 | N int `json:"n,omitempty"` 46 | Quality string `json:"quality,omitempty"` 47 | Size string `json:"size,omitempty"` 48 | Style string `json:"style,omitempty"` 49 | ResponseFormat string `json:"response_format,omitempty"` 50 | User string `json:"user,omitempty"` 51 | } 52 | 53 | // ImageResponse represents a response structure for image API. 54 | type ImageResponse struct { 55 | Created int64 `json:"created,omitempty"` 56 | Data []ImageResponseDataInner `json:"data,omitempty"` 57 | 58 | httpHeader 59 | } 60 | 61 | // ImageResponseDataInner represents a response data structure for image API. 62 | type ImageResponseDataInner struct { 63 | URL string `json:"url,omitempty"` 64 | B64JSON string `json:"b64_json,omitempty"` 65 | RevisedPrompt string `json:"revised_prompt,omitempty"` 66 | } 67 | 68 | // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. 69 | func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { 70 | urlSuffix := "/images/generations" 71 | req, err := c.newRequest( 72 | ctx, 73 | http.MethodPost, 74 | c.fullURL(urlSuffix, withModel(request.Model)), 75 | withBody(request), 76 | ) 77 | if err != nil { 78 | return 79 | } 80 | 81 | err = c.sendRequest(req, &response) 82 | return 83 | } 84 | 85 | // ImageEditRequest represents the request structure for the image API. 86 | type ImageEditRequest struct { 87 | Image *os.File `json:"image,omitempty"` 88 | Mask *os.File `json:"mask,omitempty"` 89 | Prompt string `json:"prompt,omitempty"` 90 | Model string `json:"model,omitempty"` 91 | N int `json:"n,omitempty"` 92 | Size string `json:"size,omitempty"` 93 | ResponseFormat string `json:"response_format,omitempty"` 94 | } 95 | 96 | // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. 97 | func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) (response ImageResponse, err error) { 98 | body := &bytes.Buffer{} 99 | builder := c.createFormBuilder(body) 100 | 101 | // image 102 | err = builder.CreateFormFile("image", request.Image) 103 | if err != nil { 104 | return 105 | } 106 | 107 | // mask, it is optional 108 | if request.Mask != nil { 109 | err = builder.CreateFormFile("mask", request.Mask) 110 | if err != nil { 111 | return 112 | } 113 | } 114 | 115 | err = builder.WriteField("prompt", request.Prompt) 116 | if err != nil { 117 | return 118 | } 119 | 120 | err = builder.WriteField("n", strconv.Itoa(request.N)) 121 | if err != nil { 122 | return 123 | } 124 | 125 | err = builder.WriteField("size", request.Size) 126 | if err != nil { 127 | return 128 | } 129 | 130 | err = builder.WriteField("response_format", request.ResponseFormat) 131 | if err != nil { 132 | return 133 | } 134 | 135 | err = builder.Close() 136 | if err != nil { 137 | return 138 | } 139 | 140 | req, err := c.newRequest( 141 | ctx, 142 | http.MethodPost, 143 | c.fullURL("/images/edits", withModel(request.Model)), 144 | withBody(body), 145 | withContentType(builder.FormDataContentType()), 146 | ) 147 | if err != nil { 148 | return 149 | } 150 | 151 | err = c.sendRequest(req, &response) 152 | return 153 | } 154 | 155 | // ImageVariRequest represents the request structure for the image API. 156 | type ImageVariRequest struct { 157 | Image *os.File `json:"image,omitempty"` 158 | Model string `json:"model,omitempty"` 159 | N int `json:"n,omitempty"` 160 | Size string `json:"size,omitempty"` 161 | ResponseFormat string `json:"response_format,omitempty"` 162 | } 163 | 164 | // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. 165 | // Use abbreviations(vari for variation) because ci-lint has a single-line length limit ... 166 | func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) { 167 | body := &bytes.Buffer{} 168 | builder := c.createFormBuilder(body) 169 | 170 | // image 171 | err = builder.CreateFormFile("image", request.Image) 172 | if err != nil { 173 | return 174 | } 175 | 176 | err = builder.WriteField("n", strconv.Itoa(request.N)) 177 | if err != nil { 178 | return 179 | } 180 | 181 | err = builder.WriteField("size", request.Size) 182 | if err != nil { 183 | return 184 | } 185 | 186 | err = builder.WriteField("response_format", request.ResponseFormat) 187 | if err != nil { 188 | return 189 | } 190 | 191 | err = builder.Close() 192 | if err != nil { 193 | return 194 | } 195 | 196 | req, err := c.newRequest( 197 | ctx, 198 | http.MethodPost, 199 | c.fullURL("/images/variations", withModel(request.Model)), 200 | withBody(body), 201 | withContentType(builder.FormDataContentType()), 202 | ) 203 | if err != nil { 204 | return 205 | } 206 | 207 | err = c.sendRequest(req, &response) 208 | return 209 | } 210 | -------------------------------------------------------------------------------- /jsonschema/validate_test.go: -------------------------------------------------------------------------------- 1 | package jsonschema_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/sashabaranov/go-openai/jsonschema" 7 | ) 8 | 9 | func Test_Validate(t *testing.T) { 10 | type args struct { 11 | data any 12 | schema jsonschema.Definition 13 | } 14 | tests := []struct { 15 | name string 16 | args args 17 | want bool 18 | }{ 19 | // string integer number boolean 20 | {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, 21 | {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, 22 | {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, 23 | {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, 24 | {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, 25 | {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, 26 | {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, 27 | {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, 28 | {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, 29 | {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, 30 | // array 31 | {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ 32 | Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, 33 | }, true}, 34 | {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ 35 | Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, 36 | }, false}, 37 | {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ 38 | Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, 39 | }, true}, 40 | {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ 41 | Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, 42 | }, false}, 43 | // object 44 | {"", args{data: map[string]any{ 45 | "string": "abc", 46 | "integer": 123, 47 | "number": 123.4, 48 | "boolean": false, 49 | "array": []any{1, 2, 3}, 50 | }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ 51 | "string": {Type: jsonschema.String}, 52 | "integer": {Type: jsonschema.Integer}, 53 | "number": {Type: jsonschema.Number}, 54 | "boolean": {Type: jsonschema.Boolean}, 55 | "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, 56 | }, 57 | Required: []string{"string"}, 58 | }}, true}, 59 | {"", args{data: map[string]any{ 60 | "integer": 123, 61 | "number": 123.4, 62 | "boolean": false, 63 | "array": []any{1, 2, 3}, 64 | }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ 65 | "string": {Type: jsonschema.String}, 66 | "integer": {Type: jsonschema.Integer}, 67 | "number": {Type: jsonschema.Number}, 68 | "boolean": {Type: jsonschema.Boolean}, 69 | "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, 70 | }, 71 | Required: []string{"string"}, 72 | }}, false}, 73 | } 74 | for _, tt := range tests { 75 | t.Run(tt.name, func(t *testing.T) { 76 | if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { 77 | t.Errorf("Validate() = %v, want %v", got, tt.want) 78 | } 79 | }) 80 | } 81 | } 82 | 83 | func TestUnmarshal(t *testing.T) { 84 | type args struct { 85 | schema jsonschema.Definition 86 | content []byte 87 | v any 88 | } 89 | tests := []struct { 90 | name string 91 | args args 92 | wantErr bool 93 | }{ 94 | {"", args{ 95 | schema: jsonschema.Definition{ 96 | Type: jsonschema.Object, 97 | Properties: map[string]jsonschema.Definition{ 98 | "string": {Type: jsonschema.String}, 99 | "number": {Type: jsonschema.Number}, 100 | }, 101 | }, 102 | content: []byte(`{"string":"abc","number":123.4}`), 103 | v: &struct { 104 | String string `json:"string"` 105 | Number float64 `json:"number"` 106 | }{}, 107 | }, false}, 108 | {"", args{ 109 | schema: jsonschema.Definition{ 110 | Type: jsonschema.Object, 111 | Properties: map[string]jsonschema.Definition{ 112 | "string": {Type: jsonschema.String}, 113 | "number": {Type: jsonschema.Number}, 114 | }, 115 | Required: []string{"string", "number"}, 116 | }, 117 | content: []byte(`{"string":"abc"}`), 118 | v: struct { 119 | String string `json:"string"` 120 | Number float64 `json:"number"` 121 | }{}, 122 | }, true}, 123 | {"validate integer", args{ 124 | schema: jsonschema.Definition{ 125 | Type: jsonschema.Object, 126 | Properties: map[string]jsonschema.Definition{ 127 | "string": {Type: jsonschema.String}, 128 | "integer": {Type: jsonschema.Integer}, 129 | }, 130 | Required: []string{"string", "integer"}, 131 | }, 132 | content: []byte(`{"string":"abc","integer":123}`), 133 | v: &struct { 134 | String string `json:"string"` 135 | Integer int `json:"integer"` 136 | }{}, 137 | }, false}, 138 | {"validate integer failed", args{ 139 | schema: jsonschema.Definition{ 140 | Type: jsonschema.Object, 141 | Properties: map[string]jsonschema.Definition{ 142 | "string": {Type: jsonschema.String}, 143 | "integer": {Type: jsonschema.Integer}, 144 | }, 145 | Required: []string{"string", "integer"}, 146 | }, 147 | content: []byte(`{"string":"abc","integer":123.4}`), 148 | v: &struct { 149 | String string `json:"string"` 150 | Integer int `json:"integer"` 151 | }{}, 152 | }, true}, 153 | } 154 | for _, tt := range tests { 155 | t.Run(tt.name, func(t *testing.T) { 156 | err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) 157 | if (err != nil) != tt.wantErr { 158 | t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) 159 | } else if err == nil { 160 | t.Logf("Unmarshal() v = %+v\n", tt.args.v) 161 | } 162 | }) 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /completion_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "strconv" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/sashabaranov/go-openai" 16 | "github.com/sashabaranov/go-openai/internal/test/checks" 17 | ) 18 | 19 | func TestCompletionsWrongModel(t *testing.T) { 20 | config := openai.DefaultConfig("whatever") 21 | config.BaseURL = "http://localhost/v1" 22 | client := openai.NewClientWithConfig(config) 23 | 24 | _, err := client.CreateCompletion( 25 | context.Background(), 26 | openai.CompletionRequest{ 27 | MaxTokens: 5, 28 | Model: openai.GPT3Dot5Turbo, 29 | }, 30 | ) 31 | if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { 32 | t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) 33 | } 34 | } 35 | 36 | func TestCompletionWithStream(t *testing.T) { 37 | config := openai.DefaultConfig("whatever") 38 | client := openai.NewClientWithConfig(config) 39 | 40 | ctx := context.Background() 41 | req := openai.CompletionRequest{Stream: true} 42 | _, err := client.CreateCompletion(ctx, req) 43 | if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { 44 | t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") 45 | } 46 | } 47 | 48 | // TestCompletions Tests the completions endpoint of the API using the mocked server. 49 | func TestCompletions(t *testing.T) { 50 | client, server, teardown := setupOpenAITestServer() 51 | defer teardown() 52 | server.RegisterHandler("/v1/completions", handleCompletionEndpoint) 53 | req := openai.CompletionRequest{ 54 | MaxTokens: 5, 55 | Model: "ada", 56 | Prompt: "Lorem ipsum", 57 | } 58 | _, err := client.CreateCompletion(context.Background(), req) 59 | checks.NoError(t, err, "CreateCompletion error") 60 | } 61 | 62 | // TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server 63 | // where the completions requests has a list of prompts with wrong type. 64 | func TestMultiplePromptsCompletionsWrong(t *testing.T) { 65 | client, server, teardown := setupOpenAITestServer() 66 | defer teardown() 67 | server.RegisterHandler("/v1/completions", handleCompletionEndpoint) 68 | req := openai.CompletionRequest{ 69 | MaxTokens: 5, 70 | Model: "ada", 71 | Prompt: []interface{}{"Lorem ipsum", 9}, 72 | } 73 | _, err := client.CreateCompletion(context.Background(), req) 74 | if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { 75 | t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) 76 | } 77 | } 78 | 79 | // TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server 80 | // where the completions requests has a list of prompts. 81 | func TestMultiplePromptsCompletions(t *testing.T) { 82 | client, server, teardown := setupOpenAITestServer() 83 | defer teardown() 84 | server.RegisterHandler("/v1/completions", handleCompletionEndpoint) 85 | req := openai.CompletionRequest{ 86 | MaxTokens: 5, 87 | Model: "ada", 88 | Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, 89 | } 90 | _, err := client.CreateCompletion(context.Background(), req) 91 | checks.NoError(t, err, "CreateCompletion error") 92 | } 93 | 94 | // handleCompletionEndpoint Handles the completion endpoint by the test server. 95 | func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { 96 | var err error 97 | var resBytes []byte 98 | 99 | // completions only accepts POST requests 100 | if r.Method != "POST" { 101 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 102 | } 103 | var completionReq openai.CompletionRequest 104 | if completionReq, err = getCompletionBody(r); err != nil { 105 | http.Error(w, "could not read request", http.StatusInternalServerError) 106 | return 107 | } 108 | res := openai.CompletionResponse{ 109 | ID: strconv.Itoa(int(time.Now().Unix())), 110 | Object: "test-object", 111 | Created: time.Now().Unix(), 112 | // would be nice to validate Model during testing, but 113 | // this may not be possible with how much upkeep 114 | // would be required / wouldn't make much sense 115 | Model: completionReq.Model, 116 | } 117 | // create completions 118 | n := completionReq.N 119 | if n == 0 { 120 | n = 1 121 | } 122 | // Handle different types of prompts: single string or list of strings 123 | prompts := []string{} 124 | switch v := completionReq.Prompt.(type) { 125 | case string: 126 | prompts = append(prompts, v) 127 | case []interface{}: 128 | for _, item := range v { 129 | if str, ok := item.(string); ok { 130 | prompts = append(prompts, str) 131 | } 132 | } 133 | default: 134 | http.Error(w, "Invalid prompt type", http.StatusBadRequest) 135 | return 136 | } 137 | 138 | for i := 0; i < n; i++ { 139 | for _, prompt := range prompts { 140 | // Generate a random string of length completionReq.MaxTokens 141 | completionStr := strings.Repeat("a", completionReq.MaxTokens) 142 | if completionReq.Echo { 143 | completionStr = prompt + completionStr 144 | } 145 | 146 | res.Choices = append(res.Choices, openai.CompletionChoice{ 147 | Text: completionStr, 148 | Index: len(res.Choices), 149 | }) 150 | } 151 | } 152 | 153 | inputTokens := 0 154 | for _, prompt := range prompts { 155 | inputTokens += numTokens(prompt) 156 | } 157 | inputTokens *= n 158 | completionTokens := completionReq.MaxTokens * len(prompts) * n 159 | res.Usage = openai.Usage{ 160 | PromptTokens: inputTokens, 161 | CompletionTokens: completionTokens, 162 | TotalTokens: inputTokens + completionTokens, 163 | } 164 | 165 | // Serialize the response and send it back 166 | resBytes, _ = json.Marshal(res) 167 | fmt.Fprintln(w, string(resBytes)) 168 | } 169 | 170 | // getCompletionBody Returns the body of the request to create a completion. 171 | func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { 172 | completion := openai.CompletionRequest{} 173 | // read the request body 174 | reqBody, err := io.ReadAll(r.Body) 175 | if err != nil { 176 | return openai.CompletionRequest{}, err 177 | } 178 | err = json.Unmarshal(reqBody, &completion) 179 | if err != nil { 180 | return openai.CompletionRequest{}, err 181 | } 182 | return completion, nil 183 | } 184 | -------------------------------------------------------------------------------- /files_api_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "os" 11 | "strconv" 12 | "testing" 13 | "time" 14 | 15 | "github.com/sashabaranov/go-openai" 16 | "github.com/sashabaranov/go-openai/internal/test/checks" 17 | ) 18 | 19 | func TestFileBytesUpload(t *testing.T) { 20 | client, server, teardown := setupOpenAITestServer() 21 | defer teardown() 22 | server.RegisterHandler("/v1/files", handleCreateFile) 23 | req := openai.FileBytesRequest{ 24 | Name: "foo", 25 | Bytes: []byte("foo"), 26 | Purpose: openai.PurposeFineTune, 27 | } 28 | _, err := client.CreateFileBytes(context.Background(), req) 29 | checks.NoError(t, err, "CreateFile error") 30 | } 31 | 32 | func TestFileUpload(t *testing.T) { 33 | client, server, teardown := setupOpenAITestServer() 34 | defer teardown() 35 | server.RegisterHandler("/v1/files", handleCreateFile) 36 | req := openai.FileRequest{ 37 | FileName: "test.go", 38 | FilePath: "client.go", 39 | Purpose: "fine-tune", 40 | } 41 | _, err := client.CreateFile(context.Background(), req) 42 | checks.NoError(t, err, "CreateFile error") 43 | } 44 | 45 | // handleCreateFile Handles the images endpoint by the test server. 46 | func handleCreateFile(w http.ResponseWriter, r *http.Request) { 47 | var err error 48 | var resBytes []byte 49 | 50 | // edits only accepts POST requests 51 | if r.Method != "POST" { 52 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 53 | } 54 | err = r.ParseMultipartForm(1024 * 1024 * 1024) 55 | if err != nil { 56 | http.Error(w, "file is more than 1GB", http.StatusInternalServerError) 57 | return 58 | } 59 | 60 | values := r.Form 61 | var purpose string 62 | for key, value := range values { 63 | if key == "purpose" { 64 | purpose = value[0] 65 | } 66 | } 67 | file, header, err := r.FormFile("file") 68 | if err != nil { 69 | return 70 | } 71 | defer file.Close() 72 | 73 | fileReq := openai.File{ 74 | Bytes: int(header.Size), 75 | ID: strconv.Itoa(int(time.Now().Unix())), 76 | FileName: header.Filename, 77 | Purpose: purpose, 78 | CreatedAt: time.Now().Unix(), 79 | Object: "test-objecct", 80 | } 81 | 82 | resBytes, _ = json.Marshal(fileReq) 83 | fmt.Fprint(w, string(resBytes)) 84 | } 85 | 86 | func TestDeleteFile(t *testing.T) { 87 | client, server, teardown := setupOpenAITestServer() 88 | defer teardown() 89 | server.RegisterHandler("/v1/files/deadbeef", func(http.ResponseWriter, *http.Request) {}) 90 | err := client.DeleteFile(context.Background(), "deadbeef") 91 | checks.NoError(t, err, "DeleteFile error") 92 | } 93 | 94 | func TestListFile(t *testing.T) { 95 | client, server, teardown := setupOpenAITestServer() 96 | defer teardown() 97 | server.RegisterHandler("/v1/files", func(w http.ResponseWriter, _ *http.Request) { 98 | resBytes, _ := json.Marshal(openai.FilesList{}) 99 | fmt.Fprintln(w, string(resBytes)) 100 | }) 101 | _, err := client.ListFiles(context.Background()) 102 | checks.NoError(t, err, "ListFiles error") 103 | } 104 | 105 | func TestGetFile(t *testing.T) { 106 | client, server, teardown := setupOpenAITestServer() 107 | defer teardown() 108 | server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, _ *http.Request) { 109 | resBytes, _ := json.Marshal(openai.File{}) 110 | fmt.Fprintln(w, string(resBytes)) 111 | }) 112 | _, err := client.GetFile(context.Background(), "deadbeef") 113 | checks.NoError(t, err, "GetFile error") 114 | } 115 | 116 | func TestGetFileContent(t *testing.T) { 117 | wantRespJsonl := `{"prompt": "foo", "completion": "foo"} 118 | {"prompt": "bar", "completion": "bar"} 119 | {"prompt": "baz", "completion": "baz"} 120 | ` 121 | client, server, teardown := setupOpenAITestServer() 122 | defer teardown() 123 | server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { 124 | // edits only accepts GET requests 125 | if r.Method != http.MethodGet { 126 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 127 | } 128 | fmt.Fprint(w, wantRespJsonl) 129 | }) 130 | 131 | content, err := client.GetFileContent(context.Background(), "deadbeef") 132 | checks.NoError(t, err, "GetFileContent error") 133 | defer content.Close() 134 | 135 | actual, _ := io.ReadAll(content) 136 | if string(actual) != wantRespJsonl { 137 | t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) 138 | } 139 | } 140 | 141 | func TestGetFileContentReturnError(t *testing.T) { 142 | wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." 143 | wantType := "invalid_request_error" 144 | wantErrorResp := `{ 145 | "error": { 146 | "message": "` + wantMessage + `", 147 | "type": "` + wantType + `", 148 | "param": null, 149 | "code": null 150 | } 151 | }` 152 | client, server, teardown := setupOpenAITestServer() 153 | defer teardown() 154 | server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { 155 | w.Header().Set("Content-Type", "application/json") 156 | w.WriteHeader(http.StatusBadRequest) 157 | fmt.Fprint(w, wantErrorResp) 158 | }) 159 | 160 | _, err := client.GetFileContent(context.Background(), "deadbeef") 161 | if err == nil { 162 | t.Fatal("Did not return error") 163 | } 164 | 165 | apiErr := &openai.APIError{} 166 | if !errors.As(err, &apiErr) { 167 | t.Fatalf("Did not return APIError: %+v\n", apiErr) 168 | } 169 | if apiErr.Message != wantMessage { 170 | t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) 171 | return 172 | } 173 | if apiErr.Type != wantType { 174 | t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) 175 | return 176 | } 177 | } 178 | 179 | func TestGetFileContentReturnTimeoutError(t *testing.T) { 180 | client, server, teardown := setupOpenAITestServer() 181 | defer teardown() 182 | server.RegisterHandler("/v1/files/deadbeef/content", func(http.ResponseWriter, *http.Request) { 183 | time.Sleep(10 * time.Nanosecond) 184 | }) 185 | ctx := context.Background() 186 | ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) 187 | defer cancel() 188 | 189 | _, err := client.GetFileContent(ctx, "deadbeef") 190 | if err == nil { 191 | t.Fatal("Did not return error") 192 | } 193 | if !os.IsTimeout(err) { 194 | t.Fatal("Did not return timeout error") 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /messages.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | ) 9 | 10 | const ( 11 | messagesSuffix = "messages" 12 | ) 13 | 14 | type Message struct { 15 | ID string `json:"id"` 16 | Object string `json:"object"` 17 | CreatedAt int `json:"created_at"` 18 | ThreadID string `json:"thread_id"` 19 | Role string `json:"role"` 20 | Content []MessageContent `json:"content"` 21 | FileIds []string `json:"file_ids"` //nolint:revive //backwards-compatibility 22 | AssistantID *string `json:"assistant_id,omitempty"` 23 | RunID *string `json:"run_id,omitempty"` 24 | Metadata map[string]any `json:"metadata"` 25 | 26 | httpHeader 27 | } 28 | 29 | type MessagesList struct { 30 | Messages []Message `json:"data"` 31 | 32 | Object string `json:"object"` 33 | FirstID *string `json:"first_id"` 34 | LastID *string `json:"last_id"` 35 | HasMore bool `json:"has_more"` 36 | 37 | httpHeader 38 | } 39 | 40 | type MessageContent struct { 41 | Type string `json:"type"` 42 | Text *MessageText `json:"text,omitempty"` 43 | ImageFile *ImageFile `json:"image_file,omitempty"` 44 | } 45 | type MessageText struct { 46 | Value string `json:"value"` 47 | Annotations []any `json:"annotations"` 48 | } 49 | 50 | type ImageFile struct { 51 | FileID string `json:"file_id"` 52 | } 53 | 54 | type MessageRequest struct { 55 | Role string `json:"role"` 56 | Content string `json:"content"` 57 | FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility 58 | Metadata map[string]any `json:"metadata,omitempty"` 59 | Attachments []ThreadAttachment `json:"attachments,omitempty"` 60 | } 61 | 62 | type MessageFile struct { 63 | ID string `json:"id"` 64 | Object string `json:"object"` 65 | CreatedAt int `json:"created_at"` 66 | MessageID string `json:"message_id"` 67 | 68 | httpHeader 69 | } 70 | 71 | type MessageFilesList struct { 72 | MessageFiles []MessageFile `json:"data"` 73 | 74 | httpHeader 75 | } 76 | 77 | type MessageDeletionStatus struct { 78 | ID string `json:"id"` 79 | Object string `json:"object"` 80 | Deleted bool `json:"deleted"` 81 | 82 | httpHeader 83 | } 84 | 85 | // CreateMessage creates a new message. 86 | func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { 87 | urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) 88 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), 89 | withBetaAssistantVersion(c.config.AssistantVersion)) 90 | if err != nil { 91 | return 92 | } 93 | 94 | err = c.sendRequest(req, &msg) 95 | return 96 | } 97 | 98 | // ListMessage fetches all messages in the thread. 99 | func (c *Client) ListMessage(ctx context.Context, threadID string, 100 | limit *int, 101 | order *string, 102 | after *string, 103 | before *string, 104 | runID *string, 105 | ) (messages MessagesList, err error) { 106 | urlValues := url.Values{} 107 | if limit != nil { 108 | urlValues.Add("limit", fmt.Sprintf("%d", *limit)) 109 | } 110 | if order != nil { 111 | urlValues.Add("order", *order) 112 | } 113 | if after != nil { 114 | urlValues.Add("after", *after) 115 | } 116 | if before != nil { 117 | urlValues.Add("before", *before) 118 | } 119 | if runID != nil { 120 | urlValues.Add("run_id", *runID) 121 | } 122 | 123 | encodedValues := "" 124 | if len(urlValues) > 0 { 125 | encodedValues = "?" + urlValues.Encode() 126 | } 127 | 128 | urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) 129 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), 130 | withBetaAssistantVersion(c.config.AssistantVersion)) 131 | if err != nil { 132 | return 133 | } 134 | 135 | err = c.sendRequest(req, &messages) 136 | return 137 | } 138 | 139 | // RetrieveMessage retrieves a Message. 140 | func (c *Client) RetrieveMessage( 141 | ctx context.Context, 142 | threadID, messageID string, 143 | ) (msg Message, err error) { 144 | urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) 145 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), 146 | withBetaAssistantVersion(c.config.AssistantVersion)) 147 | if err != nil { 148 | return 149 | } 150 | 151 | err = c.sendRequest(req, &msg) 152 | return 153 | } 154 | 155 | // ModifyMessage modifies a message. 156 | func (c *Client) ModifyMessage( 157 | ctx context.Context, 158 | threadID, messageID string, 159 | metadata map[string]string, 160 | ) (msg Message, err error) { 161 | urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) 162 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), 163 | withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) 164 | if err != nil { 165 | return 166 | } 167 | 168 | err = c.sendRequest(req, &msg) 169 | return 170 | } 171 | 172 | // RetrieveMessageFile fetches a message file. 173 | func (c *Client) RetrieveMessageFile( 174 | ctx context.Context, 175 | threadID, messageID, fileID string, 176 | ) (file MessageFile, err error) { 177 | urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) 178 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), 179 | withBetaAssistantVersion(c.config.AssistantVersion)) 180 | if err != nil { 181 | return 182 | } 183 | 184 | err = c.sendRequest(req, &file) 185 | return 186 | } 187 | 188 | // ListMessageFiles fetches all files attached to a message. 189 | func (c *Client) ListMessageFiles( 190 | ctx context.Context, 191 | threadID, messageID string, 192 | ) (files MessageFilesList, err error) { 193 | urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) 194 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), 195 | withBetaAssistantVersion(c.config.AssistantVersion)) 196 | if err != nil { 197 | return 198 | } 199 | 200 | err = c.sendRequest(req, &files) 201 | return 202 | } 203 | 204 | // DeleteMessage deletes a message.. 205 | func (c *Client) DeleteMessage( 206 | ctx context.Context, 207 | threadID, messageID string, 208 | ) (status MessageDeletionStatus, err error) { 209 | urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) 210 | req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), 211 | withBetaAssistantVersion(c.config.AssistantVersion)) 212 | if err != nil { 213 | return 214 | } 215 | 216 | err = c.sendRequest(req, &status) 217 | return 218 | } 219 | -------------------------------------------------------------------------------- /image_api_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "os" 10 | "testing" 11 | "time" 12 | 13 | "github.com/sashabaranov/go-openai" 14 | "github.com/sashabaranov/go-openai/internal/test/checks" 15 | ) 16 | 17 | func TestImages(t *testing.T) { 18 | client, server, teardown := setupOpenAITestServer() 19 | defer teardown() 20 | server.RegisterHandler("/v1/images/generations", handleImageEndpoint) 21 | _, err := client.CreateImage(context.Background(), openai.ImageRequest{ 22 | Prompt: "Lorem ipsum", 23 | Model: openai.CreateImageModelDallE3, 24 | N: 1, 25 | Quality: openai.CreateImageQualityHD, 26 | Size: openai.CreateImageSize1024x1024, 27 | Style: openai.CreateImageStyleVivid, 28 | ResponseFormat: openai.CreateImageResponseFormatURL, 29 | User: "user", 30 | }) 31 | checks.NoError(t, err, "CreateImage error") 32 | } 33 | 34 | // handleImageEndpoint Handles the images endpoint by the test server. 35 | func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { 36 | var err error 37 | var resBytes []byte 38 | 39 | // images only accepts POST requests 40 | if r.Method != "POST" { 41 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 42 | } 43 | var imageReq openai.ImageRequest 44 | if imageReq, err = getImageBody(r); err != nil { 45 | http.Error(w, "could not read request", http.StatusInternalServerError) 46 | return 47 | } 48 | res := openai.ImageResponse{ 49 | Created: time.Now().Unix(), 50 | } 51 | for i := 0; i < imageReq.N; i++ { 52 | imageData := openai.ImageResponseDataInner{} 53 | switch imageReq.ResponseFormat { 54 | case openai.CreateImageResponseFormatURL, "": 55 | imageData.URL = "https://example.com/image.png" 56 | case openai.CreateImageResponseFormatB64JSON: 57 | // This decodes to "{}" in base64. 58 | imageData.B64JSON = "e30K" 59 | default: 60 | http.Error(w, "invalid response format", http.StatusBadRequest) 61 | return 62 | } 63 | res.Data = append(res.Data, imageData) 64 | } 65 | resBytes, _ = json.Marshal(res) 66 | fmt.Fprintln(w, string(resBytes)) 67 | } 68 | 69 | // getImageBody Returns the body of the request to create a image. 70 | func getImageBody(r *http.Request) (openai.ImageRequest, error) { 71 | image := openai.ImageRequest{} 72 | // read the request body 73 | reqBody, err := io.ReadAll(r.Body) 74 | if err != nil { 75 | return openai.ImageRequest{}, err 76 | } 77 | err = json.Unmarshal(reqBody, &image) 78 | if err != nil { 79 | return openai.ImageRequest{}, err 80 | } 81 | return image, nil 82 | } 83 | 84 | func TestImageEdit(t *testing.T) { 85 | client, server, teardown := setupOpenAITestServer() 86 | defer teardown() 87 | server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) 88 | 89 | origin, err := os.Create("image.png") 90 | if err != nil { 91 | t.Error("open origin file error") 92 | return 93 | } 94 | 95 | mask, err := os.Create("mask.png") 96 | if err != nil { 97 | t.Error("open mask file error") 98 | return 99 | } 100 | 101 | defer func() { 102 | mask.Close() 103 | origin.Close() 104 | os.Remove("mask.png") 105 | os.Remove("image.png") 106 | }() 107 | 108 | _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ 109 | Image: origin, 110 | Mask: mask, 111 | Prompt: "There is a turtle in the pool", 112 | N: 3, 113 | Size: openai.CreateImageSize1024x1024, 114 | ResponseFormat: openai.CreateImageResponseFormatURL, 115 | }) 116 | checks.NoError(t, err, "CreateImage error") 117 | } 118 | 119 | func TestImageEditWithoutMask(t *testing.T) { 120 | client, server, teardown := setupOpenAITestServer() 121 | defer teardown() 122 | server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) 123 | 124 | origin, err := os.Create("image.png") 125 | if err != nil { 126 | t.Error("open origin file error") 127 | return 128 | } 129 | 130 | defer func() { 131 | origin.Close() 132 | os.Remove("image.png") 133 | }() 134 | 135 | _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ 136 | Image: origin, 137 | Prompt: "There is a turtle in the pool", 138 | N: 3, 139 | Size: openai.CreateImageSize1024x1024, 140 | ResponseFormat: openai.CreateImageResponseFormatURL, 141 | }) 142 | checks.NoError(t, err, "CreateImage error") 143 | } 144 | 145 | // handleEditImageEndpoint Handles the images endpoint by the test server. 146 | func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { 147 | var resBytes []byte 148 | 149 | // images only accepts POST requests 150 | if r.Method != "POST" { 151 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 152 | } 153 | 154 | responses := openai.ImageResponse{ 155 | Created: time.Now().Unix(), 156 | Data: []openai.ImageResponseDataInner{ 157 | { 158 | URL: "test-url1", 159 | B64JSON: "", 160 | }, 161 | { 162 | URL: "test-url2", 163 | B64JSON: "", 164 | }, 165 | { 166 | URL: "test-url3", 167 | B64JSON: "", 168 | }, 169 | }, 170 | } 171 | 172 | resBytes, _ = json.Marshal(responses) 173 | fmt.Fprintln(w, string(resBytes)) 174 | } 175 | 176 | func TestImageVariation(t *testing.T) { 177 | client, server, teardown := setupOpenAITestServer() 178 | defer teardown() 179 | server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) 180 | 181 | origin, err := os.Create("image.png") 182 | if err != nil { 183 | t.Error("open origin file error") 184 | return 185 | } 186 | 187 | defer func() { 188 | origin.Close() 189 | os.Remove("image.png") 190 | }() 191 | 192 | _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ 193 | Image: origin, 194 | N: 3, 195 | Size: openai.CreateImageSize1024x1024, 196 | ResponseFormat: openai.CreateImageResponseFormatURL, 197 | }) 198 | checks.NoError(t, err, "CreateImage error") 199 | } 200 | 201 | // handleVariateImageEndpoint Handles the images endpoint by the test server. 202 | func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { 203 | var resBytes []byte 204 | 205 | // images only accepts POST requests 206 | if r.Method != "POST" { 207 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 208 | } 209 | 210 | responses := openai.ImageResponse{ 211 | Created: time.Now().Unix(), 212 | Data: []openai.ImageResponseDataInner{ 213 | { 214 | URL: "test-url1", 215 | B64JSON: "", 216 | }, 217 | { 218 | URL: "test-url2", 219 | B64JSON: "", 220 | }, 221 | { 222 | URL: "test-url3", 223 | B64JSON: "", 224 | }, 225 | }, 226 | } 227 | 228 | resBytes, _ = json.Marshal(responses) 229 | fmt.Fprintln(w, string(resBytes)) 230 | } 231 | -------------------------------------------------------------------------------- /run_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | 6 | openai "github.com/sashabaranov/go-openai" 7 | "github.com/sashabaranov/go-openai/internal/test/checks" 8 | 9 | "encoding/json" 10 | "fmt" 11 | "net/http" 12 | "testing" 13 | ) 14 | 15 | // TestAssistant Tests the assistant endpoint of the API using the mocked server. 16 | func TestRun(t *testing.T) { 17 | assistantID := "asst_abc123" 18 | threadID := "thread_abc123" 19 | runID := "run_abc123" 20 | stepID := "step_abc123" 21 | limit := 20 22 | order := "desc" 23 | after := "asst_abc122" 24 | before := "asst_abc124" 25 | 26 | client, server, teardown := setupOpenAITestServer() 27 | defer teardown() 28 | 29 | server.RegisterHandler( 30 | "/v1/threads/"+threadID+"/runs/"+runID+"/steps/"+stepID, 31 | func(w http.ResponseWriter, r *http.Request) { 32 | if r.Method == http.MethodGet { 33 | resBytes, _ := json.Marshal(openai.RunStep{ 34 | ID: runID, 35 | Object: "run", 36 | CreatedAt: 1234567890, 37 | Status: openai.RunStepStatusCompleted, 38 | }) 39 | fmt.Fprintln(w, string(resBytes)) 40 | } 41 | }, 42 | ) 43 | 44 | server.RegisterHandler( 45 | "/v1/threads/"+threadID+"/runs/"+runID+"/steps", 46 | func(w http.ResponseWriter, r *http.Request) { 47 | if r.Method == http.MethodGet { 48 | resBytes, _ := json.Marshal(openai.RunStepList{ 49 | RunSteps: []openai.RunStep{ 50 | { 51 | ID: runID, 52 | Object: "run", 53 | CreatedAt: 1234567890, 54 | Status: openai.RunStepStatusCompleted, 55 | }, 56 | }, 57 | }) 58 | fmt.Fprintln(w, string(resBytes)) 59 | } 60 | }, 61 | ) 62 | 63 | server.RegisterHandler( 64 | "/v1/threads/"+threadID+"/runs/"+runID+"/cancel", 65 | func(w http.ResponseWriter, r *http.Request) { 66 | if r.Method == http.MethodPost { 67 | resBytes, _ := json.Marshal(openai.Run{ 68 | ID: runID, 69 | Object: "run", 70 | CreatedAt: 1234567890, 71 | Status: openai.RunStatusCancelling, 72 | }) 73 | fmt.Fprintln(w, string(resBytes)) 74 | } 75 | }, 76 | ) 77 | 78 | server.RegisterHandler( 79 | "/v1/threads/"+threadID+"/runs/"+runID+"/submit_tool_outputs", 80 | func(w http.ResponseWriter, r *http.Request) { 81 | if r.Method == http.MethodPost { 82 | resBytes, _ := json.Marshal(openai.Run{ 83 | ID: runID, 84 | Object: "run", 85 | CreatedAt: 1234567890, 86 | Status: openai.RunStatusCancelling, 87 | }) 88 | fmt.Fprintln(w, string(resBytes)) 89 | } 90 | }, 91 | ) 92 | 93 | server.RegisterHandler( 94 | "/v1/threads/"+threadID+"/runs/"+runID, 95 | func(w http.ResponseWriter, r *http.Request) { 96 | if r.Method == http.MethodGet { 97 | resBytes, _ := json.Marshal(openai.Run{ 98 | ID: runID, 99 | Object: "run", 100 | CreatedAt: 1234567890, 101 | Status: openai.RunStatusQueued, 102 | }) 103 | fmt.Fprintln(w, string(resBytes)) 104 | } else if r.Method == http.MethodPost { 105 | var request openai.RunModifyRequest 106 | err := json.NewDecoder(r.Body).Decode(&request) 107 | checks.NoError(t, err, "Decode error") 108 | 109 | resBytes, _ := json.Marshal(openai.Run{ 110 | ID: runID, 111 | Object: "run", 112 | CreatedAt: 1234567890, 113 | Status: openai.RunStatusQueued, 114 | Metadata: request.Metadata, 115 | }) 116 | fmt.Fprintln(w, string(resBytes)) 117 | } 118 | }, 119 | ) 120 | 121 | server.RegisterHandler( 122 | "/v1/threads/"+threadID+"/runs", 123 | func(w http.ResponseWriter, r *http.Request) { 124 | if r.Method == http.MethodPost { 125 | var request openai.RunRequest 126 | err := json.NewDecoder(r.Body).Decode(&request) 127 | checks.NoError(t, err, "Decode error") 128 | 129 | resBytes, _ := json.Marshal(openai.Run{ 130 | ID: runID, 131 | Object: "run", 132 | CreatedAt: 1234567890, 133 | Status: openai.RunStatusQueued, 134 | }) 135 | fmt.Fprintln(w, string(resBytes)) 136 | } else if r.Method == http.MethodGet { 137 | resBytes, _ := json.Marshal(openai.RunList{ 138 | Runs: []openai.Run{ 139 | { 140 | ID: runID, 141 | Object: "run", 142 | CreatedAt: 1234567890, 143 | Status: openai.RunStatusQueued, 144 | }, 145 | }, 146 | }) 147 | fmt.Fprintln(w, string(resBytes)) 148 | } 149 | }, 150 | ) 151 | 152 | server.RegisterHandler( 153 | "/v1/threads/runs", 154 | func(w http.ResponseWriter, r *http.Request) { 155 | if r.Method == http.MethodPost { 156 | var request openai.CreateThreadAndRunRequest 157 | err := json.NewDecoder(r.Body).Decode(&request) 158 | checks.NoError(t, err, "Decode error") 159 | 160 | resBytes, _ := json.Marshal(openai.Run{ 161 | ID: runID, 162 | Object: "run", 163 | CreatedAt: 1234567890, 164 | Status: openai.RunStatusQueued, 165 | }) 166 | fmt.Fprintln(w, string(resBytes)) 167 | } 168 | }, 169 | ) 170 | 171 | ctx := context.Background() 172 | 173 | _, err := client.CreateRun(ctx, threadID, openai.RunRequest{ 174 | AssistantID: assistantID, 175 | }) 176 | checks.NoError(t, err, "CreateRun error") 177 | 178 | _, err = client.RetrieveRun(ctx, threadID, runID) 179 | checks.NoError(t, err, "RetrieveRun error") 180 | 181 | _, err = client.ModifyRun(ctx, threadID, runID, openai.RunModifyRequest{ 182 | Metadata: map[string]any{ 183 | "key": "value", 184 | }, 185 | }) 186 | checks.NoError(t, err, "ModifyRun error") 187 | 188 | _, err = client.ListRuns( 189 | ctx, 190 | threadID, 191 | openai.Pagination{ 192 | Limit: &limit, 193 | Order: &order, 194 | After: &after, 195 | Before: &before, 196 | }, 197 | ) 198 | checks.NoError(t, err, "ListRuns error") 199 | 200 | _, err = client.SubmitToolOutputs(ctx, threadID, runID, 201 | openai.SubmitToolOutputsRequest{}) 202 | checks.NoError(t, err, "SubmitToolOutputs error") 203 | 204 | _, err = client.CancelRun(ctx, threadID, runID) 205 | checks.NoError(t, err, "CancelRun error") 206 | 207 | _, err = client.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ 208 | RunRequest: openai.RunRequest{ 209 | AssistantID: assistantID, 210 | }, 211 | Thread: openai.ThreadRequest{ 212 | Messages: []openai.ThreadMessage{ 213 | { 214 | Role: openai.ThreadMessageRoleUser, 215 | Content: "Hello, World!", 216 | }, 217 | }, 218 | }, 219 | }) 220 | checks.NoError(t, err, "CreateThreadAndRun error") 221 | 222 | _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) 223 | checks.NoError(t, err, "RetrieveRunStep error") 224 | 225 | _, err = client.ListRunSteps( 226 | ctx, 227 | threadID, 228 | runID, 229 | openai.Pagination{ 230 | Limit: &limit, 231 | Order: &order, 232 | After: &after, 233 | Before: &before, 234 | }, 235 | ) 236 | checks.NoError(t, err, "ListRunSteps error") 237 | } 238 | -------------------------------------------------------------------------------- /audio.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "os" 10 | 11 | utils "github.com/sashabaranov/go-openai/internal" 12 | ) 13 | 14 | // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. 15 | const ( 16 | Whisper1 = "whisper-1" 17 | ) 18 | 19 | // Response formats; Whisper uses AudioResponseFormatJSON by default. 20 | type AudioResponseFormat string 21 | 22 | const ( 23 | AudioResponseFormatJSON AudioResponseFormat = "json" 24 | AudioResponseFormatText AudioResponseFormat = "text" 25 | AudioResponseFormatSRT AudioResponseFormat = "srt" 26 | AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" 27 | AudioResponseFormatVTT AudioResponseFormat = "vtt" 28 | ) 29 | 30 | type TranscriptionTimestampGranularity string 31 | 32 | const ( 33 | TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" 34 | TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" 35 | ) 36 | 37 | // AudioRequest represents a request structure for audio API. 38 | type AudioRequest struct { 39 | Model string 40 | 41 | // FilePath is either an existing file in your filesystem or a filename representing the contents of Reader. 42 | FilePath string 43 | 44 | // Reader is an optional io.Reader when you do not want to use an existing file. 45 | Reader io.Reader 46 | 47 | Prompt string 48 | Temperature float32 49 | Language string // Only for transcription. 50 | Format AudioResponseFormat 51 | TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. 52 | } 53 | 54 | // AudioResponse represents a response structure for audio API. 55 | type AudioResponse struct { 56 | Task string `json:"task"` 57 | Language string `json:"language"` 58 | Duration float64 `json:"duration"` 59 | Segments []struct { 60 | ID int `json:"id"` 61 | Seek int `json:"seek"` 62 | Start float64 `json:"start"` 63 | End float64 `json:"end"` 64 | Text string `json:"text"` 65 | Tokens []int `json:"tokens"` 66 | Temperature float64 `json:"temperature"` 67 | AvgLogprob float64 `json:"avg_logprob"` 68 | CompressionRatio float64 `json:"compression_ratio"` 69 | NoSpeechProb float64 `json:"no_speech_prob"` 70 | Transient bool `json:"transient"` 71 | } `json:"segments"` 72 | Words []struct { 73 | Word string `json:"word"` 74 | Start float64 `json:"start"` 75 | End float64 `json:"end"` 76 | } `json:"words"` 77 | Text string `json:"text"` 78 | 79 | httpHeader 80 | } 81 | 82 | type audioTextResponse struct { 83 | Text string `json:"text"` 84 | 85 | httpHeader 86 | } 87 | 88 | func (r *audioTextResponse) ToAudioResponse() AudioResponse { 89 | return AudioResponse{ 90 | Text: r.Text, 91 | httpHeader: r.httpHeader, 92 | } 93 | } 94 | 95 | // CreateTranscription — API call to create a transcription. Returns transcribed text. 96 | func (c *Client) CreateTranscription( 97 | ctx context.Context, 98 | request AudioRequest, 99 | ) (response AudioResponse, err error) { 100 | return c.callAudioAPI(ctx, request, "transcriptions") 101 | } 102 | 103 | // CreateTranslation — API call to translate audio into English. 104 | func (c *Client) CreateTranslation( 105 | ctx context.Context, 106 | request AudioRequest, 107 | ) (response AudioResponse, err error) { 108 | return c.callAudioAPI(ctx, request, "translations") 109 | } 110 | 111 | // callAudioAPI — API call to an audio endpoint. 112 | func (c *Client) callAudioAPI( 113 | ctx context.Context, 114 | request AudioRequest, 115 | endpointSuffix string, 116 | ) (response AudioResponse, err error) { 117 | var formBody bytes.Buffer 118 | builder := c.createFormBuilder(&formBody) 119 | 120 | if err = audioMultipartForm(request, builder); err != nil { 121 | return AudioResponse{}, err 122 | } 123 | 124 | urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) 125 | req, err := c.newRequest( 126 | ctx, 127 | http.MethodPost, 128 | c.fullURL(urlSuffix, withModel(request.Model)), 129 | withBody(&formBody), 130 | withContentType(builder.FormDataContentType()), 131 | ) 132 | if err != nil { 133 | return AudioResponse{}, err 134 | } 135 | 136 | if request.HasJSONResponse() { 137 | err = c.sendRequest(req, &response) 138 | } else { 139 | var textResponse audioTextResponse 140 | err = c.sendRequest(req, &textResponse) 141 | response = textResponse.ToAudioResponse() 142 | } 143 | if err != nil { 144 | return AudioResponse{}, err 145 | } 146 | return 147 | } 148 | 149 | // HasJSONResponse returns true if the response format is JSON. 150 | func (r AudioRequest) HasJSONResponse() bool { 151 | return r.Format == "" || r.Format == AudioResponseFormatJSON || r.Format == AudioResponseFormatVerboseJSON 152 | } 153 | 154 | // audioMultipartForm creates a form with audio file contents and the name of the model to use for 155 | // audio processing. 156 | func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { 157 | err := createFileField(request, b) 158 | if err != nil { 159 | return err 160 | } 161 | 162 | err = b.WriteField("model", request.Model) 163 | if err != nil { 164 | return fmt.Errorf("writing model name: %w", err) 165 | } 166 | 167 | // Create a form field for the prompt (if provided) 168 | if request.Prompt != "" { 169 | err = b.WriteField("prompt", request.Prompt) 170 | if err != nil { 171 | return fmt.Errorf("writing prompt: %w", err) 172 | } 173 | } 174 | 175 | // Create a form field for the format (if provided) 176 | if request.Format != "" { 177 | err = b.WriteField("response_format", string(request.Format)) 178 | if err != nil { 179 | return fmt.Errorf("writing format: %w", err) 180 | } 181 | } 182 | 183 | // Create a form field for the temperature (if provided) 184 | if request.Temperature != 0 { 185 | err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) 186 | if err != nil { 187 | return fmt.Errorf("writing temperature: %w", err) 188 | } 189 | } 190 | 191 | // Create a form field for the language (if provided) 192 | if request.Language != "" { 193 | err = b.WriteField("language", request.Language) 194 | if err != nil { 195 | return fmt.Errorf("writing language: %w", err) 196 | } 197 | } 198 | 199 | if len(request.TimestampGranularities) > 0 { 200 | for _, tg := range request.TimestampGranularities { 201 | err = b.WriteField("timestamp_granularities[]", string(tg)) 202 | if err != nil { 203 | return fmt.Errorf("writing timestamp_granularities[]: %w", err) 204 | } 205 | } 206 | } 207 | 208 | // Close the multipart writer 209 | return b.Close() 210 | } 211 | 212 | // createFileField creates the "file" form field from either an existing file or by using the reader. 213 | func createFileField(request AudioRequest, b utils.FormBuilder) error { 214 | if request.Reader != nil { 215 | err := b.CreateFormFileReader("file", request.Reader, request.FilePath) 216 | if err != nil { 217 | return fmt.Errorf("creating form using reader: %w", err) 218 | } 219 | return nil 220 | } 221 | 222 | f, err := os.Open(request.FilePath) 223 | if err != nil { 224 | return fmt.Errorf("opening audio file: %w", err) 225 | } 226 | defer f.Close() 227 | 228 | err = b.CreateFormFile("file", f) 229 | if err != nil { 230 | return fmt.Errorf("creating form file: %w", err) 231 | } 232 | 233 | return nil 234 | } 235 | -------------------------------------------------------------------------------- /fine_tunes.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 10 | // This API will be officially deprecated on January 4th, 2024. 11 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 12 | type FineTuneRequest struct { 13 | TrainingFile string `json:"training_file"` 14 | ValidationFile string `json:"validation_file,omitempty"` 15 | Model string `json:"model,omitempty"` 16 | Epochs int `json:"n_epochs,omitempty"` 17 | BatchSize int `json:"batch_size,omitempty"` 18 | LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"` 19 | PromptLossRate float32 `json:"prompt_loss_rate,omitempty"` 20 | ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` 21 | ClassificationClasses int `json:"classification_n_classes,omitempty"` 22 | ClassificationPositiveClass string `json:"classification_positive_class,omitempty"` 23 | ClassificationBetas []float32 `json:"classification_betas,omitempty"` 24 | Suffix string `json:"suffix,omitempty"` 25 | } 26 | 27 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 28 | // This API will be officially deprecated on January 4th, 2024. 29 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 30 | type FineTune struct { 31 | ID string `json:"id"` 32 | Object string `json:"object"` 33 | Model string `json:"model"` 34 | CreatedAt int64 `json:"created_at"` 35 | FineTuneEventList []FineTuneEvent `json:"events,omitempty"` 36 | FineTunedModel string `json:"fine_tuned_model"` 37 | HyperParams FineTuneHyperParams `json:"hyperparams"` 38 | OrganizationID string `json:"organization_id"` 39 | ResultFiles []File `json:"result_files"` 40 | Status string `json:"status"` 41 | ValidationFiles []File `json:"validation_files"` 42 | TrainingFiles []File `json:"training_files"` 43 | UpdatedAt int64 `json:"updated_at"` 44 | 45 | httpHeader 46 | } 47 | 48 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 49 | // This API will be officially deprecated on January 4th, 2024. 50 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 51 | type FineTuneEvent struct { 52 | Object string `json:"object"` 53 | CreatedAt int64 `json:"created_at"` 54 | Level string `json:"level"` 55 | Message string `json:"message"` 56 | } 57 | 58 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 59 | // This API will be officially deprecated on January 4th, 2024. 60 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 61 | type FineTuneHyperParams struct { 62 | BatchSize int `json:"batch_size"` 63 | LearningRateMultiplier float64 `json:"learning_rate_multiplier"` 64 | Epochs int `json:"n_epochs"` 65 | PromptLossWeight float64 `json:"prompt_loss_weight"` 66 | } 67 | 68 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 69 | // This API will be officially deprecated on January 4th, 2024. 70 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 71 | type FineTuneList struct { 72 | Object string `json:"object"` 73 | Data []FineTune `json:"data"` 74 | 75 | httpHeader 76 | } 77 | 78 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 79 | // This API will be officially deprecated on January 4th, 2024. 80 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 81 | type FineTuneEventList struct { 82 | Object string `json:"object"` 83 | Data []FineTuneEvent `json:"data"` 84 | 85 | httpHeader 86 | } 87 | 88 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 89 | // This API will be officially deprecated on January 4th, 2024. 90 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 91 | type FineTuneDeleteResponse struct { 92 | ID string `json:"id"` 93 | Object string `json:"object"` 94 | Deleted bool `json:"deleted"` 95 | 96 | httpHeader 97 | } 98 | 99 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 100 | // This API will be officially deprecated on January 4th, 2024. 101 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 102 | func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { 103 | urlSuffix := "/fine-tunes" 104 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) 105 | if err != nil { 106 | return 107 | } 108 | 109 | err = c.sendRequest(req, &response) 110 | return 111 | } 112 | 113 | // CancelFineTune cancel a fine-tune job. 114 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 115 | // This API will be officially deprecated on January 4th, 2024. 116 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 117 | func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { 118 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated 119 | if err != nil { 120 | return 121 | } 122 | 123 | err = c.sendRequest(req, &response) 124 | return 125 | } 126 | 127 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 128 | // This API will be officially deprecated on January 4th, 2024. 129 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 130 | func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { 131 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) 132 | if err != nil { 133 | return 134 | } 135 | 136 | err = c.sendRequest(req, &response) 137 | return 138 | } 139 | 140 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 141 | // This API will be officially deprecated on January 4th, 2024. 142 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 143 | func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { 144 | urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) 145 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 146 | if err != nil { 147 | return 148 | } 149 | 150 | err = c.sendRequest(req, &response) 151 | return 152 | } 153 | 154 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 155 | // This API will be officially deprecated on January 4th, 2024. 156 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 157 | func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { 158 | req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) 159 | if err != nil { 160 | return 161 | } 162 | 163 | err = c.sendRequest(req, &response) 164 | return 165 | } 166 | 167 | // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. 168 | // This API will be officially deprecated on January 4th, 2024. 169 | // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. 170 | func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { 171 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) 172 | if err != nil { 173 | return 174 | } 175 | 176 | err = c.sendRequest(req, &response) 177 | return 178 | } 179 | -------------------------------------------------------------------------------- /batch.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "net/url" 10 | ) 11 | 12 | const batchesSuffix = "/batches" 13 | 14 | type BatchEndpoint string 15 | 16 | const ( 17 | BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" 18 | BatchEndpointCompletions BatchEndpoint = "/v1/completions" 19 | BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" 20 | ) 21 | 22 | type BatchLineItem interface { 23 | MarshalBatchLineItem() []byte 24 | } 25 | 26 | type BatchChatCompletionRequest struct { 27 | CustomID string `json:"custom_id"` 28 | Body ChatCompletionRequest `json:"body"` 29 | Method string `json:"method"` 30 | URL BatchEndpoint `json:"url"` 31 | } 32 | 33 | func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { 34 | marshal, _ := json.Marshal(r) 35 | return marshal 36 | } 37 | 38 | type BatchCompletionRequest struct { 39 | CustomID string `json:"custom_id"` 40 | Body CompletionRequest `json:"body"` 41 | Method string `json:"method"` 42 | URL BatchEndpoint `json:"url"` 43 | } 44 | 45 | func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { 46 | marshal, _ := json.Marshal(r) 47 | return marshal 48 | } 49 | 50 | type BatchEmbeddingRequest struct { 51 | CustomID string `json:"custom_id"` 52 | Body EmbeddingRequest `json:"body"` 53 | Method string `json:"method"` 54 | URL BatchEndpoint `json:"url"` 55 | } 56 | 57 | func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { 58 | marshal, _ := json.Marshal(r) 59 | return marshal 60 | } 61 | 62 | type Batch struct { 63 | ID string `json:"id"` 64 | Object string `json:"object"` 65 | Endpoint BatchEndpoint `json:"endpoint"` 66 | Errors *struct { 67 | Object string `json:"object,omitempty"` 68 | Data []struct { 69 | Code string `json:"code,omitempty"` 70 | Message string `json:"message,omitempty"` 71 | Param *string `json:"param,omitempty"` 72 | Line *int `json:"line,omitempty"` 73 | } `json:"data"` 74 | } `json:"errors"` 75 | InputFileID string `json:"input_file_id"` 76 | CompletionWindow string `json:"completion_window"` 77 | Status string `json:"status"` 78 | OutputFileID *string `json:"output_file_id"` 79 | ErrorFileID *string `json:"error_file_id"` 80 | CreatedAt int `json:"created_at"` 81 | InProgressAt *int `json:"in_progress_at"` 82 | ExpiresAt *int `json:"expires_at"` 83 | FinalizingAt *int `json:"finalizing_at"` 84 | CompletedAt *int `json:"completed_at"` 85 | FailedAt *int `json:"failed_at"` 86 | ExpiredAt *int `json:"expired_at"` 87 | CancellingAt *int `json:"cancelling_at"` 88 | CancelledAt *int `json:"cancelled_at"` 89 | RequestCounts BatchRequestCounts `json:"request_counts"` 90 | Metadata map[string]any `json:"metadata"` 91 | } 92 | 93 | type BatchRequestCounts struct { 94 | Total int `json:"total"` 95 | Completed int `json:"completed"` 96 | Failed int `json:"failed"` 97 | } 98 | 99 | type CreateBatchRequest struct { 100 | InputFileID string `json:"input_file_id"` 101 | Endpoint BatchEndpoint `json:"endpoint"` 102 | CompletionWindow string `json:"completion_window"` 103 | Metadata map[string]any `json:"metadata"` 104 | } 105 | 106 | type BatchResponse struct { 107 | httpHeader 108 | Batch 109 | } 110 | 111 | // CreateBatch — API call to Create batch. 112 | func (c *Client) CreateBatch( 113 | ctx context.Context, 114 | request CreateBatchRequest, 115 | ) (response BatchResponse, err error) { 116 | if request.CompletionWindow == "" { 117 | request.CompletionWindow = "24h" 118 | } 119 | 120 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) 121 | if err != nil { 122 | return 123 | } 124 | 125 | err = c.sendRequest(req, &response) 126 | return 127 | } 128 | 129 | type UploadBatchFileRequest struct { 130 | FileName string 131 | Lines []BatchLineItem 132 | } 133 | 134 | func (r *UploadBatchFileRequest) MarshalJSONL() []byte { 135 | buff := bytes.Buffer{} 136 | for i, line := range r.Lines { 137 | if i != 0 { 138 | buff.Write([]byte("\n")) 139 | } 140 | buff.Write(line.MarshalBatchLineItem()) 141 | } 142 | return buff.Bytes() 143 | } 144 | 145 | func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { 146 | r.Lines = append(r.Lines, BatchChatCompletionRequest{ 147 | CustomID: customerID, 148 | Body: body, 149 | Method: "POST", 150 | URL: BatchEndpointChatCompletions, 151 | }) 152 | } 153 | 154 | func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { 155 | r.Lines = append(r.Lines, BatchCompletionRequest{ 156 | CustomID: customerID, 157 | Body: body, 158 | Method: "POST", 159 | URL: BatchEndpointCompletions, 160 | }) 161 | } 162 | 163 | func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { 164 | r.Lines = append(r.Lines, BatchEmbeddingRequest{ 165 | CustomID: customerID, 166 | Body: body, 167 | Method: "POST", 168 | URL: BatchEndpointEmbeddings, 169 | }) 170 | } 171 | 172 | // UploadBatchFile — upload batch file. 173 | func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { 174 | if request.FileName == "" { 175 | request.FileName = "@batchinput.jsonl" 176 | } 177 | return c.CreateFileBytes(ctx, FileBytesRequest{ 178 | Name: request.FileName, 179 | Bytes: request.MarshalJSONL(), 180 | Purpose: PurposeBatch, 181 | }) 182 | } 183 | 184 | type CreateBatchWithUploadFileRequest struct { 185 | Endpoint BatchEndpoint `json:"endpoint"` 186 | CompletionWindow string `json:"completion_window"` 187 | Metadata map[string]any `json:"metadata"` 188 | UploadBatchFileRequest 189 | } 190 | 191 | // CreateBatchWithUploadFile — API call to Create batch with upload file. 192 | func (c *Client) CreateBatchWithUploadFile( 193 | ctx context.Context, 194 | request CreateBatchWithUploadFileRequest, 195 | ) (response BatchResponse, err error) { 196 | var file File 197 | file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ 198 | FileName: request.FileName, 199 | Lines: request.Lines, 200 | }) 201 | if err != nil { 202 | return 203 | } 204 | return c.CreateBatch(ctx, CreateBatchRequest{ 205 | InputFileID: file.ID, 206 | Endpoint: request.Endpoint, 207 | CompletionWindow: request.CompletionWindow, 208 | Metadata: request.Metadata, 209 | }) 210 | } 211 | 212 | // RetrieveBatch — API call to Retrieve batch. 213 | func (c *Client) RetrieveBatch( 214 | ctx context.Context, 215 | batchID string, 216 | ) (response BatchResponse, err error) { 217 | urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) 218 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 219 | if err != nil { 220 | return 221 | } 222 | err = c.sendRequest(req, &response) 223 | return 224 | } 225 | 226 | // CancelBatch — API call to Cancel batch. 227 | func (c *Client) CancelBatch( 228 | ctx context.Context, 229 | batchID string, 230 | ) (response BatchResponse, err error) { 231 | urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) 232 | req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) 233 | if err != nil { 234 | return 235 | } 236 | err = c.sendRequest(req, &response) 237 | return 238 | } 239 | 240 | type ListBatchResponse struct { 241 | httpHeader 242 | Object string `json:"object"` 243 | Data []Batch `json:"data"` 244 | FirstID string `json:"first_id"` 245 | LastID string `json:"last_id"` 246 | HasMore bool `json:"has_more"` 247 | } 248 | 249 | // ListBatch API call to List batch. 250 | func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { 251 | urlValues := url.Values{} 252 | if limit != nil { 253 | urlValues.Add("limit", fmt.Sprintf("%d", *limit)) 254 | } 255 | if after != nil { 256 | urlValues.Add("after", *after) 257 | } 258 | encodedValues := "" 259 | if len(urlValues) > 0 { 260 | encodedValues = "?" + urlValues.Encode() 261 | } 262 | 263 | urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) 264 | req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) 265 | if err != nil { 266 | return 267 | } 268 | 269 | err = c.sendRequest(req, &response) 270 | return 271 | } 272 | -------------------------------------------------------------------------------- /messages_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/sashabaranov/go-openai" 11 | "github.com/sashabaranov/go-openai/internal/test" 12 | "github.com/sashabaranov/go-openai/internal/test/checks" 13 | ) 14 | 15 | var emptyStr = "" 16 | 17 | func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { 18 | threadID := "thread_abc123" 19 | messageID := "msg_abc123" 20 | fileID := "file_abc123" 21 | 22 | server.RegisterHandler( 23 | "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, 24 | func(w http.ResponseWriter, r *http.Request) { 25 | switch r.Method { 26 | case http.MethodGet: 27 | resBytes, _ := json.Marshal( 28 | openai.MessageFile{ 29 | ID: fileID, 30 | Object: "thread.message.file", 31 | CreatedAt: 1699061776, 32 | MessageID: messageID, 33 | }) 34 | fmt.Fprintln(w, string(resBytes)) 35 | default: 36 | t.Fatalf("unsupported messages http method: %s", r.Method) 37 | } 38 | }, 39 | ) 40 | 41 | server.RegisterHandler( 42 | "/v1/threads/"+threadID+"/messages/"+messageID+"/files", 43 | func(w http.ResponseWriter, r *http.Request) { 44 | switch r.Method { 45 | case http.MethodGet: 46 | resBytes, _ := json.Marshal( 47 | openai.MessageFilesList{MessageFiles: []openai.MessageFile{{ 48 | ID: fileID, 49 | Object: "thread.message.file", 50 | CreatedAt: 0, 51 | MessageID: messageID, 52 | }}}) 53 | fmt.Fprintln(w, string(resBytes)) 54 | default: 55 | t.Fatalf("unsupported messages http method: %s", r.Method) 56 | } 57 | }, 58 | ) 59 | 60 | server.RegisterHandler( 61 | "/v1/threads/"+threadID+"/messages/"+messageID, 62 | func(w http.ResponseWriter, r *http.Request) { 63 | switch r.Method { 64 | case http.MethodPost: 65 | metadata := map[string]any{} 66 | err := json.NewDecoder(r.Body).Decode(&metadata) 67 | checks.NoError(t, err, "unable to decode metadata in modify message call") 68 | payload, ok := metadata["metadata"].(map[string]any) 69 | if !ok { 70 | t.Fatalf("metadata payload improperly wrapped %+v", metadata) 71 | } 72 | 73 | resBytes, _ := json.Marshal( 74 | openai.Message{ 75 | ID: messageID, 76 | Object: "thread.message", 77 | CreatedAt: 1234567890, 78 | ThreadID: threadID, 79 | Role: "user", 80 | Content: []openai.MessageContent{{ 81 | Type: "text", 82 | Text: &openai.MessageText{ 83 | Value: "How does AI work?", 84 | Annotations: nil, 85 | }, 86 | }}, 87 | FileIds: nil, 88 | AssistantID: &emptyStr, 89 | RunID: &emptyStr, 90 | Metadata: payload, 91 | }) 92 | 93 | fmt.Fprintln(w, string(resBytes)) 94 | case http.MethodGet: 95 | resBytes, _ := json.Marshal( 96 | openai.Message{ 97 | ID: messageID, 98 | Object: "thread.message", 99 | CreatedAt: 1234567890, 100 | ThreadID: threadID, 101 | Role: "user", 102 | Content: []openai.MessageContent{{ 103 | Type: "text", 104 | Text: &openai.MessageText{ 105 | Value: "How does AI work?", 106 | Annotations: nil, 107 | }, 108 | }}, 109 | FileIds: nil, 110 | AssistantID: &emptyStr, 111 | RunID: &emptyStr, 112 | Metadata: nil, 113 | }) 114 | fmt.Fprintln(w, string(resBytes)) 115 | case http.MethodDelete: 116 | resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ 117 | ID: messageID, 118 | Object: "thread.message.deleted", 119 | Deleted: true, 120 | }) 121 | fmt.Fprintln(w, string(resBytes)) 122 | default: 123 | t.Fatalf("unsupported messages http method: %s", r.Method) 124 | } 125 | }, 126 | ) 127 | 128 | server.RegisterHandler( 129 | "/v1/threads/"+threadID+"/messages", 130 | func(w http.ResponseWriter, r *http.Request) { 131 | switch r.Method { 132 | case http.MethodPost: 133 | resBytes, _ := json.Marshal(openai.Message{ 134 | ID: messageID, 135 | Object: "thread.message", 136 | CreatedAt: 1234567890, 137 | ThreadID: threadID, 138 | Role: "user", 139 | Content: []openai.MessageContent{{ 140 | Type: "text", 141 | Text: &openai.MessageText{ 142 | Value: "How does AI work?", 143 | Annotations: nil, 144 | }, 145 | }}, 146 | FileIds: nil, 147 | AssistantID: &emptyStr, 148 | RunID: &emptyStr, 149 | Metadata: nil, 150 | }) 151 | fmt.Fprintln(w, string(resBytes)) 152 | case http.MethodGet: 153 | resBytes, _ := json.Marshal(openai.MessagesList{ 154 | Object: "list", 155 | Messages: []openai.Message{{ 156 | ID: messageID, 157 | Object: "thread.message", 158 | CreatedAt: 1234567890, 159 | ThreadID: threadID, 160 | Role: "user", 161 | Content: []openai.MessageContent{{ 162 | Type: "text", 163 | Text: &openai.MessageText{ 164 | Value: "How does AI work?", 165 | Annotations: nil, 166 | }, 167 | }}, 168 | FileIds: nil, 169 | AssistantID: &emptyStr, 170 | RunID: &emptyStr, 171 | Metadata: nil, 172 | }}, 173 | FirstID: &messageID, 174 | LastID: &messageID, 175 | HasMore: false, 176 | }) 177 | fmt.Fprintln(w, string(resBytes)) 178 | default: 179 | t.Fatalf("unsupported messages http method: %s", r.Method) 180 | } 181 | }, 182 | ) 183 | } 184 | 185 | // TestMessages Tests the messages endpoint of the API using the mocked server. 186 | func TestMessages(t *testing.T) { 187 | threadID := "thread_abc123" 188 | messageID := "msg_abc123" 189 | fileID := "file_abc123" 190 | 191 | client, server, teardown := setupOpenAITestServer() 192 | defer teardown() 193 | 194 | setupServerForTestMessage(t, server) 195 | ctx := context.Background() 196 | 197 | // static assertion of return type 198 | var msg openai.Message 199 | msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{ 200 | Role: "user", 201 | Content: "How does AI work?", 202 | FileIds: nil, 203 | Metadata: nil, 204 | }) 205 | checks.NoError(t, err, "CreateMessage error") 206 | if msg.ID != messageID { 207 | t.Fatalf("unexpected message id: '%s'", msg.ID) 208 | } 209 | 210 | var msgs openai.MessagesList 211 | msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) 212 | checks.NoError(t, err, "ListMessages error") 213 | if len(msgs.Messages) != 1 { 214 | t.Fatalf("unexpected length of fetched messages") 215 | } 216 | 217 | // with pagination options set 218 | limit := 1 219 | order := "desc" 220 | after := "obj_foo" 221 | before := "obj_bar" 222 | runID := "run_abc123" 223 | msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) 224 | checks.NoError(t, err, "ListMessages error") 225 | if len(msgs.Messages) != 1 { 226 | t.Fatalf("unexpected length of fetched messages") 227 | } 228 | 229 | msg, err = client.RetrieveMessage(ctx, threadID, messageID) 230 | checks.NoError(t, err, "RetrieveMessage error") 231 | if msg.ID != messageID { 232 | t.Fatalf("unexpected message id: '%s'", msg.ID) 233 | } 234 | 235 | msg, err = client.ModifyMessage(ctx, threadID, messageID, 236 | map[string]string{ 237 | "foo": "bar", 238 | }) 239 | checks.NoError(t, err, "ModifyMessage error") 240 | if msg.Metadata["foo"] != "bar" { 241 | t.Fatalf("expected message metadata to get modified") 242 | } 243 | 244 | msgDel, err := client.DeleteMessage(ctx, threadID, messageID) 245 | checks.NoError(t, err, "DeleteMessage error") 246 | if msgDel.ID != messageID { 247 | t.Fatalf("unexpected message id: '%s'", msg.ID) 248 | } 249 | if !msgDel.Deleted { 250 | t.Fatalf("expected deleted is true") 251 | } 252 | _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") 253 | checks.HasError(t, err, "DeleteMessage error") 254 | 255 | // message files 256 | var msgFile openai.MessageFile 257 | msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) 258 | checks.NoError(t, err, "RetrieveMessageFile error") 259 | if msgFile.ID != fileID { 260 | t.Fatalf("unexpected message file id: '%s'", msgFile.ID) 261 | } 262 | 263 | var msgFiles openai.MessageFilesList 264 | msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID) 265 | checks.NoError(t, err, "RetrieveMessageFile error") 266 | if len(msgFiles.MessageFiles) != 1 { 267 | t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles)) 268 | } 269 | if msgFiles.MessageFiles[0].ID != fileID { 270 | t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID) 271 | } 272 | } 273 | --------------------------------------------------------------------------------