├── .gitignore ├── LICENSE ├── README.md ├── examples └── llms │ ├── openai │ └── openai_test.go │ └── openaichat │ └── openai_chat_test.go ├── go.mod ├── go.sum └── llms ├── base.go ├── openai ├── openai.go └── types.go ├── openaichat ├── openai_chat.go └── types.go ├── shared ├── openai │ ├── errors.go │ └── http_client.go └── utils.go └── types.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | .idea 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Speakeasy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # langchain-go 2 | Go bindings inspired by [langchainjs](https://github.com/hwchase17/langchainjs) 3 | -------------------------------------------------------------------------------- /examples/llms/openai/openai_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | openai "github.com/speakeasy-api/langchain-go/llms/openai" 7 | "log" 8 | "testing" 9 | ) 10 | 11 | // To Execute EXPORT OPENAI_API_KEY=... 12 | 13 | func TestBasicCompletion(t *testing.T) { 14 | llm, err := openai.New() 15 | if err != nil { 16 | log.Fatal(err) 17 | } 18 | completion, err := llm.Call(context.Background(), "Question, what kind of bear is best?", []string{}) 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | fmt.Println(completion) 24 | } 25 | 26 | func TestBasicCompletionWithStop(t *testing.T) { 27 | llm, err := openai.New() 28 | if err != nil { 29 | log.Fatal(err) 30 | } 31 | completion, err := llm.Call(context.Background(), "Question, what kind of bear is best?", []string{"bear"}) 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | 36 | fmt.Println(completion) 37 | } 38 | 39 | func TestBatchCompletion(t *testing.T) { 40 | llm, err := openai.New() 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | completion, err := llm.Generate(context.Background(), []string{ 45 | "Question, what kind of bear is best?", 46 | "How tall is mount everest?", 47 | }, []string{}) 48 | if err != nil { 49 | log.Fatal(err) 50 | } 51 | 52 | fmt.Println(completion) 53 | } 54 | -------------------------------------------------------------------------------- /examples/llms/openaichat/openai_chat_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/speakeasy-api/langchain-go/llms/openaichat" 7 | "log" 8 | "testing" 9 | ) 10 | 11 | // To Execute EXPORT OPENAI_API_KEY=... 12 | 13 | func TestFirstMessageChat(t *testing.T) { 14 | llm, err := openaichat.New() 15 | if err != nil { 16 | log.Fatal(err) 17 | } 18 | completion, err := llm.Call(context.Background(), "Hi, how are you?", []string{}) 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | fmt.Println(completion) 24 | } 25 | 26 | func TestMultiMessageChat(t *testing.T) { 27 | llm, err := openaichat.New(openaichat.OpenAIChatInput{ 28 | PrefixMessages: []openaichat.ChatMessage{ 29 | { 30 | Content: "Mount Everest is the tallest mountain in the world.", 31 | Role: openaichat.ChatMessageRoleEnumAssistant, 32 | }, 33 | }, 34 | }) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | completion, err := llm.Call(context.Background(), "How tall is it?", []string{}) 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | 43 | fmt.Println(completion) 44 | } 45 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/speakeasy-api/langchain-go 2 | 3 | go 1.20 4 | 5 | require github.com/speakeasy-sdks/openai-go-sdk v1.11.0 6 | 7 | require github.com/cenkalti/backoff/v4 v4.2.0 // indirect 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= 2 | github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= 3 | github.com/speakeasy-sdks/openai-go-sdk v1.11.0 h1:YVCv+4eV3Et51dkf59V4IWlqr0yvbAiHrcHeLMlQ8E8= 4 | github.com/speakeasy-sdks/openai-go-sdk v1.11.0/go.mod h1:BmObEkI364euE8MaBs9WDdbvk/sqr3o/XSExh3b4HL0= 5 | -------------------------------------------------------------------------------- /llms/base.go: -------------------------------------------------------------------------------- 1 | package llms 2 | 3 | import "context" 4 | 5 | type LLM interface { 6 | Generate(ctx context.Context, prompts []string, stop []string) (*LLMResult, error) 7 | Call(ctx context.Context, prompt string, stop []string) (string, error) 8 | Name() string 9 | } 10 | -------------------------------------------------------------------------------- /llms/openai/openai.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | llms_shared "github.com/speakeasy-api/langchain-go/llms/shared" 7 | 8 | openai_shared "github.com/speakeasy-api/langchain-go/llms/shared/openai" 9 | gpt "github.com/speakeasy-sdks/openai-go-sdk" 10 | "github.com/speakeasy-sdks/openai-go-sdk/pkg/models/shared" 11 | "math" 12 | "net" 13 | "net/http" 14 | "os" 15 | "strings" 16 | "time" 17 | 18 | "github.com/speakeasy-api/langchain-go/llms" 19 | ) 20 | 21 | // Default Params for Open AI model 22 | const ( 23 | temperature float64 = 0.7 24 | maxTokens int64 = 256 25 | topP float64 = 1 26 | frequencyPenalty float64 = 0 27 | presencePenalty float64 = 0 28 | n int64 = 1 29 | bestOf int64 = 1 30 | modelName string = "text-davinci-003" 31 | batchSize int64 = 20 32 | maxRetries int = 3 33 | ) 34 | 35 | type OpenAI struct { 36 | temperature float64 37 | maxTokens int64 38 | topP float64 39 | frequencyPenalty float64 40 | presencePenalty float64 41 | n int64 42 | bestOf int64 43 | logitBias map[string]interface{} 44 | streaming bool // Streaming Unsupported Right Now 45 | modelName string 46 | modelKwargs map[string]interface{} 47 | maxRetries int 48 | batchSize int64 49 | stop []string 50 | timeout *time.Duration 51 | client *gpt.Gpt 52 | } 53 | 54 | func New(args ...OpenAIInput) (*OpenAI, error) { 55 | if len(args) > 1 { 56 | return nil, errors.New("more than one config argument not supported") 57 | } 58 | 59 | input := OpenAIInput{} 60 | if len(args) > 0 { 61 | input = args[0] 62 | } 63 | 64 | openai := OpenAI{ 65 | temperature: temperature, 66 | maxTokens: maxTokens, 67 | topP: topP, 68 | frequencyPenalty: frequencyPenalty, 69 | presencePenalty: presencePenalty, 70 | n: n, 71 | bestOf: bestOf, 72 | logitBias: input.LogitBias, 73 | streaming: input.Streaming, 74 | modelName: modelName, 75 | modelKwargs: input.ModelKwargs, 76 | batchSize: batchSize, 77 | stop: input.Stop, 78 | timeout: input.Timeout, 79 | maxRetries: maxRetries, 80 | } 81 | 82 | apiKey := os.Getenv("OPENAI_API_KEY") 83 | 84 | if input.OpenAIApiKey != nil { 85 | apiKey = *input.OpenAIApiKey 86 | } 87 | 88 | if apiKey == "" { 89 | return nil, errors.New("OpenAI API key not found") 90 | } 91 | 92 | if input.ModelName != nil { 93 | openai.modelName = *input.ModelName 94 | } 95 | 96 | if strings.HasPrefix(openai.modelName, "gpt-3.5-turbo") || strings.HasPrefix(openai.modelName, "gpt-4") { 97 | return nil, errors.New("use OpenAIChat for these models") 98 | } 99 | 100 | if input.Temperature != nil { 101 | openai.temperature = *input.Temperature 102 | } 103 | 104 | if input.MaxTokens != nil { 105 | openai.maxTokens = *input.MaxTokens 106 | } 107 | 108 | if input.TopP != nil { 109 | openai.topP = *input.TopP 110 | } 111 | 112 | if input.FrequencyPenalty != nil { 113 | openai.frequencyPenalty = *input.FrequencyPenalty 114 | } 115 | 116 | if input.PresencePenalty != nil { 117 | openai.presencePenalty = *input.PresencePenalty 118 | } 119 | 120 | if input.N != nil { 121 | openai.n = *input.N 122 | } 123 | 124 | if input.BestOf != nil { 125 | openai.bestOf = *input.BestOf 126 | } 127 | 128 | if input.BatchSize != nil { 129 | openai.batchSize = *input.BatchSize 130 | } 131 | 132 | if input.MaxRetries != nil { 133 | openai.maxRetries = *input.MaxRetries 134 | } 135 | 136 | httpClient := openai_shared.OpenAIAuthenticatedClient(apiKey) 137 | 138 | if openai.timeout != nil { 139 | httpClient.Timeout = *openai.timeout 140 | } 141 | 142 | client := gpt.New(gpt.WithClient(&httpClient)) 143 | openai.client = client 144 | 145 | return &openai, nil 146 | } 147 | 148 | func (openai *OpenAI) Name() string { 149 | return "openai" 150 | } 151 | 152 | func (openai *OpenAI) Call(ctx context.Context, prompt string, stop []string) (string, error) { 153 | generations, err := openai.Generate(ctx, []string{prompt}, stop) 154 | if err != nil { 155 | return "", err 156 | } 157 | 158 | return generations.Generations[0][0].Text, nil 159 | } 160 | 161 | func (openai *OpenAI) Generate(ctx context.Context, prompts []string, stop []string) (*llms.LLMResult, error) { 162 | subPrompts := llms_shared.BatchSlice[string](prompts, openai.batchSize) 163 | maxTokens := openai.maxTokens 164 | var completionTokens, promptTokens, totalTokens int64 165 | var choices []shared.CreateCompletionResponseChoices 166 | 167 | if openai.maxTokens == -1 { 168 | if len(prompts) != 1 { 169 | return nil, errors.New("max_tokens set to -1 not supported for multiple inputs") 170 | } 171 | 172 | maxTokens = llms_shared.CalculateMaxTokens(prompts[0], openai.modelName) 173 | } 174 | 175 | if len(stop) == 0 { 176 | stop = openai.stop 177 | } 178 | 179 | for _, prompts := range subPrompts { 180 | data, err := openai.completionWithRetry(ctx, prompts, maxTokens, stop) 181 | if err != nil { 182 | return nil, err 183 | } 184 | 185 | choices = append(choices, data.Choices...) 186 | if data.Usage != nil { 187 | completionTokens += data.Usage.CompletionTokens 188 | promptTokens += data.Usage.PromptTokens 189 | totalTokens += data.Usage.TotalTokens 190 | } 191 | } 192 | var generations [][]llms.Generation 193 | batchedChoices := llms_shared.BatchSlice[shared.CreateCompletionResponseChoices](choices, openai.n) 194 | for _, batch := range batchedChoices { 195 | var generationBatch []llms.Generation 196 | for _, choice := range batch { 197 | generationBatch = append(generationBatch, llms.Generation{ 198 | Text: *choice.Text, 199 | GenerationInfo: map[string]interface{}{ 200 | "finishReason": choice.FinishReason, 201 | "logprobs": choice.Logprobs, 202 | }, 203 | }) 204 | } 205 | generations = append(generations, generationBatch) 206 | } 207 | 208 | return &llms.LLMResult{ 209 | Generations: generations, 210 | LLMOutput: map[string]interface{}{ 211 | "completionTokens": completionTokens, 212 | "promptTokens": promptTokens, 213 | "totalTokens": totalTokens, 214 | }, 215 | }, nil 216 | } 217 | 218 | func (openai *OpenAI) completionWithRetry(ctx context.Context, prompts []string, maxTokens int64, stop []string) (*shared.CreateCompletionResponse, error) { 219 | promptRequest := shared.CreateCreateCompletionRequestPromptArrayOfstr(prompts) 220 | request := shared.CreateCompletionRequest{ 221 | Model: openai.modelName, 222 | Prompt: &promptRequest, 223 | MaxTokens: &maxTokens, 224 | Temperature: &openai.temperature, 225 | TopP: &openai.topP, 226 | N: &openai.n, 227 | BestOf: &openai.bestOf, 228 | LogitBias: openai.logitBias, 229 | PresencePenalty: &openai.presencePenalty, 230 | FrequencyPenalty: &openai.frequencyPenalty, 231 | } 232 | if len(stop) != 0 { 233 | stopRequest := shared.CreateCreateCompletionRequestStopArrayOfstr(stop) 234 | request.Stop = &stopRequest 235 | } 236 | 237 | var finalResult *shared.CreateCompletionResponse 238 | var finalErr error 239 | 240 | // wait 2^x second between each retry starting with 241 | // max 10 seconds 242 | for i := 0; i < openai.maxRetries; i++ { 243 | lastTry := i == openai.maxRetries-1 244 | sleep := int(math.Min(math.Pow(2, float64(i)), float64(10))) 245 | res, err := openai.client.OpenAI.CreateCompletion(ctx, request) 246 | if err != nil { 247 | var netErr net.Error 248 | if errors.As(err, &netErr) { 249 | // retry on client timeout 250 | if netErr.Timeout() && !lastTry { 251 | time.Sleep(time.Duration(sleep) * time.Second) 252 | continue 253 | } 254 | } 255 | 256 | return nil, err 257 | } 258 | 259 | if res.StatusCode == http.StatusOK { 260 | finalResult = res.CreateCompletionResponse 261 | break 262 | } else { 263 | openAIError := openai_shared.CreateOpenAIError(res.StatusCode, res.RawResponse.Status) 264 | if lastTry || !openAIError.IsRetryable() { 265 | finalErr = openAIError 266 | break 267 | } 268 | } 269 | 270 | time.Sleep(time.Duration(sleep) * time.Second) 271 | } 272 | 273 | return finalResult, finalErr 274 | } 275 | -------------------------------------------------------------------------------- /llms/openai/types.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import "time" 4 | 5 | type OpenAIInput struct { 6 | // Model name to use 7 | ModelName *string // TODO: Make into Enum 8 | // Holds any additional parameters that are valid to pass to https://platform.openai.com/docs/api-reference/completions/create 9 | ModelKwargs map[string]interface{} 10 | // Batch size to use when passing multiple documents to generate 11 | BatchSize *int64 12 | // List of stop words to use when generating 13 | Stop []string 14 | // Timeout to use when making a http request to OpenAI 15 | Timeout *time.Duration 16 | // Number of retry attempts for a single request to OpenAI 17 | MaxRetries *int 18 | // OpenAI API Key 19 | OpenAIApiKey *string 20 | ModelParams 21 | } 22 | 23 | type ModelParams struct { 24 | // Sampling temperature to use 25 | Temperature *float64 26 | // Maximum number of tokens to generate in the completion. -1 returns as many 27 | // tokens as possible given the prompt and the model's maximum context size. 28 | MaxTokens *int64 29 | // Total probability mass of tokens to consider at each step 30 | TopP *float64 31 | // Penalizes repeated tokens according to frequency 32 | FrequencyPenalty *float64 33 | // Penalizes repeated tokens 34 | PresencePenalty *float64 35 | // Number of completions to generate for each prompt 36 | N *int64 37 | // Generates `bestOf` completions server side and returns the "best" 38 | BestOf *int64 39 | // Dictionary used to adjust the probability of specific tokens being generated 40 | LogitBias map[string]interface{} 41 | // Whether to stream the results or not. Enabling disables tokenUsage reporting 42 | Streaming bool 43 | } 44 | -------------------------------------------------------------------------------- /llms/openaichat/openai_chat.go: -------------------------------------------------------------------------------- 1 | package openaichat 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | openai_shared "github.com/speakeasy-api/langchain-go/llms/shared/openai" 7 | gpt "github.com/speakeasy-sdks/openai-go-sdk" 8 | "github.com/speakeasy-sdks/openai-go-sdk/pkg/models/shared" 9 | "math" 10 | "net" 11 | "net/http" 12 | "os" 13 | "time" 14 | 15 | "github.com/speakeasy-api/langchain-go/llms" 16 | ) 17 | 18 | // Default Params for Open AI model 19 | const ( 20 | temperature float64 = 1 21 | topP float64 = 1 22 | frequencyPenalty float64 = 0 23 | presencePenalty float64 = 0 24 | n int64 = 1 25 | modelName string = "gpt-3.5-turbo" 26 | maxRetries int = 3 27 | ) 28 | 29 | type OpenAIChat struct { 30 | temperature float64 31 | maxTokens int64 32 | topP float64 33 | frequencyPenalty float64 34 | presencePenalty float64 35 | n int64 36 | logitBias map[string]interface{} 37 | streaming bool // Streaming Unsupported Right Now 38 | modelName string 39 | modelKwargs map[string]interface{} 40 | maxRetries int 41 | stop []string 42 | prefixMessages []ChatMessage 43 | timeout *time.Duration 44 | client *gpt.Gpt 45 | } 46 | 47 | func New(args ...OpenAIChatInput) (*OpenAIChat, error) { 48 | if len(args) > 1 { 49 | return nil, errors.New("more than one config argument not supported") 50 | } 51 | 52 | input := OpenAIChatInput{} 53 | if len(args) > 0 { 54 | input = args[0] 55 | } 56 | 57 | openai := OpenAIChat{ 58 | temperature: temperature, 59 | topP: topP, 60 | frequencyPenalty: frequencyPenalty, 61 | presencePenalty: presencePenalty, 62 | n: n, 63 | logitBias: input.LogitBias, 64 | streaming: input.Streaming, 65 | modelName: modelName, 66 | modelKwargs: input.ModelKwargs, 67 | stop: input.Stop, 68 | timeout: input.Timeout, 69 | maxRetries: maxRetries, 70 | prefixMessages: input.PrefixMessages, 71 | } 72 | 73 | apiKey := os.Getenv("OPENAI_API_KEY") 74 | 75 | if input.OpenAIApiKey != nil { 76 | apiKey = *input.OpenAIApiKey 77 | } 78 | 79 | if apiKey == "" { 80 | return nil, errors.New("OpenAI API key not found") 81 | } 82 | 83 | if input.Temperature != nil { 84 | openai.temperature = *input.Temperature 85 | } 86 | 87 | if input.MaxTokens != nil { 88 | openai.maxTokens = *input.MaxTokens 89 | } 90 | 91 | if input.TopP != nil { 92 | openai.topP = *input.TopP 93 | } 94 | 95 | if input.FrequencyPenalty != nil { 96 | openai.frequencyPenalty = *input.FrequencyPenalty 97 | } 98 | 99 | if input.PresencePenalty != nil { 100 | openai.presencePenalty = *input.PresencePenalty 101 | } 102 | 103 | if input.N != nil { 104 | openai.n = *input.N 105 | } 106 | 107 | if input.ModelName != nil { 108 | openai.modelName = *input.ModelName 109 | } 110 | 111 | if input.MaxRetries != nil { 112 | openai.maxRetries = *input.MaxRetries 113 | } 114 | 115 | httpClient := openai_shared.OpenAIAuthenticatedClient(apiKey) 116 | 117 | if openai.timeout != nil { 118 | httpClient.Timeout = *openai.timeout 119 | } 120 | 121 | client := gpt.New(gpt.WithClient(&httpClient)) 122 | openai.client = client 123 | 124 | return &openai, nil 125 | } 126 | 127 | func (openai *OpenAIChat) Name() string { 128 | return "openai-chat" 129 | } 130 | 131 | func (openai *OpenAIChat) Call(ctx context.Context, prompt string, stop []string) (string, error) { 132 | if len(stop) == 0 { 133 | stop = openai.stop 134 | } 135 | 136 | data, err := openai.chatCompletionWithRetry(ctx, prompt, openai.maxTokens, stop) 137 | if err != nil { 138 | return "", err 139 | } 140 | 141 | message := "" 142 | if len(data.Choices) > 0 && data.Choices[0].Message != nil { 143 | message = data.Choices[0].Message.Content 144 | } 145 | 146 | return message, nil 147 | } 148 | 149 | func (openai *OpenAIChat) Generate(ctx context.Context, prompts []string, stop []string) (*llms.LLMResult, error) { 150 | // Not Implemented for OpenAIChat 151 | return nil, nil 152 | } 153 | 154 | func (openai *OpenAIChat) chatCompletionWithRetry(ctx context.Context, prompt string, maxTokens int64, stop []string) (*shared.CreateChatCompletionResponse, error) { 155 | request := shared.CreateChatCompletionRequest{ 156 | Model: openai.modelName, 157 | Messages: formatMessages(openai.prefixMessages, prompt), 158 | Temperature: &openai.temperature, 159 | TopP: &openai.topP, 160 | N: &openai.n, 161 | LogitBias: openai.logitBias, 162 | PresencePenalty: &openai.presencePenalty, 163 | FrequencyPenalty: &openai.frequencyPenalty, 164 | } 165 | if openai.maxTokens != 0 { 166 | request.MaxTokens = &openai.maxTokens 167 | } 168 | 169 | if len(stop) != 0 { 170 | stopRequest := shared.CreateCreateChatCompletionRequestStopArrayOfstr(stop) 171 | request.Stop = &stopRequest 172 | } 173 | 174 | var finalResult *shared.CreateChatCompletionResponse 175 | var finalErr error 176 | 177 | // wait 2^x second between each retry starting with 178 | // max 10 seconds 179 | for i := 0; i < openai.maxRetries; i++ { 180 | lastTry := i == openai.maxRetries-1 181 | sleep := int(math.Min(math.Pow(2, float64(i)), float64(10))) 182 | res, err := openai.client.OpenAI.CreateChatCompletion(ctx, request) 183 | if err != nil { 184 | var netErr net.Error 185 | if errors.As(err, &netErr) { 186 | // retry on client timeout 187 | if netErr.Timeout() && !lastTry { 188 | time.Sleep(time.Duration(sleep) * time.Second) 189 | continue 190 | } 191 | } 192 | 193 | return nil, err 194 | } 195 | 196 | if res.StatusCode == http.StatusOK { 197 | finalResult = res.CreateChatCompletionResponse 198 | break 199 | } else { 200 | openAIError := openai_shared.CreateOpenAIError(res.StatusCode, res.RawResponse.Status) 201 | if lastTry || !openAIError.IsRetryable() { 202 | finalErr = openAIError 203 | break 204 | } 205 | } 206 | 207 | time.Sleep(time.Duration(sleep) * time.Second) 208 | } 209 | 210 | return finalResult, finalErr 211 | } 212 | 213 | func formatMessages(previous []ChatMessage, message string) []shared.ChatCompletionRequestMessage { 214 | var result []shared.ChatCompletionRequestMessage 215 | for _, message := range previous { 216 | result = append(result, shared.ChatCompletionRequestMessage{ 217 | Content: message.Content, 218 | Role: convertRoleEnum(message.Role), 219 | }) 220 | } 221 | result = append(result, shared.ChatCompletionRequestMessage{ 222 | Content: message, 223 | Role: shared.ChatCompletionRequestMessageRoleEnumUser, 224 | }) 225 | return result 226 | } 227 | 228 | func convertRoleEnum(enum ChatMessageRoleEnum) shared.ChatCompletionRequestMessageRoleEnum { 229 | switch enum { 230 | case ChatMessageRoleEnumSystem: 231 | return shared.ChatCompletionRequestMessageRoleEnumSystem 232 | case ChatMessageRoleEnumUser: 233 | return shared.ChatCompletionRequestMessageRoleEnumUser 234 | case ChatMessageRoleEnumAssistant: 235 | return shared.ChatCompletionRequestMessageRoleEnumAssistant 236 | default: 237 | return "" 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /llms/openaichat/types.go: -------------------------------------------------------------------------------- 1 | package openaichat 2 | 3 | import ( 4 | "github.com/speakeasy-api/langchain-go/llms/openai" 5 | ) 6 | 7 | type OpenAIChatInput struct { 8 | // ChatGPT messages to pass as a prefix to the prompt 9 | PrefixMessages []ChatMessage 10 | 11 | openai.OpenAIInput 12 | } 13 | 14 | type ChatMessage struct { 15 | Content string 16 | Role ChatMessageRoleEnum 17 | } 18 | 19 | type ChatMessageRoleEnum string 20 | 21 | const ( 22 | ChatMessageRoleEnumSystem ChatMessageRoleEnum = "system" 23 | ChatMessageRoleEnumUser ChatMessageRoleEnum = "user" 24 | ChatMessageRoleEnumAssistant ChatMessageRoleEnum = "assistant" 25 | ) 26 | -------------------------------------------------------------------------------- /llms/shared/openai/errors.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | type OpenAIError struct { 9 | error 10 | 11 | statusCode int 12 | status string 13 | } 14 | 15 | func (e *OpenAIError) Error() string { 16 | return fmt.Sprintf("error in call to openai with status %s", e.status) 17 | } 18 | 19 | func (e *OpenAIError) GetStatusCode() int { 20 | return e.statusCode 21 | } 22 | 23 | func (e *OpenAIError) IsRetryable() bool { 24 | return e.statusCode == http.StatusTooManyRequests || e.statusCode == http.StatusInternalServerError 25 | } 26 | 27 | func CreateOpenAIError(statusCode int, status string) *OpenAIError { 28 | err := OpenAIError{ 29 | statusCode: statusCode, 30 | status: status, 31 | } 32 | 33 | return &err 34 | } 35 | -------------------------------------------------------------------------------- /llms/shared/openai/http_client.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | type authorizeTransport struct { 9 | ApiKey string 10 | } 11 | 12 | func (t *authorizeTransport) RoundTrip(req *http.Request) (*http.Response, error) { 13 | req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.ApiKey)) 14 | return http.DefaultTransport.RoundTrip(req) 15 | } 16 | 17 | func OpenAIAuthenticatedClient(apiKey string) http.Client { 18 | return http.Client{Transport: &authorizeTransport{ApiKey: apiKey}} 19 | } 20 | -------------------------------------------------------------------------------- /llms/shared/utils.go: -------------------------------------------------------------------------------- 1 | package shared 2 | 3 | func BatchSlice[T any](slice []T, batchSize int64) [][]T { 4 | var chunks [][]T 5 | for i := int64(0); i < int64(len(slice)); i += batchSize { 6 | end := i + batchSize 7 | if end > int64(len(slice)) { 8 | end = int64(len(slice)) 9 | } 10 | 11 | chunks = append(chunks, slice[i:end]) 12 | } 13 | 14 | return chunks 15 | } 16 | 17 | // TODO: Implement Max Token Inference 18 | func CalculateMaxTokens(prompt string, modelName string) int64 { 19 | return 1 20 | } 21 | -------------------------------------------------------------------------------- /llms/types.go: -------------------------------------------------------------------------------- 1 | package llms 2 | 3 | type LLMResult struct { 4 | Generations [][]Generation 5 | LLMOutput map[string]interface{} 6 | } 7 | 8 | type Generation struct { 9 | Text string 10 | GenerationInfo map[string]interface{} 11 | } 12 | --------------------------------------------------------------------------------