├── .cursor └── rules │ └── go.mdc ├── .github └── workflows │ ├── gosec.yml │ ├── lint.yml │ ├── test.yml │ └── trivy.yml ├── .gitignore ├── LICENSE ├── README.md ├── Taskfile.yml ├── context.go ├── doc ├── README.md ├── examples.md ├── getting-started.md ├── history.md ├── images │ └── logo.png ├── mcp.md └── tools.md ├── errors.go ├── examples ├── README.md ├── basic │ └── main.go ├── chat │ └── main.go ├── embedding │ └── main.go ├── mcp │ └── main.go ├── query │ └── main.go ├── simple │ └── main.go └── tools │ └── main.go ├── exit.go ├── export_test.go ├── go.mod ├── go.sum ├── gollem.go ├── gollem_test.go ├── history.go ├── history_test.go ├── hook.go ├── llm.go ├── llm ├── claude │ ├── client.go │ ├── convert.go │ ├── convert_test.go │ ├── embedding.go │ └── export_test.go ├── gemini │ ├── client.go │ ├── convert.go │ ├── convert_test.go │ ├── embedding.go │ ├── embeding_test.go │ └── export_test.go └── openai │ ├── client.go │ ├── convert.go │ ├── convert_test.go │ ├── embedding.go │ ├── embedding_test.go │ └── export_test.go ├── llm_test.go ├── mcp ├── export_test.go ├── mcp.go └── mcp_test.go ├── mock └── mock_gen.go ├── session.go ├── tool.go ├── tool_test.go └── types.go /.cursor/rules/go.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | 7 | # Development rules 8 | 9 | ## Comment & Literals 10 | 11 | All comment and literals in source code MUST be in English. 12 | 13 | ## Error handling 14 | 15 | Use `http://github.com/m-mizutani/goerr/v2` as errorh handling tool. Wrap errors as following. 16 | 17 | ```go 18 | func someAction(tasks []task) error { 19 | for _, t := range tasks { 20 | if err := validateData(t.Data); err != nil { 21 | return goerr.Wrap(err, "failed to validate data", goerr.Value("name", t.Name)) 22 | } 23 | } 24 | // .... 25 | return nil 26 | } 27 | ``` 28 | 29 | # Testing 30 | 31 | If you need to run test for checking, run only the test that you modified or specified by the developer. 32 | 33 | ## Style for similar testing 34 | 35 | Use following Helper Driven Testing style instead of general Tatble Driven Test. Do not use Table Driven Test style. 36 | 37 | ```go 38 | type testCase struct { 39 | input string 40 | expected string 41 | } 42 | 43 | runTest := func(tc testCase) func(t *testing.T) { 44 | return func(t *testing.T) { 45 | actual := someFunc(tc.input) 46 | gt.Equal(t, tc.expected, actual) 47 | } 48 | } 49 | 50 | t.Run("success case", runTest(testCase{ 51 | input: "blue", 52 | expected: "BLUE", 53 | })) 54 | ``` 55 | 56 | ## Test framework 57 | 58 | Use `github.com/m-mizutani/gt` package. 59 | 60 | `gt` is test library leveraging Go generics to check variable type in IDE and compiler. 61 | 62 | ```go 63 | color := "blue" 64 | 65 | // gt.Value(t, color).Equal(5) // <- Compile error 66 | 67 | gt.Value(t, color).Equal("orange") // <- Fail 68 | gt.Value(t, color).Equal("blue") // <- Pass 69 | ``` 70 | 71 | ```go 72 | colors := ["red", "blue"] 73 | 74 | // gt.Array(t, colors).Equal("red") // <- Compile error 75 | // gt.Array(t, colors).Equal([]int{1, 2}) // <- Compile error 76 | 77 | gt.Array(t, colors).Equal([]string{"red", "blue"}) // <- Pass 78 | gt.Array(t, colors).Has("orange") // <- Fail 79 | ``` 80 | 81 | ### Usage 82 | 83 | In many cases, a developer does not care Go generics in using `gt`. However, a developer need to specify generic type (`Value`, `Array`, `Map`, `Error`, etc.) explicitly to use specific test functions for each types. 84 | 85 | See @reference for more detail. 86 | 87 | #### Value 88 | 89 | Generic test type has a minimum set of test methods. 90 | 91 | ```go 92 | type user struct { 93 | Name string 94 | } 95 | u1 := user{Name: "blue"} 96 | 97 | // gt.Value(t, u1).Equal(1) // Compile error 98 | // gt.Value(t, u1).Equal("blue") // Compile error 99 | // gt.Value(t, u1).Equal(&user{Name:"blue"}) // Compile error 100 | 101 | gt.Value(t, u1).Equal(user{Name:"blue"}) // Pass 102 | ``` 103 | 104 | #### Number 105 | 106 | Accepts only number types: `int`, `uint`, `int64`, `float64`, etc. 107 | 108 | ```go 109 | var f float64 = 12.5 110 | gt.Number(t, f). 111 | Equal(12.5). // Pass 112 | Greater(12). // Pass 113 | Less(10). // Fail 114 | GreaterOrEqual(12.5) // Pass 115 | ``` 116 | 117 | #### Array 118 | 119 | Accepts array of any type not only primitive type but also struct. 120 | 121 | ```go 122 | colors := []string{"red", "blue", "yellow"} 123 | 124 | gt.Array(t, colors). 125 | Equal([]string{"red", "blue", "yellow"}) // Pass 126 | Equal([]string{"red", "blue"}) // Fail 127 | // Equal([]int{1, 2}) // Compile error 128 | Contain([]string{"red", "blue"}) // Pass 129 | Has("yellow") // Pass 130 | Length(3) // Pass 131 | 132 | gt.Array(t, colors).Must().Has("orange") // Fail and stop test 133 | ``` 134 | 135 | #### Map 136 | 137 | ```go 138 | colorMap := map[string]int{ 139 | "red": 1, 140 | "yellow": 2, 141 | "blue": 5, 142 | } 143 | 144 | gt.Map(t, colorMap) 145 | .HasKey("blue") // Pass 146 | .HasValue(5) // Pass 147 | // .HasValue("red") // Compile error 148 | .HasKeyValue("yellow", 2) // Pass 149 | 150 | gt.Map(t, colorMap).Must().HasKey("orange") // Fail and stop test 151 | ``` 152 | -------------------------------------------------------------------------------- /.github/workflows/gosec.yml: -------------------------------------------------------------------------------- 1 | name: gosec 2 | 3 | # Run workflow each time code is pushed to your repository and on a schedule. 4 | # The scheduled workflow runs every at 00:00 on Sunday UTC time. 5 | on: 6 | push: 7 | 8 | jobs: 9 | tests: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | security-events: write 13 | actions: read 14 | contents: read 15 | env: 16 | GO111MODULE: on 17 | steps: 18 | - name: Checkout Source 19 | uses: actions/checkout@v4 20 | - name: Run Gosec Security Scanner 21 | uses: securego/gosec@master 22 | with: 23 | # we let the report trigger content trigger a failure using the GitHub Security features. 24 | args: "-no-fail -fmt sarif -out results.sarif ./..." 25 | - name: Upload SARIF file 26 | uses: github/codeql-action/upload-sarif@v2 27 | with: 28 | # Path to SARIF file relative to the root of the repository 29 | sarif_file: results.sarif 30 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | on: 3 | push: 4 | 5 | jobs: 6 | golangci: 7 | name: lint 8 | permissions: 9 | security-events: write 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Go 14 | uses: actions/setup-go@v5 15 | with: 16 | go-version-file: "go.mod" 17 | - name: golangci-lint 18 | uses: golangci/golangci-lint-action@v6 19 | with: 20 | version: latest 21 | args: --timeout 10m ./... 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: [push] 4 | 5 | jobs: 6 | testing: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout upstream repo 11 | uses: actions/checkout@v3 12 | with: 13 | ref: ${{ github.head_ref }} 14 | - uses: actions/setup-go@v3 15 | with: 16 | go-version-file: "go.mod" 17 | - uses: google-github-actions/setup-gcloud@v0.5.0 18 | - run: go test ./... 19 | -------------------------------------------------------------------------------- /.github/workflows/trivy.yml: -------------------------------------------------------------------------------- 1 | name: trivy 2 | 3 | on: 4 | push: 5 | schedule: 6 | - cron: "0 0 * * *" 7 | workflow_dispatch: 8 | 9 | jobs: 10 | scan: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | security-events: write 14 | actions: read 15 | contents: read 16 | 17 | steps: 18 | - name: Checkout upstream repo 19 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 20 | with: 21 | ref: ${{ github.head_ref }} 22 | - id: scan 23 | name: Run Trivy vulnerability scanner in repo mode 24 | uses: aquasecurity/trivy-action@f3d98514b056d8c71a3552e8328c225bc7f6f353 # master 25 | with: 26 | scan-type: "fs" 27 | ignore-unfixed: true 28 | format: "sarif" 29 | output: "trivy-results.sarif" 30 | exit-code: 1 31 | env: 32 | TRIVY_DB_REPOSITORY: public.ecr.aws/aquasecurity/trivy-db 33 | 34 | - name: Upload Trivy scan results to GitHub Security tab 35 | if: failure() && steps.scan.outcome == 'failure' 36 | uses: github/codeql-action/upload-sarif@e8893c57a1f3a2b659b6b55564fdfdbbd2982911 # v3.24.0 37 | with: 38 | sarif_file: "trivy-results.sarif" 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | 27 | # VS code 28 | .vscode 29 | 30 | /tmp 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 gollem [![Go Reference](https://pkg.go.dev/badge/github.com/m-mizutani/gollem.svg)](https://pkg.go.dev/github.com/m-mizutani/gollem) [![Test](https://github.com/m-mizutani/gollem/actions/workflows/test.yml/badge.svg)](https://github.com/m-mizutani/gollem/actions/workflows/test.yml) [![Lint](https://github.com/m-mizutani/gollem/actions/workflows/lint.yml/badge.svg)](https://github.com/m-mizutani/gollem/actions/workflows/lint.yml) [![Gosec](https://github.com/m-mizutani/gollem/actions/workflows/gosec.yml/badge.svg)](https://github.com/m-mizutani/gollem/actions/workflows/gosec.yml) [![Trivy](https://github.com/m-mizutani/gollem/actions/workflows/trivy.yml/badge.svg)](https://github.com/m-mizutani/gollem/actions/workflows/trivy.yml) 2 | 3 | GO for Large LanguagE Model (GOLLEM) 4 | 5 |

6 | 7 |

8 | 9 | 10 | `gollem` provides: 11 | - Common interface to query prompt to Large Language Model (LLM) services 12 | - GenerateContent: Generate text content from prompt 13 | - GenerateEmbedding: Generate embedding vector from text (OpenAI and Gemini) 14 | - Framework for building agentic applications of LLMs with 15 | - Tools by MCP (Model Context Protocol) server and your built-in tools 16 | - Portable conversational memory with history for stateless/distributed applications 17 | 18 | ## Supported LLMs 19 | 20 | - [x] Gemini (see [models](https://ai.google.dev/gemini-api/docs/models?hl=ja)) 21 | - [x] Anthropic (see [models](https://docs.anthropic.com/en/docs/about-claude/models/all-models)) 22 | - [x] OpenAI (see [models](https://platform.openai.com/docs/models)) 23 | 24 | ## Quick Start 25 | 26 | ### Install 27 | 28 | ```bash 29 | go get github.com/m-mizutani/gollem 30 | ``` 31 | 32 | ### Example 33 | 34 | #### Query to LLM 35 | 36 | ```go 37 | llmProvider := os.Args[1] 38 | model := os.Args[2] 39 | prompt := os.Args[3] 40 | 41 | var client gollem.LLMClient 42 | var err error 43 | 44 | switch llmProvider { 45 | case "gemini": 46 | client, err = gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"), gemini.WithModel(model)) 47 | case "claude": 48 | client, err = claude.New(ctx, os.Getenv("ANTHROPIC_API_KEY"), claude.WithModel(model)) 49 | case "openai": 50 | client, err = openai.New(ctx, os.Getenv("OPENAI_API_KEY"), openai.WithModel(model)) 51 | } 52 | 53 | if err != nil { 54 | panic(err) 55 | } 56 | 57 | ssn, err := client.NewSession(ctx) 58 | if err != nil { 59 | panic(err) 60 | } 61 | 62 | result, err := ssn.GenerateContent(ctx, gollem.Text(prompt)) 63 | if err != nil { 64 | panic(err) 65 | } 66 | 67 | fmt.Println(result.Texts) 68 | ``` 69 | 70 | #### Generate embedding 71 | ```go 72 | // Create OpenAI client 73 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY")) 74 | if err != nil { 75 | panic(err) 76 | } 77 | 78 | embedding, err := client.GenerateEmbedding(ctx, 100, []string{"Hello, world!", "This is a test"}) 79 | if err != nil { 80 | panic(err) 81 | } 82 | 83 | // Print two embedding arrays, each containing 100-dimensional vectors 84 | fmt.Println("embedding:", embedding) 85 | ``` 86 | 87 | #### Agentic application with MCP server 88 | 89 | Here's a simple example of creating a custom tool and using it with an LLM: 90 | 91 | ```go 92 | func main() { 93 | ctx := context.Background() 94 | 95 | // Create OpenAI client 96 | client, err := OpenAI.New(ctx, os.Getenv("OPENAI_API_KEY")) 97 | if err != nil { 98 | panic(err) 99 | } 100 | 101 | // Create MCP client with local server 102 | mcpLocal, err := mcp.NewStdio(ctx, "/path/to/mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"})) 103 | if err != nil { 104 | panic(err) 105 | } 106 | defer mcpLocal.Close() 107 | 108 | // Create MCP client with remote server 109 | mcpRemote, err := mcp.NewSSE(ctx, "http://localhost:8080") 110 | if err != nil { 111 | panic(err) 112 | } 113 | defer mcpRemote.Close() 114 | 115 | // Create gollem instance 116 | agent := gollem.New(client, 117 | gollem.WithToolSets(mcpLocal, mcpRemote), 118 | gollem.WithTools(&MyTool{}), 119 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 120 | fmt.Printf("🤖 %s\n", msg) 121 | return nil 122 | }), 123 | ) 124 | 125 | var history *gollem.History 126 | for { 127 | fmt.Print("> ") 128 | scanner := bufio.NewScanner(os.Stdin) 129 | scanner.Scan() 130 | 131 | newHistory, err := agent.Prompt(ctx, scanner.Text(), history) 132 | if err != nil { 133 | panic(err) 134 | } 135 | history = newHistory 136 | } 137 | } 138 | ``` 139 | 140 | See the full example in [examples/basic](https://github.com/m-mizutani/gollem/tree/main/examples/basic), and more examples in [examples](https://github.com/m-mizutani/gollem/tree/main/examples). 141 | 142 | ## Documentation 143 | 144 | For more details and examples, visit our [document](https://github.com/m-mizutani/gollem/tree/main/doc) and [godoc](https://pkg.go.dev/github.com/m-mizutani/gollem). 145 | 146 | ## License 147 | 148 | Apache 2.0 License. See [LICENSE](LICENSE) for details. 149 | -------------------------------------------------------------------------------- /Taskfile.yml: -------------------------------------------------------------------------------- 1 | # https://taskfile.dev 2 | 3 | version: '3' 4 | 5 | tasks: 6 | default: 7 | deps: 8 | - mock 9 | mock: 10 | desc: Generate mock files 11 | cmds: 12 | - go run github.com/matryer/moq@v0.5.3 13 | -out ./mock/mock_gen.go 14 | -pkg mock 15 | -rm -skip-ensure -stub 16 | . 17 | LLMClient Session Tool 18 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | ) 7 | 8 | type ctxLoggerKey struct{} 9 | 10 | var defaultLogger = slog.New(slog.DiscardHandler) 11 | 12 | func ctxWithLogger(ctx context.Context, logger *slog.Logger) context.Context { 13 | return context.WithValue(ctx, ctxLoggerKey{}, logger) 14 | } 15 | 16 | func LoggerFromContext(ctx context.Context) *slog.Logger { 17 | if logger, ok := ctx.Value(ctxLoggerKey{}).(*slog.Logger); ok { 18 | return logger 19 | } 20 | return defaultLogger 21 | } 22 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # gollem Documentation 2 | 3 | gollem is a Go framework for building applications with Large Language Models (LLMs). This documentation provides comprehensive guides and examples to help you get started and make the most of gollem. 4 | 5 | ## Documentation Index 6 | 7 | - [Getting Started](getting-started.md) - Quick start guide and basic usage 8 | - [Tools](tools.md) - Creating and using custom tools 9 | - [MCP Server Integration](mcp.md) - Integrating with Model Context Protocol servers 10 | - [History](history.md) - Managing conversation history and context 11 | - [Examples](examples.md) - Practical examples and use cases 12 | 13 | ## Key Features 14 | 15 | - **Multiple LLM Support**: Works with various LLM providers including OpenAI, Claude, and Gemini 16 | - **Custom Tools**: Create and integrate your own tools for LLMs to use 17 | - **MCP Integration**: Connect with external tools and resources through Model Context Protocol 18 | - **History Management**: Maintain conversation context across sessions 19 | 20 | 21 | -------------------------------------------------------------------------------- /doc/examples.md: -------------------------------------------------------------------------------- 1 | # Practical Examples 2 | 3 | This guide provides practical examples of using gollem in various scenarios. 4 | 5 | ## Calculator Tool 6 | 7 | A more complex example using a calculator tool: 8 | 9 | ```go 10 | type CalculatorTool struct{} 11 | 12 | func (t *CalculatorTool) Spec() gollem.ToolSpec { 13 | return gollem.ToolSpec{ 14 | Name: "calculator", 15 | Description: "Performs basic arithmetic operations", 16 | Parameters: map[string]*gollem.Parameter{ 17 | "operation": { 18 | Name: "operation", 19 | Type: gollem.TypeString, 20 | Description: "The operation to perform (add, subtract, multiply, divide)", 21 | Required: true, 22 | }, 23 | "a": { 24 | Name: "a", 25 | Type: gollem.TypeNumber, 26 | Description: "First number", 27 | Required: true, 28 | }, 29 | "b": { 30 | Name: "b", 31 | Type: gollem.TypeNumber, 32 | Description: "Second number", 33 | Required: true, 34 | }, 35 | }, 36 | } 37 | } 38 | 39 | func (t *CalculatorTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 40 | // Validate operation 41 | op, ok := args["operation"].(string) 42 | if !ok { 43 | return nil, errors.New("operation must be a string") 44 | } 45 | if op != "add" && op != "subtract" && op != "multiply" && op != "divide" { 46 | return nil, errors.New("invalid operation: must be one of add, subtract, multiply, divide") 47 | } 48 | 49 | // Validate first number 50 | a, ok := args["a"].(float64) 51 | if !ok { 52 | return nil, errors.New("first number must be a number") 53 | } 54 | 55 | // Validate second number 56 | b, ok := args["b"].(float64) 57 | if !ok { 58 | return nil, errors.New("second number must be a number") 59 | } 60 | 61 | // Validate division by zero 62 | if op == "divide" && b == 0 { 63 | return nil, errors.New("division by zero is not allowed") 64 | } 65 | 66 | var result float64 67 | switch op { 68 | case "add": 69 | result = a + b 70 | case "subtract": 71 | result = a - b 72 | case "multiply": 73 | result = a * b 74 | case "divide": 75 | result = a / b 76 | } 77 | 78 | return map[string]any{ 79 | "result": result, 80 | }, nil 81 | } 82 | ``` 83 | 84 | ## Weather Tool with MCP 85 | 86 | An example of a weather tool using MCP: 87 | 88 | ```go 89 | type WeatherTool struct{} 90 | 91 | func (t *WeatherTool) Spec() gollem.ToolSpec { 92 | return gollem.ToolSpec{ 93 | Name: "weather", 94 | Description: "Gets weather information for a location", 95 | Parameters: map[string]*gollem.Parameter{ 96 | "location": { 97 | Type: gollem.TypeString, 98 | Description: "City name or coordinates", 99 | Required: true, 100 | }, 101 | }, 102 | } 103 | } 104 | 105 | func (t *WeatherTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 106 | location, ok := args["location"].(string) 107 | if !ok { 108 | return nil, fmt.Errorf("location must be a string") 109 | } 110 | // Implement weather API call here 111 | return map[string]any{ 112 | "temperature": 25.5, 113 | "condition": "sunny", 114 | }, nil 115 | } 116 | ``` 117 | 118 | ## Best Practices 119 | 120 | 1. **Error Handling**: Always handle errors properly 121 | 2. **Type Safety**: Use appropriate types and validate input 122 | 3. **Documentation**: Document your tools and examples 123 | 4. **Testing**: Write tests for your tools 124 | 5. **Security**: Implement proper security measures 125 | 126 | ## Argument Validation 127 | 128 | Here's an example of proper argument validation: 129 | 130 | ```go 131 | func (t *CalculatorTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 132 | // Validate operation 133 | op, ok := args["operation"].(string) 134 | if !ok { 135 | return nil, errors.New("operation must be a string") 136 | } 137 | if op != "add" && op != "subtract" && op != "multiply" && op != "divide" { 138 | return nil, errors.New("invalid operation: must be one of add, subtract, multiply, divide") 139 | } 140 | 141 | // Validate first number 142 | a, ok := args["a"].(float64) 143 | if !ok { 144 | return nil, errors.New("first number must be a number") 145 | } 146 | 147 | // Validate second number 148 | b, ok := args["b"].(float64) 149 | if !ok { 150 | return nil, errors.New("second number must be a number") 151 | } 152 | 153 | // Validate division by zero 154 | if op == "divide" && b == 0 { 155 | return nil, errors.New("division by zero is not allowed") 156 | } 157 | 158 | var result float64 159 | switch op { 160 | case "add": 161 | result = a + b 162 | case "subtract": 163 | result = a - b 164 | case "multiply": 165 | result = a * b 166 | case "divide": 167 | result = a / b 168 | } 169 | 170 | return map[string]any{ 171 | "result": result, 172 | }, nil 173 | } 174 | ``` 175 | 176 | ## Next Steps 177 | 178 | - Learn more about [tool creation](tools.md) 179 | - Explore [MCP server integration](mcp.md) 180 | - Check out the [getting started guide](getting-started.md) 181 | - Understand [history management](history.md) for conversation context 182 | - Review the [complete documentation](README.md) 183 | -------------------------------------------------------------------------------- /doc/getting-started.md: -------------------------------------------------------------------------------- 1 | # Getting Started with gollem 2 | 3 | gollem is a Go framework for building applications with Large Language Models (LLMs). This guide will help you get started with the framework. 4 | 5 | ## Installation 6 | 7 | Install gollem using Go modules: 8 | 9 | ```bash 10 | go get github.com/m-mizutani/gollem 11 | ``` 12 | 13 | ## Basic Usage 14 | 15 | Here's a simple example of how to use gollem with OpenAI's OpenAI model: 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "context" 22 | "fmt" 23 | "os" 24 | 25 | "github.com/m-mizutani/gollem" 26 | "github.com/m-mizutani/gollem/llm/openai" 27 | "github.com/m-mizutani/gollem/mcp" 28 | ) 29 | 30 | func main() { 31 | // Create OpenAI client 32 | client, err := OpenAI.New(context.Background(), os.Getenv("OPENAI_API_KEY")) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | // Create MCP client 38 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080") 39 | if err != nil { 40 | panic(err) 41 | } 42 | defer mcpClient.Close() 43 | 44 | // Create gollem instance 45 | s := gollem.New(client, 46 | gollem.WithToolSets(mcpClient), 47 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 48 | fmt.Println(msg) 49 | return nil 50 | }), 51 | ) 52 | 53 | // Send a message to the LLM 54 | if err := s.Prompt(context.Background(), "Hello, how are you?"); err != nil { 55 | panic(err) 56 | } 57 | } 58 | ``` 59 | 60 | This code uses the OpenAI OpenAI model to receive a message from the user and send it to the LLM. Here, we are not specifying a Tool or MCP server, so the LLM is expected to return only a message. 61 | 62 | For information on how to integrate with Tools and MCP servers, please refer to [tools](tools.md) and [mcp](mcp.md) documents. 63 | 64 | ## Supported LLM Providers 65 | 66 | gollem supports multiple LLM providers: 67 | 68 | - Gemini 69 | - Anthropic (Claude) 70 | - OpenAI (OpenAI) 71 | 72 | Each provider has its own client implementation in the `llm` package. See the respective documentation for configuration options. 73 | 74 | ## Key Concepts 75 | 76 | 1. **LLM Client**: The interface to communicate with LLM providers 77 | 2. **Tools**: Custom functions that LLMs can use to perform actions (see [Tools](tools.md)) 78 | 3. **MCP Server**: External tool integration through Model Context Protocol (see [MCP Server Integration](mcp.md)) 79 | 4. **Natural Language Interface**: Interact with your application using natural language 80 | 5. **History Management**: Maintain conversation context across sessions (see [History](history.md)) 81 | 82 | ## Error Handling 83 | 84 | gollem provides robust error handling capabilities to help you build reliable applications: 85 | 86 | ### Error Types 87 | - **LLM Errors**: Errors from the LLM provider (e.g., rate limits, invalid requests) 88 | - **Tool Execution Errors**: Errors during tool execution 89 | - **MCP Server Errors**: Errors from MCP server communication 90 | 91 | ### Best Practices 92 | 1. **Graceful Degradation**: Implement fallback mechanisms when LLM or tools fail 93 | 2. **Retry Strategies**: Use exponential backoff for transient errors 94 | 3. **Error Logging**: Log errors with appropriate context for debugging 95 | 4. **User Feedback**: Provide clear error messages to end users 96 | 97 | Example of error handling: 98 | ```go 99 | newHistory, err := s.Prompt(ctx, userInput, history) 100 | if err != nil { 101 | // Handle errors 102 | log.Printf("Error: %v", err) 103 | return nil, fmt.Errorf("failed to process request: %w", err) 104 | } 105 | ``` 106 | 107 | ## Context Management 108 | 109 | gollem provides a history-based context management system to maintain conversation state: 110 | 111 | ### History Object 112 | The `History` object maintains the conversation context, including: 113 | - Previous messages 114 | - Tool execution results 115 | - System prompts 116 | - Context metadata 117 | 118 | ### Best Practices 119 | 1. **Memory Management**: Clear old history when it exceeds size limits 120 | 2. **Context Persistence**: Save important context for future sessions 121 | 3. **Context Pruning**: Remove irrelevant information to maintain focus 122 | 123 | Example of history management: 124 | ```go 125 | // Initialize history 126 | var history *gollem.History 127 | 128 | // Process user input with history 129 | newHistory, err := s.Prompt(ctx, userInput, history) 130 | if err != nil { 131 | return nil, err 132 | } 133 | 134 | // Update history 135 | history = newHistory 136 | ``` 137 | 138 | ## Next Steps 139 | 140 | - Learn how to create and use [custom tools](tools.md) 141 | - Explore [MCP server integration](mcp.md) 142 | - Check out [practical examples](examples.md) 143 | - Understand [history management](history.md) for conversation context 144 | - Review the [complete documentation](README.md) 145 | -------------------------------------------------------------------------------- /doc/history.md: -------------------------------------------------------------------------------- 1 | # History 2 | 3 | History represents a conversation history that can be used across different LLM sessions. It stores messages in a format specific to each LLM type (OpenAI, Claude, or Gemini). 4 | 5 | ## Version Management 6 | 7 | History includes version information to ensure compatibility: 8 | 9 | - Current version: 1 10 | - Version checking is performed when converting between formats 11 | - Version mismatch will result in an error 12 | - This helps maintain compatibility when the History structure changes in future updates 13 | 14 | ## Session Persistence 15 | 16 | History is essential for maintaining conversation context across stateless sessions. Common use cases include: 17 | 18 | - Backend services handling stateless HTTP requests 19 | - When your backend service receives requests from different instances or after restarts 20 | - When you need to maintain conversation context across multiple API calls 21 | - Distributed systems 22 | - When sessions may be handled by different instances 23 | - When you need to load balance conversations across multiple servers 24 | - Long-running conversations 25 | - When conversations need to be resumed after service restarts 26 | - When implementing features like "continue previous conversation" 27 | 28 | ## Portability 29 | 30 | History can be easily serialized/deserialized using standard JSON marshaling. This enables: 31 | 32 | - Storing conversations in databases 33 | - Persist conversations for future reference 34 | - Implement conversation history features 35 | - Transferring conversations between services 36 | - Move conversations between different environments 37 | - Share conversations across microservices 38 | - Implementing conversation backup and restore features 39 | - Backup important conversations 40 | - Restore conversations after system failures 41 | 42 | ## LLM Type Compatibility 43 | 44 | Each History instance is tied to a specific LLM type (OpenAI, Claude, or Gemini). Important notes: 45 | 46 | - Direct conversion between different LLM types is not supported 47 | - Each LLM type has its own message format and capabilities 48 | 49 | ## Usage Guidelines 50 | 51 | 1. Get History from Prompt response: 52 | ```go 53 | // Create a new gollem instance 54 | g := gollem.New(llmClient) 55 | 56 | // Get response from Prompt 57 | history, err := g.Prompt(ctx, "What is the weather?") 58 | if err != nil { 59 | return nil, fmt.Errorf("failed to get prompt response: %w", err) 60 | } 61 | ``` 62 | 63 | 2. Store the History for future use: 64 | ```go 65 | // Store history in your database or storage 66 | jsonData, err := json.Marshal(history) 67 | if err != nil { 68 | return fmt.Errorf("failed to marshal history: %w", err) 69 | } 70 | ``` 71 | 72 | 3. Use stored History in a new session: 73 | ```go 74 | // Restore history 75 | var restoredHistory History 76 | if err := json.Unmarshal(jsonData, &restoredHistory); err != nil { 77 | return fmt.Errorf("failed to unmarshal history: %w", err) 78 | } 79 | 80 | // Use history in next Prompt call 81 | newHistory, err := g.Prompt(ctx, "What about tomorrow?", gollem.WithHistory(&restoredHistory)) 82 | if err != nil { 83 | return nil, fmt.Errorf("failed to get prompt response: %w", err) 84 | } 85 | ``` 86 | 87 | Note: The History returned from Prompt contains the complete conversation history, so there's no need to manage or track individual messages. Each Prompt response provides a new History instance that includes all previous messages. 88 | 89 | ## Best Practices 90 | 91 | 1. **Error Handling** 92 | - Always check for errors when converting between formats 93 | - Handle type mismatches gracefully 94 | - Check for version compatibility 95 | - Implement proper error handling for version mismatches 96 | 97 | 2. **Storage Considerations** 98 | - Consider the size of your conversations 99 | - Implement cleanup strategies for old conversations 100 | 101 | 3. **Security** 102 | - Implement proper access controls 103 | 104 | ## Next Steps 105 | 106 | - Learn more about [tool creation](tools.md) 107 | - Explore [MCP server integration](mcp.md) 108 | - Check out [practical examples](examples.md) 109 | - Review the [getting started guide](getting-started.md) 110 | - Explore the [complete documentation](README.md) 111 | 112 | -------------------------------------------------------------------------------- /doc/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-mizutani/gollem/0c5d4dbbe006a021f126c22e2a60ec55e3f9a79e/doc/images/logo.png -------------------------------------------------------------------------------- /doc/mcp.md: -------------------------------------------------------------------------------- 1 | # MCP Server Integration 2 | 3 | gollem supports integration with MCP (Model Context Protocol) servers, allowing you to extend LLM capabilities with external tools and resources. 4 | 5 | ## What is MCP? 6 | 7 | MCP is a protocol that enables LLMs to interact with external tools and resources through a standardized interface. It provides a way to: 8 | 9 | - Define and expose custom tools 10 | - Manage external resources 11 | - Customize LLM prompts 12 | 13 | ## Connecting to an MCP Server 14 | 15 | To connect your gollem application to an MCP server, you can use either HTTP SSE or stdio transport: 16 | 17 | ```go 18 | // Using HTTP SSE transport 19 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080") 20 | if err != nil { 21 | panic(err) 22 | } 23 | defer mcpClient.Close() 24 | 25 | s := gollem.New(client, 26 | gollem.WithToolSets(mcpClient), 27 | ) 28 | 29 | // Using stdio transport 30 | mcpClient, err := mcp.NewStdio(context.Background(), "/path/to/mcp/server", []string{"--arg1", "value1"}) 31 | if err != nil { 32 | panic(err) 33 | } 34 | defer mcpClient.Close() 35 | 36 | s := gollem.New(client, 37 | gollem.WithToolSets(mcpClient), 38 | ) 39 | ``` 40 | 41 | ## Options 42 | 43 | 1. **Environment Variables**: Set environment variables for the MCP client 44 | ```go 45 | mcpClient, err := mcp.NewStdio(context.Background(), "/path/to/mcp/server", []string{}, 46 | mcp.WithEnvVars([]string{"MCP_ENV=test"}), 47 | ) 48 | ``` 49 | 50 | 2. **HTTP Headers**: Set custom HTTP headers for SSE transport 51 | ```go 52 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080", 53 | mcp.WithHeaders(map[string]string{ 54 | "Authorization": "Bearer token", 55 | }), 56 | ) 57 | ``` 58 | 59 | ## Next Steps 60 | 61 | - Learn more about [tool creation](tools.md) 62 | - Check out [practical examples](examples.md) of MCP integration 63 | - Review the [getting started guide](getting-started.md) for basic usage 64 | - Understand [history management](history.md) for conversation context 65 | - Explore the [complete documentation](README.md) -------------------------------------------------------------------------------- /doc/tools.md: -------------------------------------------------------------------------------- 1 | # Tools in gollem 2 | 3 | Tools are your own custom built-in functions that LLMs can use to perform specific actions in your application. This guide explains how to create and use tools with gollem. 4 | 5 | ## Creating a Tool 6 | 7 | To create a tool, you need to implement the `Tool` interface: 8 | 9 | ```go 10 | type Tool interface { 11 | Spec() ToolSpec 12 | Run(ctx context.Context, args map[string]any) (map[string]any, error) 13 | } 14 | ``` 15 | 16 | Here's an example of a simple tool: 17 | 18 | ```go 19 | type HelloTool struct{} 20 | 21 | func (t *HelloTool) Spec() gollem.ToolSpec { 22 | return gollem.ToolSpec{ 23 | Name: "hello", 24 | Description: "Returns a greeting", 25 | Parameters: map[string]*gollem.Parameter{ 26 | "name": { 27 | Type: gollem.TypeString, 28 | Description: "Name of the person to greet", 29 | }, 30 | }, 31 | Required: []string{"name"}, 32 | } 33 | } 34 | 35 | func (t *HelloTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 36 | return map[string]any{ 37 | "message": fmt.Sprintf("Hello, %s!", args["name"]), 38 | }, nil 39 | } 40 | ``` 41 | 42 | ## Tool Specification 43 | 44 | The `ToolSpec` defines the tool's interface: 45 | 46 | - `Name`: Unique identifier for the tool 47 | - `Description`: Human-readable description of what the tool does 48 | - `Parameters`: Map of parameter names to their specifications 49 | - `Required`: List of required parameter names (For Object type) 50 | 51 | Each parameter specification includes: 52 | - `Type`: Parameter type (string, number, boolean, etc.) 53 | - `Description`: Human-readable description 54 | - `Title`: Optional user-friendly name for the parameter 55 | - `Required`: Optional boolean indicating if the parameter is required 56 | - `RequiredFields`: List of required field names when Type is Object 57 | - `Enum`: Optional list of allowed values 58 | - `Properties`: Map of properties when Type is Object 59 | - `Items`: Specification for array items when Type is Array 60 | - `Minimum`/`Maximum`: Number constraints 61 | - `MinLength`/`MaxLength`: String length constraints 62 | - `Pattern`: Regular expression pattern for string validation 63 | - `MinItems`/`MaxItems`: Array size constraints 64 | - `Default`: Default value for the parameter 65 | 66 | > [!CAUTION] 67 | > Note that not all parameters are supported by every LLM, as parameter support varies between different LLM providers. 68 | 69 | ## Using Tools 70 | 71 | To use tools with your LLM: 72 | 73 | ```go 74 | s := gollem.New(client, 75 | gollem.WithTools(&HelloTool{}), 76 | ) 77 | ``` 78 | 79 | You can add multiple tools: 80 | 81 | ```go 82 | s := gollem.New(client, 83 | gollem.WithTools(&HelloTool{}, &CalculatorTool{}, &WeatherTool{}), 84 | ) 85 | ``` 86 | 87 | ## Best Practices 88 | 89 | 1. **Clear Descriptions**: Provide clear and concise descriptions for tools and parameters to help the LLM understand their purpose and usage 90 | 2. **Validate Input**: Always validate that parameters passed by the LLM match the specified types in your tool specification 91 | 3. **Error Handling**: When errors occur, return clear error messages that explain both what went wrong and how to fix it. The Error() message will be passed directly to the LLM, so include actionable guidance 92 | 4. **Nested Results**: For nested tool results with multiple levels of maps, always use `map[string]any`. Other types may cause errors with some LLM SDKs like Gemini 93 | 94 | ## Next Steps 95 | 96 | - Learn about [MCP server integration](mcp.md) for external tool integration 97 | - Check out [practical examples](examples.md) of tool usage 98 | - Review the [getting started guide](getting-started.md) for basic usage 99 | - Understand [history management](history.md) for conversation context 100 | - Explore the [complete documentation](README.md) 101 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrInvalidTool is returned when the tool validation of definition fails. 7 | ErrInvalidTool = errors.New("invalid tool specification") 8 | 9 | // ErrInvalidParameter is returned when the parameter validation of definition fails. 10 | ErrInvalidParameter = errors.New("invalid parameter") 11 | 12 | // ErrToolNameConflict is returned when the tool name is already used. 13 | ErrToolNameConflict = errors.New("tool name conflict") 14 | 15 | // ErrLoopLimitExceeded is returned when the session loop limit is exceeded. You can resume the session by calling the Prompt() method again. 16 | ErrLoopLimitExceeded = errors.New("loop limit exceeded") 17 | 18 | // ErrInvalidInputSchema is returned when the input schema from MCP is invalid or unsupported. 19 | ErrInvalidInputSchema = errors.New("invalid input schema") 20 | 21 | // ErrInvalidHistoryData is returned when the history data is invalid or unsupported. 22 | ErrInvalidHistoryData = errors.New("invalid history data") 23 | 24 | // ErrLLMTypeMismatch is returned when the LLM type is invalid or unsupported when loading history. 25 | ErrLLMTypeMismatch = errors.New("llm type mismatch") 26 | 27 | // ErrHistoryVersionMismatch is returned when the history version is invalid or unsupported. 28 | ErrHistoryVersionMismatch = errors.New("history version mismatch") 29 | ) 30 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Gollem Examples 2 | 3 | This directory contains various examples demonstrating the usage of Gollem. 4 | 5 | ## Basic Example 6 | [Basic Example](basic/main.go) - A simple example showing the basic usage of Gollem. 7 | 8 | ## Chat Example 9 | [Chat Example](chat/main.go) - An example demonstrating chat functionality with Gollem. 10 | 11 | ## MCP Example 12 | [MCP Example](mcp/main.go) - An example showing how to use MCP (Model Control Protocol) with Gollem. 13 | 14 | ## Tools Example 15 | [Tools Example](tools/main.go) - An example demonstrating how to use tools with Gollem. 16 | -------------------------------------------------------------------------------- /examples/basic/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/m-mizutani/gollem" 10 | "github.com/m-mizutani/gollem/llm/openai" 11 | "github.com/m-mizutani/gollem/mcp" 12 | ) 13 | 14 | type MyTool struct{} 15 | 16 | func (t *MyTool) Spec() gollem.ToolSpec { 17 | return gollem.ToolSpec{ 18 | Name: "my_tool", 19 | Description: "Returns a greeting", 20 | Parameters: map[string]*gollem.Parameter{ 21 | "name": { 22 | Type: gollem.TypeString, 23 | Description: "Name of the person to greet", 24 | }, 25 | }, 26 | } 27 | } 28 | 29 | func (t *MyTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 30 | name, ok := args["name"].(string) 31 | if !ok { 32 | return nil, fmt.Errorf("name is required") 33 | } 34 | return map[string]any{"message": fmt.Sprintf("Hello, %s!", name)}, nil 35 | } 36 | func main() { 37 | ctx := context.Background() 38 | 39 | // Create OpenAI client 40 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY")) 41 | if err != nil { 42 | panic(err) 43 | } 44 | 45 | // Create MCP client with local server 46 | mcpLocal, err := mcp.NewStdio(ctx, "./mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"})) 47 | if err != nil { 48 | panic(err) 49 | } 50 | defer mcpLocal.Close() 51 | 52 | // Create MCP client with remote server 53 | mcpRemote, err := mcp.NewSSE(ctx, "http://localhost:8080") 54 | if err != nil { 55 | panic(err) 56 | } 57 | defer mcpRemote.Close() 58 | 59 | // Create gollem instance 60 | agent := gollem.New(client, 61 | // Not only MCP servers, 62 | gollem.WithToolSets(mcpLocal, mcpRemote), 63 | // But also you can use your own built-in tools 64 | gollem.WithTools(&MyTool{}), 65 | // You can customize the callback function for each message and tool call. 66 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 67 | fmt.Printf("🤖 %s\n", msg) 68 | return nil 69 | }), 70 | ) 71 | 72 | var history *gollem.History 73 | for { 74 | fmt.Print("> ") 75 | scanner := bufio.NewScanner(os.Stdin) 76 | scanner.Scan() 77 | 78 | newHistory, err := agent.Prompt(ctx, scanner.Text(), gollem.WithHistory(history)) 79 | if err != nil { 80 | panic(err) 81 | } 82 | history = newHistory 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "os" 10 | "path/filepath" 11 | 12 | "github.com/m-mizutani/gollem" 13 | "github.com/m-mizutani/gollem/llm/gemini" 14 | ) 15 | 16 | // WeatherTool is a simple tool that returns a weather 17 | type WeatherTool struct{} 18 | 19 | func (t *WeatherTool) Spec() gollem.ToolSpec { 20 | return gollem.ToolSpec{ 21 | Name: "weather", 22 | Description: "Returns a weather", 23 | Parameters: map[string]*gollem.Parameter{ 24 | "city": { 25 | Type: gollem.TypeString, 26 | Description: "City name", 27 | }, 28 | }, 29 | Required: []string{"city"}, 30 | } 31 | } 32 | 33 | func (t *WeatherTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 34 | city, ok := args["city"].(string) 35 | if !ok { 36 | return nil, fmt.Errorf("city is required") 37 | } 38 | 39 | return map[string]any{ 40 | "message": fmt.Sprintf("The weather in %s is sunny.", city), 41 | }, nil 42 | } 43 | 44 | func main() { 45 | ctx := context.Background() 46 | llmModel, err := gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION")) 47 | if err != nil { 48 | panic(err) 49 | } 50 | 51 | g := gollem.New(llmModel, 52 | gollem.WithResponseMode(gollem.ResponseModeStreaming), 53 | gollem.WithTools(&WeatherTool{}), 54 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 55 | fmt.Printf("%s", msg) 56 | return nil 57 | }), 58 | gollem.WithToolRequestHook(func(ctx context.Context, tool gollem.FunctionCall) error { 59 | fmt.Printf("⚡ Call: %s\n", tool.Name) 60 | return nil 61 | }), 62 | ) 63 | 64 | tmpFile, err := os.CreateTemp("", "gollem-chat-*.txt") 65 | if err != nil { 66 | panic(err) 67 | } 68 | if err := tmpFile.Close(); err != nil { 69 | panic(err) 70 | } 71 | println("history file:", tmpFile.Name()) 72 | 73 | for { 74 | history, err := loadHistory(tmpFile.Name()) 75 | if err != nil { 76 | panic(err) 77 | } 78 | 79 | fmt.Print("> ") 80 | scanner := bufio.NewScanner(os.Stdin) 81 | scanner.Scan() 82 | text := scanner.Text() 83 | 84 | fmt.Printf("🤖 ") 85 | newHistory, err := g.Prompt(ctx, text, gollem.WithHistory(history)) 86 | if err != nil { 87 | panic(err) 88 | } 89 | 90 | if err := dumpHistory(newHistory, tmpFile.Name()); err != nil { 91 | panic(err) 92 | } 93 | 94 | fmt.Printf("\n") 95 | } 96 | } 97 | 98 | func dumpHistory(history *gollem.History, path string) error { 99 | f, err := os.Create(filepath.Clean(path)) 100 | if err != nil { 101 | return err 102 | } 103 | defer f.Close() 104 | 105 | if err := json.NewEncoder(f).Encode(history); err != nil { 106 | return err 107 | } 108 | 109 | return nil 110 | } 111 | 112 | func loadHistory(path string) (*gollem.History, error) { 113 | if st, err := os.Stat(path); err != nil || st.Size() == 0 { 114 | return nil, err 115 | } 116 | 117 | f, err := os.Open(filepath.Clean(path)) 118 | if err != nil { 119 | return nil, err 120 | } 121 | defer f.Close() 122 | 123 | data, err := io.ReadAll(f) 124 | if err != nil { 125 | return nil, err 126 | } 127 | 128 | var history gollem.History 129 | if err := json.Unmarshal(data, &history); err != nil { 130 | return nil, err 131 | } 132 | 133 | return &history, nil 134 | } 135 | -------------------------------------------------------------------------------- /examples/embedding/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/m-mizutani/gollem/llm/openai" 9 | ) 10 | 11 | func main() { 12 | ctx := context.Background() 13 | 14 | // Create OpenAI client 15 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY")) 16 | if err != nil { 17 | panic(err) 18 | } 19 | 20 | embedding, err := client.GenerateEmbedding(ctx, 100, []string{"Hello, world!", "This is a test"}) 21 | if err != nil { 22 | panic(err) 23 | } 24 | fmt.Println("embedding:", embedding) 25 | } 26 | -------------------------------------------------------------------------------- /examples/mcp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | 8 | "github.com/m-mizutani/gollem" 9 | "github.com/m-mizutani/gollem/llm/openai" 10 | "github.com/m-mizutani/gollem/mcp" 11 | ) 12 | 13 | func main() { 14 | ctx := context.Background() 15 | 16 | // Create OpenAI client 17 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY")) 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | // Create MCP client (SSE) 23 | sseClient, err := mcp.NewSSE(ctx, "http://localhost:8080") 24 | if err != nil { 25 | log.Fatalf("Failed to create SSE client: %v", err) 26 | } 27 | defer sseClient.Close() 28 | 29 | // Create MCP client (Stdio) 30 | stdioClient, err := mcp.NewStdio(ctx, "./mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"})) 31 | if err != nil { 32 | log.Fatalf("Failed to create Stdio client: %v", err) 33 | } 34 | defer stdioClient.Close() 35 | 36 | // Create gollem instance with MCP tools 37 | g := gollem.New(client, 38 | gollem.WithToolSets(sseClient, stdioClient), 39 | ) 40 | 41 | // Send a message 42 | if _, err := g.Prompt(ctx, "Hello, I want to use MCP tools."); err != nil { 43 | panic(err) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /examples/query/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/m-mizutani/gollem" 9 | "github.com/m-mizutani/gollem/llm/claude" 10 | "github.com/m-mizutani/gollem/llm/gemini" 11 | "github.com/m-mizutani/gollem/llm/openai" 12 | ) 13 | 14 | func main() { 15 | ctx := context.Background() 16 | 17 | if len(os.Args) != 4 { 18 | fmt.Println("Usage: go run main.go ") 19 | os.Exit(1) 20 | } 21 | 22 | llmProvider := os.Args[1] 23 | model := os.Args[2] 24 | prompt := os.Args[3] 25 | 26 | var client gollem.LLMClient 27 | var err error 28 | 29 | switch llmProvider { 30 | case "gemini": 31 | client, err = gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"), gemini.WithModel(model)) 32 | case "claude": 33 | client, err = claude.New(ctx, os.Getenv("ANTHROPIC_API_KEY"), claude.WithModel(model)) 34 | case "openai": 35 | client, err = openai.New(ctx, os.Getenv("OPENAI_API_KEY"), openai.WithModel(model)) 36 | } 37 | 38 | if err != nil { 39 | panic(err) 40 | } 41 | 42 | ssn, err := client.NewSession(ctx) 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | result, err := ssn.GenerateContent(ctx, gollem.Text(prompt)) 48 | if err != nil { 49 | panic(err) 50 | } 51 | 52 | fmt.Println(result.Texts) 53 | } 54 | -------------------------------------------------------------------------------- /examples/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/m-mizutani/gollem" 10 | "github.com/m-mizutani/gollem/llm/openai" 11 | "github.com/m-mizutani/gollem/mcp" 12 | ) 13 | 14 | func main() { 15 | ctx := context.Background() 16 | 17 | // Create OpenAI client 18 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY")) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | // Create MCP client with local server 24 | mcpLocal, err := mcp.NewStdio(ctx, "./mcp-server", []string{"arg1", "arg2"}) 25 | if err != nil { 26 | panic(err) 27 | } 28 | defer mcpLocal.Close() 29 | 30 | // Create gollem instance 31 | agent := gollem.New(client, 32 | gollem.WithToolSets(mcpLocal), 33 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 34 | fmt.Printf("🤖 %s\n", msg) 35 | return nil 36 | }), 37 | ) 38 | 39 | fmt.Print("> ") 40 | scanner := bufio.NewScanner(os.Stdin) 41 | scanner.Scan() 42 | 43 | if _, err = agent.Prompt(ctx, scanner.Text()); err != nil { 44 | panic(err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /examples/tools/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | 8 | "github.com/m-mizutani/gollem" 9 | "github.com/m-mizutani/gollem/llm/gemini" 10 | ) 11 | 12 | func main() { 13 | ctx := context.Background() 14 | 15 | // Initialize Gemini client 16 | client, err := gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION")) 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | 21 | // Register tools 22 | tools := []gollem.Tool{ 23 | &AddTool{}, 24 | &MultiplyTool{}, 25 | } 26 | 27 | servant := gollem.New(client, 28 | gollem.WithTools(tools...), 29 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 30 | log.Printf("Response: %s", msg) 31 | return nil 32 | }), 33 | /* 34 | gollem.WithLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 35 | Level: slog.LevelDebug, 36 | }))), 37 | */ 38 | ) 39 | 40 | query := "Add 5 and 3, then multiply the result by 2" 41 | log.Printf("Query: %s", query) 42 | if _, err := servant.Prompt(ctx, query); err != nil { 43 | log.Fatal(err) 44 | } 45 | } 46 | 47 | // AddTool is a tool that adds two numbers 48 | type AddTool struct{} 49 | 50 | func (t *AddTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 51 | a := args["a"].(float64) 52 | b := args["b"].(float64) 53 | log.Printf("Add: %f + %f", a, b) 54 | return map[string]any{"result": a + b}, nil 55 | } 56 | 57 | func (t *AddTool) Spec() gollem.ToolSpec { 58 | return gollem.ToolSpec{ 59 | Name: "add", 60 | Description: "Adds two numbers together", 61 | Parameters: map[string]*gollem.Parameter{ 62 | "a": { 63 | Type: "number", 64 | Description: "First number", 65 | }, 66 | "b": { 67 | Type: "number", 68 | Description: "Second number", 69 | }, 70 | }, 71 | } 72 | } 73 | 74 | // MultiplyTool is a tool that multiplies two numbers 75 | type MultiplyTool struct{} 76 | 77 | func (t *MultiplyTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 78 | a := args["a"].(float64) 79 | b := args["b"].(float64) 80 | log.Printf("Multiply: %f * %f", a, b) 81 | return map[string]any{"result": a * b}, nil 82 | } 83 | 84 | func (t *MultiplyTool) Spec() gollem.ToolSpec { 85 | return gollem.ToolSpec{ 86 | Name: "multiply", 87 | Description: "Multiplies two numbers together", 88 | Parameters: map[string]*gollem.Parameter{ 89 | "a": { 90 | Type: "number", 91 | Description: "First number", 92 | }, 93 | "b": { 94 | Type: "number", 95 | Description: "Second number", 96 | }, 97 | }, 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /exit.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import "context" 4 | 5 | // ExitTool is a tool that can be used to exit the session. IsCompleted() is called before calling a method to generate content every loop. If IsCompleted() returns true, the session will be ended. 6 | type ExitTool interface { 7 | Tool 8 | IsCompleted() bool 9 | Response() string 10 | } 11 | 12 | // DefaultExitTool is the tool to stop the session loop. This tool is used when the agent determines that the session should be ended. The tool name is "respond_to_user". 13 | type DefaultExitTool struct { 14 | isCompleted bool 15 | response string 16 | } 17 | 18 | func (t *DefaultExitTool) Spec() ToolSpec { 19 | return ToolSpec{ 20 | Name: "respond_to_user", 21 | Description: "Call this tool when you have gathered all necessary information, completed all required actions, and already provided the final answer to the user's original request. This signals that your work on the current request is finished.", 22 | Parameters: map[string]*Parameter{ 23 | /* 24 | "final_answer": { 25 | Type: "string", 26 | Description: "The comprehensive final answer or result for the user's request. If you already provided the final answer, you MUST omit this parameter.", 27 | }, 28 | */ 29 | }, 30 | } 31 | 32 | } 33 | 34 | func (t *DefaultExitTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 35 | t.isCompleted = true 36 | 37 | if response, ok := args["final_answer"].(string); ok { 38 | t.response = response 39 | } 40 | 41 | return nil, nil 42 | } 43 | 44 | func (t *DefaultExitTool) IsCompleted() bool { 45 | return t.isCompleted 46 | } 47 | 48 | func (t *DefaultExitTool) Response() string { 49 | return t.response 50 | } 51 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/m-mizutani/gollem 2 | 3 | go 1.24 4 | 5 | require ( 6 | cloud.google.com/go/aiplatform v1.69.0 7 | cloud.google.com/go/vertexai v0.13.3 8 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 9 | github.com/google/uuid v1.6.0 10 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2 11 | github.com/m-mizutani/gt v0.0.16 12 | github.com/mark3labs/mcp-go v0.23.1 13 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 14 | github.com/sashabaranov/go-openai v1.38.2 15 | google.golang.org/api v0.211.0 16 | google.golang.org/protobuf v1.35.2 17 | ) 18 | 19 | require ( 20 | cloud.google.com/go v0.116.0 // indirect 21 | cloud.google.com/go/auth v0.12.1 // indirect 22 | cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect 23 | cloud.google.com/go/compute/metadata v0.5.2 // indirect 24 | cloud.google.com/go/iam v1.2.2 // indirect 25 | cloud.google.com/go/longrunning v0.6.2 // indirect 26 | github.com/dlclark/regexp2 v1.11.4 // indirect 27 | github.com/felixge/httpsnoop v1.0.4 // indirect 28 | github.com/go-logr/logr v1.4.2 // indirect 29 | github.com/go-logr/stdr v1.2.2 // indirect 30 | github.com/google/go-cmp v0.7.0 // indirect 31 | github.com/google/s2a-go v0.1.8 // indirect 32 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect 33 | github.com/googleapis/gax-go/v2 v2.14.0 // indirect 34 | github.com/rogpeppe/go-internal v1.14.1 // indirect 35 | github.com/spf13/cast v1.7.1 // indirect 36 | github.com/stretchr/testify v1.10.0 // indirect 37 | github.com/tidwall/gjson v1.18.0 // indirect 38 | github.com/tidwall/match v1.1.1 // indirect 39 | github.com/tidwall/pretty v1.2.1 // indirect 40 | github.com/tidwall/sjson v1.2.5 // indirect 41 | github.com/yosida95/uritemplate/v3 v3.0.2 // indirect 42 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect 43 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect 44 | go.opentelemetry.io/otel v1.29.0 // indirect 45 | go.opentelemetry.io/otel/metric v1.29.0 // indirect 46 | go.opentelemetry.io/otel/trace v1.29.0 // indirect 47 | golang.org/x/crypto v0.36.0 // indirect 48 | golang.org/x/net v0.38.0 // indirect 49 | golang.org/x/oauth2 v0.24.0 // indirect 50 | golang.org/x/sync v0.13.0 // indirect 51 | golang.org/x/sys v0.31.0 // indirect 52 | golang.org/x/text v0.24.0 // indirect 53 | golang.org/x/time v0.8.0 // indirect 54 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect 55 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 // indirect 56 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583 // indirect 57 | google.golang.org/grpc v1.67.3 // indirect 58 | ) 59 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= 2 | cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= 3 | cloud.google.com/go/aiplatform v1.69.0 h1:XvBzK8e6/6ufbi/i129Vmn/gVqFwbNPmRQ89K+MGlgc= 4 | cloud.google.com/go/aiplatform v1.69.0/go.mod h1:nUsIqzS3khlnWvpjfJbP+2+h+VrFyYsTm7RNCAViiY8= 5 | cloud.google.com/go/auth v0.12.1 h1:n2Bj25BUMM0nvE9D2XLTiImanwZhO3DkfWSYS/SAJP4= 6 | cloud.google.com/go/auth v0.12.1/go.mod h1:BFMu+TNpF3DmvfBO9ClqTR/SiqVIm7LukKF9mbendF4= 7 | cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= 8 | cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= 9 | cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= 10 | cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= 11 | cloud.google.com/go/iam v1.2.2 h1:ozUSofHUGf/F4tCNy/mu9tHLTaxZFLOUiKzjcgWHGIA= 12 | cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY= 13 | cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc= 14 | cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI= 15 | cloud.google.com/go/vertexai v0.13.3 h1:pbw1KfpdE8ZDrXxBKcIsS/j+EixyQRsyu6gxRkXq8/k= 16 | cloud.google.com/go/vertexai v0.13.3/go.mod h1:AxzUNrd36yhfOZedO+Y1v0ajVgGKOdv1njeQChL8IFY= 17 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 h1:b5t1ZJMvV/l99y4jbz7kRFdUp3BSDkI8EhSlHczivtw= 18 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c= 19 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 21 | github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= 22 | github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= 23 | github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= 24 | github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 25 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 26 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 27 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 28 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 29 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 30 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 31 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 32 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 33 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 34 | github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= 35 | github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= 36 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 37 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 38 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= 39 | github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= 40 | github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o= 41 | github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk= 42 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 43 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 44 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 45 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 46 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2 h1:3aA8RKFlWS6kX99+AJohdgIxiE5+Qwv6aIUD1QQXS2E= 47 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2/go.mod h1:K9+tBb0+0681yHpyHYqJmoijxn9qJf5grjRb9nMXevc= 48 | github.com/m-mizutani/gt v0.0.16 h1:bhJqMeqxojsgVBo9wqjPHk4s6o7mkaQLCcOus1hp1Vs= 49 | github.com/m-mizutani/gt v0.0.16/go.mod h1:0MPYSfGBLmYjTduzADVmIqD58ELQ5IfBFiK/f0FmB3k= 50 | github.com/mark3labs/mcp-go v0.23.1 h1:RzTzZ5kJ+HxwnutKA4rll8N/pKV6Wh5dhCmiJUu5S9I= 51 | github.com/mark3labs/mcp-go v0.23.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= 52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 54 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= 55 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= 56 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw= 57 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= 58 | github.com/sashabaranov/go-openai v1.38.2 h1:akrssjj+6DY3lWuDwHv6cBvJ8Z+FZDM9XEaaYFt0Auo= 59 | github.com/sashabaranov/go-openai v1.38.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= 60 | github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= 61 | github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 62 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 63 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 64 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 65 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= 66 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 67 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 68 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 69 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 70 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= 71 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 72 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= 73 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= 74 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= 75 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= 76 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= 77 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= 78 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= 79 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= 80 | go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= 81 | go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= 82 | go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= 83 | go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= 84 | go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= 85 | go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= 86 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= 87 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= 88 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 89 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 90 | golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= 91 | golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= 92 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= 93 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 94 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 95 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 96 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 97 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 98 | golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= 99 | golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 100 | google.golang.org/api v0.211.0 h1:IUpLjq09jxBSV1lACO33CGY3jsRcbctfGzhj+ZSE/Bg= 101 | google.golang.org/api v0.211.0/go.mod h1:XOloB4MXFH4UTlQSGuNUxw0UT74qdENK8d6JNsXKLi0= 102 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 h1:ToEetK57OidYuqD4Q5w+vfEnPvPpuTwedCNVohYJfNk= 103 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697/go.mod h1:JJrvXBWRZaFMxBufik1a4RpFw4HhgVtBBWQeQgUj2cc= 104 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 h1:pgr/4QbFyktUv9CtQ/Fq4gzEE6/Xs7iCXbktaGzLHbQ= 105 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697/go.mod h1:+D9ySVjN8nY8YCVjc5O7PZDIdZporIDY3KaGfJunh88= 106 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583 h1:IfdSdTcLFy4lqUQrQJLkLt1PB+AsqVz6lwkWPzWEz10= 107 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= 108 | google.golang.org/grpc v1.67.3 h1:OgPcDAFKHnH8X3O4WcO4XUc8GRDeKsKReqbQtiCj7N8= 109 | google.golang.org/grpc v1.67.3/go.mod h1:YGaHCc6Oap+FzBJTZLBzkGSYt/cvGPFTPxkn7QfSU8s= 110 | google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= 111 | google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 112 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 113 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 114 | -------------------------------------------------------------------------------- /history.go: -------------------------------------------------------------------------------- 1 | // Package gollem provides a unified interface for interacting with various LLM services. 2 | package gollem 3 | 4 | import ( 5 | "cloud.google.com/go/vertexai/genai" 6 | "github.com/anthropics/anthropic-sdk-go" 7 | "github.com/anthropics/anthropic-sdk-go/packages/param" 8 | "github.com/m-mizutani/goerr/v2" 9 | "github.com/sashabaranov/go-openai" 10 | ) 11 | 12 | // History represents a conversation history that can be used across different LLM sessions. 13 | // It stores messages in a format specific to each LLM type (OpenAI, Claude, or Gemini). 14 | // 15 | // For detailed documentation, see doc/history.md 16 | type llmType string 17 | 18 | const ( 19 | llmTypeOpenAI llmType = "OpenAI" 20 | llmTypeGemini llmType = "gemini" 21 | llmTypeClaude llmType = "claude" 22 | ) 23 | 24 | const ( 25 | HistoryVersion = 1 26 | ) 27 | 28 | type History struct { 29 | LLType llmType `json:"type"` 30 | Version int `json:"version"` 31 | 32 | Claude []claudeMessage `json:"claude,omitempty"` 33 | OpenAI []openai.ChatCompletionMessage `json:"OpenAI,omitempty"` 34 | Gemini []geminiMessage `json:"gemini,omitempty"` 35 | } 36 | 37 | func (x *History) ToCount() int { 38 | if x == nil { 39 | return 0 40 | } 41 | return len(x.Claude) + len(x.OpenAI) + len(x.Gemini) 42 | } 43 | 44 | func (x *History) ToGemini() ([]*genai.Content, error) { 45 | if x.Version != HistoryVersion { 46 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version)) 47 | } 48 | if x.LLType != llmTypeGemini { 49 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not gemini", goerr.V("expected", llmTypeGemini), goerr.V("actual", x.LLType)) 50 | } 51 | return toGeminiMessages(x.Gemini) 52 | } 53 | 54 | func (x *History) ToClaude() ([]anthropic.MessageParam, error) { 55 | if x.Version != HistoryVersion { 56 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version)) 57 | } 58 | if x.LLType != llmTypeClaude { 59 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not claude", goerr.V("expected", llmTypeClaude), goerr.V("actual", x.LLType)) 60 | } 61 | return toClaudeMessages(x.Claude) 62 | } 63 | 64 | func (x *History) ToOpenAI() ([]openai.ChatCompletionMessage, error) { 65 | if x.Version != HistoryVersion { 66 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version)) 67 | } 68 | if x.LLType != llmTypeOpenAI { 69 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not OpenAI", goerr.V("expected", llmTypeOpenAI), goerr.V("actual", x.LLType)) 70 | } 71 | return x.OpenAI, nil 72 | } 73 | 74 | type claudeMessage struct { 75 | Role anthropic.MessageParamRole `json:"role"` 76 | Content []claudeContentBlock `json:"content"` 77 | } 78 | 79 | type claudeContentBlock struct { 80 | Type string `json:"type"` 81 | Text *string `json:"text,omitempty"` 82 | Source *claudeImageSource `json:"source,omitempty"` 83 | ToolUse *claudeToolUse `json:"tool_use,omitempty"` 84 | ToolResult *claudeToolResult `json:"tool_result,omitempty"` 85 | } 86 | 87 | type claudeImageSource struct { 88 | Type string `json:"type"` 89 | MediaType string `json:"media_type,omitempty"` 90 | Data string `json:"data"` 91 | } 92 | 93 | type claudeToolUse struct { 94 | ID string `json:"id"` 95 | Name string `json:"name"` 96 | Input any `json:"input"` 97 | Type string `json:"type"` 98 | } 99 | 100 | type claudeToolResult struct { 101 | ToolUseID string `json:"tool_use_id"` 102 | Content string `json:"content"` 103 | IsError param.Opt[bool] `json:"is_error"` 104 | } 105 | 106 | type geminiMessage struct { 107 | Role string `json:"role"` 108 | Parts []geminiPart `json:"parts"` 109 | } 110 | 111 | type geminiPart struct { 112 | Type string `json:"type"` 113 | Text string `json:"text,omitempty"` 114 | MIMEType string `json:"mime_type,omitempty"` 115 | Data []byte `json:"data,omitempty"` 116 | FileURI string `json:"file_uri,omitempty"` 117 | Name string `json:"name,omitempty"` 118 | Args map[string]interface{} `json:"args,omitempty"` 119 | Response map[string]interface{} `json:"response,omitempty"` 120 | } 121 | 122 | func NewHistoryFromOpenAI(messages []openai.ChatCompletionMessage) *History { 123 | return &History{ 124 | LLType: llmTypeOpenAI, 125 | Version: HistoryVersion, 126 | OpenAI: messages, 127 | } 128 | } 129 | 130 | func NewHistoryFromClaude(messages []anthropic.MessageParam) *History { 131 | claudeMessages := make([]claudeMessage, len(messages)) 132 | for i, msg := range messages { 133 | content := make([]claudeContentBlock, len(msg.Content)) 134 | for j, c := range msg.Content { 135 | if c.OfRequestTextBlock != nil { 136 | content[j] = claudeContentBlock{ 137 | Type: "text", 138 | Text: &c.OfRequestTextBlock.Text, 139 | } 140 | } else if c.OfRequestImageBlock != nil { 141 | if c.OfRequestImageBlock.Source.OfBase64ImageSource != nil { 142 | content[j] = claudeContentBlock{ 143 | Type: "image", 144 | Source: &claudeImageSource{ 145 | Type: string(c.OfRequestImageBlock.Source.OfBase64ImageSource.Type), 146 | MediaType: string(c.OfRequestImageBlock.Source.OfBase64ImageSource.MediaType), 147 | Data: c.OfRequestImageBlock.Source.OfBase64ImageSource.Data, 148 | }, 149 | } 150 | } 151 | } else if c.OfRequestToolUseBlock != nil { 152 | content[j] = claudeContentBlock{ 153 | Type: "tool_use", 154 | ToolUse: &claudeToolUse{ 155 | ID: c.OfRequestToolUseBlock.ID, 156 | Name: c.OfRequestToolUseBlock.Name, 157 | Input: c.OfRequestToolUseBlock.Input, 158 | Type: string(c.OfRequestToolUseBlock.Type), 159 | }, 160 | } 161 | } else if c.OfRequestToolResultBlock != nil { 162 | content[j] = claudeContentBlock{ 163 | Type: "tool_result", 164 | ToolResult: &claudeToolResult{ 165 | ToolUseID: c.OfRequestToolResultBlock.ToolUseID, 166 | Content: c.OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text, 167 | IsError: c.OfRequestToolResultBlock.IsError, 168 | }, 169 | } 170 | } 171 | } 172 | claudeMessages[i] = claudeMessage{ 173 | Role: msg.Role, 174 | Content: content, 175 | } 176 | } 177 | 178 | return &History{ 179 | LLType: llmTypeClaude, 180 | Version: HistoryVersion, 181 | Claude: claudeMessages, 182 | } 183 | } 184 | 185 | func toClaudeMessages(messages []claudeMessage) ([]anthropic.MessageParam, error) { 186 | converted := make([]anthropic.MessageParam, len(messages)) 187 | 188 | for i, msg := range messages { 189 | content := make([]anthropic.ContentBlockParamUnion, 0, len(msg.Content)) 190 | for _, c := range msg.Content { 191 | switch c.Type { 192 | case "text": 193 | if c.Text == nil { 194 | return nil, goerr.New("text block has no text field") 195 | } 196 | content = append(content, anthropic.ContentBlockParamUnion{ 197 | OfRequestTextBlock: &anthropic.TextBlockParam{ 198 | Text: *c.Text, 199 | Type: "text", 200 | }, 201 | }) 202 | 203 | case "image": 204 | if c.Source == nil { 205 | return nil, goerr.New("image block has no source field") 206 | } 207 | if c.Source.Type == "base64" { 208 | content = append(content, anthropic.ContentBlockParamUnion{ 209 | OfRequestImageBlock: &anthropic.ImageBlockParam{ 210 | Source: anthropic.ImageBlockParamSourceUnion{ 211 | OfBase64ImageSource: &anthropic.Base64ImageSourceParam{ 212 | Type: "base64", 213 | MediaType: anthropic.Base64ImageSourceMediaType(c.Source.MediaType), 214 | Data: c.Source.Data, 215 | }, 216 | }, 217 | Type: "image", 218 | }, 219 | }) 220 | } 221 | 222 | case "tool_use": 223 | if c.ToolUse == nil { 224 | return nil, goerr.New("tool_use block has no tool_use field") 225 | } 226 | content = append(content, anthropic.ContentBlockParamUnion{ 227 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{ 228 | ID: c.ToolUse.ID, 229 | Name: c.ToolUse.Name, 230 | Input: c.ToolUse.Input, 231 | Type: "tool_use", 232 | }, 233 | }) 234 | 235 | case "tool_result": 236 | if c.ToolResult == nil { 237 | return nil, goerr.New("tool_result block has no tool_result field") 238 | } 239 | content = append(content, anthropic.ContentBlockParamUnion{ 240 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{ 241 | ToolUseID: c.ToolResult.ToolUseID, 242 | Content: []anthropic.ToolResultBlockParamContentUnion{ 243 | { 244 | OfRequestTextBlock: &anthropic.TextBlockParam{ 245 | Text: c.ToolResult.Content, 246 | Type: "text", 247 | }, 248 | }, 249 | }, 250 | IsError: param.NewOpt(c.ToolResult.IsError.Value), 251 | }, 252 | }) 253 | } 254 | } 255 | converted[i] = anthropic.MessageParam{ 256 | Role: msg.Role, 257 | Content: content, 258 | } 259 | } 260 | 261 | return converted, nil 262 | } 263 | 264 | func NewHistoryFromGemini(messages []*genai.Content) *History { 265 | converted := make([]geminiMessage, len(messages)) 266 | for i, msg := range messages { 267 | parts := make([]geminiPart, len(msg.Parts)) 268 | for j, p := range msg.Parts { 269 | switch v := p.(type) { 270 | case genai.Text: 271 | parts[j] = geminiPart{ 272 | Type: "text", 273 | Text: string(v), 274 | } 275 | case genai.Blob: 276 | parts[j] = geminiPart{ 277 | Type: "blob", 278 | MIMEType: v.MIMEType, 279 | Data: v.Data, 280 | } 281 | case genai.FileData: 282 | parts[j] = geminiPart{ 283 | Type: "file_data", 284 | MIMEType: v.MIMEType, 285 | FileURI: v.FileURI, 286 | } 287 | case genai.FunctionCall: 288 | parts[j] = geminiPart{ 289 | Type: "function_call", 290 | Name: v.Name, 291 | Args: v.Args, 292 | } 293 | case genai.FunctionResponse: 294 | parts[j] = geminiPart{ 295 | Type: "function_response", 296 | Name: v.Name, 297 | Response: v.Response, 298 | } 299 | } 300 | } 301 | converted[i] = geminiMessage{ 302 | Role: msg.Role, 303 | Parts: parts, 304 | } 305 | } 306 | return &History{ 307 | LLType: llmTypeGemini, 308 | Version: HistoryVersion, 309 | Gemini: converted, 310 | } 311 | } 312 | 313 | func toGeminiMessages(messages []geminiMessage) ([]*genai.Content, error) { 314 | converted := make([]*genai.Content, len(messages)) 315 | for i, msg := range messages { 316 | parts := make([]genai.Part, len(msg.Parts)) 317 | for j, p := range msg.Parts { 318 | switch p.Type { 319 | case "text": 320 | parts[j] = genai.Text(p.Text) 321 | case "blob": 322 | parts[j] = genai.Blob{ 323 | MIMEType: p.MIMEType, 324 | Data: p.Data, 325 | } 326 | case "file_data": 327 | parts[j] = genai.FileData{ 328 | MIMEType: p.MIMEType, 329 | FileURI: p.FileURI, 330 | } 331 | case "function_call": 332 | parts[j] = genai.FunctionCall{ 333 | Name: p.Name, 334 | Args: p.Args, 335 | } 336 | case "function_response": 337 | parts[j] = genai.FunctionResponse{ 338 | Name: p.Name, 339 | Response: p.Response, 340 | } 341 | } 342 | } 343 | converted[i] = &genai.Content{ 344 | Role: msg.Role, 345 | Parts: parts, 346 | } 347 | } 348 | return converted, nil 349 | } 350 | -------------------------------------------------------------------------------- /history_test.go: -------------------------------------------------------------------------------- 1 | package gollem_test 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "cloud.google.com/go/vertexai/genai" 8 | "github.com/anthropics/anthropic-sdk-go" 9 | "github.com/anthropics/anthropic-sdk-go/packages/param" 10 | "github.com/m-mizutani/gollem" 11 | "github.com/m-mizutani/gt" 12 | "github.com/sashabaranov/go-openai" 13 | ) 14 | 15 | func TestHistoryOpenAI(t *testing.T) { 16 | // Create OpenAI messages with various content types 17 | messages := []openai.ChatCompletionMessage{ 18 | { 19 | Role: "system", 20 | Content: "You are a helpful assistant.", 21 | }, 22 | { 23 | Role: "user", 24 | Content: "Hello", 25 | }, 26 | { 27 | Role: "assistant", 28 | Content: "Hi, how can I help you?", 29 | }, 30 | { 31 | Role: "user", 32 | Content: "What's the weather like?", 33 | }, 34 | { 35 | Role: "assistant", 36 | Content: "", 37 | FunctionCall: &openai.FunctionCall{ 38 | Name: "get_weather", 39 | Arguments: `{"location": "Tokyo"}`, 40 | }, 41 | }, 42 | { 43 | Role: "tool", 44 | Name: "get_weather", 45 | Content: `{"temperature": 25, "condition": "sunny"}`, 46 | }, 47 | { 48 | Role: "assistant", 49 | Content: "The weather in Tokyo is sunny with a temperature of 25°C.", 50 | }, 51 | } 52 | 53 | // Create History object 54 | history := gollem.NewHistoryFromOpenAI(messages) 55 | 56 | // Convert to JSON 57 | data, err := json.Marshal(history) 58 | gt.NoError(t, err) 59 | 60 | // Restore from JSON 61 | var restored gollem.History 62 | gt.NoError(t, json.Unmarshal(data, &restored)) 63 | 64 | restoredMessages, err := restored.ToOpenAI() 65 | gt.NoError(t, err) 66 | 67 | gt.Equal(t, messages, restoredMessages) 68 | 69 | // Validate specific message types 70 | gt.Equal(t, "system", restoredMessages[0].Role) 71 | gt.Equal(t, "You are a helpful assistant.", restoredMessages[0].Content) 72 | 73 | gt.Equal(t, "assistant", restoredMessages[2].Role) 74 | gt.Equal(t, "Hi, how can I help you?", restoredMessages[2].Content) 75 | 76 | gt.Equal(t, "assistant", restoredMessages[4].Role) 77 | gt.Equal(t, "", restoredMessages[4].Content) 78 | gt.Equal(t, "get_weather", restoredMessages[4].FunctionCall.Name) 79 | gt.Equal(t, `{"location": "Tokyo"}`, restoredMessages[4].FunctionCall.Arguments) 80 | 81 | gt.Equal(t, "tool", restoredMessages[5].Role) 82 | gt.Equal(t, "get_weather", restoredMessages[5].Name) 83 | gt.Equal(t, `{"temperature": 25, "condition": "sunny"}`, restoredMessages[5].Content) 84 | } 85 | 86 | func TestHistoryClaude(t *testing.T) { 87 | // Create Claude messages with various content types 88 | messages := []anthropic.MessageParam{ 89 | { 90 | Role: anthropic.MessageParamRoleUser, 91 | Content: []anthropic.ContentBlockParamUnion{ 92 | { 93 | OfRequestTextBlock: &anthropic.TextBlockParam{ 94 | Text: "Hello", 95 | Type: "text", 96 | }, 97 | }, 98 | { 99 | OfRequestImageBlock: &anthropic.ImageBlockParam{ 100 | Source: anthropic.ImageBlockParamSourceUnion{ 101 | OfBase64ImageSource: &anthropic.Base64ImageSourceParam{ 102 | Type: "base64", 103 | MediaType: "image/jpeg", 104 | Data: "base64encodedimage", 105 | }, 106 | }, 107 | Type: "image", 108 | }, 109 | }, 110 | }, 111 | }, 112 | { 113 | Role: anthropic.MessageParamRoleAssistant, 114 | Content: []anthropic.ContentBlockParamUnion{ 115 | { 116 | OfRequestTextBlock: &anthropic.TextBlockParam{ 117 | Text: "Hi, how can I help you?", 118 | Type: "text", 119 | }, 120 | }, 121 | }, 122 | }, 123 | { 124 | Role: anthropic.MessageParamRoleUser, 125 | Content: []anthropic.ContentBlockParamUnion{ 126 | { 127 | OfRequestTextBlock: &anthropic.TextBlockParam{ 128 | Text: "What's the weather like?", 129 | Type: "text", 130 | }, 131 | }, 132 | }, 133 | }, 134 | { 135 | Role: anthropic.MessageParamRoleAssistant, 136 | Content: []anthropic.ContentBlockParamUnion{ 137 | { 138 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{ 139 | ID: "tool_1", 140 | Name: "get_weather", 141 | Input: `{"location": "Tokyo"}`, 142 | Type: "tool_use", 143 | }, 144 | }, 145 | }, 146 | }, 147 | { 148 | Role: anthropic.MessageParamRoleUser, 149 | Content: []anthropic.ContentBlockParamUnion{ 150 | { 151 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{ 152 | ToolUseID: "tool_2", 153 | Content: []anthropic.ToolResultBlockParamContentUnion{ 154 | { 155 | OfRequestTextBlock: &anthropic.TextBlockParam{ 156 | Text: `{"temperature": 30, "condition": "cloudy"}`, 157 | Type: "text", 158 | }, 159 | }, 160 | }, 161 | IsError: param.NewOpt(false), 162 | Type: "tool_result", 163 | }, 164 | }, 165 | }, 166 | }, 167 | { 168 | Role: anthropic.MessageParamRoleAssistant, 169 | Content: []anthropic.ContentBlockParamUnion{ 170 | { 171 | OfRequestTextBlock: &anthropic.TextBlockParam{ 172 | Text: "Second message", 173 | Type: "text", 174 | }, 175 | }, 176 | }, 177 | }, 178 | { 179 | Role: anthropic.MessageParamRoleUser, 180 | Content: []anthropic.ContentBlockParamUnion{ 181 | { 182 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{ 183 | ToolUseID: "tool_3", 184 | Content: []anthropic.ToolResultBlockParamContentUnion{ 185 | { 186 | OfRequestTextBlock: &anthropic.TextBlockParam{ 187 | Text: `{"temperature": 35, "condition": "rainy"}`, 188 | Type: "text", 189 | }, 190 | }, 191 | }, 192 | IsError: param.NewOpt(false), 193 | Type: "tool_result", 194 | }, 195 | }, 196 | }, 197 | }, 198 | { 199 | Role: anthropic.MessageParamRoleAssistant, 200 | Content: []anthropic.ContentBlockParamUnion{ 201 | { 202 | OfRequestTextBlock: &anthropic.TextBlockParam{ 203 | Text: "The weather in Tokyo is sunny with a temperature of 25°C.", 204 | Type: "text", 205 | }, 206 | }, 207 | }, 208 | }, 209 | } 210 | 211 | // Create History object 212 | history := gollem.NewHistoryFromClaude(messages) 213 | 214 | // Convert to JSON 215 | data, err := json.Marshal(history) 216 | gt.NoError(t, err) 217 | 218 | // Restore from JSON 219 | var restored gollem.History 220 | gt.NoError(t, json.Unmarshal(data, &restored)) 221 | 222 | restoredMessages, err := restored.ToClaude() 223 | gt.NoError(t, err) 224 | 225 | // Compare each message individually to make debugging easier 226 | for i := range messages { 227 | gt.Value(t, restoredMessages[i].Role).Equal(messages[i].Role) 228 | gt.Value(t, len(restoredMessages[i].Content)).Equal(len(messages[i].Content)) 229 | 230 | for j := range messages[i].Content { 231 | if messages[i].Content[j].OfRequestToolResultBlock != nil { 232 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.ToolUseID).Equal(messages[i].Content[j].OfRequestToolResultBlock.ToolUseID) 233 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.IsError).Equal(messages[i].Content[j].OfRequestToolResultBlock.IsError) 234 | gt.Value(t, len(restoredMessages[i].Content[j].OfRequestToolResultBlock.Content)).Equal(len(messages[i].Content[j].OfRequestToolResultBlock.Content)) 235 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text).Equal(messages[i].Content[j].OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text) 236 | } else { 237 | gt.Value(t, restoredMessages[i].Content[j]).Equal(messages[i].Content[j]) 238 | } 239 | } 240 | } 241 | } 242 | 243 | func TestHistoryGemini(t *testing.T) { 244 | // Create Gemini messages with various content types 245 | messages := []*genai.Content{ 246 | { 247 | Role: "user", 248 | Parts: []genai.Part{ 249 | genai.Text("Hello"), 250 | genai.Blob{ 251 | MIMEType: "image/jpeg", 252 | Data: []byte("fake image data"), 253 | }, 254 | genai.FileData{ 255 | MIMEType: "application/pdf", 256 | FileURI: "gs://bucket/file.pdf", 257 | }, 258 | }, 259 | }, 260 | { 261 | Role: "model", 262 | Parts: []genai.Part{ 263 | genai.Text("Hi, how can I help you?"), 264 | genai.FunctionCall{ 265 | Name: "test_function", 266 | Args: map[string]interface{}{ 267 | "param1": "value1", 268 | "param2": float64(123), 269 | }, 270 | }, 271 | }, 272 | }, 273 | { 274 | Role: "model", 275 | Parts: []genai.Part{ 276 | genai.Text("Function result"), 277 | genai.FunctionResponse{ 278 | Name: "test_function", 279 | Response: map[string]interface{}{ 280 | "status": "success", 281 | "result": "operation completed", 282 | }, 283 | }, 284 | }, 285 | }, 286 | } 287 | 288 | // Create History object 289 | history := gollem.NewHistoryFromGemini(messages) 290 | 291 | // Convert to JSON 292 | data, err := json.Marshal(history) 293 | gt.NoError(t, err) 294 | 295 | // Restore from JSON 296 | var restored gollem.History 297 | gt.NoError(t, json.Unmarshal(data, &restored)) 298 | 299 | restoredMessages, err := restored.ToGemini() 300 | gt.NoError(t, err) 301 | gt.Equal(t, messages, restoredMessages) 302 | 303 | // Validate specific message types 304 | gt.Equal(t, "user", restoredMessages[0].Role) 305 | gt.Equal(t, 3, len(restoredMessages[0].Parts)) 306 | gt.Equal(t, "Hello", restoredMessages[0].Parts[0].(genai.Text)) 307 | gt.Equal(t, "image/jpeg", restoredMessages[0].Parts[1].(genai.Blob).MIMEType) 308 | gt.Equal(t, "application/pdf", restoredMessages[0].Parts[2].(genai.FileData).MIMEType) 309 | gt.Equal(t, "gs://bucket/file.pdf", restoredMessages[0].Parts[2].(genai.FileData).FileURI) 310 | 311 | gt.Equal(t, "model", restoredMessages[1].Role) 312 | gt.Equal(t, 2, len(restoredMessages[1].Parts)) 313 | gt.Equal(t, "Hi, how can I help you?", restoredMessages[1].Parts[0].(genai.Text)) 314 | gt.Equal(t, "test_function", restoredMessages[1].Parts[1].(genai.FunctionCall).Name) 315 | gt.Equal(t, "value1", restoredMessages[1].Parts[1].(genai.FunctionCall).Args["param1"]) 316 | gt.Equal(t, float64(123), restoredMessages[1].Parts[1].(genai.FunctionCall).Args["param2"].(float64)) 317 | 318 | gt.Equal(t, "model", restoredMessages[2].Role) 319 | gt.Equal(t, 2, len(restoredMessages[2].Parts)) 320 | gt.Equal(t, "Function result", restoredMessages[2].Parts[0].(genai.Text)) 321 | gt.Equal(t, "test_function", restoredMessages[2].Parts[1].(genai.FunctionResponse).Name) 322 | gt.Equal(t, "success", restoredMessages[2].Parts[1].(genai.FunctionResponse).Response["status"]) 323 | gt.Equal(t, "operation completed", restoredMessages[2].Parts[1].(genai.FunctionResponse).Response["result"]) 324 | } 325 | -------------------------------------------------------------------------------- /hook.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import "context" 4 | 5 | type ( 6 | // LoopHook is a hook for the session loop. "loop" is the loop count, it's 0-indexed. "input" is the current input of the loop. If you want to abort the session loop, you can return an error. 7 | LoopHook func(ctx context.Context, loop int, input []Input) error 8 | 9 | // MessageHook is a hook for the message. If you want to display or record the message, you can use this hook. 10 | MessageHook func(ctx context.Context, msg string) error 11 | 12 | // ToolRequestHook is a hook for the tool request. If you want to display or record the tool request, you can use this hook. If you want to abort the tool execution, you can return an error. 13 | ToolRequestHook func(ctx context.Context, tool FunctionCall) error 14 | 15 | // ToolResponseHook is a hook for the tool response. If you want to display or record the tool response, you can use this hook. If you want to abort the tool execution, you can return an error. 16 | ToolResponseHook func(ctx context.Context, tool FunctionCall, response map[string]any) error 17 | 18 | // ToolErrorHook is a hook for the tool error. If you want to record the tool error, you can use this hook. 19 | ToolErrorHook func(ctx context.Context, err error, tool FunctionCall) error 20 | ) 21 | 22 | func defaultLoopHook(ctx context.Context, loop int, input []Input) error { 23 | return nil 24 | } 25 | 26 | func defaultMessageHook(ctx context.Context, msg string) error { 27 | return nil 28 | } 29 | 30 | func defaultToolRequestHook(ctx context.Context, tool FunctionCall) error { 31 | return nil 32 | } 33 | 34 | func defaultToolResponseHook(ctx context.Context, tool FunctionCall, response map[string]any) error { 35 | return nil 36 | } 37 | 38 | func defaultToolErrorHook(ctx context.Context, err error, tool FunctionCall) error { 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /llm.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import "context" 4 | 5 | // LLMClient is a client for each LLM service. 6 | type LLMClient interface { 7 | NewSession(ctx context.Context, options ...SessionOption) (Session, error) 8 | GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) 9 | } 10 | 11 | type FunctionCall struct { 12 | ID string 13 | Name string 14 | Arguments map[string]any 15 | } 16 | 17 | // Response is a general response type for each gollem. 18 | type Response struct { 19 | Texts []string 20 | FunctionCalls []*FunctionCall 21 | 22 | // Error is an error that occurred during the generation for streaming response. 23 | Error error 24 | } 25 | 26 | func (r *Response) HasData() bool { 27 | return len(r.Texts) > 0 || len(r.FunctionCalls) > 0 || r.Error != nil 28 | } 29 | 30 | type Input interface { 31 | isInput() restrictedValue 32 | } 33 | 34 | type restrictedValue struct{} 35 | 36 | // Text is a text input as prompt. 37 | // Usage: 38 | // input := gollem.Text("Hello, world!") 39 | type Text string 40 | 41 | func (t Text) isInput() restrictedValue { 42 | return restrictedValue{} 43 | } 44 | 45 | // FunctionResponse is a function response. 46 | // Usage: 47 | // 48 | // input := gollem.FunctionResponse{ 49 | // Name: "function_name", 50 | // Arguments: map[string]any{"key": "value"}, 51 | // } 52 | type FunctionResponse struct { 53 | ID string 54 | Name string 55 | Data map[string]any 56 | Error error 57 | } 58 | 59 | func (f FunctionResponse) isInput() restrictedValue { 60 | return restrictedValue{} 61 | } 62 | -------------------------------------------------------------------------------- /llm/claude/client.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/anthropics/anthropic-sdk-go" 10 | "github.com/anthropics/anthropic-sdk-go/option" 11 | "github.com/m-mizutani/goerr/v2" 12 | "github.com/m-mizutani/gollem" 13 | ) 14 | 15 | const ( 16 | DefaultEmbeddingModel = "claude-3-sonnet-20240229" 17 | ) 18 | 19 | // generationParameters represents the parameters for text generation. 20 | type generationParameters struct { 21 | // Temperature controls randomness in the output. 22 | // Higher values make the output more random, lower values make it more focused. 23 | Temperature float64 24 | 25 | // TopP controls diversity via nucleus sampling. 26 | // Higher values allow more diverse outputs. 27 | TopP float64 28 | 29 | // MaxTokens limits the number of tokens to generate. 30 | MaxTokens int64 31 | } 32 | 33 | // Client is a client for the Claude API. 34 | // It provides methods to interact with Anthropic's Claude models. 35 | type Client struct { 36 | // client is the underlying Claude client. 37 | client *anthropic.Client 38 | 39 | // defaultModel is the model to use for chat completions. 40 | // It can be overridden using WithModel option. 41 | defaultModel string 42 | 43 | // embeddingModel is the model to use for embeddings. 44 | // It can be overridden using WithEmbeddingModel option. 45 | embeddingModel string 46 | 47 | // apiKey is the API key for authentication. 48 | apiKey string 49 | 50 | // generation parameters 51 | params generationParameters 52 | 53 | // systemPrompt is the system prompt to use for chat completions. 54 | systemPrompt string 55 | } 56 | 57 | // Option is a function that configures a Client. 58 | type Option func(*Client) 59 | 60 | // WithModel sets the default model to use for chat completions. 61 | // The model name should be a valid Claude model identifier. 62 | // Default: anthropic.ModelClaude3_5SonnetLatest 63 | func WithModel(modelName string) Option { 64 | return func(c *Client) { 65 | c.defaultModel = modelName 66 | } 67 | } 68 | 69 | // WithEmbeddingModel sets the embedding model to use for embeddings. 70 | // The model name should be a valid Claude model identifier. 71 | // Default: DefaultEmbeddingModel 72 | func WithEmbeddingModel(modelName string) Option { 73 | return func(c *Client) { 74 | c.embeddingModel = modelName 75 | } 76 | } 77 | 78 | // WithTemperature sets the temperature parameter for text generation. 79 | // Higher values make the output more random, lower values make it more focused. 80 | // Range: 0.0 to 1.0 81 | // Default: 0.7 82 | func WithTemperature(temp float64) Option { 83 | return func(c *Client) { 84 | c.params.Temperature = temp 85 | } 86 | } 87 | 88 | // WithTopP sets the top_p parameter for text generation. 89 | // Controls diversity via nucleus sampling. 90 | // Range: 0.0 to 1.0 91 | // Default: 1.0 92 | func WithTopP(topP float64) Option { 93 | return func(c *Client) { 94 | c.params.TopP = topP 95 | } 96 | } 97 | 98 | // WithMaxTokens sets the maximum number of tokens to generate. 99 | // Default: 4096 100 | func WithMaxTokens(maxTokens int64) Option { 101 | return func(c *Client) { 102 | c.params.MaxTokens = maxTokens 103 | } 104 | } 105 | 106 | // WithSystemPrompt sets the system prompt to use for chat completions. 107 | func WithSystemPrompt(prompt string) Option { 108 | return func(c *Client) { 109 | c.systemPrompt = prompt 110 | } 111 | } 112 | 113 | // New creates a new client for the Claude API. 114 | // It requires an API key and can be configured with additional options. 115 | func New(ctx context.Context, apiKey string, options ...Option) (*Client, error) { 116 | client := &Client{ 117 | defaultModel: anthropic.ModelClaude3_5SonnetLatest, 118 | embeddingModel: DefaultEmbeddingModel, 119 | apiKey: apiKey, 120 | params: generationParameters{ 121 | Temperature: 0.7, 122 | TopP: 1.0, 123 | MaxTokens: 4096, 124 | }, 125 | } 126 | 127 | for _, option := range options { 128 | option(client) 129 | } 130 | 131 | newClient := anthropic.NewClient( 132 | option.WithAPIKey(apiKey), 133 | ) 134 | client.client = &newClient 135 | 136 | return client, nil 137 | } 138 | 139 | // Session is a session for the Claude chat. 140 | // It maintains the conversation state and handles message generation. 141 | type Session struct { 142 | // client is the underlying Claude client. 143 | client *anthropic.Client 144 | 145 | // defaultModel is the model to use for chat completions. 146 | defaultModel string 147 | 148 | // tools are the available tools for the session. 149 | tools []anthropic.ToolUnionParam 150 | 151 | // messages stores the conversation history. 152 | messages []anthropic.MessageParam 153 | 154 | // generation parameters 155 | params generationParameters 156 | 157 | cfg gollem.SessionConfig 158 | } 159 | 160 | // NewSession creates a new session for the Claude API. 161 | // It converts the provided tools to Claude's tool format and initializes a new chat session. 162 | func (c *Client) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) { 163 | cfg := gollem.NewSessionConfig(options...) 164 | 165 | // Convert gollem.Tool to anthropic.ToolUnionParam 166 | claudeTools := make([]anthropic.ToolUnionParam, len(cfg.Tools())) 167 | for i, tool := range cfg.Tools() { 168 | claudeTools[i] = convertTool(tool) 169 | } 170 | 171 | var messages []anthropic.MessageParam 172 | if cfg.History() != nil { 173 | history, err := cfg.History().ToClaude() 174 | if err != nil { 175 | return nil, goerr.Wrap(err, "failed to convert history to anthropic.MessageParam") 176 | } 177 | messages = append(messages, history...) 178 | } 179 | 180 | session := &Session{ 181 | client: c.client, 182 | defaultModel: c.defaultModel, 183 | tools: claudeTools, 184 | params: c.params, 185 | messages: messages, 186 | cfg: cfg, 187 | } 188 | 189 | return session, nil 190 | } 191 | 192 | func (s *Session) History() *gollem.History { 193 | return gollem.NewHistoryFromClaude(s.messages) 194 | } 195 | 196 | // convertInputs converts gollem.Input to Claude messages and tool results 197 | func (s *Session) convertInputs(input ...gollem.Input) ([]anthropic.MessageParam, []anthropic.ContentBlockParamUnion, error) { 198 | var toolResults []anthropic.ContentBlockParamUnion 199 | var messages []anthropic.MessageParam 200 | 201 | for _, in := range input { 202 | switch v := in.(type) { 203 | case gollem.Text: 204 | messages = append(messages, anthropic.NewUserMessage( 205 | anthropic.NewTextBlock(string(v)), 206 | )) 207 | 208 | case gollem.FunctionResponse: 209 | data, err := json.Marshal(v.Data) 210 | if err != nil { 211 | return nil, nil, goerr.Wrap(err, "failed to marshal function response") 212 | } 213 | response := string(data) 214 | if v.Error != nil { 215 | response = fmt.Sprintf(`Error message: %+v`, v.Error) 216 | } 217 | toolResults = append(toolResults, anthropic.NewToolResultBlock(v.ID, response, v.Error != nil)) 218 | 219 | default: 220 | return nil, nil, goerr.Wrap(gollem.ErrInvalidParameter, "invalid input") 221 | } 222 | } 223 | 224 | if len(toolResults) > 0 { 225 | messages = append(messages, anthropic.NewUserMessage(toolResults...)) 226 | } 227 | 228 | return messages, toolResults, nil 229 | } 230 | 231 | // createRequest creates a message request with the current session state 232 | func (s *Session) createRequest() anthropic.MessageNewParams { 233 | var systemPrompt []anthropic.TextBlockParam 234 | if s.cfg.SystemPrompt() != "" { 235 | systemPrompt = []anthropic.TextBlockParam{ 236 | { 237 | Text: s.cfg.SystemPrompt(), 238 | }, 239 | } 240 | } 241 | 242 | // Add content type instruction to system prompt 243 | if s.cfg.ContentType() == gollem.ContentTypeJSON { 244 | if len(systemPrompt) > 0 { 245 | systemPrompt[0].Text += "\nPlease format your response as valid JSON." 246 | } else { 247 | systemPrompt = []anthropic.TextBlockParam{ 248 | { 249 | Text: "Please format your response as valid JSON.", 250 | }, 251 | } 252 | } 253 | } 254 | 255 | return anthropic.MessageNewParams{ 256 | Model: s.defaultModel, 257 | MaxTokens: s.params.MaxTokens, 258 | Temperature: anthropic.Float(s.params.Temperature), 259 | TopP: anthropic.Float(s.params.TopP), 260 | Tools: s.tools, 261 | Messages: s.messages, 262 | System: systemPrompt, 263 | } 264 | } 265 | 266 | // processResponse converts Claude response to gollem.Response 267 | func processResponse(resp *anthropic.Message) *gollem.Response { 268 | if len(resp.Content) == 0 { 269 | return &gollem.Response{} 270 | } 271 | 272 | response := &gollem.Response{ 273 | Texts: make([]string, 0), 274 | FunctionCalls: make([]*gollem.FunctionCall, 0), 275 | } 276 | 277 | for _, content := range resp.Content { 278 | textBlock := content.AsResponseTextBlock() 279 | if textBlock.Type == "text" { 280 | response.Texts = append(response.Texts, textBlock.Text) 281 | } 282 | 283 | toolUseBlock := content.AsResponseToolUseBlock() 284 | if toolUseBlock.Type == "tool_use" { 285 | var args map[string]interface{} 286 | if err := json.Unmarshal([]byte(toolUseBlock.Input), &args); err != nil { 287 | response.Error = goerr.Wrap(err, "failed to unmarshal function arguments") 288 | return response 289 | } 290 | 291 | response.FunctionCalls = append(response.FunctionCalls, &gollem.FunctionCall{ 292 | ID: toolUseBlock.ID, 293 | Name: toolUseBlock.Name, 294 | Arguments: args, 295 | }) 296 | } 297 | } 298 | 299 | return response 300 | } 301 | 302 | // GenerateContent processes the input and generates a response. 303 | // It handles both text messages and function responses. 304 | func (s *Session) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) { 305 | messages, _, err := s.convertInputs(input...) 306 | if err != nil { 307 | return nil, err 308 | } 309 | 310 | s.messages = append(s.messages, messages...) 311 | params := s.createRequest() 312 | 313 | resp, err := s.client.Messages.New(ctx, params) 314 | if err != nil { 315 | return nil, goerr.Wrap(err, "failed to create message") 316 | } 317 | 318 | // Add assistant's response to message history 319 | s.messages = append(s.messages, resp.ToParam()) 320 | 321 | return processResponse(resp), nil 322 | } 323 | 324 | // FunctionCallAccumulator accumulates function call information from stream 325 | type FunctionCallAccumulator struct { 326 | ID string 327 | Name string 328 | Arguments string 329 | } 330 | 331 | func newFunctionCallAccumulator() *FunctionCallAccumulator { 332 | return &FunctionCallAccumulator{ 333 | Arguments: "", 334 | } 335 | } 336 | 337 | func (a *FunctionCallAccumulator) accumulate() (*gollem.FunctionCall, error) { 338 | if a.ID == "" || a.Name == "" { 339 | return nil, goerr.Wrap(gollem.ErrInvalidParameter, "function call is not complete") 340 | } 341 | 342 | var args map[string]any 343 | if a.Arguments != "" { 344 | if err := json.Unmarshal([]byte(a.Arguments), &args); err != nil { 345 | return nil, goerr.Wrap(err, "failed to unmarshal function call arguments", goerr.V("accumulator", a)) 346 | } 347 | } 348 | 349 | return &gollem.FunctionCall{ 350 | ID: a.ID, 351 | Name: a.Name, 352 | Arguments: args, 353 | }, nil 354 | } 355 | 356 | // GenerateStream processes the input and generates a response stream. 357 | // It handles both text messages and function responses, and returns a channel for streaming responses. 358 | func (s *Session) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) { 359 | messages, _, err := s.convertInputs(input...) 360 | if err != nil { 361 | return nil, err 362 | } 363 | 364 | s.messages = append(s.messages, messages...) 365 | params := s.createRequest() 366 | 367 | stream := s.client.Messages.NewStreaming(ctx, params) 368 | if stream == nil { 369 | return nil, goerr.New("failed to create message stream") 370 | } 371 | 372 | responseChan := make(chan *gollem.Response) 373 | 374 | // Accumulate text and tool calls for message history 375 | var textContent strings.Builder 376 | var toolCalls []anthropic.ContentBlockParamUnion 377 | acc := newFunctionCallAccumulator() 378 | 379 | go func() { 380 | defer close(responseChan) 381 | 382 | for { 383 | if !stream.Next() { 384 | // Add accumulated message to history when stream ends 385 | if textContent.Len() > 0 || len(toolCalls) > 0 { 386 | var content []anthropic.ContentBlockParamUnion 387 | if textContent.Len() > 0 { 388 | content = append(content, anthropic.NewTextBlock(textContent.String())) 389 | } 390 | content = append(content, toolCalls...) 391 | s.messages = append(s.messages, anthropic.NewAssistantMessage(content...)) 392 | } 393 | return 394 | } 395 | 396 | event := stream.Current() 397 | response := &gollem.Response{ 398 | Texts: make([]string, 0), 399 | FunctionCalls: make([]*gollem.FunctionCall, 0), 400 | } 401 | 402 | switch event.Type { 403 | case "content_block_delta": 404 | deltaEvent := event.AsContentBlockDeltaEvent() 405 | switch deltaEvent.Delta.Type { 406 | case "text_delta": 407 | textDelta := deltaEvent.Delta.AsTextContentBlockDelta() 408 | response.Texts = append(response.Texts, textDelta.Text) 409 | textContent.WriteString(textDelta.Text) 410 | case "input_json_delta": 411 | jsonDelta := deltaEvent.Delta.AsInputJSONContentBlockDelta() 412 | if jsonDelta.PartialJSON != "" { 413 | acc.Arguments += jsonDelta.PartialJSON 414 | } 415 | } 416 | case "content_block_start": 417 | startEvent := event.AsContentBlockStartEvent() 418 | if startEvent.ContentBlock.Type == "tool_use" { 419 | toolUseBlock := startEvent.ContentBlock.AsResponseToolUseBlock() 420 | acc.ID = toolUseBlock.ID 421 | acc.Name = toolUseBlock.Name 422 | } 423 | case "content_block_stop": 424 | if acc.ID != "" && acc.Name != "" { 425 | funcCall, err := acc.accumulate() 426 | if err != nil { 427 | response.Error = err 428 | responseChan <- response 429 | return 430 | } 431 | response.FunctionCalls = append(response.FunctionCalls, funcCall) 432 | toolCalls = append(toolCalls, anthropic.ContentBlockParamUnion{ 433 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{ 434 | ID: funcCall.ID, 435 | Name: funcCall.Name, 436 | Input: funcCall.Arguments, 437 | Type: "tool_use", 438 | }, 439 | }) 440 | acc = newFunctionCallAccumulator() 441 | } 442 | } 443 | 444 | if response.HasData() { 445 | responseChan <- response 446 | } 447 | } 448 | }() 449 | 450 | return responseChan, nil 451 | } 452 | -------------------------------------------------------------------------------- /llm/claude/convert.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | import ( 4 | "github.com/anthropics/anthropic-sdk-go" 5 | "github.com/m-mizutani/gollem" 6 | ) 7 | 8 | func convertTool(tool gollem.Tool) anthropic.ToolUnionParam { 9 | spec := tool.Spec() 10 | schema := convertParametersToJSONSchema(spec.Parameters) 11 | 12 | return anthropic.ToolUnionParamOfTool( 13 | anthropic.ToolInputSchemaParam{ 14 | Properties: schema.Properties, 15 | }, 16 | spec.Name, 17 | ) 18 | } 19 | 20 | type jsonSchema struct { 21 | Type string `json:"type"` 22 | Properties map[string]jsonSchema `json:"properties,omitempty"` 23 | Required []string `json:"required,omitempty"` 24 | Items *jsonSchema `json:"items,omitempty"` 25 | Minimum *float64 `json:"minimum,omitempty"` 26 | Maximum *float64 `json:"maximum,omitempty"` 27 | MinLength *int `json:"minLength,omitempty"` 28 | MaxLength *int `json:"maxLength,omitempty"` 29 | Pattern string `json:"pattern,omitempty"` 30 | MinItems *int `json:"minItems,omitempty"` 31 | MaxItems *int `json:"maxItems,omitempty"` 32 | Default interface{} `json:"default,omitempty"` 33 | Enum []interface{} `json:"enum,omitempty"` 34 | Description string `json:"description,omitempty"` 35 | Title string `json:"title,omitempty"` 36 | } 37 | 38 | func convertParametersToJSONSchema(params map[string]*gollem.Parameter) jsonSchema { 39 | properties := make(map[string]jsonSchema) 40 | 41 | for name, param := range params { 42 | properties[name] = convertParameterToSchema(param) 43 | } 44 | 45 | return jsonSchema{ 46 | Type: "object", 47 | Properties: properties, 48 | } 49 | } 50 | 51 | // convertParameterToSchema converts gollem.Parameter to Claude schema 52 | func convertParameterToSchema(param *gollem.Parameter) jsonSchema { 53 | schema := jsonSchema{ 54 | Type: getClaudeType(param.Type), 55 | Description: param.Description, 56 | Title: param.Title, 57 | } 58 | 59 | if len(param.Enum) > 0 { 60 | enum := make([]interface{}, len(param.Enum)) 61 | for i, v := range param.Enum { 62 | enum[i] = v 63 | } 64 | schema.Enum = enum 65 | } 66 | 67 | if param.Properties != nil { 68 | properties := make(map[string]jsonSchema) 69 | for name, prop := range param.Properties { 70 | properties[name] = convertParameterToSchema(prop) 71 | } 72 | schema.Properties = properties 73 | if len(param.Required) > 0 { 74 | schema.Required = param.Required 75 | } 76 | } 77 | 78 | if param.Items != nil { 79 | items := convertParameterToSchema(param.Items) 80 | schema.Items = &items 81 | } 82 | 83 | // Add number constraints 84 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger { 85 | if param.Minimum != nil { 86 | schema.Minimum = param.Minimum 87 | } 88 | if param.Maximum != nil { 89 | schema.Maximum = param.Maximum 90 | } 91 | } 92 | 93 | // Add string constraints 94 | if param.Type == gollem.TypeString { 95 | if param.MinLength != nil { 96 | schema.MinLength = param.MinLength 97 | } 98 | if param.MaxLength != nil { 99 | schema.MaxLength = param.MaxLength 100 | } 101 | if param.Pattern != "" { 102 | schema.Pattern = param.Pattern 103 | } 104 | } 105 | 106 | // Add array constraints 107 | if param.Type == gollem.TypeArray { 108 | if param.MinItems != nil { 109 | schema.MinItems = param.MinItems 110 | } 111 | if param.MaxItems != nil { 112 | schema.MaxItems = param.MaxItems 113 | } 114 | } 115 | 116 | // Add default value 117 | if param.Default != nil { 118 | schema.Default = param.Default 119 | } 120 | 121 | return schema 122 | } 123 | 124 | func getClaudeType(paramType gollem.ParameterType) string { 125 | switch paramType { 126 | case gollem.TypeString: 127 | return "string" 128 | case gollem.TypeNumber: 129 | return "number" 130 | case gollem.TypeInteger: 131 | return "integer" 132 | case gollem.TypeBoolean: 133 | return "boolean" 134 | case gollem.TypeArray: 135 | return "array" 136 | case gollem.TypeObject: 137 | return "object" 138 | default: 139 | return "string" 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /llm/claude/convert_test.go: -------------------------------------------------------------------------------- 1 | package claude_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/m-mizutani/gollem" 8 | "github.com/m-mizutani/gollem/llm/claude" 9 | "github.com/m-mizutani/gt" 10 | ) 11 | 12 | type complexTool struct{} 13 | 14 | func (t *complexTool) Spec() gollem.ToolSpec { 15 | return gollem.ToolSpec{ 16 | Name: "complex_tool", 17 | Description: "A tool with complex parameter structure", 18 | Required: []string{"user"}, 19 | Parameters: map[string]*gollem.Parameter{ 20 | "user": { 21 | Type: gollem.TypeObject, 22 | Required: []string{"name"}, 23 | Properties: map[string]*gollem.Parameter{ 24 | "name": { 25 | Type: gollem.TypeString, 26 | Description: "User's name", 27 | }, 28 | "address": { 29 | Type: gollem.TypeObject, 30 | Properties: map[string]*gollem.Parameter{ 31 | "street": { 32 | Type: gollem.TypeString, 33 | Description: "Street address", 34 | }, 35 | "city": { 36 | Type: gollem.TypeString, 37 | Description: "City name", 38 | }, 39 | }, 40 | Required: []string{"street"}, 41 | }, 42 | }, 43 | }, 44 | "items": { 45 | Type: gollem.TypeArray, 46 | Items: &gollem.Parameter{ 47 | Type: gollem.TypeObject, 48 | Properties: map[string]*gollem.Parameter{ 49 | "id": { 50 | Type: gollem.TypeString, 51 | Description: "Item ID", 52 | }, 53 | "quantity": { 54 | Type: gollem.TypeNumber, 55 | Description: "Item quantity", 56 | }, 57 | }, 58 | }, 59 | }, 60 | }, 61 | } 62 | } 63 | 64 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 65 | return nil, nil 66 | } 67 | 68 | func TestConvertTool(t *testing.T) { 69 | tool := &complexTool{} 70 | claudeTool := claude.ConvertTool(tool) 71 | 72 | // Check basic properties 73 | gt.Equal(t, claudeTool.OfTool.Name, "complex_tool") 74 | 75 | // Check schema properties 76 | schemaProps := claudeTool.OfTool.InputSchema.Properties.(map[string]claude.JsonSchema) 77 | 78 | // Check user parameter 79 | user := schemaProps["user"] 80 | gt.Equal(t, user.Type, "object") 81 | 82 | userProps := user.Properties 83 | nameProps := userProps["name"] 84 | gt.Equal(t, nameProps.Type, "string") 85 | gt.Equal(t, nameProps.Description, "User's name") 86 | gt.Array(t, user.Required).Equal([]string{"name"}) 87 | 88 | addressProps := userProps["address"].Properties 89 | gt.Equal(t, addressProps["street"].Type, "string") 90 | gt.Equal(t, addressProps["city"].Type, "string") 91 | 92 | // Check items parameter 93 | itemsProp := schemaProps["items"] 94 | gt.Equal(t, itemsProp.Type, "array") 95 | 96 | itemsSchema := *itemsProp.Items 97 | itemsProps := itemsSchema.Properties 98 | gt.Equal(t, itemsProps["id"].Type, "string") 99 | gt.Equal(t, itemsProps["quantity"].Type, "number") 100 | } 101 | 102 | func TestConvertParameterToSchema(t *testing.T) { 103 | type testCase struct { 104 | name string 105 | schema *gollem.Parameter 106 | expected claude.JsonSchema 107 | } 108 | 109 | runTest := func(tc testCase) func(t *testing.T) { 110 | return func(t *testing.T) { 111 | actual := claude.ConvertParameterToSchema(tc.schema) 112 | gt.Value(t, actual).Equal(tc.expected) 113 | } 114 | } 115 | 116 | t.Run("number constraints", runTest(testCase{ 117 | name: "number constraints", 118 | schema: &gollem.Parameter{ 119 | Type: gollem.TypeNumber, 120 | Minimum: ptr(1.0), 121 | Maximum: ptr(10.0), 122 | }, 123 | expected: claude.JsonSchema{ 124 | Type: "number", 125 | Minimum: ptr(1.0), 126 | Maximum: ptr(10.0), 127 | }, 128 | })) 129 | 130 | t.Run("string constraints", runTest(testCase{ 131 | name: "string constraints", 132 | schema: &gollem.Parameter{ 133 | Type: gollem.TypeString, 134 | MinLength: ptr(1), 135 | MaxLength: ptr(10), 136 | Pattern: "^[a-z]+$", 137 | }, 138 | expected: claude.JsonSchema{ 139 | Type: "string", 140 | MinLength: ptr(1), 141 | MaxLength: ptr(10), 142 | Pattern: "^[a-z]+$", 143 | }, 144 | })) 145 | 146 | t.Run("array constraints", runTest(testCase{ 147 | name: "array constraints", 148 | schema: &gollem.Parameter{ 149 | Type: gollem.TypeArray, 150 | Items: &gollem.Parameter{Type: gollem.TypeString}, 151 | MinItems: ptr(1), 152 | MaxItems: ptr(10), 153 | }, 154 | expected: claude.JsonSchema{ 155 | Type: "array", 156 | MinItems: ptr(1), 157 | MaxItems: ptr(10), 158 | Items: &claude.JsonSchema{Type: "string"}, 159 | }, 160 | })) 161 | 162 | t.Run("default value", runTest(testCase{ 163 | name: "default value", 164 | schema: &gollem.Parameter{ 165 | Type: gollem.TypeString, 166 | Default: "default value", 167 | }, 168 | expected: claude.JsonSchema{ 169 | Type: "string", 170 | Default: "default value", 171 | }, 172 | })) 173 | } 174 | 175 | func ptr[T any](v T) *T { 176 | return &v 177 | } 178 | 179 | func TestConvertSchema(t *testing.T) { 180 | type testCase struct { 181 | name string 182 | schema *gollem.Parameter 183 | expected claude.JsonSchema 184 | } 185 | 186 | runTest := func(tc testCase) func(t *testing.T) { 187 | return func(t *testing.T) { 188 | actual := claude.ConvertParameterToSchema(tc.schema) 189 | gt.Value(t, actual).Equal(tc.expected) 190 | } 191 | } 192 | 193 | t.Run("string type", runTest(testCase{ 194 | name: "string type", 195 | schema: &gollem.Parameter{ 196 | Type: gollem.TypeString, 197 | }, 198 | expected: claude.JsonSchema{ 199 | Type: "string", 200 | }, 201 | })) 202 | } 203 | -------------------------------------------------------------------------------- /llm/claude/embedding.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/m-mizutani/goerr/v2" 7 | ) 8 | 9 | // GenerateEmbedding generates embeddings for the given input text. Claude does not support emmbedding generation directly. 10 | func (c *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) { 11 | return nil, goerr.New("Claude does not support embedding generation") 12 | } 13 | -------------------------------------------------------------------------------- /llm/claude/export_test.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | // Export convert functions for testing 4 | var ( 5 | ConvertTool = convertTool 6 | ConvertParameterToSchema = convertParameterToSchema 7 | ) 8 | 9 | type JsonSchema = jsonSchema 10 | -------------------------------------------------------------------------------- /llm/gemini/client.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "cloud.google.com/go/vertexai/genai" 8 | "github.com/m-mizutani/goerr/v2" 9 | "github.com/m-mizutani/gollem" 10 | "google.golang.org/api/iterator" 11 | "google.golang.org/api/option" 12 | ) 13 | 14 | const ( 15 | DefaultModel = "gemini-2.0-flash" 16 | DefaultEmbeddingModel = "text-embedding-004" 17 | ) 18 | 19 | // Client is a client for the Gemini API. 20 | // It provides methods to interact with Google's Gemini models. 21 | type Client struct { 22 | projectID string 23 | location string 24 | 25 | // client is the underlying Gemini client. 26 | client *genai.Client 27 | 28 | // defaultModel is the model to use for chat completions. 29 | // It can be overridden using WithModel option. 30 | defaultModel string 31 | 32 | // embeddingModel is the model to use for embeddings. 33 | // It can be overridden using WithEmbeddingModel option. 34 | embeddingModel string 35 | 36 | // gcpOptions are additional options for Google Cloud Platform. 37 | // They can be set using WithGoogleCloudOptions. 38 | gcpOptions []option.ClientOption 39 | 40 | // generationConfig contains the default generation parameters 41 | generationConfig genai.GenerationConfig 42 | 43 | // systemPrompt is the system prompt to use for chat completions. 44 | systemPrompt string 45 | 46 | // contentType is the type of content to be generated. 47 | contentType gollem.ContentType 48 | } 49 | 50 | // Option is a function that configures a Client. 51 | type Option func(*Client) 52 | 53 | // WithModel sets the default model to use for chat completions. 54 | // The model name should be a valid Gemini model identifier. 55 | // Default: "gemini-2.0-flash" 56 | func WithModel(modelName string) Option { 57 | return func(c *Client) { 58 | c.defaultModel = modelName 59 | } 60 | } 61 | 62 | // WithGoogleCloudOptions sets additional options for Google Cloud Platform. 63 | // These options are passed to the underlying Gemini client. 64 | func WithGoogleCloudOptions(options ...option.ClientOption) Option { 65 | return func(c *Client) { 66 | c.gcpOptions = options 67 | } 68 | } 69 | 70 | // WithTemperature sets the temperature parameter for text generation. 71 | // Higher values make the output more random, lower values make it more focused. 72 | // Range: 0.0 to 1.0 73 | // Default: 0.7 74 | func WithTemperature(temp float32) Option { 75 | return func(c *Client) { 76 | c.generationConfig.Temperature = &temp 77 | } 78 | } 79 | 80 | // WithTopP sets the top_p parameter for text generation. 81 | // Controls diversity via nucleus sampling. 82 | // Range: 0.0 to 1.0 83 | // Default: 1.0 84 | func WithTopP(topP float32) Option { 85 | return func(c *Client) { 86 | c.generationConfig.TopP = &topP 87 | } 88 | } 89 | 90 | // WithTopK sets the top_k parameter for text generation. 91 | // Controls diversity via top-k sampling. 92 | // Range: 1 to 40 93 | func WithTopK(topK int32) Option { 94 | return func(c *Client) { 95 | c.generationConfig.TopK = &topK 96 | } 97 | } 98 | 99 | // WithMaxTokens sets the maximum number of tokens to generate. 100 | func WithMaxTokens(maxTokens int32) Option { 101 | return func(c *Client) { 102 | c.generationConfig.MaxOutputTokens = &maxTokens 103 | } 104 | } 105 | 106 | // WithStopSequences sets the stop sequences for text generation. 107 | func WithStopSequences(stopSequences []string) Option { 108 | return func(c *Client) { 109 | c.generationConfig.StopSequences = stopSequences 110 | } 111 | } 112 | 113 | // WithSystemPrompt sets the system prompt to use for chat completions. 114 | func WithSystemPrompt(prompt string) Option { 115 | return func(c *Client) { 116 | c.systemPrompt = prompt 117 | } 118 | } 119 | 120 | // WithContentType sets the content type for text generation. 121 | // This determines the format of the generated content. 122 | func WithContentType(contentType gollem.ContentType) Option { 123 | return func(c *Client) { 124 | c.contentType = contentType 125 | } 126 | } 127 | 128 | // WithEmbeddingModel sets the model to use for embeddings. 129 | // Default: "textembedding-gecko@latest" 130 | func WithEmbeddingModel(modelName string) Option { 131 | return func(c *Client) { 132 | c.embeddingModel = modelName 133 | } 134 | } 135 | 136 | // New creates a new client for the Gemini API. 137 | // It requires a project ID and location, and can be configured with additional options. 138 | func New(ctx context.Context, projectID, location string, options ...Option) (*Client, error) { 139 | if projectID == "" { 140 | return nil, goerr.New("projectID is required") 141 | } 142 | if location == "" { 143 | return nil, goerr.New("location is required") 144 | } 145 | 146 | client := &Client{ 147 | projectID: projectID, 148 | location: location, 149 | defaultModel: DefaultModel, 150 | embeddingModel: DefaultEmbeddingModel, 151 | contentType: gollem.ContentTypeText, 152 | } 153 | 154 | for _, option := range options { 155 | option(client) 156 | } 157 | 158 | newClient, err := genai.NewClient(ctx, projectID, location, client.gcpOptions...) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | client.client = newClient 164 | 165 | return client, nil 166 | } 167 | 168 | // NewSession creates a new session for the Gemini API. 169 | // It converts the provided tools to Gemini's tool format and initializes a new chat session. 170 | func (c *Client) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) { 171 | cfg := gollem.NewSessionConfig(options...) 172 | 173 | // Convert gollem.Tool to *genai.Tool 174 | genaiFunctions := make([]*genai.FunctionDeclaration, len(cfg.Tools())) 175 | for i, tool := range cfg.Tools() { 176 | genaiFunctions[i] = convertTool(tool) 177 | } 178 | 179 | var messages []*genai.Content 180 | 181 | if cfg.History() != nil { 182 | history, err := cfg.History().ToGemini() 183 | if err != nil { 184 | return nil, goerr.Wrap(err, "failed to convert history to gemini.Content") 185 | } 186 | messages = append(messages, history...) 187 | } 188 | 189 | model := c.client.GenerativeModel(c.defaultModel) 190 | model.GenerationConfig = c.generationConfig 191 | 192 | switch cfg.ContentType() { 193 | case gollem.ContentTypeJSON: 194 | model.GenerationConfig.ResponseMIMEType = "application/json" 195 | case gollem.ContentTypeText: 196 | model.GenerationConfig.ResponseMIMEType = "text/plain" 197 | } 198 | 199 | if cfg.SystemPrompt() != "" { 200 | model.SystemInstruction = &genai.Content{ 201 | Role: "system", 202 | Parts: []genai.Part{genai.Text(cfg.SystemPrompt())}, 203 | } 204 | } 205 | 206 | if len(genaiFunctions) > 0 { 207 | model.Tools = []*genai.Tool{ 208 | { 209 | FunctionDeclarations: genaiFunctions, 210 | }, 211 | } 212 | } 213 | 214 | session := &Session{ 215 | session: model.StartChat(), 216 | } 217 | if len(messages) > 0 { 218 | session.session.History = messages 219 | } 220 | 221 | return session, nil 222 | } 223 | 224 | func (s *Session) History() *gollem.History { 225 | return gollem.NewHistoryFromGemini(s.session.History) 226 | } 227 | 228 | // Session is a session for the Gemini chat. 229 | // It maintains the conversation state and handles message generation. 230 | type Session struct { 231 | // session is the underlying Gemini chat session. 232 | session *genai.ChatSession 233 | } 234 | 235 | // convertInputs converts gollem.Input to Gemini parts 236 | func (s *Session) convertInputs(input ...gollem.Input) ([]genai.Part, error) { 237 | parts := make([]genai.Part, len(input)) 238 | for i, in := range input { 239 | switch v := in.(type) { 240 | case gollem.Text: 241 | parts[i] = genai.Text(string(v)) 242 | case gollem.FunctionResponse: 243 | if v.Error != nil { 244 | parts[i] = genai.FunctionResponse{ 245 | Name: v.Name, 246 | Response: map[string]any{ 247 | "error_message": fmt.Sprintf("%+v", v.Error), 248 | }, 249 | } 250 | } else { 251 | parts[i] = genai.FunctionResponse{ 252 | Name: v.Name, 253 | Response: v.Data, 254 | } 255 | } 256 | default: 257 | return nil, goerr.Wrap(gollem.ErrInvalidParameter, "invalid input") 258 | } 259 | } 260 | return parts, nil 261 | } 262 | 263 | // processResponse converts Gemini response to gollem.Response 264 | func processResponse(resp *genai.GenerateContentResponse) *gollem.Response { 265 | if len(resp.Candidates) == 0 { 266 | return &gollem.Response{} 267 | } 268 | 269 | response := &gollem.Response{ 270 | Texts: make([]string, 0), 271 | FunctionCalls: make([]*gollem.FunctionCall, 0), 272 | } 273 | 274 | for _, candidate := range resp.Candidates { 275 | for _, part := range candidate.Content.Parts { 276 | switch v := part.(type) { 277 | case genai.Text: 278 | response.Texts = append(response.Texts, string(v)) 279 | case genai.FunctionCall: 280 | response.FunctionCalls = append(response.FunctionCalls, &gollem.FunctionCall{ 281 | Name: v.Name, 282 | Arguments: v.Args, 283 | }) 284 | } 285 | } 286 | } 287 | 288 | return response 289 | } 290 | 291 | // GenerateContent processes the input and generates a response. 292 | // It handles both text messages and function responses. 293 | func (s *Session) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) { 294 | parts, err := s.convertInputs(input...) 295 | if err != nil { 296 | return nil, err 297 | } 298 | 299 | resp, err := s.session.SendMessage(ctx, parts...) 300 | if err != nil { 301 | return nil, goerr.Wrap(err, "failed to send message") 302 | } 303 | 304 | return processResponse(resp), nil 305 | } 306 | 307 | // GenerateStream processes the input and generates a response stream. 308 | // It handles both text messages and function responses, and returns a channel for streaming responses. 309 | func (s *Session) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) { 310 | parts, err := s.convertInputs(input...) 311 | if err != nil { 312 | return nil, err 313 | } 314 | 315 | iter := s.session.SendMessageStream(ctx, parts...) 316 | responseChan := make(chan *gollem.Response) 317 | 318 | go func() { 319 | defer close(responseChan) 320 | 321 | for { 322 | resp, err := iter.Next() 323 | if err != nil { 324 | if err == iterator.Done { 325 | return 326 | } 327 | responseChan <- &gollem.Response{ 328 | Error: goerr.Wrap(err, "failed to generate stream"), 329 | } 330 | return 331 | } 332 | 333 | responseChan <- processResponse(resp) 334 | } 335 | }() 336 | 337 | return responseChan, nil 338 | } 339 | -------------------------------------------------------------------------------- /llm/gemini/convert.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "cloud.google.com/go/vertexai/genai" 5 | "github.com/m-mizutani/gollem" 6 | ) 7 | 8 | // convertTool converts gollem.Tool to Gemini tool 9 | func convertTool(tool gollem.Tool) *genai.FunctionDeclaration { 10 | spec := tool.Spec() 11 | parameters := &genai.Schema{ 12 | Type: genai.TypeObject, 13 | Properties: make(map[string]*genai.Schema), 14 | Required: spec.Required, 15 | } 16 | 17 | for name, param := range spec.Parameters { 18 | parameters.Properties[name] = convertParameterToSchema(param) 19 | } 20 | 21 | return &genai.FunctionDeclaration{ 22 | Name: spec.Name, 23 | Description: spec.Description, 24 | Parameters: parameters, 25 | } 26 | } 27 | 28 | // convertParameterToSchema converts gollem.Parameter to Gemini schema 29 | func convertParameterToSchema(param *gollem.Parameter) *genai.Schema { 30 | schema := &genai.Schema{ 31 | Type: getGeminiType(param.Type), 32 | Description: param.Description, 33 | Title: param.Title, 34 | } 35 | 36 | if len(param.Enum) > 0 { 37 | schema.Enum = param.Enum 38 | } 39 | 40 | if param.Properties != nil { 41 | schema.Properties = make(map[string]*genai.Schema) 42 | for name, prop := range param.Properties { 43 | schema.Properties[name] = convertParameterToSchema(prop) 44 | } 45 | if len(param.Required) > 0 { 46 | schema.Required = param.Required 47 | } 48 | } 49 | 50 | if param.Items != nil { 51 | schema.Items = convertParameterToSchema(param.Items) 52 | } 53 | 54 | // Add number constraints 55 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger { 56 | if param.Minimum != nil { 57 | schema.Minimum = *param.Minimum 58 | } 59 | if param.Maximum != nil { 60 | schema.Maximum = *param.Maximum 61 | } 62 | } 63 | 64 | // Add string constraints 65 | if param.Type == gollem.TypeString { 66 | if param.MinLength != nil { 67 | schema.MinLength = int64(*param.MinLength) 68 | } 69 | if param.MaxLength != nil { 70 | schema.MaxLength = int64(*param.MaxLength) 71 | } 72 | if param.Pattern != "" { 73 | schema.Pattern = param.Pattern 74 | } 75 | } 76 | 77 | // Add array constraints 78 | if param.Type == gollem.TypeArray { 79 | if param.MinItems != nil { 80 | schema.MinItems = int64(*param.MinItems) 81 | } 82 | if param.MaxItems != nil { 83 | schema.MaxItems = int64(*param.MaxItems) 84 | } 85 | } 86 | 87 | // No default value in Gemini 88 | 89 | return schema 90 | } 91 | 92 | func getGeminiType(paramType gollem.ParameterType) genai.Type { 93 | switch paramType { 94 | case gollem.TypeString: 95 | return genai.TypeString 96 | case gollem.TypeNumber: 97 | return genai.TypeNumber 98 | case gollem.TypeInteger: 99 | return genai.TypeInteger 100 | case gollem.TypeBoolean: 101 | return genai.TypeBoolean 102 | case gollem.TypeArray: 103 | return genai.TypeArray 104 | case gollem.TypeObject: 105 | return genai.TypeObject 106 | default: 107 | return genai.TypeString 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /llm/gemini/convert_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "cloud.google.com/go/vertexai/genai" 8 | "github.com/m-mizutani/gollem" 9 | "github.com/m-mizutani/gollem/llm/gemini" 10 | "github.com/m-mizutani/gt" 11 | ) 12 | 13 | type complexTool struct{} 14 | 15 | func (t *complexTool) Spec() gollem.ToolSpec { 16 | return gollem.ToolSpec{ 17 | Name: "complex_tool", 18 | Description: "A tool with complex parameter structure", 19 | Required: []string{"user", "items"}, 20 | Parameters: map[string]*gollem.Parameter{ 21 | "user": { 22 | Type: gollem.TypeObject, 23 | Required: []string{"name"}, 24 | Properties: map[string]*gollem.Parameter{ 25 | "name": { 26 | Type: gollem.TypeString, 27 | Description: "User's name", 28 | }, 29 | "address": { 30 | Type: gollem.TypeObject, 31 | Properties: map[string]*gollem.Parameter{ 32 | "street": { 33 | Type: gollem.TypeString, 34 | Description: "Street address", 35 | }, 36 | "city": { 37 | Type: gollem.TypeString, 38 | Description: "City name", 39 | }, 40 | }, 41 | }, 42 | }, 43 | }, 44 | "items": { 45 | Type: gollem.TypeArray, 46 | Items: &gollem.Parameter{ 47 | Type: gollem.TypeObject, 48 | Properties: map[string]*gollem.Parameter{ 49 | "id": { 50 | Type: gollem.TypeString, 51 | Description: "Item ID", 52 | }, 53 | "quantity": { 54 | Type: gollem.TypeNumber, 55 | Description: "Item quantity", 56 | }, 57 | }, 58 | }, 59 | }, 60 | }, 61 | } 62 | } 63 | 64 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 65 | return nil, nil 66 | } 67 | 68 | func TestConvertTool(t *testing.T) { 69 | tool := &complexTool{} 70 | genaiTool := gemini.ConvertTool(tool) 71 | 72 | gt.Value(t, genaiTool.Name).Equal("complex_tool") 73 | gt.Value(t, genaiTool.Description).Equal("A tool with complex parameter structure") 74 | 75 | params := genaiTool.Parameters 76 | gt.Value(t, params.Type).Equal(genai.TypeObject) 77 | gt.Value(t, params.Required).Equal([]string{"user", "items"}) 78 | 79 | // Check user object 80 | user := params.Properties["user"] 81 | gt.Value(t, user.Type).Equal(genai.TypeObject) 82 | gt.Value(t, user.Properties["name"].Type).Equal(genai.TypeString) 83 | gt.Value(t, user.Properties["name"].Description).Equal("User's name") 84 | gt.Value(t, user.Required).Equal([]string{"name"}) 85 | 86 | // Check address object 87 | address := user.Properties["address"] 88 | gt.Value(t, address.Type).Equal(genai.TypeObject) 89 | gt.Value(t, address.Properties["street"].Type).Equal(genai.TypeString) 90 | gt.Value(t, address.Properties["city"].Type).Equal(genai.TypeString) 91 | 92 | // Check items array 93 | items := params.Properties["items"] 94 | gt.Value(t, items.Type).Equal(genai.TypeArray) 95 | gt.Value(t, items.Items.Type).Equal(genai.TypeObject) 96 | gt.Value(t, items.Items.Properties["id"].Type).Equal(genai.TypeString) 97 | gt.Value(t, items.Items.Properties["quantity"].Type).Equal(genai.TypeNumber) 98 | } 99 | 100 | func TestConvertParameterToSchema(t *testing.T) { 101 | t.Run("number constraints", func(t *testing.T) { 102 | p := &gollem.Parameter{ 103 | Type: gollem.TypeNumber, 104 | Minimum: ptr(1.0), 105 | Maximum: ptr(10.0), 106 | } 107 | schema := gemini.ConvertParameterToSchema(p) 108 | gt.Value(t, schema.Minimum).Equal(1.0) 109 | gt.Value(t, schema.Maximum).Equal(10.0) 110 | }) 111 | 112 | t.Run("string constraints", func(t *testing.T) { 113 | p := &gollem.Parameter{ 114 | Type: gollem.TypeString, 115 | MinLength: ptr(1), 116 | MaxLength: ptr(10), 117 | Pattern: "^[a-z]+$", 118 | } 119 | schema := gemini.ConvertParameterToSchema(p) 120 | gt.Value(t, schema.MinLength).Equal(int64(1)) 121 | gt.Value(t, schema.MaxLength).Equal(int64(10)) 122 | gt.Value(t, schema.Pattern).Equal("^[a-z]+$") 123 | }) 124 | 125 | t.Run("array constraints", func(t *testing.T) { 126 | p := &gollem.Parameter{ 127 | Type: gollem.TypeArray, 128 | Items: &gollem.Parameter{Type: gollem.TypeString}, 129 | MinItems: ptr(1), 130 | MaxItems: ptr(10), 131 | } 132 | schema := gemini.ConvertParameterToSchema(p) 133 | gt.Value(t, schema.MinItems).Equal(int64(1)) 134 | gt.Value(t, schema.MaxItems).Equal(int64(10)) 135 | gt.Value(t, schema.Items.Type).Equal(genai.TypeString) 136 | }) 137 | } 138 | 139 | func ptr[T any](v T) *T { 140 | return &v 141 | } 142 | -------------------------------------------------------------------------------- /llm/gemini/embedding.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | aiplatform "cloud.google.com/go/aiplatform/apiv1" 8 | "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" 9 | "github.com/m-mizutani/goerr/v2" 10 | "google.golang.org/api/option" 11 | "google.golang.org/protobuf/types/known/structpb" 12 | ) 13 | 14 | // GenerateEmbedding generates embeddings for the given input text. 15 | func (x *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) { 16 | apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", x.location) 17 | 18 | client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) 19 | if err != nil { 20 | return nil, goerr.Wrap(err, "failed to create aiplatform client") 21 | } 22 | defer client.Close() 23 | 24 | endpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", x.projectID, x.location, x.embeddingModel) 25 | instances := make([]*structpb.Value, len(input)) 26 | 27 | for i, v := range input { 28 | instances[i] = structpb.NewStructValue(&structpb.Struct{ 29 | Fields: map[string]*structpb.Value{ 30 | "content": structpb.NewStringValue(v), 31 | "task_type": structpb.NewStringValue("QUESTION_ANSWERING"), 32 | }, 33 | }) 34 | } 35 | 36 | params := structpb.NewStructValue(&structpb.Struct{ 37 | Fields: map[string]*structpb.Value{ 38 | "outputDimensionality": structpb.NewNumberValue(float64(dimension)), 39 | }, 40 | }) 41 | 42 | req := &aiplatformpb.PredictRequest{ 43 | Endpoint: endpoint, 44 | Instances: instances, 45 | Parameters: params, 46 | } 47 | resp, err := client.Predict(ctx, req) 48 | if err != nil { 49 | return nil, goerr.Wrap(err, "failed to predict", 50 | goerr.V("endpoint", endpoint), 51 | goerr.V("dimensionality", dimension), 52 | ) 53 | } 54 | 55 | if len(resp.Predictions) == 0 { 56 | return nil, goerr.New("no predictions returned") 57 | } 58 | 59 | embeddings := make([][]float64, len(resp.Predictions)) 60 | for i, prediction := range resp.Predictions { 61 | values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values 62 | embedding := make([]float64, len(values)) 63 | for j, value := range values { 64 | embedding[j] = float64(value.GetNumberValue()) 65 | } 66 | embeddings[i] = embedding 67 | } 68 | 69 | return embeddings, nil 70 | } 71 | -------------------------------------------------------------------------------- /llm/gemini/embeding_test.go: -------------------------------------------------------------------------------- 1 | package gemini_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/m-mizutani/gollem/llm/gemini" 8 | "github.com/m-mizutani/gt" 9 | ) 10 | 11 | func TestGenerateEmbedding(t *testing.T) { 12 | projectID, ok := os.LookupEnv("TEST_GCP_PROJECT_ID") 13 | if !ok { 14 | t.Skip("TEST_GCP_PROJECT_ID is not set") 15 | } 16 | 17 | location, ok := os.LookupEnv("TEST_GCP_LOCATION") 18 | if !ok { 19 | t.Skip("TEST_GCP_LOCATION is not set") 20 | } 21 | 22 | ctx := t.Context() 23 | client, err := gemini.New(ctx, projectID, location) 24 | if err != nil { 25 | t.Fatalf("failed to create client: %v", err) 26 | } 27 | 28 | embeddings, err := client.GenerateEmbedding(ctx, 256, []string{"not, SANE", "Five timeless words"}) 29 | if err != nil { 30 | t.Fatalf("failed to generate embedding: %v", err) 31 | } 32 | 33 | gt.A(t, embeddings).Length(2). 34 | At(0, func(t testing.TB, v []float64) { 35 | gt.A(t, v).Longer(0) 36 | }). 37 | At(1, func(t testing.TB, v []float64) { 38 | gt.A(t, v).Longer(0) 39 | }) 40 | } 41 | -------------------------------------------------------------------------------- /llm/gemini/export_test.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | // Export convert functions for testing 4 | var ( 5 | ConvertTool = convertTool 6 | ConvertParameterToSchema = convertParameterToSchema 7 | ) 8 | -------------------------------------------------------------------------------- /llm/openai/convert.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "github.com/m-mizutani/gollem" 5 | "github.com/sashabaranov/go-openai" 6 | ) 7 | 8 | // convertTool converts gollem.Tool to openai.Tool 9 | func convertTool(tool gollem.Tool) openai.Tool { 10 | parameters := make(map[string]interface{}) 11 | properties := make(map[string]interface{}) 12 | spec := tool.Spec() 13 | 14 | for name, param := range spec.Parameters { 15 | properties[name] = convertParameterToSchema(param) 16 | } 17 | 18 | if len(properties) > 0 { 19 | parameters["type"] = "object" 20 | parameters["properties"] = properties 21 | if len(spec.Required) > 0 { 22 | parameters["required"] = spec.Required 23 | } 24 | } 25 | 26 | return openai.Tool{ 27 | Type: openai.ToolTypeFunction, 28 | Function: &openai.FunctionDefinition{ 29 | Name: spec.Name, 30 | Description: spec.Description, 31 | Parameters: parameters, 32 | }, 33 | } 34 | } 35 | 36 | // convertParameterToSchema converts gollem.Parameter to OpenAI schema 37 | func convertParameterToSchema(param *gollem.Parameter) map[string]interface{} { 38 | schema := map[string]interface{}{ 39 | "type": getOpenAIType(param.Type), 40 | "description": param.Description, 41 | "title": param.Title, 42 | } 43 | 44 | if len(param.Enum) > 0 { 45 | schema["enum"] = param.Enum 46 | } 47 | 48 | if param.Properties != nil { 49 | properties := make(map[string]interface{}) 50 | for name, prop := range param.Properties { 51 | properties[name] = convertParameterToSchema(prop) 52 | } 53 | schema["properties"] = properties 54 | if len(param.Required) > 0 { 55 | schema["required"] = param.Required 56 | } 57 | } 58 | 59 | if param.Items != nil { 60 | schema["items"] = convertParameterToSchema(param.Items) 61 | } 62 | 63 | // Add number constraints 64 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger { 65 | if param.Minimum != nil { 66 | schema["minimum"] = *param.Minimum 67 | } 68 | if param.Maximum != nil { 69 | schema["maximum"] = *param.Maximum 70 | } 71 | } 72 | 73 | // Add string constraints 74 | if param.Type == gollem.TypeString { 75 | if param.MinLength != nil { 76 | schema["minLength"] = *param.MinLength 77 | } 78 | if param.MaxLength != nil { 79 | schema["maxLength"] = *param.MaxLength 80 | } 81 | if param.Pattern != "" { 82 | schema["pattern"] = param.Pattern 83 | } 84 | } 85 | 86 | // Add array constraints 87 | if param.Type == gollem.TypeArray { 88 | if param.MinItems != nil { 89 | schema["minItems"] = *param.MinItems 90 | } 91 | if param.MaxItems != nil { 92 | schema["maxItems"] = *param.MaxItems 93 | } 94 | } 95 | 96 | // Add default value 97 | if param.Default != nil { 98 | schema["default"] = param.Default 99 | } 100 | 101 | return schema 102 | } 103 | 104 | func getOpenAIType(paramType gollem.ParameterType) string { 105 | switch paramType { 106 | case gollem.TypeString: 107 | return "string" 108 | case gollem.TypeNumber: 109 | return "number" 110 | case gollem.TypeInteger: 111 | return "integer" 112 | case gollem.TypeBoolean: 113 | return "boolean" 114 | case gollem.TypeArray: 115 | return "array" 116 | case gollem.TypeObject: 117 | return "object" 118 | default: 119 | return "string" 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /llm/openai/convert_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/m-mizutani/gollem" 8 | "github.com/m-mizutani/gt" 9 | ) 10 | 11 | type complexTool struct{} 12 | 13 | func (t *complexTool) Spec() gollem.ToolSpec { 14 | return gollem.ToolSpec{ 15 | Name: "complex_tool", 16 | Description: "A tool with complex parameter structure", 17 | Parameters: map[string]*gollem.Parameter{ 18 | "user": { 19 | Type: gollem.TypeObject, 20 | Required: []string{"name"}, 21 | Properties: map[string]*gollem.Parameter{ 22 | "name": { 23 | Type: gollem.TypeString, 24 | Description: "User's name", 25 | }, 26 | "address": { 27 | Type: gollem.TypeObject, 28 | Properties: map[string]*gollem.Parameter{ 29 | "street": { 30 | Type: gollem.TypeString, 31 | Description: "Street address", 32 | }, 33 | "city": { 34 | Type: gollem.TypeString, 35 | Description: "City name", 36 | }, 37 | }, 38 | }, 39 | }, 40 | }, 41 | "items": { 42 | Type: gollem.TypeArray, 43 | Items: &gollem.Parameter{ 44 | Type: gollem.TypeObject, 45 | Properties: map[string]*gollem.Parameter{ 46 | "id": { 47 | Type: gollem.TypeString, 48 | Description: "Item ID", 49 | }, 50 | "quantity": { 51 | Type: gollem.TypeNumber, 52 | Description: "Item quantity", 53 | }, 54 | }, 55 | }, 56 | }, 57 | }, 58 | } 59 | } 60 | 61 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 62 | return nil, nil 63 | } 64 | 65 | func TestConvertTool(t *testing.T) { 66 | tool := &complexTool{} 67 | openaiTool := ConvertTool(tool) 68 | 69 | gt.Value(t, openaiTool.Type).Equal("function") 70 | gt.Value(t, openaiTool.Function.Name).Equal("complex_tool") 71 | gt.Value(t, openaiTool.Function.Description).Equal("A tool with complex parameter structure") 72 | 73 | params := openaiTool.Function.Parameters.(map[string]interface{}) 74 | gt.Value(t, params["type"]).Equal("object") 75 | 76 | // Check user object 77 | user := params["properties"].(map[string]interface{})["user"].(map[string]interface{}) 78 | gt.Value(t, user["type"]).Equal("object") 79 | gt.Value(t, user["properties"].(map[string]interface{})["name"].(map[string]interface{})["type"]).Equal("string") 80 | gt.Value(t, user["properties"].(map[string]interface{})["name"].(map[string]interface{})["description"]).Equal("User's name") 81 | gt.Value(t, user["required"]).Equal([]string{"name"}) 82 | 83 | // Check address object 84 | address := user["properties"].(map[string]interface{})["address"].(map[string]interface{}) 85 | gt.Value(t, address["type"]).Equal("object") 86 | gt.Value(t, address["properties"].(map[string]interface{})["street"].(map[string]interface{})["type"]).Equal("string") 87 | gt.Value(t, address["properties"].(map[string]interface{})["city"].(map[string]interface{})["type"]).Equal("string") 88 | 89 | // Check items array 90 | items := params["properties"].(map[string]interface{})["items"].(map[string]interface{}) 91 | gt.Value(t, items["type"]).Equal("array") 92 | gt.Value(t, items["items"].(map[string]interface{})["type"]).Equal("object") 93 | gt.Value(t, items["items"].(map[string]interface{})["properties"].(map[string]interface{})["id"].(map[string]interface{})["type"]).Equal("string") 94 | gt.Value(t, items["items"].(map[string]interface{})["properties"].(map[string]interface{})["quantity"].(map[string]interface{})["type"]).Equal("number") 95 | } 96 | 97 | func TestConvertParameterToSchema(t *testing.T) { 98 | t.Run("number constraints", func(t *testing.T) { 99 | p := &gollem.Parameter{ 100 | Type: gollem.TypeNumber, 101 | Minimum: ptr(1.0), 102 | Maximum: ptr(10.0), 103 | } 104 | schema := convertParameterToSchema(p) 105 | gt.Value(t, schema["minimum"]).Equal(1.0) 106 | gt.Value(t, schema["maximum"]).Equal(10.0) 107 | }) 108 | 109 | t.Run("string constraints", func(t *testing.T) { 110 | p := &gollem.Parameter{ 111 | Type: gollem.TypeString, 112 | MinLength: ptr(1), 113 | MaxLength: ptr(10), 114 | Pattern: "^[a-z]+$", 115 | } 116 | schema := convertParameterToSchema(p) 117 | gt.Value(t, schema["minLength"]).Equal(1) 118 | gt.Value(t, schema["maxLength"]).Equal(10) 119 | gt.Value(t, schema["pattern"]).Equal("^[a-z]+$") 120 | }) 121 | 122 | t.Run("array constraints", func(t *testing.T) { 123 | p := &gollem.Parameter{ 124 | Type: gollem.TypeArray, 125 | Items: &gollem.Parameter{Type: gollem.TypeString}, 126 | MinItems: ptr(1), 127 | MaxItems: ptr(10), 128 | } 129 | schema := convertParameterToSchema(p) 130 | gt.Value(t, schema["minItems"]).Equal(1) 131 | gt.Value(t, schema["maxItems"]).Equal(10) 132 | gt.Value(t, schema["items"].(map[string]interface{})["type"]).Equal("string") 133 | }) 134 | 135 | t.Run("default value", func(t *testing.T) { 136 | p := &gollem.Parameter{ 137 | Type: gollem.TypeString, 138 | Default: "default value", 139 | } 140 | schema := convertParameterToSchema(p) 141 | gt.Value(t, schema["default"]).Equal("default value") 142 | }) 143 | } 144 | 145 | func ptr[T any](v T) *T { 146 | return &v 147 | } 148 | -------------------------------------------------------------------------------- /llm/openai/embedding.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/m-mizutani/goerr/v2" 7 | "github.com/sashabaranov/go-openai" 8 | ) 9 | 10 | // GenerateEmbedding generates embeddings for the given input text. 11 | func (c *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) { 12 | /* 13 | AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" 14 | SmallEmbedding3 EmbeddingModel = "text-embedding-3-small" 15 | LargeEmbedding3 EmbeddingModel = "text-embedding-3-large" 16 | */ 17 | modelMap := map[string]openai.EmbeddingModel{ 18 | "text-embedding-ada-002": openai.AdaEmbeddingV2, 19 | "text-embedding-3-small": openai.SmallEmbedding3, 20 | "text-embedding-3-large": openai.LargeEmbedding3, 21 | } 22 | 23 | model, ok := modelMap[c.embeddingModel] 24 | if !ok { 25 | return nil, goerr.New("invalid or unsupported embedding model. See https://platform.openai.com/docs/guides/embeddings#embedding-models", goerr.V("model", c.embeddingModel)) 26 | } 27 | 28 | req := openai.EmbeddingRequest{ 29 | Input: input, 30 | Model: model, 31 | Dimensions: dimension, 32 | } 33 | 34 | resp, err := c.client.CreateEmbeddings(ctx, req) 35 | if err != nil { 36 | return nil, goerr.Wrap(err, "failed to create embedding") 37 | } 38 | 39 | if len(resp.Data) == 0 { 40 | return nil, goerr.New("no embedding data returned") 41 | } 42 | 43 | embeddings := make([][]float64, len(resp.Data)) 44 | for i, data := range resp.Data { 45 | embeddings[i] = make([]float64, len(data.Embedding)) 46 | for j, v := range data.Embedding { 47 | embeddings[i][j] = float64(v) 48 | } 49 | } 50 | 51 | return embeddings, nil 52 | } 53 | -------------------------------------------------------------------------------- /llm/openai/embedding_test.go: -------------------------------------------------------------------------------- 1 | package openai_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/m-mizutani/gollem/llm/openai" 8 | "github.com/m-mizutani/gt" 9 | ) 10 | 11 | func TestGenerateEmbedding(t *testing.T) { 12 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY") 13 | if !ok { 14 | t.Skip("TEST_OPENAI_API_KEY is not set") 15 | } 16 | 17 | ctx := t.Context() 18 | client, err := openai.New(ctx, apiKey) 19 | if err != nil { 20 | t.Fatalf("failed to create client: %v", err) 21 | } 22 | 23 | embeddings, err := client.GenerateEmbedding(ctx, 256, []string{"not, SANE", "Five timeless words"}) 24 | if err != nil { 25 | t.Fatalf("failed to generate embedding: %v", err) 26 | } 27 | 28 | gt.A(t, embeddings).Length(2). 29 | At(0, func(t testing.TB, v []float64) { 30 | gt.A(t, v).Longer(0) 31 | }). 32 | At(1, func(t testing.TB, v []float64) { 33 | gt.A(t, v).Longer(0) 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /llm/openai/export_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | // Export convert functions for testing 4 | var ( 5 | ConvertTool = convertTool 6 | ConvertParameterToSchema = convertParameterToSchema 7 | ) 8 | -------------------------------------------------------------------------------- /llm_test.go: -------------------------------------------------------------------------------- 1 | package gollem_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "os" 7 | "testing" 8 | 9 | "github.com/m-mizutani/goerr/v2" 10 | "github.com/m-mizutani/gollem" 11 | "github.com/m-mizutani/gollem/llm/claude" 12 | "github.com/m-mizutani/gollem/llm/gemini" 13 | "github.com/m-mizutani/gollem/llm/openai" 14 | "github.com/m-mizutani/gt" 15 | ) 16 | 17 | // Sample tool implementation for testing 18 | type randomNumberTool struct{} 19 | 20 | func (t *randomNumberTool) Spec() gollem.ToolSpec { 21 | return gollem.ToolSpec{ 22 | Name: "random_number", 23 | Description: "A tool for generating random numbers within a specified range", 24 | Parameters: map[string]*gollem.Parameter{ 25 | "min": { 26 | Type: gollem.TypeNumber, 27 | Description: "Minimum value of the random number", 28 | }, 29 | "max": { 30 | Type: gollem.TypeNumber, 31 | Description: "Maximum value of the random number", 32 | }, 33 | }, 34 | Required: []string{"min", "max"}, 35 | } 36 | } 37 | 38 | func (t *randomNumberTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 39 | min, ok := args["min"].(float64) 40 | if !ok { 41 | return nil, goerr.New("min is required") 42 | } 43 | 44 | max, ok := args["max"].(float64) 45 | if !ok { 46 | return nil, goerr.New("max is required") 47 | } 48 | 49 | if min >= max { 50 | return nil, goerr.New("min must be less than max") 51 | } 52 | 53 | // Note: In real implementation, you would use a proper random number generator 54 | // This is just for testing purposes 55 | result := (min + max) / 2 56 | 57 | return map[string]any{"result": result}, nil 58 | } 59 | 60 | func testGenerateContent(t *testing.T, session gollem.Session) { 61 | ctx := t.Context() 62 | 63 | // Test case 1: Generate random number 64 | resp1, err := session.GenerateContent(ctx, gollem.Text("Please generate a random number between 1 and 10")) 65 | gt.NoError(t, err) 66 | gt.Array(t, resp1.FunctionCalls).Length(1).Required() 67 | gt.Value(t, resp1.FunctionCalls[0].Name).Equal("random_number") 68 | 69 | args := resp1.FunctionCalls[0].Arguments 70 | gt.Value(t, args["min"]).Equal(1.0) 71 | gt.Value(t, args["max"]).Equal(10.0) 72 | 73 | resp2, err := session.GenerateContent(ctx, gollem.FunctionResponse{ 74 | ID: resp1.FunctionCalls[0].ID, 75 | Name: "random_number", 76 | Data: map[string]any{"result": 5.5}, 77 | }) 78 | gt.NoError(t, err).Required() 79 | gt.Array(t, resp2.Texts).Length(1).Required() 80 | } 81 | 82 | func testGenerateStream(t *testing.T, session gollem.Session) { 83 | ctx := t.Context() 84 | 85 | t.Run("generate random number", func(t *testing.T) { 86 | stream, err := session.GenerateStream(ctx, gollem.Text("Please generate a random number between 1 and 10")) 87 | gt.NoError(t, err).Required() 88 | 89 | var id string 90 | for resp := range stream { 91 | gt.NoError(t, resp.Error).Required() 92 | 93 | if len(resp.FunctionCalls) > 0 { 94 | for _, functionCall := range resp.FunctionCalls { 95 | if functionCall.ID != "" { 96 | id = functionCall.ID 97 | } 98 | } 99 | } 100 | } 101 | 102 | stream, err = session.GenerateStream(ctx, gollem.FunctionResponse{ 103 | ID: id, 104 | Name: "random_number", 105 | Data: map[string]any{"result": 5.5}, 106 | }) 107 | gt.NoError(t, err).Required() 108 | for resp := range stream { 109 | gt.NoError(t, resp.Error).Required() 110 | } 111 | }) 112 | } 113 | 114 | func newGeminiClient(t *testing.T) gollem.LLMClient { 115 | var testProjectID, testLocation string 116 | v, ok := os.LookupEnv("TEST_GCP_PROJECT_ID") 117 | if !ok { 118 | t.Skip("TEST_GCP_PROJECT_ID is not set") 119 | } else { 120 | testProjectID = v 121 | } 122 | 123 | v, ok = os.LookupEnv("TEST_GCP_LOCATION") 124 | if !ok { 125 | t.Skip("TEST_GCP_LOCATION is not set") 126 | } else { 127 | testLocation = v 128 | } 129 | 130 | ctx := t.Context() 131 | client, err := gemini.New(ctx, testProjectID, testLocation) 132 | gt.NoError(t, err) 133 | return client 134 | } 135 | 136 | func newOpenAIClient(t *testing.T) gollem.LLMClient { 137 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY") 138 | if !ok { 139 | t.Skip("TEST_OPENAI_API_KEY is not set") 140 | } 141 | 142 | ctx := t.Context() 143 | client, err := openai.New(ctx, apiKey) 144 | gt.NoError(t, err) 145 | return client 146 | } 147 | 148 | func newClaudeClient(t *testing.T) gollem.LLMClient { 149 | apiKey, ok := os.LookupEnv("TEST_CLAUDE_API_KEY") 150 | if !ok { 151 | t.Skip("TEST_CLAUDE_API_KEY is not set") 152 | } 153 | 154 | client, err := claude.New(context.Background(), apiKey) 155 | gt.NoError(t, err) 156 | return client 157 | } 158 | 159 | func TestGemini(t *testing.T) { 160 | client := newGeminiClient(t) 161 | 162 | // Setup tools 163 | tools := []gollem.Tool{&randomNumberTool{}} 164 | session, err := client.NewSession(t.Context(), gollem.WithSessionTools(tools...)) 165 | gt.NoError(t, err) 166 | 167 | t.Run("generate content", func(t *testing.T) { 168 | testGenerateContent(t, session) 169 | }) 170 | t.Run("generate stream", func(t *testing.T) { 171 | testGenerateStream(t, session) 172 | }) 173 | } 174 | 175 | func TestOpenAI(t *testing.T) { 176 | client := newOpenAIClient(t) 177 | 178 | // Setup tools 179 | tools := []gollem.Tool{&randomNumberTool{}} 180 | session, err := client.NewSession(t.Context(), gollem.WithSessionTools(tools...)) 181 | gt.NoError(t, err) 182 | 183 | t.Run("generate content", func(t *testing.T) { 184 | testGenerateContent(t, session) 185 | }) 186 | t.Run("generate stream", func(t *testing.T) { 187 | testGenerateStream(t, session) 188 | }) 189 | } 190 | 191 | func TestClaude(t *testing.T) { 192 | client := newClaudeClient(t) 193 | 194 | session, err := client.NewSession(context.Background(), gollem.WithSessionTools(&randomNumberTool{})) 195 | gt.NoError(t, err) 196 | 197 | t.Run("generate content", func(t *testing.T) { 198 | testGenerateContent(t, session) 199 | }) 200 | t.Run("generate stream", func(t *testing.T) { 201 | testGenerateStream(t, session) 202 | }) 203 | } 204 | 205 | type weatherTool struct { 206 | name string 207 | } 208 | 209 | func (x *weatherTool) Spec() gollem.ToolSpec { 210 | return gollem.ToolSpec{ 211 | Name: x.name, 212 | Description: "get weather information of a region", 213 | Parameters: map[string]*gollem.Parameter{ 214 | "region": { 215 | Type: gollem.TypeString, 216 | Description: "Region name", 217 | }, 218 | }, 219 | } 220 | } 221 | 222 | func (t *weatherTool) Run(ctx context.Context, input map[string]any) (map[string]any, error) { 223 | return map[string]any{ 224 | "weather": "sunny", 225 | }, nil 226 | } 227 | 228 | func TestCallToolNameConvention(t *testing.T) { 229 | if _, ok := os.LookupEnv("TEST_FLAG_TOOL_NAME_CONVENTION"); !ok { 230 | t.Skip("TEST_FLAG_TOOL_NAME_CONVENTION is not set") 231 | } 232 | 233 | testFunc := func(t *testing.T, client gollem.LLMClient) { 234 | testCases := map[string]struct { 235 | name string 236 | isError bool 237 | }{ 238 | "low case is allowed": { 239 | name: "test", 240 | isError: false, 241 | }, 242 | "upper case is allowed": { 243 | name: "TEST", 244 | isError: false, 245 | }, 246 | "underscore is allowed": { 247 | name: "test_tool", 248 | isError: false, 249 | }, 250 | "number is allowed": { 251 | name: "test123", 252 | isError: false, 253 | }, 254 | "hyphen is allowed": { 255 | name: "test-tool", 256 | isError: false, 257 | }, 258 | /* 259 | SKIP: OpenAI, Claude does not allow dot in tool name, but Gemini allows it. 260 | "dot is not allowed": { 261 | name: "test.tool", 262 | isError: true, 263 | }, 264 | */ 265 | "comma is not allowed": { 266 | name: "test,tool", 267 | isError: true, 268 | }, 269 | "colon is not allowed": { 270 | name: "test:tool", 271 | isError: true, 272 | }, 273 | "space is not allowed": { 274 | name: "test tool", 275 | isError: true, 276 | }, 277 | } 278 | 279 | for name, tc := range testCases { 280 | t.Run(name, func(t *testing.T) { 281 | ctx := t.Context() 282 | tool := &weatherTool{name: tc.name} 283 | 284 | session, err := client.NewSession(ctx, gollem.WithSessionTools(tool)) 285 | gt.NoError(t, err) 286 | 287 | resp, err := session.GenerateContent(ctx, gollem.Text("What is the weather in Tokyo?")) 288 | if tc.isError { 289 | gt.Error(t, err) 290 | return 291 | } 292 | gt.NoError(t, err).Required() 293 | if len(resp.FunctionCalls) > 0 { 294 | gt.A(t, resp.FunctionCalls).Length(1).At(0, func(t testing.TB, v *gollem.FunctionCall) { 295 | gt.Equal(t, v.Name, tc.name) 296 | }) 297 | } 298 | }) 299 | } 300 | } 301 | 302 | t.Run("OpenAI", func(t *testing.T) { 303 | ctx := t.Context() 304 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY") 305 | if !ok { 306 | t.Skip("TEST_OPENAI_API_KEY is not set") 307 | } 308 | 309 | client, err := openai.New(ctx, apiKey) 310 | gt.NoError(t, err) 311 | testFunc(t, client) 312 | }) 313 | 314 | t.Run("gemini", func(t *testing.T) { 315 | ctx := t.Context() 316 | projectID, ok := os.LookupEnv("TEST_GCP_PROJECT_ID") 317 | if !ok { 318 | t.Skip("TEST_GCP_PROJECT_ID is not set") 319 | } 320 | 321 | location, ok := os.LookupEnv("TEST_GCP_LOCATION") 322 | if !ok { 323 | t.Skip("TEST_GCP_LOCATION is not set") 324 | } 325 | 326 | client, err := gemini.New(ctx, projectID, location) 327 | gt.NoError(t, err) 328 | testFunc(t, client) 329 | }) 330 | 331 | t.Run("claude", func(t *testing.T) { 332 | ctx := t.Context() 333 | apiKey, ok := os.LookupEnv("TEST_CLAUDE_API_KEY") 334 | if !ok { 335 | t.Skip("TEST_CLAUDE_API_KEY is not set") 336 | } 337 | 338 | client, err := claude.New(ctx, apiKey) 339 | gt.NoError(t, err) 340 | testFunc(t, client) 341 | }) 342 | } 343 | 344 | func TestSessionHistory(t *testing.T) { 345 | testFn := func(t *testing.T, client gollem.LLMClient) { 346 | ctx := t.Context() 347 | session, err := client.NewSession(ctx, gollem.WithSessionTools(&weatherTool{name: "weather"})) 348 | gt.NoError(t, err).Required() 349 | 350 | resp1, err := session.GenerateContent(ctx, gollem.Text("What is the weather in Tokyo?")) 351 | gt.NoError(t, err).Required() 352 | gt.A(t, resp1.FunctionCalls).Length(1).At(0, func(t testing.TB, v *gollem.FunctionCall) { 353 | gt.Equal(t, v.Name, "weather") 354 | }) 355 | 356 | resp2, err := session.GenerateContent(ctx, gollem.FunctionResponse{ 357 | ID: resp1.FunctionCalls[0].ID, 358 | Name: "weather", 359 | Data: map[string]any{"weather": "sunny"}, 360 | }) 361 | gt.NoError(t, err).Required() 362 | gt.A(t, resp2.Texts).Length(1).At(0, func(t testing.TB, v string) { 363 | gt.S(t, v).Contains("sunny") 364 | }) 365 | 366 | history := session.History() 367 | rawData, err := json.Marshal(history) 368 | gt.NoError(t, err).Required() 369 | 370 | var restored gollem.History 371 | gt.NoError(t, json.Unmarshal(rawData, &restored)) 372 | 373 | newSession, err := client.NewSession(ctx, gollem.WithSessionHistory(&restored)) 374 | gt.NoError(t, err) 375 | 376 | resp3, err := newSession.GenerateContent(ctx, gollem.Text("Do you remember the weather in Tokyo?")) 377 | gt.NoError(t, err).Required() 378 | 379 | gt.A(t, resp3.Texts).Longer(0).At(0, func(t testing.TB, v string) { 380 | gt.S(t, v).Contains("sunny") 381 | }) 382 | } 383 | 384 | t.Run("OpenAI", func(t *testing.T) { 385 | client := newOpenAIClient(t) 386 | testFn(t, client) 387 | }) 388 | 389 | t.Run("gemini", func(t *testing.T) { 390 | client := newGeminiClient(t) 391 | testFn(t, client) 392 | }) 393 | 394 | t.Run("claude", func(t *testing.T) { 395 | client := newClaudeClient(t) 396 | testFn(t, client) 397 | }) 398 | } 399 | 400 | func TestExitTool(t *testing.T) { 401 | testFn := func(t *testing.T, newClient func(t *testing.T) gollem.LLMClient) { 402 | client := newClient(t) 403 | 404 | exitTool := &gollem.DefaultExitTool{} 405 | loopCount := 0 406 | exitToolCalled := false 407 | 408 | s := gollem.New(client, 409 | gollem.WithExitTool(exitTool), 410 | gollem.WithTools(&randomNumberTool{}), 411 | gollem.WithSystemPrompt("You are an assistant that can use tools. When asked to complete a task and end the session, you must use the finalize_task tool to properly end the session."), 412 | gollem.WithLoopHook(func(ctx context.Context, loop int, input []gollem.Input) error { 413 | loopCount++ 414 | t.Logf("Loop called: %d", loop) 415 | return nil 416 | }), 417 | gollem.WithMessageHook(func(ctx context.Context, msg string) error { 418 | t.Logf("[Message received]: %s", msg) 419 | return nil 420 | }), 421 | gollem.WithToolRequestHook(func(ctx context.Context, tool gollem.FunctionCall) error { 422 | t.Logf("[Tool request received]: %s (%v)", tool.Name, tool.Arguments) 423 | return nil 424 | }), 425 | gollem.WithLoopLimit(10), 426 | ) 427 | 428 | ctx := t.Context() 429 | _, err := s.Prompt(ctx, "Get a random number between 1 and 10") 430 | gt.NoError(t, err) 431 | 432 | t.Logf("Test completed: exitToolCalled=%v, isCompleted=%v, loopCount=%d", exitToolCalled, exitTool.IsCompleted(), loopCount) 433 | 434 | // Verify that the exit tool was called 435 | gt.True(t, exitTool.IsCompleted()) 436 | 437 | // Verify that loops occurred (should be more than 0 but less than loop limit) 438 | gt.N(t, loopCount).Greater(0).Less(10) 439 | } 440 | 441 | t.Run("OpenAI", func(t *testing.T) { 442 | testFn(t, newOpenAIClient) 443 | }) 444 | 445 | t.Run("Gemini", func(t *testing.T) { 446 | testFn(t, newGeminiClient) 447 | }) 448 | 449 | t.Run("Claude", func(t *testing.T) { 450 | testFn(t, newClaudeClient) 451 | }) 452 | } 453 | -------------------------------------------------------------------------------- /mcp/export_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mark3labs/mcp-go/mcp" 7 | ) 8 | 9 | func NewLocalMCPClient(path string) *Client { 10 | client := &Client{ 11 | path: path, 12 | } 13 | return client 14 | } 15 | 16 | func (x *Client) Start(ctx context.Context) error { 17 | return x.init(ctx) 18 | } 19 | 20 | func (x *Client) ListTools(ctx context.Context) ([]mcp.Tool, error) { 21 | return x.listTools(ctx) 22 | } 23 | 24 | func (x *Client) CallTool(ctx context.Context, name string, args map[string]any) (*mcp.CallToolResult, error) { 25 | return x.callTool(ctx, name, args) 26 | } 27 | 28 | var ( 29 | InputSchemaToParameter = inputSchemaToParameter 30 | MCPContentToMap = mcpContentToMap 31 | JSONSchemaToParameter = jsonSchemaToParameter 32 | ) 33 | -------------------------------------------------------------------------------- /mcp/mcp.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "sync" 9 | 10 | "github.com/m-mizutani/goerr/v2" 11 | "github.com/m-mizutani/gollem" 12 | "github.com/mark3labs/mcp-go/client" 13 | "github.com/mark3labs/mcp-go/client/transport" 14 | "github.com/mark3labs/mcp-go/mcp" 15 | "github.com/santhosh-tekuri/jsonschema/v6" 16 | ) 17 | 18 | type Client struct { 19 | // For local MCP server 20 | path string 21 | args []string 22 | envVars []string 23 | 24 | // For remote MCP server 25 | baseURL string 26 | headers map[string]string 27 | 28 | // Common client 29 | client *client.Client 30 | 31 | initResult *mcp.InitializeResult 32 | initMutex sync.Mutex 33 | } 34 | 35 | // Specs implements gollem.ToolSet interface 36 | func (c *Client) Specs(ctx context.Context) ([]gollem.ToolSpec, error) { 37 | logger := gollem.LoggerFromContext(ctx) 38 | 39 | tools, err := c.listTools(ctx) 40 | if err != nil { 41 | return nil, goerr.Wrap(err, "failed to list tools") 42 | } 43 | 44 | specs := make([]gollem.ToolSpec, len(tools)) 45 | toolNames := make([]string, len(tools)) 46 | 47 | for i, tool := range tools { 48 | toolNames[i] = tool.Name 49 | 50 | param, err := inputSchemaToParameter(tool.InputSchema) 51 | if err != nil { 52 | return nil, goerr.Wrap(err, 53 | "failed to convert input schema to parameter", 54 | goerr.V("tool.name", tool.Name), 55 | goerr.V("tool.inputSchema", tool.InputSchema), 56 | ) 57 | } 58 | 59 | specs[i] = gollem.ToolSpec{ 60 | Name: tool.Name, 61 | Description: tool.Description, 62 | Parameters: param.Properties, 63 | Required: param.Required, 64 | } 65 | } 66 | 67 | logger.Debug("found MCP tools", "names", toolNames) 68 | 69 | return specs, nil 70 | } 71 | 72 | // Run implements gollem.ToolSet interface 73 | func (c *Client) Run(ctx context.Context, name string, args map[string]any) (map[string]any, error) { 74 | logger := gollem.LoggerFromContext(ctx) 75 | 76 | logger.Debug("call MCP tool", "name", name, "args", args) 77 | 78 | resp, err := c.callTool(ctx, name, args) 79 | if err != nil { 80 | return nil, goerr.Wrap(err, "failed to call tool") 81 | } 82 | 83 | return mcpContentToMap(resp.Content), nil 84 | } 85 | 86 | // StdioOption is the option for the MCP client for local MCP executable server via stdio. 87 | type StdioOption func(*Client) 88 | 89 | // WithEnvVars sets the environment variables for the MCP client. It appends the environment variables to the existing ones. 90 | func WithEnvVars(envVars []string) StdioOption { 91 | return func(m *Client) { 92 | m.envVars = append(m.envVars, envVars...) 93 | } 94 | } 95 | 96 | // NewStdio creates a new MCP client for local MCP executable server via stdio. 97 | func NewStdio(ctx context.Context, path string, args []string, options ...StdioOption) (*Client, error) { 98 | client := &Client{ 99 | path: path, 100 | args: args, 101 | } 102 | for _, option := range options { 103 | option(client) 104 | } 105 | 106 | if err := client.init(ctx); err != nil { 107 | return nil, goerr.Wrap(err, "failed to initialize MCP client") 108 | } 109 | 110 | return client, nil 111 | } 112 | 113 | // SSEOption is the option for the MCP client for remote MCP server via HTTP SSE. 114 | type SSEOption func(*Client) 115 | 116 | // WithHeaders sets the headers for the MCP client. It replaces the existing headers setting. 117 | func WithHeaders(headers map[string]string) SSEOption { 118 | return func(m *Client) { 119 | m.headers = headers 120 | } 121 | } 122 | 123 | // NewSSE creates a new MCP client for remote MCP server via HTTP SSE. 124 | func NewSSE(ctx context.Context, baseURL string, options ...SSEOption) (*Client, error) { 125 | client := &Client{ 126 | baseURL: baseURL, 127 | } 128 | for _, option := range options { 129 | option(client) 130 | } 131 | 132 | if err := client.init(ctx); err != nil { 133 | return nil, goerr.Wrap(err, "failed to initialize MCP client") 134 | } 135 | 136 | return client, nil 137 | } 138 | 139 | func (c *Client) init(ctx context.Context) error { 140 | c.initMutex.Lock() 141 | defer c.initMutex.Unlock() 142 | 143 | logger := gollem.LoggerFromContext(ctx) 144 | 145 | if c.initResult != nil { 146 | return nil 147 | } 148 | 149 | var tp transport.Interface 150 | if c.path != "" { 151 | tp = transport.NewStdio(c.path, c.envVars, c.args...) 152 | } 153 | 154 | if c.baseURL != "" { 155 | sse, err := transport.NewSSE(c.baseURL, transport.WithHeaders(c.headers)) 156 | if err != nil { 157 | return goerr.Wrap(err, "failed to create SSE transport") 158 | } 159 | tp = sse 160 | } 161 | 162 | if tp == nil { 163 | return goerr.New("no transport") 164 | } 165 | 166 | c.client = client.NewClient(tp) 167 | 168 | logger.Debug("starting MCP client", "path", c.path, "url", c.baseURL) 169 | if err := c.client.Start(ctx); err != nil { 170 | return goerr.Wrap(err, "failed to start MCP client") 171 | } 172 | 173 | var initRequest mcp.InitializeRequest 174 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 175 | initRequest.Params.ClientInfo = mcp.Implementation{ 176 | Name: "gollem", 177 | Version: "0.0.1", 178 | } 179 | 180 | logger.Debug("initializing MCP client") 181 | if resp, err := c.client.Initialize(ctx, initRequest); err != nil { 182 | return goerr.Wrap(err, "failed to initialize MCP client") 183 | } else { 184 | c.initResult = resp 185 | } 186 | 187 | return nil 188 | } 189 | 190 | func (c *Client) listTools(ctx context.Context) ([]mcp.Tool, error) { 191 | // ListTools is thread safe 192 | resp, err := c.client.ListTools(ctx, mcp.ListToolsRequest{}) 193 | if err != nil { 194 | return nil, goerr.Wrap(err, "failed to list tools") 195 | } 196 | 197 | return resp.Tools, nil 198 | } 199 | 200 | func (c *Client) callTool(ctx context.Context, name string, args map[string]any) (*mcp.CallToolResult, error) { 201 | req := mcp.CallToolRequest{} 202 | req.Params.Name = name 203 | req.Params.Arguments = args 204 | resp, err := c.client.CallTool(ctx, req) 205 | if err != nil { 206 | return nil, goerr.Wrap(err, "failed to call tool") 207 | } 208 | 209 | return resp, nil 210 | } 211 | 212 | func (c *Client) Close() error { 213 | if err := c.client.Close(); err != nil { 214 | return goerr.Wrap(err, "failed to close MCP client") 215 | } 216 | return nil 217 | } 218 | 219 | func inputSchemaToParameter(inputSchema mcp.ToolInputSchema) (*gollem.Parameter, error) { 220 | parameters := map[string]*gollem.Parameter{} 221 | jsonSchema, err := json.Marshal(inputSchema) 222 | if err != nil { 223 | return nil, goerr.Wrap(err, "failed to marshal input schema") 224 | } 225 | 226 | rawSchema, err := jsonschema.UnmarshalJSON(bytes.NewReader(jsonSchema)) 227 | if err != nil { 228 | return nil, goerr.Wrap(err, "failed to compile input schema") 229 | } 230 | 231 | c := jsonschema.NewCompiler() 232 | if err := c.AddResource("schema.json", rawSchema); err != nil { 233 | return nil, goerr.Wrap(err, "failed to add resource to compiler") 234 | } 235 | schema, err := c.Compile("schema.json") 236 | if err != nil { 237 | return nil, goerr.Wrap(err, "failed to compile input schema") 238 | } 239 | 240 | schemaType := schema.Types.ToStrings() 241 | if len(schemaType) != 1 || schemaType[0] != "object" { 242 | return nil, goerr.Wrap(gollem.ErrInvalidTool, "invalid input schema", goerr.V("schema", schema)) 243 | } 244 | 245 | for name, property := range schema.Properties { 246 | parameters[name] = jsonSchemaToParameter(property) 247 | } 248 | 249 | return &gollem.Parameter{ 250 | Type: gollem.ParameterType(schema.Types.ToStrings()[0]), 251 | Title: schema.Title, 252 | Description: schema.Description, 253 | Required: schema.Required, 254 | Properties: parameters, 255 | }, nil 256 | } 257 | 258 | func jsonSchemaToParameter(schema *jsonschema.Schema) *gollem.Parameter { 259 | var enum []string 260 | if schema.Enum != nil { 261 | for _, v := range schema.Enum.Values { 262 | enum = append(enum, fmt.Sprintf("%v", v)) 263 | } 264 | } 265 | 266 | properties := map[string]*gollem.Parameter{} 267 | for name, property := range schema.Properties { 268 | properties[name] = jsonSchemaToParameter(property) 269 | } 270 | 271 | var items *gollem.Parameter 272 | if schema.Items != nil { 273 | switch v := schema.Items.(type) { 274 | case *jsonschema.Schema: 275 | items = jsonSchemaToParameter(v) 276 | } 277 | } 278 | 279 | var minimum, maximum *float64 280 | if schema.Minimum != nil { 281 | min, _ := (*schema.Minimum).Float64() 282 | minimum = &min 283 | } 284 | if schema.Maximum != nil { 285 | max, _ := (*schema.Maximum).Float64() 286 | maximum = &max 287 | } 288 | 289 | var minLength, maxLength *int 290 | if schema.MinLength != nil { 291 | min := int(*schema.MinLength) 292 | minLength = &min 293 | } 294 | if schema.MaxLength != nil { 295 | max := int(*schema.MaxLength) 296 | maxLength = &max 297 | } 298 | 299 | var minItems, maxItems *int 300 | if schema.MinItems != nil { 301 | min := int(*schema.MinItems) 302 | minItems = &min 303 | } 304 | if schema.MaxItems != nil { 305 | max := int(*schema.MaxItems) 306 | maxItems = &max 307 | } 308 | 309 | var pattern string 310 | if schema.Pattern != nil { 311 | pattern = schema.Pattern.String() 312 | } 313 | 314 | return &gollem.Parameter{ 315 | Type: gollem.ParameterType(schema.Types.ToStrings()[0]), 316 | Title: schema.Title, 317 | Description: schema.Description, 318 | Required: schema.Required, 319 | Enum: enum, 320 | Properties: properties, 321 | Items: items, 322 | Minimum: minimum, 323 | Maximum: maximum, 324 | MinLength: minLength, 325 | MaxLength: maxLength, 326 | Pattern: pattern, 327 | MinItems: minItems, 328 | MaxItems: maxItems, 329 | Default: schema.Default, 330 | } 331 | } 332 | 333 | func mcpContentToMap(contents []mcp.Content) map[string]any { 334 | if len(contents) == 0 { 335 | return nil 336 | } 337 | 338 | if len(contents) == 1 { 339 | if content, ok := contents[0].(mcp.TextContent); ok { 340 | var v any 341 | if err := json.Unmarshal([]byte(content.Text), &v); err == nil { 342 | if mapData, ok := v.(map[string]any); ok { 343 | return mapData 344 | } 345 | } 346 | return map[string]any{ 347 | "result": content.Text, 348 | } 349 | } 350 | return nil 351 | } 352 | 353 | result := map[string]any{} 354 | for i, c := range contents { 355 | if content, ok := c.(mcp.TextContent); ok { 356 | result[fmt.Sprintf("content_%d", i+1)] = content.Text 357 | } 358 | } 359 | return result 360 | } 361 | -------------------------------------------------------------------------------- /mcp/mcp_test.go: -------------------------------------------------------------------------------- 1 | package mcp_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/m-mizutani/gollem/mcp" 9 | "github.com/m-mizutani/gt" 10 | mcpgo "github.com/mark3labs/mcp-go/mcp" 11 | ) 12 | 13 | func TestMCPLocalDryRun(t *testing.T) { 14 | mcpExecPath, ok := os.LookupEnv("TEST_MCP_EXEC_PATH") 15 | if !ok { 16 | t.Skip("TEST_MCP_EXEC_PATH is not set") 17 | } 18 | 19 | client := mcp.NewLocalMCPClient(mcpExecPath) 20 | 21 | err := client.Start(context.Background()) 22 | gt.NoError(t, err) 23 | 24 | tools, err := client.ListTools(context.Background()) 25 | gt.NoError(t, err) 26 | gt.A(t, tools).Longer(0) 27 | 28 | parameter, err := mcp.InputSchemaToParameter(tools[0].InputSchema) 29 | gt.NoError(t, err) 30 | t.Log("parameter:", parameter) 31 | 32 | tool := tools[0] 33 | 34 | t.Log("tool:", tool) 35 | 36 | callTool, err := client.CallTool(context.Background(), tool.Name, map[string]any{ 37 | "length": 10, 38 | }) 39 | gt.NoError(t, err) 40 | 41 | t.Log("callTool:", callTool) 42 | } 43 | 44 | func TestMCPContentToMap(t *testing.T) { 45 | t.Run("when content is empty", func(t *testing.T) { 46 | result := mcp.MCPContentToMap([]mcpgo.Content{}) 47 | gt.Nil(t, result) 48 | }) 49 | 50 | t.Run("when text content is JSON", func(t *testing.T) { 51 | content := mcpgo.TextContent{Text: `{"key": "value"}`} 52 | result := mcp.MCPContentToMap([]mcpgo.Content{content}) 53 | gt.Equal(t, map[string]any{"key": "value"}, result) 54 | }) 55 | 56 | t.Run("when text content is not JSON", func(t *testing.T) { 57 | content := mcpgo.TextContent{Text: "plain text"} 58 | result := mcp.MCPContentToMap([]mcpgo.Content{content}) 59 | gt.Equal(t, map[string]any{"result": "plain text"}, result) 60 | }) 61 | 62 | t.Run("when multiple contents exist", func(t *testing.T) { 63 | contents := []mcpgo.Content{ 64 | mcpgo.TextContent{Text: "first"}, 65 | mcpgo.TextContent{Text: "second"}, 66 | } 67 | result := mcp.MCPContentToMap(contents) 68 | gt.Equal(t, map[string]any{ 69 | "content_1": "first", 70 | "content_2": "second", 71 | }, result) 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /mock/mock_gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by moq; DO NOT EDIT. 2 | // github.com/matryer/moq 3 | 4 | package mock 5 | 6 | import ( 7 | "context" 8 | "github.com/m-mizutani/gollem" 9 | "sync" 10 | ) 11 | 12 | // LLMClientMock is a mock implementation of gollem.LLMClient. 13 | // 14 | // func TestSomethingThatUsesLLMClient(t *testing.T) { 15 | // 16 | // // make and configure a mocked gollem.LLMClient 17 | // mockedLLMClient := &LLMClientMock{ 18 | // GenerateEmbeddingFunc: func(ctx context.Context, dimension int, input []string) ([][]float64, error) { 19 | // panic("mock out the GenerateEmbedding method") 20 | // }, 21 | // NewSessionFunc: func(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) { 22 | // panic("mock out the NewSession method") 23 | // }, 24 | // } 25 | // 26 | // // use mockedLLMClient in code that requires gollem.LLMClient 27 | // // and then make assertions. 28 | // 29 | // } 30 | type LLMClientMock struct { 31 | // GenerateEmbeddingFunc mocks the GenerateEmbedding method. 32 | GenerateEmbeddingFunc func(ctx context.Context, dimension int, input []string) ([][]float64, error) 33 | 34 | // NewSessionFunc mocks the NewSession method. 35 | NewSessionFunc func(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) 36 | 37 | // calls tracks calls to the methods. 38 | calls struct { 39 | // GenerateEmbedding holds details about calls to the GenerateEmbedding method. 40 | GenerateEmbedding []struct { 41 | // Ctx is the ctx argument value. 42 | Ctx context.Context 43 | // Dimension is the dimension argument value. 44 | Dimension int 45 | // Input is the input argument value. 46 | Input []string 47 | } 48 | // NewSession holds details about calls to the NewSession method. 49 | NewSession []struct { 50 | // Ctx is the ctx argument value. 51 | Ctx context.Context 52 | // Options is the options argument value. 53 | Options []gollem.SessionOption 54 | } 55 | } 56 | lockGenerateEmbedding sync.RWMutex 57 | lockNewSession sync.RWMutex 58 | } 59 | 60 | // GenerateEmbedding calls GenerateEmbeddingFunc. 61 | func (mock *LLMClientMock) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) { 62 | callInfo := struct { 63 | Ctx context.Context 64 | Dimension int 65 | Input []string 66 | }{ 67 | Ctx: ctx, 68 | Dimension: dimension, 69 | Input: input, 70 | } 71 | mock.lockGenerateEmbedding.Lock() 72 | mock.calls.GenerateEmbedding = append(mock.calls.GenerateEmbedding, callInfo) 73 | mock.lockGenerateEmbedding.Unlock() 74 | if mock.GenerateEmbeddingFunc == nil { 75 | var ( 76 | float64ssOut [][]float64 77 | errOut error 78 | ) 79 | return float64ssOut, errOut 80 | } 81 | return mock.GenerateEmbeddingFunc(ctx, dimension, input) 82 | } 83 | 84 | // GenerateEmbeddingCalls gets all the calls that were made to GenerateEmbedding. 85 | // Check the length with: 86 | // 87 | // len(mockedLLMClient.GenerateEmbeddingCalls()) 88 | func (mock *LLMClientMock) GenerateEmbeddingCalls() []struct { 89 | Ctx context.Context 90 | Dimension int 91 | Input []string 92 | } { 93 | var calls []struct { 94 | Ctx context.Context 95 | Dimension int 96 | Input []string 97 | } 98 | mock.lockGenerateEmbedding.RLock() 99 | calls = mock.calls.GenerateEmbedding 100 | mock.lockGenerateEmbedding.RUnlock() 101 | return calls 102 | } 103 | 104 | // NewSession calls NewSessionFunc. 105 | func (mock *LLMClientMock) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) { 106 | callInfo := struct { 107 | Ctx context.Context 108 | Options []gollem.SessionOption 109 | }{ 110 | Ctx: ctx, 111 | Options: options, 112 | } 113 | mock.lockNewSession.Lock() 114 | mock.calls.NewSession = append(mock.calls.NewSession, callInfo) 115 | mock.lockNewSession.Unlock() 116 | if mock.NewSessionFunc == nil { 117 | var ( 118 | sessionOut gollem.Session 119 | errOut error 120 | ) 121 | return sessionOut, errOut 122 | } 123 | return mock.NewSessionFunc(ctx, options...) 124 | } 125 | 126 | // NewSessionCalls gets all the calls that were made to NewSession. 127 | // Check the length with: 128 | // 129 | // len(mockedLLMClient.NewSessionCalls()) 130 | func (mock *LLMClientMock) NewSessionCalls() []struct { 131 | Ctx context.Context 132 | Options []gollem.SessionOption 133 | } { 134 | var calls []struct { 135 | Ctx context.Context 136 | Options []gollem.SessionOption 137 | } 138 | mock.lockNewSession.RLock() 139 | calls = mock.calls.NewSession 140 | mock.lockNewSession.RUnlock() 141 | return calls 142 | } 143 | 144 | // SessionMock is a mock implementation of gollem.Session. 145 | // 146 | // func TestSomethingThatUsesSession(t *testing.T) { 147 | // 148 | // // make and configure a mocked gollem.Session 149 | // mockedSession := &SessionMock{ 150 | // GenerateContentFunc: func(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) { 151 | // panic("mock out the GenerateContent method") 152 | // }, 153 | // GenerateStreamFunc: func(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) { 154 | // panic("mock out the GenerateStream method") 155 | // }, 156 | // HistoryFunc: func() *gollem.History { 157 | // panic("mock out the History method") 158 | // }, 159 | // } 160 | // 161 | // // use mockedSession in code that requires gollem.Session 162 | // // and then make assertions. 163 | // 164 | // } 165 | type SessionMock struct { 166 | // GenerateContentFunc mocks the GenerateContent method. 167 | GenerateContentFunc func(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) 168 | 169 | // GenerateStreamFunc mocks the GenerateStream method. 170 | GenerateStreamFunc func(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) 171 | 172 | // HistoryFunc mocks the History method. 173 | HistoryFunc func() *gollem.History 174 | 175 | // calls tracks calls to the methods. 176 | calls struct { 177 | // GenerateContent holds details about calls to the GenerateContent method. 178 | GenerateContent []struct { 179 | // Ctx is the ctx argument value. 180 | Ctx context.Context 181 | // Input is the input argument value. 182 | Input []gollem.Input 183 | } 184 | // GenerateStream holds details about calls to the GenerateStream method. 185 | GenerateStream []struct { 186 | // Ctx is the ctx argument value. 187 | Ctx context.Context 188 | // Input is the input argument value. 189 | Input []gollem.Input 190 | } 191 | // History holds details about calls to the History method. 192 | History []struct { 193 | } 194 | } 195 | lockGenerateContent sync.RWMutex 196 | lockGenerateStream sync.RWMutex 197 | lockHistory sync.RWMutex 198 | } 199 | 200 | // GenerateContent calls GenerateContentFunc. 201 | func (mock *SessionMock) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) { 202 | callInfo := struct { 203 | Ctx context.Context 204 | Input []gollem.Input 205 | }{ 206 | Ctx: ctx, 207 | Input: input, 208 | } 209 | mock.lockGenerateContent.Lock() 210 | mock.calls.GenerateContent = append(mock.calls.GenerateContent, callInfo) 211 | mock.lockGenerateContent.Unlock() 212 | if mock.GenerateContentFunc == nil { 213 | var ( 214 | responseOut *gollem.Response 215 | errOut error 216 | ) 217 | return responseOut, errOut 218 | } 219 | return mock.GenerateContentFunc(ctx, input...) 220 | } 221 | 222 | // GenerateContentCalls gets all the calls that were made to GenerateContent. 223 | // Check the length with: 224 | // 225 | // len(mockedSession.GenerateContentCalls()) 226 | func (mock *SessionMock) GenerateContentCalls() []struct { 227 | Ctx context.Context 228 | Input []gollem.Input 229 | } { 230 | var calls []struct { 231 | Ctx context.Context 232 | Input []gollem.Input 233 | } 234 | mock.lockGenerateContent.RLock() 235 | calls = mock.calls.GenerateContent 236 | mock.lockGenerateContent.RUnlock() 237 | return calls 238 | } 239 | 240 | // GenerateStream calls GenerateStreamFunc. 241 | func (mock *SessionMock) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) { 242 | callInfo := struct { 243 | Ctx context.Context 244 | Input []gollem.Input 245 | }{ 246 | Ctx: ctx, 247 | Input: input, 248 | } 249 | mock.lockGenerateStream.Lock() 250 | mock.calls.GenerateStream = append(mock.calls.GenerateStream, callInfo) 251 | mock.lockGenerateStream.Unlock() 252 | if mock.GenerateStreamFunc == nil { 253 | var ( 254 | responseChOut <-chan *gollem.Response 255 | errOut error 256 | ) 257 | return responseChOut, errOut 258 | } 259 | return mock.GenerateStreamFunc(ctx, input...) 260 | } 261 | 262 | // GenerateStreamCalls gets all the calls that were made to GenerateStream. 263 | // Check the length with: 264 | // 265 | // len(mockedSession.GenerateStreamCalls()) 266 | func (mock *SessionMock) GenerateStreamCalls() []struct { 267 | Ctx context.Context 268 | Input []gollem.Input 269 | } { 270 | var calls []struct { 271 | Ctx context.Context 272 | Input []gollem.Input 273 | } 274 | mock.lockGenerateStream.RLock() 275 | calls = mock.calls.GenerateStream 276 | mock.lockGenerateStream.RUnlock() 277 | return calls 278 | } 279 | 280 | // History calls HistoryFunc. 281 | func (mock *SessionMock) History() *gollem.History { 282 | callInfo := struct { 283 | }{} 284 | mock.lockHistory.Lock() 285 | mock.calls.History = append(mock.calls.History, callInfo) 286 | mock.lockHistory.Unlock() 287 | if mock.HistoryFunc == nil { 288 | var ( 289 | historyOut *gollem.History 290 | ) 291 | return historyOut 292 | } 293 | return mock.HistoryFunc() 294 | } 295 | 296 | // HistoryCalls gets all the calls that were made to History. 297 | // Check the length with: 298 | // 299 | // len(mockedSession.HistoryCalls()) 300 | func (mock *SessionMock) HistoryCalls() []struct { 301 | } { 302 | var calls []struct { 303 | } 304 | mock.lockHistory.RLock() 305 | calls = mock.calls.History 306 | mock.lockHistory.RUnlock() 307 | return calls 308 | } 309 | 310 | // ToolMock is a mock implementation of gollem.Tool. 311 | // 312 | // func TestSomethingThatUsesTool(t *testing.T) { 313 | // 314 | // // make and configure a mocked gollem.Tool 315 | // mockedTool := &ToolMock{ 316 | // RunFunc: func(ctx context.Context, args map[string]any) (map[string]any, error) { 317 | // panic("mock out the Run method") 318 | // }, 319 | // SpecFunc: func() gollem.ToolSpec { 320 | // panic("mock out the Spec method") 321 | // }, 322 | // } 323 | // 324 | // // use mockedTool in code that requires gollem.Tool 325 | // // and then make assertions. 326 | // 327 | // } 328 | type ToolMock struct { 329 | // RunFunc mocks the Run method. 330 | RunFunc func(ctx context.Context, args map[string]any) (map[string]any, error) 331 | 332 | // SpecFunc mocks the Spec method. 333 | SpecFunc func() gollem.ToolSpec 334 | 335 | // calls tracks calls to the methods. 336 | calls struct { 337 | // Run holds details about calls to the Run method. 338 | Run []struct { 339 | // Ctx is the ctx argument value. 340 | Ctx context.Context 341 | // Args is the args argument value. 342 | Args map[string]any 343 | } 344 | // Spec holds details about calls to the Spec method. 345 | Spec []struct { 346 | } 347 | } 348 | lockRun sync.RWMutex 349 | lockSpec sync.RWMutex 350 | } 351 | 352 | // Run calls RunFunc. 353 | func (mock *ToolMock) Run(ctx context.Context, args map[string]any) (map[string]any, error) { 354 | callInfo := struct { 355 | Ctx context.Context 356 | Args map[string]any 357 | }{ 358 | Ctx: ctx, 359 | Args: args, 360 | } 361 | mock.lockRun.Lock() 362 | mock.calls.Run = append(mock.calls.Run, callInfo) 363 | mock.lockRun.Unlock() 364 | if mock.RunFunc == nil { 365 | var ( 366 | stringToVOut map[string]any 367 | errOut error 368 | ) 369 | return stringToVOut, errOut 370 | } 371 | return mock.RunFunc(ctx, args) 372 | } 373 | 374 | // RunCalls gets all the calls that were made to Run. 375 | // Check the length with: 376 | // 377 | // len(mockedTool.RunCalls()) 378 | func (mock *ToolMock) RunCalls() []struct { 379 | Ctx context.Context 380 | Args map[string]any 381 | } { 382 | var calls []struct { 383 | Ctx context.Context 384 | Args map[string]any 385 | } 386 | mock.lockRun.RLock() 387 | calls = mock.calls.Run 388 | mock.lockRun.RUnlock() 389 | return calls 390 | } 391 | 392 | // Spec calls SpecFunc. 393 | func (mock *ToolMock) Spec() gollem.ToolSpec { 394 | callInfo := struct { 395 | }{} 396 | mock.lockSpec.Lock() 397 | mock.calls.Spec = append(mock.calls.Spec, callInfo) 398 | mock.lockSpec.Unlock() 399 | if mock.SpecFunc == nil { 400 | var ( 401 | toolSpecOut gollem.ToolSpec 402 | ) 403 | return toolSpecOut 404 | } 405 | return mock.SpecFunc() 406 | } 407 | 408 | // SpecCalls gets all the calls that were made to Spec. 409 | // Check the length with: 410 | // 411 | // len(mockedTool.SpecCalls()) 412 | func (mock *ToolMock) SpecCalls() []struct { 413 | } { 414 | var calls []struct { 415 | } 416 | mock.lockSpec.RLock() 417 | calls = mock.calls.Spec 418 | mock.lockSpec.RUnlock() 419 | return calls 420 | } 421 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import "context" 4 | 5 | // Session is a session for the LLM. This can be called to generate content and stream. It's mainly used for Prompt() method, but it also can be used for one-shot content generation. 6 | type Session interface { 7 | GenerateContent(ctx context.Context, input ...Input) (*Response, error) 8 | GenerateStream(ctx context.Context, input ...Input) (<-chan *Response, error) 9 | History() *History 10 | } 11 | 12 | // SessionConfig is the configuration for the new session. This is required for only LLM client implementations. 13 | type SessionConfig struct { 14 | history *History 15 | contentType ContentType 16 | systemPrompt string 17 | tools []Tool 18 | } 19 | 20 | // History returns the history of the session. 21 | func (c *SessionConfig) History() *History { 22 | return c.history 23 | } 24 | 25 | // SystemPrompt returns the system prompt of the session. 26 | func (c *SessionConfig) SystemPrompt() string { 27 | return c.systemPrompt 28 | } 29 | 30 | // ContentType returns the content type of the session. 31 | func (c *SessionConfig) ContentType() ContentType { 32 | return c.contentType 33 | } 34 | 35 | // Tools returns the tools of the session. 36 | func (c *SessionConfig) Tools() []Tool { 37 | return c.tools 38 | } 39 | 40 | // NewSessionConfig creates a new session configuration. This is required for only LLM client implementations. 41 | func NewSessionConfig(options ...SessionOption) SessionConfig { 42 | cfg := SessionConfig{} 43 | for _, option := range options { 44 | option(&cfg) 45 | } 46 | return cfg 47 | } 48 | 49 | // SessionOption is the option for the session configuration. This is required for only LLM client implementations. 50 | type SessionOption func(cfg *SessionConfig) 51 | 52 | // WithSessionHistory sets the history for the session. 53 | // Usage: 54 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionHistory(history)) 55 | func WithSessionHistory(history *History) SessionOption { 56 | return func(cfg *SessionConfig) { 57 | cfg.history = history 58 | } 59 | } 60 | 61 | // WithSessionContentType sets the content type for the session. 62 | // Usage: 63 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionContentType(gollem.ContentTypeJSON)) 64 | func WithSessionContentType(contentType ContentType) SessionOption { 65 | return func(cfg *SessionConfig) { 66 | cfg.contentType = contentType 67 | } 68 | } 69 | 70 | // WithSessionTools sets the tools for the session. 71 | // Usage: 72 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionTools(tools)) 73 | func WithSessionTools(tools ...Tool) SessionOption { 74 | return func(cfg *SessionConfig) { 75 | cfg.tools = append(cfg.tools, tools...) 76 | } 77 | } 78 | 79 | // WithSessionSystemPrompt sets the system prompt for the session. 80 | // Usage: 81 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionSystemPrompt("You are a helpful assistant.")) 82 | func WithSessionSystemPrompt(systemPrompt string) SessionOption { 83 | return func(cfg *SessionConfig) { 84 | cfg.systemPrompt = systemPrompt 85 | } 86 | } 87 | 88 | // ContentType represents the type of content to be generated by the LLM. 89 | type ContentType string 90 | 91 | const ( 92 | // ContentTypeText represents plain text content. 93 | ContentTypeText ContentType = "text" 94 | // ContentTypeJSON represents JSON content. 95 | ContentTypeJSON ContentType = "json" 96 | ) 97 | -------------------------------------------------------------------------------- /tool.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import ( 4 | "context" 5 | "regexp" 6 | 7 | "github.com/m-mizutani/goerr/v2" 8 | ) 9 | 10 | // ToolSpec is the specification of a tool. 11 | type ToolSpec struct { 12 | Name string 13 | Description string 14 | Parameters map[string]*Parameter 15 | Required []string 16 | } 17 | 18 | // Validate validates the tool specification. 19 | func (s *ToolSpec) Validate() error { 20 | eb := goerr.NewBuilder(goerr.V("tool", s)) 21 | if s.Name == "" { 22 | return eb.Wrap(ErrInvalidTool, "name is required") 23 | } 24 | 25 | paramNames := make(map[string]struct{}) 26 | for name, param := range s.Parameters { 27 | if _, ok := paramNames[name]; ok { 28 | return eb.Wrap(ErrInvalidTool, "duplicate parameter name", goerr.V("name", name)) 29 | } 30 | paramNames[name] = struct{}{} 31 | 32 | if err := param.Validate(); err != nil { 33 | return eb.Wrap(ErrInvalidTool, "invalid parameter") 34 | } 35 | } 36 | 37 | for _, req := range s.Required { 38 | if _, ok := paramNames[req]; !ok { 39 | return eb.Wrap(ErrInvalidTool, "required parameter not found", goerr.V("name", req)) 40 | } 41 | } 42 | 43 | return nil 44 | } 45 | 46 | // ParameterType is the type of a parameter. 47 | type ParameterType string 48 | 49 | const ( 50 | TypeString ParameterType = "string" 51 | TypeNumber ParameterType = "number" 52 | TypeInteger ParameterType = "integer" 53 | TypeBoolean ParameterType = "boolean" 54 | TypeArray ParameterType = "array" 55 | TypeObject ParameterType = "object" 56 | ) 57 | 58 | // Parameter is a parameter of a tool. 59 | type Parameter struct { 60 | // Title is the user friendly of the parameter. It's optional. 61 | Title string 62 | 63 | // Type is the type of the parameter. It's required. 64 | Type ParameterType 65 | 66 | // Description is the description of the parameter. It's optional. 67 | Description string 68 | 69 | // Required is the list of required field names when Type is Object. 70 | Required []string 71 | 72 | // Enum is the enum of the parameter. It's optional. 73 | Enum []string 74 | 75 | // Properties is the properties of the parameter. It's used for object type. 76 | Properties map[string]*Parameter 77 | 78 | // Items is the items of the parameter. It's used for array type. 79 | Items *Parameter 80 | 81 | // Number constraints 82 | Minimum *float64 83 | Maximum *float64 84 | 85 | // String constraints 86 | MinLength *int 87 | MaxLength *int 88 | Pattern string 89 | 90 | // Array constraints 91 | MinItems *int 92 | MaxItems *int 93 | 94 | // Default value 95 | Default any 96 | } 97 | 98 | // Validate validates the parameter. 99 | func (p *Parameter) Validate() error { 100 | eb := goerr.NewBuilder(goerr.V("parameter", p)) 101 | 102 | // Type is required 103 | if p.Type == "" { 104 | return eb.Wrap(ErrInvalidParameter, "type is required") 105 | } 106 | 107 | // Validate parameter type 108 | switch p.Type { 109 | case TypeString, TypeNumber, TypeInteger, TypeBoolean, TypeArray, TypeObject: 110 | // Valid type 111 | default: 112 | return eb.Wrap(ErrInvalidParameter, "invalid parameter type", goerr.V("type", p.Type)) 113 | } 114 | 115 | // Properties is required for object type 116 | if p.Type == TypeObject { 117 | if p.Properties == nil { 118 | return eb.Wrap(ErrInvalidParameter, "properties is required for object type") 119 | } 120 | 121 | // Check for duplicate property names 122 | propNames := make(map[string]struct{}) 123 | for name := range p.Properties { 124 | if _, ok := propNames[name]; ok { 125 | return eb.Wrap(ErrInvalidParameter, "duplicate property name", goerr.V("name", name)) 126 | } 127 | propNames[name] = struct{}{} 128 | } 129 | 130 | // Validate nested properties 131 | for _, prop := range p.Properties { 132 | if err := prop.Validate(); err != nil { 133 | return eb.Wrap(ErrInvalidParameter, "invalid property") 134 | } 135 | } 136 | // Validate required fields exist in properties 137 | for _, req := range p.Required { 138 | if _, ok := p.Properties[req]; !ok { 139 | return eb.Wrap(ErrInvalidParameter, "required field not found in properties", goerr.V("field", req)) 140 | } 141 | } 142 | } 143 | 144 | // Items is required for array type 145 | if p.Type == TypeArray { 146 | if p.Items == nil { 147 | return eb.Wrap(ErrInvalidParameter, "items is required for array type") 148 | } 149 | // Validate items 150 | if err := p.Items.Validate(); err != nil { 151 | return eb.Wrap(ErrInvalidParameter, "invalid items") 152 | } 153 | } 154 | 155 | // Validate number constraints 156 | if p.Type == TypeNumber || p.Type == TypeInteger { 157 | if p.Minimum != nil && p.Maximum != nil && *p.Minimum > *p.Maximum { 158 | return eb.Wrap(ErrInvalidParameter, "minimum must be less than or equal to maximum") 159 | } 160 | } 161 | 162 | // Validate string constraints 163 | if p.Type == TypeString { 164 | if p.MinLength != nil && p.MaxLength != nil && *p.MinLength > *p.MaxLength { 165 | return eb.Wrap(ErrInvalidParameter, "minLength must be less than or equal to maxLength") 166 | } 167 | if p.Pattern != "" { 168 | if _, err := regexp.Compile(p.Pattern); err != nil { 169 | return eb.Wrap(ErrInvalidParameter, "invalid pattern", goerr.V("pattern", p.Pattern)) 170 | } 171 | } 172 | } 173 | 174 | // Validate array constraints 175 | if p.Type == TypeArray { 176 | if p.MinItems != nil && p.MaxItems != nil && *p.MinItems > *p.MaxItems { 177 | return eb.Wrap(ErrInvalidParameter, "minItems must be less than or equal to maxItems") 178 | } 179 | } 180 | 181 | return nil 182 | } 183 | 184 | // Tool is specification and execution of an action that can be called by the LLM. 185 | type Tool interface { 186 | // Spec returns the specification of the tool. It's called when starting a LLM chat session in Prompt(). 187 | Spec() ToolSpec 188 | 189 | // Run is the execution of the tool. 190 | // It's called when receiving a tool call from the LLM. Even if the method returns an error, the tool execution is not aborted. Error will be passed to LLM as a response. If you want to abort the tool execution, you need to return an error from the callback function of WithToolErrorHook(). 191 | Run(ctx context.Context, args map[string]any) (map[string]any, error) 192 | } 193 | 194 | // ToolSet is a set of tools. 195 | // It's useful for providing a set of tools to the LLM. 196 | type ToolSet interface { 197 | // Specs returns the specifications of the tools. 198 | Specs(ctx context.Context) ([]ToolSpec, error) 199 | 200 | // Run is the execution of the tool. 201 | // It's called when receiving a tool call from the LLM. 202 | Run(ctx context.Context, name string, args map[string]any) (map[string]any, error) 203 | } 204 | -------------------------------------------------------------------------------- /tool_test.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/m-mizutani/gt" 7 | ) 8 | 9 | func TestParameterValidation(t *testing.T) { 10 | t.Run("number constraints", func(t *testing.T) { 11 | t.Run("valid minimum and maximum", func(t *testing.T) { 12 | p := &Parameter{ 13 | Type: TypeNumber, 14 | Minimum: ptr(1.0), 15 | Maximum: ptr(10.0), 16 | } 17 | gt.NoError(t, p.Validate()) 18 | }) 19 | 20 | t.Run("invalid minimum and maximum", func(t *testing.T) { 21 | p := &Parameter{ 22 | Type: TypeNumber, 23 | Minimum: ptr(10.0), 24 | Maximum: ptr(1.0), 25 | } 26 | gt.Error(t, p.Validate()) 27 | }) 28 | }) 29 | 30 | t.Run("string constraints", func(t *testing.T) { 31 | t.Run("valid minLength and maxLength", func(t *testing.T) { 32 | p := &Parameter{ 33 | Type: TypeString, 34 | MinLength: ptr(1), 35 | MaxLength: ptr(10), 36 | } 37 | gt.NoError(t, p.Validate()) 38 | }) 39 | 40 | t.Run("invalid minLength and maxLength", func(t *testing.T) { 41 | p := &Parameter{ 42 | Type: TypeString, 43 | MinLength: ptr(10), 44 | MaxLength: ptr(1), 45 | } 46 | gt.Error(t, p.Validate()) 47 | }) 48 | 49 | t.Run("valid pattern", func(t *testing.T) { 50 | p := &Parameter{ 51 | Type: TypeString, 52 | Pattern: "^[a-z]+$", 53 | } 54 | gt.NoError(t, p.Validate()) 55 | }) 56 | 57 | t.Run("invalid pattern", func(t *testing.T) { 58 | p := &Parameter{ 59 | Type: TypeString, 60 | Pattern: "[invalid", 61 | } 62 | gt.Error(t, p.Validate()) 63 | }) 64 | }) 65 | 66 | t.Run("array constraints", func(t *testing.T) { 67 | t.Run("valid minItems and maxItems", func(t *testing.T) { 68 | p := &Parameter{ 69 | Type: TypeArray, 70 | Items: &Parameter{Type: TypeString}, 71 | MinItems: ptr(1), 72 | MaxItems: ptr(10), 73 | } 74 | gt.NoError(t, p.Validate()) 75 | }) 76 | 77 | t.Run("invalid minItems and maxItems", func(t *testing.T) { 78 | p := &Parameter{ 79 | Type: TypeArray, 80 | Items: &Parameter{Type: TypeString}, 81 | MinItems: ptr(10), 82 | MaxItems: ptr(1), 83 | } 84 | gt.Error(t, p.Validate()) 85 | }) 86 | }) 87 | 88 | t.Run("object constraints", func(t *testing.T) { 89 | t.Run("valid properties", func(t *testing.T) { 90 | p := &Parameter{ 91 | Type: TypeObject, 92 | Properties: map[string]*Parameter{ 93 | "name": { 94 | Type: TypeString, 95 | Description: "User name", 96 | }, 97 | "age": { 98 | Type: TypeNumber, 99 | Description: "User age", 100 | }, 101 | }, 102 | } 103 | gt.NoError(t, p.Validate()) 104 | }) 105 | 106 | t.Run("duplicate property names", func(t *testing.T) { 107 | p := &Parameter{ 108 | Type: TypeObject, 109 | Properties: make(map[string]*Parameter), 110 | } 111 | p.Properties["name"] = &Parameter{ 112 | Type: TypeString, 113 | Description: "User name", 114 | } 115 | p.Properties["name"] = &Parameter{ 116 | Type: TypeString, 117 | Description: "Duplicate name", 118 | } 119 | gt.NoError(t, p.Validate()) 120 | }) 121 | 122 | t.Run("invalid property type", func(t *testing.T) { 123 | p := &Parameter{ 124 | Type: TypeObject, 125 | Properties: map[string]*Parameter{ 126 | "name": { 127 | Type: "invalid", 128 | Description: "User name", 129 | }, 130 | }, 131 | } 132 | gt.Error(t, p.Validate()) 133 | }) 134 | }) 135 | } 136 | 137 | func ptr[T any](v T) *T { 138 | return &v 139 | } 140 | 141 | func TestToolSpecValidation(t *testing.T) { 142 | t.Run("valid tool spec", func(t *testing.T) { 143 | spec := ToolSpec{ 144 | Name: "test", 145 | Description: "test description", 146 | Parameters: map[string]*Parameter{ 147 | "param1": { 148 | Type: TypeString, 149 | Description: "test parameter", 150 | }, 151 | }, 152 | Required: []string{"param1"}, 153 | } 154 | gt.NoError(t, spec.Validate()) 155 | }) 156 | 157 | t.Run("empty name", func(t *testing.T) { 158 | spec := ToolSpec{ 159 | Description: "test description", 160 | Parameters: map[string]*Parameter{ 161 | "param1": { 162 | Type: TypeString, 163 | Description: "test parameter", 164 | }, 165 | }, 166 | } 167 | gt.Error(t, spec.Validate()) 168 | }) 169 | 170 | t.Run("invalid parameter type", func(t *testing.T) { 171 | spec := ToolSpec{ 172 | Name: "test", 173 | Description: "test description", 174 | Parameters: map[string]*Parameter{ 175 | "param1": { 176 | Type: "invalid", 177 | Description: "test parameter", 178 | }, 179 | }, 180 | } 181 | gt.Error(t, spec.Validate()) 182 | }) 183 | 184 | t.Run("required parameter not found", func(t *testing.T) { 185 | spec := ToolSpec{ 186 | Name: "test", 187 | Description: "test description", 188 | Parameters: map[string]*Parameter{ 189 | "param1": { 190 | Type: TypeString, 191 | Description: "test parameter", 192 | }, 193 | }, 194 | Required: []string{"param2"}, 195 | } 196 | gt.Error(t, spec.Validate()) 197 | }) 198 | 199 | t.Run("invalid parameter", func(t *testing.T) { 200 | spec := ToolSpec{ 201 | Name: "test", 202 | Description: "test description", 203 | Parameters: map[string]*Parameter{ 204 | "param1": { 205 | Type: TypeNumber, 206 | Minimum: ptr(10.0), 207 | Maximum: ptr(1.0), 208 | }, 209 | }, 210 | } 211 | gt.Error(t, spec.Validate()) 212 | }) 213 | 214 | t.Run("object parameter without properties", func(t *testing.T) { 215 | spec := ToolSpec{ 216 | Name: "test", 217 | Description: "test description", 218 | Parameters: map[string]*Parameter{ 219 | "param1": { 220 | Type: TypeObject, 221 | Description: "test parameter", 222 | }, 223 | }, 224 | } 225 | gt.Error(t, spec.Validate()) 226 | }) 227 | 228 | t.Run("array parameter without items", func(t *testing.T) { 229 | spec := ToolSpec{ 230 | Name: "test", 231 | Description: "test description", 232 | Parameters: map[string]*Parameter{ 233 | "param1": { 234 | Type: TypeArray, 235 | Description: "test parameter", 236 | }, 237 | }, 238 | } 239 | gt.Error(t, spec.Validate()) 240 | }) 241 | } 242 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package gollem 2 | --------------------------------------------------------------------------------