├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── agents ├── agent.go ├── agent_test.go ├── agent_tool.go └── task.go ├── engines ├── engine.go ├── mocks │ └── engine.go ├── openai.go └── prompt.go ├── evaluation ├── agent_evaluator.go ├── evaluator.go ├── evaluator_test.go └── llm_evaluator.go ├── go-llm-cli └── main.go ├── go.mod ├── go.sum ├── memory ├── buffer_memory.go ├── buffer_memory_test.go ├── memory.go ├── mocks │ └── memory.go ├── summarised_memory.go ├── summarised_memory_test.go └── vectorstore.go ├── prebuilt ├── code_refactor.go ├── git_assistant.go ├── trade_assistant.go └── unit_test_writer.go └── tools ├── ask_user.go ├── ask_user_test.go ├── bash.go ├── bash_test.go ├── generic_tool.go ├── google_search.go ├── google_search_test.go ├── isolated_python_repl.go ├── isolated_python_repl_test.go ├── json_autofixer.go ├── json_autofixer_test.go ├── key_value_store.go ├── key_value_store_test.go ├── mocks └── tool.go ├── python_repl.go ├── python_repl_test.go ├── testdata ├── expected_parsed_results.json ├── golang_wikipedia_article.html └── mock_search_results.html ├── tool.go ├── utils.go ├── utils_test.go ├── webpage_summary.go ├── webpage_summary_test.go ├── wolfram_alpha.go └── wolfram_alpha_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v3 21 | with: 22 | go-version: 1.23 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .vscode 3 | **/.venv 4 | output.json 5 | log.txt 6 | .idea 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nathan Liebmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go LLM 2 | [![Go](https://github.com/natexcvi/go-llm/actions/workflows/go.yml/badge.svg)](https://github.com/natexcvi/go-llm/actions/workflows/go.yml) 3 | 4 | Integrate the power of large language models (LLM) into your Go application. 5 | 6 | This project aims to abstract away much of the plumbing (free text to structured data, contextual memory, tool wrapping, retry logic, etc.) so you can focus on the business logic of your agent. 7 | 8 | ```mermaid 9 | graph LR 10 | subgraph Input 11 | A[Structured Input] --> B[Compiled Task] 12 | end 13 | subgraph LLM-Based Agent 14 | C[Task Template] --> B[Compiled Task] 15 | B --> D((Agent)) 16 | D --"Reasoning"--> D 17 | D --"Action"--> E[Environment] 18 | E --"Observation"--> D 19 | D --"Answer"--> G[Output Validators] 20 | G --> D 21 | end 22 | subgraph Output 23 | G --"Answer"--> F[Structured Output] 24 | end 25 | ``` 26 | 27 | ## Usage Example 28 | ```go 29 | package main 30 | 31 | import ( 32 | "encoding/json" 33 | "fmt" 34 | "os" 35 | 36 | "github.com/natexcvi/go-llm/agents" 37 | "github.com/natexcvi/go-llm/engines" 38 | "github.com/natexcvi/go-llm/memory" 39 | "github.com/natexcvi/go-llm/tools" 40 | ) 41 | 42 | type CodeBaseRefactorRequest struct { 43 | Dir string 44 | Goal string 45 | } 46 | 47 | func (req CodeBaseRefactorRequest) Encode() string { 48 | return fmt.Sprintf(`{"dir": "%s", "goal": "%s"}`, req.Dir, req.Goal) 49 | } 50 | 51 | func (req CodeBaseRefactorRequest) Schema() string { 52 | return `{"dir": "path to code base", "goal": "refactoring goal"}` 53 | } 54 | 55 | type CodeBaseRefactorResponse struct { 56 | RefactoredFiles map[string]string `json:"refactored_files"` 57 | } 58 | 59 | func (resp CodeBaseRefactorResponse) Encode() string { 60 | marshalled, err := json.Marshal(resp.RefactoredFiles) 61 | if err != nil { 62 | panic(err) 63 | } 64 | return string(marshalled) 65 | } 66 | 67 | func (resp CodeBaseRefactorResponse) Schema() string { 68 | return `{"refactored_files": {"path": "description of changes"}}` 69 | } 70 | 71 | func main() { 72 | task := &agents.Task[CodeBaseRefactorRequest, CodeBaseRefactorResponse]{ 73 | Description: "You will be given access to a code base, and instructions for refactoring." + 74 | "your task is to refactor the code base to meet the given goal.", 75 | Examples: []agents.Example[CodeBaseRefactorRequest, CodeBaseRefactorResponse]{ 76 | { 77 | Input: CodeBaseRefactorRequest{ 78 | Dir: "/Users/nate/code/base", 79 | Goal: "Handle errors gracefully", 80 | }, 81 | Answer: CodeBaseRefactorResponse{ 82 | RefactoredFiles: map[string]string{ 83 | "/Users/nate/code/base/main.py": "added try/except block", 84 | }, 85 | }, 86 | IntermediarySteps: []*engines.ChatMessage{ 87 | (&agents.ChainAgentThought{ 88 | Content: "I should scan the code base for functions that might error.", 89 | }).Encode(engine), 90 | (&agents.ChainAgentAction{ 91 | Tool: tools.NewBashTerminal(), 92 | Args: json.RawMessage(`{"command": "ls /Users/nate/code/base"}`), 93 | }).Encode(engine), 94 | (&agents.ChainAgentObservation{ 95 | Content: "main.py", 96 | ToolName: tools.NewBashTerminal().Name(), 97 | }).Encode(engine), 98 | (&agents.ChainAgentThought{ 99 | Content: "Now I should read the code file.", 100 | }).Encode(engine), 101 | (&agents.ChainAgentAction{ 102 | Tool: tools.NewBashTerminal(), 103 | Args: json.RawMessage(`{"command": "cat /Users/nate/code/base/main.py"}`), 104 | }).Encode(engine), 105 | (&agents.ChainAgentObservation{ 106 | Content: "def main():\n\tfunc_that_might_error()", 107 | ToolName: tools.NewBashTerminal().Name(), 108 | }).Encode(engine), 109 | (&agents.ChainAgentThought{ 110 | Content: "I should refactor the code to handle errors gracefully.", 111 | }).Encode(engine), 112 | (&agents.ChainAgentAction{ 113 | Tool: tools.NewBashTerminal(), 114 | Args: json.RawMessage(`{"command": "echo 'def main():\n\ttry:\n\t\tfunc_that_might_error()\n\texcept Exception as e:\n\t\tprint(\"Error: %s\", e)' > /Users/nate/code/base/main.py"}`), 115 | }).Encode(engine), 116 | }, 117 | }, 118 | }, 119 | AnswerParser: func(msg string) (CodeBaseRefactorResponse, error) { 120 | var res CodeBaseRefactorResponse 121 | if err := json.Unmarshal([]byte(msg), &res); err != nil { 122 | return CodeBaseRefactorResponse{}, err 123 | } 124 | return res, nil 125 | }, 126 | } 127 | agent := agents.NewChainAgent(engines.NewGPTEngine(os.Getenv("OPENAI_TOKEN"), "gpt-3.5-turbo-0613"), task, memory.NewBufferedMemory(0)).WithMaxSolutionAttempts(12).WithTools(tools.NewPythonREPL(), tools.NewBashTerminal()) 128 | res, err := agent.Run(CodeBaseRefactorRequest{ 129 | Dir: "/Users/nate/Git/go-llm/tools", 130 | Goal: "Write unit tests for the bash.go file, following the example of python_repl_test.go.", 131 | }) 132 | ... 133 | } 134 | ``` 135 | > **Note** 136 | > 137 | > Fun fact: the `tools/bash_test.go` file was written by this very agent, and helped find a bug! 138 | 139 | ## Components 140 | ### Engines 141 | Connectors to LLM engines. Currently only OpenAI's GPT chat completion API is supported. 142 | ### Tools 143 | Tools that can provide agents with the ability to perform actions interacting with the outside world. 144 | Currently available tools are: 145 | - `PythonREPL` - a tool that allows agents to execute Python code in a REPL. 146 | - `IsolatedPythonREPL` - a tool that allows agents to execute Python code in a REPL, but in a Docker container. 147 | - `BashTerminal` - a tool that allows agents to execute bash commands in a terminal. 148 | - `GoogleSearch` - a tool that allows agents to search Google. 149 | - `WebpageSummary` - an LLM-based tool that allows agents to get a summary of a webpage. 150 | - `WolframAlpha` - a tool that allows agents to query WolframAlpha's short answer API. 151 | - `KeyValueStore` - a tool for storing and retrieving information. The agent can use this tool to re-use long pieces of information by-reference, removing duplication and therefore reducing context size. 152 | - `AskUser` - an interactivity tool that lets the agent ask a human operator for clarifications when needed. 153 | - `JSONAutoFixer` - a meta tool that is enabled by default. When the arguments to any tool are provided in a form that is not valid JSON, this tool attempts to fix the payload using a separate LLM chain. 154 | - `GenericAgentTool` - lets an agent run another agent, with pre-determined tools, dynamically providing it with its task and input and collecting its final answer. 155 | 156 | > **Warning** 157 | > 158 | > The `BashTerminal` and regular `PythonREPL` tools let the agent run arbitrary commands on your machine, use at your own risk. It may be a good idea to use the built-in support for action confirmation callbacks (see the `WithActionConfirmation` method on the `ChainAgent` type). 159 | 160 | #### Model-Native Function Calls 161 | `go-llm` tools support the [new OpenAI function call interface](https://openai.com/blog/function-calling-and-other-api-updates?ref=upstract.com) transparently, for model variants that have this feature. 162 | 163 | ### Memory 164 | A memory system that allows agents to store and retrieve information. 165 | Currently available memory systems are: 166 | - `BufferMemory` - which provides each step of the agent with a fixed buffer of recent messages from the conversation history. 167 | - `SummarisedMemory` - which provides each step of the agent with a summary of the conversation history, powered by an LLM. 168 | 169 | ### Agents 170 | Agents are the main component of the library. Agents can perform complex tasks that involve iterative interactions with the outside world. 171 | 172 | ### Prebuilt (WIP) 173 | A collection of ready-made agents that can be easily integrated with your application. 174 | 175 | ### Evaluation (WIP) 176 | A collection of evaluation tools for agents and engines. 177 | ## Example 178 | ```go 179 | package main 180 | 181 | import ( 182 | "fmt" 183 | "os" 184 | 185 | "github.com/natexcvi/go-llm/engines" 186 | "github.com/natexcvi/go-llm/evaluation" 187 | ) 188 | 189 | func goodness(_ *engines.ChatPrompt, _ *engines.ChatMessage, err error) float64 { 190 | if err != nil { 191 | return 0 192 | } 193 | 194 | return 1 195 | } 196 | 197 | func main() { 198 | engine := engines.NewGPTEngine(os.Getenv("OPENAI_TOKEN"), "gpt-3.5-turbo-0613") 199 | engineRunner := evaluation.NewLLMRunner(engine) 200 | 201 | evaluator := evaluation.NewEvaluator(engineRunner, &evaluation.Options[*engines.ChatPrompt, *engines.ChatMessage]{ 202 | GoodnessFunction: goodness, 203 | Repetitions: 5, 204 | }) 205 | 206 | testPack := []*engines.ChatPrompt{ 207 | { 208 | History: []*engines.ChatMessage{ 209 | { 210 | Text: "Hello, how are you?", 211 | }, 212 | { 213 | Text: "I'm trying to understand how this works.", 214 | }, 215 | }, 216 | }, 217 | { 218 | History: []*engines.ChatMessage{ 219 | { 220 | Text: "Could you please explain it to me?", 221 | }, 222 | }, 223 | }, 224 | } 225 | 226 | results := evaluator.Evaluate(testPack) 227 | fmt.Println("Goodness level of the first prompt:", results[0]) 228 | fmt.Println("Goodness level of the second prompt:", results[1]) 229 | } 230 | ``` -------------------------------------------------------------------------------- /agents/agent.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "regexp" 8 | "strings" 9 | 10 | log "github.com/sirupsen/logrus" 11 | 12 | "github.com/hashicorp/go-multierror" 13 | "github.com/natexcvi/go-llm/engines" 14 | "github.com/natexcvi/go-llm/memory" 15 | toolsPkg "github.com/natexcvi/go-llm/tools" 16 | "golang.org/x/exp/maps" 17 | ) 18 | 19 | const ( 20 | ThoughtCode = "Thought" 21 | ActionCode = "Action" 22 | AnswerCode = "Answer" 23 | ErrorCode = "Error" 24 | ObservationCode = "Observation" 25 | EndMarker = "" 26 | MessageFormat = "%s: %s" 27 | MessagePrefix = "%s: " 28 | ) 29 | 30 | var ( 31 | actionRegex = regexp.MustCompile(`^(?P.*?)\((?P[\s\S]*)\)`) 32 | operationRegex = regexp.MustCompile(`(?P[A-Za-z]+):\s*(?P[\s\S]*)(?:)`) 33 | operationRegexWithoutEnd = regexp.MustCompile(`(?P[A-Za-z]+):\s*(?P[\s\S]*)`) 34 | ) 35 | 36 | var ( 37 | errNativeFunctionsUnsupported = errors.New("native functions are not supported for this LLM") 38 | ) 39 | 40 | //go:generate mockgen -source=agent.go -destination=mocks/agent.go -package=mocks 41 | type Agent[T any, S any] interface { 42 | Run(input T) (S, error) 43 | } 44 | 45 | type ChainAgentThought struct { 46 | Content string 47 | } 48 | 49 | func (a *ChainAgentThought) Encode(_ engines.LLM) *engines.ChatMessage { 50 | return &engines.ChatMessage{ 51 | Role: engines.ConvRoleUser, 52 | Text: fmt.Sprintf(MessageFormat, ThoughtCode, a.Content), 53 | } 54 | } 55 | 56 | func ParseChainAgentThought(thought *engines.ChatMessage) *ChainAgentThought { 57 | return &ChainAgentThought{ 58 | Content: strings.TrimPrefix(thought.Text, fmt.Sprintf(MessagePrefix, ThoughtCode)), 59 | } 60 | } 61 | 62 | type ChainAgentAction struct { 63 | Tool toolsPkg.Tool 64 | Args json.RawMessage 65 | } 66 | 67 | func (a *ChainAgentAction) Encode(targetEngine engines.LLM) *engines.ChatMessage { 68 | if _, ok := targetEngine.(engines.LLMWithFunctionCalls); ok { 69 | return &engines.ChatMessage{ 70 | Role: engines.ConvRoleAssistant, 71 | FunctionCall: &engines.FunctionCall{ 72 | Name: a.Tool.Name(), 73 | Args: string(a.Args), 74 | }, 75 | } 76 | } 77 | msgText := fmt.Sprintf(MessageFormat, ActionCode, fmt.Sprintf("%s(%s)", a.Tool.Name(), a.Tool.CompactArgs(a.Args))) 78 | return &engines.ChatMessage{ 79 | Role: engines.ConvRoleAssistant, 80 | Text: msgText, 81 | } 82 | } 83 | 84 | func (a *ChainAgent[T, S]) parseNativeFunctionCall(msg *engines.ChatMessage) (*ChainAgentAction, error) { 85 | if msg.FunctionCall == nil { 86 | return nil, errors.New("no function call found") 87 | } 88 | tool, ok := a.Tools[msg.FunctionCall.Name] 89 | if !ok { 90 | return nil, fmt.Errorf("tool %q not found. Available tools: %s", msg.FunctionCall.Name, strings.Join(maps.Keys(a.Tools), ", ")) 91 | } 92 | return &ChainAgentAction{ 93 | Tool: tool, 94 | Args: []byte(msg.FunctionCall.Args), 95 | }, nil 96 | } 97 | 98 | func (a *ChainAgent[T, S]) ParseChainAgentAction(msg *engines.ChatMessage) (*ChainAgentAction, error) { 99 | if msg.FunctionCall != nil { 100 | return a.parseNativeFunctionCall(msg) 101 | } 102 | matches := actionRegex.FindStringSubmatch(msg.Text) 103 | if len(matches) != 3 { 104 | return nil, fmt.Errorf("invalid action format: message must start with `%s: ` and the action call itself must match regex %q", ActionCode, actionRegex.String()) 105 | } 106 | toolName := matches[actionRegex.SubexpIndex("tool")] 107 | toolArgs := matches[actionRegex.SubexpIndex("args")] 108 | 109 | tool, ok := a.Tools[toolName] 110 | if !ok { 111 | return nil, fmt.Errorf("tool %q not found. Available tools: %s", toolName, strings.Join(maps.Keys(a.Tools), ", ")) 112 | } 113 | 114 | jsonArgs := json.RawMessage(toolArgs) 115 | for _, processor := range a.ActionArgPreprocessors { 116 | var err error 117 | jsonArgs, err = processor.Process(jsonArgs) 118 | if err != nil { 119 | return nil, fmt.Errorf("error while preprocessing action args: %s", err.Error()) 120 | } 121 | } 122 | 123 | return &ChainAgentAction{ 124 | Tool: tool, 125 | Args: jsonArgs, 126 | }, nil 127 | } 128 | 129 | type ChainAgentAnswer[S any] struct { 130 | Content S 131 | } 132 | 133 | func (a *ChainAgent[T, S]) parseChainAgentAnswer(answer *engines.ChatMessage) (*ChainAgentAnswer[S], error) { 134 | output, err := a.Task.AnswerParser(answer.Text) 135 | if err != nil { 136 | return nil, err 137 | } 138 | return &ChainAgentAnswer[S]{ 139 | Content: output, 140 | }, nil 141 | } 142 | 143 | type ChainAgent[T Representable, S Representable] struct { 144 | Engine engines.LLM 145 | Task *Task[T, S] 146 | Tools map[string]toolsPkg.Tool 147 | InputValidators []func(T) error 148 | OutputValidators []func(S) error 149 | MaxSolutionAttempts int 150 | MaxRestarts int 151 | Memory memory.Memory 152 | ActionConfirmation func(action *ChainAgentAction) bool 153 | ActionArgPreprocessors []toolsPkg.PreprocessingTool 154 | nativeFunctionSpecs []engines.FunctionSpecs 155 | } 156 | 157 | type ChainAgentMessage interface { 158 | Encode(targetEngine engines.LLM) *engines.ChatMessage 159 | } 160 | 161 | type ChainAgentError struct { 162 | Content string 163 | ToolName string 164 | } 165 | 166 | func (a *ChainAgentError) Encode(targetEngine engines.LLM) *engines.ChatMessage { 167 | if _, ok := targetEngine.(engines.LLMWithFunctionCalls); ok { 168 | return &engines.ChatMessage{ 169 | Role: engines.ConvRoleFunction, 170 | Name: a.ToolName, 171 | Text: fmt.Sprintf("An error has occured: %s", a.Content), 172 | } 173 | } 174 | return &engines.ChatMessage{ 175 | Role: engines.ConvRoleSystem, 176 | Text: fmt.Sprintf(MessageFormat, ErrorCode, a.Content), 177 | } 178 | } 179 | 180 | type ChainAgentObservation struct { 181 | Content string 182 | ToolName string 183 | } 184 | 185 | func (a *ChainAgentObservation) Encode(targetEngine engines.LLM) *engines.ChatMessage { 186 | if _, ok := targetEngine.(engines.LLMWithFunctionCalls); ok { 187 | return &engines.ChatMessage{ 188 | Role: engines.ConvRoleFunction, 189 | Name: a.ToolName, 190 | Text: a.Content, 191 | } 192 | } 193 | return &engines.ChatMessage{ 194 | Role: engines.ConvRoleSystem, 195 | Text: fmt.Sprintf(MessageFormat, ObservationCode, a.Content), 196 | } 197 | } 198 | 199 | func (a *ChainAgent[T, S]) executeAction(action *ChainAgentAction) (obs ChainAgentMessage) { 200 | if a.ActionConfirmation != nil && !a.ActionConfirmation(action) { 201 | return &ChainAgentError{ 202 | Content: "action cancelled by the user", 203 | ToolName: action.Tool.Name(), 204 | } 205 | } 206 | actionOutput, err := action.Tool.Execute(action.Args) 207 | if err != nil { 208 | return &ChainAgentError{ 209 | Content: err.Error(), 210 | ToolName: action.Tool.Name(), 211 | } 212 | } 213 | return &ChainAgentObservation{ 214 | Content: string(actionOutput), 215 | ToolName: action.Tool.Name(), 216 | } 217 | } 218 | 219 | func (a *ChainAgent[T, S]) processFunctionCallMessage(response *engines.ChatMessage) (nextMessages []*engines.ChatMessage, answer *ChainAgentAnswer[S]) { 220 | action, err := a.parseNativeFunctionCall(response) 221 | if err != nil { 222 | nextMessages = append(nextMessages, &engines.ChatMessage{ 223 | Role: engines.ConvRoleFunction, 224 | Name: response.FunctionCall.Name, 225 | Text: fmt.Sprintf(MessageFormat, ErrorCode, err.Error()), 226 | }) 227 | return 228 | } 229 | nextMessages = append(nextMessages, a.executeAction(action).Encode(a.Engine)) 230 | return 231 | } 232 | 233 | func (a *ChainAgent[T, S]) parseResponse(response *engines.ChatMessage) (nextMessages []*engines.ChatMessage, answer *ChainAgentAnswer[S]) { 234 | if response.FunctionCall != nil { 235 | return a.processFunctionCallMessage(response) 236 | } 237 | var exp *regexp.Regexp 238 | var ops [][]string 239 | for _, candidateExp := range []*regexp.Regexp{operationRegex, operationRegexWithoutEnd} { 240 | ops = candidateExp.FindAllStringSubmatch(response.Text, -1) 241 | if len(ops) > 0 { 242 | exp = candidateExp 243 | break 244 | } 245 | } 246 | if len(ops) == 0 { 247 | return // consider the message a thought anyway 248 | } 249 | for _, op := range ops { 250 | opCode := op[exp.SubexpIndex("code")] 251 | opContent := op[exp.SubexpIndex("content")] 252 | switch opCode { 253 | case ThoughtCode: 254 | break 255 | case ActionCode: 256 | action, err := a.ParseChainAgentAction(&engines.ChatMessage{ 257 | Role: engines.ConvRoleAssistant, 258 | Text: opContent, 259 | }) 260 | if err != nil { 261 | nextMessages = append(nextMessages, &engines.ChatMessage{ 262 | Role: engines.ConvRoleSystem, 263 | Text: fmt.Sprintf(MessageFormat, ErrorCode, err.Error()), 264 | }) 265 | break 266 | } 267 | obs := a.executeAction(action) 268 | nextMessages = append(nextMessages, obs.Encode(a.Engine)) 269 | case AnswerCode: 270 | answer, err := a.parseChainAgentAnswer(&engines.ChatMessage{ 271 | Role: engines.ConvRoleAssistant, 272 | Text: opContent, 273 | }) 274 | if err != nil { 275 | nextMessages = append(nextMessages, &engines.ChatMessage{ 276 | Role: engines.ConvRoleSystem, 277 | Text: fmt.Sprintf(MessageFormat, ErrorCode, err.Error()), 278 | }) 279 | break 280 | } 281 | err = a.validateAnswer(answer.Content) 282 | if err != nil { 283 | nextMessages = append(nextMessages, &engines.ChatMessage{ 284 | Role: engines.ConvRoleSystem, 285 | Text: fmt.Sprintf(MessageFormat, ErrorCode, err.Error()), 286 | }) 287 | break 288 | } 289 | return nextMessages, answer 290 | default: 291 | break // consider the message a thought 292 | } 293 | } 294 | return nextMessages, nil 295 | } 296 | 297 | func (a *ChainAgent[T, S]) validateAnswer(answer S) error { 298 | var answerErr *multierror.Error 299 | for _, validator := range a.OutputValidators { 300 | answerErr = multierror.Append(answerErr, validator(answer)) 301 | } 302 | return answerErr.ErrorOrNil() 303 | } 304 | 305 | func (a *ChainAgent[T, S]) logMessages(msg ...*engines.ChatMessage) { 306 | for _, m := range msg { 307 | if m.FunctionCall != nil { 308 | log.Debugf("[%s] [function_call] %s(%s)", m.Role, m.FunctionCall.Name, m.FunctionCall.Args) 309 | continue 310 | } 311 | log.Debugf("[%s] %s", m.Role, m.Text) 312 | } 313 | } 314 | 315 | func (a *ChainAgent[T, S]) chat(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 316 | if engine, ok := a.Engine.(engines.LLMWithFunctionCalls); ok { 317 | return engine.ChatWithFunctions(prompt, a.nativeFunctionSpecs) 318 | } 319 | return a.Engine.Chat(prompt) 320 | } 321 | 322 | func (a *ChainAgent[T, S]) run(input T) (output S, err error) { 323 | var inputErr *multierror.Error 324 | for _, validator := range a.InputValidators { 325 | inputErr = multierror.Append(inputErr, validator(input)) 326 | } 327 | if inputErr.ErrorOrNil() != nil { 328 | return output, fmt.Errorf("invalid input: %w", inputErr) 329 | } 330 | visibleTools := a.Tools 331 | if _, ok := a.Engine.(engines.LLMWithFunctionCalls); ok { 332 | visibleTools = map[string]toolsPkg.Tool{} 333 | } 334 | taskPrompt := a.Task.Compile(input, visibleTools) 335 | a.logMessages(taskPrompt.History...) 336 | err = a.Memory.AddPrompt(taskPrompt) 337 | if err != nil { 338 | return output, fmt.Errorf("failed to add prompt to memory: %w", err) 339 | } 340 | response, err := a.chat(taskPrompt) 341 | if err != nil { 342 | return output, fmt.Errorf("failed to predict response: %w", err) 343 | } 344 | a.logMessages(response) 345 | err = a.Memory.Add(response) 346 | if err != nil { 347 | return output, fmt.Errorf("failed to add response to memory: %w", err) 348 | } 349 | stepsExecuted := 0 350 | for { 351 | nextMessages, answer := a.parseResponse(response) 352 | a.logMessages(nextMessages...) 353 | if answer != nil { 354 | return answer.Content, nil 355 | } 356 | prompt, err := a.Memory.PromptWithContext(nextMessages...) 357 | if err != nil { 358 | return output, fmt.Errorf("failed to generate prompt: %w", err) 359 | } 360 | if a.MaxSolutionAttempts > 0 && stepsExecuted > a.MaxSolutionAttempts { 361 | return output, errors.New("max solution attempts reached") 362 | } 363 | response, err = a.chat(prompt) 364 | if err != nil { 365 | return output, fmt.Errorf("failed to predict response: %w", err) 366 | } 367 | a.logMessages(response) 368 | err = a.Memory.Add(response) 369 | if err != nil { 370 | return output, fmt.Errorf("failed to add response to memory: %w", err) 371 | } 372 | stepsExecuted++ 373 | } 374 | } 375 | 376 | func (a *ChainAgent[T, S]) Run(input T) (output S, err error) { 377 | for i := 0; i <= a.MaxRestarts; i++ { 378 | output, err = a.run(input) 379 | if err == nil { 380 | return output, nil 381 | } 382 | } 383 | return output, err 384 | } 385 | 386 | func NewChainAgent[T Representable, S Representable](engine engines.LLM, task *Task[T, S], memory memory.Memory) *ChainAgent[T, S] { 387 | return &ChainAgent[T, S]{ 388 | Engine: engine, 389 | Task: task, 390 | Tools: map[string]toolsPkg.Tool{}, 391 | Memory: memory, 392 | ActionArgPreprocessors: []toolsPkg.PreprocessingTool{ 393 | toolsPkg.NewJSONAutoFixer(engine, 3), 394 | }, 395 | } 396 | } 397 | 398 | func (a *ChainAgent[T, S]) setNativeLLMFunctions(tools ...toolsPkg.Tool) (err error) { 399 | _, ok := a.Engine.(engines.LLMWithFunctionCalls) 400 | if !ok { 401 | return errNativeFunctionsUnsupported 402 | } 403 | functions := make([]engines.FunctionSpecs, len(tools)) 404 | for i, tool := range tools { 405 | function, err := toolsPkg.ConvertToNativeFunctionSpecs(tool) 406 | if err != nil { 407 | return fmt.Errorf("failed to convert tool to native function specs: %w", err) 408 | } 409 | functions[i] = function 410 | } 411 | a.nativeFunctionSpecs = functions 412 | return nil 413 | } 414 | 415 | func (a *ChainAgent[T, S]) WithTools(tools ...toolsPkg.Tool) *ChainAgent[T, S] { 416 | err := a.setNativeLLMFunctions(tools...) 417 | if err != nil && !errors.Is(err, errNativeFunctionsUnsupported) { 418 | log.Warnf("failed to set native LLM functions, using fallback: %v", err) 419 | } 420 | for _, tool := range tools { 421 | a.Tools[tool.Name()] = tool 422 | if preprocessor, ok := tool.(toolsPkg.PreprocessingTool); ok { 423 | a.ActionArgPreprocessors = append(a.ActionArgPreprocessors, preprocessor) 424 | } 425 | } 426 | return a 427 | } 428 | 429 | func (a *ChainAgent[T, S]) WithInputValidators(validators ...func(T) error) *ChainAgent[T, S] { 430 | a.InputValidators = append(a.InputValidators, validators...) 431 | return a 432 | } 433 | 434 | func (a *ChainAgent[T, S]) WithOutputValidators(validators ...func(S) error) *ChainAgent[T, S] { 435 | a.OutputValidators = append(a.OutputValidators, validators...) 436 | return a 437 | } 438 | 439 | func (a *ChainAgent[T, S]) WithMaxSolutionAttempts(max int) *ChainAgent[T, S] { 440 | a.MaxSolutionAttempts = max 441 | return a 442 | } 443 | 444 | func (a *ChainAgent[T, S]) WithActionConfirmation(actionConfirmationProvider func(*ChainAgentAction) bool) *ChainAgent[T, S] { 445 | a.ActionConfirmation = actionConfirmationProvider 446 | return a 447 | } 448 | 449 | func (a *ChainAgent[T, S]) WithRestarts(maxRestarts int) *ChainAgent[T, S] { 450 | a.MaxRestarts = maxRestarts 451 | return a 452 | } 453 | -------------------------------------------------------------------------------- /agents/agent_test.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "testing" 7 | 8 | "github.com/golang/mock/gomock" 9 | "github.com/natexcvi/go-llm/engines" 10 | memorymocks "github.com/natexcvi/go-llm/memory/mocks" 11 | "github.com/natexcvi/go-llm/tools" 12 | toolmocks "github.com/natexcvi/go-llm/tools/mocks" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | type MockEngine struct { 18 | Responses []*engines.ChatMessage 19 | Functions []engines.FunctionSpecs 20 | } 21 | 22 | func (engine *MockEngine) Chat(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 23 | if len(engine.Responses) == 0 { 24 | return nil, errors.New("no more responses") 25 | } 26 | response := engine.Responses[0] 27 | engine.Responses = engine.Responses[1:] 28 | return response, nil 29 | } 30 | 31 | func (engine *MockEngine) SetFunctions(funcs ...engines.FunctionSpecs) { 32 | engine.Functions = funcs 33 | } 34 | 35 | type Str string 36 | 37 | func (s *Str) Encode() string { 38 | return string(*s) 39 | } 40 | 41 | func (s *Str) Schema() string { 42 | return "" 43 | } 44 | 45 | func (s *Str) UnmarshalJSON(data []byte) error { 46 | var str string 47 | if err := json.Unmarshal(data, &str); err != nil { 48 | return err 49 | } 50 | *s = Str(str) 51 | return nil 52 | } 53 | 54 | func newStr(str string) *Str { 55 | s := Str(str) 56 | return &s 57 | } 58 | 59 | func newMockMemory(t *testing.T) *memorymocks.MockMemory { 60 | ctrl := gomock.NewController(t) 61 | buffer := make([]*engines.ChatMessage, 0) 62 | mock := memorymocks.NewMockMemory(ctrl) 63 | mock.EXPECT().Add(gomock.Any()).AnyTimes().DoAndReturn(func(msg *engines.ChatMessage) error { 64 | buffer = append(buffer, msg) 65 | return nil 66 | }) 67 | mock.EXPECT().AddPrompt(gomock.Any()).AnyTimes().DoAndReturn(func(prompt *engines.ChatPrompt) error { 68 | buffer = append(buffer, prompt.History...) 69 | return nil 70 | }) 71 | mock.EXPECT().PromptWithContext(gomock.Any()).AnyTimes().DoAndReturn(func(nextMessages ...*engines.ChatMessage) (*engines.ChatPrompt, error) { 72 | buffer = append(buffer, nextMessages...) 73 | return &engines.ChatPrompt{ 74 | History: buffer, 75 | }, nil 76 | }) 77 | return mock 78 | } 79 | 80 | func newMockTool(t *testing.T, name, description string, argSchema json.RawMessage, impl func(json.RawMessage) (json.RawMessage, error)) *toolmocks.MockTool { 81 | ctrl := gomock.NewController(t) 82 | mock := toolmocks.NewMockTool(ctrl) 83 | mock.EXPECT().Execute(gomock.Any()).AnyTimes().DoAndReturn(impl) 84 | mock.EXPECT().Name().AnyTimes().Return(name) 85 | mock.EXPECT().Description().AnyTimes().Return(description) 86 | mock.EXPECT().ArgsSchema().AnyTimes().Return(argSchema) 87 | return mock 88 | } 89 | 90 | func TestChainAgent(t *testing.T) { 91 | testCases := []struct { 92 | name string 93 | agent *ChainAgent[*Str, *Str] 94 | input *Str 95 | output *Str 96 | }{ 97 | { 98 | name: "simple", 99 | agent: &ChainAgent[*Str, *Str]{ 100 | Engine: &MockEngine{ 101 | Responses: []*engines.ChatMessage{ 102 | { 103 | Role: engines.ConvRoleAssistant, 104 | Text: `Action: echo("world")`, 105 | }, 106 | { 107 | Role: engines.ConvRoleAssistant, 108 | Text: `Answer: "Hello world"`, 109 | }, 110 | }, 111 | }, 112 | Task: &Task[*Str, *Str]{ 113 | Description: "Say hello to an entity you find yourself", 114 | AnswerParser: func(text string) (*Str, error) { 115 | var output string 116 | err := json.Unmarshal([]byte(text), &output) 117 | require.NoError(t, err) 118 | return newStr(output), nil 119 | }, 120 | }, 121 | Memory: newMockMemory(t), 122 | Tools: map[string]tools.Tool{ 123 | "echo": newMockTool( 124 | t, 125 | "echo", 126 | "echoes the input", 127 | json.RawMessage(`"the string to echo"`), 128 | func(args json.RawMessage) (json.RawMessage, error) { 129 | return args, nil 130 | }, 131 | ), 132 | }, 133 | OutputValidators: []func(*Str) error{ 134 | func(output *Str) error { 135 | if *output == "" { 136 | return errors.New("output is empty") 137 | } 138 | return nil 139 | }, 140 | }, 141 | }, 142 | input: newStr("hello"), 143 | output: newStr("Hello world"), 144 | }, 145 | { 146 | name: "simple with native LLM functions", 147 | agent: &ChainAgent[*Str, *Str]{ 148 | Engine: &MockEngine{ 149 | Responses: []*engines.ChatMessage{ 150 | { 151 | Role: engines.ConvRoleAssistant, 152 | Text: "", 153 | FunctionCall: &engines.FunctionCall{ 154 | Name: "echo", 155 | Args: `{"msg": "world"}`, 156 | }, 157 | }, 158 | { 159 | Role: engines.ConvRoleAssistant, 160 | Text: `Answer: "Hello world"`, 161 | }, 162 | }, 163 | }, 164 | Task: &Task[*Str, *Str]{ 165 | Description: "Say hello to an entity you find yourself", 166 | AnswerParser: func(text string) (*Str, error) { 167 | var output string 168 | err := json.Unmarshal([]byte(text), &output) 169 | require.NoError(t, err) 170 | return newStr(output), nil 171 | }, 172 | }, 173 | Memory: newMockMemory(t), 174 | Tools: map[string]tools.Tool{ 175 | "echo": newMockTool( 176 | t, 177 | "echo", 178 | "echoes the input", 179 | json.RawMessage(`{"msg": "the string to echo"}`), 180 | func(args json.RawMessage) (json.RawMessage, error) { 181 | var parsedArgs struct { 182 | Msg string `json:"msg"` 183 | } 184 | err := json.Unmarshal(args, &parsedArgs) 185 | require.NoError(t, err) 186 | return json.Marshal(parsedArgs.Msg) 187 | }, 188 | ), 189 | }, 190 | OutputValidators: []func(*Str) error{ 191 | func(output *Str) error { 192 | if *output == "" { 193 | return errors.New("output is empty") 194 | } 195 | return nil 196 | }, 197 | }, 198 | }, 199 | input: newStr("hello"), 200 | output: newStr("Hello world"), 201 | }, 202 | { 203 | name: "empty output makes validator fail", 204 | agent: &ChainAgent[*Str, *Str]{ 205 | Engine: &MockEngine{ 206 | Responses: []*engines.ChatMessage{ 207 | { 208 | Role: engines.ConvRoleAssistant, 209 | Text: `Action: echo("world")`, 210 | }, 211 | { 212 | Role: engines.ConvRoleAssistant, 213 | Text: `Answer: ""`, 214 | }, 215 | { 216 | Role: engines.ConvRoleAssistant, 217 | Text: `Thought: That's right, the output is empty. I'll fix it`, 218 | }, 219 | { 220 | Role: engines.ConvRoleAssistant, 221 | Text: `Answer: "Hello world"`, 222 | }, 223 | }, 224 | }, 225 | Task: &Task[*Str, *Str]{ 226 | Description: "Say hello to an entity you find yourself", 227 | AnswerParser: func(text string) (*Str, error) { 228 | var output string 229 | err := json.Unmarshal([]byte(text), &output) 230 | require.NoError(t, err) 231 | return newStr(output), nil 232 | }, 233 | }, 234 | Memory: newMockMemory(t), 235 | Tools: map[string]tools.Tool{ 236 | "echo": newMockTool( 237 | t, 238 | "echo", 239 | "echoes the input", 240 | json.RawMessage(`"the string to echo"`), 241 | func(args json.RawMessage) (json.RawMessage, error) { 242 | return args, nil 243 | }, 244 | ), 245 | }, 246 | OutputValidators: []func(*Str) error{ 247 | func(output *Str) error { 248 | if *output == "" { 249 | return errors.New("output is empty") 250 | } 251 | return nil 252 | }, 253 | }, 254 | }, 255 | input: newStr("hello"), 256 | output: newStr("Hello world"), 257 | }, 258 | } 259 | 260 | for _, testCase := range testCases { 261 | t.Run(testCase.name, func(t *testing.T) { 262 | output, err := testCase.agent.Run(testCase.input) 263 | require.NoError(t, err) 264 | assert.Equal(t, *testCase.output, *output) 265 | }) 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /agents/agent_tool.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/natexcvi/go-llm/engines" 9 | "github.com/natexcvi/go-llm/memory" 10 | "github.com/natexcvi/go-llm/tools" 11 | "github.com/samber/lo" 12 | ) 13 | 14 | type genericRequest struct { 15 | TaskDescription string `json:"task"` 16 | Input string `json:"input"` 17 | } 18 | 19 | func (r genericRequest) Encode() string { 20 | return r.Input 21 | } 22 | 23 | func (r genericRequest) Schema() string { 24 | return `{"task": "a description of the task you want to give the agent, including helpful examples.", "input": "the specific input on which the agent should act."}` 25 | } 26 | 27 | type genericResponse struct { 28 | output string 29 | } 30 | 31 | func (r genericResponse) Encode() string { 32 | return r.output 33 | } 34 | 35 | func (r genericResponse) Schema() string { 36 | return "" 37 | } 38 | 39 | type GenericAgentTool struct { 40 | engine engines.LLM 41 | tools []tools.Tool 42 | } 43 | 44 | func (ga *GenericAgentTool) Name() string { 45 | return "smart_agent" 46 | } 47 | 48 | func (ga *GenericAgentTool) Description() string { 49 | return "A smart agent you can delegate tasks to. Use for relatively larger tasks." + 50 | lo.If( 51 | len(ga.tools) > 0, 52 | " The agent will have access to the following tools: "+ 53 | strings.Join(lo.Map(ga.tools, func(tool tools.Tool, _ int) string { 54 | return tool.Name() 55 | }), ", ")+".", 56 | ).Else("") 57 | } 58 | 59 | func (ga *GenericAgentTool) Execute(args json.RawMessage) (json.RawMessage, error) { 60 | var request genericRequest 61 | err := json.Unmarshal(args, &request) 62 | if err != nil { 63 | return nil, fmt.Errorf("invalid arguments: %s", err.Error()) 64 | } 65 | task := &Task[genericRequest, genericResponse]{ 66 | Description: request.TaskDescription, 67 | Examples: []Example[genericRequest, genericResponse]{}, 68 | AnswerParser: func(res string) (genericResponse, error) { 69 | return genericResponse{res}, nil 70 | }, 71 | } 72 | agent := NewChainAgent(ga.engine, task, memory.NewBufferedMemory(10)).WithTools(ga.tools...) 73 | response, err := agent.Run(request) 74 | if err != nil { 75 | return nil, fmt.Errorf("error running agent: %s", err.Error()) 76 | } 77 | return json.Marshal(response.output) 78 | } 79 | 80 | func (ga *GenericAgentTool) ArgsSchema() json.RawMessage { 81 | return []byte(`{"task": "a description of the task you want to give the agent, including helpful examples.", "input": "the specific input on which the agent should act."}`) 82 | } 83 | 84 | func (ga *GenericAgentTool) CompactArgs(args json.RawMessage) json.RawMessage { 85 | return args 86 | } 87 | 88 | func NewGenericAgentTool(engine engines.LLM, tools []tools.Tool) *GenericAgentTool { 89 | return &GenericAgentTool{engine: engine, tools: tools} 90 | } 91 | -------------------------------------------------------------------------------- /agents/task.go: -------------------------------------------------------------------------------- 1 | package agents 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/natexcvi/go-llm/engines" 8 | "github.com/natexcvi/go-llm/tools" 9 | "github.com/samber/lo" 10 | ) 11 | 12 | type Representable interface { 13 | Encode() string 14 | Schema() string 15 | } 16 | 17 | type Example[T Representable, S Representable] struct { 18 | Input T 19 | IntermediarySteps []*engines.ChatMessage 20 | Answer S 21 | } 22 | 23 | type Task[T Representable, S Representable] struct { 24 | Description string 25 | Examples []Example[T, S] 26 | AnswerParser func(string) (S, error) 27 | } 28 | 29 | func (task *Task[T, S]) Compile(input T, tools map[string]tools.Tool) *engines.ChatPrompt { 30 | answerSchema := lo.IfF( 31 | task.Examples != nil && len(task.Examples) > 0, 32 | func() string { return task.Examples[0].Answer.Schema() }, 33 | ).Else("") 34 | 35 | prompt := &engines.ChatPrompt{ 36 | History: []*engines.ChatMessage{ 37 | { 38 | Role: engines.ConvRoleSystem, 39 | Text: fmt.Sprintf("You are a smart, autonomous agent given the task below. "+ 40 | "You will be given input from the user in the following format:\n\n"+ 41 | "%s\n\n Complete the task step-by-step, "+ 42 | "reasoning about your solution steps by sending thought messages in format "+ 43 | "`%s: (your reflection)%s`. When you are ready to return your response, "+ 44 | "send an answer message in format `%s: %s%s`. Remember: you are on your own - "+ 45 | "do not ask for any clarifications, except by using appropriate tools "+ 46 | "that enable interaction with the user. You should determine when you are "+ 47 | "done with the task, and return your answer.", 48 | input.Schema(), ThoughtCode, EndMarker, AnswerCode, answerSchema, EndMarker), 49 | }, 50 | }, 51 | } 52 | task.enrichPromptWithTools(tools, prompt) 53 | prompt.History = append(prompt.History, &engines.ChatMessage{ 54 | Role: engines.ConvRoleUser, 55 | Text: task.Description, 56 | }) 57 | task.enrichPromptWithExamples(prompt) 58 | prompt.History = append(prompt.History, &engines.ChatMessage{ 59 | Role: engines.ConvRoleUser, 60 | Text: input.Encode(), 61 | }) 62 | return prompt 63 | } 64 | 65 | func (task *Task[T, S]) enrichPromptWithExamples(prompt *engines.ChatPrompt) { 66 | if len(task.Examples) == 0 { 67 | return 68 | } 69 | for _, example := range task.Examples { 70 | prompt.History = append(prompt.History, &engines.ChatMessage{ 71 | Role: engines.ConvRoleUser, 72 | Text: example.Input.Encode(), 73 | }) 74 | for _, step := range example.IntermediarySteps { 75 | prompt.History = append(prompt.History, step) 76 | } 77 | answerRepresentation := example.Answer.Encode() 78 | prompt.History = append(prompt.History, &engines.ChatMessage{ 79 | Role: engines.ConvRoleAssistant, 80 | Text: fmt.Sprintf(MessageFormat, AnswerCode, answerRepresentation), 81 | }) 82 | } 83 | } 84 | 85 | func (*Task[T, S]) enrichPromptWithTools(tools map[string]tools.Tool, prompt *engines.ChatPrompt) { 86 | if len(tools) == 0 { 87 | return 88 | } 89 | toolsList := make([]string, 0, len(tools)) 90 | for name, tool := range tools { 91 | toolsList = append(toolsList, fmt.Sprintf("%s(%s) # %s", name, tool.ArgsSchema(), tool.Description())) 92 | } 93 | prompt.History = append(prompt.History, &engines.ChatMessage{ 94 | Role: engines.ConvRoleSystem, 95 | Text: fmt.Sprintf("Here are some tools you can use. To use a tool, "+ 96 | "send a message in the form of `%s: tool_name(args)%s`, "+ 97 | "where `args` is a valid one-line JSON representation of the arguments"+ 98 | " to the tool, as specified for it. You will get "+ 99 | "the output in "+ 100 | "a message beginning with `%s: `, or an error message beginning "+ 101 | "with `%s: `.\n\nTools:\n%s", 102 | ActionCode, EndMarker, ObservationCode, ErrorCode, strings.Join(toolsList, "\n")), 103 | }) 104 | } 105 | -------------------------------------------------------------------------------- /engines/engine.go: -------------------------------------------------------------------------------- 1 | package engines 2 | 3 | //go:generate mockgen -source=engine.go -destination=mocks/engine.go -package=mocks 4 | type LLM interface { 5 | Chat(prompt *ChatPrompt) (*ChatMessage, error) 6 | } 7 | 8 | type LLMWithFunctionCalls interface { 9 | LLM 10 | ChatWithFunctions(prompt *ChatPrompt, functions []FunctionSpecs) (*ChatMessage, error) 11 | } 12 | 13 | type ParameterSpecs struct { 14 | Type string `json:"type"` 15 | Description string `json:"description,omitempty"` 16 | Properties map[string]*ParameterSpecs `json:"properties,omitempty"` 17 | Required []string `json:"required,omitempty"` 18 | Items *ParameterSpecs `json:"items,omitempty"` 19 | Enum []any `json:"enum,omitempty"` 20 | } 21 | 22 | type FunctionSpecs struct { 23 | Name string `json:"name"` 24 | Description string `json:"description"` 25 | Parameters *ParameterSpecs `json:"parameters"` 26 | } 27 | -------------------------------------------------------------------------------- /engines/mocks/engine.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: engine.go 3 | 4 | // Package mocks is a generated GoMock package. 5 | package mocks 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | engines "github.com/natexcvi/go-llm/engines" 12 | ) 13 | 14 | // MockLLM is a mock of LLM interface. 15 | type MockLLM struct { 16 | ctrl *gomock.Controller 17 | recorder *MockLLMMockRecorder 18 | } 19 | 20 | // MockLLMMockRecorder is the mock recorder for MockLLM. 21 | type MockLLMMockRecorder struct { 22 | mock *MockLLM 23 | } 24 | 25 | // NewMockLLM creates a new mock instance. 26 | func NewMockLLM(ctrl *gomock.Controller) *MockLLM { 27 | mock := &MockLLM{ctrl: ctrl} 28 | mock.recorder = &MockLLMMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockLLM) EXPECT() *MockLLMMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // Chat mocks base method. 38 | func (m *MockLLM) Chat(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "Chat", prompt) 41 | ret0, _ := ret[0].(*engines.ChatMessage) 42 | ret1, _ := ret[1].(error) 43 | return ret0, ret1 44 | } 45 | 46 | // Chat indicates an expected call of Chat. 47 | func (mr *MockLLMMockRecorder) Chat(prompt interface{}) *gomock.Call { 48 | mr.mock.ctrl.T.Helper() 49 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chat", reflect.TypeOf((*MockLLM)(nil).Chat), prompt) 50 | } 51 | 52 | // MockLLMWithFunctionCalls is a mock of LLMWithFunctionCalls interface. 53 | type MockLLMWithFunctionCalls struct { 54 | ctrl *gomock.Controller 55 | recorder *MockLLMWithFunctionCallsMockRecorder 56 | } 57 | 58 | // MockLLMWithFunctionCallsMockRecorder is the mock recorder for MockLLMWithFunctionCalls. 59 | type MockLLMWithFunctionCallsMockRecorder struct { 60 | mock *MockLLMWithFunctionCalls 61 | } 62 | 63 | // NewMockLLMWithFunctionCalls creates a new mock instance. 64 | func NewMockLLMWithFunctionCalls(ctrl *gomock.Controller) *MockLLMWithFunctionCalls { 65 | mock := &MockLLMWithFunctionCalls{ctrl: ctrl} 66 | mock.recorder = &MockLLMWithFunctionCallsMockRecorder{mock} 67 | return mock 68 | } 69 | 70 | // EXPECT returns an object that allows the caller to indicate expected use. 71 | func (m *MockLLMWithFunctionCalls) EXPECT() *MockLLMWithFunctionCallsMockRecorder { 72 | return m.recorder 73 | } 74 | 75 | // Chat mocks base method. 76 | func (m *MockLLMWithFunctionCalls) Chat(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 77 | m.ctrl.T.Helper() 78 | ret := m.ctrl.Call(m, "Chat", prompt) 79 | ret0, _ := ret[0].(*engines.ChatMessage) 80 | ret1, _ := ret[1].(error) 81 | return ret0, ret1 82 | } 83 | 84 | // Chat indicates an expected call of Chat. 85 | func (mr *MockLLMWithFunctionCallsMockRecorder) Chat(prompt interface{}) *gomock.Call { 86 | mr.mock.ctrl.T.Helper() 87 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chat", reflect.TypeOf((*MockLLMWithFunctionCalls)(nil).Chat), prompt) 88 | } 89 | 90 | // ChatWithFunctions mocks base method. 91 | func (m *MockLLMWithFunctionCalls) ChatWithFunctions(prompt *engines.ChatPrompt, functions []engines.FunctionSpecs) (*engines.ChatMessage, error) { 92 | m.ctrl.T.Helper() 93 | ret := m.ctrl.Call(m, "ChatWithFunctions", prompt, functions) 94 | ret0, _ := ret[0].(*engines.ChatMessage) 95 | ret1, _ := ret[1].(error) 96 | return ret0, ret1 97 | } 98 | 99 | // ChatWithFunctions indicates an expected call of ChatWithFunctions. 100 | func (mr *MockLLMWithFunctionCallsMockRecorder) ChatWithFunctions(prompt, functions interface{}) *gomock.Call { 101 | mr.mock.ctrl.T.Helper() 102 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChatWithFunctions", reflect.TypeOf((*MockLLMWithFunctionCalls)(nil).ChatWithFunctions), prompt, functions) 103 | } 104 | -------------------------------------------------------------------------------- /engines/openai.go: -------------------------------------------------------------------------------- 1 | package engines 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | var ErrTokenLimitExceeded = fmt.Errorf("token limit exceeded") 12 | var OpenAIBaseURL = "https://api.openai.com" 13 | 14 | type GPT struct { 15 | APIToken string 16 | Model string 17 | PromptTokensUsed int 18 | CompletionTokensUsed int 19 | PromptTokenLimit int 20 | CompletionTokenLimit int 21 | TotalTokenLimit int 22 | Temperature float64 23 | } 24 | 25 | type ChatCompletionRequest struct { 26 | Model string `json:"model"` 27 | Temperature float64 `json:"temperature,omitempty"` 28 | Messages []*ChatMessage `json:"messages"` 29 | Functions []FunctionSpecs `json:"functions,omitempty"` 30 | } 31 | 32 | type ChatCompletionResponse struct { 33 | Choices []struct { 34 | Message *ChatMessage `json:"message"` 35 | } `json:"choices"` 36 | Usage struct { 37 | PromptTokensUsed int `json:"prompt_tokens"` 38 | CompletionTokensUsed int `json:"completion_tokens"` 39 | } `json:"usage"` 40 | } 41 | 42 | func (gpt *GPT) chat(prompt *ChatPrompt, functions []FunctionSpecs) (*ChatMessage, error) { 43 | if gpt.isLimitExceeded() { 44 | return nil, ErrTokenLimitExceeded 45 | } 46 | completionRequest := ChatCompletionRequest{ 47 | Model: gpt.Model, 48 | Messages: prompt.History, 49 | Temperature: gpt.Temperature, 50 | } 51 | if len(functions) > 0 { 52 | completionRequest.Functions = functions 53 | } 54 | bodyJSON, err := json.Marshal(completionRequest) 55 | if err != nil { 56 | return nil, err 57 | } 58 | req, err := http.NewRequest( 59 | "POST", 60 | fmt.Sprintf("%s/v1/chat/completions", OpenAIBaseURL), 61 | bytes.NewBuffer([]byte(bodyJSON)), 62 | ) 63 | if err != nil { 64 | return nil, err 65 | } 66 | req.Header.Add("Authorization", "Bearer "+gpt.APIToken) 67 | req.Header.Add("Content-Type", "application/json") 68 | res, err := http.DefaultClient.Do(req) 69 | if err != nil { 70 | return nil, err 71 | } 72 | defer res.Body.Close() 73 | return gpt.parseResponseBody(res.Body) 74 | } 75 | 76 | func (gpt *GPT) ChatWithFunctions(prompt *ChatPrompt, functions []FunctionSpecs) (*ChatMessage, error) { 77 | return gpt.chat(prompt, functions) 78 | } 79 | 80 | func (gpt *GPT) Chat(prompt *ChatPrompt) (*ChatMessage, error) { 81 | return gpt.chat(prompt, nil) 82 | } 83 | 84 | func (gpt *GPT) isLimitExceeded() bool { 85 | return gpt.PromptTokenLimit > 0 && gpt.PromptTokensUsed > gpt.PromptTokenLimit || 86 | gpt.CompletionTokenLimit > 0 && gpt.CompletionTokensUsed > gpt.CompletionTokenLimit || 87 | gpt.TotalTokenLimit > 0 && gpt.PromptTokensUsed+gpt.CompletionTokensUsed > gpt.TotalTokenLimit 88 | } 89 | 90 | func (gpt *GPT) parseResponseBody(body io.Reader) (*ChatMessage, error) { 91 | var buf bytes.Buffer 92 | tee := io.TeeReader(body, &buf) 93 | var response ChatCompletionResponse 94 | err := json.NewDecoder(tee).Decode(&response) 95 | if err != nil { 96 | return nil, err 97 | } 98 | gpt.PromptTokensUsed += response.Usage.PromptTokensUsed 99 | gpt.CompletionTokensUsed += response.Usage.CompletionTokensUsed 100 | if len(response.Choices) == 0 { 101 | return nil, fmt.Errorf("no choices in response: %s", buf.String()) 102 | } 103 | if response.Choices[0].Message.FunctionCall == nil && response.Choices[0].Message.Text == "" { 104 | return nil, fmt.Errorf("no content in response: %s", buf.String()) 105 | } 106 | return response.Choices[0].Message, nil 107 | } 108 | 109 | func NewGPTEngine(apiToken string, model string) *GPT { 110 | return &GPT{ 111 | APIToken: apiToken, 112 | Model: model, 113 | Temperature: 1, 114 | } 115 | } 116 | 117 | func (gpt *GPT) WithPromptTokenLimit(limit int) *GPT { 118 | gpt.PromptTokenLimit = limit 119 | return gpt 120 | } 121 | 122 | func (gpt *GPT) WithCompletionTokenLimit(limit int) *GPT { 123 | gpt.CompletionTokenLimit = limit 124 | return gpt 125 | } 126 | 127 | func (gpt *GPT) WithTotalTokenLimit(limit int) *GPT { 128 | gpt.TotalTokenLimit = limit 129 | return gpt 130 | } 131 | 132 | func (gpt *GPT) WithTemperature(temperature float64) *GPT { 133 | gpt.Temperature = temperature 134 | return gpt 135 | } 136 | -------------------------------------------------------------------------------- /engines/prompt.go: -------------------------------------------------------------------------------- 1 | package engines 2 | 3 | type ConvRole string 4 | 5 | const ( 6 | ConvRoleUser ConvRole = "user" 7 | ConvRoleSystem ConvRole = "system" 8 | ConvRoleAssistant ConvRole = "assistant" 9 | ConvRoleFunction ConvRole = "function" 10 | ) 11 | 12 | type ChatMessage struct { 13 | Role ConvRole `json:"role"` 14 | Text string `json:"content"` 15 | FunctionCall *FunctionCall `json:"function_call,omitempty"` 16 | Name string `json:"name,omitempty"` 17 | } 18 | 19 | type FunctionCall struct { 20 | Name string `json:"name"` 21 | Args string `json:"arguments"` 22 | } 23 | 24 | type ChatPrompt struct { 25 | History []*ChatMessage 26 | } 27 | -------------------------------------------------------------------------------- /evaluation/agent_evaluator.go: -------------------------------------------------------------------------------- 1 | package evaluation 2 | 3 | import ( 4 | "github.com/natexcvi/go-llm/agents" 5 | ) 6 | 7 | type agentRunner[Input, Output any] struct { 8 | agent agents.Agent[Input, Output] 9 | } 10 | 11 | // Returns a new agent runner that can be used to evaluate the output. 12 | func NewAgentRunner[Input, Output any](agent agents.Agent[Input, Output]) Runner[Input, Output] { 13 | return &agentRunner[Input, Output]{ 14 | agent: agent, 15 | } 16 | } 17 | 18 | func (t *agentRunner[Input, Output]) Run(input Input) (Output, error) { 19 | return t.agent.Run(input) 20 | } 21 | -------------------------------------------------------------------------------- /evaluation/evaluator.go: -------------------------------------------------------------------------------- 1 | package evaluation 2 | 3 | import ( 4 | "fmt" 5 | "github.com/samber/mo" 6 | ) 7 | 8 | // GoodnessFunction is a function that takes an input, an output and an error (if one occurred) and returns a float64 9 | // which represents the goodness score of the output. 10 | type GoodnessFunction[Input, Output any] func(input Input, output Output, err error) float64 11 | 12 | // Options is a struct that contains the options for the evaluator. 13 | type Options[Input, Output any] struct { 14 | // The goodness function that will be used to evaluate the output. 15 | GoodnessFunction GoodnessFunction[Input, Output] 16 | // The number of times the test will be repeated. The goodness level of each output will be 17 | // averaged. 18 | Repetitions int 19 | } 20 | 21 | // Runner is an interface that represents a test runner that will be used to evaluate the output. 22 | // It takes an input and returns an output and an error. 23 | type Runner[Input, Output any] interface { 24 | Run(input Input) (Output, error) 25 | } 26 | 27 | // Evaluator is a struct that runs the tests and evaluates the outputs. 28 | type Evaluator[Input, Output any] struct { 29 | options *Options[Input, Output] 30 | runner Runner[Input, Output] 31 | } 32 | 33 | // Creates a new `Evaluator` with the provided configuration. 34 | func NewEvaluator[Input, Output any](runner Runner[Input, Output], options *Options[Input, Output]) *Evaluator[Input, Output] { 35 | return &Evaluator[Input, Output]{ 36 | options: options, 37 | runner: runner, 38 | } 39 | } 40 | 41 | // Runs the tests and evaluates the outputs. The function receives a test pack 42 | // which is a slice of inputs and returns a slice of float64 which represents the goodness level 43 | // of each respective output. 44 | func (e *Evaluator[Input, Output]) Evaluate(testPack []Input) []float64 { 45 | repetitionChannels := make([]chan []float64, e.options.Repetitions) 46 | 47 | for i := 0; i < e.options.Repetitions; i++ { 48 | repetitionChannels[i] = make(chan []float64) 49 | go func(i int) { 50 | report, err := e.evaluate(testPack) 51 | if err != nil { 52 | repetitionChannels[i] <- nil 53 | return 54 | } 55 | repetitionChannels[i] <- report 56 | }(i) 57 | } 58 | 59 | responses := make([][]float64, e.options.Repetitions) 60 | for i := 0; i < e.options.Repetitions; i++ { 61 | responses[i] = <-repetitionChannels[i] 62 | } 63 | 64 | report := make([]float64, len(testPack)) 65 | for i := 0; i < len(testPack); i++ { 66 | sum := 0.0 67 | for j := 0; j < e.options.Repetitions; j++ { 68 | sum += responses[j][i] 69 | } 70 | report[i] = sum / float64(e.options.Repetitions) 71 | } 72 | 73 | return report 74 | } 75 | 76 | func (e *Evaluator[Input, Output]) evaluate(testPack []Input) ([]float64, error) { 77 | responses, err := e.test(testPack) 78 | if err != nil { 79 | return nil, fmt.Errorf("failed to test: %w", err) 80 | } 81 | 82 | report := make([]float64, len(testPack)) 83 | for i, response := range responses { 84 | res, resErr := response.Get() 85 | report[i] = e.options.GoodnessFunction(testPack[i], res, resErr) 86 | } 87 | 88 | return report, nil 89 | } 90 | 91 | func (e *Evaluator[Input, Output]) test(testPack []Input) ([]mo.Result[Output], error) { 92 | responses := make([]mo.Result[Output], len(testPack)) 93 | 94 | for i, test := range testPack { 95 | response, err := e.runner.Run(test) 96 | if err != nil { 97 | responses[i] = mo.Err[Output](err) 98 | } else { 99 | responses[i] = mo.Ok(response) 100 | } 101 | } 102 | 103 | return responses, nil 104 | } 105 | -------------------------------------------------------------------------------- /evaluation/evaluator_test.go: -------------------------------------------------------------------------------- 1 | package evaluation 2 | 3 | import ( 4 | "errors" 5 | "github.com/golang/mock/gomock" 6 | "github.com/natexcvi/go-llm/engines" 7 | "github.com/natexcvi/go-llm/engines/mocks" 8 | "github.com/samber/lo" 9 | "github.com/stretchr/testify/assert" 10 | "math" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | func createMockEchoLLM(t *testing.T) engines.LLM { 16 | t.Helper() 17 | ctrl := gomock.NewController(t) 18 | mock := mocks.NewMockLLM(ctrl) 19 | mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 20 | return &engines.ChatMessage{ 21 | Text: prompt.History[0].Text, 22 | }, nil 23 | }).AnyTimes() 24 | return mock 25 | } 26 | 27 | func createMockIncrementalLLM(t *testing.T) engines.LLM { 28 | t.Helper() 29 | ctrl := gomock.NewController(t) 30 | mock := mocks.NewMockLLM(ctrl) 31 | counters := make(map[string]int) 32 | mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 33 | counters[prompt.History[0].Text]++ 34 | return &engines.ChatMessage{ 35 | Text: strings.Repeat(prompt.History[0].Text, counters[prompt.History[0].Text]), 36 | }, nil 37 | }).AnyTimes() 38 | return mock 39 | } 40 | 41 | func createMockExponentialLLM(t *testing.T) engines.LLM { 42 | t.Helper() 43 | ctrl := gomock.NewController(t) 44 | mock := mocks.NewMockLLM(ctrl) 45 | counters := make(map[string]int) 46 | mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 47 | counters[prompt.History[0].Text]++ 48 | return &engines.ChatMessage{ 49 | Text: strings.Repeat(prompt.History[0].Text, int(math.Pow(float64(len(prompt.History[0].Text)), float64(counters[prompt.History[0].Text]+1)))), 50 | }, nil 51 | }).AnyTimes() 52 | return mock 53 | } 54 | 55 | func createMockOddErrorLLM(t *testing.T) engines.LLM { 56 | t.Helper() 57 | ctrl := gomock.NewController(t) 58 | mock := mocks.NewMockLLM(ctrl) 59 | counters := make(map[string]int) 60 | mock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 61 | counters[prompt.History[0].Text]++ 62 | if counters[prompt.History[0].Text]%2 == 1 { 63 | return nil, errors.New("error") 64 | } 65 | return &engines.ChatMessage{ 66 | Text: "OK!", 67 | }, nil 68 | }).AnyTimes() 69 | return mock 70 | } 71 | 72 | func TestLLMEvaluator(t *testing.T) { 73 | tests := []struct { 74 | name string 75 | options *Options[*engines.ChatPrompt, *engines.ChatMessage] 76 | engine engines.LLM 77 | testPack []*engines.ChatPrompt 78 | want []float64 79 | }{ 80 | { 81 | name: "Test echo engine with response length goodness and 1 repetition", 82 | options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ 83 | GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { 84 | return float64(len(response.Text)) 85 | }, 86 | Repetitions: 1, 87 | }, 88 | engine: createMockEchoLLM(t), 89 | testPack: []*engines.ChatPrompt{ 90 | { 91 | History: []*engines.ChatMessage{ 92 | { 93 | Text: "Hello", 94 | }, 95 | }, 96 | }, 97 | { 98 | History: []*engines.ChatMessage{ 99 | { 100 | Text: "Hello Hello", 101 | }, 102 | }, 103 | }, 104 | { 105 | History: []*engines.ChatMessage{ 106 | { 107 | Text: "Hello Hello Hello Hello", 108 | }, 109 | }, 110 | }, 111 | { 112 | History: []*engines.ChatMessage{ 113 | { 114 | Text: "Hello Hello Hello Hello Hello Hello", 115 | }, 116 | }, 117 | }, 118 | }, 119 | want: []float64{5, 11, 23, 35}, 120 | }, 121 | { 122 | name: "Test echo engine with response length goodness and 5 repetitions", 123 | options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ 124 | GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { 125 | return float64(len(response.Text)) 126 | }, 127 | Repetitions: 5, 128 | }, 129 | engine: createMockEchoLLM(t), 130 | testPack: []*engines.ChatPrompt{ 131 | { 132 | History: []*engines.ChatMessage{ 133 | { 134 | Text: "Hello", 135 | }, 136 | }, 137 | }, 138 | { 139 | History: []*engines.ChatMessage{ 140 | { 141 | Text: "Hello Hello", 142 | }, 143 | }, 144 | }, 145 | { 146 | History: []*engines.ChatMessage{ 147 | { 148 | Text: "Hello Hello Hello Hello", 149 | }, 150 | }, 151 | }, 152 | { 153 | History: []*engines.ChatMessage{ 154 | { 155 | Text: "Hello Hello Hello Hello Hello Hello", 156 | }, 157 | }, 158 | }, 159 | }, 160 | want: []float64{5, 11, 23, 35}, 161 | }, 162 | { 163 | name: "Test incremental engine with response length goodness and 5 repetitions", 164 | options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ 165 | GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { 166 | return float64(len(response.Text)) 167 | }, 168 | Repetitions: 5, 169 | }, 170 | engine: createMockIncrementalLLM(t), 171 | testPack: []*engines.ChatPrompt{ 172 | { 173 | History: []*engines.ChatMessage{ 174 | { 175 | Text: "a", 176 | }, 177 | }, 178 | }, 179 | { 180 | History: []*engines.ChatMessage{ 181 | { 182 | Text: "aa", 183 | }, 184 | }, 185 | }, 186 | { 187 | History: []*engines.ChatMessage{ 188 | { 189 | Text: "aaa", 190 | }, 191 | }, 192 | }, 193 | { 194 | History: []*engines.ChatMessage{ 195 | { 196 | Text: "aaaa", 197 | }, 198 | }, 199 | }, 200 | }, 201 | want: []float64{3, 6, 9, 12}, 202 | }, 203 | { 204 | name: "Test exponential engine with response length goodness and 4 repetitions", 205 | options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ 206 | GoodnessFunction: func(_ *engines.ChatPrompt, response *engines.ChatMessage, _ error) float64 { 207 | return float64(len(response.Text)) 208 | }, 209 | Repetitions: 4, 210 | }, 211 | engine: createMockExponentialLLM(t), 212 | testPack: []*engines.ChatPrompt{ 213 | { 214 | History: []*engines.ChatMessage{ 215 | { 216 | Text: "a", 217 | }, 218 | }, 219 | }, 220 | { 221 | History: []*engines.ChatMessage{ 222 | { 223 | Text: "aa", 224 | }, 225 | }, 226 | }, 227 | { 228 | History: []*engines.ChatMessage{ 229 | { 230 | Text: "aaa", 231 | }, 232 | }, 233 | }, 234 | { 235 | History: []*engines.ChatMessage{ 236 | { 237 | Text: "aaaa", 238 | }, 239 | }, 240 | }, 241 | }, 242 | want: []float64{1, 30, 270, 1360}, 243 | }, 244 | { 245 | name: "Test error engine with dummy error goodness and 4 repetitions", 246 | options: &Options[*engines.ChatPrompt, *engines.ChatMessage]{ 247 | GoodnessFunction: func(_ *engines.ChatPrompt, _ *engines.ChatMessage, err error) float64 { 248 | return lo.If(err == nil, 100.0).Else(0.0) 249 | }, 250 | Repetitions: 4, 251 | }, 252 | engine: createMockOddErrorLLM(t), 253 | testPack: []*engines.ChatPrompt{ 254 | { 255 | History: []*engines.ChatMessage{ 256 | { 257 | Text: "a", 258 | }, 259 | }, 260 | }, 261 | { 262 | History: []*engines.ChatMessage{ 263 | { 264 | Text: "aa", 265 | }, 266 | }, 267 | }, 268 | }, 269 | want: []float64{50, 50}, 270 | }, 271 | } 272 | 273 | for _, tt := range tests { 274 | t.Run(tt.name, func(t *testing.T) { 275 | runner := NewLLMRunner(tt.engine) 276 | evaluator := NewEvaluator(runner, tt.options) 277 | 278 | got := evaluator.Evaluate(tt.testPack) 279 | 280 | assert.Equal(t, tt.want, got) 281 | }) 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /evaluation/llm_evaluator.go: -------------------------------------------------------------------------------- 1 | package evaluation 2 | 3 | import "github.com/natexcvi/go-llm/engines" 4 | 5 | type llmRunner struct { 6 | llm engines.LLM 7 | } 8 | 9 | // Returns a new llm runner that can be used to evaluate the output. 10 | func NewLLMRunner(llm engines.LLM) Runner[*engines.ChatPrompt, *engines.ChatMessage] { 11 | return &llmRunner{ 12 | llm: llm, 13 | } 14 | } 15 | 16 | func (t *llmRunner) Run(input *engines.ChatPrompt) (*engines.ChatMessage, error) { 17 | return t.llm.Chat(input) 18 | } 19 | -------------------------------------------------------------------------------- /go-llm-cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "time" 9 | 10 | "github.com/AlecAivazis/survey/v2" 11 | "github.com/briandowns/spinner" 12 | "github.com/natexcvi/go-llm/agents" 13 | "github.com/natexcvi/go-llm/engines" 14 | "github.com/natexcvi/go-llm/prebuilt" 15 | "github.com/natexcvi/go-llm/tools" 16 | "github.com/samber/lo" 17 | log "github.com/sirupsen/logrus" 18 | "github.com/spf13/cobra" 19 | ) 20 | 21 | var ( 22 | tokenLimit int 23 | gptModel string 24 | ) 25 | 26 | var rootCmd = &cobra.Command{ 27 | Use: "go-llm", 28 | Short: "A CLI for using the prebuilt agents.", 29 | } 30 | 31 | var codeRefactorAgent = &cobra.Command{ 32 | Use: "code-refactor CODE_BASE_DIR GOAL", 33 | Short: "A code refactoring assistant.", 34 | Run: func(cmd *cobra.Command, args []string) { 35 | codeBaseDir := args[0] 36 | goal := args[1] 37 | if err := validateDirectory(codeBaseDir); err != nil { 38 | log.Error(err) 39 | return 40 | } 41 | apiKey := os.Getenv("OPENAI_API_KEY") 42 | if apiKey == "" { 43 | log.Errorf("OPENAI_API_KEY environment variable not set") 44 | return 45 | } 46 | engine := engines.NewGPTEngine(apiKey, gptModel) 47 | agent := prebuilt.NewCodeRefactorAgent(engine) 48 | res, err := agent.Run(prebuilt.CodeBaseRefactorRequest{ 49 | Dir: codeBaseDir, 50 | Goal: goal, 51 | }) 52 | if err != nil { 53 | log.Error(err) 54 | return 55 | } 56 | log.Info(res) 57 | }, 58 | Args: cobra.ExactArgs(2), 59 | ValidArgs: []string{"code-base-dir", "goal"}, 60 | } 61 | 62 | var tradeAssistantAgent = &cobra.Command{ 63 | Use: "trade-assistant STOCK...", 64 | Short: "A stock trading assistant.", 65 | Run: func(cmd *cobra.Command, args []string) { 66 | apiKey := os.Getenv("OPENAI_API_KEY") 67 | if apiKey == "" { 68 | log.Errorf("OPENAI_API_KEY environment variable not set") 69 | return 70 | } 71 | wolframAppID := os.Getenv("WOLFRAM_APP_ID") 72 | if wolframAppID == "" { 73 | log.Errorf("WOLFRAM_APP_ID environment variable not set") 74 | return 75 | } 76 | engine := engines.NewGPTEngine(apiKey, gptModel) 77 | agent := prebuilt.NewTradeAssistantAgent(engine, wolframAppID) 78 | res, err := agent.Run(prebuilt.TradeAssistantRequest{ 79 | Stocks: args, 80 | }) 81 | if err != nil { 82 | log.Error(err) 83 | return 84 | } 85 | log.Info(res) 86 | }, 87 | Args: cobra.MinimumNArgs(1), 88 | } 89 | 90 | func readFile(filePath string) (content string) { 91 | contentBytes, err := os.ReadFile(filePath) 92 | if err != nil { 93 | log.Fatal(err) 94 | } 95 | return string(contentBytes) 96 | } 97 | 98 | var unitTestWriter = &cobra.Command{ 99 | Use: "unit-test-writer SOURCE_FILE EXAMPLE_FILE", 100 | Short: "A tool for writing unit tests.", 101 | Long: `A tool for writing unit tests. 102 | Example usage: 103 | go-llm unit-test-writer source.py example.py 104 | Where source.py is where the source code 105 | to be tested is located, and example.py 106 | is an example unit test file. 107 | `, 108 | Run: func(cmd *cobra.Command, args []string) { 109 | apiKey := os.Getenv("OPENAI_API_KEY") 110 | if apiKey == "" { 111 | log.Errorf("OPENAI_API_KEY environment variable not set") 112 | return 113 | } 114 | sourceFilePath := args[0] 115 | exampleFilePath := args[1] 116 | engine := engines.NewGPTEngine(apiKey, gptModel) 117 | agent, err := prebuilt.NewUnitTestWriter(engine, func(code string) error { 118 | return nil 119 | }) 120 | if err != nil { 121 | log.Error(err) 122 | return 123 | } 124 | res, err := agent.Run(prebuilt.UnitTestWriterRequest{ 125 | SourceFile: readFile(sourceFilePath), 126 | ExampleFile: readFile(exampleFilePath), 127 | }) 128 | if err != nil { 129 | log.Error(err) 130 | return 131 | } 132 | log.Info(res) 133 | }, 134 | Args: cobra.ExactArgs(2), 135 | } 136 | 137 | func gitStatus() (string, error) { 138 | cmd := exec.Command("git", "status") 139 | out, err := cmd.Output() 140 | if err != nil { 141 | return "", fmt.Errorf("git status failed: %w", err) 142 | } 143 | return string(out), nil 144 | } 145 | 146 | var gitAssistantCmd = &cobra.Command{ 147 | Use: "git-assistant INSTRUCTION", 148 | Short: "A git assistant.", 149 | Run: func(cmd *cobra.Command, args []string) { 150 | log.SetLevel(log.InfoLevel) 151 | s := spinner.New(spinner.CharSets[14], 100*time.Millisecond) 152 | s.Suffix = " Just a moment..." 153 | s.Start() 154 | apiKey := os.Getenv("OPENAI_API_KEY") 155 | if apiKey == "" { 156 | log.Errorf("OPENAI_API_KEY environment variable not set") 157 | return 158 | } 159 | engine := engines.NewGPTEngine(apiKey, gptModel).WithTemperature(0) 160 | agent := prebuilt.NewGitAssistantAgent(engine, func(action *agents.ChainAgentAction) bool { 161 | isGitCommand := action.Tool.Name() == "git" 162 | if !isGitCommand { 163 | return true 164 | } 165 | s.Stop() 166 | var command struct { 167 | Command string `json:"command"` 168 | Reason string `json:"reason"` 169 | } 170 | err := json.Unmarshal(action.Args, &command) 171 | if err != nil { 172 | return false 173 | } 174 | shouldRun := false 175 | prompt := &survey.Confirm{ 176 | Message: fmt.Sprintf("Run %q%s?", command.Command, lo.If( 177 | command.Reason != "", 178 | fmt.Sprintf(" in order to %s", command.Reason), 179 | ).Else("")), 180 | } 181 | survey.AskOne(prompt, &shouldRun) 182 | s.Start() 183 | return shouldRun 184 | }, tools.NewAskUser().WithCustomQuestionHandler(func(question string) (string, error) { 185 | s.Stop() 186 | prompt := survey.Input{ 187 | Message: question, 188 | } 189 | var response string 190 | survey.AskOne(&prompt, &response) 191 | s.Start() 192 | return response, nil 193 | })) 194 | gitStatus, err := gitStatus() 195 | if err != nil { 196 | log.Error(err) 197 | return 198 | } 199 | res, err := agent.Run(prebuilt.GitAssistantRequest{ 200 | Instruction: args[0], 201 | GitStatus: gitStatus, 202 | CurrentDate: time.Now().Format(time.RFC3339), 203 | }) 204 | s.Stop() 205 | if err != nil { 206 | log.Error(err) 207 | return 208 | } 209 | fmt.Println(res.Summary) 210 | }, 211 | Args: cobra.ExactArgs(1), 212 | } 213 | 214 | func validateDirectory(dir string) error { 215 | dirInfo, err := os.Stat(dir) 216 | if err != nil { 217 | return err 218 | } 219 | if !dirInfo.IsDir() { 220 | return fmt.Errorf("%s is not a directory", dir) 221 | } 222 | return nil 223 | } 224 | 225 | func init() { 226 | rootCmd.PersistentFlags().IntVar(&tokenLimit, "token-limit", 0, "stop after using this many tokens") 227 | rootCmd.PersistentFlags().StringVar(&gptModel, "gpt-model", "gpt-3.5-turbo-0613", "the GPT model to use") 228 | } 229 | 230 | func main() { 231 | log.SetLevel(log.DebugLevel) 232 | rootCmd.AddCommand(codeRefactorAgent) 233 | rootCmd.AddCommand(tradeAssistantAgent) 234 | rootCmd.AddCommand(unitTestWriter) 235 | rootCmd.AddCommand(gitAssistantCmd) 236 | 237 | if err := rootCmd.Execute(); err != nil { 238 | log.Fatal(err) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/natexcvi/go-llm 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/golang/mock v1.6.0 7 | github.com/hashicorp/go-multierror v1.1.1 8 | ) 9 | 10 | require ( 11 | github.com/andybalholm/cascadia v1.3.1 // indirect 12 | github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/fatih/color v1.7.0 // indirect 15 | github.com/go-stack/stack v1.8.1 // indirect 16 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 17 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 18 | github.com/mattn/go-colorable v0.1.2 // indirect 19 | github.com/mattn/go-isatty v0.0.8 // indirect 20 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect 21 | github.com/pmezard/go-difflib v1.0.0 // indirect 22 | github.com/samber/mo v1.8.0 // indirect 23 | github.com/spf13/pflag v1.0.5 // indirect 24 | golang.org/x/net v0.7.0 // indirect 25 | golang.org/x/sys v0.5.0 // indirect 26 | golang.org/x/term v0.5.0 // indirect 27 | golang.org/x/text v0.7.0 // indirect 28 | gopkg.in/square/go-jose.v2 v2.6.0 // indirect 29 | gopkg.in/yaml.v3 v3.0.1 // indirect 30 | ) 31 | 32 | require ( 33 | github.com/AlecAivazis/survey/v2 v2.3.7 34 | github.com/PuerkitoBio/goquery v1.8.1 35 | github.com/briandowns/spinner v1.23.0 36 | github.com/hashicorp/errwrap v1.0.0 // indirect 37 | github.com/playwright-community/playwright-go v0.2000.1 38 | github.com/samber/lo v1.38.1 39 | github.com/sirupsen/logrus v1.9.0 40 | github.com/spf13/cobra v1.7.0 41 | github.com/stretchr/testify v1.8.2 42 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29 43 | ) 44 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= 2 | github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= 3 | github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= 4 | github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= 5 | github.com/PuerkitoBio/goquery v1.8.1 h1:uQxhNlArOIdbrH1tr0UXwdVFgDcZDrZVdcpygAcwmWM= 6 | github.com/PuerkitoBio/goquery v1.8.1/go.mod h1:Q8ICL1kNUJ2sXGoAhPGUdYDJvgQgHzJsnnd3H7Ho5jQ= 7 | github.com/andybalholm/cascadia v1.3.1 h1:nhxRkql1kdYCc8Snf7D5/D3spOX+dBgjA6u8x004T2c= 8 | github.com/andybalholm/cascadia v1.3.1/go.mod h1:R4bJ1UQfqADjvDa4P6HZHLh/3OxWWEqc0Sk8XGwHqvA= 9 | github.com/briandowns/spinner v1.23.0 h1:alDF2guRWqa/FOZZYWjlMIx2L6H0wyewPxo/CH4Pt2A= 10 | github.com/briandowns/spinner v1.23.0/go.mod h1:rPG4gmXeN3wQV/TsAY4w8lPdIM6RX3yqeBQJSrbXjuE= 11 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 12 | github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= 13 | github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= 14 | github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= 15 | github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= 16 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 17 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 18 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 19 | github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= 20 | github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= 21 | github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= 22 | github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= 23 | github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= 24 | github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= 25 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 26 | github.com/h2non/filetype v1.1.1/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= 27 | github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= 28 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 29 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 30 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 31 | github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= 32 | github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= 33 | github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= 34 | github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= 35 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= 36 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= 37 | github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= 38 | github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 39 | github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= 40 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 41 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= 42 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= 43 | github.com/playwright-community/playwright-go v0.2000.1 h1:2JViSHpJQ/UL/PO1Gg6gXV5IcXAAsoBJ3KG9L3wKXto= 44 | github.com/playwright-community/playwright-go v0.2000.1/go.mod h1:1y9cM9b9dVHnuRWzED1KLM7FtbwTJC8ibDjI6MNqewU= 45 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 46 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 47 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 48 | github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= 49 | github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= 50 | github.com/samber/mo v1.8.0 h1:vYjHTfg14JF9tD2NLhpoUsRi9bjyRoYwa4+do0nvbVw= 51 | github.com/samber/mo v1.8.0/go.mod h1:BfkrCPuYzVG3ZljnZB783WIJIGk1mcZr9c9CPf8tAxs= 52 | github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= 53 | github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= 54 | github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= 55 | github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= 56 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 57 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 58 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 59 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 60 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 61 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 62 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 63 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 64 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 65 | github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= 66 | github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 67 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 68 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 69 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 70 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 71 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 72 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= 73 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= 74 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 75 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 76 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 77 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 78 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 79 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 80 | golang.org/x/net v0.0.0-20210916014120-12bc252f5db8/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 81 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 82 | golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= 83 | golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 84 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 85 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 86 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 87 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 88 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 89 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 90 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 91 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 92 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 93 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 94 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 95 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 96 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 97 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 98 | golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= 99 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 100 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 101 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 102 | golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= 103 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 104 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 105 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 106 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 107 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 108 | golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 109 | golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= 110 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 111 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 112 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 113 | golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 114 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 115 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 116 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 117 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 118 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 119 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 120 | gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= 121 | gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= 122 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 123 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 124 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 125 | -------------------------------------------------------------------------------- /memory/buffer_memory.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import "github.com/natexcvi/go-llm/engines" 4 | 5 | type BufferMemory struct { 6 | MaxHistory int 7 | Buffer []*engines.ChatMessage 8 | } 9 | 10 | func (memory *BufferMemory) reduceBuffer() { 11 | if memory.MaxHistory > 0 && len(memory.Buffer) > memory.MaxHistory { 12 | memory.Buffer = memory.Buffer[1:] 13 | } 14 | } 15 | 16 | func (memory *BufferMemory) Add(msg *engines.ChatMessage) error { 17 | memory.Buffer = append(memory.Buffer, msg) 18 | memory.reduceBuffer() 19 | return nil 20 | } 21 | 22 | func (memory *BufferMemory) AddPrompt(prompt *engines.ChatPrompt) error { 23 | memory.Buffer = append(memory.Buffer, prompt.History...) 24 | memory.reduceBuffer() 25 | return nil 26 | } 27 | 28 | func (memory *BufferMemory) PromptWithContext(nextMessages ...*engines.ChatMessage) (*engines.ChatPrompt, error) { 29 | memory.Buffer = append(memory.Buffer, nextMessages...) 30 | memory.reduceBuffer() 31 | return &engines.ChatPrompt{ 32 | History: memory.Buffer, 33 | }, nil 34 | } 35 | 36 | func NewBufferedMemory(maxHistory int) *BufferMemory { 37 | return &BufferMemory{ 38 | MaxHistory: maxHistory, 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /memory/buffer_memory_test.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/natexcvi/go-llm/engines" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestBufferMemory(t *testing.T) { 12 | testCases := []struct { 13 | name string 14 | max int 15 | existingMessages []*engines.ChatMessage 16 | newMessages []*engines.ChatMessage 17 | expected []*engines.ChatMessage 18 | }{ 19 | { 20 | name: "no max", 21 | max: 0, 22 | existingMessages: []*engines.ChatMessage{ 23 | { 24 | Text: "hello", 25 | }, 26 | { 27 | Text: "world", 28 | }, 29 | }, 30 | newMessages: []*engines.ChatMessage{ 31 | { 32 | Text: "!", 33 | }, 34 | }, 35 | expected: []*engines.ChatMessage{ 36 | { 37 | Text: "hello", 38 | }, 39 | { 40 | Text: "world", 41 | }, 42 | { 43 | Text: "!", 44 | }, 45 | }, 46 | }, 47 | { 48 | name: "max 2", 49 | max: 2, 50 | existingMessages: []*engines.ChatMessage{ 51 | { 52 | Text: "hello", 53 | }, 54 | { 55 | Text: "world", 56 | }, 57 | }, 58 | newMessages: []*engines.ChatMessage{ 59 | { 60 | Text: "!", 61 | }, 62 | }, 63 | expected: []*engines.ChatMessage{ 64 | { 65 | Text: "world", 66 | }, 67 | { 68 | Text: "!", 69 | }, 70 | }, 71 | }, 72 | } 73 | for _, tc := range testCases { 74 | t.Run(tc.name, func(t *testing.T) { 75 | memory := NewBufferedMemory(tc.max) 76 | for _, msg := range tc.existingMessages { 77 | memory.Add(msg) 78 | } 79 | assert.Equal(t, len(tc.existingMessages), len(memory.Buffer), "expected %d messages in buffer, got %d", len(tc.expected), len(memory.Buffer)) 80 | prompt, err := memory.PromptWithContext(tc.newMessages...) 81 | require.NoError(t, err) 82 | assert.Equal(t, tc.expected, prompt.History) 83 | }) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /memory/memory.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import "github.com/natexcvi/go-llm/engines" 4 | 5 | //go:generate mockgen -source=memory.go -destination=mocks/memory.go -package=mocks 6 | type Memory interface { 7 | Add(msg *engines.ChatMessage) error 8 | AddPrompt(prompt *engines.ChatPrompt) error 9 | PromptWithContext(nextMessages ...*engines.ChatMessage) (*engines.ChatPrompt, error) 10 | } 11 | -------------------------------------------------------------------------------- /memory/mocks/memory.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: memory.go 3 | 4 | // Package mocks is a generated GoMock package. 5 | package mocks 6 | 7 | import ( 8 | reflect "reflect" 9 | 10 | gomock "github.com/golang/mock/gomock" 11 | engines "github.com/natexcvi/go-llm/engines" 12 | ) 13 | 14 | // MockMemory is a mock of Memory interface. 15 | type MockMemory struct { 16 | ctrl *gomock.Controller 17 | recorder *MockMemoryMockRecorder 18 | } 19 | 20 | // MockMemoryMockRecorder is the mock recorder for MockMemory. 21 | type MockMemoryMockRecorder struct { 22 | mock *MockMemory 23 | } 24 | 25 | // NewMockMemory creates a new mock instance. 26 | func NewMockMemory(ctrl *gomock.Controller) *MockMemory { 27 | mock := &MockMemory{ctrl: ctrl} 28 | mock.recorder = &MockMemoryMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockMemory) EXPECT() *MockMemoryMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // Add mocks base method. 38 | func (m *MockMemory) Add(msg *engines.ChatMessage) error { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "Add", msg) 41 | ret0, _ := ret[0].(error) 42 | return ret0 43 | } 44 | 45 | // Add indicates an expected call of Add. 46 | func (mr *MockMemoryMockRecorder) Add(msg interface{}) *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockMemory)(nil).Add), msg) 49 | } 50 | 51 | // AddPrompt mocks base method. 52 | func (m *MockMemory) AddPrompt(prompt *engines.ChatPrompt) error { 53 | m.ctrl.T.Helper() 54 | ret := m.ctrl.Call(m, "AddPrompt", prompt) 55 | ret0, _ := ret[0].(error) 56 | return ret0 57 | } 58 | 59 | // AddPrompt indicates an expected call of AddPrompt. 60 | func (mr *MockMemoryMockRecorder) AddPrompt(prompt interface{}) *gomock.Call { 61 | mr.mock.ctrl.T.Helper() 62 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPrompt", reflect.TypeOf((*MockMemory)(nil).AddPrompt), prompt) 63 | } 64 | 65 | // PromptWithContext mocks base method. 66 | func (m *MockMemory) PromptWithContext(nextMessages ...*engines.ChatMessage) (*engines.ChatPrompt, error) { 67 | m.ctrl.T.Helper() 68 | varargs := []interface{}{} 69 | for _, a := range nextMessages { 70 | varargs = append(varargs, a) 71 | } 72 | ret := m.ctrl.Call(m, "PromptWithContext", varargs...) 73 | ret0, _ := ret[0].(*engines.ChatPrompt) 74 | ret1, _ := ret[1].(error) 75 | return ret0, ret1 76 | } 77 | 78 | // PromptWithContext indicates an expected call of PromptWithContext. 79 | func (mr *MockMemoryMockRecorder) PromptWithContext(nextMessages ...interface{}) *gomock.Call { 80 | mr.mock.ctrl.T.Helper() 81 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PromptWithContext", reflect.TypeOf((*MockMemory)(nil).PromptWithContext), nextMessages...) 82 | } 83 | -------------------------------------------------------------------------------- /memory/summarised_memory.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/natexcvi/go-llm/engines" 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | type SummarisedMemory struct { 11 | recentMessageLimit int 12 | recentMessages []*engines.ChatMessage 13 | originalPrompt *engines.ChatPrompt 14 | memoryState string 15 | model engines.LLM 16 | } 17 | 18 | func (memory *SummarisedMemory) reduceBuffer() { 19 | if memory.recentMessageLimit > 0 && len(memory.recentMessages) > memory.recentMessageLimit { 20 | memory.recentMessages = memory.recentMessages[1:] 21 | } 22 | } 23 | 24 | func (memory *SummarisedMemory) updateMemoryState(msg ...*engines.ChatMessage) error { 25 | if memory.memoryState == "" { 26 | memory.memoryState = "" 27 | } 28 | prompt := engines.ChatPrompt{ 29 | History: []*engines.ChatMessage{ 30 | { 31 | Role: engines.ConvRoleSystem, 32 | Text: "You are a smart memory manager. The user sends you two or more messages: " + 33 | "one with the current memory state, and the rest with new messages " + 34 | "sent to their conversation with a smart, LLM based assistant. You should " + 35 | "update the memory state to reflect the new messages' content. " + 36 | "Your goal is for the memory state to be as compact as possible, " + 37 | "while still providing the smart assistant with all the information " + 38 | "it needs for completing its task. Specifically, you should make sure " + 39 | "you specify actions the assistant has taken and their results, as well as " + 40 | "intentions of the assistant and its action plan. Do not include any other text " + 41 | "in your response. Remember, the smart assistant will read it and use it " + 42 | "as its context for further action, so try to be helpful, as if " + 43 | "the assistant has just asked you what has been happening in the conversation " + 44 | "so far.", 45 | }, 46 | { 47 | Role: engines.ConvRoleUser, 48 | Text: "The current memory state is:\n\n" + memory.memoryState, 49 | }, 50 | }, 51 | } 52 | for _, m := range msg { 53 | prompt.History = append(prompt.History, &engines.ChatMessage{ 54 | Role: engines.ConvRoleUser, 55 | Text: fmt.Sprintf("New message:\n\nRole: %s\nContent: %s", m.Role, m.Text), 56 | }) 57 | } 58 | prompt.History = append(prompt.History, &engines.ChatMessage{ 59 | Role: engines.ConvRoleSystem, 60 | Text: "Please update the memory state to reflect the new messages. " + 61 | "Do not forget to give proper weight to the current memory state.", 62 | }) 63 | updatedMemState, err := memory.model.Chat(&prompt) 64 | if err != nil { 65 | return fmt.Errorf("failed to update memory state: %w", err) 66 | } 67 | memory.memoryState = updatedMemState.Text 68 | log.Debugf("Updated memory state: %s", memory.memoryState) 69 | return nil 70 | } 71 | 72 | func (memory *SummarisedMemory) Add(msg *engines.ChatMessage) error { 73 | memory.recentMessages = append(memory.recentMessages, msg) 74 | memory.reduceBuffer() 75 | if err := memory.updateMemoryState(msg); err != nil { 76 | return fmt.Errorf("failed to update memory state: %w", err) 77 | } 78 | return nil 79 | } 80 | 81 | func (memory *SummarisedMemory) AddPrompt(prompt *engines.ChatPrompt) error { 82 | memory.originalPrompt = prompt 83 | return nil 84 | } 85 | 86 | func (memory *SummarisedMemory) PromptWithContext(nextMessages ...*engines.ChatMessage) (*engines.ChatPrompt, error) { 87 | promptMessages := make([]*engines.ChatMessage, 0, len(memory.recentMessages)+len(nextMessages)) 88 | if memory.originalPrompt != nil { 89 | promptMessages = append(promptMessages, memory.originalPrompt.History...) 90 | } 91 | promptMessages = append(promptMessages, &engines.ChatMessage{ 92 | Role: engines.ConvRoleSystem, 93 | Text: fmt.Sprintf("Memory state:\n\n%s", memory.memoryState), 94 | }) 95 | memory.recentMessages = append(memory.recentMessages, nextMessages...) 96 | promptMessages = append(promptMessages, memory.recentMessages...) 97 | return &engines.ChatPrompt{ 98 | History: promptMessages, 99 | }, nil 100 | } 101 | 102 | func NewSummarisedMemory(recentMessageLimit int, model engines.LLM) *SummarisedMemory { 103 | return &SummarisedMemory{ 104 | recentMessageLimit: recentMessageLimit, 105 | model: model, 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /memory/summarised_memory_test.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/golang/mock/gomock" 9 | "github.com/natexcvi/go-llm/engines" 10 | enginemocks "github.com/natexcvi/go-llm/engines/mocks" 11 | "github.com/samber/lo" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestSummarisedMemory(t *testing.T) { 17 | testCases := []struct { 18 | name string 19 | recentMessageLimit int 20 | existingMessages []*engines.ChatMessage 21 | newMessages []*engines.ChatMessage 22 | currentMemoryState string 23 | expected string 24 | }{ 25 | { 26 | name: "update memory state", 27 | recentMessageLimit: 1, 28 | existingMessages: []*engines.ChatMessage{ 29 | { 30 | Role: engines.ConvRoleUser, 31 | Text: "hello", 32 | }, 33 | }, 34 | newMessages: []*engines.ChatMessage{ 35 | { 36 | Role: engines.ConvRoleUser, 37 | Text: "world", 38 | }, 39 | { 40 | Role: engines.ConvRoleSystem, 41 | Text: "the result of an action", 42 | }, 43 | }, 44 | currentMemoryState: "", 45 | expected: "hello", 46 | }, 47 | { 48 | name: "memory state is empty", 49 | recentMessageLimit: 3, 50 | existingMessages: []*engines.ChatMessage{}, 51 | newMessages: []*engines.ChatMessage{ 52 | { 53 | Role: engines.ConvRoleUser, 54 | Text: "Message 1", 55 | }, 56 | { 57 | Role: engines.ConvRoleUser, 58 | Text: "Message 2", 59 | }, 60 | { 61 | Role: engines.ConvRoleSystem, 62 | Text: "the result of an action", 63 | }, 64 | }, 65 | currentMemoryState: "", 66 | expected: "", 67 | }, 68 | { 69 | name: "update memory state with capacity", 70 | recentMessageLimit: 1, 71 | existingMessages: []*engines.ChatMessage{ 72 | { 73 | Role: engines.ConvRoleUser, 74 | Text: "hello", 75 | }, 76 | { 77 | Role: engines.ConvRoleUser, 78 | Text: "world", 79 | }, 80 | }, 81 | newMessages: []*engines.ChatMessage{ 82 | { 83 | Role: engines.ConvRoleSystem, 84 | Text: "OBS: the result of an action", 85 | }, 86 | }, 87 | currentMemoryState: "", 88 | expected: "world", 89 | }, 90 | } 91 | 92 | for _, tc := range testCases { 93 | t.Run(tc.name, func(t *testing.T) { 94 | ctrl := gomock.NewController(t) 95 | engineMock := enginemocks.NewMockLLM(ctrl) 96 | memState := tc.currentMemoryState 97 | engineMock.EXPECT().Chat(gomock.Any()).AnyTimes().DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 98 | newMessages := prompt.History[2 : len(prompt.History)-1] 99 | newMessagesEnc := strings.Join(lo.Map(newMessages, func(msg *engines.ChatMessage, _ int) string { 100 | return msg.Text 101 | }), "\n") 102 | memState = memState + "\n" + newMessagesEnc 103 | return &engines.ChatMessage{ 104 | Role: engines.ConvRoleAssistant, 105 | Text: memState, 106 | }, nil 107 | }) 108 | memory := NewSummarisedMemory(tc.recentMessageLimit, engineMock) 109 | for _, msg := range tc.existingMessages { 110 | memory.Add(msg) 111 | } 112 | 113 | prompt, err := memory.PromptWithContext(tc.newMessages...) 114 | require.NoError(t, err) 115 | numExpectedPromptMsgs := lo.Min([]int{tc.recentMessageLimit, len(tc.existingMessages)}) + len(tc.newMessages) + 1 116 | assert.Len(t, prompt.History, numExpectedPromptMsgs, "history length") 117 | memoryStateMsg, ok := lo.Find(prompt.History, func(msg *engines.ChatMessage) bool { 118 | return msg.Role == engines.ConvRoleSystem && strings.HasPrefix(msg.Text, "Memory state:") 119 | }) 120 | require.True(t, ok, "memory state message not found") 121 | assert.True(t, strings.HasSuffix(memoryStateMsg.Text, tc.expected), fmt.Sprintf("expected ends with: %s, got: %s", tc.expected, memoryStateMsg.Text)) 122 | }) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /memory/vectorstore.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | type TextEmbedder interface { 4 | Embed(text string) []float64 5 | } 6 | 7 | type Vectorstore interface { 8 | Store(key []float64, value string) error 9 | FindNearest(key []float64, k int) ([]string, error) 10 | } 11 | 12 | type VectorstoreMemory struct { 13 | embedder TextEmbedder 14 | store Vectorstore 15 | } 16 | 17 | // TODO: implement Memory for VectorstoreMemory 18 | -------------------------------------------------------------------------------- /prebuilt/code_refactor.go: -------------------------------------------------------------------------------- 1 | package prebuilt 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/natexcvi/go-llm/agents" 8 | "github.com/natexcvi/go-llm/engines" 9 | "github.com/natexcvi/go-llm/memory" 10 | "github.com/natexcvi/go-llm/tools" 11 | ) 12 | 13 | type CodeBaseRefactorRequest struct { 14 | Dir string 15 | Goal string 16 | } 17 | 18 | func (req CodeBaseRefactorRequest) Encode() string { 19 | return fmt.Sprintf(`{"dir": "%s", "goal": "%s"}`, req.Dir, req.Goal) 20 | } 21 | 22 | func (req CodeBaseRefactorRequest) Schema() string { 23 | return `{"dir": "path to code base", "goal": "refactoring goal"}` 24 | } 25 | 26 | type CodeBaseRefactorResponse struct { 27 | RefactoredFiles map[string]string `json:"refactored_files"` 28 | } 29 | 30 | func (resp CodeBaseRefactorResponse) Encode() string { 31 | marshalled, err := json.Marshal(resp.RefactoredFiles) 32 | if err != nil { 33 | panic(err) 34 | } 35 | return string(marshalled) 36 | } 37 | 38 | func (resp CodeBaseRefactorResponse) Schema() string { 39 | return `{"refactored_files": {"path": "description of changes"}}` 40 | } 41 | 42 | func NewCodeRefactorAgent(engine engines.LLM) agents.Agent[CodeBaseRefactorRequest, CodeBaseRefactorResponse] { 43 | task := &agents.Task[CodeBaseRefactorRequest, CodeBaseRefactorResponse]{ 44 | Description: "You will be given access to a code base, and instructions for refactoring. " + 45 | "Your task is to refactor the code base to meet the given goal.", 46 | Examples: []agents.Example[CodeBaseRefactorRequest, CodeBaseRefactorResponse]{ 47 | { 48 | Input: CodeBaseRefactorRequest{ 49 | Dir: "/Users/nate/code/base", 50 | Goal: "Handle errors gracefully", 51 | }, 52 | Answer: CodeBaseRefactorResponse{ 53 | RefactoredFiles: map[string]string{ 54 | "/Users/nate/code/base/main.py": "added try/except block", 55 | }, 56 | }, 57 | IntermediarySteps: []*engines.ChatMessage{ 58 | (&agents.ChainAgentThought{ 59 | Content: "I should scan the code base for functions that might error.", 60 | }).Encode(engine), 61 | (&agents.ChainAgentAction{ 62 | Tool: agents.NewGenericAgentTool(nil, nil), 63 | Args: json.RawMessage(`{"task": "scan code base for functions that might error", "input": "/Users/nate/code/base"}`), 64 | }).Encode(engine), 65 | (&agents.ChainAgentObservation{ 66 | Content: "main.py", 67 | ToolName: "smart_agent", 68 | }).Encode(engine), 69 | (&agents.ChainAgentThought{ 70 | Content: "Now I should handle each function that might error.", 71 | }).Encode(engine), 72 | (&agents.ChainAgentAction{ 73 | Tool: agents.NewGenericAgentTool(nil, nil), 74 | Args: json.RawMessage(`{"task": "fix any function that has unhandled exceptions in the file you will be given.", "input": "/Users/nate/code/base/main.py"}`), 75 | }).Encode(engine), 76 | (&agents.ChainAgentObservation{ 77 | Content: "Okay, I've fixed the errors in main.py by wrapping a block with try/except.", 78 | ToolName: "smart_agent", 79 | }).Encode(engine), 80 | }, 81 | }, 82 | }, 83 | AnswerParser: func(msg string) (CodeBaseRefactorResponse, error) { 84 | var res CodeBaseRefactorResponse 85 | if err := json.Unmarshal([]byte(msg), &res); err == nil && res.RefactoredFiles != nil { 86 | return res, nil 87 | } 88 | var rawRes map[string]string 89 | if err := json.Unmarshal([]byte(msg), &rawRes); err != nil { 90 | return CodeBaseRefactorResponse{}, fmt.Errorf("invalid response: %s", err.Error()) 91 | } 92 | return CodeBaseRefactorResponse{ 93 | RefactoredFiles: rawRes, 94 | }, nil 95 | }, 96 | } 97 | agent := agents.NewChainAgent(engine, task, memory.NewBufferedMemory(10)).WithMaxSolutionAttempts(12).WithTools( 98 | tools.NewPythonREPL(), 99 | tools.NewBashTerminal(), 100 | tools.NewAskUser(), 101 | agents.NewGenericAgentTool(engine, []tools.Tool{tools.NewBashTerminal(), tools.NewPythonREPL()}), 102 | ) 103 | return agent 104 | } 105 | -------------------------------------------------------------------------------- /prebuilt/git_assistant.go: -------------------------------------------------------------------------------- 1 | package prebuilt 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/natexcvi/go-llm/agents" 9 | "github.com/natexcvi/go-llm/engines" 10 | "github.com/natexcvi/go-llm/memory" 11 | "github.com/natexcvi/go-llm/tools" 12 | ) 13 | 14 | var ( 15 | gitTool = tools.NewGenericTool( 16 | "git", 17 | "A tool for executing git commands.", 18 | json.RawMessage(`{"command": "the git command to execute", "reason": "explain why you are executing this command, e.g. 'add a file to the staging area''"}`), 19 | func(args json.RawMessage) (json.RawMessage, error) { 20 | var command struct { 21 | Command string `json:"command"` 22 | Reason string `json:"reason"` 23 | } 24 | err := json.Unmarshal(args, &command) 25 | if err != nil { 26 | return nil, err 27 | } 28 | if strings.HasPrefix(command.Command, "git ") { 29 | command.Command = command.Command[4:] 30 | } 31 | out, err := tools.NewBashTerminal().Execute([]byte(fmt.Sprintf(`{"command": "git %s"}`, command.Command))) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return json.Marshal(string(out)) 36 | }, 37 | ) 38 | ) 39 | 40 | type GitAssistantRequest struct { 41 | Instruction string 42 | GitStatus string 43 | CurrentDate string 44 | } 45 | 46 | func (req GitAssistantRequest) Encode() string { 47 | marshaled, err := json.Marshal(req) 48 | if err != nil { 49 | panic(err) 50 | } 51 | return string(marshaled) 52 | } 53 | 54 | func (req GitAssistantRequest) Schema() string { 55 | return `{"instruction": "a description of what the user wants to do", "git_status": "output of git status"}` 56 | } 57 | 58 | type GitAssistantResponse struct { 59 | Summary string `json:"summary"` 60 | } 61 | 62 | func (resp GitAssistantResponse) Encode() string { 63 | marshaled, err := json.Marshal(resp) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return string(marshaled) 68 | } 69 | 70 | func (resp GitAssistantResponse) Schema() string { 71 | return `{"summary": "a summary of the git operations performed"}` 72 | } 73 | 74 | func NewGitAssistantAgent(engine engines.LLM, actionConfirmationHook func(action *agents.ChainAgentAction) bool, additionalTools ...tools.Tool) agents.Agent[GitAssistantRequest, GitAssistantResponse] { 75 | task := &agents.Task[GitAssistantRequest, GitAssistantResponse]{ 76 | Description: "You will be given an instruction for some operation " + 77 | "to be performed with git. Your task is to perform the operation, " + 78 | "and explain why it was performed. Sometimes more than one operation " + 79 | "will be required to complete the task, but make sure to use as few command " + 80 | "as possible.", 81 | Examples: []agents.Example[GitAssistantRequest, GitAssistantResponse]{ 82 | { 83 | Input: GitAssistantRequest{ 84 | Instruction: "I added a try/except block to main.py, and now I want to push the changes to GitHub.", 85 | }, 86 | Answer: GitAssistantResponse{ 87 | Summary: "I pushed the changes to GitHub.", 88 | }, 89 | IntermediarySteps: []*engines.ChatMessage{ 90 | (&agents.ChainAgentAction{ 91 | Tool: gitTool, 92 | Args: []byte(`{"command": "push", "reason": "push the changes to GitHub"}`), 93 | }).Encode(engine), 94 | }, 95 | }, 96 | }, 97 | AnswerParser: func(msg string) (GitAssistantResponse, error) { 98 | var resp GitAssistantResponse 99 | if err := json.Unmarshal([]byte(msg), &resp); err != nil { 100 | resp.Summary = msg 101 | } 102 | return resp, nil 103 | }, 104 | } 105 | additionalTools = append(additionalTools, gitTool) 106 | agent := agents.NewChainAgent(engine, task, memory.NewBufferedMemory(10)).WithMaxSolutionAttempts(15).WithTools( 107 | additionalTools..., 108 | ).WithActionConfirmation(actionConfirmationHook) 109 | return agent 110 | } 111 | -------------------------------------------------------------------------------- /prebuilt/trade_assistant.go: -------------------------------------------------------------------------------- 1 | package prebuilt 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/natexcvi/go-llm/agents" 8 | "github.com/natexcvi/go-llm/engines" 9 | "github.com/natexcvi/go-llm/memory" 10 | "github.com/natexcvi/go-llm/tools" 11 | ) 12 | 13 | type TradeAssistantRequest struct { 14 | Stocks []string `json:"stocks"` 15 | } 16 | 17 | func (r TradeAssistantRequest) Encode() string { 18 | return fmt.Sprintf(`{"stocks": %s}`, r.Stocks) 19 | } 20 | 21 | func (r TradeAssistantRequest) Schema() string { 22 | return `{"stocks": "list of stock tickers"}` 23 | } 24 | 25 | type Recommendation string 26 | 27 | const ( 28 | RecommendationBuy Recommendation = "buy" 29 | RecommendationSell Recommendation = "sell" 30 | RecommendationHold Recommendation = "hold" 31 | ) 32 | 33 | type TradeAssistantResponse struct { 34 | Recommendations map[string]Recommendation 35 | } 36 | 37 | func (r TradeAssistantResponse) Encode() string { 38 | marshalled, err := json.Marshal(r.Recommendations) 39 | if err != nil { 40 | panic(err) 41 | } 42 | return string(marshalled) 43 | } 44 | 45 | func (r TradeAssistantResponse) Schema() string { 46 | return `{"recommendations": {"ticker": "buy/sell/hold"}}` 47 | } 48 | 49 | func NewTradeAssistantAgent(engine engines.LLM, wolframAlphaAppID string) agents.Agent[TradeAssistantRequest, TradeAssistantResponse] { 50 | task := &agents.Task[TradeAssistantRequest, TradeAssistantResponse]{ 51 | Description: "You will be given a list of stocks. " + 52 | "Your task is to recommend whether to buy, sell, or hold each stock.", 53 | Examples: []agents.Example[TradeAssistantRequest, TradeAssistantResponse]{ 54 | { 55 | Input: TradeAssistantRequest{ 56 | Stocks: []string{"AAPL", "MSFT", "GOOG"}, 57 | }, 58 | Answer: TradeAssistantResponse{ 59 | Recommendations: map[string]Recommendation{ 60 | "AAPL": RecommendationBuy, 61 | "MSFT": RecommendationSell, 62 | "GOOG": RecommendationHold, 63 | }, 64 | }, 65 | IntermediarySteps: []*engines.ChatMessage{ 66 | (&agents.ChainAgentThought{ 67 | Content: "I should look up the stock price for each stock.", 68 | }).Encode(engine), 69 | (&agents.ChainAgentAction{ 70 | Tool: tools.NewWolframAlpha(wolframAlphaAppID), 71 | Args: json.RawMessage(`{"query": "stock price of AAPL"}`), 72 | }).Encode(engine), 73 | (&agents.ChainAgentObservation{ 74 | Content: "AAPL is currently trading at $100.00", 75 | ToolName: "wolfram_alpha", 76 | }).Encode(engine), 77 | }, 78 | }, 79 | }, 80 | AnswerParser: func(text string) (TradeAssistantResponse, error) { 81 | var recommendations map[string]Recommendation 82 | if err := json.Unmarshal([]byte(text), &recommendations); err != nil { 83 | return TradeAssistantResponse{}, err 84 | } 85 | return TradeAssistantResponse{ 86 | Recommendations: recommendations, 87 | }, nil 88 | }, 89 | } 90 | return agents.NewChainAgent(engine, task, memory.NewSummarisedMemory(3, engine)).WithMaxSolutionAttempts(12).WithTools( 91 | tools.NewGoogleSearch(), 92 | tools.NewIsolatedPythonREPL(), 93 | tools.NewWolframAlpha(wolframAlphaAppID), 94 | tools.NewWebpageSummary(engine), 95 | ) 96 | } 97 | -------------------------------------------------------------------------------- /prebuilt/unit_test_writer.go: -------------------------------------------------------------------------------- 1 | package prebuilt 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/natexcvi/go-llm/agents" 8 | "github.com/natexcvi/go-llm/engines" 9 | "github.com/natexcvi/go-llm/memory" 10 | ) 11 | 12 | type UnitTestWriterRequest struct { 13 | SourceFile string `json:"source_file"` 14 | ExampleFile string `json:"example_file"` 15 | } 16 | 17 | func (r UnitTestWriterRequest) Encode() string { 18 | return fmt.Sprintf(`Write unit tests for the following file based on the example: {"source_file": %q, "example_file": %q}`, r.SourceFile, r.ExampleFile) 19 | } 20 | 21 | func (r UnitTestWriterRequest) Schema() string { 22 | return `{"source_file": "source code file", "example_file": "example unit test file"}` 23 | } 24 | 25 | type UnitTestWriterResponse struct { 26 | UnitTestFile string `json:"unit_test_file"` 27 | } 28 | 29 | func (r UnitTestWriterResponse) Encode() string { 30 | marshalled, err := json.Marshal(r.UnitTestFile) 31 | if err != nil { 32 | panic(err) 33 | } 34 | return string(marshalled) 35 | } 36 | 37 | func (r UnitTestWriterResponse) Schema() string { 38 | return `{"unit_test_file": "unit test file"}` 39 | } 40 | 41 | func NewUnitTestWriter(engine engines.LLM, codeValidator func(code string) error) (agents.Agent[UnitTestWriterRequest, UnitTestWriterResponse], error) { 42 | task := &agents.Task[UnitTestWriterRequest, UnitTestWriterResponse]{ 43 | Description: "You are a coding assistant that specialises in writing " + 44 | "unit tests. You will be given a source code file and an example unit test file. " + 45 | "Your task is to write unit tests for the source code file, following " + 46 | "the patterns and conventions you see in the example unit test file. " + 47 | "Your final answer should be just the content of the unit test file, " + 48 | "and nothing else. " + 49 | "For this task, no intermediary steps are required and in most cases you can " + 50 | "Reply with your final answer immediately.", 51 | Examples: []agents.Example[UnitTestWriterRequest, UnitTestWriterResponse]{ 52 | { 53 | Input: UnitTestWriterRequest{ 54 | SourceFile: "def add(a, b):\n return a + b\n", 55 | ExampleFile: "from example import multiply\ndef test_multiply():" + 56 | "\n assert multiply(2, 3) == 6\n", 57 | }, 58 | Answer: UnitTestWriterResponse{ 59 | UnitTestFile: "from example import add\ndef test_add():" + 60 | "\n assert add(4, -4) == 0\n", 61 | }, 62 | IntermediarySteps: []*engines.ChatMessage{ 63 | (&agents.ChainAgentThought{ 64 | Content: "I now know what tests to write.", 65 | }).Encode(engine), 66 | }, 67 | }, 68 | }, 69 | AnswerParser: func(answer string) (UnitTestWriterResponse, error) { 70 | var response UnitTestWriterResponse 71 | if err := json.Unmarshal([]byte(answer), &response); err == nil { 72 | return response, nil 73 | } 74 | var responseString string 75 | if err := json.Unmarshal([]byte(answer), &responseString); err == nil { 76 | return UnitTestWriterResponse{ 77 | UnitTestFile: responseString, 78 | }, nil 79 | } 80 | return UnitTestWriterResponse{ 81 | UnitTestFile: answer, 82 | }, nil 83 | }, 84 | } 85 | agent := agents.NewChainAgent(engine, task, memory.NewBufferedMemory(0)) 86 | if codeValidator != nil { 87 | agent = agent.WithOutputValidators(func(utwr UnitTestWriterResponse) error { 88 | err := codeValidator(utwr.UnitTestFile) 89 | if err != nil { 90 | return fmt.Errorf("your unit test file is not valid: %s", err) 91 | } 92 | return nil 93 | }) 94 | } 95 | return agent, nil 96 | } 97 | -------------------------------------------------------------------------------- /tools/ask_user.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "os" 9 | ) 10 | 11 | type AskUser struct { 12 | source io.Reader 13 | questionHandler func(question string) (string, error) 14 | } 15 | 16 | func (b *AskUser) Execute(args json.RawMessage) (json.RawMessage, error) { 17 | var command struct { 18 | Question string `json:"question"` 19 | } 20 | err := json.Unmarshal(args, &command) 21 | if err != nil { 22 | return nil, err 23 | } 24 | var answer string 25 | if b.questionHandler != nil { 26 | answer, err = b.questionHandler(command.Question) 27 | } else { 28 | fmt.Println(command.Question) 29 | answer, err = b.readUserInput() 30 | } 31 | if err != nil { 32 | if errors.Is(err, io.EOF) { 33 | return nil, fmt.Errorf("the user did not provide an answer") 34 | } 35 | return nil, fmt.Errorf("error while reading user input: %s", err.Error()) 36 | } 37 | var response struct { 38 | Answer string `json:"answer"` 39 | } 40 | response.Answer = answer 41 | return json.Marshal(response) 42 | } 43 | 44 | func (b *AskUser) readUserInput() (string, error) { 45 | var answer string 46 | n, err := fmt.Fscanln(b.source, &answer) 47 | if err != nil { 48 | return "", err 49 | } 50 | if n == 0 { 51 | return "", fmt.Errorf("no input") 52 | } 53 | return answer, nil 54 | } 55 | 56 | func (b *AskUser) Name() string { 57 | return "ask_user" 58 | } 59 | 60 | func (b *AskUser) Description() string { 61 | return "A tool for asking the user a question. Use this tool " + 62 | "if you need any kind of input or help from the user." 63 | } 64 | 65 | func (b *AskUser) ArgsSchema() json.RawMessage { 66 | return json.RawMessage(`{"question": "the question to ask the user"}`) 67 | } 68 | 69 | func (b *AskUser) CompactArgs(args json.RawMessage) json.RawMessage { 70 | return args 71 | } 72 | 73 | func NewAskUser() *AskUser { 74 | return &AskUser{ 75 | source: os.Stdin, 76 | } 77 | } 78 | 79 | func NewAskUserWithSource(source io.Reader) *AskUser { 80 | return &AskUser{ 81 | source: source, 82 | } 83 | } 84 | 85 | func (b *AskUser) WithCustomQuestionHandler(handler func(question string) (string, error)) *AskUser { 86 | b.questionHandler = handler 87 | return b 88 | } 89 | -------------------------------------------------------------------------------- /tools/ask_user_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func mockStream(text string) io.Reader { 14 | return strings.NewReader(text) 15 | } 16 | 17 | func TestAskUser(t *testing.T) { 18 | testCases := []struct { 19 | name string 20 | question string 21 | mockResponse string 22 | expectedOutput string 23 | expectedError error 24 | }{ 25 | { 26 | name: "simple", 27 | question: "What is your name?", 28 | mockResponse: "John", 29 | expectedOutput: `{"answer": "John"}`, 30 | }, 31 | { 32 | name: "empty", 33 | question: "What is your name?", 34 | mockResponse: "", 35 | expectedError: fmt.Errorf("the user did not provide an answer"), 36 | }, 37 | { 38 | name: "reads only one line", 39 | question: "What is your name?", 40 | mockResponse: "John\nAlex", 41 | expectedOutput: `{"answer": "John"}`, 42 | }, 43 | } 44 | for _, tc := range testCases { 45 | t.Run(tc.name, func(t *testing.T) { 46 | askUser := NewAskUserWithSource(mockStream(tc.mockResponse)) 47 | output, err := askUser.Execute([]byte(`{"question": "` + tc.question + `"}`)) 48 | if tc.expectedError != nil { 49 | require.EqualError(t, err, tc.expectedError.Error()) 50 | return 51 | } 52 | require.NoError(t, err) 53 | assert.JSONEq(t, tc.expectedOutput, string(output)) 54 | }) 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /tools/bash.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os/exec" 7 | ) 8 | 9 | type BashTerminal struct { 10 | } 11 | 12 | func (b *BashTerminal) Execute(args json.RawMessage) (json.RawMessage, error) { 13 | var command struct { 14 | Command string `json:"command"` 15 | } 16 | err := json.Unmarshal(args, &command) 17 | if err != nil { 18 | return nil, err 19 | } 20 | out, err := exec.Command("bash", "-c", command.Command).Output() 21 | if err != nil { 22 | if exitError, ok := err.(*exec.ExitError); ok { 23 | return nil, fmt.Errorf("bash exited with code %d: %s", exitError.ExitCode(), string(exitError.Stderr)) 24 | } 25 | return nil, err 26 | } 27 | return json.Marshal(string(out)) 28 | } 29 | 30 | func (b *BashTerminal) Name() string { 31 | return "bash" 32 | } 33 | 34 | func (b *BashTerminal) Description() string { 35 | return "A tool for executing bash commands. Important! This tool is not sandboxed, so it can do anything on the host machine." 36 | } 37 | 38 | func (b *BashTerminal) ArgsSchema() json.RawMessage { 39 | return json.RawMessage(`{"command": "the bash command to execute"}`) 40 | } 41 | 42 | func (b *BashTerminal) CompactArgs(args json.RawMessage) json.RawMessage { 43 | return args 44 | } 45 | 46 | func NewBashTerminal() *BashTerminal { 47 | return &BashTerminal{} 48 | } 49 | -------------------------------------------------------------------------------- /tools/bash_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestBash(t *testing.T) { 13 | testCases := []struct { 14 | name string 15 | bash *BashTerminal 16 | input json.RawMessage 17 | output json.RawMessage 18 | expErr error 19 | }{ 20 | { 21 | name: "simple", 22 | bash: NewBashTerminal(), 23 | input: json.RawMessage(`{"command": "echo hello, world"}`), 24 | output: json.RawMessage(`"hello, world\n"`), 25 | }, 26 | { 27 | name: "error", 28 | bash: NewBashTerminal(), 29 | input: json.RawMessage(`{"command": "cat no-such-file"}`), 30 | expErr: fmt.Errorf("bash exited with code 1: cat: no-such-file: No such file or directory\n"), 31 | }, 32 | } 33 | for _, testCase := range testCases { 34 | t.Run(testCase.name, func(t *testing.T) { 35 | output, err := testCase.bash.Execute(testCase.input) 36 | if testCase.expErr != nil { 37 | require.EqualError(t, err, testCase.expErr.Error()) 38 | return 39 | } 40 | require.NoError(t, err) 41 | assert.JSONEq(t, string(testCase.output), string(output)) 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /tools/generic_tool.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type GenericTool struct { 8 | name string 9 | description string 10 | argSchema json.RawMessage 11 | handler func(args json.RawMessage) (json.RawMessage, error) 12 | } 13 | 14 | func (b *GenericTool) Execute(args json.RawMessage) (json.RawMessage, error) { 15 | return b.handler(args) 16 | } 17 | 18 | func (b *GenericTool) Name() string { 19 | return b.name 20 | } 21 | 22 | func (b *GenericTool) Description() string { 23 | return b.description 24 | } 25 | 26 | func (b *GenericTool) ArgsSchema() json.RawMessage { 27 | return b.argSchema 28 | } 29 | 30 | func (b *GenericTool) CompactArgs(args json.RawMessage) json.RawMessage { 31 | return args 32 | } 33 | 34 | func NewGenericTool(name, description string, argSchema json.RawMessage, handler func(args json.RawMessage) (json.RawMessage, error)) *GenericTool { 35 | return &GenericTool{ 36 | name: name, 37 | description: description, 38 | argSchema: argSchema, 39 | handler: handler, 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tools/google_search.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/playwright-community/playwright-go" 9 | ) 10 | 11 | type SearchResult struct { 12 | Title string 13 | URL string 14 | } 15 | 16 | type WebSearch struct { 17 | ServiceURL string 18 | } 19 | 20 | func NewWebSearch(serviceURL string) *WebSearch { 21 | return &WebSearch{ 22 | ServiceURL: serviceURL, 23 | } 24 | } 25 | 26 | func NewGoogleSearch() *WebSearch { 27 | return &WebSearch{ 28 | ServiceURL: "https://google.com/", 29 | } 30 | } 31 | 32 | func (ws *WebSearch) Execute(args json.RawMessage) (json.RawMessage, error) { 33 | var query struct { 34 | Query string `json:"query"` 35 | } 36 | err := json.Unmarshal(args, &query) 37 | if err != nil { 38 | return nil, err 39 | } 40 | results, err := ws.search(query.Query) 41 | if err != nil { 42 | return nil, err 43 | } 44 | return json.Marshal(results) 45 | } 46 | 47 | func (ws *WebSearch) Name() string { 48 | return "google_search" 49 | } 50 | 51 | func (ws *WebSearch) Description() string { 52 | return fmt.Sprintf("A tool for searching the web using Google.") 53 | } 54 | 55 | func (ws *WebSearch) ArgsSchema() json.RawMessage { 56 | return json.RawMessage(`{"query": "the search query"}`) 57 | } 58 | 59 | func (ws *WebSearch) isPlaywrightInstalled() bool { 60 | _, err := playwright.Run() 61 | return err == nil 62 | } 63 | 64 | func (ws *WebSearch) search(query string) (searchResults []SearchResult, err error) { 65 | _, cancel := context.WithCancel(context.Background()) 66 | defer cancel() 67 | if !ws.isPlaywrightInstalled() { 68 | err = playwright.Install() 69 | if err != nil { 70 | return nil, fmt.Errorf("could not install playwright: %v", err) 71 | } 72 | } 73 | pw, err := playwright.Run() 74 | if err != nil { 75 | return nil, fmt.Errorf("could not start playwright: %v", err) 76 | } 77 | defer pw.Stop() 78 | 79 | // Launch a new Chromium browser context 80 | browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ 81 | SlowMo: playwright.Float(100), 82 | Headless: playwright.Bool(true), 83 | }) 84 | if err != nil { 85 | return nil, fmt.Errorf("could not launch browser: %v", err) 86 | } 87 | defer browser.Close() 88 | 89 | // Create a new browser page 90 | page, err := browser.NewPage() 91 | if err != nil { 92 | return nil, fmt.Errorf("could not create page: %v", err) 93 | } 94 | defer page.Close() 95 | 96 | // Navigate to Google and search for query 97 | searchURL := fmt.Sprintf("%s/search?q=%s", ws.ServiceURL, query) 98 | if _, err := page.Goto(searchURL); err != nil { 99 | return nil, fmt.Errorf("could not navigate to google: %v", err) 100 | } 101 | 102 | // Wait for the search results to load 103 | if _, err := page.WaitForSelector("#search"); err != nil { 104 | return nil, fmt.Errorf("could not find search results: %v", err) 105 | } 106 | 107 | // scroll to the bottom of the page 108 | if _, err := page.Evaluate(`() => { 109 | window.scrollBy(0, window.innerHeight); 110 | }`); err != nil { 111 | return nil, fmt.Errorf("could not scroll to bottom of page: %v", err) 112 | } 113 | 114 | // wait for last result to load 115 | if _, err := page.WaitForSelector("#search .g:last-child"); err != nil { 116 | return nil, fmt.Errorf("could not find last search result: %v", err) 117 | } 118 | 119 | // wait for page to load completely 120 | page.WaitForLoadState(string(*playwright.LoadStateDomcontentloaded)) 121 | 122 | // Parse the search results 123 | results, err := page.QuerySelectorAll("#search .g") 124 | if err != nil { 125 | return nil, fmt.Errorf("could not query search results: %v", err) 126 | } 127 | for _, result := range results { 128 | title, err := result.QuerySelector(".LC20lb.DKV0Md") 129 | if err != nil { 130 | continue 131 | } 132 | titleText, err := title.InnerText() 133 | if err != nil { 134 | continue 135 | } 136 | 137 | urlElem, err := result.QuerySelector("a") 138 | if err != nil { 139 | continue 140 | } 141 | url, err := urlElem.GetAttribute("href") 142 | if err != nil { 143 | continue 144 | } 145 | searchResults = append(searchResults, SearchResult{ 146 | Title: titleText, 147 | URL: url, 148 | }) 149 | } 150 | return searchResults, nil 151 | } 152 | 153 | func (ws *WebSearch) CompactArgs(args json.RawMessage) json.RawMessage { 154 | return args 155 | } 156 | -------------------------------------------------------------------------------- /tools/google_search_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | "log" 9 | "net/http" 10 | "os" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func mockSearchService(t *testing.T) { 18 | t.Helper() 19 | // Set up a mock server to respond to requests with a mock Google search results page 20 | mockServer := http.NewServeMux() 21 | mockServer.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) { 22 | // Load the sample search results page from a local file 23 | file, err := ioutil.ReadFile("testdata/mock_search_results.html") 24 | if err != nil { 25 | http.Error(w, "Failed to load mock search results page", http.StatusInternalServerError) 26 | return 27 | } 28 | 29 | // Replace the search query in the mock search results page with the query from the request 30 | query := r.FormValue("q") 31 | fileString := string(file) 32 | fileString = strings.ReplaceAll(fileString, "{{QUERY}}", query) 33 | 34 | // Set the content type and write the mock search results page to the response 35 | w.Header().Set("Content-Type", "text/html") 36 | w.Write([]byte(fileString)) 37 | }) 38 | 39 | server := &http.Server{ 40 | Addr: ":8080", 41 | Handler: mockServer, 42 | } 43 | 44 | // Start the mock server on port 8080 45 | go func() { 46 | if err := server.ListenAndServe(); err != nil { 47 | if err == http.ErrServerClosed { 48 | return 49 | } 50 | log.Fatal(err) 51 | } 52 | }() 53 | t.Cleanup(func() { 54 | require.NoError(t, server.Shutdown(context.Background())) 55 | }) 56 | } 57 | 58 | func TestWebSearch(t *testing.T) { 59 | testCases := []struct { 60 | name string 61 | query string 62 | expOutputPath string 63 | }{ 64 | { 65 | name: "simple", 66 | query: "hello world", 67 | expOutputPath: "testdata/expected_parsed_results.json", 68 | }, 69 | } 70 | mockSearchService(t) 71 | for _, tc := range testCases { 72 | t.Run(tc.name, func(t *testing.T) { 73 | ws := NewWebSearch("http://localhost:8080") 74 | output, err := ws.Execute(json.RawMessage( 75 | fmt.Sprintf(`{"query": %q}`, tc.query), 76 | )) 77 | require.NoError(t, err) 78 | var results []SearchResult 79 | require.NoError(t, json.Unmarshal(output, &results)) 80 | _, err = os.Stat(tc.expOutputPath) 81 | if os.IsNotExist(err) { 82 | require.NoError(t, ioutil.WriteFile(tc.expOutputPath, output, 0644)) 83 | } 84 | expOutput, err := ioutil.ReadFile(tc.expOutputPath) 85 | require.NoError(t, err) 86 | var expResults []SearchResult 87 | require.NoError(t, json.Unmarshal(expOutput, &expResults)) 88 | require.Equal(t, expResults, results) 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /tools/isolated_python_repl.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "path" 9 | "strings" 10 | ) 11 | 12 | type IsolatedPythonREPL struct { 13 | } 14 | 15 | func NewIsolatedPythonREPL() *IsolatedPythonREPL { 16 | return &IsolatedPythonREPL{} 17 | } 18 | 19 | func (repl *IsolatedPythonREPL) Execute(arg json.RawMessage) (json.RawMessage, error) { 20 | var args struct { 21 | Code string `json:"code"` 22 | Modules []string `json:"modules"` 23 | } 24 | err := json.Unmarshal(arg, &args) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to unmarshal args: %w", err) 27 | } 28 | tmpDir, err := os.MkdirTemp("", "python_repl") 29 | if err != nil { 30 | return nil, fmt.Errorf("failed to create temp dir: %w", err) 31 | } 32 | defer os.RemoveAll(tmpDir) 33 | err = os.WriteFile(path.Join(tmpDir, "script.py"), []byte(args.Code), 0644) 34 | if err != nil { 35 | return nil, fmt.Errorf("failed to write script to file: %w", err) 36 | } 37 | cmdArgs := []string{"run", "--rm", "-v", fmt.Sprintf("%s:/app", tmpDir), "python:3.11-alpine"} 38 | shArgs := []string{} 39 | if len(args.Modules) > 0 { 40 | shArgs = append(shArgs, "python", "-m", "pip", "install", "--quiet") 41 | shArgs = append(shArgs, args.Modules...) 42 | shArgs = append(shArgs, "&&") 43 | } 44 | shArgs = append(shArgs, "python", path.Join("app", "script.py")) 45 | cmdArgs = append(cmdArgs, "sh", "-c", strings.Join(shArgs, " ")) 46 | cmd := exec.Command("docker", cmdArgs...) 47 | out, err := cmd.Output() 48 | if err != nil { 49 | if exitError, ok := err.(*exec.ExitError); ok { 50 | return nil, fmt.Errorf("python exited with code %d: %s", exitError.ExitCode(), string(exitError.Stderr)) 51 | } 52 | return nil, err 53 | } 54 | return json.Marshal(string(out)) 55 | } 56 | 57 | func (repl *IsolatedPythonREPL) Name() string { 58 | return "isolated_python" 59 | } 60 | 61 | func (repl *IsolatedPythonREPL) Description() string { 62 | return "A Python REPL that runs in a Docker container. " + 63 | "Use this to run any Python code that can help you complete your task." 64 | } 65 | 66 | func (repl *IsolatedPythonREPL) ArgsSchema() json.RawMessage { 67 | return json.RawMessage(`{"code": "the Python code to execute", "modules": ["a list", "of modules", "to install"]}`) 68 | } 69 | 70 | func (repl *IsolatedPythonREPL) CompactArgs(args json.RawMessage) json.RawMessage { 71 | return args 72 | } 73 | -------------------------------------------------------------------------------- /tools/isolated_python_repl_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/samber/lo" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestPythonRepl(t *testing.T) { 15 | testCases := []struct { 16 | name string 17 | code string 18 | modules []string 19 | expError error 20 | expOutput string 21 | }{ 22 | { 23 | name: "simple", 24 | code: "print('hello world')", 25 | expOutput: `"hello world\n"`, 26 | }, 27 | { 28 | name: "error", 29 | code: "print('hello world')\nprint(1/0)", 30 | expError: fmt.Errorf("python exited with code 1: Traceback (most recent call last):\n File \"//app/script.py\", line 2, in \n print(1/0)\n ~^~\nZeroDivisionError: division by zero\n"), 31 | }, 32 | { 33 | name: "with modules", 34 | code: "import requests\nprint('hello world')", 35 | modules: []string{ 36 | "requests", 37 | }, 38 | expOutput: `"hello world\n"`, 39 | }, 40 | { 41 | name: "no existing module - error", 42 | code: "import requests\nprint('hello world')", 43 | modules: []string{ 44 | "requests", 45 | "does-not-exist", 46 | }, 47 | expError: fmt.Errorf("python exited with code 1: ERROR: Could not find a version that satisfies the requirement does-not-exist"), 48 | }, 49 | } 50 | for _, tc := range testCases { 51 | t.Run(tc.name, func(t *testing.T) { 52 | repl := NewIsolatedPythonREPL() 53 | output, err := repl.Execute(json.RawMessage( 54 | fmt.Sprintf( 55 | `{"code": %q, "modules": [%s]}`, 56 | tc.code, 57 | strings.Join( 58 | lo.Map(tc.modules, func(in string, _ int) string { return fmt.Sprintf("%q", in) }), 59 | ",", 60 | ), 61 | ), 62 | )) 63 | if tc.expError != nil { 64 | actualError := err.Error() 65 | require.True(t, strings.HasPrefix(actualError, tc.expError.Error()), "expected error to start with %q, got %q", tc.expError.Error(), actualError) 66 | return 67 | } 68 | require.NoError(t, err) 69 | assert.Equal(t, tc.expOutput, string(output)) 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /tools/json_autofixer.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "regexp" 7 | 8 | "github.com/hashicorp/go-multierror" 9 | "github.com/natexcvi/go-llm/engines" 10 | log "github.com/sirupsen/logrus" 11 | ) 12 | 13 | var ErrMaxRetriesExceeded = fmt.Errorf("max retries exceeded") 14 | 15 | type JSONAutoFixer struct { 16 | engine engines.LLM 17 | maxRetries int 18 | } 19 | 20 | func (t *JSONAutoFixer) prompt(args json.RawMessage) *engines.ChatPrompt { 21 | prompt := engines.ChatPrompt{ 22 | History: []*engines.ChatMessage{ 23 | { 24 | Role: engines.ConvRoleSystem, 25 | Text: "You are an automated JSON fixer. " + 26 | "You will receive a JSON payload that might contain " + 27 | "errors, and you must fix them and return a valid JSON payload.", 28 | }, 29 | { 30 | Role: engines.ConvRoleUser, 31 | Text: `{"name": "John "Doe", "age": 30, "car": null}`, 32 | }, 33 | { 34 | Role: engines.ConvRoleAssistant, 35 | Text: `{"name": "John \"Doe", "age": 30, "car": null}`, 36 | }, 37 | }, 38 | } 39 | prompt.History = append(prompt.History, &engines.ChatMessage{ 40 | Role: engines.ConvRoleUser, 41 | Text: string(args), 42 | }) 43 | return &prompt 44 | } 45 | 46 | func (t *JSONAutoFixer) validateJSON(raw string) error { 47 | var obj any 48 | if err := json.Unmarshal([]byte(raw), &obj); err != nil { 49 | return fmt.Errorf("invalid JSON: %w", err) 50 | } 51 | return nil 52 | } 53 | 54 | func (t *JSONAutoFixer) extractJSONFromResponse(response string) string { 55 | wrappedJSONRegex := regexp.MustCompile(`\x60\x60\x60(?:json)?\s(?P[\s\S]+)\s\x60\x60\x60`) 56 | if wrappedJSONRegex.MatchString(response) { 57 | return wrappedJSONRegex.FindStringSubmatch(response)[1] 58 | } 59 | return response 60 | } 61 | 62 | func (t *JSONAutoFixer) Process(args json.RawMessage) (json.RawMessage, error) { 63 | if err := t.validateJSON(string(args)); err == nil { 64 | return args, nil 65 | } 66 | log.Debugf("Running JSON fixer") 67 | prompt := t.prompt(args) 68 | var cumErr *multierror.Error 69 | for i := 0; i < t.maxRetries; i++ { 70 | resp, err := t.engine.Chat(prompt) 71 | if err != nil { 72 | return nil, fmt.Errorf("error running JSON auto fixer: %w", err) 73 | } 74 | respJSON := t.extractJSONFromResponse(resp.Text) 75 | if err := t.validateJSON(respJSON); err != nil { 76 | cumErr = multierror.Append(cumErr, fmt.Errorf("invalid JSON returned by JSON auto fixer: %w", err)) 77 | continue 78 | } 79 | log.Debugf("JSON auto fixer succeeded after %d retries", i+1) 80 | log.Debugf("Fixed JSON payload: %s", respJSON) 81 | return json.RawMessage(respJSON), nil 82 | } 83 | return nil, multierror.Append(cumErr, ErrMaxRetriesExceeded) 84 | } 85 | 86 | func NewJSONAutoFixer(engine engines.LLM, maxRetries int) *JSONAutoFixer { 87 | return &JSONAutoFixer{ 88 | engine: engine, 89 | maxRetries: maxRetries, 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /tools/json_autofixer_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/natexcvi/go-llm/engines" 9 | enginemocks "github.com/natexcvi/go-llm/engines/mocks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestJSONAutoFixer_Process(t *testing.T) { 14 | testCases := []struct { 15 | name string 16 | maxRetries int 17 | raw string 18 | expectedErr error 19 | expectedOutput string 20 | modelResponses []string 21 | }{ 22 | { 23 | name: "Valid JSON", 24 | maxRetries: 1, 25 | raw: `{"foo":"bar"}`, 26 | expectedErr: nil, 27 | expectedOutput: `{"foo":"bar"}`, 28 | modelResponses: []string{}, 29 | }, 30 | { 31 | name: "Error JSON", 32 | maxRetries: 1, 33 | raw: `{"foo":"bar"`, 34 | expectedErr: nil, 35 | expectedOutput: `{"foo": "bar"}`, 36 | modelResponses: []string{ 37 | `{"foo": "bar"}`, 38 | }, 39 | }, 40 | { 41 | name: "Error JSON wrapped response", 42 | maxRetries: 1, 43 | raw: `{"foo":"bar"`, 44 | expectedErr: nil, 45 | expectedOutput: `{"foo": "bar"}`, 46 | modelResponses: []string{ 47 | "Here if your fixed JSON payload:" + 48 | "\x60\x60\x60json\n" + 49 | `{"foo": "bar"}` + 50 | "\n\x60\x60\x60", 51 | }, 52 | }, 53 | { 54 | name: "Error JSON wrapped response (just backtics)", 55 | maxRetries: 1, 56 | raw: `{"foo":"bar"`, 57 | expectedErr: nil, 58 | expectedOutput: `{"foo": "bar"}`, 59 | modelResponses: []string{ 60 | "Here if your fixed JSON payload:" + 61 | "\x60\x60\x60\n" + 62 | `{"foo": "bar"}` + 63 | "\n\x60\x60\x60", 64 | }, 65 | }, 66 | { 67 | name: "Error JSON with max retries", 68 | maxRetries: 2, 69 | raw: `{"foo":"bar"`, 70 | expectedErr: nil, 71 | expectedOutput: `{"foo": "bar"}`, 72 | modelResponses: []string{ 73 | `{"foo": "bar"`, 74 | `{"foo": "bar"}`, 75 | }, 76 | }, 77 | { 78 | name: "Error JSON - max retries exceeded", 79 | maxRetries: 2, 80 | raw: `{"foo":"bar"`, 81 | expectedErr: ErrMaxRetriesExceeded, 82 | expectedOutput: `{"foo": "bar"`, 83 | modelResponses: []string{ 84 | `{"foo": "bar"`, 85 | `{"foo": "bar"`, 86 | }, 87 | }, 88 | } 89 | 90 | for _, tc := range testCases { 91 | t.Run(tc.name, func(t *testing.T) { 92 | ctrl := gomock.NewController(t) 93 | engineMock := enginemocks.NewMockLLM(ctrl) 94 | i := -1 95 | engineMock.EXPECT().Chat(gomock.Any()).DoAndReturn(func(prompt *engines.ChatPrompt) (*engines.ChatMessage, error) { 96 | i++ 97 | return &engines.ChatMessage{ 98 | Role: engines.ConvRoleAssistant, 99 | Text: tc.modelResponses[i], 100 | }, nil 101 | }).Times(len(tc.modelResponses)) 102 | autoFixer := NewJSONAutoFixer(engineMock, tc.maxRetries) 103 | output, err := autoFixer.Process(json.RawMessage(tc.raw)) 104 | if tc.expectedErr != nil { 105 | assert.ErrorIs(t, err, tc.expectedErr, "error") 106 | assert.Nil(t, output, "output") 107 | } else { 108 | assert.NoError(t, err, "error") 109 | assert.Equal(t, json.RawMessage(tc.expectedOutput), output, "output") 110 | } 111 | }) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /tools/key_value_store.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "text/template" 9 | 10 | "github.com/hashicorp/go-multierror" 11 | "golang.org/x/exp/maps" 12 | ) 13 | 14 | type KeyValueStore struct { 15 | store map[string]string 16 | } 17 | 18 | type kvStoreArgs struct { 19 | Command string `json:"command"` 20 | Key string `json:"key"` 21 | Value string `json:"value"` 22 | } 23 | 24 | func (s *KeyValueStore) Execute(args json.RawMessage) (json.RawMessage, error) { 25 | var command kvStoreArgs 26 | err := json.Unmarshal(args, &command) 27 | if err != nil { 28 | return nil, err 29 | } 30 | if command.Command == "" { 31 | if command.Key != "" && command.Value != "" { 32 | command.Command = "set" 33 | } 34 | if command.Key != "" && command.Value == "" { 35 | command.Command = "get" 36 | } 37 | if command.Key == "" && command.Value == "" { 38 | command.Command = "list" 39 | } 40 | } 41 | switch command.Command { 42 | case "get": 43 | value, ok := s.store[command.Key] 44 | if !ok { 45 | return nil, fmt.Errorf("key not found: %s", command.Key) 46 | } 47 | return json.Marshal(value) 48 | case "set": 49 | s.store[command.Key] = command.Value 50 | return json.Marshal("stored successfully") 51 | case "list": 52 | keys := maps.Keys(s.store) 53 | return json.Marshal(keys) 54 | default: 55 | return nil, fmt.Errorf("unknown command: %s", command.Command) 56 | } 57 | } 58 | 59 | func (s *KeyValueStore) Name() string { 60 | return "store" 61 | } 62 | 63 | func (s *KeyValueStore) Description() string { 64 | return "A place where you can store any key-value pairs " + 65 | "of data. This is useful mainly for long values, which you should " + 66 | "store here to save memory. To use a value you have stored, " + 67 | "reference it by Go template syntax: {{ store \"key\" }}. " + 68 | "where \"key\" is the key you used to store the value. " + 69 | "You can reference a saved value anywhere you want, including " + 70 | "arguments to other tools." 71 | } 72 | 73 | func (s *KeyValueStore) ArgsSchema() json.RawMessage { 74 | return json.RawMessage(`{"command": "either 'set', 'get' or 'list'", "key": "the key to store or retrieve. Specify only for 'get' and 'set'.", "value": "the value to store. Specify only for 'set'."}`) 75 | } 76 | 77 | func (s *KeyValueStore) CompactArgs(args json.RawMessage) json.RawMessage { 78 | var command kvStoreArgs 79 | err := json.Unmarshal(args, &command) 80 | if err != nil { 81 | return args 82 | } 83 | switch command.Command { 84 | case "set": 85 | return json.RawMessage(fmt.Sprintf(`{"command": "set", "key": "%s", "value": ""}`, command.Key)) 86 | default: 87 | return args 88 | } 89 | } 90 | 91 | func (s *KeyValueStore) recursivelyProcessStringFields(input any, processor func(string) string) any { 92 | switch input := input.(type) { 93 | case map[string]interface{}: 94 | output := map[string]interface{}{} 95 | for k, v := range input { 96 | output[k] = s.recursivelyProcessStringFields(v, processor) 97 | } 98 | return output 99 | case []interface{}: 100 | output := make([]interface{}, len(input)) 101 | for i, v := range input { 102 | output[i] = s.recursivelyProcessStringFields(v, processor) 103 | } 104 | return output 105 | case string: 106 | return processor(input) 107 | default: 108 | return input 109 | } 110 | } 111 | 112 | func (s *KeyValueStore) Process(args json.RawMessage) (json.RawMessage, error) { 113 | tmpl := template.New("store").Funcs(template.FuncMap{ 114 | "store": func(key string) string { 115 | value, ok := s.store[key] 116 | if !ok { 117 | return fmt.Sprintf("{{ store %q }}", key) 118 | } 119 | return value 120 | }, 121 | }) 122 | var unmarshaledArgs any 123 | err := json.Unmarshal(args, &unmarshaledArgs) 124 | if err != nil { 125 | return nil, fmt.Errorf("error unmarshaling args: %s", err) 126 | } 127 | var temlErr *multierror.Error 128 | processedArgs := s.recursivelyProcessStringFields(unmarshaledArgs, func(input string) string { 129 | tmpl, err := tmpl.Parse(input) 130 | if err != nil { 131 | temlErr = multierror.Append(temlErr, fmt.Errorf("error parsing args: %s", err)) 132 | return input 133 | } 134 | var processedArgs bytes.Buffer 135 | err = tmpl.Execute(&processedArgs, nil) 136 | if err != nil { 137 | temlErr = multierror.Append(temlErr, fmt.Errorf("error processing args: %s", err)) 138 | return input 139 | } 140 | return processedArgs.String() 141 | }) 142 | if temlErr.ErrorOrNil() != nil { 143 | return nil, temlErr 144 | } 145 | var marshaledProcessedArgs json.RawMessage 146 | marshaledProcessedArgs, err = json.Marshal(processedArgs) 147 | if err != nil { 148 | return nil, fmt.Errorf("error marshaling processed args: %s", err) 149 | } 150 | return marshaledProcessedArgs, nil 151 | } 152 | 153 | func NewKeyValueStore() *KeyValueStore { 154 | return &KeyValueStore{ 155 | store: map[string]string{}, 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /tools/key_value_store_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestKeyValueStore(t *testing.T) { 12 | testCases := []struct { 13 | name string 14 | storeState map[string]string 15 | input json.RawMessage 16 | output json.RawMessage 17 | expErr error 18 | }{ 19 | { 20 | name: "setting value", 21 | storeState: map[string]string{}, 22 | input: json.RawMessage(`{"command": "set", "key": "hello", "value": "world"}`), 23 | output: json.RawMessage(`"stored successfully"`), 24 | }, 25 | { 26 | name: "getting value", 27 | storeState: map[string]string{"hello": "world"}, 28 | input: json.RawMessage(`{"command": "get", "key": "hello"}`), 29 | output: json.RawMessage(`"world"`), 30 | }, 31 | { 32 | name: "list keys", 33 | storeState: map[string]string{"hello": "world"}, 34 | input: json.RawMessage(`{"command": "list"}`), 35 | output: json.RawMessage(`["hello"]`), 36 | }, 37 | } 38 | 39 | for _, testCase := range testCases { 40 | t.Run(testCase.name, func(t *testing.T) { 41 | store := NewKeyValueStore() 42 | store.store = testCase.storeState 43 | output, err := store.Execute(testCase.input) 44 | if testCase.expErr != nil { 45 | require.EqualError(t, err, testCase.expErr.Error()) 46 | return 47 | } 48 | require.NoError(t, err) 49 | assert.JSONEq(t, string(testCase.output), string(output)) 50 | }) 51 | } 52 | } 53 | 54 | func TestKeyValueStorePreprocessing(t *testing.T) { 55 | testCases := []struct { 56 | name string 57 | input string 58 | output string 59 | }{ 60 | { 61 | name: "no store", 62 | input: "hello world", 63 | output: "hello world", 64 | }, 65 | { 66 | name: "single store", 67 | input: "hello {{ store \"key\" }}", 68 | output: "hello world", 69 | }, 70 | { 71 | name: "multiple stores", 72 | input: "hello {{ store \"key1\" }} {{ store \"key2\" }}", 73 | output: "hello world world", 74 | }, 75 | { 76 | name: "multiple stores with other text", 77 | input: "hello {{ store \"key1\" }} world {{ store \"key2\" }}", 78 | output: "hello world world world", 79 | }, 80 | } 81 | 82 | for _, testCase := range testCases { 83 | t.Run(testCase.name, func(t *testing.T) { 84 | store := NewKeyValueStore() 85 | store.store = map[string]string{ 86 | "key": "world", 87 | "key1": "world", 88 | "key2": "world", 89 | } 90 | marshaledInput, err := json.Marshal(testCase.input) 91 | require.NoError(t, err) 92 | output, err := store.Process(marshaledInput) 93 | require.NoError(t, err) 94 | var unmarshaledOutput string 95 | err = json.Unmarshal(output, &unmarshaledOutput) 96 | require.NoError(t, err) 97 | assert.Equal(t, testCase.output, unmarshaledOutput) 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /tools/mocks/tool.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: tool.go 3 | 4 | // Package mocks is a generated GoMock package. 5 | package mocks 6 | 7 | import ( 8 | json "encoding/json" 9 | reflect "reflect" 10 | 11 | gomock "github.com/golang/mock/gomock" 12 | ) 13 | 14 | // MockTool is a mock of Tool interface. 15 | type MockTool struct { 16 | ctrl *gomock.Controller 17 | recorder *MockToolMockRecorder 18 | } 19 | 20 | // MockToolMockRecorder is the mock recorder for MockTool. 21 | type MockToolMockRecorder struct { 22 | mock *MockTool 23 | } 24 | 25 | // NewMockTool creates a new mock instance. 26 | func NewMockTool(ctrl *gomock.Controller) *MockTool { 27 | mock := &MockTool{ctrl: ctrl} 28 | mock.recorder = &MockToolMockRecorder{mock} 29 | return mock 30 | } 31 | 32 | // EXPECT returns an object that allows the caller to indicate expected use. 33 | func (m *MockTool) EXPECT() *MockToolMockRecorder { 34 | return m.recorder 35 | } 36 | 37 | // ArgsSchema mocks base method. 38 | func (m *MockTool) ArgsSchema() json.RawMessage { 39 | m.ctrl.T.Helper() 40 | ret := m.ctrl.Call(m, "ArgsSchema") 41 | ret0, _ := ret[0].(json.RawMessage) 42 | return ret0 43 | } 44 | 45 | // ArgsSchema indicates an expected call of ArgsSchema. 46 | func (mr *MockToolMockRecorder) ArgsSchema() *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArgsSchema", reflect.TypeOf((*MockTool)(nil).ArgsSchema)) 49 | } 50 | 51 | // CompactArgs mocks base method. 52 | func (m *MockTool) CompactArgs(args json.RawMessage) json.RawMessage { 53 | m.ctrl.T.Helper() 54 | ret := m.ctrl.Call(m, "CompactArgs", args) 55 | ret0, _ := ret[0].(json.RawMessage) 56 | return ret0 57 | } 58 | 59 | // CompactArgs indicates an expected call of CompactArgs. 60 | func (mr *MockToolMockRecorder) CompactArgs(args interface{}) *gomock.Call { 61 | mr.mock.ctrl.T.Helper() 62 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompactArgs", reflect.TypeOf((*MockTool)(nil).CompactArgs), args) 63 | } 64 | 65 | // Description mocks base method. 66 | func (m *MockTool) Description() string { 67 | m.ctrl.T.Helper() 68 | ret := m.ctrl.Call(m, "Description") 69 | ret0, _ := ret[0].(string) 70 | return ret0 71 | } 72 | 73 | // Description indicates an expected call of Description. 74 | func (mr *MockToolMockRecorder) Description() *gomock.Call { 75 | mr.mock.ctrl.T.Helper() 76 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Description", reflect.TypeOf((*MockTool)(nil).Description)) 77 | } 78 | 79 | // Execute mocks base method. 80 | func (m *MockTool) Execute(args json.RawMessage) (json.RawMessage, error) { 81 | m.ctrl.T.Helper() 82 | ret := m.ctrl.Call(m, "Execute", args) 83 | ret0, _ := ret[0].(json.RawMessage) 84 | ret1, _ := ret[1].(error) 85 | return ret0, ret1 86 | } 87 | 88 | // Execute indicates an expected call of Execute. 89 | func (mr *MockToolMockRecorder) Execute(args interface{}) *gomock.Call { 90 | mr.mock.ctrl.T.Helper() 91 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockTool)(nil).Execute), args) 92 | } 93 | 94 | // Name mocks base method. 95 | func (m *MockTool) Name() string { 96 | m.ctrl.T.Helper() 97 | ret := m.ctrl.Call(m, "Name") 98 | ret0, _ := ret[0].(string) 99 | return ret0 100 | } 101 | 102 | // Name indicates an expected call of Name. 103 | func (mr *MockToolMockRecorder) Name() *gomock.Call { 104 | mr.mock.ctrl.T.Helper() 105 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockTool)(nil).Name)) 106 | } 107 | 108 | // MockPreprocessingTool is a mock of PreprocessingTool interface. 109 | type MockPreprocessingTool struct { 110 | ctrl *gomock.Controller 111 | recorder *MockPreprocessingToolMockRecorder 112 | } 113 | 114 | // MockPreprocessingToolMockRecorder is the mock recorder for MockPreprocessingTool. 115 | type MockPreprocessingToolMockRecorder struct { 116 | mock *MockPreprocessingTool 117 | } 118 | 119 | // NewMockPreprocessingTool creates a new mock instance. 120 | func NewMockPreprocessingTool(ctrl *gomock.Controller) *MockPreprocessingTool { 121 | mock := &MockPreprocessingTool{ctrl: ctrl} 122 | mock.recorder = &MockPreprocessingToolMockRecorder{mock} 123 | return mock 124 | } 125 | 126 | // EXPECT returns an object that allows the caller to indicate expected use. 127 | func (m *MockPreprocessingTool) EXPECT() *MockPreprocessingToolMockRecorder { 128 | return m.recorder 129 | } 130 | 131 | // Process mocks base method. 132 | func (m *MockPreprocessingTool) Process(args json.RawMessage) (json.RawMessage, error) { 133 | m.ctrl.T.Helper() 134 | ret := m.ctrl.Call(m, "Process", args) 135 | ret0, _ := ret[0].(json.RawMessage) 136 | ret1, _ := ret[1].(error) 137 | return ret0, ret1 138 | } 139 | 140 | // Process indicates an expected call of Process. 141 | func (mr *MockPreprocessingToolMockRecorder) Process(args interface{}) *gomock.Call { 142 | mr.mock.ctrl.T.Helper() 143 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Process", reflect.TypeOf((*MockPreprocessingTool)(nil).Process), args) 144 | } 145 | -------------------------------------------------------------------------------- /tools/python_repl.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | ) 9 | 10 | type PythonREPL struct { 11 | pythonBinary string 12 | } 13 | 14 | func (p *PythonREPL) createVenv() error { 15 | if _, err := os.Stat(".venv"); err == nil { 16 | return nil 17 | } 18 | _, err := exec.Command(p.pythonBinary, "-m", "venv", ".venv").Output() 19 | if err != nil { 20 | return err 21 | } 22 | return nil 23 | } 24 | 25 | func (p *PythonREPL) installModules(modules []string) error { 26 | args := []string{"-m", "pip", "install"} 27 | args = append(args, modules...) 28 | _, err := exec.Command("./.venv/bin/python", args...).Output() 29 | if err != nil { 30 | return err 31 | } 32 | return nil 33 | } 34 | 35 | func (p *PythonREPL) Execute(args json.RawMessage) (json.RawMessage, error) { 36 | var command struct { 37 | Code string `json:"code"` 38 | Modules []string `json:"modules"` 39 | } 40 | err := json.Unmarshal(args, &command) 41 | if err != nil { 42 | return nil, fmt.Errorf("failed to unmarshal args: %w", err) 43 | } 44 | err = p.createVenv() 45 | if err != nil { 46 | return nil, fmt.Errorf("failed to create venv: %w", err) 47 | } 48 | if len(command.Modules) > 0 { 49 | err = p.installModules(command.Modules) 50 | if err != nil { 51 | if exitError, ok := err.(*exec.ExitError); ok { 52 | return nil, fmt.Errorf("failed to install modules: %s", string(exitError.Stderr)) 53 | } 54 | return nil, fmt.Errorf("failed to install modules: %w", err) 55 | } 56 | } 57 | out, err := exec.Command("./.venv/bin/python", "-c", command.Code).Output() 58 | if err != nil { 59 | if exitError, ok := err.(*exec.ExitError); ok { 60 | return nil, fmt.Errorf("python exited with code %d: %s", exitError.ExitCode(), string(exitError.Stderr)) 61 | } 62 | return nil, err 63 | } 64 | return json.RawMessage(out), nil 65 | } 66 | 67 | func (p *PythonREPL) Name() string { 68 | return "python" 69 | } 70 | 71 | func (p *PythonREPL) Description() string { 72 | return "A Python REPL. Use this to execute scripts that help you complete your task. " + 73 | "If you need to install any modules, you can do so by passing a list of modules to the modules argument." 74 | } 75 | 76 | func (p *PythonREPL) ArgsSchema() json.RawMessage { 77 | return json.RawMessage(`{"code": "the Python code to execute, properly escaped", "modules": ["a list", "of modules", "to install"]}`) 78 | } 79 | 80 | func (p *PythonREPL) CompactArgs(args json.RawMessage) json.RawMessage { 81 | return args 82 | } 83 | 84 | func NewPythonREPL() *PythonREPL { 85 | return &PythonREPL{ 86 | pythonBinary: "python3", 87 | } 88 | } 89 | 90 | func NewPythonREPLWithCustomBinary(pythonBinary string) *PythonREPL { 91 | return &PythonREPL{ 92 | pythonBinary: pythonBinary, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /tools/python_repl_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestPythonREPL(t *testing.T) { 14 | testCases := []struct { 15 | name string 16 | repl *PythonREPL 17 | input json.RawMessage 18 | output json.RawMessage 19 | expErr error 20 | }{ 21 | { 22 | name: "simple", 23 | repl: NewPythonREPL(), 24 | input: json.RawMessage(`{"code": "print(1 + 1)"}`), 25 | output: json.RawMessage(`2`), 26 | }, 27 | { 28 | name: "error", 29 | repl: NewPythonREPL(), 30 | input: json.RawMessage(`{"code": "print(1 + 1"}`), 31 | expErr: fmt.Errorf("python exited with code 1: File \"\", line 1\n print(1 + 1\n ^\nSyntaxError: '(' was never closed\n"), 32 | }, 33 | { 34 | name: "multiline code", 35 | repl: NewPythonREPL(), 36 | input: json.RawMessage(`{ 37 | "code": "print('[')\nfor i in range(3):\n print(i)\n print(',')\nprint('9]')" 38 | }`), 39 | output: json.RawMessage(`[0,1,2,9]`), 40 | }, 41 | { 42 | name: "with modules", 43 | repl: NewPythonREPL(), 44 | input: json.RawMessage(`{ 45 | "code": "import dotenv\ndotenv.load_dotenv()\nprint([1,2,3])", 46 | "modules": ["python-dotenv"] 47 | }`), 48 | output: json.RawMessage(`[1,2,3]`), 49 | }, 50 | } 51 | for _, testCase := range testCases { 52 | t.Run(testCase.name, func(t *testing.T) { 53 | t.Cleanup(func() { 54 | require.NoError(t, os.RemoveAll(".venv")) 55 | }) 56 | output, err := testCase.repl.Execute(testCase.input) 57 | if testCase.expErr != nil { 58 | require.EqualError(t, err, testCase.expErr.Error()) 59 | return 60 | } 61 | require.NoError(t, err) 62 | assert.JSONEq(t, string(testCase.output), string(output)) 63 | }) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /tools/testdata/expected_parsed_results.json: -------------------------------------------------------------------------------- 1 | [{"Title":"\"Hello, World!\" program - Wikipedia","URL":"https://en.wikipedia.org/wiki/%22Hello,_World!%22_program"},{"Title":"תוכנית Hello world - ויקיפדיה","URL":"https://he.wikipedia.org/wiki/%D7%AA%D7%95%D7%9B%D7%A0%D7%99%D7%AA_Hello_world"},{"Title":"Hello World | Code.org","URL":"https://code.org/helloworld"},{"Title":"Hello World - Raspberry Pi","URL":"https://helloworld.raspberrypi.org/"},{"Title":"Hello World - Go by Example","URL":"https://gobyexample.com/hello-world"},{"Title":"Hello World (2019) - IMDb","URL":"https://www.imdb.com/title/tt9418812/"},{"Title":"Hello, World! - Free Interactive Python Tutorial","URL":"https://www.learnpython.org/en/Hello,_World!"},{"Title":"Hello World - GitHub Docs","URL":"https://docs.github.com/en/get-started/quickstart/hello-world"},{"Title":"Total immersion, Serious fun! with Hello-World!","URL":"https://www.hello-world.com/"}] -------------------------------------------------------------------------------- /tools/tool.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import "encoding/json" 4 | 5 | //go:generate mockgen -source=tool.go -destination=mocks/tool.go -package=mocks 6 | type Tool interface { 7 | // Executes the tool with the given 8 | // arguments. If the arguments are 9 | // invalid, an error should be returned 10 | // which will be displayed to the agent. 11 | Execute(args json.RawMessage) (json.RawMessage, error) 12 | // The name of the tool, as it will be 13 | // displayed to the agent. 14 | Name() string 15 | // A short description of the tool, as 16 | // it will be displayed to the agent. 17 | Description() string 18 | // A 'fuzzy schema' of the arguments 19 | // that the tool expects. This is used 20 | // to instruct the agent on how to 21 | // generate the arguments. 22 | ArgsSchema() json.RawMessage 23 | // Generates a compact representation 24 | // of the arguments, to be used in the 25 | // agent's memory. 26 | CompactArgs(args json.RawMessage) json.RawMessage 27 | } 28 | 29 | type PreprocessingTool interface { 30 | // Preprocesses the arguments before 31 | // they are passed to any tool. 32 | Process(args json.RawMessage) (json.RawMessage, error) 33 | } 34 | -------------------------------------------------------------------------------- /tools/utils.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/natexcvi/go-llm/engines" 9 | ) 10 | 11 | var ( 12 | ErrCannotAutoConvertArgSchema = fmt.Errorf("cannot auto-convert arg schema") 13 | ) 14 | 15 | func ConvertToNativeFunctionSpecs(tool Tool) (engines.FunctionSpecs, error) { 16 | parameterSpecs, err := convertArgSchemaToParameterSpecs(tool.ArgsSchema()) 17 | if err != nil { 18 | return engines.FunctionSpecs{}, err 19 | } 20 | return engines.FunctionSpecs{ 21 | Name: tool.Name(), 22 | Description: tool.Description(), 23 | Parameters: ¶meterSpecs, 24 | }, nil 25 | } 26 | 27 | func convertArgSchemaToParameterSpecs(argSchema json.RawMessage) (engines.ParameterSpecs, error) { 28 | var unmarshaledSchema any 29 | if err := json.Unmarshal(argSchema, &unmarshaledSchema); err != nil { 30 | return engines.ParameterSpecs{}, err 31 | } 32 | switch schema := unmarshaledSchema.(type) { 33 | case map[string]any: 34 | specs := engines.ParameterSpecs{ 35 | Type: "object", 36 | Properties: map[string]*engines.ParameterSpecs{}, 37 | Required: []string{}, 38 | } 39 | for key, value := range schema { 40 | marshaledValue, err := json.Marshal(value) 41 | if err != nil { 42 | return engines.ParameterSpecs{}, err 43 | } 44 | propertySpecs, err := convertArgSchemaToParameterSpecs(marshaledValue) 45 | if err != nil { 46 | return engines.ParameterSpecs{}, err 47 | } 48 | specs.Properties[key] = &propertySpecs 49 | // specs.Required = append(specs.Required, key) 50 | } 51 | return specs, nil 52 | case []any: 53 | specs := engines.ParameterSpecs{ 54 | Type: "array", 55 | Items: nil, 56 | } 57 | // infer type from first element 58 | for _, value := range schema { 59 | marshaledValue, err := json.Marshal(value) 60 | if err != nil { 61 | return engines.ParameterSpecs{}, err 62 | } 63 | propertySpecs, err := convertArgSchemaToParameterSpecs(marshaledValue) 64 | if err != nil { 65 | return engines.ParameterSpecs{}, err 66 | } 67 | if specs.Items != nil && specs.Items.Type != propertySpecs.Type { 68 | return engines.ParameterSpecs{}, fmt.Errorf("%w: arrays with values of more than one type not currently supported", ErrCannotAutoConvertArgSchema) 69 | } 70 | if specs.Items != nil && specs.Items.Description != propertySpecs.Description { 71 | propertySpecs.Description = strings.Join([]string{ 72 | specs.Items.Description, 73 | propertySpecs.Description, 74 | }, " ") 75 | } 76 | specs.Items = &propertySpecs 77 | } 78 | return specs, nil 79 | case string: 80 | return engines.ParameterSpecs{ 81 | Type: "string", 82 | Description: schema, 83 | }, nil 84 | case float64, int: 85 | return engines.ParameterSpecs{ 86 | Type: "number", 87 | Description: "a number", 88 | }, nil 89 | case bool: 90 | return engines.ParameterSpecs{ 91 | Type: "boolean", 92 | Description: "a boolean value", 93 | }, nil 94 | default: 95 | return engines.ParameterSpecs{}, ErrCannotAutoConvertArgSchema 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /tools/utils_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/natexcvi/go-llm/engines" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type mockTool struct { 12 | name string 13 | description string 14 | argsSchema string 15 | } 16 | 17 | func (t *mockTool) Execute(args json.RawMessage) (json.RawMessage, error) { 18 | return nil, nil 19 | } 20 | 21 | func (t *mockTool) Name() string { 22 | return t.name 23 | } 24 | 25 | func (t *mockTool) Description() string { 26 | return t.description 27 | } 28 | 29 | func (t *mockTool) ArgsSchema() json.RawMessage { 30 | return []byte(t.argsSchema) 31 | } 32 | func (t *mockTool) CompactArgs(args json.RawMessage) json.RawMessage { 33 | return args 34 | } 35 | 36 | func TestConvertToLLMFunctionSpecs(t *testing.T) { 37 | testCases := []struct { 38 | name string 39 | tool Tool 40 | expectedOutput engines.FunctionSpecs 41 | expectedErr error 42 | }{ 43 | { 44 | name: "Simple tool args", 45 | tool: &mockTool{ 46 | name: "test", 47 | description: "This is a test.", 48 | argsSchema: `{"text": "some text", "num": 0, "booly": true}`, 49 | }, 50 | expectedOutput: engines.FunctionSpecs{ 51 | Name: "test", 52 | Description: "This is a test.", 53 | Parameters: &engines.ParameterSpecs{ 54 | Type: "object", 55 | Properties: map[string]*engines.ParameterSpecs{ 56 | "text": { 57 | Type: "string", 58 | Description: "some text", 59 | }, 60 | "num": { 61 | Type: "number", 62 | Description: "a number", 63 | }, 64 | "booly": { 65 | Type: "boolean", 66 | Description: "a boolean value", 67 | }, 68 | }, 69 | Required: []string{}, 70 | }, 71 | }, 72 | expectedErr: nil, 73 | }, 74 | { 75 | name: "Tool args with array", 76 | tool: &mockTool{ 77 | name: "array_tool", 78 | description: "This is a test.", 79 | argsSchema: `{"text": "some text", "num": 0, "arr": ["this", "is", "an", "array"]}`, 80 | }, 81 | expectedOutput: engines.FunctionSpecs{ 82 | Name: "array_tool", 83 | Description: "This is a test.", 84 | Parameters: &engines.ParameterSpecs{ 85 | Type: "object", 86 | Properties: map[string]*engines.ParameterSpecs{ 87 | "text": { 88 | Type: "string", 89 | Description: "some text", 90 | }, 91 | "num": { 92 | Type: "number", 93 | Description: "a number", 94 | }, 95 | "arr": { 96 | Type: "array", 97 | Items: &engines.ParameterSpecs{Type: "string", Description: "this is an array"}, 98 | }, 99 | }, 100 | Required: []string{}, 101 | }, 102 | }, 103 | expectedErr: nil, 104 | }, 105 | } 106 | for _, tc := range testCases { 107 | t.Run(tc.name, func(t *testing.T) { 108 | output, err := ConvertToNativeFunctionSpecs(tc.tool) 109 | assert.Equal(t, tc.expectedOutput, output) 110 | assert.Equal(t, tc.expectedErr, err) 111 | }) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /tools/webpage_summary.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "regexp" 9 | "strings" 10 | 11 | "github.com/PuerkitoBio/goquery" 12 | "github.com/natexcvi/go-llm/engines" 13 | ) 14 | 15 | type WebpageSummary struct { 16 | model engines.LLM 17 | } 18 | 19 | func (*WebpageSummary) stripHTMLTags(s string) string { 20 | // Remove HTML tags 21 | document, err := goquery.NewDocumentFromReader(strings.NewReader(s)) 22 | if err != nil { 23 | return "" 24 | } 25 | document.Find("script, style").Each(func(index int, item *goquery.Selection) { 26 | item.Remove() 27 | }) 28 | text := document.Text() 29 | 30 | // Remove JavaScript code 31 | re := regexp.MustCompile(`(?m)^world", 122 | expected: "hello, world", 123 | }, 124 | { 125 | name: "with newlines", 126 | input: "hello,\n\nworld", 127 | expected: "hello, world", 128 | }, 129 | { 130 | name: "with nested tags", 131 | input: "hello, world and universe", 132 | expected: "hello, world and universe", 133 | }, 134 | } 135 | for _, tc := range testCases { 136 | t.Run(tc.name, func(t *testing.T) { 137 | assert.Equal(t, tc.expected, (&WebpageSummary{}).stripHTMLTags(tc.input)) 138 | }) 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /tools/wolfram_alpha.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | ) 9 | 10 | type WolframAlpha struct { 11 | ServiceURL string 12 | AppID string 13 | } 14 | 15 | func NewWolframAlpha(appID string) *WolframAlpha { 16 | return &WolframAlpha{ 17 | ServiceURL: "http://api.wolframalpha.com/v1/result", 18 | AppID: appID, 19 | } 20 | } 21 | 22 | func NewWolframAlphaWithServiceURL(serviceURL, appID string) *WolframAlpha { 23 | return &WolframAlpha{ 24 | ServiceURL: serviceURL, 25 | AppID: appID, 26 | } 27 | } 28 | 29 | func (wa *WolframAlpha) shortAnswer(query string) (answer string, err error) { 30 | req, err := http.NewRequest("GET", wa.ServiceURL, nil) 31 | if err != nil { 32 | return "", fmt.Errorf("failed to create request: %w", err) 33 | } 34 | q := req.URL.Query() 35 | q.Add("appid", wa.AppID) 36 | q.Add("i", query) 37 | q.Add("timeout", "10") 38 | req.URL.RawQuery = q.Encode() 39 | resp, err := http.DefaultClient.Do(req) 40 | if err != nil { 41 | return "", fmt.Errorf("failed to send request: %w", err) 42 | } 43 | defer resp.Body.Close() 44 | if resp.StatusCode != http.StatusOK { 45 | return "", fmt.Errorf("failed to get short answer: %s", resp.Status) 46 | } 47 | body, err := ioutil.ReadAll(resp.Body) 48 | if err != nil { 49 | return "", fmt.Errorf("failed to read response body: %w", err) 50 | } 51 | return string(body), nil 52 | } 53 | 54 | func (wa *WolframAlpha) Execute(args json.RawMessage) (json.RawMessage, error) { 55 | var query struct { 56 | Query string `json:"query"` 57 | } 58 | err := json.Unmarshal(args, &query) 59 | if err != nil { 60 | return nil, fmt.Errorf("failed to unmarshal args: %w", err) 61 | } 62 | answer, err := wa.shortAnswer(query.Query) 63 | if err != nil { 64 | return nil, fmt.Errorf("failed to query WolframAlpha: %w", err) 65 | } 66 | return json.Marshal(answer) 67 | } 68 | 69 | func (wa *WolframAlpha) Name() string { 70 | return "wolfram_alpha" 71 | } 72 | 73 | func (wa *WolframAlpha) Description() string { 74 | return "A tool for querying WolframAlpha. Use it for " + 75 | "factual information retrieval, calculations, etc." 76 | } 77 | 78 | func (wa *WolframAlpha) ArgsSchema() json.RawMessage { 79 | return json.RawMessage(`{"query": "the search query, e.g. '2+2' or 'what is the capital of France?'"}`) 80 | } 81 | 82 | func (wa *WolframAlpha) CompactArgs(args json.RawMessage) json.RawMessage { 83 | return args 84 | } 85 | -------------------------------------------------------------------------------- /tools/wolfram_alpha_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func mockWolframServer(t *testing.T, response string) { 15 | t.Helper() 16 | // Set up a mock server to respond to requests with a mock Google search results page 17 | mockServer := http.NewServeMux() 18 | mockServer.HandleFunc("/v1/result", func(w http.ResponseWriter, r *http.Request) { 19 | w.Header().Set("Content-Type", "text/plain") 20 | w.Write([]byte(response)) 21 | }) 22 | 23 | server := &http.Server{ 24 | Addr: ":8080", 25 | Handler: mockServer, 26 | } 27 | 28 | // Start the mock server on port 8080 29 | go func() { 30 | if err := server.ListenAndServe(); err != nil { 31 | if err == http.ErrServerClosed { 32 | return 33 | } 34 | log.Fatal(err) 35 | } 36 | }() 37 | t.Cleanup(func() { 38 | require.NoError(t, server.Shutdown(context.Background())) 39 | }) 40 | } 41 | 42 | func TestWolframAlpha(t *testing.T) { 43 | testCases := []struct { 44 | name string 45 | query string 46 | mockResponse string 47 | expectedError error 48 | }{ 49 | { 50 | name: "simple calculation", 51 | query: "(1 + 2) * 3", 52 | mockResponse: "9", 53 | }, 54 | } 55 | for _, tc := range testCases { 56 | t.Run(tc.name, func(t *testing.T) { 57 | mockWolframServer(t, tc.mockResponse) 58 | wa := NewWolframAlphaWithServiceURL("http://localhost:8080/v1/result", "1234") 59 | _, err := wa.Execute(json.RawMessage( 60 | fmt.Sprintf(`{"query": %q}`, tc.query), 61 | )) 62 | if tc.expectedError != nil { 63 | require.Error(t, err) 64 | require.EqualError(t, err, tc.expectedError.Error()) 65 | } else { 66 | require.NoError(t, err) 67 | } 68 | }) 69 | } 70 | } 71 | --------------------------------------------------------------------------------