├── .cursor
└── rules
│ └── go.mdc
├── .github
└── workflows
│ ├── gosec.yml
│ ├── lint.yml
│ ├── test.yml
│ └── trivy.yml
├── .gitignore
├── LICENSE
├── README.md
├── Taskfile.yml
├── context.go
├── doc
├── README.md
├── examples.md
├── getting-started.md
├── history.md
├── images
│ └── logo.png
├── mcp.md
└── tools.md
├── errors.go
├── examples
├── README.md
├── basic
│ └── main.go
├── chat
│ └── main.go
├── embedding
│ └── main.go
├── mcp
│ └── main.go
├── query
│ └── main.go
├── simple
│ └── main.go
└── tools
│ └── main.go
├── exit.go
├── export_test.go
├── go.mod
├── go.sum
├── gollem.go
├── gollem_test.go
├── history.go
├── history_test.go
├── hook.go
├── llm.go
├── llm
├── claude
│ ├── client.go
│ ├── convert.go
│ ├── convert_test.go
│ ├── embedding.go
│ └── export_test.go
├── gemini
│ ├── client.go
│ ├── convert.go
│ ├── convert_test.go
│ ├── embedding.go
│ ├── embeding_test.go
│ └── export_test.go
└── openai
│ ├── client.go
│ ├── convert.go
│ ├── convert_test.go
│ ├── embedding.go
│ ├── embedding_test.go
│ └── export_test.go
├── llm_test.go
├── mcp
├── export_test.go
├── mcp.go
└── mcp_test.go
├── mock
└── mock_gen.go
├── session.go
├── tool.go
├── tool_test.go
└── types.go
/.cursor/rules/go.mdc:
--------------------------------------------------------------------------------
1 | ---
2 | description:
3 | globs:
4 | alwaysApply: true
5 | ---
6 |
7 | # Development rules
8 |
9 | ## Comment & Literals
10 |
11 | All comment and literals in source code MUST be in English.
12 |
13 | ## Error handling
14 |
15 | Use `http://github.com/m-mizutani/goerr/v2` as errorh handling tool. Wrap errors as following.
16 |
17 | ```go
18 | func someAction(tasks []task) error {
19 | for _, t := range tasks {
20 | if err := validateData(t.Data); err != nil {
21 | return goerr.Wrap(err, "failed to validate data", goerr.Value("name", t.Name))
22 | }
23 | }
24 | // ....
25 | return nil
26 | }
27 | ```
28 |
29 | # Testing
30 |
31 | If you need to run test for checking, run only the test that you modified or specified by the developer.
32 |
33 | ## Style for similar testing
34 |
35 | Use following Helper Driven Testing style instead of general Tatble Driven Test. Do not use Table Driven Test style.
36 |
37 | ```go
38 | type testCase struct {
39 | input string
40 | expected string
41 | }
42 |
43 | runTest := func(tc testCase) func(t *testing.T) {
44 | return func(t *testing.T) {
45 | actual := someFunc(tc.input)
46 | gt.Equal(t, tc.expected, actual)
47 | }
48 | }
49 |
50 | t.Run("success case", runTest(testCase{
51 | input: "blue",
52 | expected: "BLUE",
53 | }))
54 | ```
55 |
56 | ## Test framework
57 |
58 | Use `github.com/m-mizutani/gt` package.
59 |
60 | `gt` is test library leveraging Go generics to check variable type in IDE and compiler.
61 |
62 | ```go
63 | color := "blue"
64 |
65 | // gt.Value(t, color).Equal(5) // <- Compile error
66 |
67 | gt.Value(t, color).Equal("orange") // <- Fail
68 | gt.Value(t, color).Equal("blue") // <- Pass
69 | ```
70 |
71 | ```go
72 | colors := ["red", "blue"]
73 |
74 | // gt.Array(t, colors).Equal("red") // <- Compile error
75 | // gt.Array(t, colors).Equal([]int{1, 2}) // <- Compile error
76 |
77 | gt.Array(t, colors).Equal([]string{"red", "blue"}) // <- Pass
78 | gt.Array(t, colors).Has("orange") // <- Fail
79 | ```
80 |
81 | ### Usage
82 |
83 | In many cases, a developer does not care Go generics in using `gt`. However, a developer need to specify generic type (`Value`, `Array`, `Map`, `Error`, etc.) explicitly to use specific test functions for each types.
84 |
85 | See @reference for more detail.
86 |
87 | #### Value
88 |
89 | Generic test type has a minimum set of test methods.
90 |
91 | ```go
92 | type user struct {
93 | Name string
94 | }
95 | u1 := user{Name: "blue"}
96 |
97 | // gt.Value(t, u1).Equal(1) // Compile error
98 | // gt.Value(t, u1).Equal("blue") // Compile error
99 | // gt.Value(t, u1).Equal(&user{Name:"blue"}) // Compile error
100 |
101 | gt.Value(t, u1).Equal(user{Name:"blue"}) // Pass
102 | ```
103 |
104 | #### Number
105 |
106 | Accepts only number types: `int`, `uint`, `int64`, `float64`, etc.
107 |
108 | ```go
109 | var f float64 = 12.5
110 | gt.Number(t, f).
111 | Equal(12.5). // Pass
112 | Greater(12). // Pass
113 | Less(10). // Fail
114 | GreaterOrEqual(12.5) // Pass
115 | ```
116 |
117 | #### Array
118 |
119 | Accepts array of any type not only primitive type but also struct.
120 |
121 | ```go
122 | colors := []string{"red", "blue", "yellow"}
123 |
124 | gt.Array(t, colors).
125 | Equal([]string{"red", "blue", "yellow"}) // Pass
126 | Equal([]string{"red", "blue"}) // Fail
127 | // Equal([]int{1, 2}) // Compile error
128 | Contain([]string{"red", "blue"}) // Pass
129 | Has("yellow") // Pass
130 | Length(3) // Pass
131 |
132 | gt.Array(t, colors).Must().Has("orange") // Fail and stop test
133 | ```
134 |
135 | #### Map
136 |
137 | ```go
138 | colorMap := map[string]int{
139 | "red": 1,
140 | "yellow": 2,
141 | "blue": 5,
142 | }
143 |
144 | gt.Map(t, colorMap)
145 | .HasKey("blue") // Pass
146 | .HasValue(5) // Pass
147 | // .HasValue("red") // Compile error
148 | .HasKeyValue("yellow", 2) // Pass
149 |
150 | gt.Map(t, colorMap).Must().HasKey("orange") // Fail and stop test
151 | ```
152 |
--------------------------------------------------------------------------------
/.github/workflows/gosec.yml:
--------------------------------------------------------------------------------
1 | name: gosec
2 |
3 | # Run workflow each time code is pushed to your repository and on a schedule.
4 | # The scheduled workflow runs every at 00:00 on Sunday UTC time.
5 | on:
6 | push:
7 |
8 | jobs:
9 | tests:
10 | runs-on: ubuntu-latest
11 | permissions:
12 | security-events: write
13 | actions: read
14 | contents: read
15 | env:
16 | GO111MODULE: on
17 | steps:
18 | - name: Checkout Source
19 | uses: actions/checkout@v4
20 | - name: Run Gosec Security Scanner
21 | uses: securego/gosec@master
22 | with:
23 | # we let the report trigger content trigger a failure using the GitHub Security features.
24 | args: "-no-fail -fmt sarif -out results.sarif ./..."
25 | - name: Upload SARIF file
26 | uses: github/codeql-action/upload-sarif@v2
27 | with:
28 | # Path to SARIF file relative to the root of the repository
29 | sarif_file: results.sarif
30 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 | on:
3 | push:
4 |
5 | jobs:
6 | golangci:
7 | name: lint
8 | permissions:
9 | security-events: write
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Set up Go
14 | uses: actions/setup-go@v5
15 | with:
16 | go-version-file: "go.mod"
17 | - name: golangci-lint
18 | uses: golangci/golangci-lint-action@v6
19 | with:
20 | version: latest
21 | args: --timeout 10m ./...
22 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on: [push]
4 |
5 | jobs:
6 | testing:
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - name: Checkout upstream repo
11 | uses: actions/checkout@v3
12 | with:
13 | ref: ${{ github.head_ref }}
14 | - uses: actions/setup-go@v3
15 | with:
16 | go-version-file: "go.mod"
17 | - uses: google-github-actions/setup-gcloud@v0.5.0
18 | - run: go test ./...
19 |
--------------------------------------------------------------------------------
/.github/workflows/trivy.yml:
--------------------------------------------------------------------------------
1 | name: trivy
2 |
3 | on:
4 | push:
5 | schedule:
6 | - cron: "0 0 * * *"
7 | workflow_dispatch:
8 |
9 | jobs:
10 | scan:
11 | runs-on: ubuntu-latest
12 | permissions:
13 | security-events: write
14 | actions: read
15 | contents: read
16 |
17 | steps:
18 | - name: Checkout upstream repo
19 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
20 | with:
21 | ref: ${{ github.head_ref }}
22 | - id: scan
23 | name: Run Trivy vulnerability scanner in repo mode
24 | uses: aquasecurity/trivy-action@f3d98514b056d8c71a3552e8328c225bc7f6f353 # master
25 | with:
26 | scan-type: "fs"
27 | ignore-unfixed: true
28 | format: "sarif"
29 | output: "trivy-results.sarif"
30 | exit-code: 1
31 | env:
32 | TRIVY_DB_REPOSITORY: public.ecr.aws/aquasecurity/trivy-db
33 |
34 | - name: Upload Trivy scan results to GitHub Security tab
35 | if: failure() && steps.scan.outcome == 'failure'
36 | uses: github/codeql-action/upload-sarif@e8893c57a1f3a2b659b6b55564fdfdbbd2982911 # v3.24.0
37 | with:
38 | sarif_file: "trivy-results.sarif"
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # If you prefer the allow list template instead of the deny list, see community template:
2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3 | #
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, built with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Dependency directories (remove the comment below to include it)
18 | # vendor/
19 |
20 | # Go workspace file
21 | go.work
22 | go.work.sum
23 |
24 | # env file
25 | .env
26 |
27 | # VS code
28 | .vscode
29 |
30 | /tmp
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🤖 gollem [](https://pkg.go.dev/github.com/m-mizutani/gollem) [](https://github.com/m-mizutani/gollem/actions/workflows/test.yml) [](https://github.com/m-mizutani/gollem/actions/workflows/lint.yml) [](https://github.com/m-mizutani/gollem/actions/workflows/gosec.yml) [](https://github.com/m-mizutani/gollem/actions/workflows/trivy.yml)
2 |
3 | GO for Large LanguagE Model (GOLLEM)
4 |
5 |
6 |
7 |
8 |
9 |
10 | `gollem` provides:
11 | - Common interface to query prompt to Large Language Model (LLM) services
12 | - GenerateContent: Generate text content from prompt
13 | - GenerateEmbedding: Generate embedding vector from text (OpenAI and Gemini)
14 | - Framework for building agentic applications of LLMs with
15 | - Tools by MCP (Model Context Protocol) server and your built-in tools
16 | - Portable conversational memory with history for stateless/distributed applications
17 |
18 | ## Supported LLMs
19 |
20 | - [x] Gemini (see [models](https://ai.google.dev/gemini-api/docs/models?hl=ja))
21 | - [x] Anthropic (see [models](https://docs.anthropic.com/en/docs/about-claude/models/all-models))
22 | - [x] OpenAI (see [models](https://platform.openai.com/docs/models))
23 |
24 | ## Quick Start
25 |
26 | ### Install
27 |
28 | ```bash
29 | go get github.com/m-mizutani/gollem
30 | ```
31 |
32 | ### Example
33 |
34 | #### Query to LLM
35 |
36 | ```go
37 | llmProvider := os.Args[1]
38 | model := os.Args[2]
39 | prompt := os.Args[3]
40 |
41 | var client gollem.LLMClient
42 | var err error
43 |
44 | switch llmProvider {
45 | case "gemini":
46 | client, err = gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"), gemini.WithModel(model))
47 | case "claude":
48 | client, err = claude.New(ctx, os.Getenv("ANTHROPIC_API_KEY"), claude.WithModel(model))
49 | case "openai":
50 | client, err = openai.New(ctx, os.Getenv("OPENAI_API_KEY"), openai.WithModel(model))
51 | }
52 |
53 | if err != nil {
54 | panic(err)
55 | }
56 |
57 | ssn, err := client.NewSession(ctx)
58 | if err != nil {
59 | panic(err)
60 | }
61 |
62 | result, err := ssn.GenerateContent(ctx, gollem.Text(prompt))
63 | if err != nil {
64 | panic(err)
65 | }
66 |
67 | fmt.Println(result.Texts)
68 | ```
69 |
70 | #### Generate embedding
71 | ```go
72 | // Create OpenAI client
73 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY"))
74 | if err != nil {
75 | panic(err)
76 | }
77 |
78 | embedding, err := client.GenerateEmbedding(ctx, 100, []string{"Hello, world!", "This is a test"})
79 | if err != nil {
80 | panic(err)
81 | }
82 |
83 | // Print two embedding arrays, each containing 100-dimensional vectors
84 | fmt.Println("embedding:", embedding)
85 | ```
86 |
87 | #### Agentic application with MCP server
88 |
89 | Here's a simple example of creating a custom tool and using it with an LLM:
90 |
91 | ```go
92 | func main() {
93 | ctx := context.Background()
94 |
95 | // Create OpenAI client
96 | client, err := OpenAI.New(ctx, os.Getenv("OPENAI_API_KEY"))
97 | if err != nil {
98 | panic(err)
99 | }
100 |
101 | // Create MCP client with local server
102 | mcpLocal, err := mcp.NewStdio(ctx, "/path/to/mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"}))
103 | if err != nil {
104 | panic(err)
105 | }
106 | defer mcpLocal.Close()
107 |
108 | // Create MCP client with remote server
109 | mcpRemote, err := mcp.NewSSE(ctx, "http://localhost:8080")
110 | if err != nil {
111 | panic(err)
112 | }
113 | defer mcpRemote.Close()
114 |
115 | // Create gollem instance
116 | agent := gollem.New(client,
117 | gollem.WithToolSets(mcpLocal, mcpRemote),
118 | gollem.WithTools(&MyTool{}),
119 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
120 | fmt.Printf("🤖 %s\n", msg)
121 | return nil
122 | }),
123 | )
124 |
125 | var history *gollem.History
126 | for {
127 | fmt.Print("> ")
128 | scanner := bufio.NewScanner(os.Stdin)
129 | scanner.Scan()
130 |
131 | newHistory, err := agent.Prompt(ctx, scanner.Text(), history)
132 | if err != nil {
133 | panic(err)
134 | }
135 | history = newHistory
136 | }
137 | }
138 | ```
139 |
140 | See the full example in [examples/basic](https://github.com/m-mizutani/gollem/tree/main/examples/basic), and more examples in [examples](https://github.com/m-mizutani/gollem/tree/main/examples).
141 |
142 | ## Documentation
143 |
144 | For more details and examples, visit our [document](https://github.com/m-mizutani/gollem/tree/main/doc) and [godoc](https://pkg.go.dev/github.com/m-mizutani/gollem).
145 |
146 | ## License
147 |
148 | Apache 2.0 License. See [LICENSE](LICENSE) for details.
149 |
--------------------------------------------------------------------------------
/Taskfile.yml:
--------------------------------------------------------------------------------
1 | # https://taskfile.dev
2 |
3 | version: '3'
4 |
5 | tasks:
6 | default:
7 | deps:
8 | - mock
9 | mock:
10 | desc: Generate mock files
11 | cmds:
12 | - go run github.com/matryer/moq@v0.5.3
13 | -out ./mock/mock_gen.go
14 | -pkg mock
15 | -rm -skip-ensure -stub
16 | .
17 | LLMClient Session Tool
18 |
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import (
4 | "context"
5 | "log/slog"
6 | )
7 |
8 | type ctxLoggerKey struct{}
9 |
10 | var defaultLogger = slog.New(slog.DiscardHandler)
11 |
12 | func ctxWithLogger(ctx context.Context, logger *slog.Logger) context.Context {
13 | return context.WithValue(ctx, ctxLoggerKey{}, logger)
14 | }
15 |
16 | func LoggerFromContext(ctx context.Context) *slog.Logger {
17 | if logger, ok := ctx.Value(ctxLoggerKey{}).(*slog.Logger); ok {
18 | return logger
19 | }
20 | return defaultLogger
21 | }
22 |
--------------------------------------------------------------------------------
/doc/README.md:
--------------------------------------------------------------------------------
1 | # gollem Documentation
2 |
3 | gollem is a Go framework for building applications with Large Language Models (LLMs). This documentation provides comprehensive guides and examples to help you get started and make the most of gollem.
4 |
5 | ## Documentation Index
6 |
7 | - [Getting Started](getting-started.md) - Quick start guide and basic usage
8 | - [Tools](tools.md) - Creating and using custom tools
9 | - [MCP Server Integration](mcp.md) - Integrating with Model Context Protocol servers
10 | - [History](history.md) - Managing conversation history and context
11 | - [Examples](examples.md) - Practical examples and use cases
12 |
13 | ## Key Features
14 |
15 | - **Multiple LLM Support**: Works with various LLM providers including OpenAI, Claude, and Gemini
16 | - **Custom Tools**: Create and integrate your own tools for LLMs to use
17 | - **MCP Integration**: Connect with external tools and resources through Model Context Protocol
18 | - **History Management**: Maintain conversation context across sessions
19 |
20 |
21 |
--------------------------------------------------------------------------------
/doc/examples.md:
--------------------------------------------------------------------------------
1 | # Practical Examples
2 |
3 | This guide provides practical examples of using gollem in various scenarios.
4 |
5 | ## Calculator Tool
6 |
7 | A more complex example using a calculator tool:
8 |
9 | ```go
10 | type CalculatorTool struct{}
11 |
12 | func (t *CalculatorTool) Spec() gollem.ToolSpec {
13 | return gollem.ToolSpec{
14 | Name: "calculator",
15 | Description: "Performs basic arithmetic operations",
16 | Parameters: map[string]*gollem.Parameter{
17 | "operation": {
18 | Name: "operation",
19 | Type: gollem.TypeString,
20 | Description: "The operation to perform (add, subtract, multiply, divide)",
21 | Required: true,
22 | },
23 | "a": {
24 | Name: "a",
25 | Type: gollem.TypeNumber,
26 | Description: "First number",
27 | Required: true,
28 | },
29 | "b": {
30 | Name: "b",
31 | Type: gollem.TypeNumber,
32 | Description: "Second number",
33 | Required: true,
34 | },
35 | },
36 | }
37 | }
38 |
39 | func (t *CalculatorTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
40 | // Validate operation
41 | op, ok := args["operation"].(string)
42 | if !ok {
43 | return nil, errors.New("operation must be a string")
44 | }
45 | if op != "add" && op != "subtract" && op != "multiply" && op != "divide" {
46 | return nil, errors.New("invalid operation: must be one of add, subtract, multiply, divide")
47 | }
48 |
49 | // Validate first number
50 | a, ok := args["a"].(float64)
51 | if !ok {
52 | return nil, errors.New("first number must be a number")
53 | }
54 |
55 | // Validate second number
56 | b, ok := args["b"].(float64)
57 | if !ok {
58 | return nil, errors.New("second number must be a number")
59 | }
60 |
61 | // Validate division by zero
62 | if op == "divide" && b == 0 {
63 | return nil, errors.New("division by zero is not allowed")
64 | }
65 |
66 | var result float64
67 | switch op {
68 | case "add":
69 | result = a + b
70 | case "subtract":
71 | result = a - b
72 | case "multiply":
73 | result = a * b
74 | case "divide":
75 | result = a / b
76 | }
77 |
78 | return map[string]any{
79 | "result": result,
80 | }, nil
81 | }
82 | ```
83 |
84 | ## Weather Tool with MCP
85 |
86 | An example of a weather tool using MCP:
87 |
88 | ```go
89 | type WeatherTool struct{}
90 |
91 | func (t *WeatherTool) Spec() gollem.ToolSpec {
92 | return gollem.ToolSpec{
93 | Name: "weather",
94 | Description: "Gets weather information for a location",
95 | Parameters: map[string]*gollem.Parameter{
96 | "location": {
97 | Type: gollem.TypeString,
98 | Description: "City name or coordinates",
99 | Required: true,
100 | },
101 | },
102 | }
103 | }
104 |
105 | func (t *WeatherTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
106 | location, ok := args["location"].(string)
107 | if !ok {
108 | return nil, fmt.Errorf("location must be a string")
109 | }
110 | // Implement weather API call here
111 | return map[string]any{
112 | "temperature": 25.5,
113 | "condition": "sunny",
114 | }, nil
115 | }
116 | ```
117 |
118 | ## Best Practices
119 |
120 | 1. **Error Handling**: Always handle errors properly
121 | 2. **Type Safety**: Use appropriate types and validate input
122 | 3. **Documentation**: Document your tools and examples
123 | 4. **Testing**: Write tests for your tools
124 | 5. **Security**: Implement proper security measures
125 |
126 | ## Argument Validation
127 |
128 | Here's an example of proper argument validation:
129 |
130 | ```go
131 | func (t *CalculatorTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
132 | // Validate operation
133 | op, ok := args["operation"].(string)
134 | if !ok {
135 | return nil, errors.New("operation must be a string")
136 | }
137 | if op != "add" && op != "subtract" && op != "multiply" && op != "divide" {
138 | return nil, errors.New("invalid operation: must be one of add, subtract, multiply, divide")
139 | }
140 |
141 | // Validate first number
142 | a, ok := args["a"].(float64)
143 | if !ok {
144 | return nil, errors.New("first number must be a number")
145 | }
146 |
147 | // Validate second number
148 | b, ok := args["b"].(float64)
149 | if !ok {
150 | return nil, errors.New("second number must be a number")
151 | }
152 |
153 | // Validate division by zero
154 | if op == "divide" && b == 0 {
155 | return nil, errors.New("division by zero is not allowed")
156 | }
157 |
158 | var result float64
159 | switch op {
160 | case "add":
161 | result = a + b
162 | case "subtract":
163 | result = a - b
164 | case "multiply":
165 | result = a * b
166 | case "divide":
167 | result = a / b
168 | }
169 |
170 | return map[string]any{
171 | "result": result,
172 | }, nil
173 | }
174 | ```
175 |
176 | ## Next Steps
177 |
178 | - Learn more about [tool creation](tools.md)
179 | - Explore [MCP server integration](mcp.md)
180 | - Check out the [getting started guide](getting-started.md)
181 | - Understand [history management](history.md) for conversation context
182 | - Review the [complete documentation](README.md)
183 |
--------------------------------------------------------------------------------
/doc/getting-started.md:
--------------------------------------------------------------------------------
1 | # Getting Started with gollem
2 |
3 | gollem is a Go framework for building applications with Large Language Models (LLMs). This guide will help you get started with the framework.
4 |
5 | ## Installation
6 |
7 | Install gollem using Go modules:
8 |
9 | ```bash
10 | go get github.com/m-mizutani/gollem
11 | ```
12 |
13 | ## Basic Usage
14 |
15 | Here's a simple example of how to use gollem with OpenAI's OpenAI model:
16 |
17 | ```go
18 | package main
19 |
20 | import (
21 | "context"
22 | "fmt"
23 | "os"
24 |
25 | "github.com/m-mizutani/gollem"
26 | "github.com/m-mizutani/gollem/llm/openai"
27 | "github.com/m-mizutani/gollem/mcp"
28 | )
29 |
30 | func main() {
31 | // Create OpenAI client
32 | client, err := OpenAI.New(context.Background(), os.Getenv("OPENAI_API_KEY"))
33 | if err != nil {
34 | panic(err)
35 | }
36 |
37 | // Create MCP client
38 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080")
39 | if err != nil {
40 | panic(err)
41 | }
42 | defer mcpClient.Close()
43 |
44 | // Create gollem instance
45 | s := gollem.New(client,
46 | gollem.WithToolSets(mcpClient),
47 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
48 | fmt.Println(msg)
49 | return nil
50 | }),
51 | )
52 |
53 | // Send a message to the LLM
54 | if err := s.Prompt(context.Background(), "Hello, how are you?"); err != nil {
55 | panic(err)
56 | }
57 | }
58 | ```
59 |
60 | This code uses the OpenAI OpenAI model to receive a message from the user and send it to the LLM. Here, we are not specifying a Tool or MCP server, so the LLM is expected to return only a message.
61 |
62 | For information on how to integrate with Tools and MCP servers, please refer to [tools](tools.md) and [mcp](mcp.md) documents.
63 |
64 | ## Supported LLM Providers
65 |
66 | gollem supports multiple LLM providers:
67 |
68 | - Gemini
69 | - Anthropic (Claude)
70 | - OpenAI (OpenAI)
71 |
72 | Each provider has its own client implementation in the `llm` package. See the respective documentation for configuration options.
73 |
74 | ## Key Concepts
75 |
76 | 1. **LLM Client**: The interface to communicate with LLM providers
77 | 2. **Tools**: Custom functions that LLMs can use to perform actions (see [Tools](tools.md))
78 | 3. **MCP Server**: External tool integration through Model Context Protocol (see [MCP Server Integration](mcp.md))
79 | 4. **Natural Language Interface**: Interact with your application using natural language
80 | 5. **History Management**: Maintain conversation context across sessions (see [History](history.md))
81 |
82 | ## Error Handling
83 |
84 | gollem provides robust error handling capabilities to help you build reliable applications:
85 |
86 | ### Error Types
87 | - **LLM Errors**: Errors from the LLM provider (e.g., rate limits, invalid requests)
88 | - **Tool Execution Errors**: Errors during tool execution
89 | - **MCP Server Errors**: Errors from MCP server communication
90 |
91 | ### Best Practices
92 | 1. **Graceful Degradation**: Implement fallback mechanisms when LLM or tools fail
93 | 2. **Retry Strategies**: Use exponential backoff for transient errors
94 | 3. **Error Logging**: Log errors with appropriate context for debugging
95 | 4. **User Feedback**: Provide clear error messages to end users
96 |
97 | Example of error handling:
98 | ```go
99 | newHistory, err := s.Prompt(ctx, userInput, history)
100 | if err != nil {
101 | // Handle errors
102 | log.Printf("Error: %v", err)
103 | return nil, fmt.Errorf("failed to process request: %w", err)
104 | }
105 | ```
106 |
107 | ## Context Management
108 |
109 | gollem provides a history-based context management system to maintain conversation state:
110 |
111 | ### History Object
112 | The `History` object maintains the conversation context, including:
113 | - Previous messages
114 | - Tool execution results
115 | - System prompts
116 | - Context metadata
117 |
118 | ### Best Practices
119 | 1. **Memory Management**: Clear old history when it exceeds size limits
120 | 2. **Context Persistence**: Save important context for future sessions
121 | 3. **Context Pruning**: Remove irrelevant information to maintain focus
122 |
123 | Example of history management:
124 | ```go
125 | // Initialize history
126 | var history *gollem.History
127 |
128 | // Process user input with history
129 | newHistory, err := s.Prompt(ctx, userInput, history)
130 | if err != nil {
131 | return nil, err
132 | }
133 |
134 | // Update history
135 | history = newHistory
136 | ```
137 |
138 | ## Next Steps
139 |
140 | - Learn how to create and use [custom tools](tools.md)
141 | - Explore [MCP server integration](mcp.md)
142 | - Check out [practical examples](examples.md)
143 | - Understand [history management](history.md) for conversation context
144 | - Review the [complete documentation](README.md)
145 |
--------------------------------------------------------------------------------
/doc/history.md:
--------------------------------------------------------------------------------
1 | # History
2 |
3 | History represents a conversation history that can be used across different LLM sessions. It stores messages in a format specific to each LLM type (OpenAI, Claude, or Gemini).
4 |
5 | ## Version Management
6 |
7 | History includes version information to ensure compatibility:
8 |
9 | - Current version: 1
10 | - Version checking is performed when converting between formats
11 | - Version mismatch will result in an error
12 | - This helps maintain compatibility when the History structure changes in future updates
13 |
14 | ## Session Persistence
15 |
16 | History is essential for maintaining conversation context across stateless sessions. Common use cases include:
17 |
18 | - Backend services handling stateless HTTP requests
19 | - When your backend service receives requests from different instances or after restarts
20 | - When you need to maintain conversation context across multiple API calls
21 | - Distributed systems
22 | - When sessions may be handled by different instances
23 | - When you need to load balance conversations across multiple servers
24 | - Long-running conversations
25 | - When conversations need to be resumed after service restarts
26 | - When implementing features like "continue previous conversation"
27 |
28 | ## Portability
29 |
30 | History can be easily serialized/deserialized using standard JSON marshaling. This enables:
31 |
32 | - Storing conversations in databases
33 | - Persist conversations for future reference
34 | - Implement conversation history features
35 | - Transferring conversations between services
36 | - Move conversations between different environments
37 | - Share conversations across microservices
38 | - Implementing conversation backup and restore features
39 | - Backup important conversations
40 | - Restore conversations after system failures
41 |
42 | ## LLM Type Compatibility
43 |
44 | Each History instance is tied to a specific LLM type (OpenAI, Claude, or Gemini). Important notes:
45 |
46 | - Direct conversion between different LLM types is not supported
47 | - Each LLM type has its own message format and capabilities
48 |
49 | ## Usage Guidelines
50 |
51 | 1. Get History from Prompt response:
52 | ```go
53 | // Create a new gollem instance
54 | g := gollem.New(llmClient)
55 |
56 | // Get response from Prompt
57 | history, err := g.Prompt(ctx, "What is the weather?")
58 | if err != nil {
59 | return nil, fmt.Errorf("failed to get prompt response: %w", err)
60 | }
61 | ```
62 |
63 | 2. Store the History for future use:
64 | ```go
65 | // Store history in your database or storage
66 | jsonData, err := json.Marshal(history)
67 | if err != nil {
68 | return fmt.Errorf("failed to marshal history: %w", err)
69 | }
70 | ```
71 |
72 | 3. Use stored History in a new session:
73 | ```go
74 | // Restore history
75 | var restoredHistory History
76 | if err := json.Unmarshal(jsonData, &restoredHistory); err != nil {
77 | return fmt.Errorf("failed to unmarshal history: %w", err)
78 | }
79 |
80 | // Use history in next Prompt call
81 | newHistory, err := g.Prompt(ctx, "What about tomorrow?", gollem.WithHistory(&restoredHistory))
82 | if err != nil {
83 | return nil, fmt.Errorf("failed to get prompt response: %w", err)
84 | }
85 | ```
86 |
87 | Note: The History returned from Prompt contains the complete conversation history, so there's no need to manage or track individual messages. Each Prompt response provides a new History instance that includes all previous messages.
88 |
89 | ## Best Practices
90 |
91 | 1. **Error Handling**
92 | - Always check for errors when converting between formats
93 | - Handle type mismatches gracefully
94 | - Check for version compatibility
95 | - Implement proper error handling for version mismatches
96 |
97 | 2. **Storage Considerations**
98 | - Consider the size of your conversations
99 | - Implement cleanup strategies for old conversations
100 |
101 | 3. **Security**
102 | - Implement proper access controls
103 |
104 | ## Next Steps
105 |
106 | - Learn more about [tool creation](tools.md)
107 | - Explore [MCP server integration](mcp.md)
108 | - Check out [practical examples](examples.md)
109 | - Review the [getting started guide](getting-started.md)
110 | - Explore the [complete documentation](README.md)
111 |
112 |
--------------------------------------------------------------------------------
/doc/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-mizutani/gollem/0c5d4dbbe006a021f126c22e2a60ec55e3f9a79e/doc/images/logo.png
--------------------------------------------------------------------------------
/doc/mcp.md:
--------------------------------------------------------------------------------
1 | # MCP Server Integration
2 |
3 | gollem supports integration with MCP (Model Context Protocol) servers, allowing you to extend LLM capabilities with external tools and resources.
4 |
5 | ## What is MCP?
6 |
7 | MCP is a protocol that enables LLMs to interact with external tools and resources through a standardized interface. It provides a way to:
8 |
9 | - Define and expose custom tools
10 | - Manage external resources
11 | - Customize LLM prompts
12 |
13 | ## Connecting to an MCP Server
14 |
15 | To connect your gollem application to an MCP server, you can use either HTTP SSE or stdio transport:
16 |
17 | ```go
18 | // Using HTTP SSE transport
19 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080")
20 | if err != nil {
21 | panic(err)
22 | }
23 | defer mcpClient.Close()
24 |
25 | s := gollem.New(client,
26 | gollem.WithToolSets(mcpClient),
27 | )
28 |
29 | // Using stdio transport
30 | mcpClient, err := mcp.NewStdio(context.Background(), "/path/to/mcp/server", []string{"--arg1", "value1"})
31 | if err != nil {
32 | panic(err)
33 | }
34 | defer mcpClient.Close()
35 |
36 | s := gollem.New(client,
37 | gollem.WithToolSets(mcpClient),
38 | )
39 | ```
40 |
41 | ## Options
42 |
43 | 1. **Environment Variables**: Set environment variables for the MCP client
44 | ```go
45 | mcpClient, err := mcp.NewStdio(context.Background(), "/path/to/mcp/server", []string{},
46 | mcp.WithEnvVars([]string{"MCP_ENV=test"}),
47 | )
48 | ```
49 |
50 | 2. **HTTP Headers**: Set custom HTTP headers for SSE transport
51 | ```go
52 | mcpClient, err := mcp.NewSSE(context.Background(), "http://localhost:8080",
53 | mcp.WithHeaders(map[string]string{
54 | "Authorization": "Bearer token",
55 | }),
56 | )
57 | ```
58 |
59 | ## Next Steps
60 |
61 | - Learn more about [tool creation](tools.md)
62 | - Check out [practical examples](examples.md) of MCP integration
63 | - Review the [getting started guide](getting-started.md) for basic usage
64 | - Understand [history management](history.md) for conversation context
65 | - Explore the [complete documentation](README.md)
--------------------------------------------------------------------------------
/doc/tools.md:
--------------------------------------------------------------------------------
1 | # Tools in gollem
2 |
3 | Tools are your own custom built-in functions that LLMs can use to perform specific actions in your application. This guide explains how to create and use tools with gollem.
4 |
5 | ## Creating a Tool
6 |
7 | To create a tool, you need to implement the `Tool` interface:
8 |
9 | ```go
10 | type Tool interface {
11 | Spec() ToolSpec
12 | Run(ctx context.Context, args map[string]any) (map[string]any, error)
13 | }
14 | ```
15 |
16 | Here's an example of a simple tool:
17 |
18 | ```go
19 | type HelloTool struct{}
20 |
21 | func (t *HelloTool) Spec() gollem.ToolSpec {
22 | return gollem.ToolSpec{
23 | Name: "hello",
24 | Description: "Returns a greeting",
25 | Parameters: map[string]*gollem.Parameter{
26 | "name": {
27 | Type: gollem.TypeString,
28 | Description: "Name of the person to greet",
29 | },
30 | },
31 | Required: []string{"name"},
32 | }
33 | }
34 |
35 | func (t *HelloTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
36 | return map[string]any{
37 | "message": fmt.Sprintf("Hello, %s!", args["name"]),
38 | }, nil
39 | }
40 | ```
41 |
42 | ## Tool Specification
43 |
44 | The `ToolSpec` defines the tool's interface:
45 |
46 | - `Name`: Unique identifier for the tool
47 | - `Description`: Human-readable description of what the tool does
48 | - `Parameters`: Map of parameter names to their specifications
49 | - `Required`: List of required parameter names (For Object type)
50 |
51 | Each parameter specification includes:
52 | - `Type`: Parameter type (string, number, boolean, etc.)
53 | - `Description`: Human-readable description
54 | - `Title`: Optional user-friendly name for the parameter
55 | - `Required`: Optional boolean indicating if the parameter is required
56 | - `RequiredFields`: List of required field names when Type is Object
57 | - `Enum`: Optional list of allowed values
58 | - `Properties`: Map of properties when Type is Object
59 | - `Items`: Specification for array items when Type is Array
60 | - `Minimum`/`Maximum`: Number constraints
61 | - `MinLength`/`MaxLength`: String length constraints
62 | - `Pattern`: Regular expression pattern for string validation
63 | - `MinItems`/`MaxItems`: Array size constraints
64 | - `Default`: Default value for the parameter
65 |
66 | > [!CAUTION]
67 | > Note that not all parameters are supported by every LLM, as parameter support varies between different LLM providers.
68 |
69 | ## Using Tools
70 |
71 | To use tools with your LLM:
72 |
73 | ```go
74 | s := gollem.New(client,
75 | gollem.WithTools(&HelloTool{}),
76 | )
77 | ```
78 |
79 | You can add multiple tools:
80 |
81 | ```go
82 | s := gollem.New(client,
83 | gollem.WithTools(&HelloTool{}, &CalculatorTool{}, &WeatherTool{}),
84 | )
85 | ```
86 |
87 | ## Best Practices
88 |
89 | 1. **Clear Descriptions**: Provide clear and concise descriptions for tools and parameters to help the LLM understand their purpose and usage
90 | 2. **Validate Input**: Always validate that parameters passed by the LLM match the specified types in your tool specification
91 | 3. **Error Handling**: When errors occur, return clear error messages that explain both what went wrong and how to fix it. The Error() message will be passed directly to the LLM, so include actionable guidance
92 | 4. **Nested Results**: For nested tool results with multiple levels of maps, always use `map[string]any`. Other types may cause errors with some LLM SDKs like Gemini
93 |
94 | ## Next Steps
95 |
96 | - Learn about [MCP server integration](mcp.md) for external tool integration
97 | - Check out [practical examples](examples.md) of tool usage
98 | - Review the [getting started guide](getting-started.md) for basic usage
99 | - Understand [history management](history.md) for conversation context
100 | - Explore the [complete documentation](README.md)
101 |
--------------------------------------------------------------------------------
/errors.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import "errors"
4 |
5 | var (
6 | // ErrInvalidTool is returned when the tool validation of definition fails.
7 | ErrInvalidTool = errors.New("invalid tool specification")
8 |
9 | // ErrInvalidParameter is returned when the parameter validation of definition fails.
10 | ErrInvalidParameter = errors.New("invalid parameter")
11 |
12 | // ErrToolNameConflict is returned when the tool name is already used.
13 | ErrToolNameConflict = errors.New("tool name conflict")
14 |
15 | // ErrLoopLimitExceeded is returned when the session loop limit is exceeded. You can resume the session by calling the Prompt() method again.
16 | ErrLoopLimitExceeded = errors.New("loop limit exceeded")
17 |
18 | // ErrInvalidInputSchema is returned when the input schema from MCP is invalid or unsupported.
19 | ErrInvalidInputSchema = errors.New("invalid input schema")
20 |
21 | // ErrInvalidHistoryData is returned when the history data is invalid or unsupported.
22 | ErrInvalidHistoryData = errors.New("invalid history data")
23 |
24 | // ErrLLMTypeMismatch is returned when the LLM type is invalid or unsupported when loading history.
25 | ErrLLMTypeMismatch = errors.New("llm type mismatch")
26 |
27 | // ErrHistoryVersionMismatch is returned when the history version is invalid or unsupported.
28 | ErrHistoryVersionMismatch = errors.New("history version mismatch")
29 | )
30 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Gollem Examples
2 |
3 | This directory contains various examples demonstrating the usage of Gollem.
4 |
5 | ## Basic Example
6 | [Basic Example](basic/main.go) - A simple example showing the basic usage of Gollem.
7 |
8 | ## Chat Example
9 | [Chat Example](chat/main.go) - An example demonstrating chat functionality with Gollem.
10 |
11 | ## MCP Example
12 | [MCP Example](mcp/main.go) - An example showing how to use MCP (Model Control Protocol) with Gollem.
13 |
14 | ## Tools Example
15 | [Tools Example](tools/main.go) - An example demonstrating how to use tools with Gollem.
16 |
--------------------------------------------------------------------------------
/examples/basic/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "fmt"
7 | "os"
8 |
9 | "github.com/m-mizutani/gollem"
10 | "github.com/m-mizutani/gollem/llm/openai"
11 | "github.com/m-mizutani/gollem/mcp"
12 | )
13 |
14 | type MyTool struct{}
15 |
16 | func (t *MyTool) Spec() gollem.ToolSpec {
17 | return gollem.ToolSpec{
18 | Name: "my_tool",
19 | Description: "Returns a greeting",
20 | Parameters: map[string]*gollem.Parameter{
21 | "name": {
22 | Type: gollem.TypeString,
23 | Description: "Name of the person to greet",
24 | },
25 | },
26 | }
27 | }
28 |
29 | func (t *MyTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
30 | name, ok := args["name"].(string)
31 | if !ok {
32 | return nil, fmt.Errorf("name is required")
33 | }
34 | return map[string]any{"message": fmt.Sprintf("Hello, %s!", name)}, nil
35 | }
36 | func main() {
37 | ctx := context.Background()
38 |
39 | // Create OpenAI client
40 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY"))
41 | if err != nil {
42 | panic(err)
43 | }
44 |
45 | // Create MCP client with local server
46 | mcpLocal, err := mcp.NewStdio(ctx, "./mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"}))
47 | if err != nil {
48 | panic(err)
49 | }
50 | defer mcpLocal.Close()
51 |
52 | // Create MCP client with remote server
53 | mcpRemote, err := mcp.NewSSE(ctx, "http://localhost:8080")
54 | if err != nil {
55 | panic(err)
56 | }
57 | defer mcpRemote.Close()
58 |
59 | // Create gollem instance
60 | agent := gollem.New(client,
61 | // Not only MCP servers,
62 | gollem.WithToolSets(mcpLocal, mcpRemote),
63 | // But also you can use your own built-in tools
64 | gollem.WithTools(&MyTool{}),
65 | // You can customize the callback function for each message and tool call.
66 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
67 | fmt.Printf("🤖 %s\n", msg)
68 | return nil
69 | }),
70 | )
71 |
72 | var history *gollem.History
73 | for {
74 | fmt.Print("> ")
75 | scanner := bufio.NewScanner(os.Stdin)
76 | scanner.Scan()
77 |
78 | newHistory, err := agent.Prompt(ctx, scanner.Text(), gollem.WithHistory(history))
79 | if err != nil {
80 | panic(err)
81 | }
82 | history = newHistory
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/examples/chat/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "os"
10 | "path/filepath"
11 |
12 | "github.com/m-mizutani/gollem"
13 | "github.com/m-mizutani/gollem/llm/gemini"
14 | )
15 |
16 | // WeatherTool is a simple tool that returns a weather
17 | type WeatherTool struct{}
18 |
19 | func (t *WeatherTool) Spec() gollem.ToolSpec {
20 | return gollem.ToolSpec{
21 | Name: "weather",
22 | Description: "Returns a weather",
23 | Parameters: map[string]*gollem.Parameter{
24 | "city": {
25 | Type: gollem.TypeString,
26 | Description: "City name",
27 | },
28 | },
29 | Required: []string{"city"},
30 | }
31 | }
32 |
33 | func (t *WeatherTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
34 | city, ok := args["city"].(string)
35 | if !ok {
36 | return nil, fmt.Errorf("city is required")
37 | }
38 |
39 | return map[string]any{
40 | "message": fmt.Sprintf("The weather in %s is sunny.", city),
41 | }, nil
42 | }
43 |
44 | func main() {
45 | ctx := context.Background()
46 | llmModel, err := gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"))
47 | if err != nil {
48 | panic(err)
49 | }
50 |
51 | g := gollem.New(llmModel,
52 | gollem.WithResponseMode(gollem.ResponseModeStreaming),
53 | gollem.WithTools(&WeatherTool{}),
54 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
55 | fmt.Printf("%s", msg)
56 | return nil
57 | }),
58 | gollem.WithToolRequestHook(func(ctx context.Context, tool gollem.FunctionCall) error {
59 | fmt.Printf("⚡ Call: %s\n", tool.Name)
60 | return nil
61 | }),
62 | )
63 |
64 | tmpFile, err := os.CreateTemp("", "gollem-chat-*.txt")
65 | if err != nil {
66 | panic(err)
67 | }
68 | if err := tmpFile.Close(); err != nil {
69 | panic(err)
70 | }
71 | println("history file:", tmpFile.Name())
72 |
73 | for {
74 | history, err := loadHistory(tmpFile.Name())
75 | if err != nil {
76 | panic(err)
77 | }
78 |
79 | fmt.Print("> ")
80 | scanner := bufio.NewScanner(os.Stdin)
81 | scanner.Scan()
82 | text := scanner.Text()
83 |
84 | fmt.Printf("🤖 ")
85 | newHistory, err := g.Prompt(ctx, text, gollem.WithHistory(history))
86 | if err != nil {
87 | panic(err)
88 | }
89 |
90 | if err := dumpHistory(newHistory, tmpFile.Name()); err != nil {
91 | panic(err)
92 | }
93 |
94 | fmt.Printf("\n")
95 | }
96 | }
97 |
98 | func dumpHistory(history *gollem.History, path string) error {
99 | f, err := os.Create(filepath.Clean(path))
100 | if err != nil {
101 | return err
102 | }
103 | defer f.Close()
104 |
105 | if err := json.NewEncoder(f).Encode(history); err != nil {
106 | return err
107 | }
108 |
109 | return nil
110 | }
111 |
112 | func loadHistory(path string) (*gollem.History, error) {
113 | if st, err := os.Stat(path); err != nil || st.Size() == 0 {
114 | return nil, err
115 | }
116 |
117 | f, err := os.Open(filepath.Clean(path))
118 | if err != nil {
119 | return nil, err
120 | }
121 | defer f.Close()
122 |
123 | data, err := io.ReadAll(f)
124 | if err != nil {
125 | return nil, err
126 | }
127 |
128 | var history gollem.History
129 | if err := json.Unmarshal(data, &history); err != nil {
130 | return nil, err
131 | }
132 |
133 | return &history, nil
134 | }
135 |
--------------------------------------------------------------------------------
/examples/embedding/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 |
8 | "github.com/m-mizutani/gollem/llm/openai"
9 | )
10 |
11 | func main() {
12 | ctx := context.Background()
13 |
14 | // Create OpenAI client
15 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY"))
16 | if err != nil {
17 | panic(err)
18 | }
19 |
20 | embedding, err := client.GenerateEmbedding(ctx, 100, []string{"Hello, world!", "This is a test"})
21 | if err != nil {
22 | panic(err)
23 | }
24 | fmt.Println("embedding:", embedding)
25 | }
26 |
--------------------------------------------------------------------------------
/examples/mcp/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 |
8 | "github.com/m-mizutani/gollem"
9 | "github.com/m-mizutani/gollem/llm/openai"
10 | "github.com/m-mizutani/gollem/mcp"
11 | )
12 |
13 | func main() {
14 | ctx := context.Background()
15 |
16 | // Create OpenAI client
17 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY"))
18 | if err != nil {
19 | panic(err)
20 | }
21 |
22 | // Create MCP client (SSE)
23 | sseClient, err := mcp.NewSSE(ctx, "http://localhost:8080")
24 | if err != nil {
25 | log.Fatalf("Failed to create SSE client: %v", err)
26 | }
27 | defer sseClient.Close()
28 |
29 | // Create MCP client (Stdio)
30 | stdioClient, err := mcp.NewStdio(ctx, "./mcp-server", []string{}, mcp.WithEnvVars([]string{"MCP_ENV=test"}))
31 | if err != nil {
32 | log.Fatalf("Failed to create Stdio client: %v", err)
33 | }
34 | defer stdioClient.Close()
35 |
36 | // Create gollem instance with MCP tools
37 | g := gollem.New(client,
38 | gollem.WithToolSets(sseClient, stdioClient),
39 | )
40 |
41 | // Send a message
42 | if _, err := g.Prompt(ctx, "Hello, I want to use MCP tools."); err != nil {
43 | panic(err)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/examples/query/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "os"
7 |
8 | "github.com/m-mizutani/gollem"
9 | "github.com/m-mizutani/gollem/llm/claude"
10 | "github.com/m-mizutani/gollem/llm/gemini"
11 | "github.com/m-mizutani/gollem/llm/openai"
12 | )
13 |
14 | func main() {
15 | ctx := context.Background()
16 |
17 | if len(os.Args) != 4 {
18 | fmt.Println("Usage: go run main.go ")
19 | os.Exit(1)
20 | }
21 |
22 | llmProvider := os.Args[1]
23 | model := os.Args[2]
24 | prompt := os.Args[3]
25 |
26 | var client gollem.LLMClient
27 | var err error
28 |
29 | switch llmProvider {
30 | case "gemini":
31 | client, err = gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"), gemini.WithModel(model))
32 | case "claude":
33 | client, err = claude.New(ctx, os.Getenv("ANTHROPIC_API_KEY"), claude.WithModel(model))
34 | case "openai":
35 | client, err = openai.New(ctx, os.Getenv("OPENAI_API_KEY"), openai.WithModel(model))
36 | }
37 |
38 | if err != nil {
39 | panic(err)
40 | }
41 |
42 | ssn, err := client.NewSession(ctx)
43 | if err != nil {
44 | panic(err)
45 | }
46 |
47 | result, err := ssn.GenerateContent(ctx, gollem.Text(prompt))
48 | if err != nil {
49 | panic(err)
50 | }
51 |
52 | fmt.Println(result.Texts)
53 | }
54 |
--------------------------------------------------------------------------------
/examples/simple/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "fmt"
7 | "os"
8 |
9 | "github.com/m-mizutani/gollem"
10 | "github.com/m-mizutani/gollem/llm/openai"
11 | "github.com/m-mizutani/gollem/mcp"
12 | )
13 |
14 | func main() {
15 | ctx := context.Background()
16 |
17 | // Create OpenAI client
18 | client, err := openai.New(ctx, os.Getenv("OPENAI_API_KEY"))
19 | if err != nil {
20 | panic(err)
21 | }
22 |
23 | // Create MCP client with local server
24 | mcpLocal, err := mcp.NewStdio(ctx, "./mcp-server", []string{"arg1", "arg2"})
25 | if err != nil {
26 | panic(err)
27 | }
28 | defer mcpLocal.Close()
29 |
30 | // Create gollem instance
31 | agent := gollem.New(client,
32 | gollem.WithToolSets(mcpLocal),
33 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
34 | fmt.Printf("🤖 %s\n", msg)
35 | return nil
36 | }),
37 | )
38 |
39 | fmt.Print("> ")
40 | scanner := bufio.NewScanner(os.Stdin)
41 | scanner.Scan()
42 |
43 | if _, err = agent.Prompt(ctx, scanner.Text()); err != nil {
44 | panic(err)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/examples/tools/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 |
8 | "github.com/m-mizutani/gollem"
9 | "github.com/m-mizutani/gollem/llm/gemini"
10 | )
11 |
12 | func main() {
13 | ctx := context.Background()
14 |
15 | // Initialize Gemini client
16 | client, err := gemini.New(ctx, os.Getenv("GEMINI_PROJECT_ID"), os.Getenv("GEMINI_LOCATION"))
17 | if err != nil {
18 | log.Fatal(err)
19 | }
20 |
21 | // Register tools
22 | tools := []gollem.Tool{
23 | &AddTool{},
24 | &MultiplyTool{},
25 | }
26 |
27 | servant := gollem.New(client,
28 | gollem.WithTools(tools...),
29 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
30 | log.Printf("Response: %s", msg)
31 | return nil
32 | }),
33 | /*
34 | gollem.WithLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
35 | Level: slog.LevelDebug,
36 | }))),
37 | */
38 | )
39 |
40 | query := "Add 5 and 3, then multiply the result by 2"
41 | log.Printf("Query: %s", query)
42 | if _, err := servant.Prompt(ctx, query); err != nil {
43 | log.Fatal(err)
44 | }
45 | }
46 |
47 | // AddTool is a tool that adds two numbers
48 | type AddTool struct{}
49 |
50 | func (t *AddTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
51 | a := args["a"].(float64)
52 | b := args["b"].(float64)
53 | log.Printf("Add: %f + %f", a, b)
54 | return map[string]any{"result": a + b}, nil
55 | }
56 |
57 | func (t *AddTool) Spec() gollem.ToolSpec {
58 | return gollem.ToolSpec{
59 | Name: "add",
60 | Description: "Adds two numbers together",
61 | Parameters: map[string]*gollem.Parameter{
62 | "a": {
63 | Type: "number",
64 | Description: "First number",
65 | },
66 | "b": {
67 | Type: "number",
68 | Description: "Second number",
69 | },
70 | },
71 | }
72 | }
73 |
74 | // MultiplyTool is a tool that multiplies two numbers
75 | type MultiplyTool struct{}
76 |
77 | func (t *MultiplyTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
78 | a := args["a"].(float64)
79 | b := args["b"].(float64)
80 | log.Printf("Multiply: %f * %f", a, b)
81 | return map[string]any{"result": a * b}, nil
82 | }
83 |
84 | func (t *MultiplyTool) Spec() gollem.ToolSpec {
85 | return gollem.ToolSpec{
86 | Name: "multiply",
87 | Description: "Multiplies two numbers together",
88 | Parameters: map[string]*gollem.Parameter{
89 | "a": {
90 | Type: "number",
91 | Description: "First number",
92 | },
93 | "b": {
94 | Type: "number",
95 | Description: "Second number",
96 | },
97 | },
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/exit.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import "context"
4 |
5 | // ExitTool is a tool that can be used to exit the session. IsCompleted() is called before calling a method to generate content every loop. If IsCompleted() returns true, the session will be ended.
6 | type ExitTool interface {
7 | Tool
8 | IsCompleted() bool
9 | Response() string
10 | }
11 |
12 | // DefaultExitTool is the tool to stop the session loop. This tool is used when the agent determines that the session should be ended. The tool name is "respond_to_user".
13 | type DefaultExitTool struct {
14 | isCompleted bool
15 | response string
16 | }
17 |
18 | func (t *DefaultExitTool) Spec() ToolSpec {
19 | return ToolSpec{
20 | Name: "respond_to_user",
21 | Description: "Call this tool when you have gathered all necessary information, completed all required actions, and already provided the final answer to the user's original request. This signals that your work on the current request is finished.",
22 | Parameters: map[string]*Parameter{
23 | /*
24 | "final_answer": {
25 | Type: "string",
26 | Description: "The comprehensive final answer or result for the user's request. If you already provided the final answer, you MUST omit this parameter.",
27 | },
28 | */
29 | },
30 | }
31 |
32 | }
33 |
34 | func (t *DefaultExitTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
35 | t.isCompleted = true
36 |
37 | if response, ok := args["final_answer"].(string); ok {
38 | t.response = response
39 | }
40 |
41 | return nil, nil
42 | }
43 |
44 | func (t *DefaultExitTool) IsCompleted() bool {
45 | return t.isCompleted
46 | }
47 |
48 | func (t *DefaultExitTool) Response() string {
49 | return t.response
50 | }
51 |
--------------------------------------------------------------------------------
/export_test.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/m-mizutani/gollem
2 |
3 | go 1.24
4 |
5 | require (
6 | cloud.google.com/go/aiplatform v1.69.0
7 | cloud.google.com/go/vertexai v0.13.3
8 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3
9 | github.com/google/uuid v1.6.0
10 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2
11 | github.com/m-mizutani/gt v0.0.16
12 | github.com/mark3labs/mcp-go v0.23.1
13 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1
14 | github.com/sashabaranov/go-openai v1.38.2
15 | google.golang.org/api v0.211.0
16 | google.golang.org/protobuf v1.35.2
17 | )
18 |
19 | require (
20 | cloud.google.com/go v0.116.0 // indirect
21 | cloud.google.com/go/auth v0.12.1 // indirect
22 | cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
23 | cloud.google.com/go/compute/metadata v0.5.2 // indirect
24 | cloud.google.com/go/iam v1.2.2 // indirect
25 | cloud.google.com/go/longrunning v0.6.2 // indirect
26 | github.com/dlclark/regexp2 v1.11.4 // indirect
27 | github.com/felixge/httpsnoop v1.0.4 // indirect
28 | github.com/go-logr/logr v1.4.2 // indirect
29 | github.com/go-logr/stdr v1.2.2 // indirect
30 | github.com/google/go-cmp v0.7.0 // indirect
31 | github.com/google/s2a-go v0.1.8 // indirect
32 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
33 | github.com/googleapis/gax-go/v2 v2.14.0 // indirect
34 | github.com/rogpeppe/go-internal v1.14.1 // indirect
35 | github.com/spf13/cast v1.7.1 // indirect
36 | github.com/stretchr/testify v1.10.0 // indirect
37 | github.com/tidwall/gjson v1.18.0 // indirect
38 | github.com/tidwall/match v1.1.1 // indirect
39 | github.com/tidwall/pretty v1.2.1 // indirect
40 | github.com/tidwall/sjson v1.2.5 // indirect
41 | github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
42 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
43 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
44 | go.opentelemetry.io/otel v1.29.0 // indirect
45 | go.opentelemetry.io/otel/metric v1.29.0 // indirect
46 | go.opentelemetry.io/otel/trace v1.29.0 // indirect
47 | golang.org/x/crypto v0.36.0 // indirect
48 | golang.org/x/net v0.38.0 // indirect
49 | golang.org/x/oauth2 v0.24.0 // indirect
50 | golang.org/x/sync v0.13.0 // indirect
51 | golang.org/x/sys v0.31.0 // indirect
52 | golang.org/x/text v0.24.0 // indirect
53 | golang.org/x/time v0.8.0 // indirect
54 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 // indirect
55 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 // indirect
56 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583 // indirect
57 | google.golang.org/grpc v1.67.3 // indirect
58 | )
59 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
2 | cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
3 | cloud.google.com/go/aiplatform v1.69.0 h1:XvBzK8e6/6ufbi/i129Vmn/gVqFwbNPmRQ89K+MGlgc=
4 | cloud.google.com/go/aiplatform v1.69.0/go.mod h1:nUsIqzS3khlnWvpjfJbP+2+h+VrFyYsTm7RNCAViiY8=
5 | cloud.google.com/go/auth v0.12.1 h1:n2Bj25BUMM0nvE9D2XLTiImanwZhO3DkfWSYS/SAJP4=
6 | cloud.google.com/go/auth v0.12.1/go.mod h1:BFMu+TNpF3DmvfBO9ClqTR/SiqVIm7LukKF9mbendF4=
7 | cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
8 | cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
9 | cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo=
10 | cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k=
11 | cloud.google.com/go/iam v1.2.2 h1:ozUSofHUGf/F4tCNy/mu9tHLTaxZFLOUiKzjcgWHGIA=
12 | cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY=
13 | cloud.google.com/go/longrunning v0.6.2 h1:xjDfh1pQcWPEvnfjZmwjKQEcHnpz6lHjfy7Fo0MK+hc=
14 | cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI=
15 | cloud.google.com/go/vertexai v0.13.3 h1:pbw1KfpdE8ZDrXxBKcIsS/j+EixyQRsyu6gxRkXq8/k=
16 | cloud.google.com/go/vertexai v0.13.3/go.mod h1:AxzUNrd36yhfOZedO+Y1v0ajVgGKOdv1njeQChL8IFY=
17 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 h1:b5t1ZJMvV/l99y4jbz7kRFdUp3BSDkI8EhSlHczivtw=
18 | github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3/go.mod h1:AapDW22irxK2PSumZiQXYUFvsdQgkwIWlpESweWZI/c=
19 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
21 | github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
22 | github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
23 | github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
24 | github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
25 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
26 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
27 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
28 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
29 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
30 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
31 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
32 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
33 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
34 | github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
35 | github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
36 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
37 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
38 | github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
39 | github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
40 | github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o=
41 | github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk=
42 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
43 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
44 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
45 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
46 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2 h1:3aA8RKFlWS6kX99+AJohdgIxiE5+Qwv6aIUD1QQXS2E=
47 | github.com/m-mizutani/goerr/v2 v2.0.0-beta.2/go.mod h1:K9+tBb0+0681yHpyHYqJmoijxn9qJf5grjRb9nMXevc=
48 | github.com/m-mizutani/gt v0.0.16 h1:bhJqMeqxojsgVBo9wqjPHk4s6o7mkaQLCcOus1hp1Vs=
49 | github.com/m-mizutani/gt v0.0.16/go.mod h1:0MPYSfGBLmYjTduzADVmIqD58ELQ5IfBFiK/f0FmB3k=
50 | github.com/mark3labs/mcp-go v0.23.1 h1:RzTzZ5kJ+HxwnutKA4rll8N/pKV6Wh5dhCmiJUu5S9I=
51 | github.com/mark3labs/mcp-go v0.23.1/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4=
52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
54 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
55 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
56 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw=
57 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
58 | github.com/sashabaranov/go-openai v1.38.2 h1:akrssjj+6DY3lWuDwHv6cBvJ8Z+FZDM9XEaaYFt0Auo=
59 | github.com/sashabaranov/go-openai v1.38.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
60 | github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
61 | github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
62 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
63 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
64 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
65 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
66 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
67 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
68 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
69 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
70 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
71 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
72 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
73 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
74 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
75 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
76 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
77 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
78 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
79 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
80 | go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
81 | go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
82 | go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
83 | go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
84 | go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
85 | go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
86 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
87 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
88 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
89 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
90 | golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
91 | golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
92 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
93 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
94 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
95 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
96 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
97 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
98 | golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
99 | golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
100 | google.golang.org/api v0.211.0 h1:IUpLjq09jxBSV1lACO33CGY3jsRcbctfGzhj+ZSE/Bg=
101 | google.golang.org/api v0.211.0/go.mod h1:XOloB4MXFH4UTlQSGuNUxw0UT74qdENK8d6JNsXKLi0=
102 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697 h1:ToEetK57OidYuqD4Q5w+vfEnPvPpuTwedCNVohYJfNk=
103 | google.golang.org/genproto v0.0.0-20241118233622-e639e219e697/go.mod h1:JJrvXBWRZaFMxBufik1a4RpFw4HhgVtBBWQeQgUj2cc=
104 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 h1:pgr/4QbFyktUv9CtQ/Fq4gzEE6/Xs7iCXbktaGzLHbQ=
105 | google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697/go.mod h1:+D9ySVjN8nY8YCVjc5O7PZDIdZporIDY3KaGfJunh88=
106 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583 h1:IfdSdTcLFy4lqUQrQJLkLt1PB+AsqVz6lwkWPzWEz10=
107 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241206012308-a4fef0638583/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU=
108 | google.golang.org/grpc v1.67.3 h1:OgPcDAFKHnH8X3O4WcO4XUc8GRDeKsKReqbQtiCj7N8=
109 | google.golang.org/grpc v1.67.3/go.mod h1:YGaHCc6Oap+FzBJTZLBzkGSYt/cvGPFTPxkn7QfSU8s=
110 | google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
111 | google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
112 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
113 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
114 |
--------------------------------------------------------------------------------
/history.go:
--------------------------------------------------------------------------------
1 | // Package gollem provides a unified interface for interacting with various LLM services.
2 | package gollem
3 |
4 | import (
5 | "cloud.google.com/go/vertexai/genai"
6 | "github.com/anthropics/anthropic-sdk-go"
7 | "github.com/anthropics/anthropic-sdk-go/packages/param"
8 | "github.com/m-mizutani/goerr/v2"
9 | "github.com/sashabaranov/go-openai"
10 | )
11 |
12 | // History represents a conversation history that can be used across different LLM sessions.
13 | // It stores messages in a format specific to each LLM type (OpenAI, Claude, or Gemini).
14 | //
15 | // For detailed documentation, see doc/history.md
16 | type llmType string
17 |
18 | const (
19 | llmTypeOpenAI llmType = "OpenAI"
20 | llmTypeGemini llmType = "gemini"
21 | llmTypeClaude llmType = "claude"
22 | )
23 |
24 | const (
25 | HistoryVersion = 1
26 | )
27 |
28 | type History struct {
29 | LLType llmType `json:"type"`
30 | Version int `json:"version"`
31 |
32 | Claude []claudeMessage `json:"claude,omitempty"`
33 | OpenAI []openai.ChatCompletionMessage `json:"OpenAI,omitempty"`
34 | Gemini []geminiMessage `json:"gemini,omitempty"`
35 | }
36 |
37 | func (x *History) ToCount() int {
38 | if x == nil {
39 | return 0
40 | }
41 | return len(x.Claude) + len(x.OpenAI) + len(x.Gemini)
42 | }
43 |
44 | func (x *History) ToGemini() ([]*genai.Content, error) {
45 | if x.Version != HistoryVersion {
46 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version))
47 | }
48 | if x.LLType != llmTypeGemini {
49 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not gemini", goerr.V("expected", llmTypeGemini), goerr.V("actual", x.LLType))
50 | }
51 | return toGeminiMessages(x.Gemini)
52 | }
53 |
54 | func (x *History) ToClaude() ([]anthropic.MessageParam, error) {
55 | if x.Version != HistoryVersion {
56 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version))
57 | }
58 | if x.LLType != llmTypeClaude {
59 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not claude", goerr.V("expected", llmTypeClaude), goerr.V("actual", x.LLType))
60 | }
61 | return toClaudeMessages(x.Claude)
62 | }
63 |
64 | func (x *History) ToOpenAI() ([]openai.ChatCompletionMessage, error) {
65 | if x.Version != HistoryVersion {
66 | return nil, goerr.Wrap(ErrHistoryVersionMismatch, "history version is not supported", goerr.V("expected", HistoryVersion), goerr.V("actual", x.Version))
67 | }
68 | if x.LLType != llmTypeOpenAI {
69 | return nil, goerr.Wrap(ErrLLMTypeMismatch, "history is not OpenAI", goerr.V("expected", llmTypeOpenAI), goerr.V("actual", x.LLType))
70 | }
71 | return x.OpenAI, nil
72 | }
73 |
74 | type claudeMessage struct {
75 | Role anthropic.MessageParamRole `json:"role"`
76 | Content []claudeContentBlock `json:"content"`
77 | }
78 |
79 | type claudeContentBlock struct {
80 | Type string `json:"type"`
81 | Text *string `json:"text,omitempty"`
82 | Source *claudeImageSource `json:"source,omitempty"`
83 | ToolUse *claudeToolUse `json:"tool_use,omitempty"`
84 | ToolResult *claudeToolResult `json:"tool_result,omitempty"`
85 | }
86 |
87 | type claudeImageSource struct {
88 | Type string `json:"type"`
89 | MediaType string `json:"media_type,omitempty"`
90 | Data string `json:"data"`
91 | }
92 |
93 | type claudeToolUse struct {
94 | ID string `json:"id"`
95 | Name string `json:"name"`
96 | Input any `json:"input"`
97 | Type string `json:"type"`
98 | }
99 |
100 | type claudeToolResult struct {
101 | ToolUseID string `json:"tool_use_id"`
102 | Content string `json:"content"`
103 | IsError param.Opt[bool] `json:"is_error"`
104 | }
105 |
106 | type geminiMessage struct {
107 | Role string `json:"role"`
108 | Parts []geminiPart `json:"parts"`
109 | }
110 |
111 | type geminiPart struct {
112 | Type string `json:"type"`
113 | Text string `json:"text,omitempty"`
114 | MIMEType string `json:"mime_type,omitempty"`
115 | Data []byte `json:"data,omitempty"`
116 | FileURI string `json:"file_uri,omitempty"`
117 | Name string `json:"name,omitempty"`
118 | Args map[string]interface{} `json:"args,omitempty"`
119 | Response map[string]interface{} `json:"response,omitempty"`
120 | }
121 |
122 | func NewHistoryFromOpenAI(messages []openai.ChatCompletionMessage) *History {
123 | return &History{
124 | LLType: llmTypeOpenAI,
125 | Version: HistoryVersion,
126 | OpenAI: messages,
127 | }
128 | }
129 |
130 | func NewHistoryFromClaude(messages []anthropic.MessageParam) *History {
131 | claudeMessages := make([]claudeMessage, len(messages))
132 | for i, msg := range messages {
133 | content := make([]claudeContentBlock, len(msg.Content))
134 | for j, c := range msg.Content {
135 | if c.OfRequestTextBlock != nil {
136 | content[j] = claudeContentBlock{
137 | Type: "text",
138 | Text: &c.OfRequestTextBlock.Text,
139 | }
140 | } else if c.OfRequestImageBlock != nil {
141 | if c.OfRequestImageBlock.Source.OfBase64ImageSource != nil {
142 | content[j] = claudeContentBlock{
143 | Type: "image",
144 | Source: &claudeImageSource{
145 | Type: string(c.OfRequestImageBlock.Source.OfBase64ImageSource.Type),
146 | MediaType: string(c.OfRequestImageBlock.Source.OfBase64ImageSource.MediaType),
147 | Data: c.OfRequestImageBlock.Source.OfBase64ImageSource.Data,
148 | },
149 | }
150 | }
151 | } else if c.OfRequestToolUseBlock != nil {
152 | content[j] = claudeContentBlock{
153 | Type: "tool_use",
154 | ToolUse: &claudeToolUse{
155 | ID: c.OfRequestToolUseBlock.ID,
156 | Name: c.OfRequestToolUseBlock.Name,
157 | Input: c.OfRequestToolUseBlock.Input,
158 | Type: string(c.OfRequestToolUseBlock.Type),
159 | },
160 | }
161 | } else if c.OfRequestToolResultBlock != nil {
162 | content[j] = claudeContentBlock{
163 | Type: "tool_result",
164 | ToolResult: &claudeToolResult{
165 | ToolUseID: c.OfRequestToolResultBlock.ToolUseID,
166 | Content: c.OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text,
167 | IsError: c.OfRequestToolResultBlock.IsError,
168 | },
169 | }
170 | }
171 | }
172 | claudeMessages[i] = claudeMessage{
173 | Role: msg.Role,
174 | Content: content,
175 | }
176 | }
177 |
178 | return &History{
179 | LLType: llmTypeClaude,
180 | Version: HistoryVersion,
181 | Claude: claudeMessages,
182 | }
183 | }
184 |
185 | func toClaudeMessages(messages []claudeMessage) ([]anthropic.MessageParam, error) {
186 | converted := make([]anthropic.MessageParam, len(messages))
187 |
188 | for i, msg := range messages {
189 | content := make([]anthropic.ContentBlockParamUnion, 0, len(msg.Content))
190 | for _, c := range msg.Content {
191 | switch c.Type {
192 | case "text":
193 | if c.Text == nil {
194 | return nil, goerr.New("text block has no text field")
195 | }
196 | content = append(content, anthropic.ContentBlockParamUnion{
197 | OfRequestTextBlock: &anthropic.TextBlockParam{
198 | Text: *c.Text,
199 | Type: "text",
200 | },
201 | })
202 |
203 | case "image":
204 | if c.Source == nil {
205 | return nil, goerr.New("image block has no source field")
206 | }
207 | if c.Source.Type == "base64" {
208 | content = append(content, anthropic.ContentBlockParamUnion{
209 | OfRequestImageBlock: &anthropic.ImageBlockParam{
210 | Source: anthropic.ImageBlockParamSourceUnion{
211 | OfBase64ImageSource: &anthropic.Base64ImageSourceParam{
212 | Type: "base64",
213 | MediaType: anthropic.Base64ImageSourceMediaType(c.Source.MediaType),
214 | Data: c.Source.Data,
215 | },
216 | },
217 | Type: "image",
218 | },
219 | })
220 | }
221 |
222 | case "tool_use":
223 | if c.ToolUse == nil {
224 | return nil, goerr.New("tool_use block has no tool_use field")
225 | }
226 | content = append(content, anthropic.ContentBlockParamUnion{
227 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{
228 | ID: c.ToolUse.ID,
229 | Name: c.ToolUse.Name,
230 | Input: c.ToolUse.Input,
231 | Type: "tool_use",
232 | },
233 | })
234 |
235 | case "tool_result":
236 | if c.ToolResult == nil {
237 | return nil, goerr.New("tool_result block has no tool_result field")
238 | }
239 | content = append(content, anthropic.ContentBlockParamUnion{
240 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{
241 | ToolUseID: c.ToolResult.ToolUseID,
242 | Content: []anthropic.ToolResultBlockParamContentUnion{
243 | {
244 | OfRequestTextBlock: &anthropic.TextBlockParam{
245 | Text: c.ToolResult.Content,
246 | Type: "text",
247 | },
248 | },
249 | },
250 | IsError: param.NewOpt(c.ToolResult.IsError.Value),
251 | },
252 | })
253 | }
254 | }
255 | converted[i] = anthropic.MessageParam{
256 | Role: msg.Role,
257 | Content: content,
258 | }
259 | }
260 |
261 | return converted, nil
262 | }
263 |
264 | func NewHistoryFromGemini(messages []*genai.Content) *History {
265 | converted := make([]geminiMessage, len(messages))
266 | for i, msg := range messages {
267 | parts := make([]geminiPart, len(msg.Parts))
268 | for j, p := range msg.Parts {
269 | switch v := p.(type) {
270 | case genai.Text:
271 | parts[j] = geminiPart{
272 | Type: "text",
273 | Text: string(v),
274 | }
275 | case genai.Blob:
276 | parts[j] = geminiPart{
277 | Type: "blob",
278 | MIMEType: v.MIMEType,
279 | Data: v.Data,
280 | }
281 | case genai.FileData:
282 | parts[j] = geminiPart{
283 | Type: "file_data",
284 | MIMEType: v.MIMEType,
285 | FileURI: v.FileURI,
286 | }
287 | case genai.FunctionCall:
288 | parts[j] = geminiPart{
289 | Type: "function_call",
290 | Name: v.Name,
291 | Args: v.Args,
292 | }
293 | case genai.FunctionResponse:
294 | parts[j] = geminiPart{
295 | Type: "function_response",
296 | Name: v.Name,
297 | Response: v.Response,
298 | }
299 | }
300 | }
301 | converted[i] = geminiMessage{
302 | Role: msg.Role,
303 | Parts: parts,
304 | }
305 | }
306 | return &History{
307 | LLType: llmTypeGemini,
308 | Version: HistoryVersion,
309 | Gemini: converted,
310 | }
311 | }
312 |
313 | func toGeminiMessages(messages []geminiMessage) ([]*genai.Content, error) {
314 | converted := make([]*genai.Content, len(messages))
315 | for i, msg := range messages {
316 | parts := make([]genai.Part, len(msg.Parts))
317 | for j, p := range msg.Parts {
318 | switch p.Type {
319 | case "text":
320 | parts[j] = genai.Text(p.Text)
321 | case "blob":
322 | parts[j] = genai.Blob{
323 | MIMEType: p.MIMEType,
324 | Data: p.Data,
325 | }
326 | case "file_data":
327 | parts[j] = genai.FileData{
328 | MIMEType: p.MIMEType,
329 | FileURI: p.FileURI,
330 | }
331 | case "function_call":
332 | parts[j] = genai.FunctionCall{
333 | Name: p.Name,
334 | Args: p.Args,
335 | }
336 | case "function_response":
337 | parts[j] = genai.FunctionResponse{
338 | Name: p.Name,
339 | Response: p.Response,
340 | }
341 | }
342 | }
343 | converted[i] = &genai.Content{
344 | Role: msg.Role,
345 | Parts: parts,
346 | }
347 | }
348 | return converted, nil
349 | }
350 |
--------------------------------------------------------------------------------
/history_test.go:
--------------------------------------------------------------------------------
1 | package gollem_test
2 |
3 | import (
4 | "encoding/json"
5 | "testing"
6 |
7 | "cloud.google.com/go/vertexai/genai"
8 | "github.com/anthropics/anthropic-sdk-go"
9 | "github.com/anthropics/anthropic-sdk-go/packages/param"
10 | "github.com/m-mizutani/gollem"
11 | "github.com/m-mizutani/gt"
12 | "github.com/sashabaranov/go-openai"
13 | )
14 |
15 | func TestHistoryOpenAI(t *testing.T) {
16 | // Create OpenAI messages with various content types
17 | messages := []openai.ChatCompletionMessage{
18 | {
19 | Role: "system",
20 | Content: "You are a helpful assistant.",
21 | },
22 | {
23 | Role: "user",
24 | Content: "Hello",
25 | },
26 | {
27 | Role: "assistant",
28 | Content: "Hi, how can I help you?",
29 | },
30 | {
31 | Role: "user",
32 | Content: "What's the weather like?",
33 | },
34 | {
35 | Role: "assistant",
36 | Content: "",
37 | FunctionCall: &openai.FunctionCall{
38 | Name: "get_weather",
39 | Arguments: `{"location": "Tokyo"}`,
40 | },
41 | },
42 | {
43 | Role: "tool",
44 | Name: "get_weather",
45 | Content: `{"temperature": 25, "condition": "sunny"}`,
46 | },
47 | {
48 | Role: "assistant",
49 | Content: "The weather in Tokyo is sunny with a temperature of 25°C.",
50 | },
51 | }
52 |
53 | // Create History object
54 | history := gollem.NewHistoryFromOpenAI(messages)
55 |
56 | // Convert to JSON
57 | data, err := json.Marshal(history)
58 | gt.NoError(t, err)
59 |
60 | // Restore from JSON
61 | var restored gollem.History
62 | gt.NoError(t, json.Unmarshal(data, &restored))
63 |
64 | restoredMessages, err := restored.ToOpenAI()
65 | gt.NoError(t, err)
66 |
67 | gt.Equal(t, messages, restoredMessages)
68 |
69 | // Validate specific message types
70 | gt.Equal(t, "system", restoredMessages[0].Role)
71 | gt.Equal(t, "You are a helpful assistant.", restoredMessages[0].Content)
72 |
73 | gt.Equal(t, "assistant", restoredMessages[2].Role)
74 | gt.Equal(t, "Hi, how can I help you?", restoredMessages[2].Content)
75 |
76 | gt.Equal(t, "assistant", restoredMessages[4].Role)
77 | gt.Equal(t, "", restoredMessages[4].Content)
78 | gt.Equal(t, "get_weather", restoredMessages[4].FunctionCall.Name)
79 | gt.Equal(t, `{"location": "Tokyo"}`, restoredMessages[4].FunctionCall.Arguments)
80 |
81 | gt.Equal(t, "tool", restoredMessages[5].Role)
82 | gt.Equal(t, "get_weather", restoredMessages[5].Name)
83 | gt.Equal(t, `{"temperature": 25, "condition": "sunny"}`, restoredMessages[5].Content)
84 | }
85 |
86 | func TestHistoryClaude(t *testing.T) {
87 | // Create Claude messages with various content types
88 | messages := []anthropic.MessageParam{
89 | {
90 | Role: anthropic.MessageParamRoleUser,
91 | Content: []anthropic.ContentBlockParamUnion{
92 | {
93 | OfRequestTextBlock: &anthropic.TextBlockParam{
94 | Text: "Hello",
95 | Type: "text",
96 | },
97 | },
98 | {
99 | OfRequestImageBlock: &anthropic.ImageBlockParam{
100 | Source: anthropic.ImageBlockParamSourceUnion{
101 | OfBase64ImageSource: &anthropic.Base64ImageSourceParam{
102 | Type: "base64",
103 | MediaType: "image/jpeg",
104 | Data: "base64encodedimage",
105 | },
106 | },
107 | Type: "image",
108 | },
109 | },
110 | },
111 | },
112 | {
113 | Role: anthropic.MessageParamRoleAssistant,
114 | Content: []anthropic.ContentBlockParamUnion{
115 | {
116 | OfRequestTextBlock: &anthropic.TextBlockParam{
117 | Text: "Hi, how can I help you?",
118 | Type: "text",
119 | },
120 | },
121 | },
122 | },
123 | {
124 | Role: anthropic.MessageParamRoleUser,
125 | Content: []anthropic.ContentBlockParamUnion{
126 | {
127 | OfRequestTextBlock: &anthropic.TextBlockParam{
128 | Text: "What's the weather like?",
129 | Type: "text",
130 | },
131 | },
132 | },
133 | },
134 | {
135 | Role: anthropic.MessageParamRoleAssistant,
136 | Content: []anthropic.ContentBlockParamUnion{
137 | {
138 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{
139 | ID: "tool_1",
140 | Name: "get_weather",
141 | Input: `{"location": "Tokyo"}`,
142 | Type: "tool_use",
143 | },
144 | },
145 | },
146 | },
147 | {
148 | Role: anthropic.MessageParamRoleUser,
149 | Content: []anthropic.ContentBlockParamUnion{
150 | {
151 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{
152 | ToolUseID: "tool_2",
153 | Content: []anthropic.ToolResultBlockParamContentUnion{
154 | {
155 | OfRequestTextBlock: &anthropic.TextBlockParam{
156 | Text: `{"temperature": 30, "condition": "cloudy"}`,
157 | Type: "text",
158 | },
159 | },
160 | },
161 | IsError: param.NewOpt(false),
162 | Type: "tool_result",
163 | },
164 | },
165 | },
166 | },
167 | {
168 | Role: anthropic.MessageParamRoleAssistant,
169 | Content: []anthropic.ContentBlockParamUnion{
170 | {
171 | OfRequestTextBlock: &anthropic.TextBlockParam{
172 | Text: "Second message",
173 | Type: "text",
174 | },
175 | },
176 | },
177 | },
178 | {
179 | Role: anthropic.MessageParamRoleUser,
180 | Content: []anthropic.ContentBlockParamUnion{
181 | {
182 | OfRequestToolResultBlock: &anthropic.ToolResultBlockParam{
183 | ToolUseID: "tool_3",
184 | Content: []anthropic.ToolResultBlockParamContentUnion{
185 | {
186 | OfRequestTextBlock: &anthropic.TextBlockParam{
187 | Text: `{"temperature": 35, "condition": "rainy"}`,
188 | Type: "text",
189 | },
190 | },
191 | },
192 | IsError: param.NewOpt(false),
193 | Type: "tool_result",
194 | },
195 | },
196 | },
197 | },
198 | {
199 | Role: anthropic.MessageParamRoleAssistant,
200 | Content: []anthropic.ContentBlockParamUnion{
201 | {
202 | OfRequestTextBlock: &anthropic.TextBlockParam{
203 | Text: "The weather in Tokyo is sunny with a temperature of 25°C.",
204 | Type: "text",
205 | },
206 | },
207 | },
208 | },
209 | }
210 |
211 | // Create History object
212 | history := gollem.NewHistoryFromClaude(messages)
213 |
214 | // Convert to JSON
215 | data, err := json.Marshal(history)
216 | gt.NoError(t, err)
217 |
218 | // Restore from JSON
219 | var restored gollem.History
220 | gt.NoError(t, json.Unmarshal(data, &restored))
221 |
222 | restoredMessages, err := restored.ToClaude()
223 | gt.NoError(t, err)
224 |
225 | // Compare each message individually to make debugging easier
226 | for i := range messages {
227 | gt.Value(t, restoredMessages[i].Role).Equal(messages[i].Role)
228 | gt.Value(t, len(restoredMessages[i].Content)).Equal(len(messages[i].Content))
229 |
230 | for j := range messages[i].Content {
231 | if messages[i].Content[j].OfRequestToolResultBlock != nil {
232 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.ToolUseID).Equal(messages[i].Content[j].OfRequestToolResultBlock.ToolUseID)
233 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.IsError).Equal(messages[i].Content[j].OfRequestToolResultBlock.IsError)
234 | gt.Value(t, len(restoredMessages[i].Content[j].OfRequestToolResultBlock.Content)).Equal(len(messages[i].Content[j].OfRequestToolResultBlock.Content))
235 | gt.Value(t, restoredMessages[i].Content[j].OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text).Equal(messages[i].Content[j].OfRequestToolResultBlock.Content[0].OfRequestTextBlock.Text)
236 | } else {
237 | gt.Value(t, restoredMessages[i].Content[j]).Equal(messages[i].Content[j])
238 | }
239 | }
240 | }
241 | }
242 |
243 | func TestHistoryGemini(t *testing.T) {
244 | // Create Gemini messages with various content types
245 | messages := []*genai.Content{
246 | {
247 | Role: "user",
248 | Parts: []genai.Part{
249 | genai.Text("Hello"),
250 | genai.Blob{
251 | MIMEType: "image/jpeg",
252 | Data: []byte("fake image data"),
253 | },
254 | genai.FileData{
255 | MIMEType: "application/pdf",
256 | FileURI: "gs://bucket/file.pdf",
257 | },
258 | },
259 | },
260 | {
261 | Role: "model",
262 | Parts: []genai.Part{
263 | genai.Text("Hi, how can I help you?"),
264 | genai.FunctionCall{
265 | Name: "test_function",
266 | Args: map[string]interface{}{
267 | "param1": "value1",
268 | "param2": float64(123),
269 | },
270 | },
271 | },
272 | },
273 | {
274 | Role: "model",
275 | Parts: []genai.Part{
276 | genai.Text("Function result"),
277 | genai.FunctionResponse{
278 | Name: "test_function",
279 | Response: map[string]interface{}{
280 | "status": "success",
281 | "result": "operation completed",
282 | },
283 | },
284 | },
285 | },
286 | }
287 |
288 | // Create History object
289 | history := gollem.NewHistoryFromGemini(messages)
290 |
291 | // Convert to JSON
292 | data, err := json.Marshal(history)
293 | gt.NoError(t, err)
294 |
295 | // Restore from JSON
296 | var restored gollem.History
297 | gt.NoError(t, json.Unmarshal(data, &restored))
298 |
299 | restoredMessages, err := restored.ToGemini()
300 | gt.NoError(t, err)
301 | gt.Equal(t, messages, restoredMessages)
302 |
303 | // Validate specific message types
304 | gt.Equal(t, "user", restoredMessages[0].Role)
305 | gt.Equal(t, 3, len(restoredMessages[0].Parts))
306 | gt.Equal(t, "Hello", restoredMessages[0].Parts[0].(genai.Text))
307 | gt.Equal(t, "image/jpeg", restoredMessages[0].Parts[1].(genai.Blob).MIMEType)
308 | gt.Equal(t, "application/pdf", restoredMessages[0].Parts[2].(genai.FileData).MIMEType)
309 | gt.Equal(t, "gs://bucket/file.pdf", restoredMessages[0].Parts[2].(genai.FileData).FileURI)
310 |
311 | gt.Equal(t, "model", restoredMessages[1].Role)
312 | gt.Equal(t, 2, len(restoredMessages[1].Parts))
313 | gt.Equal(t, "Hi, how can I help you?", restoredMessages[1].Parts[0].(genai.Text))
314 | gt.Equal(t, "test_function", restoredMessages[1].Parts[1].(genai.FunctionCall).Name)
315 | gt.Equal(t, "value1", restoredMessages[1].Parts[1].(genai.FunctionCall).Args["param1"])
316 | gt.Equal(t, float64(123), restoredMessages[1].Parts[1].(genai.FunctionCall).Args["param2"].(float64))
317 |
318 | gt.Equal(t, "model", restoredMessages[2].Role)
319 | gt.Equal(t, 2, len(restoredMessages[2].Parts))
320 | gt.Equal(t, "Function result", restoredMessages[2].Parts[0].(genai.Text))
321 | gt.Equal(t, "test_function", restoredMessages[2].Parts[1].(genai.FunctionResponse).Name)
322 | gt.Equal(t, "success", restoredMessages[2].Parts[1].(genai.FunctionResponse).Response["status"])
323 | gt.Equal(t, "operation completed", restoredMessages[2].Parts[1].(genai.FunctionResponse).Response["result"])
324 | }
325 |
--------------------------------------------------------------------------------
/hook.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import "context"
4 |
5 | type (
6 | // LoopHook is a hook for the session loop. "loop" is the loop count, it's 0-indexed. "input" is the current input of the loop. If you want to abort the session loop, you can return an error.
7 | LoopHook func(ctx context.Context, loop int, input []Input) error
8 |
9 | // MessageHook is a hook for the message. If you want to display or record the message, you can use this hook.
10 | MessageHook func(ctx context.Context, msg string) error
11 |
12 | // ToolRequestHook is a hook for the tool request. If you want to display or record the tool request, you can use this hook. If you want to abort the tool execution, you can return an error.
13 | ToolRequestHook func(ctx context.Context, tool FunctionCall) error
14 |
15 | // ToolResponseHook is a hook for the tool response. If you want to display or record the tool response, you can use this hook. If you want to abort the tool execution, you can return an error.
16 | ToolResponseHook func(ctx context.Context, tool FunctionCall, response map[string]any) error
17 |
18 | // ToolErrorHook is a hook for the tool error. If you want to record the tool error, you can use this hook.
19 | ToolErrorHook func(ctx context.Context, err error, tool FunctionCall) error
20 | )
21 |
22 | func defaultLoopHook(ctx context.Context, loop int, input []Input) error {
23 | return nil
24 | }
25 |
26 | func defaultMessageHook(ctx context.Context, msg string) error {
27 | return nil
28 | }
29 |
30 | func defaultToolRequestHook(ctx context.Context, tool FunctionCall) error {
31 | return nil
32 | }
33 |
34 | func defaultToolResponseHook(ctx context.Context, tool FunctionCall, response map[string]any) error {
35 | return nil
36 | }
37 |
38 | func defaultToolErrorHook(ctx context.Context, err error, tool FunctionCall) error {
39 | return nil
40 | }
41 |
--------------------------------------------------------------------------------
/llm.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import "context"
4 |
5 | // LLMClient is a client for each LLM service.
6 | type LLMClient interface {
7 | NewSession(ctx context.Context, options ...SessionOption) (Session, error)
8 | GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error)
9 | }
10 |
11 | type FunctionCall struct {
12 | ID string
13 | Name string
14 | Arguments map[string]any
15 | }
16 |
17 | // Response is a general response type for each gollem.
18 | type Response struct {
19 | Texts []string
20 | FunctionCalls []*FunctionCall
21 |
22 | // Error is an error that occurred during the generation for streaming response.
23 | Error error
24 | }
25 |
26 | func (r *Response) HasData() bool {
27 | return len(r.Texts) > 0 || len(r.FunctionCalls) > 0 || r.Error != nil
28 | }
29 |
30 | type Input interface {
31 | isInput() restrictedValue
32 | }
33 |
34 | type restrictedValue struct{}
35 |
36 | // Text is a text input as prompt.
37 | // Usage:
38 | // input := gollem.Text("Hello, world!")
39 | type Text string
40 |
41 | func (t Text) isInput() restrictedValue {
42 | return restrictedValue{}
43 | }
44 |
45 | // FunctionResponse is a function response.
46 | // Usage:
47 | //
48 | // input := gollem.FunctionResponse{
49 | // Name: "function_name",
50 | // Arguments: map[string]any{"key": "value"},
51 | // }
52 | type FunctionResponse struct {
53 | ID string
54 | Name string
55 | Data map[string]any
56 | Error error
57 | }
58 |
59 | func (f FunctionResponse) isInput() restrictedValue {
60 | return restrictedValue{}
61 | }
62 |
--------------------------------------------------------------------------------
/llm/claude/client.go:
--------------------------------------------------------------------------------
1 | package claude
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "strings"
8 |
9 | "github.com/anthropics/anthropic-sdk-go"
10 | "github.com/anthropics/anthropic-sdk-go/option"
11 | "github.com/m-mizutani/goerr/v2"
12 | "github.com/m-mizutani/gollem"
13 | )
14 |
15 | const (
16 | DefaultEmbeddingModel = "claude-3-sonnet-20240229"
17 | )
18 |
19 | // generationParameters represents the parameters for text generation.
20 | type generationParameters struct {
21 | // Temperature controls randomness in the output.
22 | // Higher values make the output more random, lower values make it more focused.
23 | Temperature float64
24 |
25 | // TopP controls diversity via nucleus sampling.
26 | // Higher values allow more diverse outputs.
27 | TopP float64
28 |
29 | // MaxTokens limits the number of tokens to generate.
30 | MaxTokens int64
31 | }
32 |
33 | // Client is a client for the Claude API.
34 | // It provides methods to interact with Anthropic's Claude models.
35 | type Client struct {
36 | // client is the underlying Claude client.
37 | client *anthropic.Client
38 |
39 | // defaultModel is the model to use for chat completions.
40 | // It can be overridden using WithModel option.
41 | defaultModel string
42 |
43 | // embeddingModel is the model to use for embeddings.
44 | // It can be overridden using WithEmbeddingModel option.
45 | embeddingModel string
46 |
47 | // apiKey is the API key for authentication.
48 | apiKey string
49 |
50 | // generation parameters
51 | params generationParameters
52 |
53 | // systemPrompt is the system prompt to use for chat completions.
54 | systemPrompt string
55 | }
56 |
57 | // Option is a function that configures a Client.
58 | type Option func(*Client)
59 |
60 | // WithModel sets the default model to use for chat completions.
61 | // The model name should be a valid Claude model identifier.
62 | // Default: anthropic.ModelClaude3_5SonnetLatest
63 | func WithModel(modelName string) Option {
64 | return func(c *Client) {
65 | c.defaultModel = modelName
66 | }
67 | }
68 |
69 | // WithEmbeddingModel sets the embedding model to use for embeddings.
70 | // The model name should be a valid Claude model identifier.
71 | // Default: DefaultEmbeddingModel
72 | func WithEmbeddingModel(modelName string) Option {
73 | return func(c *Client) {
74 | c.embeddingModel = modelName
75 | }
76 | }
77 |
78 | // WithTemperature sets the temperature parameter for text generation.
79 | // Higher values make the output more random, lower values make it more focused.
80 | // Range: 0.0 to 1.0
81 | // Default: 0.7
82 | func WithTemperature(temp float64) Option {
83 | return func(c *Client) {
84 | c.params.Temperature = temp
85 | }
86 | }
87 |
88 | // WithTopP sets the top_p parameter for text generation.
89 | // Controls diversity via nucleus sampling.
90 | // Range: 0.0 to 1.0
91 | // Default: 1.0
92 | func WithTopP(topP float64) Option {
93 | return func(c *Client) {
94 | c.params.TopP = topP
95 | }
96 | }
97 |
98 | // WithMaxTokens sets the maximum number of tokens to generate.
99 | // Default: 4096
100 | func WithMaxTokens(maxTokens int64) Option {
101 | return func(c *Client) {
102 | c.params.MaxTokens = maxTokens
103 | }
104 | }
105 |
106 | // WithSystemPrompt sets the system prompt to use for chat completions.
107 | func WithSystemPrompt(prompt string) Option {
108 | return func(c *Client) {
109 | c.systemPrompt = prompt
110 | }
111 | }
112 |
113 | // New creates a new client for the Claude API.
114 | // It requires an API key and can be configured with additional options.
115 | func New(ctx context.Context, apiKey string, options ...Option) (*Client, error) {
116 | client := &Client{
117 | defaultModel: anthropic.ModelClaude3_5SonnetLatest,
118 | embeddingModel: DefaultEmbeddingModel,
119 | apiKey: apiKey,
120 | params: generationParameters{
121 | Temperature: 0.7,
122 | TopP: 1.0,
123 | MaxTokens: 4096,
124 | },
125 | }
126 |
127 | for _, option := range options {
128 | option(client)
129 | }
130 |
131 | newClient := anthropic.NewClient(
132 | option.WithAPIKey(apiKey),
133 | )
134 | client.client = &newClient
135 |
136 | return client, nil
137 | }
138 |
139 | // Session is a session for the Claude chat.
140 | // It maintains the conversation state and handles message generation.
141 | type Session struct {
142 | // client is the underlying Claude client.
143 | client *anthropic.Client
144 |
145 | // defaultModel is the model to use for chat completions.
146 | defaultModel string
147 |
148 | // tools are the available tools for the session.
149 | tools []anthropic.ToolUnionParam
150 |
151 | // messages stores the conversation history.
152 | messages []anthropic.MessageParam
153 |
154 | // generation parameters
155 | params generationParameters
156 |
157 | cfg gollem.SessionConfig
158 | }
159 |
160 | // NewSession creates a new session for the Claude API.
161 | // It converts the provided tools to Claude's tool format and initializes a new chat session.
162 | func (c *Client) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) {
163 | cfg := gollem.NewSessionConfig(options...)
164 |
165 | // Convert gollem.Tool to anthropic.ToolUnionParam
166 | claudeTools := make([]anthropic.ToolUnionParam, len(cfg.Tools()))
167 | for i, tool := range cfg.Tools() {
168 | claudeTools[i] = convertTool(tool)
169 | }
170 |
171 | var messages []anthropic.MessageParam
172 | if cfg.History() != nil {
173 | history, err := cfg.History().ToClaude()
174 | if err != nil {
175 | return nil, goerr.Wrap(err, "failed to convert history to anthropic.MessageParam")
176 | }
177 | messages = append(messages, history...)
178 | }
179 |
180 | session := &Session{
181 | client: c.client,
182 | defaultModel: c.defaultModel,
183 | tools: claudeTools,
184 | params: c.params,
185 | messages: messages,
186 | cfg: cfg,
187 | }
188 |
189 | return session, nil
190 | }
191 |
192 | func (s *Session) History() *gollem.History {
193 | return gollem.NewHistoryFromClaude(s.messages)
194 | }
195 |
196 | // convertInputs converts gollem.Input to Claude messages and tool results
197 | func (s *Session) convertInputs(input ...gollem.Input) ([]anthropic.MessageParam, []anthropic.ContentBlockParamUnion, error) {
198 | var toolResults []anthropic.ContentBlockParamUnion
199 | var messages []anthropic.MessageParam
200 |
201 | for _, in := range input {
202 | switch v := in.(type) {
203 | case gollem.Text:
204 | messages = append(messages, anthropic.NewUserMessage(
205 | anthropic.NewTextBlock(string(v)),
206 | ))
207 |
208 | case gollem.FunctionResponse:
209 | data, err := json.Marshal(v.Data)
210 | if err != nil {
211 | return nil, nil, goerr.Wrap(err, "failed to marshal function response")
212 | }
213 | response := string(data)
214 | if v.Error != nil {
215 | response = fmt.Sprintf(`Error message: %+v`, v.Error)
216 | }
217 | toolResults = append(toolResults, anthropic.NewToolResultBlock(v.ID, response, v.Error != nil))
218 |
219 | default:
220 | return nil, nil, goerr.Wrap(gollem.ErrInvalidParameter, "invalid input")
221 | }
222 | }
223 |
224 | if len(toolResults) > 0 {
225 | messages = append(messages, anthropic.NewUserMessage(toolResults...))
226 | }
227 |
228 | return messages, toolResults, nil
229 | }
230 |
231 | // createRequest creates a message request with the current session state
232 | func (s *Session) createRequest() anthropic.MessageNewParams {
233 | var systemPrompt []anthropic.TextBlockParam
234 | if s.cfg.SystemPrompt() != "" {
235 | systemPrompt = []anthropic.TextBlockParam{
236 | {
237 | Text: s.cfg.SystemPrompt(),
238 | },
239 | }
240 | }
241 |
242 | // Add content type instruction to system prompt
243 | if s.cfg.ContentType() == gollem.ContentTypeJSON {
244 | if len(systemPrompt) > 0 {
245 | systemPrompt[0].Text += "\nPlease format your response as valid JSON."
246 | } else {
247 | systemPrompt = []anthropic.TextBlockParam{
248 | {
249 | Text: "Please format your response as valid JSON.",
250 | },
251 | }
252 | }
253 | }
254 |
255 | return anthropic.MessageNewParams{
256 | Model: s.defaultModel,
257 | MaxTokens: s.params.MaxTokens,
258 | Temperature: anthropic.Float(s.params.Temperature),
259 | TopP: anthropic.Float(s.params.TopP),
260 | Tools: s.tools,
261 | Messages: s.messages,
262 | System: systemPrompt,
263 | }
264 | }
265 |
266 | // processResponse converts Claude response to gollem.Response
267 | func processResponse(resp *anthropic.Message) *gollem.Response {
268 | if len(resp.Content) == 0 {
269 | return &gollem.Response{}
270 | }
271 |
272 | response := &gollem.Response{
273 | Texts: make([]string, 0),
274 | FunctionCalls: make([]*gollem.FunctionCall, 0),
275 | }
276 |
277 | for _, content := range resp.Content {
278 | textBlock := content.AsResponseTextBlock()
279 | if textBlock.Type == "text" {
280 | response.Texts = append(response.Texts, textBlock.Text)
281 | }
282 |
283 | toolUseBlock := content.AsResponseToolUseBlock()
284 | if toolUseBlock.Type == "tool_use" {
285 | var args map[string]interface{}
286 | if err := json.Unmarshal([]byte(toolUseBlock.Input), &args); err != nil {
287 | response.Error = goerr.Wrap(err, "failed to unmarshal function arguments")
288 | return response
289 | }
290 |
291 | response.FunctionCalls = append(response.FunctionCalls, &gollem.FunctionCall{
292 | ID: toolUseBlock.ID,
293 | Name: toolUseBlock.Name,
294 | Arguments: args,
295 | })
296 | }
297 | }
298 |
299 | return response
300 | }
301 |
302 | // GenerateContent processes the input and generates a response.
303 | // It handles both text messages and function responses.
304 | func (s *Session) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) {
305 | messages, _, err := s.convertInputs(input...)
306 | if err != nil {
307 | return nil, err
308 | }
309 |
310 | s.messages = append(s.messages, messages...)
311 | params := s.createRequest()
312 |
313 | resp, err := s.client.Messages.New(ctx, params)
314 | if err != nil {
315 | return nil, goerr.Wrap(err, "failed to create message")
316 | }
317 |
318 | // Add assistant's response to message history
319 | s.messages = append(s.messages, resp.ToParam())
320 |
321 | return processResponse(resp), nil
322 | }
323 |
324 | // FunctionCallAccumulator accumulates function call information from stream
325 | type FunctionCallAccumulator struct {
326 | ID string
327 | Name string
328 | Arguments string
329 | }
330 |
331 | func newFunctionCallAccumulator() *FunctionCallAccumulator {
332 | return &FunctionCallAccumulator{
333 | Arguments: "",
334 | }
335 | }
336 |
337 | func (a *FunctionCallAccumulator) accumulate() (*gollem.FunctionCall, error) {
338 | if a.ID == "" || a.Name == "" {
339 | return nil, goerr.Wrap(gollem.ErrInvalidParameter, "function call is not complete")
340 | }
341 |
342 | var args map[string]any
343 | if a.Arguments != "" {
344 | if err := json.Unmarshal([]byte(a.Arguments), &args); err != nil {
345 | return nil, goerr.Wrap(err, "failed to unmarshal function call arguments", goerr.V("accumulator", a))
346 | }
347 | }
348 |
349 | return &gollem.FunctionCall{
350 | ID: a.ID,
351 | Name: a.Name,
352 | Arguments: args,
353 | }, nil
354 | }
355 |
356 | // GenerateStream processes the input and generates a response stream.
357 | // It handles both text messages and function responses, and returns a channel for streaming responses.
358 | func (s *Session) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) {
359 | messages, _, err := s.convertInputs(input...)
360 | if err != nil {
361 | return nil, err
362 | }
363 |
364 | s.messages = append(s.messages, messages...)
365 | params := s.createRequest()
366 |
367 | stream := s.client.Messages.NewStreaming(ctx, params)
368 | if stream == nil {
369 | return nil, goerr.New("failed to create message stream")
370 | }
371 |
372 | responseChan := make(chan *gollem.Response)
373 |
374 | // Accumulate text and tool calls for message history
375 | var textContent strings.Builder
376 | var toolCalls []anthropic.ContentBlockParamUnion
377 | acc := newFunctionCallAccumulator()
378 |
379 | go func() {
380 | defer close(responseChan)
381 |
382 | for {
383 | if !stream.Next() {
384 | // Add accumulated message to history when stream ends
385 | if textContent.Len() > 0 || len(toolCalls) > 0 {
386 | var content []anthropic.ContentBlockParamUnion
387 | if textContent.Len() > 0 {
388 | content = append(content, anthropic.NewTextBlock(textContent.String()))
389 | }
390 | content = append(content, toolCalls...)
391 | s.messages = append(s.messages, anthropic.NewAssistantMessage(content...))
392 | }
393 | return
394 | }
395 |
396 | event := stream.Current()
397 | response := &gollem.Response{
398 | Texts: make([]string, 0),
399 | FunctionCalls: make([]*gollem.FunctionCall, 0),
400 | }
401 |
402 | switch event.Type {
403 | case "content_block_delta":
404 | deltaEvent := event.AsContentBlockDeltaEvent()
405 | switch deltaEvent.Delta.Type {
406 | case "text_delta":
407 | textDelta := deltaEvent.Delta.AsTextContentBlockDelta()
408 | response.Texts = append(response.Texts, textDelta.Text)
409 | textContent.WriteString(textDelta.Text)
410 | case "input_json_delta":
411 | jsonDelta := deltaEvent.Delta.AsInputJSONContentBlockDelta()
412 | if jsonDelta.PartialJSON != "" {
413 | acc.Arguments += jsonDelta.PartialJSON
414 | }
415 | }
416 | case "content_block_start":
417 | startEvent := event.AsContentBlockStartEvent()
418 | if startEvent.ContentBlock.Type == "tool_use" {
419 | toolUseBlock := startEvent.ContentBlock.AsResponseToolUseBlock()
420 | acc.ID = toolUseBlock.ID
421 | acc.Name = toolUseBlock.Name
422 | }
423 | case "content_block_stop":
424 | if acc.ID != "" && acc.Name != "" {
425 | funcCall, err := acc.accumulate()
426 | if err != nil {
427 | response.Error = err
428 | responseChan <- response
429 | return
430 | }
431 | response.FunctionCalls = append(response.FunctionCalls, funcCall)
432 | toolCalls = append(toolCalls, anthropic.ContentBlockParamUnion{
433 | OfRequestToolUseBlock: &anthropic.ToolUseBlockParam{
434 | ID: funcCall.ID,
435 | Name: funcCall.Name,
436 | Input: funcCall.Arguments,
437 | Type: "tool_use",
438 | },
439 | })
440 | acc = newFunctionCallAccumulator()
441 | }
442 | }
443 |
444 | if response.HasData() {
445 | responseChan <- response
446 | }
447 | }
448 | }()
449 |
450 | return responseChan, nil
451 | }
452 |
--------------------------------------------------------------------------------
/llm/claude/convert.go:
--------------------------------------------------------------------------------
1 | package claude
2 |
3 | import (
4 | "github.com/anthropics/anthropic-sdk-go"
5 | "github.com/m-mizutani/gollem"
6 | )
7 |
8 | func convertTool(tool gollem.Tool) anthropic.ToolUnionParam {
9 | spec := tool.Spec()
10 | schema := convertParametersToJSONSchema(spec.Parameters)
11 |
12 | return anthropic.ToolUnionParamOfTool(
13 | anthropic.ToolInputSchemaParam{
14 | Properties: schema.Properties,
15 | },
16 | spec.Name,
17 | )
18 | }
19 |
20 | type jsonSchema struct {
21 | Type string `json:"type"`
22 | Properties map[string]jsonSchema `json:"properties,omitempty"`
23 | Required []string `json:"required,omitempty"`
24 | Items *jsonSchema `json:"items,omitempty"`
25 | Minimum *float64 `json:"minimum,omitempty"`
26 | Maximum *float64 `json:"maximum,omitempty"`
27 | MinLength *int `json:"minLength,omitempty"`
28 | MaxLength *int `json:"maxLength,omitempty"`
29 | Pattern string `json:"pattern,omitempty"`
30 | MinItems *int `json:"minItems,omitempty"`
31 | MaxItems *int `json:"maxItems,omitempty"`
32 | Default interface{} `json:"default,omitempty"`
33 | Enum []interface{} `json:"enum,omitempty"`
34 | Description string `json:"description,omitempty"`
35 | Title string `json:"title,omitempty"`
36 | }
37 |
38 | func convertParametersToJSONSchema(params map[string]*gollem.Parameter) jsonSchema {
39 | properties := make(map[string]jsonSchema)
40 |
41 | for name, param := range params {
42 | properties[name] = convertParameterToSchema(param)
43 | }
44 |
45 | return jsonSchema{
46 | Type: "object",
47 | Properties: properties,
48 | }
49 | }
50 |
51 | // convertParameterToSchema converts gollem.Parameter to Claude schema
52 | func convertParameterToSchema(param *gollem.Parameter) jsonSchema {
53 | schema := jsonSchema{
54 | Type: getClaudeType(param.Type),
55 | Description: param.Description,
56 | Title: param.Title,
57 | }
58 |
59 | if len(param.Enum) > 0 {
60 | enum := make([]interface{}, len(param.Enum))
61 | for i, v := range param.Enum {
62 | enum[i] = v
63 | }
64 | schema.Enum = enum
65 | }
66 |
67 | if param.Properties != nil {
68 | properties := make(map[string]jsonSchema)
69 | for name, prop := range param.Properties {
70 | properties[name] = convertParameterToSchema(prop)
71 | }
72 | schema.Properties = properties
73 | if len(param.Required) > 0 {
74 | schema.Required = param.Required
75 | }
76 | }
77 |
78 | if param.Items != nil {
79 | items := convertParameterToSchema(param.Items)
80 | schema.Items = &items
81 | }
82 |
83 | // Add number constraints
84 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger {
85 | if param.Minimum != nil {
86 | schema.Minimum = param.Minimum
87 | }
88 | if param.Maximum != nil {
89 | schema.Maximum = param.Maximum
90 | }
91 | }
92 |
93 | // Add string constraints
94 | if param.Type == gollem.TypeString {
95 | if param.MinLength != nil {
96 | schema.MinLength = param.MinLength
97 | }
98 | if param.MaxLength != nil {
99 | schema.MaxLength = param.MaxLength
100 | }
101 | if param.Pattern != "" {
102 | schema.Pattern = param.Pattern
103 | }
104 | }
105 |
106 | // Add array constraints
107 | if param.Type == gollem.TypeArray {
108 | if param.MinItems != nil {
109 | schema.MinItems = param.MinItems
110 | }
111 | if param.MaxItems != nil {
112 | schema.MaxItems = param.MaxItems
113 | }
114 | }
115 |
116 | // Add default value
117 | if param.Default != nil {
118 | schema.Default = param.Default
119 | }
120 |
121 | return schema
122 | }
123 |
124 | func getClaudeType(paramType gollem.ParameterType) string {
125 | switch paramType {
126 | case gollem.TypeString:
127 | return "string"
128 | case gollem.TypeNumber:
129 | return "number"
130 | case gollem.TypeInteger:
131 | return "integer"
132 | case gollem.TypeBoolean:
133 | return "boolean"
134 | case gollem.TypeArray:
135 | return "array"
136 | case gollem.TypeObject:
137 | return "object"
138 | default:
139 | return "string"
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/llm/claude/convert_test.go:
--------------------------------------------------------------------------------
1 | package claude_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/m-mizutani/gollem"
8 | "github.com/m-mizutani/gollem/llm/claude"
9 | "github.com/m-mizutani/gt"
10 | )
11 |
12 | type complexTool struct{}
13 |
14 | func (t *complexTool) Spec() gollem.ToolSpec {
15 | return gollem.ToolSpec{
16 | Name: "complex_tool",
17 | Description: "A tool with complex parameter structure",
18 | Required: []string{"user"},
19 | Parameters: map[string]*gollem.Parameter{
20 | "user": {
21 | Type: gollem.TypeObject,
22 | Required: []string{"name"},
23 | Properties: map[string]*gollem.Parameter{
24 | "name": {
25 | Type: gollem.TypeString,
26 | Description: "User's name",
27 | },
28 | "address": {
29 | Type: gollem.TypeObject,
30 | Properties: map[string]*gollem.Parameter{
31 | "street": {
32 | Type: gollem.TypeString,
33 | Description: "Street address",
34 | },
35 | "city": {
36 | Type: gollem.TypeString,
37 | Description: "City name",
38 | },
39 | },
40 | Required: []string{"street"},
41 | },
42 | },
43 | },
44 | "items": {
45 | Type: gollem.TypeArray,
46 | Items: &gollem.Parameter{
47 | Type: gollem.TypeObject,
48 | Properties: map[string]*gollem.Parameter{
49 | "id": {
50 | Type: gollem.TypeString,
51 | Description: "Item ID",
52 | },
53 | "quantity": {
54 | Type: gollem.TypeNumber,
55 | Description: "Item quantity",
56 | },
57 | },
58 | },
59 | },
60 | },
61 | }
62 | }
63 |
64 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
65 | return nil, nil
66 | }
67 |
68 | func TestConvertTool(t *testing.T) {
69 | tool := &complexTool{}
70 | claudeTool := claude.ConvertTool(tool)
71 |
72 | // Check basic properties
73 | gt.Equal(t, claudeTool.OfTool.Name, "complex_tool")
74 |
75 | // Check schema properties
76 | schemaProps := claudeTool.OfTool.InputSchema.Properties.(map[string]claude.JsonSchema)
77 |
78 | // Check user parameter
79 | user := schemaProps["user"]
80 | gt.Equal(t, user.Type, "object")
81 |
82 | userProps := user.Properties
83 | nameProps := userProps["name"]
84 | gt.Equal(t, nameProps.Type, "string")
85 | gt.Equal(t, nameProps.Description, "User's name")
86 | gt.Array(t, user.Required).Equal([]string{"name"})
87 |
88 | addressProps := userProps["address"].Properties
89 | gt.Equal(t, addressProps["street"].Type, "string")
90 | gt.Equal(t, addressProps["city"].Type, "string")
91 |
92 | // Check items parameter
93 | itemsProp := schemaProps["items"]
94 | gt.Equal(t, itemsProp.Type, "array")
95 |
96 | itemsSchema := *itemsProp.Items
97 | itemsProps := itemsSchema.Properties
98 | gt.Equal(t, itemsProps["id"].Type, "string")
99 | gt.Equal(t, itemsProps["quantity"].Type, "number")
100 | }
101 |
102 | func TestConvertParameterToSchema(t *testing.T) {
103 | type testCase struct {
104 | name string
105 | schema *gollem.Parameter
106 | expected claude.JsonSchema
107 | }
108 |
109 | runTest := func(tc testCase) func(t *testing.T) {
110 | return func(t *testing.T) {
111 | actual := claude.ConvertParameterToSchema(tc.schema)
112 | gt.Value(t, actual).Equal(tc.expected)
113 | }
114 | }
115 |
116 | t.Run("number constraints", runTest(testCase{
117 | name: "number constraints",
118 | schema: &gollem.Parameter{
119 | Type: gollem.TypeNumber,
120 | Minimum: ptr(1.0),
121 | Maximum: ptr(10.0),
122 | },
123 | expected: claude.JsonSchema{
124 | Type: "number",
125 | Minimum: ptr(1.0),
126 | Maximum: ptr(10.0),
127 | },
128 | }))
129 |
130 | t.Run("string constraints", runTest(testCase{
131 | name: "string constraints",
132 | schema: &gollem.Parameter{
133 | Type: gollem.TypeString,
134 | MinLength: ptr(1),
135 | MaxLength: ptr(10),
136 | Pattern: "^[a-z]+$",
137 | },
138 | expected: claude.JsonSchema{
139 | Type: "string",
140 | MinLength: ptr(1),
141 | MaxLength: ptr(10),
142 | Pattern: "^[a-z]+$",
143 | },
144 | }))
145 |
146 | t.Run("array constraints", runTest(testCase{
147 | name: "array constraints",
148 | schema: &gollem.Parameter{
149 | Type: gollem.TypeArray,
150 | Items: &gollem.Parameter{Type: gollem.TypeString},
151 | MinItems: ptr(1),
152 | MaxItems: ptr(10),
153 | },
154 | expected: claude.JsonSchema{
155 | Type: "array",
156 | MinItems: ptr(1),
157 | MaxItems: ptr(10),
158 | Items: &claude.JsonSchema{Type: "string"},
159 | },
160 | }))
161 |
162 | t.Run("default value", runTest(testCase{
163 | name: "default value",
164 | schema: &gollem.Parameter{
165 | Type: gollem.TypeString,
166 | Default: "default value",
167 | },
168 | expected: claude.JsonSchema{
169 | Type: "string",
170 | Default: "default value",
171 | },
172 | }))
173 | }
174 |
175 | func ptr[T any](v T) *T {
176 | return &v
177 | }
178 |
179 | func TestConvertSchema(t *testing.T) {
180 | type testCase struct {
181 | name string
182 | schema *gollem.Parameter
183 | expected claude.JsonSchema
184 | }
185 |
186 | runTest := func(tc testCase) func(t *testing.T) {
187 | return func(t *testing.T) {
188 | actual := claude.ConvertParameterToSchema(tc.schema)
189 | gt.Value(t, actual).Equal(tc.expected)
190 | }
191 | }
192 |
193 | t.Run("string type", runTest(testCase{
194 | name: "string type",
195 | schema: &gollem.Parameter{
196 | Type: gollem.TypeString,
197 | },
198 | expected: claude.JsonSchema{
199 | Type: "string",
200 | },
201 | }))
202 | }
203 |
--------------------------------------------------------------------------------
/llm/claude/embedding.go:
--------------------------------------------------------------------------------
1 | package claude
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/m-mizutani/goerr/v2"
7 | )
8 |
9 | // GenerateEmbedding generates embeddings for the given input text. Claude does not support emmbedding generation directly.
10 | func (c *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) {
11 | return nil, goerr.New("Claude does not support embedding generation")
12 | }
13 |
--------------------------------------------------------------------------------
/llm/claude/export_test.go:
--------------------------------------------------------------------------------
1 | package claude
2 |
3 | // Export convert functions for testing
4 | var (
5 | ConvertTool = convertTool
6 | ConvertParameterToSchema = convertParameterToSchema
7 | )
8 |
9 | type JsonSchema = jsonSchema
10 |
--------------------------------------------------------------------------------
/llm/gemini/client.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | import (
4 | "context"
5 | "fmt"
6 |
7 | "cloud.google.com/go/vertexai/genai"
8 | "github.com/m-mizutani/goerr/v2"
9 | "github.com/m-mizutani/gollem"
10 | "google.golang.org/api/iterator"
11 | "google.golang.org/api/option"
12 | )
13 |
14 | const (
15 | DefaultModel = "gemini-2.0-flash"
16 | DefaultEmbeddingModel = "text-embedding-004"
17 | )
18 |
19 | // Client is a client for the Gemini API.
20 | // It provides methods to interact with Google's Gemini models.
21 | type Client struct {
22 | projectID string
23 | location string
24 |
25 | // client is the underlying Gemini client.
26 | client *genai.Client
27 |
28 | // defaultModel is the model to use for chat completions.
29 | // It can be overridden using WithModel option.
30 | defaultModel string
31 |
32 | // embeddingModel is the model to use for embeddings.
33 | // It can be overridden using WithEmbeddingModel option.
34 | embeddingModel string
35 |
36 | // gcpOptions are additional options for Google Cloud Platform.
37 | // They can be set using WithGoogleCloudOptions.
38 | gcpOptions []option.ClientOption
39 |
40 | // generationConfig contains the default generation parameters
41 | generationConfig genai.GenerationConfig
42 |
43 | // systemPrompt is the system prompt to use for chat completions.
44 | systemPrompt string
45 |
46 | // contentType is the type of content to be generated.
47 | contentType gollem.ContentType
48 | }
49 |
50 | // Option is a function that configures a Client.
51 | type Option func(*Client)
52 |
53 | // WithModel sets the default model to use for chat completions.
54 | // The model name should be a valid Gemini model identifier.
55 | // Default: "gemini-2.0-flash"
56 | func WithModel(modelName string) Option {
57 | return func(c *Client) {
58 | c.defaultModel = modelName
59 | }
60 | }
61 |
62 | // WithGoogleCloudOptions sets additional options for Google Cloud Platform.
63 | // These options are passed to the underlying Gemini client.
64 | func WithGoogleCloudOptions(options ...option.ClientOption) Option {
65 | return func(c *Client) {
66 | c.gcpOptions = options
67 | }
68 | }
69 |
70 | // WithTemperature sets the temperature parameter for text generation.
71 | // Higher values make the output more random, lower values make it more focused.
72 | // Range: 0.0 to 1.0
73 | // Default: 0.7
74 | func WithTemperature(temp float32) Option {
75 | return func(c *Client) {
76 | c.generationConfig.Temperature = &temp
77 | }
78 | }
79 |
80 | // WithTopP sets the top_p parameter for text generation.
81 | // Controls diversity via nucleus sampling.
82 | // Range: 0.0 to 1.0
83 | // Default: 1.0
84 | func WithTopP(topP float32) Option {
85 | return func(c *Client) {
86 | c.generationConfig.TopP = &topP
87 | }
88 | }
89 |
90 | // WithTopK sets the top_k parameter for text generation.
91 | // Controls diversity via top-k sampling.
92 | // Range: 1 to 40
93 | func WithTopK(topK int32) Option {
94 | return func(c *Client) {
95 | c.generationConfig.TopK = &topK
96 | }
97 | }
98 |
99 | // WithMaxTokens sets the maximum number of tokens to generate.
100 | func WithMaxTokens(maxTokens int32) Option {
101 | return func(c *Client) {
102 | c.generationConfig.MaxOutputTokens = &maxTokens
103 | }
104 | }
105 |
106 | // WithStopSequences sets the stop sequences for text generation.
107 | func WithStopSequences(stopSequences []string) Option {
108 | return func(c *Client) {
109 | c.generationConfig.StopSequences = stopSequences
110 | }
111 | }
112 |
113 | // WithSystemPrompt sets the system prompt to use for chat completions.
114 | func WithSystemPrompt(prompt string) Option {
115 | return func(c *Client) {
116 | c.systemPrompt = prompt
117 | }
118 | }
119 |
120 | // WithContentType sets the content type for text generation.
121 | // This determines the format of the generated content.
122 | func WithContentType(contentType gollem.ContentType) Option {
123 | return func(c *Client) {
124 | c.contentType = contentType
125 | }
126 | }
127 |
128 | // WithEmbeddingModel sets the model to use for embeddings.
129 | // Default: "textembedding-gecko@latest"
130 | func WithEmbeddingModel(modelName string) Option {
131 | return func(c *Client) {
132 | c.embeddingModel = modelName
133 | }
134 | }
135 |
136 | // New creates a new client for the Gemini API.
137 | // It requires a project ID and location, and can be configured with additional options.
138 | func New(ctx context.Context, projectID, location string, options ...Option) (*Client, error) {
139 | if projectID == "" {
140 | return nil, goerr.New("projectID is required")
141 | }
142 | if location == "" {
143 | return nil, goerr.New("location is required")
144 | }
145 |
146 | client := &Client{
147 | projectID: projectID,
148 | location: location,
149 | defaultModel: DefaultModel,
150 | embeddingModel: DefaultEmbeddingModel,
151 | contentType: gollem.ContentTypeText,
152 | }
153 |
154 | for _, option := range options {
155 | option(client)
156 | }
157 |
158 | newClient, err := genai.NewClient(ctx, projectID, location, client.gcpOptions...)
159 | if err != nil {
160 | return nil, err
161 | }
162 |
163 | client.client = newClient
164 |
165 | return client, nil
166 | }
167 |
168 | // NewSession creates a new session for the Gemini API.
169 | // It converts the provided tools to Gemini's tool format and initializes a new chat session.
170 | func (c *Client) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) {
171 | cfg := gollem.NewSessionConfig(options...)
172 |
173 | // Convert gollem.Tool to *genai.Tool
174 | genaiFunctions := make([]*genai.FunctionDeclaration, len(cfg.Tools()))
175 | for i, tool := range cfg.Tools() {
176 | genaiFunctions[i] = convertTool(tool)
177 | }
178 |
179 | var messages []*genai.Content
180 |
181 | if cfg.History() != nil {
182 | history, err := cfg.History().ToGemini()
183 | if err != nil {
184 | return nil, goerr.Wrap(err, "failed to convert history to gemini.Content")
185 | }
186 | messages = append(messages, history...)
187 | }
188 |
189 | model := c.client.GenerativeModel(c.defaultModel)
190 | model.GenerationConfig = c.generationConfig
191 |
192 | switch cfg.ContentType() {
193 | case gollem.ContentTypeJSON:
194 | model.GenerationConfig.ResponseMIMEType = "application/json"
195 | case gollem.ContentTypeText:
196 | model.GenerationConfig.ResponseMIMEType = "text/plain"
197 | }
198 |
199 | if cfg.SystemPrompt() != "" {
200 | model.SystemInstruction = &genai.Content{
201 | Role: "system",
202 | Parts: []genai.Part{genai.Text(cfg.SystemPrompt())},
203 | }
204 | }
205 |
206 | if len(genaiFunctions) > 0 {
207 | model.Tools = []*genai.Tool{
208 | {
209 | FunctionDeclarations: genaiFunctions,
210 | },
211 | }
212 | }
213 |
214 | session := &Session{
215 | session: model.StartChat(),
216 | }
217 | if len(messages) > 0 {
218 | session.session.History = messages
219 | }
220 |
221 | return session, nil
222 | }
223 |
224 | func (s *Session) History() *gollem.History {
225 | return gollem.NewHistoryFromGemini(s.session.History)
226 | }
227 |
228 | // Session is a session for the Gemini chat.
229 | // It maintains the conversation state and handles message generation.
230 | type Session struct {
231 | // session is the underlying Gemini chat session.
232 | session *genai.ChatSession
233 | }
234 |
235 | // convertInputs converts gollem.Input to Gemini parts
236 | func (s *Session) convertInputs(input ...gollem.Input) ([]genai.Part, error) {
237 | parts := make([]genai.Part, len(input))
238 | for i, in := range input {
239 | switch v := in.(type) {
240 | case gollem.Text:
241 | parts[i] = genai.Text(string(v))
242 | case gollem.FunctionResponse:
243 | if v.Error != nil {
244 | parts[i] = genai.FunctionResponse{
245 | Name: v.Name,
246 | Response: map[string]any{
247 | "error_message": fmt.Sprintf("%+v", v.Error),
248 | },
249 | }
250 | } else {
251 | parts[i] = genai.FunctionResponse{
252 | Name: v.Name,
253 | Response: v.Data,
254 | }
255 | }
256 | default:
257 | return nil, goerr.Wrap(gollem.ErrInvalidParameter, "invalid input")
258 | }
259 | }
260 | return parts, nil
261 | }
262 |
263 | // processResponse converts Gemini response to gollem.Response
264 | func processResponse(resp *genai.GenerateContentResponse) *gollem.Response {
265 | if len(resp.Candidates) == 0 {
266 | return &gollem.Response{}
267 | }
268 |
269 | response := &gollem.Response{
270 | Texts: make([]string, 0),
271 | FunctionCalls: make([]*gollem.FunctionCall, 0),
272 | }
273 |
274 | for _, candidate := range resp.Candidates {
275 | for _, part := range candidate.Content.Parts {
276 | switch v := part.(type) {
277 | case genai.Text:
278 | response.Texts = append(response.Texts, string(v))
279 | case genai.FunctionCall:
280 | response.FunctionCalls = append(response.FunctionCalls, &gollem.FunctionCall{
281 | Name: v.Name,
282 | Arguments: v.Args,
283 | })
284 | }
285 | }
286 | }
287 |
288 | return response
289 | }
290 |
291 | // GenerateContent processes the input and generates a response.
292 | // It handles both text messages and function responses.
293 | func (s *Session) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) {
294 | parts, err := s.convertInputs(input...)
295 | if err != nil {
296 | return nil, err
297 | }
298 |
299 | resp, err := s.session.SendMessage(ctx, parts...)
300 | if err != nil {
301 | return nil, goerr.Wrap(err, "failed to send message")
302 | }
303 |
304 | return processResponse(resp), nil
305 | }
306 |
307 | // GenerateStream processes the input and generates a response stream.
308 | // It handles both text messages and function responses, and returns a channel for streaming responses.
309 | func (s *Session) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) {
310 | parts, err := s.convertInputs(input...)
311 | if err != nil {
312 | return nil, err
313 | }
314 |
315 | iter := s.session.SendMessageStream(ctx, parts...)
316 | responseChan := make(chan *gollem.Response)
317 |
318 | go func() {
319 | defer close(responseChan)
320 |
321 | for {
322 | resp, err := iter.Next()
323 | if err != nil {
324 | if err == iterator.Done {
325 | return
326 | }
327 | responseChan <- &gollem.Response{
328 | Error: goerr.Wrap(err, "failed to generate stream"),
329 | }
330 | return
331 | }
332 |
333 | responseChan <- processResponse(resp)
334 | }
335 | }()
336 |
337 | return responseChan, nil
338 | }
339 |
--------------------------------------------------------------------------------
/llm/gemini/convert.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | import (
4 | "cloud.google.com/go/vertexai/genai"
5 | "github.com/m-mizutani/gollem"
6 | )
7 |
8 | // convertTool converts gollem.Tool to Gemini tool
9 | func convertTool(tool gollem.Tool) *genai.FunctionDeclaration {
10 | spec := tool.Spec()
11 | parameters := &genai.Schema{
12 | Type: genai.TypeObject,
13 | Properties: make(map[string]*genai.Schema),
14 | Required: spec.Required,
15 | }
16 |
17 | for name, param := range spec.Parameters {
18 | parameters.Properties[name] = convertParameterToSchema(param)
19 | }
20 |
21 | return &genai.FunctionDeclaration{
22 | Name: spec.Name,
23 | Description: spec.Description,
24 | Parameters: parameters,
25 | }
26 | }
27 |
28 | // convertParameterToSchema converts gollem.Parameter to Gemini schema
29 | func convertParameterToSchema(param *gollem.Parameter) *genai.Schema {
30 | schema := &genai.Schema{
31 | Type: getGeminiType(param.Type),
32 | Description: param.Description,
33 | Title: param.Title,
34 | }
35 |
36 | if len(param.Enum) > 0 {
37 | schema.Enum = param.Enum
38 | }
39 |
40 | if param.Properties != nil {
41 | schema.Properties = make(map[string]*genai.Schema)
42 | for name, prop := range param.Properties {
43 | schema.Properties[name] = convertParameterToSchema(prop)
44 | }
45 | if len(param.Required) > 0 {
46 | schema.Required = param.Required
47 | }
48 | }
49 |
50 | if param.Items != nil {
51 | schema.Items = convertParameterToSchema(param.Items)
52 | }
53 |
54 | // Add number constraints
55 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger {
56 | if param.Minimum != nil {
57 | schema.Minimum = *param.Minimum
58 | }
59 | if param.Maximum != nil {
60 | schema.Maximum = *param.Maximum
61 | }
62 | }
63 |
64 | // Add string constraints
65 | if param.Type == gollem.TypeString {
66 | if param.MinLength != nil {
67 | schema.MinLength = int64(*param.MinLength)
68 | }
69 | if param.MaxLength != nil {
70 | schema.MaxLength = int64(*param.MaxLength)
71 | }
72 | if param.Pattern != "" {
73 | schema.Pattern = param.Pattern
74 | }
75 | }
76 |
77 | // Add array constraints
78 | if param.Type == gollem.TypeArray {
79 | if param.MinItems != nil {
80 | schema.MinItems = int64(*param.MinItems)
81 | }
82 | if param.MaxItems != nil {
83 | schema.MaxItems = int64(*param.MaxItems)
84 | }
85 | }
86 |
87 | // No default value in Gemini
88 |
89 | return schema
90 | }
91 |
92 | func getGeminiType(paramType gollem.ParameterType) genai.Type {
93 | switch paramType {
94 | case gollem.TypeString:
95 | return genai.TypeString
96 | case gollem.TypeNumber:
97 | return genai.TypeNumber
98 | case gollem.TypeInteger:
99 | return genai.TypeInteger
100 | case gollem.TypeBoolean:
101 | return genai.TypeBoolean
102 | case gollem.TypeArray:
103 | return genai.TypeArray
104 | case gollem.TypeObject:
105 | return genai.TypeObject
106 | default:
107 | return genai.TypeString
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/llm/gemini/convert_test.go:
--------------------------------------------------------------------------------
1 | package gemini_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "cloud.google.com/go/vertexai/genai"
8 | "github.com/m-mizutani/gollem"
9 | "github.com/m-mizutani/gollem/llm/gemini"
10 | "github.com/m-mizutani/gt"
11 | )
12 |
13 | type complexTool struct{}
14 |
15 | func (t *complexTool) Spec() gollem.ToolSpec {
16 | return gollem.ToolSpec{
17 | Name: "complex_tool",
18 | Description: "A tool with complex parameter structure",
19 | Required: []string{"user", "items"},
20 | Parameters: map[string]*gollem.Parameter{
21 | "user": {
22 | Type: gollem.TypeObject,
23 | Required: []string{"name"},
24 | Properties: map[string]*gollem.Parameter{
25 | "name": {
26 | Type: gollem.TypeString,
27 | Description: "User's name",
28 | },
29 | "address": {
30 | Type: gollem.TypeObject,
31 | Properties: map[string]*gollem.Parameter{
32 | "street": {
33 | Type: gollem.TypeString,
34 | Description: "Street address",
35 | },
36 | "city": {
37 | Type: gollem.TypeString,
38 | Description: "City name",
39 | },
40 | },
41 | },
42 | },
43 | },
44 | "items": {
45 | Type: gollem.TypeArray,
46 | Items: &gollem.Parameter{
47 | Type: gollem.TypeObject,
48 | Properties: map[string]*gollem.Parameter{
49 | "id": {
50 | Type: gollem.TypeString,
51 | Description: "Item ID",
52 | },
53 | "quantity": {
54 | Type: gollem.TypeNumber,
55 | Description: "Item quantity",
56 | },
57 | },
58 | },
59 | },
60 | },
61 | }
62 | }
63 |
64 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
65 | return nil, nil
66 | }
67 |
68 | func TestConvertTool(t *testing.T) {
69 | tool := &complexTool{}
70 | genaiTool := gemini.ConvertTool(tool)
71 |
72 | gt.Value(t, genaiTool.Name).Equal("complex_tool")
73 | gt.Value(t, genaiTool.Description).Equal("A tool with complex parameter structure")
74 |
75 | params := genaiTool.Parameters
76 | gt.Value(t, params.Type).Equal(genai.TypeObject)
77 | gt.Value(t, params.Required).Equal([]string{"user", "items"})
78 |
79 | // Check user object
80 | user := params.Properties["user"]
81 | gt.Value(t, user.Type).Equal(genai.TypeObject)
82 | gt.Value(t, user.Properties["name"].Type).Equal(genai.TypeString)
83 | gt.Value(t, user.Properties["name"].Description).Equal("User's name")
84 | gt.Value(t, user.Required).Equal([]string{"name"})
85 |
86 | // Check address object
87 | address := user.Properties["address"]
88 | gt.Value(t, address.Type).Equal(genai.TypeObject)
89 | gt.Value(t, address.Properties["street"].Type).Equal(genai.TypeString)
90 | gt.Value(t, address.Properties["city"].Type).Equal(genai.TypeString)
91 |
92 | // Check items array
93 | items := params.Properties["items"]
94 | gt.Value(t, items.Type).Equal(genai.TypeArray)
95 | gt.Value(t, items.Items.Type).Equal(genai.TypeObject)
96 | gt.Value(t, items.Items.Properties["id"].Type).Equal(genai.TypeString)
97 | gt.Value(t, items.Items.Properties["quantity"].Type).Equal(genai.TypeNumber)
98 | }
99 |
100 | func TestConvertParameterToSchema(t *testing.T) {
101 | t.Run("number constraints", func(t *testing.T) {
102 | p := &gollem.Parameter{
103 | Type: gollem.TypeNumber,
104 | Minimum: ptr(1.0),
105 | Maximum: ptr(10.0),
106 | }
107 | schema := gemini.ConvertParameterToSchema(p)
108 | gt.Value(t, schema.Minimum).Equal(1.0)
109 | gt.Value(t, schema.Maximum).Equal(10.0)
110 | })
111 |
112 | t.Run("string constraints", func(t *testing.T) {
113 | p := &gollem.Parameter{
114 | Type: gollem.TypeString,
115 | MinLength: ptr(1),
116 | MaxLength: ptr(10),
117 | Pattern: "^[a-z]+$",
118 | }
119 | schema := gemini.ConvertParameterToSchema(p)
120 | gt.Value(t, schema.MinLength).Equal(int64(1))
121 | gt.Value(t, schema.MaxLength).Equal(int64(10))
122 | gt.Value(t, schema.Pattern).Equal("^[a-z]+$")
123 | })
124 |
125 | t.Run("array constraints", func(t *testing.T) {
126 | p := &gollem.Parameter{
127 | Type: gollem.TypeArray,
128 | Items: &gollem.Parameter{Type: gollem.TypeString},
129 | MinItems: ptr(1),
130 | MaxItems: ptr(10),
131 | }
132 | schema := gemini.ConvertParameterToSchema(p)
133 | gt.Value(t, schema.MinItems).Equal(int64(1))
134 | gt.Value(t, schema.MaxItems).Equal(int64(10))
135 | gt.Value(t, schema.Items.Type).Equal(genai.TypeString)
136 | })
137 | }
138 |
139 | func ptr[T any](v T) *T {
140 | return &v
141 | }
142 |
--------------------------------------------------------------------------------
/llm/gemini/embedding.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | import (
4 | "context"
5 | "fmt"
6 |
7 | aiplatform "cloud.google.com/go/aiplatform/apiv1"
8 | "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
9 | "github.com/m-mizutani/goerr/v2"
10 | "google.golang.org/api/option"
11 | "google.golang.org/protobuf/types/known/structpb"
12 | )
13 |
14 | // GenerateEmbedding generates embeddings for the given input text.
15 | func (x *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) {
16 | apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", x.location)
17 |
18 | client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
19 | if err != nil {
20 | return nil, goerr.Wrap(err, "failed to create aiplatform client")
21 | }
22 | defer client.Close()
23 |
24 | endpoint := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", x.projectID, x.location, x.embeddingModel)
25 | instances := make([]*structpb.Value, len(input))
26 |
27 | for i, v := range input {
28 | instances[i] = structpb.NewStructValue(&structpb.Struct{
29 | Fields: map[string]*structpb.Value{
30 | "content": structpb.NewStringValue(v),
31 | "task_type": structpb.NewStringValue("QUESTION_ANSWERING"),
32 | },
33 | })
34 | }
35 |
36 | params := structpb.NewStructValue(&structpb.Struct{
37 | Fields: map[string]*structpb.Value{
38 | "outputDimensionality": structpb.NewNumberValue(float64(dimension)),
39 | },
40 | })
41 |
42 | req := &aiplatformpb.PredictRequest{
43 | Endpoint: endpoint,
44 | Instances: instances,
45 | Parameters: params,
46 | }
47 | resp, err := client.Predict(ctx, req)
48 | if err != nil {
49 | return nil, goerr.Wrap(err, "failed to predict",
50 | goerr.V("endpoint", endpoint),
51 | goerr.V("dimensionality", dimension),
52 | )
53 | }
54 |
55 | if len(resp.Predictions) == 0 {
56 | return nil, goerr.New("no predictions returned")
57 | }
58 |
59 | embeddings := make([][]float64, len(resp.Predictions))
60 | for i, prediction := range resp.Predictions {
61 | values := prediction.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
62 | embedding := make([]float64, len(values))
63 | for j, value := range values {
64 | embedding[j] = float64(value.GetNumberValue())
65 | }
66 | embeddings[i] = embedding
67 | }
68 |
69 | return embeddings, nil
70 | }
71 |
--------------------------------------------------------------------------------
/llm/gemini/embeding_test.go:
--------------------------------------------------------------------------------
1 | package gemini_test
2 |
3 | import (
4 | "os"
5 | "testing"
6 |
7 | "github.com/m-mizutani/gollem/llm/gemini"
8 | "github.com/m-mizutani/gt"
9 | )
10 |
11 | func TestGenerateEmbedding(t *testing.T) {
12 | projectID, ok := os.LookupEnv("TEST_GCP_PROJECT_ID")
13 | if !ok {
14 | t.Skip("TEST_GCP_PROJECT_ID is not set")
15 | }
16 |
17 | location, ok := os.LookupEnv("TEST_GCP_LOCATION")
18 | if !ok {
19 | t.Skip("TEST_GCP_LOCATION is not set")
20 | }
21 |
22 | ctx := t.Context()
23 | client, err := gemini.New(ctx, projectID, location)
24 | if err != nil {
25 | t.Fatalf("failed to create client: %v", err)
26 | }
27 |
28 | embeddings, err := client.GenerateEmbedding(ctx, 256, []string{"not, SANE", "Five timeless words"})
29 | if err != nil {
30 | t.Fatalf("failed to generate embedding: %v", err)
31 | }
32 |
33 | gt.A(t, embeddings).Length(2).
34 | At(0, func(t testing.TB, v []float64) {
35 | gt.A(t, v).Longer(0)
36 | }).
37 | At(1, func(t testing.TB, v []float64) {
38 | gt.A(t, v).Longer(0)
39 | })
40 | }
41 |
--------------------------------------------------------------------------------
/llm/gemini/export_test.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | // Export convert functions for testing
4 | var (
5 | ConvertTool = convertTool
6 | ConvertParameterToSchema = convertParameterToSchema
7 | )
8 |
--------------------------------------------------------------------------------
/llm/openai/convert.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | import (
4 | "github.com/m-mizutani/gollem"
5 | "github.com/sashabaranov/go-openai"
6 | )
7 |
8 | // convertTool converts gollem.Tool to openai.Tool
9 | func convertTool(tool gollem.Tool) openai.Tool {
10 | parameters := make(map[string]interface{})
11 | properties := make(map[string]interface{})
12 | spec := tool.Spec()
13 |
14 | for name, param := range spec.Parameters {
15 | properties[name] = convertParameterToSchema(param)
16 | }
17 |
18 | if len(properties) > 0 {
19 | parameters["type"] = "object"
20 | parameters["properties"] = properties
21 | if len(spec.Required) > 0 {
22 | parameters["required"] = spec.Required
23 | }
24 | }
25 |
26 | return openai.Tool{
27 | Type: openai.ToolTypeFunction,
28 | Function: &openai.FunctionDefinition{
29 | Name: spec.Name,
30 | Description: spec.Description,
31 | Parameters: parameters,
32 | },
33 | }
34 | }
35 |
36 | // convertParameterToSchema converts gollem.Parameter to OpenAI schema
37 | func convertParameterToSchema(param *gollem.Parameter) map[string]interface{} {
38 | schema := map[string]interface{}{
39 | "type": getOpenAIType(param.Type),
40 | "description": param.Description,
41 | "title": param.Title,
42 | }
43 |
44 | if len(param.Enum) > 0 {
45 | schema["enum"] = param.Enum
46 | }
47 |
48 | if param.Properties != nil {
49 | properties := make(map[string]interface{})
50 | for name, prop := range param.Properties {
51 | properties[name] = convertParameterToSchema(prop)
52 | }
53 | schema["properties"] = properties
54 | if len(param.Required) > 0 {
55 | schema["required"] = param.Required
56 | }
57 | }
58 |
59 | if param.Items != nil {
60 | schema["items"] = convertParameterToSchema(param.Items)
61 | }
62 |
63 | // Add number constraints
64 | if param.Type == gollem.TypeNumber || param.Type == gollem.TypeInteger {
65 | if param.Minimum != nil {
66 | schema["minimum"] = *param.Minimum
67 | }
68 | if param.Maximum != nil {
69 | schema["maximum"] = *param.Maximum
70 | }
71 | }
72 |
73 | // Add string constraints
74 | if param.Type == gollem.TypeString {
75 | if param.MinLength != nil {
76 | schema["minLength"] = *param.MinLength
77 | }
78 | if param.MaxLength != nil {
79 | schema["maxLength"] = *param.MaxLength
80 | }
81 | if param.Pattern != "" {
82 | schema["pattern"] = param.Pattern
83 | }
84 | }
85 |
86 | // Add array constraints
87 | if param.Type == gollem.TypeArray {
88 | if param.MinItems != nil {
89 | schema["minItems"] = *param.MinItems
90 | }
91 | if param.MaxItems != nil {
92 | schema["maxItems"] = *param.MaxItems
93 | }
94 | }
95 |
96 | // Add default value
97 | if param.Default != nil {
98 | schema["default"] = param.Default
99 | }
100 |
101 | return schema
102 | }
103 |
104 | func getOpenAIType(paramType gollem.ParameterType) string {
105 | switch paramType {
106 | case gollem.TypeString:
107 | return "string"
108 | case gollem.TypeNumber:
109 | return "number"
110 | case gollem.TypeInteger:
111 | return "integer"
112 | case gollem.TypeBoolean:
113 | return "boolean"
114 | case gollem.TypeArray:
115 | return "array"
116 | case gollem.TypeObject:
117 | return "object"
118 | default:
119 | return "string"
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/llm/openai/convert_test.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/m-mizutani/gollem"
8 | "github.com/m-mizutani/gt"
9 | )
10 |
11 | type complexTool struct{}
12 |
13 | func (t *complexTool) Spec() gollem.ToolSpec {
14 | return gollem.ToolSpec{
15 | Name: "complex_tool",
16 | Description: "A tool with complex parameter structure",
17 | Parameters: map[string]*gollem.Parameter{
18 | "user": {
19 | Type: gollem.TypeObject,
20 | Required: []string{"name"},
21 | Properties: map[string]*gollem.Parameter{
22 | "name": {
23 | Type: gollem.TypeString,
24 | Description: "User's name",
25 | },
26 | "address": {
27 | Type: gollem.TypeObject,
28 | Properties: map[string]*gollem.Parameter{
29 | "street": {
30 | Type: gollem.TypeString,
31 | Description: "Street address",
32 | },
33 | "city": {
34 | Type: gollem.TypeString,
35 | Description: "City name",
36 | },
37 | },
38 | },
39 | },
40 | },
41 | "items": {
42 | Type: gollem.TypeArray,
43 | Items: &gollem.Parameter{
44 | Type: gollem.TypeObject,
45 | Properties: map[string]*gollem.Parameter{
46 | "id": {
47 | Type: gollem.TypeString,
48 | Description: "Item ID",
49 | },
50 | "quantity": {
51 | Type: gollem.TypeNumber,
52 | Description: "Item quantity",
53 | },
54 | },
55 | },
56 | },
57 | },
58 | }
59 | }
60 |
61 | func (t *complexTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
62 | return nil, nil
63 | }
64 |
65 | func TestConvertTool(t *testing.T) {
66 | tool := &complexTool{}
67 | openaiTool := ConvertTool(tool)
68 |
69 | gt.Value(t, openaiTool.Type).Equal("function")
70 | gt.Value(t, openaiTool.Function.Name).Equal("complex_tool")
71 | gt.Value(t, openaiTool.Function.Description).Equal("A tool with complex parameter structure")
72 |
73 | params := openaiTool.Function.Parameters.(map[string]interface{})
74 | gt.Value(t, params["type"]).Equal("object")
75 |
76 | // Check user object
77 | user := params["properties"].(map[string]interface{})["user"].(map[string]interface{})
78 | gt.Value(t, user["type"]).Equal("object")
79 | gt.Value(t, user["properties"].(map[string]interface{})["name"].(map[string]interface{})["type"]).Equal("string")
80 | gt.Value(t, user["properties"].(map[string]interface{})["name"].(map[string]interface{})["description"]).Equal("User's name")
81 | gt.Value(t, user["required"]).Equal([]string{"name"})
82 |
83 | // Check address object
84 | address := user["properties"].(map[string]interface{})["address"].(map[string]interface{})
85 | gt.Value(t, address["type"]).Equal("object")
86 | gt.Value(t, address["properties"].(map[string]interface{})["street"].(map[string]interface{})["type"]).Equal("string")
87 | gt.Value(t, address["properties"].(map[string]interface{})["city"].(map[string]interface{})["type"]).Equal("string")
88 |
89 | // Check items array
90 | items := params["properties"].(map[string]interface{})["items"].(map[string]interface{})
91 | gt.Value(t, items["type"]).Equal("array")
92 | gt.Value(t, items["items"].(map[string]interface{})["type"]).Equal("object")
93 | gt.Value(t, items["items"].(map[string]interface{})["properties"].(map[string]interface{})["id"].(map[string]interface{})["type"]).Equal("string")
94 | gt.Value(t, items["items"].(map[string]interface{})["properties"].(map[string]interface{})["quantity"].(map[string]interface{})["type"]).Equal("number")
95 | }
96 |
97 | func TestConvertParameterToSchema(t *testing.T) {
98 | t.Run("number constraints", func(t *testing.T) {
99 | p := &gollem.Parameter{
100 | Type: gollem.TypeNumber,
101 | Minimum: ptr(1.0),
102 | Maximum: ptr(10.0),
103 | }
104 | schema := convertParameterToSchema(p)
105 | gt.Value(t, schema["minimum"]).Equal(1.0)
106 | gt.Value(t, schema["maximum"]).Equal(10.0)
107 | })
108 |
109 | t.Run("string constraints", func(t *testing.T) {
110 | p := &gollem.Parameter{
111 | Type: gollem.TypeString,
112 | MinLength: ptr(1),
113 | MaxLength: ptr(10),
114 | Pattern: "^[a-z]+$",
115 | }
116 | schema := convertParameterToSchema(p)
117 | gt.Value(t, schema["minLength"]).Equal(1)
118 | gt.Value(t, schema["maxLength"]).Equal(10)
119 | gt.Value(t, schema["pattern"]).Equal("^[a-z]+$")
120 | })
121 |
122 | t.Run("array constraints", func(t *testing.T) {
123 | p := &gollem.Parameter{
124 | Type: gollem.TypeArray,
125 | Items: &gollem.Parameter{Type: gollem.TypeString},
126 | MinItems: ptr(1),
127 | MaxItems: ptr(10),
128 | }
129 | schema := convertParameterToSchema(p)
130 | gt.Value(t, schema["minItems"]).Equal(1)
131 | gt.Value(t, schema["maxItems"]).Equal(10)
132 | gt.Value(t, schema["items"].(map[string]interface{})["type"]).Equal("string")
133 | })
134 |
135 | t.Run("default value", func(t *testing.T) {
136 | p := &gollem.Parameter{
137 | Type: gollem.TypeString,
138 | Default: "default value",
139 | }
140 | schema := convertParameterToSchema(p)
141 | gt.Value(t, schema["default"]).Equal("default value")
142 | })
143 | }
144 |
145 | func ptr[T any](v T) *T {
146 | return &v
147 | }
148 |
--------------------------------------------------------------------------------
/llm/openai/embedding.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/m-mizutani/goerr/v2"
7 | "github.com/sashabaranov/go-openai"
8 | )
9 |
10 | // GenerateEmbedding generates embeddings for the given input text.
11 | func (c *Client) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) {
12 | /*
13 | AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
14 | SmallEmbedding3 EmbeddingModel = "text-embedding-3-small"
15 | LargeEmbedding3 EmbeddingModel = "text-embedding-3-large"
16 | */
17 | modelMap := map[string]openai.EmbeddingModel{
18 | "text-embedding-ada-002": openai.AdaEmbeddingV2,
19 | "text-embedding-3-small": openai.SmallEmbedding3,
20 | "text-embedding-3-large": openai.LargeEmbedding3,
21 | }
22 |
23 | model, ok := modelMap[c.embeddingModel]
24 | if !ok {
25 | return nil, goerr.New("invalid or unsupported embedding model. See https://platform.openai.com/docs/guides/embeddings#embedding-models", goerr.V("model", c.embeddingModel))
26 | }
27 |
28 | req := openai.EmbeddingRequest{
29 | Input: input,
30 | Model: model,
31 | Dimensions: dimension,
32 | }
33 |
34 | resp, err := c.client.CreateEmbeddings(ctx, req)
35 | if err != nil {
36 | return nil, goerr.Wrap(err, "failed to create embedding")
37 | }
38 |
39 | if len(resp.Data) == 0 {
40 | return nil, goerr.New("no embedding data returned")
41 | }
42 |
43 | embeddings := make([][]float64, len(resp.Data))
44 | for i, data := range resp.Data {
45 | embeddings[i] = make([]float64, len(data.Embedding))
46 | for j, v := range data.Embedding {
47 | embeddings[i][j] = float64(v)
48 | }
49 | }
50 |
51 | return embeddings, nil
52 | }
53 |
--------------------------------------------------------------------------------
/llm/openai/embedding_test.go:
--------------------------------------------------------------------------------
1 | package openai_test
2 |
3 | import (
4 | "os"
5 | "testing"
6 |
7 | "github.com/m-mizutani/gollem/llm/openai"
8 | "github.com/m-mizutani/gt"
9 | )
10 |
11 | func TestGenerateEmbedding(t *testing.T) {
12 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY")
13 | if !ok {
14 | t.Skip("TEST_OPENAI_API_KEY is not set")
15 | }
16 |
17 | ctx := t.Context()
18 | client, err := openai.New(ctx, apiKey)
19 | if err != nil {
20 | t.Fatalf("failed to create client: %v", err)
21 | }
22 |
23 | embeddings, err := client.GenerateEmbedding(ctx, 256, []string{"not, SANE", "Five timeless words"})
24 | if err != nil {
25 | t.Fatalf("failed to generate embedding: %v", err)
26 | }
27 |
28 | gt.A(t, embeddings).Length(2).
29 | At(0, func(t testing.TB, v []float64) {
30 | gt.A(t, v).Longer(0)
31 | }).
32 | At(1, func(t testing.TB, v []float64) {
33 | gt.A(t, v).Longer(0)
34 | })
35 | }
36 |
--------------------------------------------------------------------------------
/llm/openai/export_test.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | // Export convert functions for testing
4 | var (
5 | ConvertTool = convertTool
6 | ConvertParameterToSchema = convertParameterToSchema
7 | )
8 |
--------------------------------------------------------------------------------
/llm_test.go:
--------------------------------------------------------------------------------
1 | package gollem_test
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "os"
7 | "testing"
8 |
9 | "github.com/m-mizutani/goerr/v2"
10 | "github.com/m-mizutani/gollem"
11 | "github.com/m-mizutani/gollem/llm/claude"
12 | "github.com/m-mizutani/gollem/llm/gemini"
13 | "github.com/m-mizutani/gollem/llm/openai"
14 | "github.com/m-mizutani/gt"
15 | )
16 |
17 | // Sample tool implementation for testing
18 | type randomNumberTool struct{}
19 |
20 | func (t *randomNumberTool) Spec() gollem.ToolSpec {
21 | return gollem.ToolSpec{
22 | Name: "random_number",
23 | Description: "A tool for generating random numbers within a specified range",
24 | Parameters: map[string]*gollem.Parameter{
25 | "min": {
26 | Type: gollem.TypeNumber,
27 | Description: "Minimum value of the random number",
28 | },
29 | "max": {
30 | Type: gollem.TypeNumber,
31 | Description: "Maximum value of the random number",
32 | },
33 | },
34 | Required: []string{"min", "max"},
35 | }
36 | }
37 |
38 | func (t *randomNumberTool) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
39 | min, ok := args["min"].(float64)
40 | if !ok {
41 | return nil, goerr.New("min is required")
42 | }
43 |
44 | max, ok := args["max"].(float64)
45 | if !ok {
46 | return nil, goerr.New("max is required")
47 | }
48 |
49 | if min >= max {
50 | return nil, goerr.New("min must be less than max")
51 | }
52 |
53 | // Note: In real implementation, you would use a proper random number generator
54 | // This is just for testing purposes
55 | result := (min + max) / 2
56 |
57 | return map[string]any{"result": result}, nil
58 | }
59 |
60 | func testGenerateContent(t *testing.T, session gollem.Session) {
61 | ctx := t.Context()
62 |
63 | // Test case 1: Generate random number
64 | resp1, err := session.GenerateContent(ctx, gollem.Text("Please generate a random number between 1 and 10"))
65 | gt.NoError(t, err)
66 | gt.Array(t, resp1.FunctionCalls).Length(1).Required()
67 | gt.Value(t, resp1.FunctionCalls[0].Name).Equal("random_number")
68 |
69 | args := resp1.FunctionCalls[0].Arguments
70 | gt.Value(t, args["min"]).Equal(1.0)
71 | gt.Value(t, args["max"]).Equal(10.0)
72 |
73 | resp2, err := session.GenerateContent(ctx, gollem.FunctionResponse{
74 | ID: resp1.FunctionCalls[0].ID,
75 | Name: "random_number",
76 | Data: map[string]any{"result": 5.5},
77 | })
78 | gt.NoError(t, err).Required()
79 | gt.Array(t, resp2.Texts).Length(1).Required()
80 | }
81 |
82 | func testGenerateStream(t *testing.T, session gollem.Session) {
83 | ctx := t.Context()
84 |
85 | t.Run("generate random number", func(t *testing.T) {
86 | stream, err := session.GenerateStream(ctx, gollem.Text("Please generate a random number between 1 and 10"))
87 | gt.NoError(t, err).Required()
88 |
89 | var id string
90 | for resp := range stream {
91 | gt.NoError(t, resp.Error).Required()
92 |
93 | if len(resp.FunctionCalls) > 0 {
94 | for _, functionCall := range resp.FunctionCalls {
95 | if functionCall.ID != "" {
96 | id = functionCall.ID
97 | }
98 | }
99 | }
100 | }
101 |
102 | stream, err = session.GenerateStream(ctx, gollem.FunctionResponse{
103 | ID: id,
104 | Name: "random_number",
105 | Data: map[string]any{"result": 5.5},
106 | })
107 | gt.NoError(t, err).Required()
108 | for resp := range stream {
109 | gt.NoError(t, resp.Error).Required()
110 | }
111 | })
112 | }
113 |
114 | func newGeminiClient(t *testing.T) gollem.LLMClient {
115 | var testProjectID, testLocation string
116 | v, ok := os.LookupEnv("TEST_GCP_PROJECT_ID")
117 | if !ok {
118 | t.Skip("TEST_GCP_PROJECT_ID is not set")
119 | } else {
120 | testProjectID = v
121 | }
122 |
123 | v, ok = os.LookupEnv("TEST_GCP_LOCATION")
124 | if !ok {
125 | t.Skip("TEST_GCP_LOCATION is not set")
126 | } else {
127 | testLocation = v
128 | }
129 |
130 | ctx := t.Context()
131 | client, err := gemini.New(ctx, testProjectID, testLocation)
132 | gt.NoError(t, err)
133 | return client
134 | }
135 |
136 | func newOpenAIClient(t *testing.T) gollem.LLMClient {
137 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY")
138 | if !ok {
139 | t.Skip("TEST_OPENAI_API_KEY is not set")
140 | }
141 |
142 | ctx := t.Context()
143 | client, err := openai.New(ctx, apiKey)
144 | gt.NoError(t, err)
145 | return client
146 | }
147 |
148 | func newClaudeClient(t *testing.T) gollem.LLMClient {
149 | apiKey, ok := os.LookupEnv("TEST_CLAUDE_API_KEY")
150 | if !ok {
151 | t.Skip("TEST_CLAUDE_API_KEY is not set")
152 | }
153 |
154 | client, err := claude.New(context.Background(), apiKey)
155 | gt.NoError(t, err)
156 | return client
157 | }
158 |
159 | func TestGemini(t *testing.T) {
160 | client := newGeminiClient(t)
161 |
162 | // Setup tools
163 | tools := []gollem.Tool{&randomNumberTool{}}
164 | session, err := client.NewSession(t.Context(), gollem.WithSessionTools(tools...))
165 | gt.NoError(t, err)
166 |
167 | t.Run("generate content", func(t *testing.T) {
168 | testGenerateContent(t, session)
169 | })
170 | t.Run("generate stream", func(t *testing.T) {
171 | testGenerateStream(t, session)
172 | })
173 | }
174 |
175 | func TestOpenAI(t *testing.T) {
176 | client := newOpenAIClient(t)
177 |
178 | // Setup tools
179 | tools := []gollem.Tool{&randomNumberTool{}}
180 | session, err := client.NewSession(t.Context(), gollem.WithSessionTools(tools...))
181 | gt.NoError(t, err)
182 |
183 | t.Run("generate content", func(t *testing.T) {
184 | testGenerateContent(t, session)
185 | })
186 | t.Run("generate stream", func(t *testing.T) {
187 | testGenerateStream(t, session)
188 | })
189 | }
190 |
191 | func TestClaude(t *testing.T) {
192 | client := newClaudeClient(t)
193 |
194 | session, err := client.NewSession(context.Background(), gollem.WithSessionTools(&randomNumberTool{}))
195 | gt.NoError(t, err)
196 |
197 | t.Run("generate content", func(t *testing.T) {
198 | testGenerateContent(t, session)
199 | })
200 | t.Run("generate stream", func(t *testing.T) {
201 | testGenerateStream(t, session)
202 | })
203 | }
204 |
205 | type weatherTool struct {
206 | name string
207 | }
208 |
209 | func (x *weatherTool) Spec() gollem.ToolSpec {
210 | return gollem.ToolSpec{
211 | Name: x.name,
212 | Description: "get weather information of a region",
213 | Parameters: map[string]*gollem.Parameter{
214 | "region": {
215 | Type: gollem.TypeString,
216 | Description: "Region name",
217 | },
218 | },
219 | }
220 | }
221 |
222 | func (t *weatherTool) Run(ctx context.Context, input map[string]any) (map[string]any, error) {
223 | return map[string]any{
224 | "weather": "sunny",
225 | }, nil
226 | }
227 |
228 | func TestCallToolNameConvention(t *testing.T) {
229 | if _, ok := os.LookupEnv("TEST_FLAG_TOOL_NAME_CONVENTION"); !ok {
230 | t.Skip("TEST_FLAG_TOOL_NAME_CONVENTION is not set")
231 | }
232 |
233 | testFunc := func(t *testing.T, client gollem.LLMClient) {
234 | testCases := map[string]struct {
235 | name string
236 | isError bool
237 | }{
238 | "low case is allowed": {
239 | name: "test",
240 | isError: false,
241 | },
242 | "upper case is allowed": {
243 | name: "TEST",
244 | isError: false,
245 | },
246 | "underscore is allowed": {
247 | name: "test_tool",
248 | isError: false,
249 | },
250 | "number is allowed": {
251 | name: "test123",
252 | isError: false,
253 | },
254 | "hyphen is allowed": {
255 | name: "test-tool",
256 | isError: false,
257 | },
258 | /*
259 | SKIP: OpenAI, Claude does not allow dot in tool name, but Gemini allows it.
260 | "dot is not allowed": {
261 | name: "test.tool",
262 | isError: true,
263 | },
264 | */
265 | "comma is not allowed": {
266 | name: "test,tool",
267 | isError: true,
268 | },
269 | "colon is not allowed": {
270 | name: "test:tool",
271 | isError: true,
272 | },
273 | "space is not allowed": {
274 | name: "test tool",
275 | isError: true,
276 | },
277 | }
278 |
279 | for name, tc := range testCases {
280 | t.Run(name, func(t *testing.T) {
281 | ctx := t.Context()
282 | tool := &weatherTool{name: tc.name}
283 |
284 | session, err := client.NewSession(ctx, gollem.WithSessionTools(tool))
285 | gt.NoError(t, err)
286 |
287 | resp, err := session.GenerateContent(ctx, gollem.Text("What is the weather in Tokyo?"))
288 | if tc.isError {
289 | gt.Error(t, err)
290 | return
291 | }
292 | gt.NoError(t, err).Required()
293 | if len(resp.FunctionCalls) > 0 {
294 | gt.A(t, resp.FunctionCalls).Length(1).At(0, func(t testing.TB, v *gollem.FunctionCall) {
295 | gt.Equal(t, v.Name, tc.name)
296 | })
297 | }
298 | })
299 | }
300 | }
301 |
302 | t.Run("OpenAI", func(t *testing.T) {
303 | ctx := t.Context()
304 | apiKey, ok := os.LookupEnv("TEST_OPENAI_API_KEY")
305 | if !ok {
306 | t.Skip("TEST_OPENAI_API_KEY is not set")
307 | }
308 |
309 | client, err := openai.New(ctx, apiKey)
310 | gt.NoError(t, err)
311 | testFunc(t, client)
312 | })
313 |
314 | t.Run("gemini", func(t *testing.T) {
315 | ctx := t.Context()
316 | projectID, ok := os.LookupEnv("TEST_GCP_PROJECT_ID")
317 | if !ok {
318 | t.Skip("TEST_GCP_PROJECT_ID is not set")
319 | }
320 |
321 | location, ok := os.LookupEnv("TEST_GCP_LOCATION")
322 | if !ok {
323 | t.Skip("TEST_GCP_LOCATION is not set")
324 | }
325 |
326 | client, err := gemini.New(ctx, projectID, location)
327 | gt.NoError(t, err)
328 | testFunc(t, client)
329 | })
330 |
331 | t.Run("claude", func(t *testing.T) {
332 | ctx := t.Context()
333 | apiKey, ok := os.LookupEnv("TEST_CLAUDE_API_KEY")
334 | if !ok {
335 | t.Skip("TEST_CLAUDE_API_KEY is not set")
336 | }
337 |
338 | client, err := claude.New(ctx, apiKey)
339 | gt.NoError(t, err)
340 | testFunc(t, client)
341 | })
342 | }
343 |
344 | func TestSessionHistory(t *testing.T) {
345 | testFn := func(t *testing.T, client gollem.LLMClient) {
346 | ctx := t.Context()
347 | session, err := client.NewSession(ctx, gollem.WithSessionTools(&weatherTool{name: "weather"}))
348 | gt.NoError(t, err).Required()
349 |
350 | resp1, err := session.GenerateContent(ctx, gollem.Text("What is the weather in Tokyo?"))
351 | gt.NoError(t, err).Required()
352 | gt.A(t, resp1.FunctionCalls).Length(1).At(0, func(t testing.TB, v *gollem.FunctionCall) {
353 | gt.Equal(t, v.Name, "weather")
354 | })
355 |
356 | resp2, err := session.GenerateContent(ctx, gollem.FunctionResponse{
357 | ID: resp1.FunctionCalls[0].ID,
358 | Name: "weather",
359 | Data: map[string]any{"weather": "sunny"},
360 | })
361 | gt.NoError(t, err).Required()
362 | gt.A(t, resp2.Texts).Length(1).At(0, func(t testing.TB, v string) {
363 | gt.S(t, v).Contains("sunny")
364 | })
365 |
366 | history := session.History()
367 | rawData, err := json.Marshal(history)
368 | gt.NoError(t, err).Required()
369 |
370 | var restored gollem.History
371 | gt.NoError(t, json.Unmarshal(rawData, &restored))
372 |
373 | newSession, err := client.NewSession(ctx, gollem.WithSessionHistory(&restored))
374 | gt.NoError(t, err)
375 |
376 | resp3, err := newSession.GenerateContent(ctx, gollem.Text("Do you remember the weather in Tokyo?"))
377 | gt.NoError(t, err).Required()
378 |
379 | gt.A(t, resp3.Texts).Longer(0).At(0, func(t testing.TB, v string) {
380 | gt.S(t, v).Contains("sunny")
381 | })
382 | }
383 |
384 | t.Run("OpenAI", func(t *testing.T) {
385 | client := newOpenAIClient(t)
386 | testFn(t, client)
387 | })
388 |
389 | t.Run("gemini", func(t *testing.T) {
390 | client := newGeminiClient(t)
391 | testFn(t, client)
392 | })
393 |
394 | t.Run("claude", func(t *testing.T) {
395 | client := newClaudeClient(t)
396 | testFn(t, client)
397 | })
398 | }
399 |
400 | func TestExitTool(t *testing.T) {
401 | testFn := func(t *testing.T, newClient func(t *testing.T) gollem.LLMClient) {
402 | client := newClient(t)
403 |
404 | exitTool := &gollem.DefaultExitTool{}
405 | loopCount := 0
406 | exitToolCalled := false
407 |
408 | s := gollem.New(client,
409 | gollem.WithExitTool(exitTool),
410 | gollem.WithTools(&randomNumberTool{}),
411 | gollem.WithSystemPrompt("You are an assistant that can use tools. When asked to complete a task and end the session, you must use the finalize_task tool to properly end the session."),
412 | gollem.WithLoopHook(func(ctx context.Context, loop int, input []gollem.Input) error {
413 | loopCount++
414 | t.Logf("Loop called: %d", loop)
415 | return nil
416 | }),
417 | gollem.WithMessageHook(func(ctx context.Context, msg string) error {
418 | t.Logf("[Message received]: %s", msg)
419 | return nil
420 | }),
421 | gollem.WithToolRequestHook(func(ctx context.Context, tool gollem.FunctionCall) error {
422 | t.Logf("[Tool request received]: %s (%v)", tool.Name, tool.Arguments)
423 | return nil
424 | }),
425 | gollem.WithLoopLimit(10),
426 | )
427 |
428 | ctx := t.Context()
429 | _, err := s.Prompt(ctx, "Get a random number between 1 and 10")
430 | gt.NoError(t, err)
431 |
432 | t.Logf("Test completed: exitToolCalled=%v, isCompleted=%v, loopCount=%d", exitToolCalled, exitTool.IsCompleted(), loopCount)
433 |
434 | // Verify that the exit tool was called
435 | gt.True(t, exitTool.IsCompleted())
436 |
437 | // Verify that loops occurred (should be more than 0 but less than loop limit)
438 | gt.N(t, loopCount).Greater(0).Less(10)
439 | }
440 |
441 | t.Run("OpenAI", func(t *testing.T) {
442 | testFn(t, newOpenAIClient)
443 | })
444 |
445 | t.Run("Gemini", func(t *testing.T) {
446 | testFn(t, newGeminiClient)
447 | })
448 |
449 | t.Run("Claude", func(t *testing.T) {
450 | testFn(t, newClaudeClient)
451 | })
452 | }
453 |
--------------------------------------------------------------------------------
/mcp/export_test.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/mark3labs/mcp-go/mcp"
7 | )
8 |
9 | func NewLocalMCPClient(path string) *Client {
10 | client := &Client{
11 | path: path,
12 | }
13 | return client
14 | }
15 |
16 | func (x *Client) Start(ctx context.Context) error {
17 | return x.init(ctx)
18 | }
19 |
20 | func (x *Client) ListTools(ctx context.Context) ([]mcp.Tool, error) {
21 | return x.listTools(ctx)
22 | }
23 |
24 | func (x *Client) CallTool(ctx context.Context, name string, args map[string]any) (*mcp.CallToolResult, error) {
25 | return x.callTool(ctx, name, args)
26 | }
27 |
28 | var (
29 | InputSchemaToParameter = inputSchemaToParameter
30 | MCPContentToMap = mcpContentToMap
31 | JSONSchemaToParameter = jsonSchemaToParameter
32 | )
33 |
--------------------------------------------------------------------------------
/mcp/mcp.go:
--------------------------------------------------------------------------------
1 | package mcp
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "sync"
9 |
10 | "github.com/m-mizutani/goerr/v2"
11 | "github.com/m-mizutani/gollem"
12 | "github.com/mark3labs/mcp-go/client"
13 | "github.com/mark3labs/mcp-go/client/transport"
14 | "github.com/mark3labs/mcp-go/mcp"
15 | "github.com/santhosh-tekuri/jsonschema/v6"
16 | )
17 |
18 | type Client struct {
19 | // For local MCP server
20 | path string
21 | args []string
22 | envVars []string
23 |
24 | // For remote MCP server
25 | baseURL string
26 | headers map[string]string
27 |
28 | // Common client
29 | client *client.Client
30 |
31 | initResult *mcp.InitializeResult
32 | initMutex sync.Mutex
33 | }
34 |
35 | // Specs implements gollem.ToolSet interface
36 | func (c *Client) Specs(ctx context.Context) ([]gollem.ToolSpec, error) {
37 | logger := gollem.LoggerFromContext(ctx)
38 |
39 | tools, err := c.listTools(ctx)
40 | if err != nil {
41 | return nil, goerr.Wrap(err, "failed to list tools")
42 | }
43 |
44 | specs := make([]gollem.ToolSpec, len(tools))
45 | toolNames := make([]string, len(tools))
46 |
47 | for i, tool := range tools {
48 | toolNames[i] = tool.Name
49 |
50 | param, err := inputSchemaToParameter(tool.InputSchema)
51 | if err != nil {
52 | return nil, goerr.Wrap(err,
53 | "failed to convert input schema to parameter",
54 | goerr.V("tool.name", tool.Name),
55 | goerr.V("tool.inputSchema", tool.InputSchema),
56 | )
57 | }
58 |
59 | specs[i] = gollem.ToolSpec{
60 | Name: tool.Name,
61 | Description: tool.Description,
62 | Parameters: param.Properties,
63 | Required: param.Required,
64 | }
65 | }
66 |
67 | logger.Debug("found MCP tools", "names", toolNames)
68 |
69 | return specs, nil
70 | }
71 |
72 | // Run implements gollem.ToolSet interface
73 | func (c *Client) Run(ctx context.Context, name string, args map[string]any) (map[string]any, error) {
74 | logger := gollem.LoggerFromContext(ctx)
75 |
76 | logger.Debug("call MCP tool", "name", name, "args", args)
77 |
78 | resp, err := c.callTool(ctx, name, args)
79 | if err != nil {
80 | return nil, goerr.Wrap(err, "failed to call tool")
81 | }
82 |
83 | return mcpContentToMap(resp.Content), nil
84 | }
85 |
86 | // StdioOption is the option for the MCP client for local MCP executable server via stdio.
87 | type StdioOption func(*Client)
88 |
89 | // WithEnvVars sets the environment variables for the MCP client. It appends the environment variables to the existing ones.
90 | func WithEnvVars(envVars []string) StdioOption {
91 | return func(m *Client) {
92 | m.envVars = append(m.envVars, envVars...)
93 | }
94 | }
95 |
96 | // NewStdio creates a new MCP client for local MCP executable server via stdio.
97 | func NewStdio(ctx context.Context, path string, args []string, options ...StdioOption) (*Client, error) {
98 | client := &Client{
99 | path: path,
100 | args: args,
101 | }
102 | for _, option := range options {
103 | option(client)
104 | }
105 |
106 | if err := client.init(ctx); err != nil {
107 | return nil, goerr.Wrap(err, "failed to initialize MCP client")
108 | }
109 |
110 | return client, nil
111 | }
112 |
113 | // SSEOption is the option for the MCP client for remote MCP server via HTTP SSE.
114 | type SSEOption func(*Client)
115 |
116 | // WithHeaders sets the headers for the MCP client. It replaces the existing headers setting.
117 | func WithHeaders(headers map[string]string) SSEOption {
118 | return func(m *Client) {
119 | m.headers = headers
120 | }
121 | }
122 |
123 | // NewSSE creates a new MCP client for remote MCP server via HTTP SSE.
124 | func NewSSE(ctx context.Context, baseURL string, options ...SSEOption) (*Client, error) {
125 | client := &Client{
126 | baseURL: baseURL,
127 | }
128 | for _, option := range options {
129 | option(client)
130 | }
131 |
132 | if err := client.init(ctx); err != nil {
133 | return nil, goerr.Wrap(err, "failed to initialize MCP client")
134 | }
135 |
136 | return client, nil
137 | }
138 |
139 | func (c *Client) init(ctx context.Context) error {
140 | c.initMutex.Lock()
141 | defer c.initMutex.Unlock()
142 |
143 | logger := gollem.LoggerFromContext(ctx)
144 |
145 | if c.initResult != nil {
146 | return nil
147 | }
148 |
149 | var tp transport.Interface
150 | if c.path != "" {
151 | tp = transport.NewStdio(c.path, c.envVars, c.args...)
152 | }
153 |
154 | if c.baseURL != "" {
155 | sse, err := transport.NewSSE(c.baseURL, transport.WithHeaders(c.headers))
156 | if err != nil {
157 | return goerr.Wrap(err, "failed to create SSE transport")
158 | }
159 | tp = sse
160 | }
161 |
162 | if tp == nil {
163 | return goerr.New("no transport")
164 | }
165 |
166 | c.client = client.NewClient(tp)
167 |
168 | logger.Debug("starting MCP client", "path", c.path, "url", c.baseURL)
169 | if err := c.client.Start(ctx); err != nil {
170 | return goerr.Wrap(err, "failed to start MCP client")
171 | }
172 |
173 | var initRequest mcp.InitializeRequest
174 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
175 | initRequest.Params.ClientInfo = mcp.Implementation{
176 | Name: "gollem",
177 | Version: "0.0.1",
178 | }
179 |
180 | logger.Debug("initializing MCP client")
181 | if resp, err := c.client.Initialize(ctx, initRequest); err != nil {
182 | return goerr.Wrap(err, "failed to initialize MCP client")
183 | } else {
184 | c.initResult = resp
185 | }
186 |
187 | return nil
188 | }
189 |
190 | func (c *Client) listTools(ctx context.Context) ([]mcp.Tool, error) {
191 | // ListTools is thread safe
192 | resp, err := c.client.ListTools(ctx, mcp.ListToolsRequest{})
193 | if err != nil {
194 | return nil, goerr.Wrap(err, "failed to list tools")
195 | }
196 |
197 | return resp.Tools, nil
198 | }
199 |
200 | func (c *Client) callTool(ctx context.Context, name string, args map[string]any) (*mcp.CallToolResult, error) {
201 | req := mcp.CallToolRequest{}
202 | req.Params.Name = name
203 | req.Params.Arguments = args
204 | resp, err := c.client.CallTool(ctx, req)
205 | if err != nil {
206 | return nil, goerr.Wrap(err, "failed to call tool")
207 | }
208 |
209 | return resp, nil
210 | }
211 |
212 | func (c *Client) Close() error {
213 | if err := c.client.Close(); err != nil {
214 | return goerr.Wrap(err, "failed to close MCP client")
215 | }
216 | return nil
217 | }
218 |
219 | func inputSchemaToParameter(inputSchema mcp.ToolInputSchema) (*gollem.Parameter, error) {
220 | parameters := map[string]*gollem.Parameter{}
221 | jsonSchema, err := json.Marshal(inputSchema)
222 | if err != nil {
223 | return nil, goerr.Wrap(err, "failed to marshal input schema")
224 | }
225 |
226 | rawSchema, err := jsonschema.UnmarshalJSON(bytes.NewReader(jsonSchema))
227 | if err != nil {
228 | return nil, goerr.Wrap(err, "failed to compile input schema")
229 | }
230 |
231 | c := jsonschema.NewCompiler()
232 | if err := c.AddResource("schema.json", rawSchema); err != nil {
233 | return nil, goerr.Wrap(err, "failed to add resource to compiler")
234 | }
235 | schema, err := c.Compile("schema.json")
236 | if err != nil {
237 | return nil, goerr.Wrap(err, "failed to compile input schema")
238 | }
239 |
240 | schemaType := schema.Types.ToStrings()
241 | if len(schemaType) != 1 || schemaType[0] != "object" {
242 | return nil, goerr.Wrap(gollem.ErrInvalidTool, "invalid input schema", goerr.V("schema", schema))
243 | }
244 |
245 | for name, property := range schema.Properties {
246 | parameters[name] = jsonSchemaToParameter(property)
247 | }
248 |
249 | return &gollem.Parameter{
250 | Type: gollem.ParameterType(schema.Types.ToStrings()[0]),
251 | Title: schema.Title,
252 | Description: schema.Description,
253 | Required: schema.Required,
254 | Properties: parameters,
255 | }, nil
256 | }
257 |
258 | func jsonSchemaToParameter(schema *jsonschema.Schema) *gollem.Parameter {
259 | var enum []string
260 | if schema.Enum != nil {
261 | for _, v := range schema.Enum.Values {
262 | enum = append(enum, fmt.Sprintf("%v", v))
263 | }
264 | }
265 |
266 | properties := map[string]*gollem.Parameter{}
267 | for name, property := range schema.Properties {
268 | properties[name] = jsonSchemaToParameter(property)
269 | }
270 |
271 | var items *gollem.Parameter
272 | if schema.Items != nil {
273 | switch v := schema.Items.(type) {
274 | case *jsonschema.Schema:
275 | items = jsonSchemaToParameter(v)
276 | }
277 | }
278 |
279 | var minimum, maximum *float64
280 | if schema.Minimum != nil {
281 | min, _ := (*schema.Minimum).Float64()
282 | minimum = &min
283 | }
284 | if schema.Maximum != nil {
285 | max, _ := (*schema.Maximum).Float64()
286 | maximum = &max
287 | }
288 |
289 | var minLength, maxLength *int
290 | if schema.MinLength != nil {
291 | min := int(*schema.MinLength)
292 | minLength = &min
293 | }
294 | if schema.MaxLength != nil {
295 | max := int(*schema.MaxLength)
296 | maxLength = &max
297 | }
298 |
299 | var minItems, maxItems *int
300 | if schema.MinItems != nil {
301 | min := int(*schema.MinItems)
302 | minItems = &min
303 | }
304 | if schema.MaxItems != nil {
305 | max := int(*schema.MaxItems)
306 | maxItems = &max
307 | }
308 |
309 | var pattern string
310 | if schema.Pattern != nil {
311 | pattern = schema.Pattern.String()
312 | }
313 |
314 | return &gollem.Parameter{
315 | Type: gollem.ParameterType(schema.Types.ToStrings()[0]),
316 | Title: schema.Title,
317 | Description: schema.Description,
318 | Required: schema.Required,
319 | Enum: enum,
320 | Properties: properties,
321 | Items: items,
322 | Minimum: minimum,
323 | Maximum: maximum,
324 | MinLength: minLength,
325 | MaxLength: maxLength,
326 | Pattern: pattern,
327 | MinItems: minItems,
328 | MaxItems: maxItems,
329 | Default: schema.Default,
330 | }
331 | }
332 |
333 | func mcpContentToMap(contents []mcp.Content) map[string]any {
334 | if len(contents) == 0 {
335 | return nil
336 | }
337 |
338 | if len(contents) == 1 {
339 | if content, ok := contents[0].(mcp.TextContent); ok {
340 | var v any
341 | if err := json.Unmarshal([]byte(content.Text), &v); err == nil {
342 | if mapData, ok := v.(map[string]any); ok {
343 | return mapData
344 | }
345 | }
346 | return map[string]any{
347 | "result": content.Text,
348 | }
349 | }
350 | return nil
351 | }
352 |
353 | result := map[string]any{}
354 | for i, c := range contents {
355 | if content, ok := c.(mcp.TextContent); ok {
356 | result[fmt.Sprintf("content_%d", i+1)] = content.Text
357 | }
358 | }
359 | return result
360 | }
361 |
--------------------------------------------------------------------------------
/mcp/mcp_test.go:
--------------------------------------------------------------------------------
1 | package mcp_test
2 |
3 | import (
4 | "context"
5 | "os"
6 | "testing"
7 |
8 | "github.com/m-mizutani/gollem/mcp"
9 | "github.com/m-mizutani/gt"
10 | mcpgo "github.com/mark3labs/mcp-go/mcp"
11 | )
12 |
13 | func TestMCPLocalDryRun(t *testing.T) {
14 | mcpExecPath, ok := os.LookupEnv("TEST_MCP_EXEC_PATH")
15 | if !ok {
16 | t.Skip("TEST_MCP_EXEC_PATH is not set")
17 | }
18 |
19 | client := mcp.NewLocalMCPClient(mcpExecPath)
20 |
21 | err := client.Start(context.Background())
22 | gt.NoError(t, err)
23 |
24 | tools, err := client.ListTools(context.Background())
25 | gt.NoError(t, err)
26 | gt.A(t, tools).Longer(0)
27 |
28 | parameter, err := mcp.InputSchemaToParameter(tools[0].InputSchema)
29 | gt.NoError(t, err)
30 | t.Log("parameter:", parameter)
31 |
32 | tool := tools[0]
33 |
34 | t.Log("tool:", tool)
35 |
36 | callTool, err := client.CallTool(context.Background(), tool.Name, map[string]any{
37 | "length": 10,
38 | })
39 | gt.NoError(t, err)
40 |
41 | t.Log("callTool:", callTool)
42 | }
43 |
44 | func TestMCPContentToMap(t *testing.T) {
45 | t.Run("when content is empty", func(t *testing.T) {
46 | result := mcp.MCPContentToMap([]mcpgo.Content{})
47 | gt.Nil(t, result)
48 | })
49 |
50 | t.Run("when text content is JSON", func(t *testing.T) {
51 | content := mcpgo.TextContent{Text: `{"key": "value"}`}
52 | result := mcp.MCPContentToMap([]mcpgo.Content{content})
53 | gt.Equal(t, map[string]any{"key": "value"}, result)
54 | })
55 |
56 | t.Run("when text content is not JSON", func(t *testing.T) {
57 | content := mcpgo.TextContent{Text: "plain text"}
58 | result := mcp.MCPContentToMap([]mcpgo.Content{content})
59 | gt.Equal(t, map[string]any{"result": "plain text"}, result)
60 | })
61 |
62 | t.Run("when multiple contents exist", func(t *testing.T) {
63 | contents := []mcpgo.Content{
64 | mcpgo.TextContent{Text: "first"},
65 | mcpgo.TextContent{Text: "second"},
66 | }
67 | result := mcp.MCPContentToMap(contents)
68 | gt.Equal(t, map[string]any{
69 | "content_1": "first",
70 | "content_2": "second",
71 | }, result)
72 | })
73 | }
74 |
--------------------------------------------------------------------------------
/mock/mock_gen.go:
--------------------------------------------------------------------------------
1 | // Code generated by moq; DO NOT EDIT.
2 | // github.com/matryer/moq
3 |
4 | package mock
5 |
6 | import (
7 | "context"
8 | "github.com/m-mizutani/gollem"
9 | "sync"
10 | )
11 |
12 | // LLMClientMock is a mock implementation of gollem.LLMClient.
13 | //
14 | // func TestSomethingThatUsesLLMClient(t *testing.T) {
15 | //
16 | // // make and configure a mocked gollem.LLMClient
17 | // mockedLLMClient := &LLMClientMock{
18 | // GenerateEmbeddingFunc: func(ctx context.Context, dimension int, input []string) ([][]float64, error) {
19 | // panic("mock out the GenerateEmbedding method")
20 | // },
21 | // NewSessionFunc: func(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) {
22 | // panic("mock out the NewSession method")
23 | // },
24 | // }
25 | //
26 | // // use mockedLLMClient in code that requires gollem.LLMClient
27 | // // and then make assertions.
28 | //
29 | // }
30 | type LLMClientMock struct {
31 | // GenerateEmbeddingFunc mocks the GenerateEmbedding method.
32 | GenerateEmbeddingFunc func(ctx context.Context, dimension int, input []string) ([][]float64, error)
33 |
34 | // NewSessionFunc mocks the NewSession method.
35 | NewSessionFunc func(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error)
36 |
37 | // calls tracks calls to the methods.
38 | calls struct {
39 | // GenerateEmbedding holds details about calls to the GenerateEmbedding method.
40 | GenerateEmbedding []struct {
41 | // Ctx is the ctx argument value.
42 | Ctx context.Context
43 | // Dimension is the dimension argument value.
44 | Dimension int
45 | // Input is the input argument value.
46 | Input []string
47 | }
48 | // NewSession holds details about calls to the NewSession method.
49 | NewSession []struct {
50 | // Ctx is the ctx argument value.
51 | Ctx context.Context
52 | // Options is the options argument value.
53 | Options []gollem.SessionOption
54 | }
55 | }
56 | lockGenerateEmbedding sync.RWMutex
57 | lockNewSession sync.RWMutex
58 | }
59 |
60 | // GenerateEmbedding calls GenerateEmbeddingFunc.
61 | func (mock *LLMClientMock) GenerateEmbedding(ctx context.Context, dimension int, input []string) ([][]float64, error) {
62 | callInfo := struct {
63 | Ctx context.Context
64 | Dimension int
65 | Input []string
66 | }{
67 | Ctx: ctx,
68 | Dimension: dimension,
69 | Input: input,
70 | }
71 | mock.lockGenerateEmbedding.Lock()
72 | mock.calls.GenerateEmbedding = append(mock.calls.GenerateEmbedding, callInfo)
73 | mock.lockGenerateEmbedding.Unlock()
74 | if mock.GenerateEmbeddingFunc == nil {
75 | var (
76 | float64ssOut [][]float64
77 | errOut error
78 | )
79 | return float64ssOut, errOut
80 | }
81 | return mock.GenerateEmbeddingFunc(ctx, dimension, input)
82 | }
83 |
84 | // GenerateEmbeddingCalls gets all the calls that were made to GenerateEmbedding.
85 | // Check the length with:
86 | //
87 | // len(mockedLLMClient.GenerateEmbeddingCalls())
88 | func (mock *LLMClientMock) GenerateEmbeddingCalls() []struct {
89 | Ctx context.Context
90 | Dimension int
91 | Input []string
92 | } {
93 | var calls []struct {
94 | Ctx context.Context
95 | Dimension int
96 | Input []string
97 | }
98 | mock.lockGenerateEmbedding.RLock()
99 | calls = mock.calls.GenerateEmbedding
100 | mock.lockGenerateEmbedding.RUnlock()
101 | return calls
102 | }
103 |
104 | // NewSession calls NewSessionFunc.
105 | func (mock *LLMClientMock) NewSession(ctx context.Context, options ...gollem.SessionOption) (gollem.Session, error) {
106 | callInfo := struct {
107 | Ctx context.Context
108 | Options []gollem.SessionOption
109 | }{
110 | Ctx: ctx,
111 | Options: options,
112 | }
113 | mock.lockNewSession.Lock()
114 | mock.calls.NewSession = append(mock.calls.NewSession, callInfo)
115 | mock.lockNewSession.Unlock()
116 | if mock.NewSessionFunc == nil {
117 | var (
118 | sessionOut gollem.Session
119 | errOut error
120 | )
121 | return sessionOut, errOut
122 | }
123 | return mock.NewSessionFunc(ctx, options...)
124 | }
125 |
126 | // NewSessionCalls gets all the calls that were made to NewSession.
127 | // Check the length with:
128 | //
129 | // len(mockedLLMClient.NewSessionCalls())
130 | func (mock *LLMClientMock) NewSessionCalls() []struct {
131 | Ctx context.Context
132 | Options []gollem.SessionOption
133 | } {
134 | var calls []struct {
135 | Ctx context.Context
136 | Options []gollem.SessionOption
137 | }
138 | mock.lockNewSession.RLock()
139 | calls = mock.calls.NewSession
140 | mock.lockNewSession.RUnlock()
141 | return calls
142 | }
143 |
144 | // SessionMock is a mock implementation of gollem.Session.
145 | //
146 | // func TestSomethingThatUsesSession(t *testing.T) {
147 | //
148 | // // make and configure a mocked gollem.Session
149 | // mockedSession := &SessionMock{
150 | // GenerateContentFunc: func(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) {
151 | // panic("mock out the GenerateContent method")
152 | // },
153 | // GenerateStreamFunc: func(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) {
154 | // panic("mock out the GenerateStream method")
155 | // },
156 | // HistoryFunc: func() *gollem.History {
157 | // panic("mock out the History method")
158 | // },
159 | // }
160 | //
161 | // // use mockedSession in code that requires gollem.Session
162 | // // and then make assertions.
163 | //
164 | // }
165 | type SessionMock struct {
166 | // GenerateContentFunc mocks the GenerateContent method.
167 | GenerateContentFunc func(ctx context.Context, input ...gollem.Input) (*gollem.Response, error)
168 |
169 | // GenerateStreamFunc mocks the GenerateStream method.
170 | GenerateStreamFunc func(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error)
171 |
172 | // HistoryFunc mocks the History method.
173 | HistoryFunc func() *gollem.History
174 |
175 | // calls tracks calls to the methods.
176 | calls struct {
177 | // GenerateContent holds details about calls to the GenerateContent method.
178 | GenerateContent []struct {
179 | // Ctx is the ctx argument value.
180 | Ctx context.Context
181 | // Input is the input argument value.
182 | Input []gollem.Input
183 | }
184 | // GenerateStream holds details about calls to the GenerateStream method.
185 | GenerateStream []struct {
186 | // Ctx is the ctx argument value.
187 | Ctx context.Context
188 | // Input is the input argument value.
189 | Input []gollem.Input
190 | }
191 | // History holds details about calls to the History method.
192 | History []struct {
193 | }
194 | }
195 | lockGenerateContent sync.RWMutex
196 | lockGenerateStream sync.RWMutex
197 | lockHistory sync.RWMutex
198 | }
199 |
200 | // GenerateContent calls GenerateContentFunc.
201 | func (mock *SessionMock) GenerateContent(ctx context.Context, input ...gollem.Input) (*gollem.Response, error) {
202 | callInfo := struct {
203 | Ctx context.Context
204 | Input []gollem.Input
205 | }{
206 | Ctx: ctx,
207 | Input: input,
208 | }
209 | mock.lockGenerateContent.Lock()
210 | mock.calls.GenerateContent = append(mock.calls.GenerateContent, callInfo)
211 | mock.lockGenerateContent.Unlock()
212 | if mock.GenerateContentFunc == nil {
213 | var (
214 | responseOut *gollem.Response
215 | errOut error
216 | )
217 | return responseOut, errOut
218 | }
219 | return mock.GenerateContentFunc(ctx, input...)
220 | }
221 |
222 | // GenerateContentCalls gets all the calls that were made to GenerateContent.
223 | // Check the length with:
224 | //
225 | // len(mockedSession.GenerateContentCalls())
226 | func (mock *SessionMock) GenerateContentCalls() []struct {
227 | Ctx context.Context
228 | Input []gollem.Input
229 | } {
230 | var calls []struct {
231 | Ctx context.Context
232 | Input []gollem.Input
233 | }
234 | mock.lockGenerateContent.RLock()
235 | calls = mock.calls.GenerateContent
236 | mock.lockGenerateContent.RUnlock()
237 | return calls
238 | }
239 |
240 | // GenerateStream calls GenerateStreamFunc.
241 | func (mock *SessionMock) GenerateStream(ctx context.Context, input ...gollem.Input) (<-chan *gollem.Response, error) {
242 | callInfo := struct {
243 | Ctx context.Context
244 | Input []gollem.Input
245 | }{
246 | Ctx: ctx,
247 | Input: input,
248 | }
249 | mock.lockGenerateStream.Lock()
250 | mock.calls.GenerateStream = append(mock.calls.GenerateStream, callInfo)
251 | mock.lockGenerateStream.Unlock()
252 | if mock.GenerateStreamFunc == nil {
253 | var (
254 | responseChOut <-chan *gollem.Response
255 | errOut error
256 | )
257 | return responseChOut, errOut
258 | }
259 | return mock.GenerateStreamFunc(ctx, input...)
260 | }
261 |
262 | // GenerateStreamCalls gets all the calls that were made to GenerateStream.
263 | // Check the length with:
264 | //
265 | // len(mockedSession.GenerateStreamCalls())
266 | func (mock *SessionMock) GenerateStreamCalls() []struct {
267 | Ctx context.Context
268 | Input []gollem.Input
269 | } {
270 | var calls []struct {
271 | Ctx context.Context
272 | Input []gollem.Input
273 | }
274 | mock.lockGenerateStream.RLock()
275 | calls = mock.calls.GenerateStream
276 | mock.lockGenerateStream.RUnlock()
277 | return calls
278 | }
279 |
280 | // History calls HistoryFunc.
281 | func (mock *SessionMock) History() *gollem.History {
282 | callInfo := struct {
283 | }{}
284 | mock.lockHistory.Lock()
285 | mock.calls.History = append(mock.calls.History, callInfo)
286 | mock.lockHistory.Unlock()
287 | if mock.HistoryFunc == nil {
288 | var (
289 | historyOut *gollem.History
290 | )
291 | return historyOut
292 | }
293 | return mock.HistoryFunc()
294 | }
295 |
296 | // HistoryCalls gets all the calls that were made to History.
297 | // Check the length with:
298 | //
299 | // len(mockedSession.HistoryCalls())
300 | func (mock *SessionMock) HistoryCalls() []struct {
301 | } {
302 | var calls []struct {
303 | }
304 | mock.lockHistory.RLock()
305 | calls = mock.calls.History
306 | mock.lockHistory.RUnlock()
307 | return calls
308 | }
309 |
310 | // ToolMock is a mock implementation of gollem.Tool.
311 | //
312 | // func TestSomethingThatUsesTool(t *testing.T) {
313 | //
314 | // // make and configure a mocked gollem.Tool
315 | // mockedTool := &ToolMock{
316 | // RunFunc: func(ctx context.Context, args map[string]any) (map[string]any, error) {
317 | // panic("mock out the Run method")
318 | // },
319 | // SpecFunc: func() gollem.ToolSpec {
320 | // panic("mock out the Spec method")
321 | // },
322 | // }
323 | //
324 | // // use mockedTool in code that requires gollem.Tool
325 | // // and then make assertions.
326 | //
327 | // }
328 | type ToolMock struct {
329 | // RunFunc mocks the Run method.
330 | RunFunc func(ctx context.Context, args map[string]any) (map[string]any, error)
331 |
332 | // SpecFunc mocks the Spec method.
333 | SpecFunc func() gollem.ToolSpec
334 |
335 | // calls tracks calls to the methods.
336 | calls struct {
337 | // Run holds details about calls to the Run method.
338 | Run []struct {
339 | // Ctx is the ctx argument value.
340 | Ctx context.Context
341 | // Args is the args argument value.
342 | Args map[string]any
343 | }
344 | // Spec holds details about calls to the Spec method.
345 | Spec []struct {
346 | }
347 | }
348 | lockRun sync.RWMutex
349 | lockSpec sync.RWMutex
350 | }
351 |
352 | // Run calls RunFunc.
353 | func (mock *ToolMock) Run(ctx context.Context, args map[string]any) (map[string]any, error) {
354 | callInfo := struct {
355 | Ctx context.Context
356 | Args map[string]any
357 | }{
358 | Ctx: ctx,
359 | Args: args,
360 | }
361 | mock.lockRun.Lock()
362 | mock.calls.Run = append(mock.calls.Run, callInfo)
363 | mock.lockRun.Unlock()
364 | if mock.RunFunc == nil {
365 | var (
366 | stringToVOut map[string]any
367 | errOut error
368 | )
369 | return stringToVOut, errOut
370 | }
371 | return mock.RunFunc(ctx, args)
372 | }
373 |
374 | // RunCalls gets all the calls that were made to Run.
375 | // Check the length with:
376 | //
377 | // len(mockedTool.RunCalls())
378 | func (mock *ToolMock) RunCalls() []struct {
379 | Ctx context.Context
380 | Args map[string]any
381 | } {
382 | var calls []struct {
383 | Ctx context.Context
384 | Args map[string]any
385 | }
386 | mock.lockRun.RLock()
387 | calls = mock.calls.Run
388 | mock.lockRun.RUnlock()
389 | return calls
390 | }
391 |
392 | // Spec calls SpecFunc.
393 | func (mock *ToolMock) Spec() gollem.ToolSpec {
394 | callInfo := struct {
395 | }{}
396 | mock.lockSpec.Lock()
397 | mock.calls.Spec = append(mock.calls.Spec, callInfo)
398 | mock.lockSpec.Unlock()
399 | if mock.SpecFunc == nil {
400 | var (
401 | toolSpecOut gollem.ToolSpec
402 | )
403 | return toolSpecOut
404 | }
405 | return mock.SpecFunc()
406 | }
407 |
408 | // SpecCalls gets all the calls that were made to Spec.
409 | // Check the length with:
410 | //
411 | // len(mockedTool.SpecCalls())
412 | func (mock *ToolMock) SpecCalls() []struct {
413 | } {
414 | var calls []struct {
415 | }
416 | mock.lockSpec.RLock()
417 | calls = mock.calls.Spec
418 | mock.lockSpec.RUnlock()
419 | return calls
420 | }
421 |
--------------------------------------------------------------------------------
/session.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import "context"
4 |
5 | // Session is a session for the LLM. This can be called to generate content and stream. It's mainly used for Prompt() method, but it also can be used for one-shot content generation.
6 | type Session interface {
7 | GenerateContent(ctx context.Context, input ...Input) (*Response, error)
8 | GenerateStream(ctx context.Context, input ...Input) (<-chan *Response, error)
9 | History() *History
10 | }
11 |
12 | // SessionConfig is the configuration for the new session. This is required for only LLM client implementations.
13 | type SessionConfig struct {
14 | history *History
15 | contentType ContentType
16 | systemPrompt string
17 | tools []Tool
18 | }
19 |
20 | // History returns the history of the session.
21 | func (c *SessionConfig) History() *History {
22 | return c.history
23 | }
24 |
25 | // SystemPrompt returns the system prompt of the session.
26 | func (c *SessionConfig) SystemPrompt() string {
27 | return c.systemPrompt
28 | }
29 |
30 | // ContentType returns the content type of the session.
31 | func (c *SessionConfig) ContentType() ContentType {
32 | return c.contentType
33 | }
34 |
35 | // Tools returns the tools of the session.
36 | func (c *SessionConfig) Tools() []Tool {
37 | return c.tools
38 | }
39 |
40 | // NewSessionConfig creates a new session configuration. This is required for only LLM client implementations.
41 | func NewSessionConfig(options ...SessionOption) SessionConfig {
42 | cfg := SessionConfig{}
43 | for _, option := range options {
44 | option(&cfg)
45 | }
46 | return cfg
47 | }
48 |
49 | // SessionOption is the option for the session configuration. This is required for only LLM client implementations.
50 | type SessionOption func(cfg *SessionConfig)
51 |
52 | // WithSessionHistory sets the history for the session.
53 | // Usage:
54 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionHistory(history))
55 | func WithSessionHistory(history *History) SessionOption {
56 | return func(cfg *SessionConfig) {
57 | cfg.history = history
58 | }
59 | }
60 |
61 | // WithSessionContentType sets the content type for the session.
62 | // Usage:
63 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionContentType(gollem.ContentTypeJSON))
64 | func WithSessionContentType(contentType ContentType) SessionOption {
65 | return func(cfg *SessionConfig) {
66 | cfg.contentType = contentType
67 | }
68 | }
69 |
70 | // WithSessionTools sets the tools for the session.
71 | // Usage:
72 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionTools(tools))
73 | func WithSessionTools(tools ...Tool) SessionOption {
74 | return func(cfg *SessionConfig) {
75 | cfg.tools = append(cfg.tools, tools...)
76 | }
77 | }
78 |
79 | // WithSessionSystemPrompt sets the system prompt for the session.
80 | // Usage:
81 | // session, err := llmClient.NewSession(ctx, gollem.WithSessionSystemPrompt("You are a helpful assistant."))
82 | func WithSessionSystemPrompt(systemPrompt string) SessionOption {
83 | return func(cfg *SessionConfig) {
84 | cfg.systemPrompt = systemPrompt
85 | }
86 | }
87 |
88 | // ContentType represents the type of content to be generated by the LLM.
89 | type ContentType string
90 |
91 | const (
92 | // ContentTypeText represents plain text content.
93 | ContentTypeText ContentType = "text"
94 | // ContentTypeJSON represents JSON content.
95 | ContentTypeJSON ContentType = "json"
96 | )
97 |
--------------------------------------------------------------------------------
/tool.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import (
4 | "context"
5 | "regexp"
6 |
7 | "github.com/m-mizutani/goerr/v2"
8 | )
9 |
10 | // ToolSpec is the specification of a tool.
11 | type ToolSpec struct {
12 | Name string
13 | Description string
14 | Parameters map[string]*Parameter
15 | Required []string
16 | }
17 |
18 | // Validate validates the tool specification.
19 | func (s *ToolSpec) Validate() error {
20 | eb := goerr.NewBuilder(goerr.V("tool", s))
21 | if s.Name == "" {
22 | return eb.Wrap(ErrInvalidTool, "name is required")
23 | }
24 |
25 | paramNames := make(map[string]struct{})
26 | for name, param := range s.Parameters {
27 | if _, ok := paramNames[name]; ok {
28 | return eb.Wrap(ErrInvalidTool, "duplicate parameter name", goerr.V("name", name))
29 | }
30 | paramNames[name] = struct{}{}
31 |
32 | if err := param.Validate(); err != nil {
33 | return eb.Wrap(ErrInvalidTool, "invalid parameter")
34 | }
35 | }
36 |
37 | for _, req := range s.Required {
38 | if _, ok := paramNames[req]; !ok {
39 | return eb.Wrap(ErrInvalidTool, "required parameter not found", goerr.V("name", req))
40 | }
41 | }
42 |
43 | return nil
44 | }
45 |
46 | // ParameterType is the type of a parameter.
47 | type ParameterType string
48 |
49 | const (
50 | TypeString ParameterType = "string"
51 | TypeNumber ParameterType = "number"
52 | TypeInteger ParameterType = "integer"
53 | TypeBoolean ParameterType = "boolean"
54 | TypeArray ParameterType = "array"
55 | TypeObject ParameterType = "object"
56 | )
57 |
58 | // Parameter is a parameter of a tool.
59 | type Parameter struct {
60 | // Title is the user friendly of the parameter. It's optional.
61 | Title string
62 |
63 | // Type is the type of the parameter. It's required.
64 | Type ParameterType
65 |
66 | // Description is the description of the parameter. It's optional.
67 | Description string
68 |
69 | // Required is the list of required field names when Type is Object.
70 | Required []string
71 |
72 | // Enum is the enum of the parameter. It's optional.
73 | Enum []string
74 |
75 | // Properties is the properties of the parameter. It's used for object type.
76 | Properties map[string]*Parameter
77 |
78 | // Items is the items of the parameter. It's used for array type.
79 | Items *Parameter
80 |
81 | // Number constraints
82 | Minimum *float64
83 | Maximum *float64
84 |
85 | // String constraints
86 | MinLength *int
87 | MaxLength *int
88 | Pattern string
89 |
90 | // Array constraints
91 | MinItems *int
92 | MaxItems *int
93 |
94 | // Default value
95 | Default any
96 | }
97 |
98 | // Validate validates the parameter.
99 | func (p *Parameter) Validate() error {
100 | eb := goerr.NewBuilder(goerr.V("parameter", p))
101 |
102 | // Type is required
103 | if p.Type == "" {
104 | return eb.Wrap(ErrInvalidParameter, "type is required")
105 | }
106 |
107 | // Validate parameter type
108 | switch p.Type {
109 | case TypeString, TypeNumber, TypeInteger, TypeBoolean, TypeArray, TypeObject:
110 | // Valid type
111 | default:
112 | return eb.Wrap(ErrInvalidParameter, "invalid parameter type", goerr.V("type", p.Type))
113 | }
114 |
115 | // Properties is required for object type
116 | if p.Type == TypeObject {
117 | if p.Properties == nil {
118 | return eb.Wrap(ErrInvalidParameter, "properties is required for object type")
119 | }
120 |
121 | // Check for duplicate property names
122 | propNames := make(map[string]struct{})
123 | for name := range p.Properties {
124 | if _, ok := propNames[name]; ok {
125 | return eb.Wrap(ErrInvalidParameter, "duplicate property name", goerr.V("name", name))
126 | }
127 | propNames[name] = struct{}{}
128 | }
129 |
130 | // Validate nested properties
131 | for _, prop := range p.Properties {
132 | if err := prop.Validate(); err != nil {
133 | return eb.Wrap(ErrInvalidParameter, "invalid property")
134 | }
135 | }
136 | // Validate required fields exist in properties
137 | for _, req := range p.Required {
138 | if _, ok := p.Properties[req]; !ok {
139 | return eb.Wrap(ErrInvalidParameter, "required field not found in properties", goerr.V("field", req))
140 | }
141 | }
142 | }
143 |
144 | // Items is required for array type
145 | if p.Type == TypeArray {
146 | if p.Items == nil {
147 | return eb.Wrap(ErrInvalidParameter, "items is required for array type")
148 | }
149 | // Validate items
150 | if err := p.Items.Validate(); err != nil {
151 | return eb.Wrap(ErrInvalidParameter, "invalid items")
152 | }
153 | }
154 |
155 | // Validate number constraints
156 | if p.Type == TypeNumber || p.Type == TypeInteger {
157 | if p.Minimum != nil && p.Maximum != nil && *p.Minimum > *p.Maximum {
158 | return eb.Wrap(ErrInvalidParameter, "minimum must be less than or equal to maximum")
159 | }
160 | }
161 |
162 | // Validate string constraints
163 | if p.Type == TypeString {
164 | if p.MinLength != nil && p.MaxLength != nil && *p.MinLength > *p.MaxLength {
165 | return eb.Wrap(ErrInvalidParameter, "minLength must be less than or equal to maxLength")
166 | }
167 | if p.Pattern != "" {
168 | if _, err := regexp.Compile(p.Pattern); err != nil {
169 | return eb.Wrap(ErrInvalidParameter, "invalid pattern", goerr.V("pattern", p.Pattern))
170 | }
171 | }
172 | }
173 |
174 | // Validate array constraints
175 | if p.Type == TypeArray {
176 | if p.MinItems != nil && p.MaxItems != nil && *p.MinItems > *p.MaxItems {
177 | return eb.Wrap(ErrInvalidParameter, "minItems must be less than or equal to maxItems")
178 | }
179 | }
180 |
181 | return nil
182 | }
183 |
184 | // Tool is specification and execution of an action that can be called by the LLM.
185 | type Tool interface {
186 | // Spec returns the specification of the tool. It's called when starting a LLM chat session in Prompt().
187 | Spec() ToolSpec
188 |
189 | // Run is the execution of the tool.
190 | // It's called when receiving a tool call from the LLM. Even if the method returns an error, the tool execution is not aborted. Error will be passed to LLM as a response. If you want to abort the tool execution, you need to return an error from the callback function of WithToolErrorHook().
191 | Run(ctx context.Context, args map[string]any) (map[string]any, error)
192 | }
193 |
194 | // ToolSet is a set of tools.
195 | // It's useful for providing a set of tools to the LLM.
196 | type ToolSet interface {
197 | // Specs returns the specifications of the tools.
198 | Specs(ctx context.Context) ([]ToolSpec, error)
199 |
200 | // Run is the execution of the tool.
201 | // It's called when receiving a tool call from the LLM.
202 | Run(ctx context.Context, name string, args map[string]any) (map[string]any, error)
203 | }
204 |
--------------------------------------------------------------------------------
/tool_test.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/m-mizutani/gt"
7 | )
8 |
9 | func TestParameterValidation(t *testing.T) {
10 | t.Run("number constraints", func(t *testing.T) {
11 | t.Run("valid minimum and maximum", func(t *testing.T) {
12 | p := &Parameter{
13 | Type: TypeNumber,
14 | Minimum: ptr(1.0),
15 | Maximum: ptr(10.0),
16 | }
17 | gt.NoError(t, p.Validate())
18 | })
19 |
20 | t.Run("invalid minimum and maximum", func(t *testing.T) {
21 | p := &Parameter{
22 | Type: TypeNumber,
23 | Minimum: ptr(10.0),
24 | Maximum: ptr(1.0),
25 | }
26 | gt.Error(t, p.Validate())
27 | })
28 | })
29 |
30 | t.Run("string constraints", func(t *testing.T) {
31 | t.Run("valid minLength and maxLength", func(t *testing.T) {
32 | p := &Parameter{
33 | Type: TypeString,
34 | MinLength: ptr(1),
35 | MaxLength: ptr(10),
36 | }
37 | gt.NoError(t, p.Validate())
38 | })
39 |
40 | t.Run("invalid minLength and maxLength", func(t *testing.T) {
41 | p := &Parameter{
42 | Type: TypeString,
43 | MinLength: ptr(10),
44 | MaxLength: ptr(1),
45 | }
46 | gt.Error(t, p.Validate())
47 | })
48 |
49 | t.Run("valid pattern", func(t *testing.T) {
50 | p := &Parameter{
51 | Type: TypeString,
52 | Pattern: "^[a-z]+$",
53 | }
54 | gt.NoError(t, p.Validate())
55 | })
56 |
57 | t.Run("invalid pattern", func(t *testing.T) {
58 | p := &Parameter{
59 | Type: TypeString,
60 | Pattern: "[invalid",
61 | }
62 | gt.Error(t, p.Validate())
63 | })
64 | })
65 |
66 | t.Run("array constraints", func(t *testing.T) {
67 | t.Run("valid minItems and maxItems", func(t *testing.T) {
68 | p := &Parameter{
69 | Type: TypeArray,
70 | Items: &Parameter{Type: TypeString},
71 | MinItems: ptr(1),
72 | MaxItems: ptr(10),
73 | }
74 | gt.NoError(t, p.Validate())
75 | })
76 |
77 | t.Run("invalid minItems and maxItems", func(t *testing.T) {
78 | p := &Parameter{
79 | Type: TypeArray,
80 | Items: &Parameter{Type: TypeString},
81 | MinItems: ptr(10),
82 | MaxItems: ptr(1),
83 | }
84 | gt.Error(t, p.Validate())
85 | })
86 | })
87 |
88 | t.Run("object constraints", func(t *testing.T) {
89 | t.Run("valid properties", func(t *testing.T) {
90 | p := &Parameter{
91 | Type: TypeObject,
92 | Properties: map[string]*Parameter{
93 | "name": {
94 | Type: TypeString,
95 | Description: "User name",
96 | },
97 | "age": {
98 | Type: TypeNumber,
99 | Description: "User age",
100 | },
101 | },
102 | }
103 | gt.NoError(t, p.Validate())
104 | })
105 |
106 | t.Run("duplicate property names", func(t *testing.T) {
107 | p := &Parameter{
108 | Type: TypeObject,
109 | Properties: make(map[string]*Parameter),
110 | }
111 | p.Properties["name"] = &Parameter{
112 | Type: TypeString,
113 | Description: "User name",
114 | }
115 | p.Properties["name"] = &Parameter{
116 | Type: TypeString,
117 | Description: "Duplicate name",
118 | }
119 | gt.NoError(t, p.Validate())
120 | })
121 |
122 | t.Run("invalid property type", func(t *testing.T) {
123 | p := &Parameter{
124 | Type: TypeObject,
125 | Properties: map[string]*Parameter{
126 | "name": {
127 | Type: "invalid",
128 | Description: "User name",
129 | },
130 | },
131 | }
132 | gt.Error(t, p.Validate())
133 | })
134 | })
135 | }
136 |
137 | func ptr[T any](v T) *T {
138 | return &v
139 | }
140 |
141 | func TestToolSpecValidation(t *testing.T) {
142 | t.Run("valid tool spec", func(t *testing.T) {
143 | spec := ToolSpec{
144 | Name: "test",
145 | Description: "test description",
146 | Parameters: map[string]*Parameter{
147 | "param1": {
148 | Type: TypeString,
149 | Description: "test parameter",
150 | },
151 | },
152 | Required: []string{"param1"},
153 | }
154 | gt.NoError(t, spec.Validate())
155 | })
156 |
157 | t.Run("empty name", func(t *testing.T) {
158 | spec := ToolSpec{
159 | Description: "test description",
160 | Parameters: map[string]*Parameter{
161 | "param1": {
162 | Type: TypeString,
163 | Description: "test parameter",
164 | },
165 | },
166 | }
167 | gt.Error(t, spec.Validate())
168 | })
169 |
170 | t.Run("invalid parameter type", func(t *testing.T) {
171 | spec := ToolSpec{
172 | Name: "test",
173 | Description: "test description",
174 | Parameters: map[string]*Parameter{
175 | "param1": {
176 | Type: "invalid",
177 | Description: "test parameter",
178 | },
179 | },
180 | }
181 | gt.Error(t, spec.Validate())
182 | })
183 |
184 | t.Run("required parameter not found", func(t *testing.T) {
185 | spec := ToolSpec{
186 | Name: "test",
187 | Description: "test description",
188 | Parameters: map[string]*Parameter{
189 | "param1": {
190 | Type: TypeString,
191 | Description: "test parameter",
192 | },
193 | },
194 | Required: []string{"param2"},
195 | }
196 | gt.Error(t, spec.Validate())
197 | })
198 |
199 | t.Run("invalid parameter", func(t *testing.T) {
200 | spec := ToolSpec{
201 | Name: "test",
202 | Description: "test description",
203 | Parameters: map[string]*Parameter{
204 | "param1": {
205 | Type: TypeNumber,
206 | Minimum: ptr(10.0),
207 | Maximum: ptr(1.0),
208 | },
209 | },
210 | }
211 | gt.Error(t, spec.Validate())
212 | })
213 |
214 | t.Run("object parameter without properties", func(t *testing.T) {
215 | spec := ToolSpec{
216 | Name: "test",
217 | Description: "test description",
218 | Parameters: map[string]*Parameter{
219 | "param1": {
220 | Type: TypeObject,
221 | Description: "test parameter",
222 | },
223 | },
224 | }
225 | gt.Error(t, spec.Validate())
226 | })
227 |
228 | t.Run("array parameter without items", func(t *testing.T) {
229 | spec := ToolSpec{
230 | Name: "test",
231 | Description: "test description",
232 | Parameters: map[string]*Parameter{
233 | "param1": {
234 | Type: TypeArray,
235 | Description: "test parameter",
236 | },
237 | },
238 | }
239 | gt.Error(t, spec.Validate())
240 | })
241 | }
242 |
--------------------------------------------------------------------------------
/types.go:
--------------------------------------------------------------------------------
1 | package gollem
2 |
--------------------------------------------------------------------------------