├── .gitignore ├── LICENSE ├── README.md ├── client.go ├── count_token.go ├── count_token_test.go ├── embd.go ├── embd_test.go ├── examples └── text_generation │ └── main.go ├── generate_content.go ├── generate_content_test.go ├── go.mod ├── go.sum ├── helper.go ├── model_info.go ├── model_info_test.go ├── models ├── common.go ├── const.go ├── const_models.go ├── count_token.go ├── embd.go ├── error.go ├── generate_content.go ├── model_info.go ├── parts.go └── vertexRegion.go ├── stream_generate_content.go └── stream_generate_content_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | .env 4 | *.exe 5 | *.exe~ 6 | *.dll 7 | *.so 8 | *_priv.go -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Limit-LAB & KevinZonda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go Gemini - A Go SDK for Google Gemini LLM 2 | 3 | 4 | [![Go Reference](https://pkg.go.dev/badge/github.com/Limit-LAB/go-gemini.svg)](https://pkg.go.dev/github.com/Limit-LAB/go-gemini) 5 | 6 | This library provides unofficial Go clients for [Gemini API](https://ai.google.dev/tutorials/rest_quickstart). We support: 7 | 8 | * Gemini Pro 9 | * Gemini Pro Vision 10 | 11 | ## Get Started 12 | 13 | ```go 14 | package main 15 | 16 | import ( 17 | "fmt" 18 | gemini "github.com/Limit-LAB/go-gemini" 19 | "github.com/Limit-LAB/go-gemini/models" 20 | ) 21 | 22 | func main() { 23 | cli := gemini.NewClient("") 24 | rst, err := cli.GenerateContent(models.GeminiPro, 25 | models.NewGenerateContentRequest( 26 | models.NewContent(models.RoleUser, models.NewTextPart("How are you?")), 27 | ), 28 | ) 29 | if err != nil { 30 | panic(err) 31 | } 32 | fmt.Println(rst.Candidates[0].Content.Parts[0]) 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini/models" 5 | "net/http" 6 | ) 7 | 8 | type AuthBy string 9 | 10 | const ( 11 | AuthByHttpHeader AuthBy = "HttpHeader" 12 | AuthByUrlQuery AuthBy = "UrlQuery" 13 | ) 14 | 15 | type Client struct { 16 | hc *http.Client 17 | key string 18 | baseUrl string 19 | auth AuthBy // Vertex AI uses HTTP HEAD instead of Query 20 | } 21 | 22 | func NewClient(key string) *Client { 23 | return &Client{ 24 | hc: &http.Client{}, 25 | key: key, 26 | baseUrl: "https://generativelanguage.googleapis.com/v1beta/models/", 27 | auth: AuthByUrlQuery, 28 | } 29 | } 30 | 31 | func (c *Client) SetBaseUrl(url string) *Client { 32 | c.baseUrl = url 33 | return c 34 | } 35 | 36 | func (c *Client) SetAuthWay(auth AuthBy) *Client { 37 | c.auth = auth 38 | return c 39 | } 40 | 41 | func (c *Client) VertexAI(region models.VertexRegion, projId string) *Client { 42 | rg := string(region) 43 | c.auth = AuthByHttpHeader 44 | c.baseUrl = "https://" + rg + "-aiplatform.googleapis.com/v1/projects/" + projId + "/locations/" + rg + "/publishers/google/models/" 45 | return c 46 | } 47 | 48 | func (c *Client) AIStudio() *Client { 49 | c.auth = AuthByUrlQuery 50 | c.baseUrl = "https://generativelanguage.googleapis.com/v1beta/models/" 51 | return c 52 | } 53 | -------------------------------------------------------------------------------- /count_token.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini/models" 5 | ) 6 | 7 | func (c *Client) CountToken(model models.GeminiModel, req models.CountTokenRequest) (rst models.CountTokenResponse, err error) { 8 | url := c.url(string(model), "countTokens") 9 | rst, err = unjson[models.CountTokenResponse](c.post(url, req)) 10 | return 11 | } 12 | -------------------------------------------------------------------------------- /count_token_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini" 5 | "github.com/Limit-LAB/go-gemini/models" 6 | "github.com/joho/godotenv" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func TestCountToken(t *testing.T) { 12 | godotenv.Load() 13 | key := os.Getenv("GEMINI") 14 | cli := gemini.NewClient(key) 15 | rst, err := cli.CountToken(models.GeminiPro, models.CountTokenRequest{ 16 | Contents: []models.Content{ 17 | models.NewContent(models.RoleNil, models.NewTextPart("Hello, world!")), 18 | }, 19 | }) 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | jsonPrint(rst) 24 | } 25 | -------------------------------------------------------------------------------- /embd.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini/models" 5 | ) 6 | 7 | func (c *Client) EmbedContent(model models.EmbeddingModel, parts []models.Part) (models.EmbeddingValue, error) { 8 | url := c.url(string(model), "embedContent") 9 | req := models.EmbeddingContentRequest{ 10 | Model: "models/" + string(model), 11 | Content: models.Content{ 12 | Parts: parts, 13 | }, 14 | } 15 | rst, err := unjson[models.EmbeddingContentResponse](c.post(url, req)) 16 | return rst.Embedding, err 17 | } 18 | 19 | func (c *Client) BatchEmbedContent(model models.EmbeddingModel, parts [][]models.Part) ([]models.EmbeddingValue, error) { 20 | url := c.url(string(model), "batchEmbedContents") 21 | req := models.BatchEmbeddingContentsRequest{} 22 | modelStr := "models/" + string(model) 23 | for _, content := range parts { 24 | req.Requests = append(req.Requests, models.EmbeddingContentRequest{ 25 | Model: modelStr, 26 | Content: models.Content{ 27 | Parts: content, 28 | }, 29 | }) 30 | } 31 | rst, err := unjson[models.BatchEmbeddingContentsResponse](c.post(url, req)) 32 | return rst.Embeddings, err 33 | } 34 | -------------------------------------------------------------------------------- /embd_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini" 5 | "github.com/Limit-LAB/go-gemini/models" 6 | "github.com/joho/godotenv" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func TestEmbd(t *testing.T) { 12 | godotenv.Load() 13 | key := os.Getenv("GEMINI") 14 | cli := gemini.NewClient(key) 15 | emb, err := cli.EmbedContent( 16 | models.Embedding001, 17 | models.NewParts().AppendPart(models.NewTextPart("Write a story about a magic backpack.")), 18 | ) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | jsonPrint(emb) 23 | } 24 | -------------------------------------------------------------------------------- /examples/text_generation/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | gemini "github.com/Limit-LAB/go-gemini" 6 | "github.com/Limit-LAB/go-gemini/models" 7 | ) 8 | 9 | func main() { 10 | cli := gemini.NewClient("") 11 | rst, err := cli.GenerateContent(models.GeminiPro, 12 | models.NewGenerateContentRequest( 13 | models.NewContent(models.RoleUser, models.NewTextPart("How are you?")), 14 | ), 15 | ) 16 | if err != nil { 17 | panic(err) 18 | } 19 | fmt.Println(rst.Candidates[0].Content.Parts[0]) 20 | } 21 | -------------------------------------------------------------------------------- /generate_content.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "errors" 5 | "github.com/Limit-LAB/go-gemini/models" 6 | ) 7 | 8 | func (c *Client) GenerateContent(model models.GeminiModel, req *models.GenerateContentRequest) (models.GenerateContentResponse, error) { 9 | for _, content := range req.Contents { 10 | err := validateGenerateContentRequest(model, content) 11 | if err != nil { 12 | return models.GenerateContentResponse{}, err 13 | } 14 | } 15 | url := c.url(string(model), "generateContent") 16 | rst, err := unjson[models.GenerateContentResponse](c.post(url, req)) 17 | return rst, err 18 | } 19 | 20 | var ErrTextOnlyModel = errors.New("this model only supports text input") 21 | 22 | func validateGenerateContentRequest(model models.GeminiModel, req models.Content) error { 23 | if model != models.GeminiPro { 24 | return nil 25 | } 26 | for _, c := range req.Parts { 27 | if c.Text != nil { 28 | continue 29 | } 30 | return ErrTextOnlyModel 31 | } 32 | return nil 33 | } 34 | -------------------------------------------------------------------------------- /generate_content_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/Limit-LAB/go-gemini" 7 | "github.com/Limit-LAB/go-gemini/models" 8 | "github.com/joho/godotenv" 9 | "os" 10 | "testing" 11 | ) 12 | 13 | func TestGenerateContent(t *testing.T) { 14 | godotenv.Load() 15 | key := os.Getenv("GEMINI") 16 | cli := gemini.NewClient(key) 17 | rst, err := cli.GenerateContent(models.GeminiPro, 18 | models.NewGenerateContentRequest( 19 | models.NewContent(models.RoleUser, models.NewTextPart("你好")), 20 | models.NewContent(models.RoleModel, models.NewTextPart("你好!很高兴为您服务。请问您需要什么帮助?")), 21 | models.NewContent(models.RoleUser, models.NewTextPart("你是谁")), 22 | ), 23 | ) 24 | 25 | if err != nil { 26 | panic(err) 27 | } 28 | jsonPrint(rst) 29 | fmt.Println(rst.Candidates[0].Content.Parts[0].GetText()) 30 | } 31 | 32 | func jsonPrint(v any) { 33 | bs, _ := json.MarshalIndent(v, "", " ") 34 | fmt.Println(string(bs)) 35 | } 36 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/Limit-LAB/go-gemini 2 | 3 | go 1.21 4 | 5 | require github.com/joho/godotenv v1.5.1 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= 2 | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= 3 | -------------------------------------------------------------------------------- /helper.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "github.com/Limit-LAB/go-gemini/models" 9 | "io" 10 | "net/http" 11 | ) 12 | 13 | func (c *Client) keyParam() string { 14 | return "?key=" + c.key 15 | } 16 | 17 | func (c *Client) url(model, action string) string { 18 | if action == "" { 19 | return c.baseUrl + model 20 | } 21 | return c.baseUrl + model + ":" + action 22 | } 23 | 24 | func (c *Client) post(url string, body interface{}) (rst []byte, err error) { 25 | return c.simpleReq("POST", url, body) 26 | } 27 | 28 | func (c *Client) newReq(method string, url string, body any) (req *http.Request, err error) { 29 | var reader io.Reader = nil 30 | if body != nil { 31 | var bs []byte 32 | bs, err = json.Marshal(body) 33 | if err != nil { 34 | return 35 | } 36 | reader = bytes.NewReader(bs) 37 | } 38 | 39 | if c.auth == AuthByUrlQuery { 40 | url = url + c.keyParam() 41 | } 42 | req, err = http.NewRequest(method, url, reader) 43 | if err != nil { 44 | return 45 | } 46 | req.Header.Set("Content-Type", "application/json") 47 | if c.auth == AuthByHttpHeader { 48 | req.Header.Set("Authorization", "Bearer "+c.key) 49 | } 50 | return 51 | } 52 | 53 | func (c *Client) handleReq(req *http.Request) (rst []byte, err error) { 54 | rsp, err := c.hc.Do(req) 55 | if err != nil { 56 | return 57 | } 58 | defer rsp.Body.Close() 59 | rst, err = io.ReadAll(rsp.Body) 60 | return 61 | } 62 | 63 | func (c *Client) simpleReq(method string, url string, body any) (rst []byte, err error) { 64 | req, err := c.newReq(method, url, body) 65 | if err != nil { 66 | return 67 | } 68 | return c.handleReq(req) 69 | } 70 | 71 | func (c *Client) get(url string) (rst []byte, err error) { 72 | return c.simpleReq("GET", url, nil) 73 | } 74 | 75 | func unjson[T any](bs []byte, e error) (rst T, err error) { 76 | if e != nil { 77 | err = e 78 | return 79 | } 80 | var eResp models.ErrorResponse 81 | err = json.Unmarshal(bs, &eResp) 82 | if err == nil && eResp.Error != nil { 83 | err = errors.New(eResp.Error.Message) 84 | return 85 | } 86 | err = json.Unmarshal(bs, &rst) 87 | return 88 | } 89 | 90 | func jout(v any) { 91 | bs, _ := json.MarshalIndent(v, "", " ") 92 | fmt.Println(string(bs)) 93 | } 94 | -------------------------------------------------------------------------------- /model_info.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import "github.com/Limit-LAB/go-gemini/models" 4 | 5 | func (c *Client) GetModelInfo(model models.GeminiModel) (models.ModelInfo, error) { 6 | url := c.url(string(model), "") 7 | return unjson[models.ModelInfo](c.get(url)) 8 | } 9 | 10 | func (c *Client) GetModelList() ([]models.ModelInfo, error) { 11 | lst, err := unjson[models.ModelListResponse](c.get(c.baseUrl)) 12 | return lst.Models, err 13 | } 14 | -------------------------------------------------------------------------------- /model_info_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "github.com/Limit-LAB/go-gemini" 5 | "github.com/Limit-LAB/go-gemini/models" 6 | "github.com/joho/godotenv" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func TestGetModelInfo(t *testing.T) { 12 | godotenv.Load() 13 | key := os.Getenv("GEMINI") 14 | cli := gemini.NewClient(key) 15 | rst, err := cli.GetModelInfo(models.GeminiPro) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | jsonPrint(rst) 20 | } 21 | 22 | func TestGetModelList(t *testing.T) { 23 | godotenv.Load() 24 | key := os.Getenv("GEMINI") 25 | cli := gemini.NewClient(key) 26 | rst, err := cli.GetModelList() 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | jsonPrint(rst) 31 | } 32 | -------------------------------------------------------------------------------- /models/common.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type PromptFeedback struct { 4 | SafetyRatings []SafetyRating `json:"safetyRatings"` 5 | } 6 | 7 | type SafetyRating struct { 8 | Category string `json:"category"` 9 | Probability string `json:"probability"` 10 | } 11 | 12 | type Content struct { 13 | Parts []Part `json:"parts"` 14 | Role Role `json:"role,omitempty"` 15 | } 16 | 17 | type GenerationConfig struct { 18 | StopSequences []string `json:"stopSequences,omitempty"` 19 | Temperature *float32 `json:"temperature,omitempty"` 20 | MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` 21 | TopP *float32 `json:"topP,omitempty"` 22 | TopK *float32 `json:"topK,omitempty"` 23 | } 24 | 25 | func NewGenerationConfig() *GenerationConfig { 26 | return &GenerationConfig{} 27 | } 28 | 29 | func (g *GenerationConfig) WithStopSequences(stopSequences ...string) *GenerationConfig { 30 | g.StopSequences = stopSequences 31 | return g 32 | } 33 | 34 | func (g *GenerationConfig) WithTemperature(temperature float32) *GenerationConfig { 35 | g.Temperature = &temperature 36 | return g 37 | } 38 | 39 | func (g *GenerationConfig) WithMaxOutputTokens(maxOutputTokens int) *GenerationConfig { 40 | g.MaxOutputTokens = &maxOutputTokens 41 | return g 42 | } 43 | 44 | func (g *GenerationConfig) WithTopP(topP float32) *GenerationConfig { 45 | g.TopP = &topP 46 | return g 47 | } 48 | 49 | func (g *GenerationConfig) WithTopK(topK float32) *GenerationConfig { 50 | g.TopK = &topK 51 | return g 52 | } 53 | -------------------------------------------------------------------------------- /models/const.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type Role string 4 | 5 | const ( 6 | RoleUser Role = "USER" 7 | RoleModel Role = "MODEL" 8 | RoleNil Role = "" 9 | ) 10 | 11 | func (r Role) ToOpenAIRole() string { 12 | switch r { 13 | case RoleUser: 14 | return "user" 15 | case RoleModel: 16 | return "assistant" 17 | default: 18 | return "" 19 | } 20 | } 21 | 22 | type MimeType string 23 | 24 | const ( 25 | MimeImagePng MimeType = "image/png" 26 | MimeImageJpeg MimeType = "image/jpeg" 27 | MimeImageWebP MimeType = "image/webp" 28 | MimeImageHEIC MimeType = "image/heic" 29 | MimeImageHEIF MimeType = "image/heif" 30 | MimeVideoMov MimeType = "video/mov" 31 | MimeVideoMpeg MimeType = "video/mpeg" 32 | MimeVideoMp4 MimeType = "video/mp4" 33 | MimeVideoMpg MimeType = "video/mpg" 34 | MimeVideoAvi MimeType = "video/avi" 35 | MimeVideoWmv MimeType = "video/wmv" 36 | MimeVideoMpegps MimeType = "video/mpegps" 37 | MimeVideoFlv MimeType = "video/flv" 38 | ) 39 | -------------------------------------------------------------------------------- /models/const_models.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type GeminiModel string 4 | 5 | const ( 6 | GeminiPro GeminiModel = "gemini-pro" 7 | GeminiProVision GeminiModel = "gemini-pro-vision" 8 | ) 9 | 10 | type EmbeddingModel string 11 | 12 | const ( 13 | Embedding001 EmbeddingModel = "embedding-001" 14 | ) 15 | -------------------------------------------------------------------------------- /models/count_token.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type CountTokenRequest struct { 4 | Contents []Content `json:"contents"` 5 | } 6 | 7 | type CountTokenResponse struct { 8 | TotalTokens int `json:"totalTokens"` 9 | } 10 | -------------------------------------------------------------------------------- /models/embd.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type EmbeddingContentRequest struct { 4 | Model string `json:"model"` 5 | Content Content `json:"content"` 6 | } 7 | 8 | type EmbeddingValue struct { 9 | Values []float32 `json:"values"` 10 | } 11 | 12 | type EmbeddingContentResponse struct { 13 | Embedding EmbeddingValue `json:"embedding"` 14 | } 15 | 16 | type BatchEmbeddingContentsRequest struct { 17 | Requests []EmbeddingContentRequest `json:"requests"` 18 | } 19 | 20 | type BatchEmbeddingContentsResponse struct { 21 | Embeddings []EmbeddingValue `json:"embeddings"` 22 | } 23 | -------------------------------------------------------------------------------- /models/error.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type ErrorIn struct { 4 | Code int `json:"code"` 5 | Message string `json:"message"` 6 | Status string `json:"status"` 7 | } 8 | 9 | type ErrorResponse struct { 10 | Error *ErrorIn `json:"error,omitempty"` 11 | } 12 | -------------------------------------------------------------------------------- /models/generate_content.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type GenerateContentRequest struct { 4 | Contents []Content `json:"contents"` 5 | GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"` 6 | } 7 | 8 | type GenerateContentResponse struct { 9 | Candidates []GenerateContentCandidate `json:"candidates"` 10 | PromptFeedback PromptFeedback `json:"promptFeedback"` 11 | } 12 | 13 | type GenerateContentCandidate struct { 14 | Content Content `json:"content"` 15 | FinishReason string `json:"finishReason"` 16 | Index int `json:"index"` 17 | SafetyRatings []SafetyRating `json:"safetyRatings"` 18 | } 19 | 20 | func NewGenerateContentRequest(contents ...Content) *GenerateContentRequest { 21 | return &GenerateContentRequest{ 22 | Contents: contents, 23 | } 24 | } 25 | 26 | func NewGenerateContentRequestWithConfig(cfg *GenerationConfig, contents ...Content) *GenerateContentRequest { 27 | return &GenerateContentRequest{ 28 | Contents: contents, 29 | GenerationConfig: cfg, 30 | } 31 | } 32 | 33 | func (r *GenerateContentRequest) WithGenerationConfig(config GenerationConfig) *GenerateContentRequest { 34 | r.GenerationConfig = &config 35 | return r 36 | } 37 | 38 | func NewContent(role Role, parts ...Part) Content { 39 | return Content{ 40 | Role: role, 41 | Parts: parts, 42 | } 43 | } 44 | 45 | func (r *GenerateContentRequest) AppendContent(role Role, parts ...Part) *GenerateContentRequest { 46 | r.Contents = append(r.Contents, Content{ 47 | Role: role, 48 | Parts: parts, 49 | }) 50 | return r 51 | } 52 | -------------------------------------------------------------------------------- /models/model_info.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type ModelListResponse struct { 4 | Models []ModelInfo `json:"models"` 5 | } 6 | 7 | type ModelInfo struct { 8 | Name string `json:"name"` 9 | Version string `json:"version"` 10 | DisplayName string `json:"displayName"` 11 | Description string `json:"description"` 12 | InputTokenLimit int `json:"inputTokenLimit"` 13 | OutputTokenLimit int `json:"outputTokenLimit"` 14 | SupportedGenerationMethods []string `json:"supportedGenerationMethods"` 15 | Temperature float32 `json:"temperature,omitempty"` 16 | TopP float32 `json:"topP,omitempty"` 17 | TopK int `json:"topK,omitempty"` 18 | } 19 | -------------------------------------------------------------------------------- /models/parts.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type Part struct { 4 | Text *string `json:"text,omitempty"` 5 | InlineData *InlineData `json:"inline_data,omitempty"` 6 | } 7 | 8 | func (p Part) IsText() bool { 9 | return p.Text != nil 10 | } 11 | 12 | func (p Part) IsInlineData() bool { 13 | return p.InlineData != nil 14 | } 15 | 16 | func (p Part) GetText() string { 17 | if p.Text == nil { 18 | return "" 19 | } 20 | return *p.Text 21 | } 22 | 23 | func (p Part) GetInlineData() InlineData { 24 | if p.InlineData == nil { 25 | return InlineData{} 26 | } 27 | return *p.InlineData 28 | } 29 | 30 | type InlineData struct { 31 | MimeType string `json:"mime_type"` 32 | Data string `json:"data"` 33 | } 34 | 35 | type Parts []Part 36 | 37 | func NewParts(parts ...Part) Parts { 38 | return parts 39 | } 40 | 41 | func (p Parts) AppendPart(part ...Part) Parts { 42 | return append(p, part...) 43 | } 44 | 45 | func NewTextPart(text string) Part { 46 | return Part{ 47 | Text: &text, 48 | } 49 | } 50 | 51 | func NewInlineDataPart(mimeType, data string) Part { 52 | return Part{ 53 | InlineData: &InlineData{ 54 | MimeType: mimeType, 55 | Data: data, 56 | }, 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /models/vertexRegion.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type VertexRegion string 4 | 5 | const ( 6 | VertexRegion_US_Iowa VertexRegion = "us-central1" 7 | VertexRegion_US_Las_Vegas VertexRegion = "us-west4" 8 | VertexRegion_CA_Montreal VertexRegion = "northamerica-northeast1" 9 | VertexRegion_US_Northern_Virginia VertexRegion = "us-east4" 10 | VertexRegion_US_Oregon VertexRegion = "us-west1" 11 | VertexRegion_KR_Seoul VertexRegion = "asia-northeast3" 12 | VertexRegion_SG_Singapore VertexRegion = "asia-southeast1" 13 | VertexRegion_JP_Tokyo VertexRegion = "asia-northeast1" 14 | ) 15 | 16 | const ( 17 | VertexRegion_US_Central_1 VertexRegion = "us-central1" 18 | VertexRegion_US_West_4 VertexRegion = "us-west4" 19 | VertexRegion_NorthAmerica_NorthWast_1 VertexRegion = "northamerica-northeast1" 20 | VertexRegion_US_East_4 VertexRegion = "us-east4" 21 | VertexRegion_US_West_1 VertexRegion = "us-west1" 22 | VertexRegion_Asia_NorthEast_3 VertexRegion = "asia-northeast3" 23 | VertexRegion_Asia_SouthEast_1 VertexRegion = "asia-southeast1" 24 | VertexRegion_Asia_NorthEast_1 VertexRegion = "asia-northeast1" 25 | ) 26 | -------------------------------------------------------------------------------- /stream_generate_content.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/json" 7 | "errors" 8 | "github.com/Limit-LAB/go-gemini/models" 9 | "io" 10 | "strings" 11 | ) 12 | 13 | func (c *Client) GenerateContentStream(model models.GeminiModel, req *models.GenerateContentRequest) (*GenerateContentStreamer, error) { 14 | for _, content := range req.Contents { 15 | err := validateGenerateContentRequest(model, content) 16 | if err != nil { 17 | return nil, err 18 | } 19 | } 20 | url := c.url(string(model), "streamGenerateContent") 21 | httpReq, err := c.newReq("POST", url, req) 22 | if err != nil { 23 | return nil, err 24 | } 25 | resp, err := c.hc.Do(httpReq) 26 | if err != nil { 27 | return nil, err 28 | } 29 | if resp.StatusCode >= 300 { 30 | defer resp.Body.Close() 31 | var bs []byte 32 | bs, err = io.ReadAll(resp.Body) 33 | if err != nil { 34 | return nil, err 35 | } 36 | var eResp []models.ErrorResponse 37 | err = json.Unmarshal(bs, &eResp) 38 | if err != nil { 39 | return nil, err 40 | } 41 | if len(eResp) > 0 { 42 | return nil, errors.New(eResp[0].Error.Message) 43 | } 44 | } 45 | return newStreamScanner(resp.Body), nil 46 | } 47 | 48 | type GenerateContentStreamer struct { 49 | buf *bufio.Scanner 50 | raw io.ReadCloser 51 | } 52 | 53 | func (r *GenerateContentStreamer) Close() error { 54 | return r.raw.Close() 55 | } 56 | 57 | func (r *GenerateContentStreamer) Receive() (models.GenerateContentResponse, error) { 58 | if !r.buf.Scan() { 59 | err := r.buf.Err() 60 | if err == nil { 61 | err = io.EOF 62 | } 63 | return models.GenerateContentResponse{}, err 64 | } 65 | txt := r.buf.Text() 66 | // remove head '[' and tail ']' 67 | txt = strings.TrimLeft(txt, "[,\r\n") 68 | txt = strings.TrimRight(txt, "],\r\n") 69 | var res models.GenerateContentResponse 70 | err := json.Unmarshal([]byte(txt), &res) 71 | return res, err 72 | } 73 | 74 | func newStreamScanner(eventStream io.ReadCloser) *GenerateContentStreamer { 75 | scanner := bufio.NewScanner(eventStream) 76 | 77 | scanner.Buffer(make([]byte, 4096), 4096) 78 | 79 | split := func(data []byte, atEOF bool) (int, []byte, error) { 80 | if i := bytes.Index(data, []byte("}\n,\r\n")); i >= 0 { 81 | return i + 5, data[0 : i+1], nil 82 | } 83 | if atEOF && len(data) == 0 { 84 | return 0, nil, nil 85 | } 86 | if atEOF { 87 | return len(data), data, nil 88 | } 89 | return 0, nil, nil 90 | } 91 | scanner.Split(split) 92 | 93 | return &GenerateContentStreamer{ 94 | buf: scanner, 95 | raw: eventStream, 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /stream_generate_content_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "fmt" 5 | "github.com/Limit-LAB/go-gemini" 6 | "github.com/Limit-LAB/go-gemini/models" 7 | "github.com/joho/godotenv" 8 | "os" 9 | "testing" 10 | ) 11 | 12 | func TestGeminiStreamGenerateContent(t *testing.T) { 13 | godotenv.Load() 14 | key := os.Getenv("GEMINI") 15 | req := models.NewGenerateContentRequest(). 16 | AppendContent( 17 | models.RoleUser, 18 | models.NewTextPart("Hi! Use 20 words to describe yourself."), 19 | ) 20 | cli := gemini.NewClient(key) 21 | gcs, err := cli.GenerateContentStream(models.GeminiPro, req) 22 | if err != nil { 23 | panic(err) 24 | } 25 | defer gcs.Close() 26 | for { 27 | content, err := gcs.Receive() 28 | if err != nil { 29 | fmt.Println(err) 30 | break 31 | } 32 | fmt.Println("==============LOOP==============") 33 | jsonPrint(content) 34 | } 35 | } 36 | --------------------------------------------------------------------------------