├── .gitignore ├── img ├── banner.jpg ├── chats.gif └── piping.gif ├── internal ├── utils │ ├── errors.go │ ├── context_keys.go │ ├── term.go │ ├── path.go │ ├── misc.go │ ├── input.go │ ├── file.go │ ├── file_test.go │ ├── prompt.go │ ├── prompt_test.go │ └── misc_test.go ├── vendors │ ├── anthropic │ │ ├── constants.go │ │ ├── claude_setup.go │ │ ├── claude_test.go │ │ ├── claude_models.go │ │ └── claude_stream_block_events.go │ ├── openai │ │ ├── constants.go │ │ ├── gpt.go │ │ └── gpt_test.go │ ├── inception │ │ ├── inception_test.go │ │ └── inception.go │ ├── xai │ │ ├── xai.go │ │ └── xai_test.go │ ├── gemini │ │ ├── gemini.go │ │ └── gemini_test.go │ ├── ollama │ │ ├── ollama.go │ │ └── ollama_test.go │ ├── deepseek │ │ ├── deepseek.go │ │ └── deepseek_test.go │ ├── novita │ │ ├── novita.go │ │ └── novita_test.go │ └── mistral │ │ ├── mistral_test.go │ │ └── mistral.go ├── models │ ├── completion │ │ └── types.go │ ├── errors_test.go │ ├── model_generic_tests.go │ ├── model_generic_tests_test.go │ ├── models.go │ └── models_test.go ├── version.go ├── chat │ ├── replay.go │ ├── chat_test.go │ ├── handler_test.go │ ├── chat.go │ └── reply.go ├── tools │ ├── models.go │ ├── bash_tool_pwd_test.go │ ├── bash_tool_cat_test.go │ ├── bash_tool_find_test.go │ ├── bash_tool_pwd.go │ ├── programming_tool_line_count_test.go │ ├── bash_tool_date_test.go │ ├── mcp │ │ ├── client_test.go │ │ ├── models.go │ │ ├── testserver │ │ │ └── main.go │ │ ├── manager_test.go │ │ ├── client.go │ │ └── tool.go │ ├── clai_tool_help.go │ ├── cmd.go │ ├── clai_tool_result.go │ ├── programming_tool_line_count.go │ ├── clai_tool_check.go │ ├── bash_tool_freetext_command.go │ ├── bash_tool_tree.go │ ├── bash_tool_file.go │ ├── programming_tool_sed_test.go │ ├── programming_tool_go.go │ ├── bash_tool_ls.go │ ├── programming_tool_rows_between_test.go │ ├── programming_tool_git_test.go │ ├── bash_tool_cat.go │ ├── bash_tool_find.go │ ├── programming_tool_write_file.go │ ├── programming_tool_recall.go │ ├── bash_tool_date.go │ ├── programming_tool_rows_between.go │ ├── clai_tool_wait_for_workers_test.go │ ├── bash_tool_rg.go │ ├── handler.go │ ├── registry.go │ ├── programming_tool_git.go │ └── programming_tool_sed.go ├── photo │ ├── conf_test.go │ ├── store.go │ ├── conf.go │ ├── funimation_0.go │ ├── store_test.go │ ├── prompt.go │ └── prompt_additional_test.go ├── video │ ├── conf_test.go │ ├── store.go │ ├── conf.go │ ├── prompt.go │ └── store_test.go ├── text │ ├── generic │ │ ├── stream_completer_setup.go │ │ └── stream_completer_models.go │ ├── querier_setup_tools_test.go │ ├── conf_profile.go │ ├── querier_cmd_mode_test.go │ └── querier_cmd_mode.go ├── setup │ ├── setup_actions_test.go │ ├── setup_test.go │ └── mcp_parser.go ├── setup_test.go ├── glob │ └── glob.go └── profiles │ └── cmd.go ├── pkg └── text │ ├── models │ ├── configurations.go │ ├── tools_test.go │ └── chat_test.go │ ├── full_test.go │ └── full.go ├── go.mod ├── .github └── workflows │ ├── validate.yml │ └── release.yml ├── examples └── profiles │ ├── gopher.json │ └── cody.json ├── go.sum ├── LICENSE ├── setup.sh └── .vscode └── launch.json /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.out 2 | -------------------------------------------------------------------------------- /img/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/HEAD/img/banner.jpg -------------------------------------------------------------------------------- /img/chats.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/HEAD/img/chats.gif -------------------------------------------------------------------------------- /img/piping.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/HEAD/img/piping.gif -------------------------------------------------------------------------------- /internal/utils/errors.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "errors" 4 | 5 | var ErrUserInitiatedExit = errors.New("user exit") 6 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/constants.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | const ClaudeURL = "https://api.anthropic.com/v1/messages" 4 | -------------------------------------------------------------------------------- /internal/models/completion/types.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | type Type int 4 | 5 | const ( 6 | ERROR Type = iota 7 | TOKEN 8 | ) 9 | -------------------------------------------------------------------------------- /internal/utils/context_keys.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | type ContextKey string 4 | 5 | const ( 6 | ContextCancelKey ContextKey = "contextCancel" 7 | ) 8 | -------------------------------------------------------------------------------- /internal/version.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | // Set with buildflag if built in pipeline and not using go install 4 | var ( 5 | BuildVersion = "" 6 | BUILD_CHECKSUM = "" 7 | ) 8 | -------------------------------------------------------------------------------- /pkg/text/models/configurations.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type Configurations struct { 4 | Model string 5 | SystemPrompt string 6 | ConfigDir string 7 | InternalTools []ToolName 8 | McpServers []McpServer 9 | } 10 | -------------------------------------------------------------------------------- /internal/vendors/openai/constants.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | const ( 4 | ChatURL = "https://api.openai.com/v1/chat/completions" 5 | PhotoURL = "https://api.openai.com/v1/images/generations" 6 | VideoURL = "https://api.openai.com/v1/videos" 7 | FilesURL = "https://api.openai.com/v1/files" 8 | ) 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/baalimago/clai 2 | 3 | go 1.24 4 | 5 | require github.com/baalimago/go_away_boilerplate v1.3.337 6 | 7 | require golang.org/x/net v0.43.0 8 | 9 | require golang.org/x/exp v0.0.0-20240119083558-1b970713d09a 10 | 11 | require golang.org/x/text v0.28.0 // indirect 12 | -------------------------------------------------------------------------------- /.github/workflows/validate.yml: -------------------------------------------------------------------------------- 1 | name: Simple Go Pipeline - validate 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | call-workflow: 11 | uses: baalimago/simple-go-pipeline/.github/workflows/validate.yml@main 12 | with: 13 | go-version: "1.24" 14 | staticcheck-version: "2025.1.1" 15 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Simple Go Pipeline - release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9]+.[0-9]+.[0-9]+" 7 | jobs: 8 | call-workflow: 9 | uses: baalimago/simple-go-pipeline/.github/workflows/release.yml@v0.3.0 10 | with: 11 | project-name: clai 12 | branch: main 13 | version-var: "github.com/baalimago/clai/internal.BUILD_VERSION" 14 | -------------------------------------------------------------------------------- /internal/utils/term.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "syscall" 5 | "unsafe" 6 | ) 7 | 8 | func TermWidth() (int, error) { 9 | ws := &struct { 10 | Row uint16 11 | Col uint16 12 | Xpixel uint16 13 | Ypixel uint16 14 | }{} 15 | 16 | retCode, _, errno := syscall.Syscall( 17 | syscall.SYS_IOCTL, 18 | uintptr(syscall.Stderr), 19 | uintptr(syscall.TIOCGWINSZ), 20 | uintptr(unsafe.Pointer(ws)), 21 | ) 22 | 23 | if int(retCode) == -1 { 24 | return 0, errno 25 | } 26 | 27 | return int(ws.Col), nil 28 | } 29 | -------------------------------------------------------------------------------- /internal/chat/replay.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/baalimago/clai/internal/utils" 8 | ) 9 | 10 | func Replay(raw bool) error { 11 | prevReply, err := LoadPrevQuery("") 12 | if err != nil { 13 | return fmt.Errorf("failed to load previous reply: %v", err) 14 | } 15 | amMessages := len(prevReply.Messages) 16 | if amMessages == 0 { 17 | return errors.New("failed to find any recent reply") 18 | } 19 | mostRecentMsg := prevReply.Messages[amMessages-1] 20 | utils.AttemptPrettyPrint(mostRecentMsg, "system", raw) 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /internal/tools/models.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import pub_models "github.com/baalimago/clai/pkg/text/models" 4 | 5 | type LLMTool interface { 6 | // Call the LLM tool with the given Input. Returns output from the tool or an error 7 | // if the call returned an error-like. An error-like is either exit code non-zero or 8 | // http response which isn't 2xx or 3xx. 9 | Call(pub_models.Input) (string, error) 10 | 11 | // Return the Specification, later on used 12 | // by text queriers to send to their respective 13 | // models 14 | Specification() pub_models.Specification 15 | } 16 | 17 | type McpServerConfig map[string]pub_models.McpServer 18 | -------------------------------------------------------------------------------- /examples/profiles/gopher.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "gpt-4.1", 3 | "prompt": "You are golang programming assistant. Answer consicely. ONLY use go standard library. NEVER write any files unless asked to. If context is lacking, assume that you should writ a function which achieves the task. Example: 'traverses a file tree' =\u003e 'write a function that traverses a file tree'", 4 | "save-reply-as-conv": true, 5 | "tools": [ 6 | "cat", 7 | "file_type", 8 | "file_tree", 9 | "ls", 10 | "rg", 11 | "go", 12 | "write_file", 13 | "find", 14 | "sed", 15 | "rows_between", 16 | "mcp_fetch" 17 | ], 18 | "use_tools": true 19 | } 20 | -------------------------------------------------------------------------------- /internal/photo/conf_test.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import "testing" 4 | 5 | func TestValidateOutputType(t *testing.T) { 6 | valid := []OutputType{LOCAL, URL, UNSET} 7 | for _, v := range valid { 8 | if err := ValidateOutputType(v); err != nil { 9 | t.Errorf("expected no error for %v", v) 10 | } 11 | } 12 | if err := ValidateOutputType(OutputType("bad")); err == nil { 13 | t.Error("expected error for invalid output type") 14 | } 15 | } 16 | 17 | func TestFunimation(t *testing.T) { 18 | if funimation(0) != "🕛" { 19 | t.Errorf("unexpected image for 0") 20 | } 21 | if funimation(43478260) != "🕧" { 22 | t.Errorf("unexpected image for step") 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /internal/models/errors_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestRateLimitError(t *testing.T) { 10 | reset := time.Now().Add(time.Minute).Round(time.Second) 11 | err := NewRateLimitError(reset, 1234, 567) 12 | rl, ok := err.(*ErrRateLimit) 13 | if !ok { 14 | t.Fatalf("expected *ErrRateLimit, got %T", err) 15 | } 16 | if rl.MaxInputTokens != 1234 || rl.TokensRemaining != 567 || !rl.ResetAt.Equal(reset) { 17 | t.Fatalf("unexpected values: %#v", rl) 18 | } 19 | msg := rl.Error() 20 | if !strings.Contains(msg, "reset at:") || !strings.Contains(msg, "input tokens used") { 21 | t.Fatalf("unexpected error message: %q", msg) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_pwd_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestPwd_ReturnsCurrentWorkingDirectory(t *testing.T) { 10 | // Arrange: get the Go process' current working directory 11 | want, err := os.Getwd() 12 | if err != nil { 13 | t.Fatalf("os.Getwd failed: %v", err) 14 | } 15 | 16 | // Act: call the tool 17 | out, err := Pwd.Call(map[string]any{}) 18 | if err != nil { 19 | t.Fatalf("Pwd.Call returned error: %v", err) 20 | } 21 | 22 | // pwd usually prints a trailing newline; trim whitespace 23 | got := strings.TrimSpace(out) 24 | 25 | // Assert 26 | if got != want { 27 | t.Fatalf("pwd output mismatch.\nwant: %q\ngot: %q", want, got) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_setup.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 10 | ) 11 | 12 | func (c *Claude) Setup() error { 13 | apiKey := os.Getenv("ANTHROPIC_API_KEY") 14 | if apiKey == "" { 15 | return fmt.Errorf("environment variable 'ANTHROPIC_API_KEY' not set") 16 | } 17 | c.client = &http.Client{} 18 | c.apiKey = apiKey 19 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("ANTHROPIC_DEBUG")) { 20 | c.debug = true 21 | } 22 | return nil 23 | } 24 | 25 | func (c *Claude) RegisterTool(tool tools.LLMTool) { 26 | c.tools = append(c.tools, tool.Specification()) 27 | } 28 | -------------------------------------------------------------------------------- /examples/profiles/cody.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "claude-sonnet-4-0", 3 | "prompt": "I need you help to achieve diverse programming and devops tasks. I'll give you some information retrieved from my IDE, and you should use this + tooling to build up your context. Then, either answer, or generate the code.\n\nRequirements:\n\t* ONLY OUTPUT LINES SHORTER THAN 72 CHARACTERS\n\t* PRIMARILY USE 'rows_between' TO BUILD CONTEXT\n\t* ONLY USE 'cat' TOOL IF YOU KNOW THE FILE IS LESS THAN 400 lines\n\t* DO NOT WRAP THE OUTPUT IN QUOTES OR BACKTICKS.", 4 | "save-reply-as-conv": true, 5 | "tools": [ 6 | "git", 7 | "file_tree", 8 | "find", 9 | "file_type", 10 | "rows_between", 11 | "cat", 12 | "ls", 13 | "website_text", 14 | "rg" 15 | ], 16 | "use_tools": true 17 | } 18 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_cat_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | func TestCatTool_Call(t *testing.T) { 12 | tmp := t.TempDir() 13 | f := filepath.Join(tmp, "file.txt") 14 | if err := os.WriteFile(f, []byte("hello\nworld"), 0o644); err != nil { 15 | t.Fatal(err) 16 | } 17 | out, err := Cat.Call(pub_models.Input{"file": f}) 18 | if err != nil { 19 | t.Fatalf("cat failed: %v", err) 20 | } 21 | if out != "hello\nworld" { 22 | t.Errorf("unexpected output: %q", out) 23 | } 24 | } 25 | 26 | func TestCatTool_BadType(t *testing.T) { 27 | if _, err := Cat.Call(pub_models.Input{"file": 123}); err == nil { 28 | t.Error("expected error for bad file type") 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_find_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | "testing" 8 | 9 | pub_models "github.com/baalimago/clai/pkg/text/models" 10 | ) 11 | 12 | func TestFindTool_Call(t *testing.T) { 13 | tmp := t.TempDir() 14 | os.WriteFile(filepath.Join(tmp, "a.txt"), []byte("hi"), 0o644) 15 | os.WriteFile(filepath.Join(tmp, "b.log"), []byte("bye"), 0o644) 16 | out, err := Find.Call(pub_models.Input{"directory": tmp, "name": "*.txt"}) 17 | if err != nil { 18 | t.Fatalf("find failed: %v", err) 19 | } 20 | if !strings.Contains(out, "a.txt") { 21 | t.Errorf("expected to find a.txt, got %q", out) 22 | } 23 | } 24 | 25 | func TestFindTool_BadType(t *testing.T) { 26 | if _, err := Find.Call(pub_models.Input{"directory": 123}); err == nil { 27 | t.Error("expected error for bad directory type") 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /internal/video/conf_test.go: -------------------------------------------------------------------------------- 1 | package video 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestValidateOutputType(t *testing.T) { 8 | tests := []struct { 9 | name string 10 | input OutputType 11 | wantErr bool 12 | }{ 13 | { 14 | name: "valid local", 15 | input: LOCAL, 16 | wantErr: false, 17 | }, 18 | { 19 | name: "valid url", 20 | input: URL, 21 | wantErr: false, 22 | }, 23 | { 24 | name: "valid unset", 25 | input: UNSET, 26 | wantErr: false, 27 | }, 28 | { 29 | name: "invalid type", 30 | input: "invalid", 31 | wantErr: true, 32 | }, 33 | } 34 | 35 | for _, tt := range tests { 36 | t.Run(tt.name, func(t *testing.T) { 37 | if err := ValidateOutputType(tt.input); (err != nil) != tt.wantErr { 38 | t.Errorf("ValidateOutputType() error = %v, wantErr %v", err, tt.wantErr) 39 | } 40 | }) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /internal/utils/path.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | ) 8 | 9 | // GetClaiConfigDir returns the path to the clai configuration directory. 10 | // The directory is located inside the user's configuration directory 11 | // as /.clai. 12 | func GetClaiConfigDir() (string, error) { 13 | cfg, err := os.UserConfigDir() 14 | if err != nil { 15 | return "", fmt.Errorf("failed to get user config directory: %w", err) 16 | } 17 | return path.Join(cfg, ".clai"), nil 18 | } 19 | 20 | // GetClaiCacheDir returns the path to the clai cache directory. 21 | // The directory is located inside the user's cache directory 22 | // as /clai. 23 | func GetClaiCacheDir() (string, error) { 24 | cacheDir, err := os.UserCacheDir() 25 | if err != nil { 26 | return "", fmt.Errorf("failed to get user cache directory: %w", err) 27 | } 28 | return path.Join(cacheDir, "clai"), nil 29 | } 30 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_pwd.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type PwdTool pub_models.Specification 11 | 12 | var Pwd = PwdTool{ 13 | Name: "pwd", 14 | Description: "Print the current working directory. Uses the Linux command 'pwd'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Required: make([]string, 0), 18 | Properties: map[string]pub_models.ParameterObject{}, 19 | }, 20 | } 21 | 22 | func (p PwdTool) Call(input pub_models.Input) (string, error) { 23 | cmd := exec.Command("pwd") 24 | output, err := cmd.CombinedOutput() 25 | if err != nil { 26 | return "", fmt.Errorf("failed to run pwd: %w, output: %v", err, string(output)) 27 | } 28 | return string(output), nil 29 | } 30 | 31 | func (p PwdTool) Specification() pub_models.Specification { 32 | return pub_models.Specification(Pwd) 33 | } 34 | -------------------------------------------------------------------------------- /internal/video/store.go: -------------------------------------------------------------------------------- 1 | package video 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 10 | ) 11 | 12 | func SaveVideo(out Output, b64JSON, container string) (string, error) { 13 | data, err := base64.StdEncoding.DecodeString(b64JSON) 14 | if err != nil { 15 | return "", fmt.Errorf("failed to decode base64: %w", err) 16 | } 17 | videoName := fmt.Sprintf("%v_%v.%v", out.Prefix, utils.RandomPrefix(), container) 18 | outFile := fmt.Sprintf("%v/%v", out.Dir, videoName) 19 | err = os.WriteFile(outFile, data, 0o644) 20 | if err != nil { 21 | ancli.PrintWarn(fmt.Sprintf("failed to write file: '%v', attempting tmp file...\n", err)) 22 | outFile = fmt.Sprintf("/tmp/%v", videoName) 23 | err = os.WriteFile(outFile, data, 0o644) 24 | if err != nil { 25 | return "", fmt.Errorf("failed to write file: %w", err) 26 | } 27 | } 28 | return outFile, nil 29 | } 30 | -------------------------------------------------------------------------------- /internal/photo/store.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 10 | ) 11 | 12 | func SaveImage(out Output, b64JSON, encoding string) (string, error) { 13 | data, err := base64.StdEncoding.DecodeString(b64JSON) 14 | if err != nil { 15 | return "", fmt.Errorf("failed to decode base64: %w", err) 16 | } 17 | pictureName := fmt.Sprintf("%v_%v.%v", out.Prefix, utils.RandomPrefix(), encoding) 18 | outFile := fmt.Sprintf("%v/%v", out.Dir, pictureName) 19 | err = os.WriteFile(outFile, data, 0o644) 20 | if err != nil { 21 | ancli.PrintWarn(fmt.Sprintf("failed to write file: '%v', attempting tmp file...\n", err)) 22 | outFile = fmt.Sprintf("/tmp/%v", pictureName) 23 | err = os.WriteFile(outFile, data, 0o644) 24 | if err != nil { 25 | return "", fmt.Errorf("failed to write file: %w", err) 26 | } 27 | } 28 | return outFile, nil 29 | } 30 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_line_count_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | func TestLineCountTool_Call(t *testing.T) { 11 | const fileName = "test_line_count.txt" 12 | content := "one\ntwo\nthree\n" 13 | if err := os.WriteFile(fileName, []byte(content), 0o644); err != nil { 14 | t.Fatalf("setup failed: %v", err) 15 | } 16 | defer os.Remove(fileName) 17 | 18 | out, err := LineCount.Call(pub_models.Input{"file_path": fileName}) 19 | if err != nil { 20 | t.Fatalf("unexpected error: %v", err) 21 | } 22 | if out != "3" { 23 | t.Errorf("unexpected output: got %q want \"3\"", out) 24 | } 25 | } 26 | 27 | func TestLineCountTool_BadInputs(t *testing.T) { 28 | if _, err := LineCount.Call(pub_models.Input{"file_path": 123}); err == nil { 29 | t.Error("expected error for bad file_path type") 30 | } 31 | if _, err := LineCount.Call(pub_models.Input{"file_path": "no_such_file.txt"}); err == nil { 32 | t.Error("expected error for missing file") 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/baalimago/go_away_boilerplate v1.3.34 h1:6fzbpN/mWPYkboO9TF8F6jdV7wNQfahyNh4pQ2NxM3A= 2 | github.com/baalimago/go_away_boilerplate v1.3.34/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU= 3 | github.com/baalimago/go_away_boilerplate v1.3.337 h1:7vq2hlpklKB7tABy3KiV34jfopOEAYrAueVa3Yp3sLI= 4 | github.com/baalimago/go_away_boilerplate v1.3.337/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU= 5 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= 6 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= 7 | golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= 8 | golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= 9 | golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= 10 | golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= 11 | golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= 12 | golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= 13 | -------------------------------------------------------------------------------- /internal/models/model_generic_tests.go: -------------------------------------------------------------------------------- 1 | // This package contains test intended to be used by the implementations of the 2 | // Querier, ChatQuerier and StreamCompleter interfaces 3 | package models 4 | 5 | import ( 6 | "context" 7 | "testing" 8 | "time" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 12 | ) 13 | 14 | // These tests are used in other places of code, an attempt at generic testing 15 | // to ensure implementation standards are kept 16 | func Querier_Context_Test(t *testing.T, q Querier) { 17 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 18 | q.Query(ctx) 19 | }, time.Second) 20 | } 21 | 22 | func ChatQuerier_Test(t *testing.T, q ChatQuerier) { 23 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 24 | q.TextQuery(ctx, pub_models.Chat{}) 25 | }, time.Second) 26 | } 27 | 28 | func StreamCompleter_Test(t *testing.T, s StreamCompleter) { 29 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 30 | s.StreamCompletions(ctx, pub_models.Chat{}) 31 | }, time.Second) 32 | } 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 baalimago 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /internal/text/generic/stream_completer_setup.go: -------------------------------------------------------------------------------- 1 | package generic 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | pub_models "github.com/baalimago/clai/pkg/text/models" 10 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 11 | ) 12 | 13 | func (s *StreamCompleter) Setup(apiKeyEnv, url, debugEnv string) error { 14 | apiKey := os.Getenv(apiKeyEnv) 15 | if apiKey == "" { 16 | return fmt.Errorf("environment variable '%v' not set", apiKeyEnv) 17 | } 18 | s.client = &http.Client{} 19 | s.apiKey = apiKey 20 | s.URL = url 21 | 22 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv(debugEnv)) { 23 | s.debug = true 24 | } 25 | 26 | return nil 27 | } 28 | 29 | func (g *StreamCompleter) InternalRegisterTool(tool tools.LLMTool) { 30 | g.tools = append(g.tools, ToolSuper{ 31 | Type: "function", 32 | Function: convertToGenericTool(tool.Specification()), 33 | }) 34 | } 35 | 36 | func convertToGenericTool(tool pub_models.Specification) Tool { 37 | return Tool{ 38 | Name: tool.Name, 39 | Description: tool.Description, 40 | Inputs: *tool.Inputs, 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /internal/utils/misc.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "math/rand" 6 | ) 7 | 8 | func RandomPrefix() string { 9 | const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 10 | result := make([]byte, 10) 11 | for i := range result { 12 | result[i] = charset[rand.Intn(len(charset))] 13 | } 14 | 15 | return string(result) 16 | } 17 | 18 | // GetFirstTokens returns the first n tokens of the prompt, or the whole prompt if it has less than n tokens 19 | func GetFirstTokens(prompt []string, n int) []string { 20 | ret := make([]string, 0) 21 | for _, token := range prompt { 22 | if token == "" { 23 | continue 24 | } 25 | if len(ret) < n { 26 | ret = append(ret, token) 27 | } else { 28 | return ret 29 | } 30 | } 31 | return ret 32 | } 33 | 34 | // DeleteRange removes elements in [start, end] (inclusive) from the slice and returns the new slice. 35 | // Panics if indices are invalid. Works for any slice type (Go 1.18+). 36 | func DeleteRange[T any](s []T, start, end int) ([]T, error) { 37 | if start < 0 || end >= len(s) || start > end { 38 | return s, errors.New("invalid range for DeleteRange") 39 | } 40 | return append(s[:start], s[end+1:]...), nil 41 | } 42 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_date_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestDate_Default(t *testing.T) { 9 | out, err := Date.Call(map[string]any{}) 10 | if err != nil { 11 | t.Fatalf("Date.Call default returned error: %v", err) 12 | } 13 | if strings.TrimSpace(out) == "" { 14 | t.Fatalf("expected non-empty output from date") 15 | } 16 | } 17 | 18 | func TestDate_Unix(t *testing.T) { 19 | out, err := Date.Call(map[string]any{"unix": true}) 20 | if err != nil { 21 | t.Fatalf("Date.Call unix returned error: %v", err) 22 | } 23 | out = strings.TrimSpace(out) 24 | if len(out) == 0 { 25 | t.Fatalf("expected unix timestamp, got empty string") 26 | } 27 | for _, r := range out { 28 | if r < '0' || r > '9' { 29 | t.Fatalf("expected numeric unix timestamp, got %q", out) 30 | } 31 | } 32 | } 33 | 34 | func TestDate_UTCAndRFC3339(t *testing.T) { 35 | out, err := Date.Call(map[string]any{"utc": true, "rfc3339": true}) 36 | if err != nil { 37 | t.Fatalf("Date.Call utc+rfc3339 returned error: %v", err) 38 | } 39 | out = strings.TrimSpace(out) 40 | if !strings.Contains(out, "T") { 41 | t.Fatalf("expected RFC3339-like output containing 'T', got %q", out) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /internal/tools/mcp/client_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "testing" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | func TestClient(t *testing.T) { 12 | ctx, cancel := context.WithCancel(context.Background()) 13 | defer cancel() 14 | 15 | srv := pub_models.McpServer{Command: "go", Args: []string{"run", "./testserver"}} 16 | in, out, err := Client(ctx, srv) 17 | if err != nil { 18 | t.Fatalf("client: %v", err) 19 | } 20 | 21 | req := Request{JSONRPC: "2.0", ID: 1, Method: "initialize"} 22 | in <- req 23 | msg := <-out 24 | raw, ok := msg.(json.RawMessage) 25 | if !ok { 26 | t.Fatalf("unexpected type %T", msg) 27 | } 28 | var resp Response 29 | if err := json.Unmarshal(raw, &resp); err != nil { 30 | t.Fatalf("decode: %v", err) 31 | } 32 | if resp.ID != 1 || resp.Error != nil { 33 | t.Errorf("unexpected response: %+v", resp) 34 | } 35 | } 36 | 37 | func TestClientBadCommand(t *testing.T) { 38 | ctx, cancel := context.WithCancel(context.Background()) 39 | defer cancel() 40 | _, _, err := Client(ctx, pub_models.McpServer{Command: "does-not-exist"}) 41 | if err == nil { 42 | t.Fatal("expected error for bad command") 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/video/conf.go: -------------------------------------------------------------------------------- 1 | package video 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | type Configurations struct { 9 | Model string `json:"model"` 10 | // Format of the prompt, will place prompt at '%v' 11 | PromptFormat string `json:"prompt-format"` 12 | Output Output `json:"output"` 13 | Raw bool `json:"raw"` 14 | StdinReplace string `json:"-"` 15 | ReplyMode bool `json:"-"` 16 | Prompt string `json:"-"` 17 | 18 | PromptImageB64 string `json:"-"` 19 | } 20 | 21 | type Output struct { 22 | Type OutputType `json:"type"` 23 | Dir string `json:"dir"` 24 | Prefix string `json:"prefix"` 25 | } 26 | 27 | var Default = Configurations{ 28 | Model: "sora-2", 29 | PromptFormat: "%v", 30 | Output: Output{ 31 | Type: UNSET, 32 | Dir: fmt.Sprintf("%v/Videos", os.Getenv("HOME")), 33 | Prefix: "clai", 34 | }, 35 | } 36 | 37 | type OutputType string 38 | 39 | const ( 40 | LOCAL OutputType = "local" 41 | URL OutputType = "url" 42 | UNSET OutputType = "unset" 43 | ) 44 | 45 | func ValidateOutputType(outputType OutputType) error { 46 | switch outputType { 47 | case URL, LOCAL, UNSET: 48 | return nil 49 | default: 50 | return fmt.Errorf("invalid output type: %v", outputType) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /internal/tools/clai_tool_help.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | // ClaiHelp - Run `clai help` 11 | var ClaiHelp = &claiHelpTool{} 12 | 13 | type claiHelpTool struct{} 14 | 15 | const desc = `Run 'clai help' to output instructions on how to use the tool. 16 | 17 | Guidelines when using clai tools: 18 | * Always run 'clai help' to understand how to use the tool 19 | * Always run 'clai profiles' to know which profiles to use 20 | * Always run 'clai tools' to find which tools you can utilize for the clai_run subprocess workers` 21 | 22 | func (t *claiHelpTool) Specification() pub_models.Specification { 23 | return pub_models.Specification{ 24 | Name: "clai_help", 25 | Description: desc, 26 | Inputs: &pub_models.InputSchema{ 27 | Type: "object", 28 | Properties: map[string]pub_models.ParameterObject{}, 29 | Required: make([]string, 0), 30 | }, 31 | } 32 | } 33 | 34 | func (t *claiHelpTool) Call(input pub_models.Input) (string, error) { 35 | cmd := exec.Command(ClaiBinaryPath, "help") 36 | out, err := cmd.CombinedOutput() 37 | if err != nil { 38 | return string(out), fmt.Errorf("failed to run clai help: %w", err) 39 | } 40 | return string(out), nil 41 | } 42 | -------------------------------------------------------------------------------- /internal/chat/chat_test.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | 9 | pub_models "github.com/baalimago/clai/pkg/text/models" 10 | ) 11 | 12 | func TestSaveAndFromPath(t *testing.T) { 13 | tmp := t.TempDir() 14 | ch := pub_models.Chat{ 15 | ID: "my_chat", 16 | Messages: []pub_models.Message{{Role: "user", Content: "hello"}}, 17 | } 18 | if err := Save(tmp, ch); err != nil { 19 | t.Fatalf("save failed: %v", err) 20 | } 21 | file := filepath.Join(tmp, "my_chat.json") 22 | if _, err := os.Stat(file); err != nil { 23 | t.Fatalf("expected file %v to exist: %v", file, err) 24 | } 25 | loaded, err := FromPath(file) 26 | if err != nil { 27 | t.Fatalf("frompath failed: %v", err) 28 | } 29 | if !reflect.DeepEqual(ch, loaded) { 30 | t.Errorf("loaded chat mismatch: %+v vs %+v", loaded, ch) 31 | } 32 | } 33 | 34 | func TestFromPathError(t *testing.T) { 35 | if _, err := FromPath("nonexistent.json"); err == nil { 36 | t.Error("expected error for missing file") 37 | } 38 | } 39 | 40 | func TestIDFromPrompt(t *testing.T) { 41 | prompt := "hello world some/test path\\dir other extra" 42 | got := IDFromPrompt(prompt) 43 | want := "hello_world_some.test_path.dir_other" 44 | if got != want { 45 | t.Errorf("IDFromPrompt() = %q, want %q", got, want) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /internal/vendors/inception/inception_test.go: -------------------------------------------------------------------------------- 1 | package inception 2 | 3 | import "testing" 4 | 5 | func TestSetupConfigMapping(t *testing.T) { 6 | v := Default 7 | fp := 0.1 8 | v.FrequencyPenalty = fp 9 | mt := 777 10 | v.MaxTokens = &mt 11 | v.Temperature = 0.9 12 | v.Model = "inc-custom" 13 | 14 | t.Setenv("INCEPTION_API_KEY", "k") 15 | if err := v.Setup(); err != nil { 16 | t.Fatalf("setup failed: %v", err) 17 | } 18 | if v.StreamCompleter.Model != v.Model { 19 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 20 | } 21 | if v.StreamCompleter.FrequencyPenalty == nil || *v.StreamCompleter.FrequencyPenalty != v.FrequencyPenalty { 22 | t.Errorf("frequency penalty not mapped, got %#v want %v", v.StreamCompleter.FrequencyPenalty, v.FrequencyPenalty) 23 | } 24 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 25 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 26 | } 27 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 28 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 29 | } 30 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 31 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/vendors/inception/inception.go: -------------------------------------------------------------------------------- 1 | package inception 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/baalimago/clai/internal/text/generic" 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | var Default = Inception{ 11 | Model: "murcury", 12 | URL: ChatURL, 13 | } 14 | 15 | type Inception struct { 16 | generic.StreamCompleter 17 | Model string `json:"model"` 18 | FrequencyPenalty float64 `json:"frequency_penalty"` 19 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 20 | PresencePenalty float64 `json:"presence_penalty"` 21 | Temperature float64 `json:"temperature"` 22 | URL string `json:"url"` 23 | } 24 | 25 | const ChatURL = "https://api.inceptionlabs.ai/v1/chat/completions" 26 | 27 | func (g *Inception) Setup() error { 28 | err := g.StreamCompleter.Setup("INCEPTION_API_KEY", ChatURL, "INCEPTION_DEBUG") 29 | if err != nil { 30 | return fmt.Errorf("failed to setup stream completer: %w", err) 31 | } 32 | g.StreamCompleter.Model = g.Model 33 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 34 | g.StreamCompleter.MaxTokens = g.MaxTokens 35 | g.StreamCompleter.Temperature = &g.Temperature 36 | toolChoice := "auto" 37 | g.ToolChoice = &toolChoice 38 | return nil 39 | } 40 | 41 | func (g *Inception) RegisterTool(tool tools.LLMTool) { 42 | g.InternalRegisterTool(tool) 43 | } 44 | -------------------------------------------------------------------------------- /internal/photo/conf.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | type Configurations struct { 9 | Model string `json:"model"` 10 | // Format of the prompt, will place prompt at '%v' 11 | PromptFormat string `json:"prompt-format"` 12 | Output Output `json:"output"` 13 | Raw bool `json:"raw"` 14 | StdinReplace string `json:"-"` 15 | ReplyMode bool `json:"-"` 16 | Prompt string `json:"-"` 17 | } 18 | 19 | type Output struct { 20 | Type OutputType `json:"type"` 21 | Dir string `json:"dir"` 22 | Prefix string `json:"prefix"` 23 | } 24 | 25 | var DEFAULT = Configurations{ 26 | Model: "gpt-image-1", 27 | PromptFormat: "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS: '%v'", 28 | Output: Output{ 29 | Type: UNSET, 30 | Dir: fmt.Sprintf("%v/Pictures", os.Getenv("HOME")), 31 | Prefix: "clai", 32 | }, 33 | } 34 | 35 | type OutputType string 36 | 37 | const ( 38 | LOCAL OutputType = "local" 39 | URL OutputType = "url" 40 | UNSET OutputType = "unset" 41 | ) 42 | 43 | // ValidateOutputType is kind of dumb. Why did I add this..? 44 | func ValidateOutputType(outputType OutputType) error { 45 | switch outputType { 46 | case URL, LOCAL, UNSET: 47 | return nil 48 | default: 49 | return fmt.Errorf("invalid output type: %v", outputType) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /internal/tools/cmd.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "sort" 8 | 9 | "github.com/baalimago/clai/internal/utils" 10 | ) 11 | 12 | func SubCmd(ctx context.Context, args []string) error { 13 | if len(args) > 1 { 14 | toolName := args[1] 15 | tool, exists := Registry.Get(toolName) 16 | if !exists { 17 | return fmt.Errorf("tool '%s' not found", toolName) 18 | } 19 | spec := tool.Specification() 20 | jsonSpec, err := json.MarshalIndent(spec, "", " ") 21 | if err != nil { 22 | return fmt.Errorf("failed to marshal tool specification: %w", err) 23 | } 24 | fmt.Printf("%s\n", string(jsonSpec)) 25 | return utils.ErrUserInitiatedExit 26 | } 27 | 28 | tls := Registry.All() 29 | var toolNames []string 30 | for k := range tls { 31 | toolNames = append(toolNames, k) 32 | } 33 | sort.Strings(toolNames) 34 | 35 | fmt.Printf("Available Tools:\n") 36 | for _, name := range toolNames { 37 | tool := tls[name] 38 | spec := tool.Specification() 39 | prefix := fmt.Sprintf("- %s: ", name) 40 | 41 | maybeShortenedDesc, err := utils.WidthAppropriateStringTrunk(spec.Description, prefix, 5) 42 | if err != nil { 43 | return fmt.Errorf("failed to truncate descriptoin: :%v", err) 44 | } 45 | fmt.Println(maybeShortenedDesc) 46 | } 47 | fmt.Println("\nRun 'clai tools ' for more details.") 48 | return utils.ErrUserInitiatedExit 49 | } 50 | -------------------------------------------------------------------------------- /internal/chat/handler_test.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type mockChatQuerier struct{} 12 | 13 | func (mockChatQuerier) Query(ctx context.Context) error { return nil } 14 | func (mockChatQuerier) TextQuery(ctx context.Context, c pub_models.Chat) (pub_models.Chat, error) { 15 | return c, nil 16 | } 17 | 18 | func TestChatHandlerListAndFind(t *testing.T) { 19 | tmp := t.TempDir() 20 | chats := []pub_models.Chat{ 21 | {ID: "one", Created: time.Now().Add(-time.Hour)}, 22 | {ID: "two", Created: time.Now()}, 23 | } 24 | for _, c := range chats { 25 | if err := Save(tmp, c); err != nil { 26 | t.Fatalf("save: %v", err) 27 | } 28 | } 29 | h := &ChatHandler{convDir: tmp, q: mockChatQuerier{}} 30 | got, err := h.list() 31 | if err != nil { 32 | t.Fatalf("list err: %v", err) 33 | } 34 | if len(got) != 2 || got[0].ID != "two" { 35 | t.Fatalf("unexpected list result: %+v", got) 36 | } 37 | 38 | res, err := h.findChatByID("1 extra words") 39 | if err != nil { 40 | t.Fatalf("findChatByID err: %v", err) 41 | } 42 | if res.ID != "one" || h.prompt != "extra words" { 43 | t.Errorf("unexpected chat or prompt: %+v %q", res, h.prompt) 44 | } 45 | res, err = h.findChatByID("two") 46 | if err != nil || res.ID != "two" { 47 | t.Errorf("find by id failed: %v %+v", err, res) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /internal/tools/clai_tool_result.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | 6 | pub_models "github.com/baalimago/clai/pkg/text/models" 7 | ) 8 | 9 | // ClaiResult - Get result 10 | var ClaiResult = &claiResultTool{} 11 | 12 | type claiResultTool struct{} 13 | 14 | func (t *claiResultTool) Specification() pub_models.Specification { 15 | return pub_models.Specification{ 16 | Name: "clai_result", 17 | Description: "Get the stdout, stderr and statuscode of the run-id", 18 | Inputs: &pub_models.InputSchema{ 19 | Type: "object", 20 | Required: []string{"run_id"}, 21 | Properties: map[string]pub_models.ParameterObject{ 22 | "run_id": { 23 | Type: "string", 24 | Description: "The run-id returned by clai_run", 25 | }, 26 | }, 27 | }, 28 | } 29 | } 30 | 31 | func (t *claiResultTool) Call(input pub_models.Input) (string, error) { 32 | runIDRaw, ok := input["run_id"] 33 | if !ok { 34 | return "", fmt.Errorf("missing run_id") 35 | } 36 | runID, ok := runIDRaw.(string) 37 | if !ok { 38 | return "", fmt.Errorf("run_id must be a string") 39 | } 40 | 41 | claiRunsMu.Lock() 42 | process, ok := claiRuns[runID] 43 | claiRunsMu.Unlock() 44 | 45 | if !ok { 46 | return "", fmt.Errorf("unknown run_id: %s", runID) 47 | } 48 | 49 | return fmt.Sprintf("Exit Code: %d\nStdout:\n%s\nStderr:\n%s", process.exitCode, process.stdout.String(), process.stderr.String()), nil 50 | } 51 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_line_count.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "os" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type LineCountTool pub_models.Specification 12 | 13 | var LineCount = LineCountTool{ 14 | Name: "line_count", 15 | Description: "Count the number of lines in a file.", 16 | Inputs: &pub_models.InputSchema{ 17 | Type: "object", 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "file_path": { 20 | Type: "string", 21 | Description: "The path to the file to count lines of.", 22 | }, 23 | }, 24 | Required: []string{"file_path"}, 25 | }, 26 | } 27 | 28 | func (l LineCountTool) Call(input pub_models.Input) (string, error) { 29 | filePath, ok := input["file_path"].(string) 30 | if !ok { 31 | return "", fmt.Errorf("file_path must be a string") 32 | } 33 | file, err := os.Open(filePath) 34 | if err != nil { 35 | return "", fmt.Errorf("failed to open file: %w", err) 36 | } 37 | defer file.Close() 38 | 39 | scanner := bufio.NewScanner(file) 40 | count := 0 41 | for scanner.Scan() { 42 | count++ 43 | } 44 | if err := scanner.Err(); err != nil { 45 | return "", fmt.Errorf("failed to read file: %w", err) 46 | } 47 | return fmt.Sprintf("%d", count), nil 48 | } 49 | 50 | func (l LineCountTool) Specification() pub_models.Specification { 51 | return pub_models.Specification(LineCount) 52 | } 53 | -------------------------------------------------------------------------------- /internal/vendors/openai/gpt.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/baalimago/clai/internal/text/generic" 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | var GptDefault = ChatGPT{ 11 | Model: "gpt-4.1-mini", 12 | Temperature: 1.0, 13 | TopP: 1.0, 14 | URL: ChatURL, 15 | } 16 | 17 | type ChatGPT struct { 18 | generic.StreamCompleter 19 | Model string `json:"model"` 20 | FrequencyPenalty float64 `json:"frequency_penalty"` 21 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 22 | PresencePenalty float64 `json:"presence_penalty"` 23 | Temperature float64 `json:"temperature"` 24 | TopP float64 `json:"top_p"` 25 | URL string `json:"url"` 26 | } 27 | 28 | func (g *ChatGPT) Setup() error { 29 | err := g.StreamCompleter.Setup("OPENAI_API_KEY", ChatURL, "DEBUG_OPENAI") 30 | if err != nil { 31 | return fmt.Errorf("failed to setup stream completer: %w", err) 32 | } 33 | g.StreamCompleter.Model = g.Model 34 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 35 | g.StreamCompleter.MaxTokens = g.MaxTokens 36 | g.StreamCompleter.Temperature = &g.Temperature 37 | g.StreamCompleter.TopP = &g.TopP 38 | toolChoice := "auto" 39 | g.ToolChoice = &toolChoice 40 | return nil 41 | } 42 | 43 | func (g *ChatGPT) RegisterTool(tool tools.LLMTool) { 44 | g.InternalRegisterTool(tool) 45 | } 46 | -------------------------------------------------------------------------------- /internal/tools/mcp/models.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | pub_models "github.com/baalimago/clai/pkg/text/models" 7 | ) 8 | 9 | // ControlEvent instructs the Manager to register a new MCP server. 10 | type ControlEvent struct { 11 | ServerName string 12 | Server pub_models.McpServer 13 | InputChan chan<- any 14 | OutputChan <-chan any 15 | } 16 | 17 | // Request represents a JSON-RPC request. 18 | type Request struct { 19 | JSONRPC string `json:"jsonrpc"` 20 | ID int `json:"id,omitempty"` 21 | Method string `json:"method"` 22 | Params map[string]any `json:"params,omitempty"` 23 | } 24 | 25 | // Response represents a JSON-RPC response. 26 | type Response struct { 27 | JSONRPC string `json:"jsonrpc"` 28 | ID int `json:"id,omitempty"` 29 | Result json.RawMessage `json:"result,omitempty"` 30 | Error *RPCError `json:"error,omitempty"` 31 | } 32 | 33 | // RPCError represents a JSON-RPC error structure. 34 | type RPCError struct { 35 | Code int `json:"code"` 36 | Message string `json:"message"` 37 | Data any `json:"data,omitempty"` 38 | } 39 | 40 | // Tool describes a tool as returned by tools/list. 41 | type Tool struct { 42 | Name string `json:"name"` 43 | Description string `json:"description,omitempty"` 44 | InputSchema pub_models.InputSchema `json:"inputSchema"` 45 | } 46 | -------------------------------------------------------------------------------- /internal/models/model_generic_tests_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type mockQuerier struct{} 11 | 12 | func (m *mockQuerier) Query(ctx context.Context) error { 13 | <-ctx.Done() 14 | return ctx.Err() 15 | } 16 | 17 | func TestQuerier_Context_Test(t *testing.T) { 18 | // Should pass for a compliant Querier 19 | Querier_Context_Test(t, &mockQuerier{}) 20 | } 21 | 22 | type mockChatQuerier struct{} 23 | 24 | func (m *mockChatQuerier) Query(ctx context.Context) error { 25 | <-ctx.Done() 26 | return ctx.Err() 27 | } 28 | 29 | func (m *mockChatQuerier) TextQuery(ctx context.Context, chat pub_models.Chat) (pub_models.Chat, error) { 30 | <-ctx.Done() 31 | return pub_models.Chat{}, ctx.Err() 32 | } 33 | 34 | func TestChatQuerier_Test(t *testing.T) { 35 | // Should pass for a compliant ChatQuerier 36 | ChatQuerier_Test(t, &mockChatQuerier{}) 37 | } 38 | 39 | type mockStreamCompleter struct{} 40 | 41 | func (m *mockStreamCompleter) Setup() error { 42 | return nil 43 | } 44 | 45 | func (m *mockStreamCompleter) StreamCompletions(ctx context.Context, chat pub_models.Chat) (chan CompletionEvent, error) { 46 | <-ctx.Done() 47 | return nil, ctx.Err() 48 | } 49 | 50 | func TestStreamCompleter_Test(t *testing.T) { 51 | // Should pass for a compliant StreamCompleter 52 | StreamCompleter_Test(t, &mockStreamCompleter{}) 53 | } 54 | -------------------------------------------------------------------------------- /internal/vendors/xai/xai.go: -------------------------------------------------------------------------------- 1 | package xai 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/text/generic" 8 | "github.com/baalimago/clai/internal/tools" 9 | ) 10 | 11 | var Default = XAI{ 12 | Model: "grok-code-fast-1", 13 | Temperature: 0, 14 | TopP: 1.0, 15 | URL: ChatURL, 16 | } 17 | 18 | type XAI struct { 19 | generic.StreamCompleter 20 | Model string `json:"model"` 21 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 22 | PresencePenalty float64 `json:"presence_penalty"` 23 | Temperature float64 `json:"temperature"` 24 | TopP float64 `json:"top_p"` 25 | URL string `json:"url"` 26 | } 27 | 28 | const ChatURL = "https://api.x.ai/v1/chat/completions" 29 | 30 | func (g *XAI) Setup() error { 31 | if os.Getenv("XAI_API_KEY") == "" { 32 | os.Setenv("XAI_API_KEY", "xai") 33 | } 34 | err := g.StreamCompleter.Setup("XAI_API_KEY", ChatURL, "XAI_DEBUG") 35 | if err != nil { 36 | return fmt.Errorf("failed to setup stream completer: %w", err) 37 | } 38 | g.StreamCompleter.Model = g.Model 39 | g.StreamCompleter.MaxTokens = g.MaxTokens 40 | g.StreamCompleter.Temperature = &g.Temperature 41 | g.StreamCompleter.TopP = &g.TopP 42 | toolChoice := "auto" 43 | g.ToolChoice = &toolChoice 44 | return nil 45 | } 46 | 47 | func (g *XAI) RegisterTool(tool tools.LLMTool) { 48 | g.InternalRegisterTool(tool) 49 | } 50 | -------------------------------------------------------------------------------- /internal/utils/input.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | "slices" 10 | "strings" 11 | ) 12 | 13 | // ReadUserInput and return on interrupt channel 14 | func ReadUserInput() (string, error) { 15 | sigChan := make(chan os.Signal, 1) 16 | signal.Notify(sigChan, os.Interrupt) 17 | defer signal.Stop(sigChan) 18 | inputChan := make(chan string) 19 | errChan := make(chan error) 20 | 21 | go func() { 22 | // Open /dev/tty for direct terminal access 23 | tty, err := os.Open("/dev/tty") 24 | if err != nil { 25 | errChan <- fmt.Errorf("cannot open terminal: %w", err) 26 | return 27 | } 28 | defer tty.Close() 29 | 30 | reader := bufio.NewReader(tty) 31 | userInput, err := reader.ReadString('\n') 32 | if err != nil { 33 | errChan <- err 34 | return 35 | } 36 | inputChan <- userInput 37 | }() 38 | 39 | select { 40 | case <-sigChan: 41 | return "", ErrUserInitiatedExit 42 | case err := <-errChan: 43 | return "", fmt.Errorf("failed to read user input: %w", err) 44 | case userInput, open := <-inputChan: 45 | if open { 46 | trimmedInput := strings.TrimSpace(userInput) 47 | quitters := []string{"q", "quit"} 48 | if slices.Contains(quitters, trimmedInput) { 49 | return "", ErrUserInitiatedExit 50 | } 51 | return trimmedInput, nil 52 | } else { 53 | return "", errors.New("user input channel closed. Not sure how we ended up here🤔") 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /internal/tools/clai_tool_check.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | 6 | pub_models "github.com/baalimago/clai/pkg/text/models" 7 | ) 8 | 9 | // ClaiCheck - Check status 10 | var ClaiCheck = &claiCheckTool{} 11 | 12 | type claiCheckTool struct{} 13 | 14 | func (t *claiCheckTool) Specification() pub_models.Specification { 15 | return pub_models.Specification{ 16 | Name: "clai_check", 17 | Description: "Check status of the run-id: RUNNING, COMPLETED, FAILED", 18 | Inputs: &pub_models.InputSchema{ 19 | Type: "object", 20 | Required: []string{"run_id"}, 21 | Properties: map[string]pub_models.ParameterObject{ 22 | "run_id": { 23 | Type: "string", 24 | Description: "The run-id returned by clai_run", 25 | }, 26 | }, 27 | }, 28 | } 29 | } 30 | 31 | func (t *claiCheckTool) Call(input pub_models.Input) (string, error) { 32 | runIDRaw, ok := input["run_id"] 33 | if !ok { 34 | return "", fmt.Errorf("missing run_id") 35 | } 36 | runID, ok := runIDRaw.(string) 37 | if !ok { 38 | return "", fmt.Errorf("run_id must be a string") 39 | } 40 | 41 | claiRunsMu.Lock() 42 | process, ok := claiRuns[runID] 43 | claiRunsMu.Unlock() 44 | 45 | if !ok { 46 | return "", fmt.Errorf("unknown run_id: %s", runID) 47 | } 48 | 49 | if !process.done { 50 | return "RUNNING", nil 51 | } 52 | 53 | if process.exitCode != 0 || process.err != nil { 54 | return "FAILED", nil 55 | } 56 | 57 | return "COMPLETED", nil 58 | } 59 | -------------------------------------------------------------------------------- /internal/photo/funimation_0.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 10 | ) 11 | 12 | func StartAnimation() func() { 13 | t0 := time.Now() 14 | ticker := time.NewTicker(time.Second / 60) 15 | stop := make(chan struct{}) 16 | termWidth, err := utils.TermWidth() 17 | if err != nil { 18 | ancli.PrintWarn(fmt.Sprintf("failed to get terminal size: %v\n", err)) 19 | termWidth = 100 20 | } 21 | go func() { 22 | for { 23 | select { 24 | case <-ticker.C: 25 | cTick := time.Since(t0) 26 | clearLine := strings.Repeat(" ", termWidth) 27 | fmt.Printf("\r%v", clearLine) 28 | fmt.Printf("\rElapsed time: %v - %v", funimation(cTick), cTick) 29 | case <-stop: 30 | return 31 | } 32 | } 33 | }() 34 | return func() { 35 | close(stop) 36 | } 37 | } 38 | 39 | func funimation(t time.Duration) string { 40 | images := []string{ 41 | "🕛", 42 | "🕧", 43 | "🕐", 44 | "🕜", 45 | "🕑", 46 | "🕝", 47 | "🕒", 48 | "🕞", 49 | "🕓", 50 | "🕟", 51 | "🕔", 52 | "🕠", 53 | "🕕", 54 | "🕡", 55 | "🕖", 56 | "🕢", 57 | "🕗", 58 | "🕣", 59 | "🕘", 60 | "🕤", 61 | "🕙", 62 | "🕥", 63 | "🕚", 64 | "🕦", 65 | } 66 | // 1 nanosecond / 23 frames = 43478260 nanoseconds. Too low brainjuice to know 67 | // why that works right now 68 | return images[int(t.Nanoseconds()/43478260)%len(images)] 69 | } 70 | -------------------------------------------------------------------------------- /internal/vendors/gemini/gemini.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/text/generic" 8 | "github.com/baalimago/clai/internal/tools" 9 | ) 10 | 11 | var Default = Gemini{ 12 | Model: "gemini-2.5-flash", 13 | Temperature: 1.0, 14 | TopP: 1.0, 15 | URL: ChatURL, 16 | } 17 | 18 | type Gemini struct { 19 | generic.StreamCompleter 20 | Model string `json:"model"` 21 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 22 | PresencePenalty float64 `json:"presence_penalty"` 23 | Temperature float64 `json:"temperature"` 24 | TopP float64 `json:"top_p"` 25 | URL string `json:"url"` 26 | } 27 | 28 | const ChatURL = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" 29 | 30 | func (g *Gemini) Setup() error { 31 | if os.Getenv("GEMINI_API_KEY") == "" { 32 | os.Setenv("GEMINI_API_KEY", "gemini") 33 | } 34 | err := g.StreamCompleter.Setup("GEMINI_API_KEY", ChatURL, "GEMINI_DEBUG") 35 | if err != nil { 36 | return fmt.Errorf("failed to setup stream completer: %w", err) 37 | } 38 | g.StreamCompleter.Model = g.Model 39 | g.StreamCompleter.MaxTokens = g.MaxTokens 40 | g.StreamCompleter.Temperature = &g.Temperature 41 | g.StreamCompleter.TopP = &g.TopP 42 | toolChoice := "auto" 43 | g.ToolChoice = &toolChoice 44 | return nil 45 | } 46 | 47 | func (g *Gemini) RegisterTool(tool tools.LLMTool) { 48 | g.InternalRegisterTool(tool) 49 | } 50 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_freetext_command.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "strings" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type FreetextCmdTool pub_models.Specification 12 | 13 | var FreetextCmd = FreetextCmdTool{ 14 | Name: "freetext_command", 15 | Description: "Run any entered string as a terminal command.", 16 | Inputs: &pub_models.InputSchema{ 17 | Type: "object", 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "command": { 20 | Type: "string", 21 | Description: "The freetext comand. May be any string. Will return error on non-zero exit code.", 22 | }, 23 | }, 24 | Required: []string{"command"}, 25 | }, 26 | } 27 | 28 | func (r FreetextCmdTool) Call(input pub_models.Input) (string, error) { 29 | freetextCmd, ok := input["command"].(string) 30 | if !ok { 31 | return "", fmt.Errorf("freetextCmd must be a string") 32 | } 33 | freetextCmdSplit := strings.Split(freetextCmd, " ") 34 | var potentialArgsFlags []string 35 | if len(freetextCmdSplit) > 0 { 36 | potentialArgsFlags = freetextCmdSplit[1:] 37 | } 38 | cmd := exec.Command(freetextCmdSplit[0], potentialArgsFlags...) 39 | 40 | output, err := cmd.CombinedOutput() 41 | if err != nil { 42 | return "", fmt.Errorf("error: '%w', output: %v", err, string(output)) 43 | } 44 | return string(output), nil 45 | } 46 | 47 | func (r FreetextCmdTool) Specification() pub_models.Specification { 48 | return pub_models.Specification(FreetextCmd) 49 | } 50 | -------------------------------------------------------------------------------- /internal/photo/store_test.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "encoding/base64" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestSaveImage_PrimaryDir(t *testing.T) { 12 | tmp := t.TempDir() 13 | // simple 1x1 transparent PNG 14 | b64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMB/ak5tqkAAAAASUVORK5CYII=" 15 | out, err := SaveImage(Output{Dir: tmp, Prefix: "x"}, b64, "png") 16 | if err != nil { 17 | t.Fatalf("SaveImage error: %v", err) 18 | } 19 | if !strings.HasSuffix(out, ".png") { 20 | t.Fatalf("expected .png suffix, got %q", out) 21 | } 22 | data, err := os.ReadFile(out) 23 | if err != nil { 24 | t.Fatalf("read: %v", err) 25 | } 26 | dec := make([]byte, base64.StdEncoding.DecodedLen(len(b64))) 27 | // just ensure non-empty file and decode matches size 28 | if len(data) == 0 { 29 | t.Fatal("no data written") 30 | } 31 | _ = dec 32 | } 33 | 34 | func TestSaveImage_FallbackTmp(t *testing.T) { 35 | // Create a directory and remove write perms so first write fails 36 | dir := filepath.Join(t.TempDir(), "nope") 37 | if err := os.MkdirAll(dir, 0o555); err != nil { 38 | t.Fatal(err) 39 | } 40 | b64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMB/ak5tqkAAAAASUVORK5CYII=" 41 | out, err := SaveImage(Output{Dir: dir, Prefix: "y"}, b64, "png") 42 | if err != nil { 43 | t.Fatalf("SaveImage fallback error: %v", err) 44 | } 45 | if !strings.HasPrefix(out, "/tmp/") { 46 | t.Fatalf("expected fallback to /tmp, got %q", out) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /internal/vendors/openai/gpt_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import "testing" 4 | 5 | func TestSetupConfigMapping(t *testing.T) { 6 | v := GptDefault 7 | // customize some values 8 | fp := 0.21 9 | v.FrequencyPenalty = fp 10 | mt := 4096 11 | v.MaxTokens = &mt 12 | v.Temperature = 0.55 13 | v.TopP = 0.66 14 | v.Model = "gpt-custom" 15 | 16 | t.Setenv("OPENAI_API_KEY", "key") 17 | if err := v.Setup(); err != nil { 18 | t.Fatalf("setup failed: %v", err) 19 | } 20 | if v.StreamCompleter.Model != v.Model { 21 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 22 | } 23 | if v.StreamCompleter.FrequencyPenalty == nil || *v.StreamCompleter.FrequencyPenalty != v.FrequencyPenalty { 24 | t.Errorf("frequency penalty not mapped, got %#v want %v", v.StreamCompleter.FrequencyPenalty, v.FrequencyPenalty) 25 | } 26 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 27 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 28 | } 29 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 30 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 31 | } 32 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 33 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 34 | } 35 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 36 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/vendors/xai/xai_test.go: -------------------------------------------------------------------------------- 1 | package xai 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestSetupConfigMapping(t *testing.T) { 9 | v := Default 10 | mt := 8192 11 | v.MaxTokens = &mt 12 | v.Temperature = 0.12 13 | v.TopP = 0.34 14 | v.Model = "xai-custom" 15 | 16 | t.Setenv("XAI_API_KEY", "k") 17 | if err := v.Setup(); err != nil { 18 | t.Fatalf("setup failed: %v", err) 19 | } 20 | if v.StreamCompleter.Model != v.Model { 21 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 22 | } 23 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 24 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 25 | } 26 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 27 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 28 | } 29 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 30 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 31 | } 32 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 33 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 34 | } 35 | } 36 | 37 | func TestSetupSetsDefaultEnvWhenMissingXAI(t *testing.T) { 38 | v := Default 39 | t.Setenv("XAI_API_KEY", "") 40 | if err := v.Setup(); err != nil { 41 | t.Fatalf("setup failed: %v", err) 42 | } 43 | if os.Getenv("XAI_API_KEY") == "" { 44 | t.Fatalf("expected XAI_API_KEY to be set by Setup when missing") 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/vendors/ollama/ollama.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/baalimago/clai/internal/text/generic" 9 | "github.com/baalimago/clai/internal/tools" 10 | ) 11 | 12 | const ChatURL = "http://localhost:11434/v1/chat/completions" 13 | 14 | var Default = Ollama{ 15 | Model: "llama3", 16 | Temperature: 1.0, 17 | TopP: 1.0, 18 | } 19 | 20 | type Ollama struct { 21 | generic.StreamCompleter 22 | Model string `json:"model"` 23 | FrequencyPenalty float64 `json:"frequency_penalty"` 24 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 25 | PresencePenalty float64 `json:"presence_penalty"` 26 | Temperature float64 `json:"temperature"` 27 | TopP float64 `json:"top_p"` 28 | } 29 | 30 | func (g *Ollama) Setup() error { 31 | if os.Getenv("OLLAMA_API_KEY") == "" { 32 | os.Setenv("OLLAMA_API_KEY", "ollama") 33 | } 34 | err := g.StreamCompleter.Setup("OLLAMA_API_KEY", ChatURL, "OLLAMA_DEBUG") 35 | if err != nil { 36 | return fmt.Errorf("failed to setup stream completer: %w", err) 37 | } 38 | modelName := strings.TrimPrefix(g.Model, "ollama:") 39 | g.StreamCompleter.Model = modelName 40 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 41 | g.StreamCompleter.MaxTokens = g.MaxTokens 42 | g.StreamCompleter.Temperature = &g.Temperature 43 | g.StreamCompleter.TopP = &g.TopP 44 | toolChoice := "auto" 45 | g.ToolChoice = &toolChoice 46 | return nil 47 | } 48 | 49 | func (g *Ollama) RegisterTool(tool tools.LLMTool) { 50 | g.InternalRegisterTool(tool) 51 | } 52 | -------------------------------------------------------------------------------- /internal/vendors/deepseek/deepseek.go: -------------------------------------------------------------------------------- 1 | package deepseek 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/text/generic" 8 | "github.com/baalimago/clai/internal/tools" 9 | ) 10 | 11 | var Default = Deepseek{ 12 | Model: "deepseek-chat", 13 | Temperature: 1.0, 14 | TopP: 1.0, 15 | URL: ChatURL, 16 | } 17 | 18 | type Deepseek struct { 19 | generic.StreamCompleter 20 | Model string `json:"model"` 21 | FrequencyPenalty float64 `json:"frequency_penalty"` 22 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 23 | PresencePenalty float64 `json:"presence_penalty"` 24 | Temperature float64 `json:"temperature"` 25 | TopP float64 `json:"top_p"` 26 | URL string `json:"url"` 27 | } 28 | 29 | const ChatURL = "https://api.deepseek.com/chat/completions" 30 | 31 | func (g *Deepseek) Setup() error { 32 | if os.Getenv("DEEPSEEK_API_KEY") == "" { 33 | os.Setenv("DEEPSEEK_API_KEY", "deepseek") 34 | } 35 | err := g.StreamCompleter.Setup("DEEPSEEK_API_KEY", ChatURL, "DEEPSEEK_DEBUG") 36 | if err != nil { 37 | return fmt.Errorf("failed to setup stream completer: %w", err) 38 | } 39 | g.StreamCompleter.Model = g.Model 40 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 41 | g.StreamCompleter.MaxTokens = g.MaxTokens 42 | g.StreamCompleter.Temperature = &g.Temperature 43 | g.StreamCompleter.TopP = &g.TopP 44 | toolChoice := "auto" 45 | g.ToolChoice = &toolChoice 46 | return nil 47 | } 48 | 49 | func (g *Deepseek) RegisterTool(tool tools.LLMTool) { 50 | g.InternalRegisterTool(tool) 51 | } 52 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_tree.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type FileTreeTool pub_models.Specification 11 | 12 | var FileTree = FileTreeTool{ 13 | Name: "file_tree", 14 | Description: "List the filetree of some directory. Uses linux command 'tree'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "directory": { 19 | Type: "string", 20 | Description: "The directory to list the filetree of.", 21 | }, 22 | "level": { 23 | Type: "integer", 24 | Description: "The depth of the tree to display.", 25 | }, 26 | }, 27 | Required: []string{"directory"}, 28 | }, 29 | } 30 | 31 | func (f FileTreeTool) Call(input pub_models.Input) (string, error) { 32 | directory, ok := input["directory"].(string) 33 | if !ok { 34 | return "", fmt.Errorf("directory must be a string") 35 | } 36 | cmd := exec.Command("tree", directory) 37 | if input["level"] != nil { 38 | level, ok := input["level"].(float64) 39 | if !ok { 40 | return "", fmt.Errorf("level must be a number") 41 | } 42 | cmd.Args = append(cmd.Args, "-L") 43 | cmd.Args = append(cmd.Args, fmt.Sprintf("%v", level)) 44 | } 45 | output, err := cmd.CombinedOutput() 46 | if err != nil { 47 | return "", fmt.Errorf("failed to run tree: %w, output: %v", err, string(output)) 48 | } 49 | return string(output), nil 50 | } 51 | 52 | func (f FileTreeTool) Specification() pub_models.Specification { 53 | return pub_models.Specification(FileTree) 54 | } 55 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_file.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type FileTypeTool pub_models.Specification 11 | 12 | var FileType = FileTypeTool{ 13 | Name: "file_type", 14 | Description: "Determine the file type of a given file. Uses the linux command 'file'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "file_path": { 19 | Type: "string", 20 | Description: "The path to the file to analyze.", 21 | }, 22 | "mime_type": { 23 | Type: "boolean", 24 | Description: "Whether to display the MIME type of the file.", 25 | }, 26 | }, 27 | Required: []string{"file_path"}, 28 | }, 29 | } 30 | 31 | func (f FileTypeTool) Call(input pub_models.Input) (string, error) { 32 | filePath, ok := input["file_path"].(string) 33 | if !ok { 34 | return "", fmt.Errorf("file_path must be a string") 35 | } 36 | cmd := exec.Command("file", filePath) 37 | if input["mime_type"] != nil { 38 | mimeType, ok := input["mime_type"].(bool) 39 | if !ok { 40 | return "", fmt.Errorf("mime_type must be a boolean") 41 | } 42 | if mimeType { 43 | cmd.Args = append(cmd.Args, "--mime-type") 44 | } 45 | } 46 | output, err := cmd.CombinedOutput() 47 | if err != nil { 48 | return "", fmt.Errorf("failed to run file command: %w, output: %v", err, string(output)) 49 | } 50 | return string(output), nil 51 | } 52 | 53 | func (f FileTypeTool) Specification() pub_models.Specification { 54 | return pub_models.Specification(FileType) 55 | } 56 | -------------------------------------------------------------------------------- /internal/vendors/gemini/gemini_test.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import "testing" 4 | 5 | func TestSetupConfigMapping(t *testing.T) { 6 | v := Default 7 | mt := 2048 8 | v.MaxTokens = &mt 9 | v.Temperature = 0.33 10 | v.TopP = 0.44 11 | v.Model = "gemini-custom" 12 | 13 | t.Setenv("GEMINI_API_KEY", "key") 14 | if err := v.Setup(); err != nil { 15 | t.Fatalf("setup failed: %v", err) 16 | } 17 | if v.StreamCompleter.Model != v.Model { 18 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 19 | } 20 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 21 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 22 | } 23 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 24 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 25 | } 26 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 27 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 28 | } 29 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 30 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 31 | } 32 | } 33 | 34 | func TestSetupSetsDefaultEnvWhenMissing(t *testing.T) { 35 | v := Default 36 | // ensure env missing 37 | t.Setenv("GEMINI_API_KEY", "") 38 | if err := v.Setup(); err != nil { 39 | t.Fatalf("setup failed: %v", err) 40 | } 41 | if got := v.StreamCompleter; got.Model == "" { 42 | t.Fatalf("expected setup to have initialized model and env fallback; model empty implies setup not run correctly") 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/vendors/novita/novita.go: -------------------------------------------------------------------------------- 1 | package novita 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/baalimago/clai/internal/text/generic" 9 | "github.com/baalimago/clai/internal/tools" 10 | ) 11 | 12 | var Default = Novita{ 13 | Model: "gryphe/mythomax-l2-13b", 14 | Temperature: 1.0, 15 | TopP: 1.0, 16 | URL: ChatURL, 17 | } 18 | 19 | type Novita struct { 20 | generic.StreamCompleter 21 | Model string `json:"model"` 22 | FrequencyPenalty float64 `json:"frequency_penalty"` 23 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 24 | PresencePenalty float64 `json:"presence_penalty"` 25 | Temperature float64 `json:"temperature"` 26 | TopP float64 `json:"top_p"` 27 | URL string `json:"url"` 28 | } 29 | 30 | const ChatURL = "https://api.novita.ai/openai/v1/chat/completions" 31 | 32 | func (g *Novita) Setup() error { 33 | if os.Getenv("NOVITA_API_KEY") == "" { 34 | os.Setenv("NOVITA_API_KEY", "novita") 35 | } 36 | err := g.StreamCompleter.Setup("NOVITA_API_KEY", ChatURL, "NOVITA_DEBUG") 37 | if err != nil { 38 | return fmt.Errorf("failed to setup stream completer: %w", err) 39 | } 40 | 41 | modelName := strings.TrimPrefix(g.Model, "novita:") 42 | g.StreamCompleter.Model = modelName 43 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 44 | g.StreamCompleter.MaxTokens = g.MaxTokens 45 | g.StreamCompleter.Temperature = &g.Temperature 46 | g.StreamCompleter.TopP = &g.TopP 47 | toolChoice := "auto" 48 | g.ToolChoice = &toolChoice 49 | return nil 50 | } 51 | 52 | func (g *Novita) RegisterTool(tool tools.LLMTool) { 53 | g.InternalRegisterTool(tool) 54 | } 55 | -------------------------------------------------------------------------------- /internal/utils/file.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | "os" 10 | ) 11 | 12 | func CreateFile[T any](path string, toCreate *T) error { 13 | file, err := os.Create(path) 14 | if err != nil { 15 | return fmt.Errorf("failed to create config file: %w", err) 16 | } 17 | defer file.Close() 18 | b, err := json.MarshalIndent(toCreate, "", " ") 19 | if err != nil { 20 | return fmt.Errorf("failed to marshal config: %w", err) 21 | } 22 | if _, err := file.Write(b); err != nil { 23 | return fmt.Errorf("failed to write config: %w", err) 24 | } 25 | return nil 26 | } 27 | 28 | func WriteFile[T any](path string, toWrite *T) error { 29 | fileBytes, err := json.MarshalIndent(toWrite, "", " ") 30 | if err != nil { 31 | return fmt.Errorf("failed to marshal file: %w", err) 32 | } 33 | err = os.WriteFile(path, fileBytes, 0o644) 34 | if err != nil { 35 | return fmt.Errorf("failed to write file: %w", err) 36 | } 37 | return nil 38 | } 39 | 40 | // ReadAndUnmarshal by first finding the file, then attempting to read + unmarshal to T 41 | func ReadAndUnmarshal[T any](filePath string, config *T) error { 42 | if _, err := os.Stat(filePath); errors.Is(err, fs.ErrNotExist) { 43 | return fmt.Errorf("failed to find file: %w", err) 44 | } 45 | file, err := os.Open(filePath) 46 | if err != nil { 47 | return fmt.Errorf("failed to open file: %w", err) 48 | } 49 | defer file.Close() 50 | fileBytes, err := io.ReadAll(file) 51 | if err != nil { 52 | return fmt.Errorf("failed to read file: %w", err) 53 | } 54 | err = json.Unmarshal(fileBytes, config) 55 | if err != nil { 56 | return fmt.Errorf("failed to unmarshal file: %w", err) 57 | } 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /internal/chat/chat.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/utils" 11 | pub_models "github.com/baalimago/clai/pkg/text/models" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | ) 15 | 16 | func FromPath(path string) (pub_models.Chat, error) { 17 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("DEBUG_REPLY_MODE")) { 18 | ancli.PrintOK(fmt.Sprintf("reading chat from '%v'\n", path)) 19 | } 20 | b, err := os.ReadFile(path) 21 | if err != nil { 22 | return pub_models.Chat{}, fmt.Errorf("failed to read file: %w", err) 23 | } 24 | var chat pub_models.Chat 25 | err = json.Unmarshal(b, &chat) 26 | if err != nil { 27 | return pub_models.Chat{}, fmt.Errorf("failed to decode JSON: %w", err) 28 | } 29 | 30 | return chat, nil 31 | } 32 | 33 | func Save(saveAt string, chat pub_models.Chat) error { 34 | b, err := json.Marshal(chat) 35 | if err != nil { 36 | return fmt.Errorf("failed to encode JSON: %w", err) 37 | } 38 | fileName := path.Join(saveAt, fmt.Sprintf("%v.json", chat.ID)) 39 | if misc.Truthy(os.Getenv("DEBUG")) && misc.Truthy(os.Getenv("DEBUG_VERBOSE")) || misc.Truthy(os.Getenv("DEBUG_REPLY_MODE")) { 40 | ancli.PrintOK(fmt.Sprintf("saving chat to: '%v'", fileName)) 41 | } 42 | return os.WriteFile(fileName, b, 0o644) 43 | } 44 | 45 | func IDFromPrompt(prompt string) string { 46 | id := strings.Join(utils.GetFirstTokens(strings.Split(prompt, " "), 5), "_") 47 | // Slashes messes up the save path pretty bad 48 | id = strings.ReplaceAll(id, "/", ".") 49 | // You're welcome, windows users. You're also weird. 50 | id = strings.ReplaceAll(id, "\\", ".") 51 | return id 52 | } 53 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_sed_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | func TestSedTool_Call(t *testing.T) { 11 | const fileName = "test_sed.txt" 12 | initial := "apple\nbanana\napple pie\n" 13 | err := os.WriteFile(fileName, []byte(initial), 0o644) 14 | if err != nil { 15 | t.Fatalf("setup failed: %v", err) 16 | } 17 | defer os.Remove(fileName) 18 | 19 | _, err = Sed.Call(pub_models.Input{ 20 | "file_path": fileName, 21 | "pattern": "apple", 22 | "repl": "orange", 23 | }) 24 | if err != nil { 25 | t.Fatalf("sed failed: %v", err) 26 | } 27 | 28 | result, err := os.ReadFile(fileName) 29 | if err != nil { 30 | t.Fatalf("read failed: %v", err) 31 | } 32 | 33 | expected := "orange\nbanana\norange pie\n" 34 | if string(result) != expected { 35 | t.Errorf("unexpected output: got\n%q\nwant\n%q", string(result), expected) 36 | } 37 | } 38 | 39 | func TestSedTool_Range(t *testing.T) { 40 | const fileName = "test_sed_range.txt" 41 | initial := "foo\nfoo\nfoo\nfoo\n" 42 | err := os.WriteFile(fileName, []byte(initial), 0o644) 43 | if err != nil { 44 | t.Fatalf("setup failed: %v", err) 45 | } 46 | defer os.Remove(fileName) 47 | 48 | _, err = Sed.Call(pub_models.Input{ 49 | "file_path": fileName, 50 | "pattern": "foo", 51 | "repl": "bar", 52 | "start_line": 2, 53 | "end_line": 3, 54 | }) 55 | if err != nil { 56 | t.Fatalf("sed with range failed: %v", err) 57 | } 58 | 59 | result, err := os.ReadFile(fileName) 60 | if err != nil { 61 | t.Fatalf("read failed: %v", err) 62 | } 63 | 64 | expected := "foo\nbar\nbar\nfoo\n" 65 | if string(result) != expected { 66 | t.Errorf("unexpected output: got\n%q\nwant\n%q", string(result), expected) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /internal/photo/prompt.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/baalimago/clai/internal/chat" 10 | "github.com/baalimago/clai/internal/utils" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 13 | ) 14 | 15 | func (c *Configurations) SetupPrompts() error { 16 | args := flag.Args() 17 | if c.ReplyMode { 18 | confDir, err := utils.GetClaiConfigDir() 19 | if err != nil { 20 | return fmt.Errorf("failed to get config dir: %w", err) 21 | } 22 | iP, err := chat.LoadPrevQuery(confDir) 23 | if err != nil { 24 | return fmt.Errorf("failed to load previous query: %w", err) 25 | } 26 | if len(iP.Messages) > 0 { 27 | replyMessages := "You will be given a serie of messages from different roles, then a prompt descibing what to do with these messages. " 28 | replyMessages += "Between the messages and the prompt, there will be this line: '-------------'." 29 | replyMessages += "The format is json with the structure {\"role\": \"\", \"content\": \"\"}. " 30 | replyMessages += "The roles are 'system' and 'user'. " 31 | b, err := json.Marshal(iP.Messages) 32 | if err != nil { 33 | return fmt.Errorf("failed to encode reply JSON: %w", err) 34 | } 35 | replyMessages = fmt.Sprintf("%vMessages:\n%v\n-------------\n", replyMessages, string(b)) 36 | c.Prompt += replyMessages 37 | } 38 | } 39 | prompt, err := utils.Prompt(c.StdinReplace, args) 40 | if err != nil { 41 | return fmt.Errorf("failed to setup prompt from stdin: %w", err) 42 | } 43 | if misc.Truthy(os.Getenv("DEBUG")) { 44 | ancli.PrintOK(fmt.Sprintf("format: '%v', prompt: '%v'\n", c.PromptFormat, prompt)) 45 | } 46 | c.Prompt += fmt.Sprintf(c.PromptFormat, prompt) 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_go.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "strings" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type GoTool pub_models.Specification 12 | 13 | var Go = GoTool{ 14 | Name: "go", 15 | Description: "Run Go commands like 'go test' and 'go run' to compile, test, and run Go programs. Run 'go help' to get details of this tool.", 16 | Inputs: &pub_models.InputSchema{ 17 | Type: "object", 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "command": { 20 | Type: "string", 21 | Description: "The Go command to run (e.g., 'run', 'test', 'build').", 22 | }, 23 | "args": { 24 | Type: "string", 25 | Description: "Additional arguments for the Go command (e.g., file names, flags).", 26 | }, 27 | "dir": { 28 | Type: "string", 29 | Description: "The directory to run the command in (optional, defaults to current directory).", 30 | }, 31 | }, 32 | Required: []string{"command"}, 33 | }, 34 | } 35 | 36 | func (g GoTool) Call(input pub_models.Input) (string, error) { 37 | command, ok := input["command"].(string) 38 | if !ok { 39 | return "", fmt.Errorf("command must be a string") 40 | } 41 | 42 | args := []string{command} 43 | 44 | if inputArgs, ok := input["args"].(string); ok { 45 | args = append(args, strings.Fields(inputArgs)...) 46 | } 47 | 48 | cmd := exec.Command("go", args...) 49 | 50 | if dir, ok := input["dir"].(string); ok { 51 | cmd.Dir = dir 52 | } 53 | 54 | output, err := cmd.CombinedOutput() 55 | if err != nil { 56 | return "", fmt.Errorf("failed to run go command: %w, output: %v", err, string(output)) 57 | } 58 | 59 | return string(output), nil 60 | } 61 | 62 | func (g GoTool) Specification() pub_models.Specification { 63 | return pub_models.Specification(Go) 64 | } 65 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_ls.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type LsTool pub_models.Specification 11 | 12 | var LS = LsTool{ 13 | Name: "ls", 14 | Description: "List the files in a directory. Uses the Linux command 'ls'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "directory": { 19 | Type: "string", 20 | Description: "The directory to list the files of.", 21 | }, 22 | "all": { 23 | Type: "boolean", 24 | Description: "Show all files, including hidden files.", 25 | }, 26 | "long": { 27 | Type: "boolean", 28 | Description: "Use a long listing format.", 29 | }, 30 | }, 31 | Required: []string{"directory"}, 32 | }, 33 | } 34 | 35 | func (f LsTool) Call(input pub_models.Input) (string, error) { 36 | directory, ok := input["directory"].(string) 37 | if !ok { 38 | return "", fmt.Errorf("directory must be a string") 39 | } 40 | cmd := exec.Command("ls", directory) 41 | if input["all"] != nil { 42 | all, ok := input["all"].(bool) 43 | if !ok { 44 | return "", fmt.Errorf("all must be a boolean") 45 | } 46 | if all { 47 | cmd.Args = append(cmd.Args, "-a") 48 | } 49 | } 50 | if input["long"] != nil { 51 | long, ok := input["long"].(bool) 52 | if !ok { 53 | return "", fmt.Errorf("long must be a boolean") 54 | } 55 | if long { 56 | cmd.Args = append(cmd.Args, "-l") 57 | } 58 | } 59 | output, err := cmd.CombinedOutput() 60 | if err != nil { 61 | return "", fmt.Errorf("failed to run ls: %w, output: %v", err, string(output)) 62 | } 63 | return string(output), nil 64 | } 65 | 66 | func (f LsTool) Specification() pub_models.Specification { 67 | return pub_models.Specification(LS) 68 | } 69 | -------------------------------------------------------------------------------- /internal/utils/file_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | type TestData struct { 9 | Name string `json:"name"` 10 | Age int `json:"age"` 11 | } 12 | 13 | func TestCreateFile(t *testing.T) { 14 | filePath := "test_create.json" 15 | defer os.Remove(filePath) 16 | 17 | data := &TestData{Name: "John", Age: 30} 18 | err := CreateFile(filePath, data) 19 | if err != nil { 20 | t.Errorf("CreateFile failed: %v", err) 21 | } 22 | 23 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 24 | t.Errorf("File not created: %v", err) 25 | } 26 | } 27 | 28 | func TestWriteFile(t *testing.T) { 29 | filePath := "test_write.json" 30 | defer os.Remove(filePath) 31 | 32 | data := &TestData{Name: "Alice", Age: 25} 33 | err := WriteFile(filePath, data) 34 | if err != nil { 35 | t.Errorf("WriteFile failed: %v", err) 36 | } 37 | 38 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 39 | t.Errorf("File not written: %v", err) 40 | } 41 | } 42 | 43 | func TestReadAndUnmarshal(t *testing.T) { 44 | filePath := "test_read.json" 45 | defer os.Remove(filePath) 46 | 47 | expected := &TestData{Name: "Bob", Age: 35} 48 | err := CreateFile(filePath, expected) 49 | if err != nil { 50 | t.Fatalf("Failed to create test file: %v", err) 51 | } 52 | 53 | var actual TestData 54 | err = ReadAndUnmarshal(filePath, &actual) 55 | if err != nil { 56 | t.Errorf("ReadAndUnmarshal failed: %v", err) 57 | } 58 | 59 | if actual.Name != expected.Name || actual.Age != expected.Age { 60 | t.Errorf("ReadAndUnmarshal returned unexpected data: got %+v, want %+v", actual, expected) 61 | } 62 | } 63 | 64 | func TestReadAndUnmarshal_FileNotFound(t *testing.T) { 65 | filePath := "nonexistent.json" 66 | var data TestData 67 | err := ReadAndUnmarshal(filePath, &data) 68 | if err == nil { 69 | t.Error("ReadAndUnmarshal should have failed for non-existent file") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/setup/setup_actions_test.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestCastPrimitive(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | input any 14 | want any 15 | }{ 16 | {"String to int", "42", 42}, 17 | {"String to float", "3.14", 3.14}, 18 | {"String remains string", "hello", "hello"}, 19 | {"Boolean true", "true", true}, 20 | {"Boolean false", "false", false}, 21 | } 22 | 23 | for _, tt := range tests { 24 | t.Run(tt.name, func(t *testing.T) { 25 | got := castPrimitive(tt.input) 26 | if !reflect.DeepEqual(got, tt.want) { 27 | t.Errorf("castPrimitive() = %v, want %v", got, tt.want) 28 | } 29 | }) 30 | } 31 | } 32 | 33 | func TestReconfigureWithEditor(t *testing.T) { 34 | tests := []struct { 35 | name string 36 | editor string 37 | content string 38 | wantErr bool 39 | }{ 40 | { 41 | name: "No editor set", 42 | editor: "", 43 | content: "", 44 | wantErr: true, 45 | }, 46 | { 47 | name: "Valid editor", 48 | editor: "echo", 49 | content: "{\"test\": \"value\"}", 50 | wantErr: false, 51 | }, 52 | } 53 | 54 | for _, tt := range tests { 55 | t.Run(tt.name, func(t *testing.T) { 56 | // Setup temporary file 57 | tmpDir := t.TempDir() 58 | tmpFile := filepath.Join(tmpDir, "config.json") 59 | if err := os.WriteFile(tmpFile, []byte(tt.content), 0o644); err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | // Set environment 64 | oldEditor := os.Getenv("EDITOR") 65 | defer os.Setenv("EDITOR", oldEditor) 66 | os.Setenv("EDITOR", tt.editor) 67 | 68 | cfg := config{ 69 | name: "test", 70 | filePath: tmpFile, 71 | } 72 | 73 | err := reconfigureWithEditor(cfg) 74 | if (err != nil) != tt.wantErr { 75 | t.Errorf("reconfigureWithEditor() error = %v, wantErr %v", err, tt.wantErr) 76 | } 77 | }) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /internal/vendors/novita/novita_test.go: -------------------------------------------------------------------------------- 1 | package novita 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestSetupConfigMappingAndModelPrefixTrim(t *testing.T) { 9 | v := Default 10 | fp := 0.5 11 | v.FrequencyPenalty = fp 12 | mt := 321 13 | v.MaxTokens = &mt 14 | v.Temperature = 0.7 15 | v.TopP = 0.8 16 | v.Model = "novita:gryphe/some-model" 17 | 18 | t.Setenv("NOVITA_API_KEY", "k") 19 | if err := v.Setup(); err != nil { 20 | t.Fatalf("setup failed: %v", err) 21 | } 22 | if v.StreamCompleter.Model != "gryphe/some-model" { 23 | t.Errorf("expected model to be trimmed of novita: prefix, got %q", v.StreamCompleter.Model) 24 | } 25 | if v.StreamCompleter.FrequencyPenalty == nil || *v.StreamCompleter.FrequencyPenalty != v.FrequencyPenalty { 26 | t.Errorf("frequency penalty not mapped, got %#v want %v", v.StreamCompleter.FrequencyPenalty, v.FrequencyPenalty) 27 | } 28 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 29 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 30 | } 31 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 32 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 33 | } 34 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 35 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 36 | } 37 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 38 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 39 | } 40 | } 41 | 42 | func TestSetupSetsDefaultEnvWhenMissingNOVITA(t *testing.T) { 43 | v := Default 44 | t.Setenv("NOVITA_API_KEY", "") 45 | if err := v.Setup(); err != nil { 46 | t.Fatalf("setup failed: %v", err) 47 | } 48 | if got := os.Getenv("NOVITA_API_KEY"); got == "" { 49 | t.Fatalf("expected NOVITA_API_KEY to be set by Setup") 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /internal/setup_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "testing" 7 | 8 | "github.com/baalimago/clai/internal/text" 9 | "github.com/baalimago/clai/internal/vendors/ollama" 10 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 11 | ) 12 | 13 | func TestGetModeFromArgs(t *testing.T) { 14 | tests := []struct { 15 | arg string 16 | want Mode 17 | }{ 18 | {"p", PHOTO}, 19 | {"chat", CHAT}, 20 | {"q", QUERY}, 21 | {"glob", GLOB}, 22 | {"re", REPLAY}, 23 | {"cmd", CMD}, 24 | {"setup", SETUP}, 25 | {"version", VERSION}, 26 | {"tools", TOOLS}, 27 | } 28 | for _, tc := range tests { 29 | got, err := getModeFromArgs(tc.arg) 30 | if err != nil { 31 | t.Errorf("unexpected error for %s: %v", tc.arg, err) 32 | } 33 | if got != tc.want { 34 | t.Errorf("mode for %s = %v, want %v", tc.arg, got, tc.want) 35 | } 36 | } 37 | if _, err := getModeFromArgs("unknown"); err == nil { 38 | t.Error("expected error for unknown command") 39 | } 40 | } 41 | 42 | func Test_setupTextQuerier(t *testing.T) { 43 | testDir := t.TempDir() 44 | // Issue reported here: https://github.com/baalimago/clai/pull/16#issuecomment-3506586071 45 | t.Run("deepseek url on ollama:deepseek-r1:8b chat model", func(t *testing.T) { 46 | t.Setenv("DEBUG", "1") 47 | oldFS := flag.CommandLine 48 | defer func() { flag.CommandLine = oldFS }() 49 | fs := flag.NewFlagSet("clai", flag.ContinueOnError) 50 | _ = fs.Parse([]string{"q", "noop"}) 51 | flag.CommandLine = fs 52 | 53 | got, err := setupTextQuerier(context.Background(), 54 | QUERY, 55 | testDir, 56 | Configurations{ 57 | ChatModel: "ollama:deepseek-r1:8b", 58 | }) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | ollamaModel, ok := got.(*text.Querier[*ollama.Ollama]) 64 | if !ok { 65 | t.Fatalf("expected type *text.Querier[*ollama.Ollama]), got: '%T'", got) 66 | } 67 | 68 | testboil.FailTestIfDiff(t, ollamaModel.Model.URL, ollama.ChatURL) 69 | }) 70 | } 71 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_rows_between_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | func TestRowsBetweenTool_Call(t *testing.T) { 11 | const fileName = "test_rows_between.txt" 12 | initial := "one\ntwo\nthree\nfour\nfive\n" 13 | err := os.WriteFile(fileName, []byte(initial), 0o644) 14 | if err != nil { 15 | t.Fatalf("setup failed: %v", err) 16 | } 17 | defer os.Remove(fileName) 18 | 19 | cases := []struct { 20 | start, end int 21 | expected string 22 | }{ 23 | {1, 3, "1: one\n2: two\n3: three"}, 24 | {2, 4, "2: two\n3: three\n4: four"}, 25 | {4, 5, "4: four\n5: five"}, 26 | {3, 3, "3: three"}, 27 | } 28 | 29 | for _, tc := range cases { 30 | got, err := RowsBetween.Call(pub_models.Input{ 31 | "file_path": fileName, 32 | "start_line": tc.start, 33 | "end_line": tc.end, 34 | }) 35 | if err != nil { 36 | t.Errorf("unexpected error: %v", err) 37 | } 38 | if got != tc.expected { 39 | t.Errorf("unexpected output: got %q want %q (start=%d, end=%d)", got, tc.expected, tc.start, tc.end) 40 | } 41 | } 42 | } 43 | 44 | func TestRowsBetweenTool_BadInputs(t *testing.T) { 45 | _, err := RowsBetween.Call(pub_models.Input{"file_path": "nonexistent.txt", "start_line": 1, "end_line": 3}) 46 | if err == nil { 47 | t.Error("expected error for missing file") 48 | } 49 | 50 | _, err = RowsBetween.Call(pub_models.Input{"file_path": "", "start_line": 1, "end_line": 3}) 51 | if err == nil { 52 | t.Error("expected error for missing file_path") 53 | } 54 | _, err = RowsBetween.Call(pub_models.Input{"file_path": "test_rows_between.txt", "start_line": -2, "end_line": 3}) 55 | if err == nil { 56 | t.Error("expected error for bad start_line") 57 | } 58 | _, err = RowsBetween.Call(pub_models.Input{"file_path": "test_rows_between.txt", "start_line": 4, "end_line": 2}) 59 | if err == nil { 60 | t.Error("expected error for inverted lines") 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /internal/models/models.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | type Querier interface { 14 | Query(ctx context.Context) error 15 | } 16 | 17 | type ChatQuerier interface { 18 | Querier 19 | TextQuery(context.Context, pub_models.Chat) (pub_models.Chat, error) 20 | } 21 | 22 | type StreamCompleter interface { 23 | // Setup the stream completer, do things like init http.Client/websocket etc 24 | // Will be called synchronously. Should return error if setup fails 25 | Setup() error 26 | 27 | // StreamCompletions and return a channel which sends CompletionsEvents. 28 | // The CompletionEvents should be a string, an error, NoopEvent or a models.Call. If there is 29 | // a catastrophic error, return the error and close the channel. 30 | StreamCompletions(context.Context, pub_models.Chat) (chan CompletionEvent, error) 31 | } 32 | 33 | // InputTokenCounter can return the amount of input tokens for a chat. 34 | type InputTokenCounter interface { 35 | CountInputTokens(context.Context, pub_models.Chat) (int, error) 36 | } 37 | 38 | // ToolBox can register tools which later on will be added to the chat completion queries 39 | type ToolBox interface { 40 | // RegisterTool registers a tool to the ToolBox 41 | RegisterTool(tools.LLMTool) 42 | } 43 | 44 | type CompletionEvent any 45 | 46 | type NoopEvent struct{} 47 | 48 | type StopEvent struct{} 49 | 50 | type ErrRateLimit struct { 51 | ResetAt time.Time 52 | TokensRemaining int 53 | MaxInputTokens int 54 | } 55 | 56 | func (erl *ErrRateLimit) Error() string { 57 | return fmt.Sprintf("reset at: '%v', input tokens used at time of rate limit: '%v'", erl.ResetAt, erl.TokensRemaining) 58 | } 59 | 60 | func NewRateLimitError(resetAt time.Time, maxInputTokens int, tokensRemaining int) error { 61 | return &ErrRateLimit{ 62 | ResetAt: resetAt, 63 | MaxInputTokens: maxInputTokens, 64 | TokensRemaining: tokensRemaining, 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /internal/vendors/deepseek/deepseek_test.go: -------------------------------------------------------------------------------- 1 | package deepseek 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestSetupConfigMapping(t *testing.T) { 9 | v := Default 10 | // customize some values to ensure mapping is from struct to embedded StreamCompleter 11 | tmp := 0.42 12 | v.FrequencyPenalty = tmp 13 | mt := 1234 14 | v.MaxTokens = &mt 15 | v.Temperature = 0.77 16 | v.TopP = 0.88 17 | v.Model = "deepseek-test-model" 18 | 19 | t.Setenv("DEEPSEEK_API_KEY", "any-key") 20 | if err := v.Setup(); err != nil { 21 | t.Fatalf("setup failed: %v", err) 22 | } 23 | // Assert fields mapped into embedded generic.StreamCompleter 24 | if v.StreamCompleter.Model != v.Model { 25 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 26 | } 27 | if v.StreamCompleter.FrequencyPenalty == nil || *v.StreamCompleter.FrequencyPenalty != v.FrequencyPenalty { 28 | t.Errorf("frequency penalty not mapped, got %#v want %v", v.StreamCompleter.FrequencyPenalty, v.FrequencyPenalty) 29 | } 30 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 31 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 32 | } 33 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 34 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 35 | } 36 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 37 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 38 | } 39 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 40 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 41 | } 42 | } 43 | 44 | func TestSetupSetsDefaultEnvWhenMissing(t *testing.T) { 45 | v := Default 46 | // ensure env missing 47 | t.Setenv("DEEPSEEK_API_KEY", "") 48 | if err := v.Setup(); err != nil { 49 | t.Fatalf("setup failed: %v", err) 50 | } 51 | if got := os.Getenv("DEEPSEEK_API_KEY"); got == "" { 52 | t.Fatalf("expected DEEPSEEK_API_KEY to be set, got empty") 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /internal/text/querier_setup_tools_test.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | 7 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 8 | ) 9 | 10 | func Test_filterMcpServersByProfile(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | files []string 14 | userConf Configurations 15 | want []string 16 | }{ 17 | { 18 | name: "No specific tools configured, return all files", 19 | files: []string{"server1.json", "server2.json"}, 20 | userConf: Configurations{ 21 | Tools: []string{}, 22 | }, 23 | want: []string{"server1.json", "server2.json"}, 24 | }, 25 | { 26 | name: "Specific tool matches one server", 27 | files: []string{"server1.json", "server2.json"}, 28 | userConf: Configurations{ 29 | Tools: []string{"mcp_server1"}, 30 | }, 31 | want: []string{"server1.json"}, 32 | }, 33 | { 34 | name: "Wildcard matches all mcp", 35 | files: []string{"server1.json", "server2.json"}, 36 | userConf: Configurations{ 37 | Tools: []string{"mcp_*"}, 38 | }, 39 | want: []string{"server1.json", "server2.json"}, 40 | }, 41 | { 42 | name: "Wildcard match on some servers", 43 | files: []string{"server1.json", "server2.json"}, 44 | userConf: Configurations{ 45 | Tools: []string{"mcp_server1*"}, 46 | }, 47 | want: []string{"server1.json"}, 48 | }, 49 | { 50 | name: "Match on server tool", 51 | files: []string{"server1.json", "server2.json"}, 52 | userConf: Configurations{ 53 | Tools: []string{"mcp_server1_tool0"}, 54 | }, 55 | want: []string{"server1.json"}, 56 | }, 57 | { 58 | name: "No match for any servers", 59 | files: []string{"server1.json", "server2.json"}, 60 | userConf: Configurations{ 61 | Tools: []string{"mcp_server3"}, 62 | }, 63 | want: []string{}, 64 | }, 65 | } 66 | 67 | for _, tt := range tests { 68 | t.Run(tt.name, func(t *testing.T) { 69 | ancli.Noticef("== test: %v\n", tt.name) 70 | got := filterMcpServersByProfile(tt.files, tt.userConf) 71 | if !slices.Equal(got, tt.want) { 72 | t.Errorf("want %v, got: %v", tt.want, got) 73 | } 74 | }) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /internal/text/conf_profile.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "fmt" 5 | "path" 6 | "strings" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | ) 10 | 11 | func findProfile(profileName string) (Profile, error) { 12 | cfg, _ := utils.GetClaiConfigDir() 13 | profilePath := path.Join(cfg, "profiles") 14 | var p Profile 15 | err := utils.ReadAndUnmarshal(path.Join(profilePath, fmt.Sprintf("%v.json", profileName)), &p) 16 | if err != nil { 17 | // Backwards compatibility: if we fail to load, at least surface the requested name. 18 | p.Name = profileName 19 | return p, err 20 | } 21 | // If Name is empty in the stored profile, normalize it to the filename/profileName. 22 | if strings.TrimSpace(p.Name) == "" { 23 | p.Name = profileName 24 | } 25 | return p, nil 26 | } 27 | 28 | func findProfileByPath(p string) (Profile, error) { 29 | var prof Profile 30 | err := utils.ReadAndUnmarshal(p, &prof) 31 | if err != nil { 32 | return prof, err 33 | } 34 | return prof, nil 35 | } 36 | 37 | func (c *Configurations) ProfileOverrides() error { 38 | if c.UseProfile == "" && c.ProfilePath == "" { 39 | return nil 40 | } 41 | if c.UseProfile != "" && c.ProfilePath != "" { 42 | return fmt.Errorf("profile and profile-path are mutually exclusive") 43 | } 44 | var profile Profile 45 | var err error 46 | if c.ProfilePath != "" { 47 | profile, err = findProfileByPath(c.ProfilePath) 48 | } else { 49 | profile, err = findProfile(c.UseProfile) 50 | } 51 | if err != nil { 52 | return fmt.Errorf("failed to find profile: %w", err) 53 | } 54 | c.Model = profile.Model 55 | newPrompt := profile.Prompt 56 | if c.CmdMode { 57 | // SystmePrompt here is CmdPrompt, keep it and remind llm to only suggest cmd 58 | newPrompt = fmt.Sprintf("You will get this pattern: || | ||. It is VERY vital that you DO NOT disobey the with whatever is posted in . || %v| %v ||", c.CmdModePrompt, profile.Prompt) 59 | } 60 | c.SystemPrompt = newPrompt 61 | c.UseTools = profile.UseTools && !c.CmdMode 62 | c.Tools = profile.Tools 63 | c.SaveReplyAsConv = profile.SaveReplyAsConv 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /internal/text/querier_cmd_mode_test.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | 9 | "github.com/baalimago/clai/internal/models" 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | 12 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 13 | ) 14 | 15 | type mockCompleter struct{} 16 | 17 | func (m mockCompleter) Setup() error { 18 | return nil 19 | } 20 | 21 | func (m mockCompleter) StreamCompletions(ctx context.Context, c pub_models.Chat) (chan models.CompletionEvent, error) { 22 | return nil, nil 23 | } 24 | 25 | func Test_executeAiCmd(t *testing.T) { 26 | testCases := []struct { 27 | description string 28 | setup func(t *testing.T) 29 | given string 30 | want string 31 | wantErr error 32 | }{ 33 | { 34 | description: "it should run shell cmd", 35 | given: "printf 'test'", 36 | want: "'test'", 37 | wantErr: nil, 38 | }, 39 | { 40 | description: "it should work with quotes", 41 | setup: func(t *testing.T) { 42 | t.Helper() 43 | os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name())) 44 | }, 45 | given: "find ./ -name \"testfile\"", 46 | want: "./testfile\n", 47 | wantErr: nil, 48 | }, 49 | { 50 | description: "it should work without quotes", 51 | setup: func(t *testing.T) { 52 | t.Helper() 53 | os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name())) 54 | }, 55 | given: "find ./ -name testfile", 56 | want: "./testfile\n", 57 | wantErr: nil, 58 | }, 59 | } 60 | 61 | for _, tc := range testCases { 62 | t.Run(tc.description, func(t *testing.T) { 63 | var gotErr error 64 | got := testboil.CaptureStdout(t, func(t *testing.T) { 65 | q := Querier[mockCompleter]{} 66 | if tc.setup != nil { 67 | tc.setup(t) 68 | } 69 | q.fullMsg = tc.given 70 | tmp := q.executeLlmCmd() 71 | gotErr = tmp 72 | }) 73 | if got != tc.want { 74 | t.Fatalf("expected: %v, got: %v", tc.want, got) 75 | } 76 | 77 | if gotErr != tc.wantErr { 78 | t.Fatalf("expected error: %v, got: %v", tc.wantErr, gotErr) 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_git_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | func setupRepo(t *testing.T) string { 14 | dir := t.TempDir() 15 | cmd := exec.Command("git", "init") 16 | cmd.Dir = dir 17 | if out, err := cmd.CombinedOutput(); err != nil { 18 | t.Fatalf("git init failed: %v, %s", err, out) 19 | } 20 | // Configure user 21 | exec.Command("git", "-C", dir, "config", "user.email", "test@example.com").Run() 22 | exec.Command("git", "-C", dir, "config", "user.name", "tester").Run() 23 | 24 | os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0o644) 25 | exec.Command("git", "-C", dir, "add", "a.txt").Run() 26 | exec.Command("git", "-C", dir, "commit", "-m", "first").Run() 27 | 28 | os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello world"), 0o644) 29 | exec.Command("git", "-C", dir, "commit", "-am", "second").Run() 30 | return dir 31 | } 32 | 33 | func TestGitTool_Log(t *testing.T) { 34 | repo := setupRepo(t) 35 | out, err := Git.Call(pub_models.Input{"operation": "log", "n": 1, "range": "HEAD", "dir": repo}) 36 | if err != nil { 37 | t.Fatalf("git log failed: %v", err) 38 | } 39 | if !strings.Contains(out, "second") { 40 | t.Errorf("expected log to contain second commit, got %q", out) 41 | } 42 | } 43 | 44 | func TestGitTool_Diff(t *testing.T) { 45 | repo := setupRepo(t) 46 | out, err := Git.Call(pub_models.Input{"operation": "diff", "range": "HEAD~1..HEAD", "file": "a.txt", "dir": repo}) 47 | if err != nil { 48 | t.Fatalf("git diff failed: %v", err) 49 | } 50 | if !strings.Contains(out, "hello world") { 51 | t.Errorf("unexpected diff output: %q", out) 52 | } 53 | } 54 | 55 | func TestGitTool_Status(t *testing.T) { 56 | repo := setupRepo(t) 57 | os.WriteFile(filepath.Join(repo, "b.txt"), []byte("x"), 0o644) 58 | out, err := Git.Call(pub_models.Input{"operation": "status", "dir": repo}) 59 | if err != nil { 60 | t.Fatalf("git status failed: %v", err) 61 | } 62 | if !strings.Contains(out, "?? b.txt") { 63 | t.Errorf("expected status to show untracked file, got %q", out) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /internal/tools/mcp/testserver/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "os" 6 | ) 7 | 8 | type Request struct { 9 | JSONRPC string `json:"jsonrpc"` 10 | ID int `json:"id,omitempty"` 11 | Method string `json:"method"` 12 | Params json.RawMessage `json:"params,omitempty"` 13 | } 14 | 15 | func main() { 16 | dec := json.NewDecoder(os.Stdin) 17 | enc := json.NewEncoder(os.Stdout) 18 | for { 19 | var req Request 20 | if err := dec.Decode(&req); err != nil { 21 | return 22 | } 23 | switch req.Method { 24 | case "initialize": 25 | enc.Encode(map[string]any{ 26 | "jsonrpc": "2.0", 27 | "id": req.ID, 28 | "result": map[string]any{}, 29 | }) 30 | case "tools/list": 31 | enc.Encode(map[string]any{ 32 | "jsonrpc": "2.0", 33 | "id": req.ID, 34 | "result": map[string]any{ 35 | "tools": []map[string]any{ 36 | { 37 | "name": "echo", 38 | "description": "echo text", 39 | "inputSchema": map[string]any{ 40 | "type": "object", 41 | "required": []string{"text"}, 42 | "properties": map[string]any{ 43 | "text": map[string]any{ 44 | "type": "string", 45 | "description": "text to echo", 46 | }, 47 | }, 48 | }, 49 | }, 50 | }, 51 | }, 52 | }) 53 | case "tools/call": 54 | var p struct { 55 | Name string `json:"name"` 56 | Arguments map[string]any `json:"arguments"` 57 | } 58 | json.Unmarshal(req.Params, &p) 59 | text, _ := p.Arguments["text"].(string) 60 | result := map[string]any{ 61 | "content": []map[string]any{{"type": "text", "text": text}}, 62 | "isError": false, 63 | } 64 | if text == "error" { 65 | result["isError"] = true 66 | } 67 | enc.Encode(map[string]any{ 68 | "jsonrpc": "2.0", 69 | "id": req.ID, 70 | "result": result, 71 | }) 72 | default: 73 | enc.Encode(map[string]any{ 74 | "jsonrpc": "2.0", 75 | "id": req.ID, 76 | "error": map[string]any{ 77 | "code": -32601, 78 | "message": "method not found", 79 | }, 80 | }) 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /internal/text/querier_cmd_mode.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/utils" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | ) 13 | 14 | var errFormat = "code: %v, stderr: '%v'\n" 15 | 16 | func (q *Querier[C]) handleCmdMode() error { 17 | // Tokens stream end without endline 18 | fmt.Println() 19 | 20 | if q.execErr != nil { 21 | return nil 22 | } 23 | 24 | for { 25 | fmt.Print("Do you want to [e]xecute cmd, [q]uit?: ") 26 | input, err := utils.ReadUserInput() 27 | if err != nil { 28 | return err 29 | } 30 | switch strings.ToLower(input) { 31 | case "q": 32 | return nil 33 | case "e": 34 | err := q.executeLlmCmd() 35 | if err == nil { 36 | return nil 37 | } else { 38 | return fmt.Errorf("failed to execute cmd: %v", err) 39 | } 40 | default: 41 | ancli.PrintWarn(fmt.Sprintf("unrecognized command: %v, please try again\n", input)) 42 | } 43 | } 44 | } 45 | 46 | func (q *Querier[C]) executeLlmCmd() error { 47 | fullMsg, err := utils.ReplaceTildeWithHome(q.fullMsg) 48 | if err != nil { 49 | return fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err) 50 | } 51 | // Quotes are, in 99% of the time, expanded by the shell in 52 | // different ways and then passed into the shell. So when LLM 53 | // suggests a command, executeAiCmd needs to act the same (meaning) 54 | // remove/expand the quotes 55 | fullMsg = strings.ReplaceAll(fullMsg, "\"", "") 56 | split := strings.Split(fullMsg, " ") 57 | if len(split) < 1 { 58 | return errors.New("Querier.executeAiCmd: too few tokens in q.fullMsg") 59 | } 60 | cmd := split[0] 61 | args := split[1:] 62 | 63 | if len(cmd) == 0 { 64 | return errors.New("Querier.executeAiCmd: command is empty") 65 | } 66 | 67 | command := exec.Command(cmd, args...) 68 | command.Stdout = os.Stdout 69 | command.Stderr = os.Stderr 70 | err = command.Run() 71 | if err != nil { 72 | cast := &exec.ExitError{} 73 | if errors.As(err, &cast) { 74 | return fmt.Errorf(errFormat, cast.ExitCode()) 75 | } else { 76 | return fmt.Errorf("Querier.executeAiCmd - run error: %w", err) 77 | } 78 | } 79 | 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_test.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | func Test_claudifyMessages(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | msgs []pub_models.Message 14 | want []ClaudeConvMessage 15 | }{ 16 | { 17 | name: "Single text message", 18 | msgs: []pub_models.Message{ 19 | {Role: "user", Content: "Hello"}, 20 | }, 21 | want: []ClaudeConvMessage{ 22 | {Role: "user", Content: []any{TextContentBlock{Type: "text", Text: "Hello"}}}, 23 | }, 24 | }, 25 | { 26 | name: "Multiple text messages same role", 27 | msgs: []pub_models.Message{ 28 | {Role: "user", Content: "Hello"}, 29 | {Role: "user", Content: "World"}, 30 | }, 31 | want: []ClaudeConvMessage{ 32 | {Role: "user", Content: []any{ 33 | TextContentBlock{Type: "text", Text: "Hello"}, 34 | TextContentBlock{Type: "text", Text: "World"}, 35 | }}, 36 | }, 37 | }, 38 | { 39 | name: "Tool call and result", 40 | msgs: []pub_models.Message{ 41 | {Role: "user", ToolCalls: []pub_models.Call{ 42 | {Name: "exampleTool", ID: "tool1", Inputs: &pub_models.Input{"test": 0}}, 43 | }}, 44 | {Role: "tool", ToolCallID: "tool1", Content: "tool result"}, 45 | }, 46 | want: []ClaudeConvMessage{ 47 | {Role: "user", Content: []any{ 48 | ToolUseContentBlock{Type: "tool_use", ID: "tool1", Name: "exampleTool", Input: &map[string]interface{}{"test": 0}}, 49 | ToolResultContentBlock{Type: "tool_result", ToolUseID: "tool1", Content: "tool result"}, 50 | }}, 51 | }, 52 | }, 53 | { 54 | name: "System message ignored", 55 | msgs: []pub_models.Message{ 56 | {Role: "system", Content: "system message"}, 57 | {Role: "user", Content: "Hello"}, 58 | }, 59 | want: []ClaudeConvMessage{ 60 | {Role: "user", Content: []any{TextContentBlock{Type: "text", Text: "Hello"}}}, 61 | }, 62 | }, 63 | } 64 | 65 | for _, tt := range tests { 66 | t.Run(tt.name, func(t *testing.T) { 67 | if got := claudifyMessages(tt.msgs); !reflect.DeepEqual(got, tt.want) { 68 | t.Errorf("claudifyMessages() = %v, want %v", got, tt.want) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /internal/setup/setup_test.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestGetConfigs(t *testing.T) { 11 | // Create a temporary directory for test files 12 | tempDir, err := os.MkdirTemp("", "test_configs") 13 | if err != nil { 14 | t.Fatalf("Failed to create temp directory: %v", err) 15 | } 16 | defer os.RemoveAll(tempDir) 17 | 18 | // Create test files 19 | testFiles := []string{ 20 | "config1.json", 21 | "config2.json", 22 | "textConfig.json", 23 | "photoConfig.json", 24 | "otherFile.txt", 25 | } 26 | for _, file := range testFiles { 27 | _, err := os.Create(filepath.Join(tempDir, file)) 28 | if err != nil { 29 | t.Fatalf("Failed to create test file %s: %v", file, err) 30 | } 31 | } 32 | 33 | tests := []struct { 34 | name string 35 | includeGlob string 36 | excludeContains []string 37 | want []config 38 | }{ 39 | { 40 | name: "All JSON files", 41 | includeGlob: filepath.Join(tempDir, "*.json"), 42 | excludeContains: []string{}, 43 | want: []config{ 44 | {name: "config1.json", filePath: filepath.Join(tempDir, "config1.json")}, 45 | {name: "config2.json", filePath: filepath.Join(tempDir, "config2.json")}, 46 | {name: "photoConfig.json", filePath: filepath.Join(tempDir, "photoConfig.json")}, 47 | {name: "textConfig.json", filePath: filepath.Join(tempDir, "textConfig.json")}, 48 | }, 49 | }, 50 | { 51 | name: "Exclude text and photo configs", 52 | includeGlob: filepath.Join(tempDir, "*.json"), 53 | excludeContains: []string{"textConfig", "photoConfig"}, 54 | want: []config{ 55 | {name: "config1.json", filePath: filepath.Join(tempDir, "config1.json")}, 56 | {name: "config2.json", filePath: filepath.Join(tempDir, "config2.json")}, 57 | }, 58 | }, 59 | } 60 | 61 | for _, tt := range tests { 62 | t.Run(tt.name, func(t *testing.T) { 63 | got, err := getConfigs(tt.includeGlob, tt.excludeContains) 64 | if err != nil { 65 | t.Errorf("getConfigs() error = %v", err) 66 | return 67 | } 68 | if !reflect.DeepEqual(got, tt.want) { 69 | t.Errorf("getConfigs() = %v, want %v", got, tt.want) 70 | } 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_cat.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type CatTool pub_models.Specification 11 | 12 | var Cat = CatTool{ 13 | Name: "cat", 14 | Description: "Display the contents of a file. Uses the linux command 'cat'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "file": { 19 | Type: "string", 20 | Description: "The file to display the contents of.", 21 | }, 22 | "number": { 23 | Type: "boolean", 24 | Description: "Number all output lines.", 25 | }, 26 | "showEnds": { 27 | Type: "boolean", 28 | Description: "Display $ at end of each line.", 29 | }, 30 | "squeezeBlank": { 31 | Type: "boolean", 32 | Description: "Suppress repeated empty output lines.", 33 | }, 34 | }, 35 | Required: []string{"file"}, 36 | }, 37 | } 38 | 39 | func (c CatTool) Call(input pub_models.Input) (string, error) { 40 | file, ok := input["file"].(string) 41 | if !ok { 42 | return "", fmt.Errorf("file must be a string") 43 | } 44 | cmd := exec.Command("cat", file) 45 | if input["number"] != nil { 46 | number, ok := input["number"].(bool) 47 | if !ok { 48 | return "", fmt.Errorf("number must be a boolean") 49 | } 50 | if number { 51 | cmd.Args = append(cmd.Args, "-n") 52 | } 53 | } 54 | if input["showEnds"] != nil { 55 | showEnds, ok := input["showEnds"].(bool) 56 | if !ok { 57 | return "", fmt.Errorf("showEnds must be a boolean") 58 | } 59 | if showEnds { 60 | cmd.Args = append(cmd.Args, "-E") 61 | } 62 | } 63 | if input["squeezeBlank"] != nil { 64 | squeezeBlank, ok := input["squeezeBlank"].(bool) 65 | if !ok { 66 | return "", fmt.Errorf("squeezeBlank must be a boolean") 67 | } 68 | if squeezeBlank { 69 | cmd.Args = append(cmd.Args, "-s") 70 | } 71 | } 72 | output, err := cmd.CombinedOutput() 73 | if err != nil { 74 | return "", fmt.Errorf("failed to run cat: %w, output: %v", err, string(output)) 75 | } 76 | return string(output), nil 77 | } 78 | 79 | func (c CatTool) Specification() pub_models.Specification { 80 | return pub_models.Specification(Cat) 81 | } 82 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_find.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type FindTool pub_models.Specification 11 | 12 | var Find = FindTool{ 13 | Name: "find", 14 | Description: "Search for files in a directory hierarchy. Uses linux command 'find'.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "directory": { 19 | Type: "string", 20 | Description: "The directory to start the search from.", 21 | }, 22 | "name": { 23 | Type: "string", 24 | Description: "The name pattern to search for.", 25 | }, 26 | "type": { 27 | Type: "string", 28 | Description: "The file type to search for (f: regular file, d: directory).", 29 | }, 30 | "maxdepth": { 31 | Type: "integer", 32 | Description: "The maximum depth of directories to search.", 33 | }, 34 | }, 35 | Required: []string{"directory"}, 36 | }, 37 | } 38 | 39 | func (f FindTool) Call(input pub_models.Input) (string, error) { 40 | directory, ok := input["directory"].(string) 41 | if !ok { 42 | return "", fmt.Errorf("directory must be a string") 43 | } 44 | cmd := exec.Command("find", directory) 45 | if input["name"] != nil { 46 | name, ok := input["name"].(string) 47 | if !ok { 48 | return "", fmt.Errorf("name must be a string") 49 | } 50 | cmd.Args = append(cmd.Args, "-name", name) 51 | } 52 | if input["type"] != nil { 53 | fileType, ok := input["type"].(string) 54 | if !ok { 55 | return "", fmt.Errorf("type must be a string") 56 | } 57 | cmd.Args = append(cmd.Args, "-type", fileType) 58 | } 59 | if input["maxdepth"] != nil { 60 | maxdepth, ok := input["maxdepth"].(float64) 61 | if !ok { 62 | return "", fmt.Errorf("maxdepth must be a number") 63 | } 64 | cmd.Args = append(cmd.Args, "-maxdepth", fmt.Sprintf("%v", maxdepth)) 65 | } 66 | output, err := cmd.CombinedOutput() 67 | if err != nil { 68 | return "", fmt.Errorf("failed to run find: %w, output: %v", err, string(output)) 69 | } 70 | return string(output), nil 71 | } 72 | 73 | func (f FindTool) Specification() pub_models.Specification { 74 | return pub_models.Specification(Find) 75 | } 76 | -------------------------------------------------------------------------------- /pkg/text/models/tools_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "encoding/json" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestCallPatchAndPretty(t *testing.T) { 10 | // empty -> defaults 11 | c := Call{} 12 | c.Patch() 13 | if c.Type != "function" { 14 | t.Fatalf("expected default type function, got %q", c.Type) 15 | } 16 | if c.Function.Name == "" { 17 | t.Fatalf("expected function name filled from Name or placeholder") 18 | } 19 | if c.Function.Arguments == "" { 20 | t.Fatalf("expected arguments to be auto-filled with JSON") 21 | } 22 | // Test PrettyPrint and JSON on populated object 23 | inp := Input{"path": "a", "flags": 2} 24 | c = Call{Name: "ls", Inputs: &inp} 25 | c.Patch() 26 | if c.Function.Name != "ls" || c.Type != "function" { 27 | t.Fatalf("unexpected patch results: %#v", c) 28 | } 29 | 30 | // Test PrettyPrint output 31 | pp := c.PrettyPrint() 32 | if !strings.Contains(pp, "Call: 'ls'") { 33 | t.Errorf("PrettyPrint expected to contain name 'ls', got %q", pp) 34 | } 35 | // Since map iteration is random, we check if keys exist in the string 36 | if !strings.Contains(pp, "'path': 'a'") { 37 | t.Errorf("PrettyPrint expected to contain path input, got %q", pp) 38 | } 39 | if !strings.Contains(pp, "'flags': '2'") { 40 | t.Errorf("PrettyPrint expected to contain flags input, got %q", pp) 41 | } 42 | 43 | // Test JSON output 44 | js := c.JSON() 45 | if !json.Valid([]byte(js)) { 46 | t.Errorf("JSON() returned invalid json: %s", js) 47 | } 48 | if !strings.Contains(js, `"name":"ls"`) { 49 | t.Errorf("JSON output missing name field: %s", js) 50 | } 51 | } 52 | 53 | func TestInputSchemaPatchAndIsOk(t *testing.T) { 54 | is := &InputSchema{} 55 | is.Patch() 56 | if is.Type != "object" || is.Required == nil || is.Properties == nil { 57 | t.Fatalf("patch did not initialize fields: %#v", is) 58 | } 59 | 60 | // array without items -> not ok 61 | is.Properties["arr"] = ParameterObject{Type: "array"} 62 | if is.IsOk() { 63 | t.Fatalf("expected IsOk to fail when array items are missing") 64 | } 65 | 66 | // array with items -> ok 67 | is.Properties["arr"] = ParameterObject{Type: "array", Items: &ParameterObject{Type: "string"}} 68 | if !is.IsOk() { 69 | t.Fatalf("expected IsOk to pass when array items are provided") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/chat/reply.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/fs" 7 | "os" 8 | "path" 9 | "time" 10 | 11 | "github.com/baalimago/clai/internal/utils" 12 | pub_models "github.com/baalimago/clai/pkg/text/models" 13 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 14 | ) 15 | 16 | // SaveAsPreviousQuery at claiConfDir/conversations/prevQuery.json with ID prevQuery 17 | func SaveAsPreviousQuery(claiConfDir string, msgs []pub_models.Message) error { 18 | prevQueryChat := pub_models.Chat{ 19 | Created: time.Now(), 20 | ID: "prevQuery", 21 | Messages: msgs, 22 | } 23 | // This check avoid storing queries without any replies, which would most likely 24 | // flood the conversations needlessly 25 | if len(msgs) > 2 { 26 | firstUserMsg, err := prevQueryChat.FirstUserMessage() 27 | if err != nil { 28 | return fmt.Errorf("failed to get first user message: %w", err) 29 | } 30 | convChat := pub_models.Chat{ 31 | Created: time.Now(), 32 | ID: IDFromPrompt(firstUserMsg.Content), 33 | Messages: msgs, 34 | } 35 | convPath := path.Join(claiConfDir, "conversations") 36 | if _, convDirExistsErr := os.Stat(convPath); convDirExistsErr != nil { 37 | os.MkdirAll(convPath, 0o755) 38 | } 39 | err = Save(convPath, convChat) 40 | if err != nil { 41 | return fmt.Errorf("failed to save previous query as new conversation: %w", err) 42 | } 43 | } 44 | 45 | return Save(path.Join(claiConfDir, "conversations"), prevQueryChat) 46 | } 47 | 48 | // LoadPrevQuery the prevQuery.json from the claiConfDir/conversations directory 49 | // If claiConfDir is left empty, it will be re-constructed. The technical debt 50 | // is piling up quite fast here 51 | func LoadPrevQuery(claiConfDir string) (pub_models.Chat, error) { 52 | if claiConfDir == "" { 53 | dir, err := utils.GetClaiConfigDir() 54 | if err != nil { 55 | return pub_models.Chat{}, fmt.Errorf("failed to find home dir: %v", err) 56 | } 57 | claiConfDir = dir 58 | } 59 | 60 | c, err := FromPath(path.Join(claiConfDir, "conversations", "prevQuery.json")) 61 | if err != nil { 62 | if errors.Is(err, fs.ErrNotExist) { 63 | ancli.PrintWarn("no previous query found\n") 64 | } else { 65 | return pub_models.Chat{}, fmt.Errorf("failed to read from path: %w", err) 66 | } 67 | } 68 | return c, nil 69 | } 70 | -------------------------------------------------------------------------------- /internal/models/models_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "testing" 5 | 6 | pub_models "github.com/baalimago/clai/pkg/text/models" 7 | ) 8 | 9 | func TestLastOfRole(t *testing.T) { 10 | chat := pub_models.Chat{Messages: []pub_models.Message{ 11 | {Role: "system", Content: "sys"}, 12 | {Role: "user", Content: "first"}, 13 | {Role: "admin", Content: "admin-msg"}, 14 | {Role: "user", Content: "last"}, 15 | }} 16 | 17 | msg, i, err := chat.LastOfRole("admin") 18 | if err != nil { 19 | t.Fatalf("unexpected err: %v", err) 20 | } 21 | if msg.Content != "admin-msg" { 22 | t.Errorf("expected 'admin-msg', got %q", msg.Content) 23 | } 24 | if i != 2 { 25 | t.Errorf("expected '2', got %v", i) 26 | } 27 | 28 | msg, i, err = chat.LastOfRole("user") 29 | if err != nil { 30 | t.Fatalf("unexpected err: %v", err) 31 | } 32 | if msg.Content != "last" { 33 | t.Errorf("expected 'last', got %q", msg.Content) 34 | } 35 | if i != 3 { 36 | t.Errorf("expected '3', got %v", i) 37 | } 38 | 39 | _, _, err = chat.LastOfRole("nonexistent") 40 | if err == nil { 41 | t.Error("expected error for nonexistent role") 42 | } 43 | } 44 | 45 | func TestFirstSystemMessage(t *testing.T) { 46 | chat := pub_models.Chat{Messages: []pub_models.Message{ 47 | {Role: "user", Content: "hi"}, 48 | {Role: "system", Content: "rules"}, 49 | }} 50 | msg, err := chat.FirstSystemMessage() 51 | if err != nil { 52 | t.Fatalf("unexpected err: %v", err) 53 | } 54 | if msg.Content != "rules" { 55 | t.Errorf("expected 'rules', got %q", msg.Content) 56 | } 57 | chat.Messages = []pub_models.Message{{Role: "user", Content: "hi"}} 58 | if _, err := chat.FirstSystemMessage(); err == nil { 59 | t.Error("expected error when no system message") 60 | } 61 | } 62 | 63 | func TestFirstUserMessage(t *testing.T) { 64 | chat := pub_models.Chat{Messages: []pub_models.Message{ 65 | {Role: "system", Content: "sys"}, 66 | {Role: "user", Content: "ok"}, 67 | }} 68 | msg, err := chat.FirstUserMessage() 69 | if err != nil { 70 | t.Fatalf("unexpected err: %v", err) 71 | } 72 | if msg.Content != "ok" { 73 | t.Errorf("expected 'ok', got %q", msg.Content) 74 | } 75 | chat.Messages = []pub_models.Message{{Role: "system", Content: "sys"}} 76 | if _, err := chat.FirstUserMessage(); err == nil { 77 | t.Error("expected error when no user message") 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_write_file.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type WriteFileTool pub_models.Specification 12 | 13 | var WriteFile = WriteFileTool{ 14 | Name: "write_file", 15 | Description: "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does.", 16 | Inputs: &pub_models.InputSchema{ 17 | Type: "object", 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "file_path": { 20 | Type: "string", 21 | Description: "The path to the file to write to.", 22 | }, 23 | "content": { 24 | Type: "string", 25 | Description: "The content to write to the file.", 26 | }, 27 | "append": { 28 | Type: "boolean", 29 | Description: "If true, append to the file instead of overwriting it.", 30 | }, 31 | }, 32 | Required: []string{"file_path", "content"}, 33 | }, 34 | } 35 | 36 | func (w WriteFileTool) Call(input pub_models.Input) (string, error) { 37 | filePath, ok := input["file_path"].(string) 38 | if !ok { 39 | return "", fmt.Errorf("file_path must be a string") 40 | } 41 | 42 | content, ok := input["content"].(string) 43 | if !ok { 44 | return "", fmt.Errorf("content must be a string") 45 | } 46 | 47 | append := false 48 | if input["append"] != nil { 49 | append, ok = input["append"].(bool) 50 | if !ok { 51 | return "", fmt.Errorf("append must be a boolean") 52 | } 53 | } 54 | 55 | // Ensure the directory exists 56 | dir := filepath.Dir(filePath) 57 | if err := os.MkdirAll(dir, 0o755); err != nil { 58 | return "", fmt.Errorf("failed to create directory: %w", err) 59 | } 60 | 61 | var flag int 62 | if append { 63 | flag = os.O_APPEND | os.O_CREATE | os.O_WRONLY 64 | } else { 65 | flag = os.O_TRUNC | os.O_CREATE | os.O_WRONLY 66 | } 67 | 68 | file, err := os.OpenFile(filePath, flag, 0o644) 69 | if err != nil { 70 | return "", fmt.Errorf("failed to open file: %w", err) 71 | } 72 | defer file.Close() 73 | 74 | _, err = file.WriteString(content) 75 | if err != nil { 76 | return "", fmt.Errorf("failed to write to file: %w", err) 77 | } 78 | 79 | return fmt.Sprintf("Successfully wrote %d bytes to %s", len(content), filePath), nil 80 | } 81 | 82 | func (w WriteFileTool) Specification() pub_models.Specification { 83 | return pub_models.Specification(WriteFile) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/text/full_test.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | "time" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | func TestNewFullResponseQuerier(t *testing.T) { 14 | q := NewFullResponseQuerier(pub_models.Configurations{Model: "gpt-4o", ConfigDir: t.TempDir()}) 15 | if q == nil { 16 | t.Fatal("expected non-nil") 17 | } 18 | } 19 | 20 | func TestPubConfigToInternalAndInternalToolsToString(t *testing.T) { 21 | cfg := pub_models.Configurations{ 22 | Model: "gpt-4o", 23 | SystemPrompt: "sys", 24 | ConfigDir: t.TempDir(), 25 | InternalTools: []pub_models.ToolName{ 26 | pub_models.CatTool, pub_models.LSTool, 27 | }, 28 | } 29 | ic := pubConfigToInternal(cfg) 30 | if !ic.UseTools || ic.Model != cfg.Model || ic.SystemPrompt != cfg.SystemPrompt { 31 | t.Fatalf("unexpected mapping: %#v", ic) 32 | } 33 | if len(ic.Tools) != 2 || ic.Tools[0] != string(pub_models.CatTool) { 34 | t.Fatalf("tools mapping unexpected: %#v", ic.Tools) 35 | } 36 | } 37 | 38 | func TestSetupCreatesDirsEvenOnError(t *testing.T) { 39 | tmp := t.TempDir() 40 | cfg := pub_models.Configurations{Model: "mock", ConfigDir: tmp} 41 | pq := NewFullResponseQuerier(cfg).(*publicQuerier) 42 | 43 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 44 | defer cancel() 45 | _ = pq.Setup(ctx) // may fail depending on vendor selection; we only care about dir side-effects 46 | 47 | // required dirs that Setup creates up-front 48 | if _, err := os.Stat(filepath.Join(pq.conf.ConfigDir, "mcpServers")); err != nil { 49 | t.Fatalf("expected mcpServers dir: %v", err) 50 | } 51 | // conversations dir creation depends on a condition in current code; do not assert strictly 52 | _ = os.MkdirAll(filepath.Join(pq.conf.ConfigDir, "conversations"), 0o755) 53 | } 54 | 55 | func TestQueryReturnsErrorWhenSetupFails(t *testing.T) { 56 | // model that will not be found by selectTextQuerier => Setup fails 57 | pq := NewFullResponseQuerier(pub_models.Configurations{Model: "unknown-model", ConfigDir: t.TempDir()}) 58 | ctx := context.Background() 59 | chat := pub_models.Chat{} 60 | out, err := pq.Query(ctx, chat) 61 | if err == nil { 62 | t.Fatalf("expected error from Query due to failing Setup") 63 | } 64 | if out.Messages != nil || out.ID != "" || !out.Created.IsZero() { 65 | t.Fatalf("expected zero value chat on error, got: %#v", out) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_models.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | pub_models "github.com/baalimago/clai/pkg/text/models" 5 | ) 6 | 7 | type ClaudeResponse struct { 8 | Content []ClaudeMessage `json:"content"` 9 | ID string `json:"id"` 10 | Model string `json:"model"` 11 | Role string `json:"role"` 12 | StopReason string `json:"stop_reason"` 13 | StopSequence any `json:"stop_sequence"` 14 | Type string `json:"type"` 15 | Usage TokenInfo `json:"usage"` 16 | } 17 | 18 | type ClaudeMessage struct { 19 | ID string `json:"id,omitempty"` 20 | Input *pub_models.Input `json:"input,omitempty"` 21 | Name string `json:"name,omitempty"` 22 | Text string `json:"text,omitempty"` 23 | Type string `json:"type"` 24 | } 25 | 26 | type TokenInfo struct { 27 | InputTokens int `json:"input_tokens"` 28 | OutputTokens int `json:"output_tokens"` 29 | } 30 | 31 | type Delta struct { 32 | Type string `json:"type"` 33 | Text string `json:"text,omitempty"` 34 | PartialJSON string `json:"partial_json,omitempty"` 35 | } 36 | 37 | type ContentBlockDelta struct { 38 | Type string `json:"type"` 39 | Index int `json:"index"` 40 | Delta Delta `json:"delta"` 41 | } 42 | 43 | type ContentBlockSuper struct { 44 | Type string `json:"type"` 45 | Index int `json:"index"` 46 | ToolContentBlock ToolUseContentBlock `json:"content_block"` 47 | } 48 | 49 | type ToolUseContentBlock struct { 50 | Type string `json:"type"` 51 | ID string `json:"id"` 52 | Name string `json:"name"` 53 | Input *map[string]interface{} `json:"input,omitempty"` 54 | } 55 | 56 | type ToolResultContentBlock struct { 57 | Type string `json:"type"` 58 | Content string `json:"content"` 59 | ToolUseID string `json:"tool_use_id"` 60 | } 61 | 62 | type TextContentBlock struct { 63 | Type string `json:"type"` 64 | Text string `json:"text"` 65 | } 66 | 67 | type Root struct { 68 | Type string `json:"type"` 69 | Index int `json:"index"` 70 | ContentBlock ToolUseContentBlock `json:"content_block"` 71 | } 72 | 73 | type ClaudeConvMessage struct { 74 | Role string `json:"role"` 75 | // Content may be either ToolContentBlock or TextContentBlock 76 | Content []any `json:"content"` 77 | } 78 | -------------------------------------------------------------------------------- /internal/tools/mcp/manager_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/baalimago/clai/internal/tools" 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | func TestHandleServerRegistersTool(t *testing.T) { 14 | ctx, cancel := context.WithCancel(context.Background()) 15 | defer cancel() 16 | 17 | srv := pub_models.McpServer{Command: "go", Args: []string{"run", "./testserver"}} 18 | in, out, err := Client(ctx, srv) 19 | if err != nil { 20 | t.Fatalf("client: %v", err) 21 | } 22 | 23 | orig := tools.Registry 24 | tools.Registry = tools.NewRegistry() 25 | defer func() { tools.Registry = orig }() 26 | 27 | ev := ControlEvent{ServerName: "echo", Server: srv, InputChan: in, OutputChan: out} 28 | readyChan := make(chan struct{}, 1) 29 | if serveErr := handleServer(ctx, ev, readyChan); serveErr != nil { 30 | t.Fatalf("handleServer: %v", serveErr) 31 | } 32 | 33 | tool, ok := tools.Registry.Get("mcp_echo_echo") 34 | if !ok { 35 | t.Fatal("tool not registered") 36 | } 37 | res, err := tool.Call(pub_models.Input{"text": "hello"}) 38 | if err != nil { 39 | t.Fatalf("call: %v", err) 40 | } 41 | if res != "hello" { 42 | t.Errorf("unexpected response %q", res) 43 | } 44 | 45 | if _, err := tool.Call(pub_models.Input{"text": "error"}); err == nil { 46 | t.Error("expected error on isError=true") 47 | } 48 | } 49 | 50 | func TestManager(t *testing.T) { 51 | ctx, cancel := context.WithCancel(context.Background()) 52 | defer cancel() 53 | srv := pub_models.McpServer{Command: "go", Args: []string{"run", "./testserver"}} 54 | in, out, err := Client(ctx, srv) 55 | if err != nil { 56 | t.Fatalf("client: %v", err) 57 | } 58 | 59 | orig := tools.Registry 60 | tools.Registry = tools.NewRegistry() 61 | defer func() { tools.Registry = orig }() 62 | 63 | controlCh := make(chan ControlEvent) 64 | statusCh := make(chan error, 1) 65 | var wg sync.WaitGroup 66 | wg.Add(1) 67 | go Manager(ctx, controlCh, statusCh, &wg) 68 | 69 | controlCh <- ControlEvent{ServerName: "echo", Server: srv, InputChan: in, OutputChan: out} 70 | 71 | var ok bool 72 | for i := 0; i < 20; i++ { 73 | _, ok = tools.Registry.Get("mcp_echo_echo") 74 | if ok { 75 | break 76 | } 77 | time.Sleep(50 * time.Millisecond) 78 | } 79 | if !ok { 80 | t.Fatal("tool not registered") 81 | } 82 | 83 | cancel() 84 | wg.Wait() 85 | if err := <-statusCh; err != nil { 86 | t.Fatalf("manager error: %v", err) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /internal/utils/prompt.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | 10 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 11 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 12 | ) 13 | 14 | // Prompt returns the prompt by checking all the arguments and stdin. 15 | // If there is no arguments, but data in stdin, stdin will become the prompt. 16 | // If there are arguments and data in stdin, all stdinReplace tokens will be substituted 17 | // with the data in stdin 18 | func Prompt(stdinReplace string, args []string) (string, error) { 19 | debug := misc.Truthy(os.Getenv("DEBUG")) 20 | if debug { 21 | ancli.PrintOK(fmt.Sprintf("stdinReplace: %v\n", stdinReplace)) 22 | } 23 | fi, err := os.Stdin.Stat() 24 | if err != nil { 25 | panic(err) 26 | } 27 | var hasPipe bool 28 | if fi.Mode()&os.ModeNamedPipe == 0 { 29 | hasPipe = false 30 | } else { 31 | hasPipe = true 32 | } 33 | 34 | if len(args) == 1 && !hasPipe { 35 | return "", errors.New("found no prompt, set args or pipe in some string") 36 | } 37 | // First argument is the command, so we skip it 38 | args = args[1:] 39 | // If no data is in stdin, simply return args 40 | if !hasPipe { 41 | return strings.Join(args, " "), nil 42 | } 43 | 44 | inputData, err := io.ReadAll(os.Stdin) 45 | if err != nil { 46 | return "", fmt.Errorf("failed to read stdin: %v", err) 47 | } 48 | pipeIn := string(inputData) 49 | // Add the pipeIn to the args if there are no args 50 | if len(args) == 0 { 51 | args = append(args, strings.Split(pipeIn, " ")...) 52 | } else if stdinReplace == "" && hasPipe { 53 | stdinReplace = "{}" 54 | args = append(args, "{}") 55 | } 56 | 57 | // Replace all occurrence of stdinReplaceSignal with pipeIn 58 | if stdinReplace != "" { 59 | if debug { 60 | ancli.PrintOK(fmt.Sprintf("attempting to replace: '%v' with stdin\n", stdinReplace)) 61 | } 62 | for i, arg := range args { 63 | if strings.Contains(arg, stdinReplace) { 64 | args[i] = strings.ReplaceAll(arg, stdinReplace, pipeIn) 65 | } 66 | } 67 | } 68 | 69 | if debug { 70 | ancli.PrintOK(fmt.Sprintf("args: %v\n", args)) 71 | } 72 | return strings.Join(args, " "), nil 73 | } 74 | 75 | func ReplaceTildeWithHome(s string) (string, error) { 76 | home, err := os.UserHomeDir() 77 | if err != nil && strings.Contains(s, "~/") { // only fail if glob contains ~/ and home dir is not found 78 | return "", fmt.Errorf("failed to get home dir: %w", err) 79 | } 80 | return strings.ReplaceAll(s, "~", home), nil 81 | } 82 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Function to get the latest release download URL for the specified OS and architecture 4 | get_latest_release_url() { 5 | repo="baalimago/clai" 6 | os="$1" 7 | arch="$2" 8 | 9 | # Fetch the latest release data from GitHub API 10 | release_data=$(curl -s "https://api.github.com/repos/$repo/releases/latest") 11 | 12 | # Extract the asset URL for the specified OS and architecture 13 | download_url=$(echo "$release_data" | grep "browser_download_url" | grep "$os" | grep "$arch" | cut -d '"' -f 4) 14 | 15 | echo "$download_url" 16 | } 17 | 18 | # Detect the OS 19 | case "$(uname)" in 20 | Linux*) 21 | os="linux" 22 | ;; 23 | Darwin*) 24 | os="darwin" 25 | ;; 26 | *) 27 | echo "Unsupported OS: $(uname)" 28 | exit 1 29 | ;; 30 | esac 31 | 32 | # Detect the architecture 33 | arch=$(uname -m) 34 | case "$arch" in 35 | x86_64) 36 | arch="amd64" 37 | ;; 38 | armv7*) 39 | arch="arm" 40 | ;; 41 | aarch64|arm64) 42 | arch="arm64" 43 | ;; 44 | i?86) 45 | arch="386" 46 | ;; 47 | *) 48 | echo "Unsupported architecture: $arch" 49 | exit 1 50 | ;; 51 | esac 52 | 53 | printf "detected os: '%s', arch: '%s'\n" "$os" "$arch" 54 | 55 | # Get the download URL for the latest release 56 | printf "finding asset url..." 57 | download_url=$(get_latest_release_url "$os" "$arch") 58 | printf "OK!\n" 59 | 60 | # Download the binary 61 | tmp_file=$(mktemp) 62 | 63 | printf "downloading binary..." 64 | if ! curl -s -L -o "$tmp_file" "$download_url"; then 65 | echo 66 | echo "Failed to download the binary." 67 | exit 1 68 | fi 69 | printf "OK!\n" 70 | 71 | printf "setting file executable file permissions..." 72 | # Make the binary executable 73 | 74 | if ! chmod +x "$tmp_file"; then 75 | echo 76 | echo "Failed to make the binary executable. Try running the script with sudo." 77 | exit 1 78 | fi 79 | printf "OK!\n" 80 | 81 | # Move the binary to standard XDG location and handle permission errors 82 | INSTALL_DIR=$HOME/.local/bin 83 | # If run as 'sudo', install to /usr/local/bin for systemwide use 84 | if [ -x /usr/bin/id ]; then 85 | if [ `/usr/bin/id -u` -eq 0 ]; then 86 | INSTALL_DIR=/usr/local/bin 87 | fi 88 | fi 89 | 90 | if ! mv "$tmp_file" $INSTALL_DIR/clai; then 91 | echo "Failed to move the binary to $INSTALL_DIR/clai, see error above. Try making sure you have write permission there, or run 'mv $tmp_file '." 92 | exit 1 93 | fi 94 | 95 | echo "clai installed successfully in $INSTALL_DIR, try it out with 'clai h'" 96 | -------------------------------------------------------------------------------- /internal/tools/mcp/client.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "os" 11 | "os/exec" 12 | 13 | pub_models "github.com/baalimago/clai/pkg/text/models" 14 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 15 | ) 16 | 17 | // Client starts the MCP server process defined by mcpConfig and returns channels 18 | // for sending requests and receiving responses. 19 | func Client(ctx context.Context, mcpConfig pub_models.McpServer) (chan<- any, <-chan any, error) { 20 | cmd := exec.CommandContext(ctx, mcpConfig.Command, mcpConfig.Args...) 21 | cmd.Env = os.Environ() 22 | for k, v := range mcpConfig.Env { 23 | cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) 24 | } 25 | 26 | stdout, err := cmd.StdoutPipe() 27 | if err != nil { 28 | return nil, nil, fmt.Errorf("stdout pipe: %w", err) 29 | } 30 | stdin, err := cmd.StdinPipe() 31 | if err != nil { 32 | return nil, nil, fmt.Errorf("stdin pipe: %w", err) 33 | } 34 | stderr, err := cmd.StderrPipe() 35 | if err != nil { 36 | return nil, nil, fmt.Errorf("stderr pipe: %w", err) 37 | } 38 | 39 | if err := cmd.Start(); err != nil { 40 | return nil, nil, fmt.Errorf("start mcp server: %w", err) 41 | } 42 | 43 | in := make(chan any) 44 | out := make(chan any) 45 | 46 | go func() { 47 | enc := json.NewEncoder(stdin) 48 | for { 49 | select { 50 | case msg, ok := <-in: 51 | if !ok { 52 | return 53 | } 54 | enc.Encode(msg) 55 | case <-ctx.Done(): 56 | return 57 | } 58 | } 59 | }() 60 | 61 | go func() { 62 | dec := json.NewDecoder(stdout) 63 | for { 64 | var raw json.RawMessage 65 | if err := dec.Decode(&raw); err != nil { 66 | if err == io.EOF { 67 | close(out) 68 | return 69 | } 70 | out <- fmt.Errorf("decode: %w", err) 71 | close(out) 72 | return 73 | } 74 | out <- raw 75 | } 76 | }() 77 | 78 | go func() { 79 | scanner := bufio.NewScanner(stderr) 80 | for scanner.Scan() { 81 | line := scanner.Text() 82 | if line != "" { 83 | ancli.Noticef("mcp_%v: %v\n", mcpConfig.Name, line) 84 | } 85 | } 86 | if ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) { 87 | return 88 | } 89 | if err := scanner.Err(); err != nil { 90 | ancli.Errf("mcp_%v: %s\n", mcpConfig.Name, err) 91 | } 92 | }() 93 | 94 | go func() { 95 | <-ctx.Done() 96 | stdin.Close() 97 | cmd.Wait() 98 | }() 99 | 100 | return in, out, nil 101 | } 102 | -------------------------------------------------------------------------------- /internal/video/prompt.go: -------------------------------------------------------------------------------- 1 | package video 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "os" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/chat" 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | ) 15 | 16 | func (c *Configurations) SetupPrompts() error { 17 | args := flag.Args() 18 | if c.ReplyMode { 19 | confDir, err := utils.GetClaiConfigDir() 20 | if err != nil { 21 | return fmt.Errorf("failed to get config dir: %w", err) 22 | } 23 | iP, err := chat.LoadPrevQuery(confDir) 24 | if err != nil { 25 | return fmt.Errorf("failed to load previous query: %w", err) 26 | } 27 | if len(iP.Messages) > 0 { 28 | replyMessages := "You will be given a serie of messages from different roles, then a prompt descibing what to do with these messages. " 29 | replyMessages += "Between the messages and the prompt, there will be this line: '-------------'. " 30 | replyMessages += "The format is json with the structure {\"role\": \"\", \"content\": \"\"}. " 31 | replyMessages += "The roles are 'system' and 'user'. " 32 | b, err := json.Marshal(iP.Messages) 33 | if err != nil { 34 | return fmt.Errorf("failed to encode reply JSON: %w", err) 35 | } 36 | replyMessages = fmt.Sprintf("%vMessages:\n%v\n-------------\n", replyMessages, string(b)) 37 | c.Prompt += replyMessages 38 | } 39 | } 40 | prompt, err := utils.Prompt(c.StdinReplace, args) 41 | if err != nil { 42 | return fmt.Errorf("failed to setup prompt from stdin: %w", err) 43 | } 44 | chat, err := chat.PromptToImageMessage(prompt) 45 | if err != nil { 46 | return fmt.Errorf("failed to convert to chat with image message") 47 | } 48 | isImagePrompt := false 49 | for _, m := range chat { 50 | for _, cp := range m.ContentParts { 51 | if cp.Type == "image_url" { 52 | isImagePrompt = true 53 | c.PromptImageB64 = cp.ImageB64.RawB64 54 | } 55 | if cp.Type == "text" { 56 | c.Prompt = cp.Text 57 | } 58 | } 59 | } 60 | if misc.Truthy(os.Getenv("DEBUG")) { 61 | ancli.PrintOK(fmt.Sprintf("format: '%v', prompt: '%v'\n", c.PromptFormat, prompt)) 62 | } 63 | // Don't do additional weird stuff if it's an image prompt 64 | if isImagePrompt { 65 | return nil 66 | } 67 | // If prompt format has %v, formatting it, otherwise just appending 68 | if strings.Contains(c.PromptFormat, "%v") { 69 | c.Prompt += fmt.Sprintf(c.PromptFormat, prompt) 70 | } else { 71 | c.Prompt += prompt 72 | } 73 | return nil 74 | } 75 | -------------------------------------------------------------------------------- /internal/tools/mcp/tool.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "os" 9 | "sync" 10 | 11 | pub_models "github.com/baalimago/clai/pkg/text/models" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 14 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 15 | ) 16 | 17 | // mcpTool wraps a tool provided by an MCP server and implements tools.LLMTool. 18 | type mcpTool struct { 19 | remoteName string 20 | spec pub_models.Specification 21 | inputChan chan<- any 22 | outputChan <-chan any 23 | 24 | mu sync.Mutex 25 | seq int 26 | } 27 | 28 | func (m *mcpTool) nextID() int { 29 | m.mu.Lock() 30 | defer m.mu.Unlock() 31 | m.seq++ 32 | return m.seq 33 | } 34 | 35 | func (m *mcpTool) Call(input pub_models.Input) (string, error) { 36 | nonNullableInp := make(map[string]any) 37 | if len(input) != 0 { 38 | nonNullableInp = input 39 | } 40 | id := m.nextID() 41 | req := Request{ 42 | JSONRPC: "2.0", 43 | ID: id, 44 | Method: "tools/call", 45 | Params: map[string]any{ 46 | "name": m.remoteName, 47 | "arguments": nonNullableInp, 48 | }, 49 | } 50 | if misc.Truthy(os.Getenv("DEBUG_CALL")) { 51 | ancli.Noticef("mcpTool.Call req: %v", debug.IndentedJsonFmt(req)) 52 | } 53 | 54 | m.inputChan <- req 55 | 56 | for msg := range m.outputChan { 57 | raw, ok := msg.(json.RawMessage) 58 | if !ok { 59 | if err, ok := msg.(error); ok { 60 | return "", err 61 | } 62 | continue 63 | } 64 | var resp Response 65 | if err := json.Unmarshal(raw, &resp); err != nil { 66 | continue 67 | } 68 | if resp.ID != id { 69 | continue 70 | } 71 | if resp.Error != nil { 72 | return "", errors.New(resp.Error.Message) 73 | } 74 | var result struct { 75 | Content []struct { 76 | Type string `json:"type"` 77 | Text string `json:"text"` 78 | } `json:"content"` 79 | IsError bool `json:"isError"` 80 | } 81 | if err := json.Unmarshal(resp.Result, &result); err != nil { 82 | return "", fmt.Errorf("decode result: %w", err) 83 | } 84 | var buf bytes.Buffer 85 | for _, c := range result.Content { 86 | if c.Type == "text" { 87 | buf.WriteString(c.Text) 88 | } 89 | } 90 | if result.IsError { 91 | return "", errors.New(buf.String()) 92 | } 93 | return buf.String(), nil 94 | } 95 | return "", fmt.Errorf("connection closed") 96 | } 97 | 98 | func (m *mcpTool) Specification() pub_models.Specification { 99 | return m.spec 100 | } 101 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_recall.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path" 8 | "strconv" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | // RecallTool fetches a message from a stored conversation 14 | // given its name and message index. 15 | type RecallTool pub_models.Specification 16 | 17 | // Recall is the exported instance of RecallTool. 18 | var Recall = RecallTool{ 19 | Name: "recall", 20 | Description: "Recall a message from a stored conversation by name and index.", 21 | Inputs: &pub_models.InputSchema{ 22 | Type: "object", 23 | Properties: map[string]pub_models.ParameterObject{ 24 | "conversation": { 25 | Type: "string", 26 | Description: "Conversation name or id", 27 | }, 28 | "index": { 29 | Type: "integer", 30 | Description: "Index of the message to retrieve", 31 | }, 32 | }, 33 | Required: []string{"conversation", "index"}, 34 | }, 35 | } 36 | 37 | func (r RecallTool) Call(input pub_models.Input) (string, error) { 38 | convName, ok := input["conversation"].(string) 39 | if !ok { 40 | return "", fmt.Errorf("conversation must be a string") 41 | } 42 | 43 | var idx int 44 | switch v := input["index"].(type) { 45 | case int: 46 | idx = v 47 | case float64: 48 | idx = int(v) 49 | case string: 50 | n, err := strconv.Atoi(v) 51 | if err != nil { 52 | return "", fmt.Errorf("index must be a number") 53 | } 54 | idx = n 55 | default: 56 | return "", fmt.Errorf("index must be a number") 57 | } 58 | 59 | confDir, err := os.UserConfigDir() 60 | if err != nil { 61 | return "", fmt.Errorf("failed to get user config dir: %w", err) 62 | } 63 | pathToConv := path.Join(confDir, ".clai", "conversations", fmt.Sprintf("%s.json", convName)) 64 | b, err := os.ReadFile(pathToConv) 65 | if err != nil { 66 | return "", fmt.Errorf("failed to load conversation: %w", err) 67 | } 68 | 69 | var conv struct { 70 | Messages []struct { 71 | Role string `json:"role"` 72 | Content string `json:"content"` 73 | } `json:"messages"` 74 | } 75 | 76 | if err := json.Unmarshal(b, &conv); err != nil { 77 | return "", fmt.Errorf("failed to decode conversation: %w", err) 78 | } 79 | 80 | lenConv := len(conv.Messages) 81 | if idx < 0 || idx >= lenConv { 82 | return "", fmt.Errorf("index out of range. Am messages: %v, index: %v", lenConv, idx) 83 | } 84 | msg := conv.Messages[idx] 85 | return fmt.Sprintf("%s: %s", msg.Role, msg.Content), nil 86 | } 87 | 88 | func (r RecallTool) Specification() pub_models.Specification { 89 | return pub_models.Specification(Recall) 90 | } 91 | -------------------------------------------------------------------------------- /internal/vendors/mistral/mistral_test.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | func TestCleanRemovesExtraToolFieldsAndMergesAssistants(t *testing.T) { 11 | msgs := []pub_models.Message{ 12 | {Role: "system", Content: "sys"}, 13 | {Role: "assistant", Content: "a1", ToolCalls: []pub_models.Call{{Name: "x", Function: pub_models.Specification{Name: "fn", Description: "desc"}}}}, 14 | {Role: "assistant", Content: "a2"}, 15 | {Role: "tool", Content: "tool-res"}, 16 | {Role: "system", Content: "after"}, 17 | } 18 | cleaned := clean(append([]pub_models.Message(nil), msgs...)) 19 | 20 | // assistant fields stripped on first assistant with tool calls 21 | if cleaned[1].ToolCalls[0].Name != "" || cleaned[1].ToolCalls[0].Function.Description != "" { 22 | t.Fatalf("expected tool fields cleared, got %+v", cleaned[1].ToolCalls[0]) 23 | } 24 | // content merged from consecutive assistants (first assistant content cleared, so only second remains with a leading newline) 25 | if cleaned[1].Content != "\na2" { 26 | t.Fatalf("expected merged content with leading newline, got %q", cleaned[1].Content) 27 | } 28 | // role change tool followed by system -> assistant 29 | if cleaned[3].Role != "assistant" { 30 | t.Fatalf("expected role assistant at idx 3, got %q", cleaned[3].Role) 31 | } 32 | // Ensure there are no consecutive assistants left 33 | for i := 1; i < len(cleaned); i++ { 34 | if cleaned[i].Role == "assistant" && cleaned[i-1].Role == "assistant" { 35 | t.Fatalf("expected no consecutive assistant messages after merge at positions %d and %d", i-1, i) 36 | } 37 | } 38 | } 39 | 40 | func TestSetupAssignsFieldsAndToolChoice(t *testing.T) { 41 | m := Default 42 | t.Setenv("MISTRAL_API_KEY", "k") 43 | if err := m.Setup(); err != nil { 44 | t.Fatalf("setup failed: %v", err) 45 | } 46 | if m.StreamCompleter.Model != m.Model { 47 | t.Errorf("model not mapped: %q vs %q", m.StreamCompleter.Model, m.Model) 48 | } 49 | if m.ToolChoice == nil || *m.ToolChoice != "any" { 50 | t.Errorf("expected tool choice 'any', got %#v", m.ToolChoice) 51 | } 52 | if m.Clean == nil { 53 | t.Errorf("expected Clean callback to be set") 54 | } 55 | } 56 | 57 | func TestStreamCompletionsDelegates(t *testing.T) { 58 | m := Default 59 | t.Setenv("MISTRAL_API_KEY", "k") 60 | _ = m.Setup() // ignore error, we will not actually perform network 61 | // Using a canceled context must quickly return an error channel or error per generic StreamCompleter tests 62 | ctx, cancel := context.WithCancel(context.Background()) 63 | cancel() 64 | _, _ = m.StreamCompletions(ctx, pub_models.Chat{}) 65 | } 66 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_date.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type DateTool pub_models.Specification 11 | 12 | var Date = DateTool{ 13 | Name: "date", 14 | Description: "Get or format the current date and time. Wraps the linux 'date' command and is optimized for agentic workloads.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Required: make([]string, 0), 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "format": { 20 | Type: "string", 21 | Description: "Optional format string passed to 'date +FORMAT'. Common example: '%Y-%m-%d %H:%M:%S'. If omitted, uses system default format.", 22 | }, 23 | "utc": { 24 | Type: "boolean", 25 | Description: "If true, returns time in UTC (equivalent to 'TZ=UTC date').", 26 | }, 27 | "rfc3339": { 28 | Type: "boolean", 29 | Description: "If true, returns time in RFC3339 format (e.g. 2006-01-02T15:04:05Z07:00). Overrides 'format' if both are set.", 30 | }, 31 | "unix": { 32 | Type: "boolean", 33 | Description: "If true, returns the current Unix timestamp in seconds. Overrides 'format' if set.", 34 | }, 35 | "args": { 36 | Type: "string", 37 | Description: "Raw argument string forwarded to the underlying 'date' command. Use only if other flags are not sufficient.", 38 | }, 39 | }, 40 | }, 41 | } 42 | 43 | func (d DateTool) Call(input pub_models.Input) (string, error) { 44 | var args []string 45 | 46 | // Highest priority: unix or rfc3339 helper flags (agent-friendly) 47 | if v, ok := input["unix"].(bool); ok && v { 48 | args = append(args, "+%s") 49 | } else if v, ok := input["rfc3339"].(bool); ok && v { 50 | args = append(args, "+%Y-%m-%dT%H:%M:%S%z") 51 | } else if format, ok := input["format"].(string); ok && format != "" { 52 | args = append(args, "+"+format) 53 | } 54 | 55 | // Raw args (lowest level escape hatch) 56 | if raw, ok := input["args"].(string); ok && raw != "" { 57 | // Let the user fully control arguments; do not mix with above 58 | args = []string{raw} 59 | } 60 | 61 | cmd := exec.Command("date", args...) 62 | 63 | // Support UTC via env var; avoids shell wrapping 64 | if v, ok := input["utc"].(bool); ok && v { 65 | cmd.Env = append(cmd.Env, "TZ=UTC") 66 | } 67 | 68 | output, err := cmd.CombinedOutput() 69 | if err != nil { 70 | return "", fmt.Errorf("failed to run date: %w, output: %v", err, string(output)) 71 | } 72 | return string(output), nil 73 | } 74 | 75 | func (d DateTool) Specification() pub_models.Specification { 76 | return pub_models.Specification(Date) 77 | } 78 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_rows_between.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "strings" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | type RowsBetweenTool pub_models.Specification 14 | 15 | var RowsBetween = RowsBetweenTool{ 16 | Name: "rows_between", 17 | Description: "Fetch the lines between two line numbers (inclusive) from a file.", 18 | Inputs: &pub_models.InputSchema{ 19 | Type: "object", 20 | Properties: map[string]pub_models.ParameterObject{ 21 | "file_path": { 22 | Type: "string", 23 | Description: "The path to the file to read.", 24 | }, 25 | "start_line": { 26 | Type: "integer", 27 | Description: "First line to include (1-based, inclusive).", 28 | }, 29 | "end_line": { 30 | Type: "integer", 31 | Description: "Last line to include (1-based, inclusive).", 32 | }, 33 | }, 34 | Required: []string{"file_path", "start_line", "end_line"}, 35 | }, 36 | } 37 | 38 | func (r RowsBetweenTool) Call(input pub_models.Input) (string, error) { 39 | filePath, ok := input["file_path"].(string) 40 | if !ok { 41 | return "", fmt.Errorf("file_path must be a string") 42 | } 43 | startLine, ok := input["start_line"].(int) 44 | if !ok { 45 | // Accept float64 (from JSON decoding) 46 | if f, isFloat := input["start_line"].(float64); isFloat { 47 | startLine = int(f) 48 | } else if s, isString := input["start_line"].(string); isString { 49 | startLine, _ = strconv.Atoi(s) 50 | } else { 51 | return "", fmt.Errorf("start_line must be an integer") 52 | } 53 | } 54 | endLine, ok := input["end_line"].(int) 55 | if !ok { 56 | if f, ok := input["end_line"].(float64); ok { 57 | endLine = int(f) 58 | } else if s, ok := input["end_line"].(string); ok { 59 | endLine, _ = strconv.Atoi(s) 60 | } else { 61 | return "", fmt.Errorf("end_line must be an integer") 62 | } 63 | } 64 | 65 | if startLine <= 0 || endLine < startLine { 66 | return "", fmt.Errorf("invalid line range") 67 | } 68 | 69 | file, err := os.Open(filePath) 70 | if err != nil { 71 | return "", fmt.Errorf("failed to open file: %w", err) 72 | } 73 | defer file.Close() 74 | 75 | var lines []string 76 | scanner := bufio.NewScanner(file) 77 | for i := 1; scanner.Scan(); i++ { 78 | if i >= startLine && i <= endLine { 79 | lineWithNumber := fmt.Sprintf("%d: %s", i, scanner.Text()) 80 | lines = append(lines, lineWithNumber) 81 | } 82 | if i > endLine { 83 | break 84 | } 85 | } 86 | if err := scanner.Err(); err != nil { 87 | return "", fmt.Errorf("failed to scan file: %w", err) 88 | } 89 | 90 | return strings.Join(lines, "\n"), nil 91 | } 92 | 93 | func (r RowsBetweenTool) Specification() pub_models.Specification { 94 | return pub_models.Specification(RowsBetween) 95 | } 96 | -------------------------------------------------------------------------------- /internal/photo/prompt_additional_test.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "flag" 5 | "os" 6 | "path" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/baalimago/clai/internal/chat" 11 | pub_models "github.com/baalimago/clai/pkg/text/models" 12 | ) 13 | 14 | func withFlagArgs(t *testing.T, args []string, fn func()) { 15 | old := flag.CommandLine 16 | t.Cleanup(func() { flag.CommandLine = old }) 17 | flag.CommandLine = flag.NewFlagSet("test", flag.ContinueOnError) 18 | _ = flag.CommandLine.Parse(args) 19 | fn() 20 | } 21 | 22 | func TestSetupPrompts_ArgsOnly(t *testing.T) { 23 | withFlagArgs(t, []string{"cmd", "hello", "world"}, func() { 24 | c := &Configurations{PromptFormat: "%v"} 25 | if err := c.SetupPrompts(); err != nil { 26 | t.Fatalf("SetupPrompts error: %v", err) 27 | } 28 | if got, want := c.Prompt, "hello world"; got != want { 29 | t.Fatalf("got prompt %q, want %q", got, want) 30 | } 31 | }) 32 | } 33 | 34 | func TestSetupPrompts_StdinOnly(t *testing.T) { 35 | // Prepare stdin pipe 36 | oldStdin := os.Stdin 37 | r, w, err := os.Pipe() 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | os.Stdin = r 42 | t.Cleanup(func() { os.Stdin = oldStdin }) 43 | _, _ = w.WriteString("piped content") 44 | _ = w.Close() 45 | 46 | withFlagArgs(t, []string{"cmd"}, func() { 47 | c := &Configurations{PromptFormat: "==%v=="} 48 | if err := c.SetupPrompts(); err != nil { 49 | t.Fatalf("SetupPrompts error: %v", err) 50 | } 51 | if got, want := c.Prompt, "==piped content=="; got != want { 52 | t.Fatalf("got prompt %q, want %q", got, want) 53 | } 54 | }) 55 | } 56 | 57 | func TestSetupPrompts_ReplyModePrependsMessages(t *testing.T) { 58 | // Point config dir to a temp XDG config home 59 | tmp := t.TempDir() 60 | oldEnv := os.Getenv("XDG_CONFIG_HOME") 61 | t.Cleanup(func() { _ = os.Setenv("XDG_CONFIG_HOME", oldEnv) }) 62 | _ = os.Setenv("XDG_CONFIG_HOME", tmp) 63 | claiConfDir := path.Join(tmp, ".clai") 64 | if err := os.MkdirAll(path.Join(claiConfDir, "conversations"), 0o755); err != nil { 65 | t.Fatalf("mkdir: %v", err) 66 | } 67 | 68 | msgs := []pub_models.Message{ 69 | {Role: "user", Content: "hi there"}, 70 | } 71 | if err := chat.SaveAsPreviousQuery(claiConfDir, msgs); err != nil { 72 | t.Fatalf("save prev query: %v", err) 73 | } 74 | 75 | withFlagArgs(t, []string{"cmd", "hello"}, func() { 76 | c := &Configurations{PromptFormat: "%v", ReplyMode: true} 77 | if err := c.SetupPrompts(); err != nil { 78 | t.Fatalf("SetupPrompts error: %v", err) 79 | } 80 | if !strings.Contains(c.Prompt, "Messages:") || !strings.Contains(c.Prompt, "\"hi there\"") { 81 | t.Fatalf("expected previous messages embedded, got: %q", c.Prompt) 82 | } 83 | if !strings.HasSuffix(c.Prompt, "hello") { 84 | t.Fatalf("expected prompt to end with formatted args, got: %q", c.Prompt) 85 | } 86 | }) 87 | } 88 | -------------------------------------------------------------------------------- /internal/tools/clai_tool_wait_for_workers_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bytes" 5 | "os/exec" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | // TestClaiWaitForWorkers_NoWorkers ensures that the tool behaves when there are no workers. 12 | func TestClaiWaitForWorkers_NoWorkers(t *testing.T) { 13 | tool := &claiWaitForWorkersTool{} 14 | 15 | out, err := tool.Call(map[string]any{"timeout_seconds": 1}) 16 | if err != nil { 17 | t.Fatalf("expected no error, got %v", err) 18 | } 19 | if !strings.Contains(out, "No active workers") { 20 | t.Fatalf("unexpected output: %q", out) 21 | } 22 | } 23 | 24 | // TestClaiWaitForWorkers_WaitsAndAggregates simulates a small number of workers and verifies aggregation. 25 | func TestClaiWaitForWorkers_WaitsAndAggregates(t *testing.T) { 26 | // Setup fake workers 27 | claiRunsMu.Lock() 28 | claiRuns = map[string]*claiProcess{ 29 | "worker1": { 30 | cmd: &exec.Cmd{}, 31 | stdout: bytes.NewBufferString("output1"), 32 | stderr: bytes.NewBufferString("err1"), 33 | done: true, 34 | exitCode: 0, 35 | }, 36 | "worker2": { 37 | cmd: &exec.Cmd{}, 38 | stdout: bytes.NewBufferString("output2"), 39 | stderr: bytes.NewBufferString("err2"), 40 | done: true, 41 | exitCode: 1, 42 | }, 43 | } 44 | claiRunsMu.Unlock() 45 | 46 | tool := &claiWaitForWorkersTool{} 47 | 48 | out, err := tool.Call(map[string]any{"timeout_seconds": 1}) 49 | if err != nil { 50 | t.Fatalf("expected no error, got %v", err) 51 | } 52 | 53 | if !strings.Contains(out, "worker1") || !strings.Contains(out, "worker2") { 54 | t.Fatalf("expected output to contain both worker IDs, got: %q", out) 55 | } 56 | if !strings.Contains(out, "COMPLETED") || !strings.Contains(out, "FAILED") { 57 | t.Fatalf("expected output to contain statuses, got: %q", out) 58 | } 59 | } 60 | 61 | // TestClaiWaitForWorkers_Timeout ensures that a timeout results in an error and sends interrupts. 62 | func TestClaiWaitForWorkers_Timeout(t *testing.T) { 63 | cmd := exec.Command("sleep", "10") 64 | stdout := &bytes.Buffer{} 65 | stderr := &bytes.Buffer{} 66 | cmd.Stdout = stdout 67 | cmd.Stderr = stderr 68 | 69 | process := &claiProcess{ 70 | cmd: cmd, 71 | stdout: stdout, 72 | stderr: stderr, 73 | } 74 | 75 | if err := cmd.Start(); err != nil { 76 | t.Skipf("unable to start sleep command: %v", err) 77 | } 78 | 79 | claiRunsMu.Lock() 80 | claiRuns = map[string]*claiProcess{"worker1": process} 81 | claiRunsMu.Unlock() 82 | 83 | tool := &claiWaitForWorkersTool{} 84 | 85 | start := time.Now() 86 | _, err := tool.Call(map[string]any{"timeout_seconds": 0.5}) 87 | if err == nil { 88 | t.Fatalf("expected timeout error, got nil") 89 | } 90 | dur := time.Since(start) 91 | if dur < 500*time.Millisecond { 92 | t.Fatalf("expected to wait at least 500ms, waited %v", dur) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_rg.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | ) 9 | 10 | type RipGrepTool pub_models.Specification 11 | 12 | var RipGrep = RipGrepTool{ 13 | Name: "rg", 14 | Description: "Search for a pattern in files using ripgrep.", 15 | Inputs: &pub_models.InputSchema{ 16 | Type: "object", 17 | Properties: map[string]pub_models.ParameterObject{ 18 | "pattern": { 19 | Type: "string", 20 | Description: "The pattern to search for.", 21 | }, 22 | "path": { 23 | Type: "string", 24 | Description: "The path to search in.", 25 | }, 26 | "case_sensitive": { 27 | Type: "boolean", 28 | Description: "Whether the search should be case sensitive.", 29 | }, 30 | "line_number": { 31 | Type: "boolean", 32 | Description: "Whether to show line numbers.", 33 | }, 34 | "hidden": { 35 | Type: "boolean", 36 | Description: "Whether to search hidden files and directories.", 37 | }, 38 | }, 39 | Required: []string{"pattern"}, 40 | }, 41 | } 42 | 43 | func (r RipGrepTool) Call(input pub_models.Input) (string, error) { 44 | pattern, ok := input["pattern"].(string) 45 | if !ok { 46 | return "", fmt.Errorf("pattern must be a string") 47 | } 48 | cmd := exec.Command("rg", pattern) 49 | if input["path"] != nil { 50 | path, ok := input["path"].(string) 51 | if !ok { 52 | return "", fmt.Errorf("path must be a string") 53 | } 54 | cmd.Args = append(cmd.Args, path) 55 | } 56 | if input["case_sensitive"] != nil { 57 | caseSensitive, ok := input["case_sensitive"].(bool) 58 | if !ok { 59 | return "", fmt.Errorf("case_sensitive must be a boolean") 60 | } 61 | if caseSensitive { 62 | cmd.Args = append(cmd.Args, "--case-sensitive") 63 | } 64 | } 65 | if input["line_number"] != nil { 66 | lineNumber, ok := input["line_number"].(bool) 67 | if !ok { 68 | return "", fmt.Errorf("line_number must be a boolean") 69 | } 70 | if lineNumber { 71 | cmd.Args = append(cmd.Args, "--line-number") 72 | } 73 | } 74 | if input["hidden"] != nil { 75 | hidden, ok := input["hidden"].(bool) 76 | if !ok { 77 | return "", fmt.Errorf("hidden must be a boolean") 78 | } 79 | if hidden { 80 | cmd.Args = append(cmd.Args, "--hidden") 81 | } 82 | } 83 | output, err := cmd.CombinedOutput() 84 | if err != nil { 85 | // exit status 1 is not found, and not to be considered an error 86 | if err.Error() == "exit status 1" { 87 | err = nil 88 | output = []byte(fmt.Sprintf("found no hits with pattern: '%s'", pattern)) 89 | } else { 90 | return "", fmt.Errorf("failed to run rg: %w, output: %v", err, string(output)) 91 | } 92 | } 93 | return string(output), nil 94 | } 95 | 96 | func (r RipGrepTool) Specification() pub_models.Specification { 97 | return pub_models.Specification(RipGrep) 98 | } 99 | -------------------------------------------------------------------------------- /internal/video/store_test.go: -------------------------------------------------------------------------------- 1 | package video 2 | 3 | import ( 4 | "encoding/base64" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestSaveVideo(t *testing.T) { 12 | // Create a temporary directory for testing 13 | tmpDir, err := os.MkdirTemp("", "video_test") 14 | if err != nil { 15 | t.Fatalf("Failed to create temp dir: %v", err) 16 | } 17 | defer os.RemoveAll(tmpDir) 18 | 19 | // Create valid base64 data 20 | testContent := "Hello Video" 21 | b64Data := base64.StdEncoding.EncodeToString([]byte(testContent)) 22 | 23 | container := "mp4" 24 | 25 | t.Run("success write to dir", func(t *testing.T) { 26 | out := Output{ 27 | Dir: tmpDir, 28 | Prefix: "test", 29 | } 30 | 31 | filePath, err := SaveVideo(out, b64Data, container) 32 | if err != nil { 33 | t.Fatalf("SaveVideo failed: %v", err) 34 | } 35 | 36 | // Verify file exists 37 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 38 | t.Errorf("File was not created at %v", filePath) 39 | } 40 | 41 | // Verify content 42 | content, err := os.ReadFile(filePath) 43 | if err != nil { 44 | t.Fatalf("Failed to read created file: %v", err) 45 | } 46 | if string(content) != testContent { 47 | t.Errorf("File content mismatch. Got %s, want %s", string(content), testContent) 48 | } 49 | 50 | // Check filename format 51 | baseName := filepath.Base(filePath) 52 | if !strings.HasPrefix(baseName, "test_") || !strings.HasSuffix(baseName, "."+container) { 53 | t.Errorf("Filename format incorrect: %v", baseName) 54 | } 55 | }) 56 | 57 | t.Run("invalid base64", func(t *testing.T) { 58 | out := Output{ 59 | Dir: tmpDir, 60 | Prefix: "test", 61 | } 62 | _, err := SaveVideo(out, "invalid-base64!!!!", container) 63 | if err == nil { 64 | t.Error("Expected error for invalid base64, got nil") 65 | } 66 | }) 67 | 68 | t.Run("fallback to tmp", func(t *testing.T) { 69 | // Use a non-existent directory to force error (and thus fallback) 70 | // Or a directory we can"t write to. 71 | // Using a nested non-existent dir usually causes write error unless MkdirAll involves, 72 | // but WriteFile doesn"t create parent dirs. 73 | nonExistentDir := filepath.Join(tmpDir, "nonexistent") 74 | 75 | out := Output{ 76 | Dir: nonExistentDir, 77 | Prefix: "fallback", 78 | } 79 | 80 | filePath, err := SaveVideo(out, b64Data, container) 81 | if err != nil { 82 | t.Fatalf("SaveVideo failed during fallback test: %v", err) 83 | } 84 | 85 | // Check that it wrote to /tmp 86 | if !strings.HasPrefix(filePath, "/tmp/") { 87 | t.Errorf("Expected fallback to /tmp, got path: %v", filePath) 88 | } 89 | 90 | // Verify file exists 91 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 92 | t.Errorf("File was not created at fallback location %v", filePath) 93 | } 94 | 95 | // Clean up the file in /tmp 96 | defer os.Remove(filePath) 97 | }) 98 | } 99 | -------------------------------------------------------------------------------- /internal/tools/handler.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | pub_models "github.com/baalimago/clai/pkg/text/models" 8 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 9 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 10 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 11 | ) 12 | 13 | // Registry is the global registry of available LLM tools. 14 | var Registry = NewRegistry() 15 | 16 | // Init initializes the global Registry with available local LLM tools. 17 | // If the Registry has already been initialized, it simply returns. 18 | func Init() { 19 | if Registry.hasBeenInit { 20 | return 21 | } 22 | Registry.hasBeenInit = true 23 | Registry.Set(FileTree.Specification().Name, FileTree) 24 | Registry.Set(Cat.Specification().Name, Cat) 25 | Registry.Set(Find.Specification().Name, Find) 26 | Registry.Set(FileType.Specification().Name, FileType) 27 | Registry.Set(LS.Specification().Name, LS) 28 | Registry.Set(WebsiteText.Specification().Name, WebsiteText) 29 | Registry.Set(RipGrep.Specification().Name, RipGrep) 30 | Registry.Set(Go.Specification().Name, Go) 31 | Registry.Set(WriteFile.Specification().Name, WriteFile) 32 | Registry.Set(FreetextCmd.Specification().Name, FreetextCmd) 33 | Registry.Set(Sed.Specification().Name, Sed) 34 | Registry.Set(RowsBetween.Specification().Name, RowsBetween) 35 | Registry.Set(LineCount.Specification().Name, LineCount) 36 | Registry.Set(Git.Specification().Name, Git) 37 | Registry.Set(Recall.Specification().Name, Recall) 38 | Registry.Set(FFProbe.Specification().Name, FFProbe) 39 | Registry.Set(Date.Specification().Name, Date) 40 | Registry.Set(Pwd.Specification().Name, Pwd) 41 | Registry.Set(ClaiHelp.Specification().Name, ClaiHelp) 42 | Registry.Set(ClaiRun.Specification().Name, ClaiRun) 43 | Registry.Set(ClaiCheck.Specification().Name, ClaiCheck) 44 | Registry.Set(ClaiResult.Specification().Name, ClaiResult) 45 | Registry.Set(ClaiWaitForWorkers.Specification().Name, ClaiWaitForWorkers) 46 | Registry.Set(Date.Specification().Name, Date) 47 | } 48 | 49 | // Invoke the call, and gather both error and output in the same string 50 | func Invoke(call pub_models.Call) string { 51 | t, exists := Registry.Get(call.Name) 52 | if !exists { 53 | return "ERROR: unknown tool call: " + call.Name 54 | } 55 | if misc.Truthy(os.Getenv("DEBUG_CALL")) { 56 | ancli.Noticef("Invoke call: %v", debug.IndentedJsonFmt(call)) 57 | } 58 | inp := pub_models.Input{} 59 | if call.Inputs != nil { 60 | inp = *call.Inputs 61 | } 62 | out, err := t.Call(inp) 63 | if err != nil { 64 | return fmt.Sprintf("ERROR: failed to run tool: %v, error: %v", call.Name, err) 65 | } 66 | return out 67 | } 68 | 69 | // ToolFromName looks at the static tools.Tools map 70 | func ToolFromName(name string) pub_models.Specification { 71 | t, exists := Registry.Get(name) 72 | if !exists { 73 | return pub_models.Specification{} 74 | } 75 | return t.Specification() 76 | } 77 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "ChatGPT - Query", 6 | "type": "go", 7 | "request": "launch", 8 | "program": "${workspaceFolder}", 9 | "args": [ 10 | "-cm", 11 | "gpt-4-turbo-preview", 12 | "query", 13 | "I'm debugging my cli ai too. Write a short consice response." 14 | ], 15 | "env": { 16 | "NO_COLOR": "true" 17 | } 18 | }, 19 | { 20 | "name": "ChatGPT - Query - Tool", 21 | "type": "go", 22 | "request": "launch", 23 | "program": "${workspaceFolder}", 24 | "args": [ 25 | "-t", 26 | "-cm", 27 | "gpt-4-turbo", 28 | "query", 29 | "try to call the file tree command using /home/imago as agument, i'm debugging this functionality." 30 | ], 31 | "env": { 32 | "NO_COLOR": "true" 33 | } 34 | }, 35 | { 36 | "name": "Claude - Query", 37 | "type": "go", 38 | "request": "launch", 39 | "program": "${workspaceFolder}", 40 | "args": [ 41 | "query", 42 | "test" 43 | ], 44 | "env": { 45 | "NO_COLOR": "true" 46 | } 47 | }, 48 | { 49 | "name": "Claude - Query - Tool", 50 | "type": "go", 51 | "request": "launch", 52 | "program": "${workspaceFolder}", 53 | "args": [ 54 | "-t", 55 | "-cm", 56 | "claude-3-opus-20240229", 57 | "query", 58 | "try to call the file ls command on ~/, i'm debugging this functionality." 59 | ], 60 | "env": { 61 | "NO_COLOR": "true" 62 | } 63 | }, 64 | { 65 | "name": "ChatGPT - Chat - GlobMode", 66 | "type": "go", 67 | "request": "launch", 68 | "program": "${workspaceFolder}", 69 | "args": [ 70 | "-cm", 71 | "gpt-4o", 72 | "-glob", 73 | "README.md", 74 | "chat", 75 | "new", 76 | "Explain this project in 5 words" 77 | ], 78 | "env": { 79 | "NO_COLOR": "true" 80 | } 81 | }, 82 | { 83 | "name": "ChatGPT - Chat - Cmd", 84 | "type": "go", 85 | "request": "launch", 86 | "program": "${workspaceFolder}", 87 | "args": [ 88 | "-cm", 89 | "gpt-4o", 90 | "cmd", 91 | "give me a command to show my current directory" 92 | ], 93 | "env": { 94 | "NO_COLOR": "true" 95 | } 96 | }, 97 | { 98 | "name": "ChatGPT - Query - MCP", 99 | "type": "go", 100 | "request": "launch", 101 | "program": "${workspaceFolder}", 102 | "args": [ 103 | "-t", 104 | "q", 105 | "Use my browser a little" 106 | ], 107 | "env": { 108 | "NO_COLOR": "true" 109 | } 110 | } 111 | ] 112 | } 113 | -------------------------------------------------------------------------------- /internal/tools/registry.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "sync" 7 | 8 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 9 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 10 | ) 11 | 12 | // registry is a threadsafe storage for LLMTools. 13 | type registry struct { 14 | mu sync.RWMutex 15 | tools map[string]LLMTool 16 | debug bool 17 | hasBeenInit bool 18 | } 19 | 20 | // NewRegistry returns an empty tools registry. 21 | func NewRegistry() *registry { 22 | return ®istry{tools: make(map[string]LLMTool), debug: misc.Truthy(os.Getenv("DEBUG"))} 23 | } 24 | 25 | // Get returns the tool registered under name. 26 | func (r *registry) Get(name string) (LLMTool, bool) { 27 | r.mu.RLock() 28 | defer r.mu.RUnlock() 29 | t, ok := r.tools[name] 30 | return t, ok 31 | } 32 | 33 | // Add to registry.go 34 | func (r *registry) WildcardGet(pattern string) []LLMTool { 35 | r.mu.RLock() 36 | defer r.mu.RUnlock() 37 | 38 | var matches []LLMTool 39 | for name, tool := range r.tools { 40 | if WildcardMatch(pattern, name) { 41 | matches = append(matches, tool) 42 | } 43 | } 44 | return matches 45 | } 46 | 47 | func WildcardMatch(pattern, name string) bool { 48 | if pattern == "*" { 49 | return true 50 | } 51 | 52 | // Simple wildcard matching - supports * at start, end, or middle 53 | if strings.HasPrefix(pattern, "*") && strings.HasSuffix(pattern, "*") { 54 | // *substring* 55 | substr := pattern[1 : len(pattern)-1] 56 | return strings.Contains(name, substr) 57 | } else if strings.HasPrefix(pattern, "*") { 58 | // *suffix 59 | suffix := pattern[1:] 60 | return strings.HasSuffix(name, suffix) 61 | } else if strings.HasSuffix(pattern, "*") { 62 | // prefix* 63 | prefix := pattern[:len(pattern)-1] 64 | return strings.HasPrefix(name, prefix) 65 | } 66 | 67 | // No wildcards - exact match 68 | return pattern == name 69 | } 70 | 71 | // Set registers tool under the provided name. 72 | func (r *registry) Set(name string, t LLMTool) { 73 | r.mu.Lock() 74 | if strings.Contains(name, "printEnv") { 75 | ancli.Warnf("found env printing tool, skipping for security's sake. Tool name: '%v'", name) 76 | } 77 | if r.debug || misc.Truthy(os.Getenv("DEBUG_TOOLS_REGISTRY_SET")) { 78 | ancli.Okf("adding tool too registry, name: %v\n", t.Specification().Name) 79 | } 80 | r.tools[name] = t 81 | r.mu.Unlock() 82 | } 83 | 84 | // All returns a copy of all registered tools keyed by name. 85 | func (r *registry) All() map[string]LLMTool { 86 | r.mu.RLock() 87 | defer r.mu.RUnlock() 88 | cp := make(map[string]LLMTool, len(r.tools)) 89 | for k, v := range r.tools { 90 | cp[k] = v 91 | } 92 | return cp 93 | } 94 | 95 | // Reset removes all registered tools. Primarily used for tests. 96 | func (r *registry) Reset() { 97 | r.mu.Lock() 98 | r.tools = make(map[string]LLMTool) 99 | r.mu.Unlock() 100 | } 101 | -------------------------------------------------------------------------------- /internal/utils/prompt_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestPrompt(t *testing.T) { 9 | testCases := []struct { 10 | name string 11 | stdinReplace string 12 | args []string 13 | stdin string 14 | expectedPrompt string 15 | expectedError bool 16 | }{ 17 | { 18 | name: "No arguments and no stdin", 19 | stdinReplace: "", 20 | args: []string{""}, 21 | stdin: "", 22 | expectedPrompt: "", 23 | expectedError: true, 24 | }, 25 | { 26 | name: "Arguments only", 27 | stdinReplace: "", 28 | args: []string{"cmd", "arg1", "arg2"}, 29 | stdin: "", 30 | expectedPrompt: "arg1 arg2", 31 | expectedError: false, 32 | }, 33 | { 34 | name: "Stdin only", 35 | stdinReplace: "", 36 | args: []string{"cmd"}, 37 | stdin: "input from stdin", 38 | expectedPrompt: "input from stdin", 39 | expectedError: false, 40 | }, 41 | { 42 | name: "Arguments and stdin", 43 | stdinReplace: "{}", 44 | args: []string{"cmd", "arg1", "arg2", "{}"}, 45 | stdin: "input from stdin", 46 | expectedPrompt: "arg1 arg2 input from stdin", 47 | expectedError: false, 48 | }, 49 | { 50 | name: "Arguments with stdinReplace", 51 | stdinReplace: "", 52 | args: []string{"cmd", "prefix", "", "suffix"}, 53 | stdin: "input from stdin", 54 | expectedPrompt: "prefix input from stdin suffix", 55 | expectedError: false, 56 | }, 57 | { 58 | name: "Arguments with stdinReplace", 59 | stdinReplace: "", 60 | args: []string{"cmd", "prefix", "suffix"}, 61 | stdin: "input from stdin", 62 | expectedPrompt: "prefix suffix input from stdin", 63 | expectedError: false, 64 | }, 65 | } 66 | 67 | for _, tc := range testCases { 68 | t.Run(tc.name, func(t *testing.T) { 69 | if tc.stdin != "" { 70 | // Set up stdin 71 | oldStdin := os.Stdin 72 | t.Cleanup(func() { os.Stdin = oldStdin }) 73 | r, w, err := os.Pipe() 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | os.Stdin = r 78 | _, err = w.WriteString(tc.stdin) 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | w.Close() 83 | } 84 | 85 | // Call the function 86 | prompt, err := Prompt(tc.stdinReplace, tc.args) 87 | 88 | // Check the error 89 | if tc.expectedError && err == nil { 90 | t.Error("Expected an error, but got nil") 91 | } else if !tc.expectedError && err != nil { 92 | t.Errorf("Unexpected error: %v", err) 93 | } 94 | 95 | // Check the prompt 96 | if prompt != tc.expectedPrompt { 97 | t.Errorf("Prompt mismatch. Expected: %q, Got: %q", tc.expectedPrompt, prompt) 98 | } 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /internal/vendors/mistral/mistral.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/baalimago/clai/internal/models" 8 | "github.com/baalimago/clai/internal/text/generic" 9 | "github.com/baalimago/clai/internal/tools" 10 | "github.com/baalimago/clai/internal/utils" 11 | pub_models "github.com/baalimago/clai/pkg/text/models" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | ) 14 | 15 | const MistralURL = "https://api.mistral.ai/v1/chat/completions" 16 | 17 | var Default = Mistral{ 18 | Model: "mistral-large-latest", 19 | Temperature: 0.7, 20 | TopP: 1.0, 21 | URL: MistralURL, 22 | MaxTokens: 100000, 23 | } 24 | 25 | type Mistral struct { 26 | generic.StreamCompleter 27 | Model string `json:"model"` 28 | URL string `json:"url"` 29 | TopP float64 `json:"top_p"` 30 | Temperature float64 `json:"temperature"` 31 | SafePrompt bool `json:"safe_prompt"` 32 | MaxTokens int `json:"max_tokens"` 33 | RandomSeed int `json:"random_seed"` 34 | } 35 | 36 | func clean(msg []pub_models.Message) []pub_models.Message { 37 | // Mistral doesn't like additional fields in the tools call 38 | for i, m := range msg { 39 | if m.Role == "assistant" { 40 | if len(m.ToolCalls) > 0 { 41 | m.Content = "" 42 | } 43 | for j, tc := range m.ToolCalls { 44 | tc.Name = "" 45 | tc.Inputs = nil 46 | tc.Function.Description = "" 47 | tc.Function.Inputs = nil 48 | tc.ExtraContent = nil 49 | m.ToolCalls[j] = tc 50 | } 51 | } 52 | msg[i] = m 53 | } 54 | 55 | for i := 0; i < len(msg)-1; i++ { 56 | if msg[i].Role == "tool" && msg[i+1].Role == "system" { 57 | msg[i+1].Role = "assistant" 58 | } 59 | } 60 | 61 | // Merge consequtive assistant messages 62 | for i := 1; i < len(msg); i++ { 63 | if msg[i].Role == "assistant" && msg[i-1].Role == "assistant" { 64 | msg[i-1].Content += "\n" + msg[i].Content 65 | nMsg, err := utils.DeleteRange(msg, i, i) 66 | if err != nil { 67 | ancli.Errf("failed to delete range. No error management here... Not great. Why error here? Stop please...: %v", err) 68 | } 69 | msg = nMsg 70 | i-- 71 | } 72 | } 73 | 74 | return msg 75 | } 76 | 77 | func (m *Mistral) Setup() error { 78 | err := m.StreamCompleter.Setup("MISTRAL_API_KEY", MistralURL, "DEBUG_MISTRAL") 79 | if err != nil { 80 | return fmt.Errorf("failed to setup stream completer: %w", err) 81 | } 82 | m.StreamCompleter.Model = m.Model 83 | m.StreamCompleter.MaxTokens = &m.MaxTokens 84 | m.StreamCompleter.Temperature = &m.Temperature 85 | m.StreamCompleter.TopP = &m.TopP 86 | toolChoice := "any" 87 | m.ToolChoice = &toolChoice 88 | m.Clean = clean 89 | 90 | return nil 91 | } 92 | 93 | func (m *Mistral) StreamCompletions(ctx context.Context, chat pub_models.Chat) (chan models.CompletionEvent, error) { 94 | return m.StreamCompleter.StreamCompletions(ctx, chat) 95 | } 96 | 97 | func (m *Mistral) RegisterTool(tool tools.LLMTool) { 98 | m.InternalRegisterTool(tool) 99 | } 100 | -------------------------------------------------------------------------------- /pkg/text/models/chat_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestMessageJSON(t *testing.T) { 10 | // Test round-trip with simple string content 11 | simpleMsg := Message{Role: "user", Content: "hello world"} 12 | data, err := json.Marshal(simpleMsg) 13 | if err != nil { 14 | t.Fatalf("failed to marshal simple message: %v", err) 15 | } 16 | var decodedSimple Message 17 | if err := json.Unmarshal(data, &decodedSimple); err != nil { 18 | t.Fatalf("failed to unmarshal simple message: %v", err) 19 | } 20 | if decodedSimple.Role != simpleMsg.Role || decodedSimple.Content != simpleMsg.Content { 21 | t.Errorf("simple message roundtrip mismatch. got: %+v, want: %+v", decodedSimple, simpleMsg) 22 | } 23 | if len(decodedSimple.ContentParts) != 0 { 24 | t.Errorf("expected nil/empty ContentParts, got %v", decodedSimple.ContentParts) 25 | } 26 | 27 | // Test round-trip with ContentParts 28 | partsMsg := Message{ 29 | Role: "user", 30 | ContentParts: []ImageOrTextInput{ 31 | {Type: "text", Text: "describe this image"}, 32 | {Type: "image_url", ImageB64: &ImageURL{URL: "http://example.com/img.png", Detail: "high"}}, 33 | }, 34 | } 35 | data, err = json.Marshal(partsMsg) 36 | if err != nil { 37 | t.Fatalf("failed to marshal parts message: %v", err) 38 | } 39 | var decodedParts Message 40 | if err := json.Unmarshal(data, &decodedParts); err != nil { 41 | t.Fatalf("failed to unmarshal parts message: %v", err) 42 | } 43 | if decodedParts.Role != partsMsg.Role { 44 | t.Errorf("parts message role mismatch. got: %v, want: %v", decodedParts.Role, partsMsg.Role) 45 | } 46 | if decodedParts.Content != "" { 47 | t.Errorf("expected empty Content, got %q", decodedParts.Content) 48 | } 49 | if len(decodedParts.ContentParts) != 2 { 50 | t.Fatalf("expected 2 content parts, got %d", len(decodedParts.ContentParts)) 51 | } 52 | if decodedParts.ContentParts[0].Text != "describe this image" { 53 | t.Errorf("expected text part match, got %v", decodedParts.ContentParts[0]) 54 | } 55 | if decodedParts.ContentParts[1].ImageB64.URL != "http://example.com/img.png" { 56 | t.Errorf("expected image url match, got %v", decodedParts.ContentParts[1].ImageB64) 57 | } 58 | } 59 | 60 | func TestChatHelpers(t *testing.T) { 61 | c := Chat{ 62 | Created: time.Now(), 63 | ID: "id1", 64 | Messages: []Message{ 65 | {Role: "system", Content: "sys"}, 66 | {Role: "user", Content: "u1"}, 67 | {Role: "assistant", Content: "a"}, 68 | {Role: "user", Content: "u2"}, 69 | }, 70 | } 71 | 72 | // First system 73 | if m, err := c.FirstSystemMessage(); err != nil || m.Content != "sys" { 74 | t.Fatalf("FirstSystemMessage unexpected: %v, %v", m, err) 75 | } 76 | // First user 77 | if m, err := c.FirstUserMessage(); err != nil || m.Content != "u1" { 78 | t.Fatalf("FirstUserMessage unexpected: %v, %v", m, err) 79 | } 80 | // Last of role 81 | m, idx, err := c.LastOfRole("user") 82 | if err != nil || m.Content != "u2" || idx != 3 { 83 | t.Fatalf("LastOfRole unexpected: %v, %v, %d", m, err, idx) 84 | } 85 | // Missing role 86 | if _, _, err := c.LastOfRole("none"); err == nil { 87 | t.Fatalf("expected error for missing role") 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /pkg/text/full.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path" 8 | 9 | "github.com/baalimago/clai/internal" 10 | priv_models "github.com/baalimago/clai/internal/models" 11 | "github.com/baalimago/clai/internal/text" 12 | "github.com/baalimago/clai/pkg/text/models" 13 | ) 14 | 15 | // FullResponse text querier, as opposed to returning a stream or something 16 | type FullResponse interface { 17 | Setup(context.Context) error 18 | 19 | // Query the underlying llm with some prompt. Will cancel on context cancel. 20 | Query(context.Context, models.Chat) (models.Chat, error) 21 | } 22 | 23 | type publicQuerier struct { 24 | conf text.Configurations 25 | querier priv_models.ChatQuerier 26 | } 27 | 28 | func NewFullResponseQuerier(c models.Configurations) FullResponse { 29 | return &publicQuerier{ 30 | conf: pubConfigToInternal(c), 31 | } 32 | } 33 | 34 | func internalToolsToString(in []models.ToolName) (ret []string) { 35 | for _, s := range in { 36 | ret = append(ret, string(s)) 37 | } 38 | return 39 | } 40 | 41 | func pubConfigToInternal(c models.Configurations) text.Configurations { 42 | claiDir := path.Join(c.ConfigDir, "clai") 43 | 44 | return text.Configurations{ 45 | Model: c.Model, 46 | SystemPrompt: c.SystemPrompt, 47 | UseTools: true, 48 | ConfigDir: claiDir, 49 | TokenWarnLimit: 300000, 50 | ToolOutputRuneLimit: 30000, 51 | SaveReplyAsConv: true, 52 | Stream: true, 53 | UseProfile: "", 54 | ProfilePath: "", 55 | Tools: internalToolsToString(c.InternalTools), 56 | } 57 | } 58 | 59 | // Setup the public querier by creating a config dir + supportive directories, then by initiating 60 | // the querier following the config 61 | func (pq *publicQuerier) Setup(ctx context.Context) error { 62 | if _, err := os.Stat(pq.conf.ConfigDir); os.IsNotExist(err) { 63 | os.Mkdir(pq.conf.ConfigDir, 0o755) 64 | } 65 | mcpServersDir := path.Join(pq.conf.ConfigDir, "mcpServers") 66 | if _, err := os.Stat(mcpServersDir); os.IsNotExist(err) { 67 | os.Mkdir(mcpServersDir, 0o755) 68 | } 69 | conversationsDir := path.Join(pq.conf.ConfigDir, "conversations") 70 | if _, err := os.Stat(mcpServersDir); os.IsNotExist(err) { 71 | os.Mkdir(conversationsDir, 0o755) 72 | } 73 | querier, err := internal.CreateTextQuerier(ctx, pq.conf) 74 | if err != nil { 75 | return fmt.Errorf("publicQuerier.Setup failed to CreateTextQuerier: %v", err) 76 | } 77 | tq, isChatQuerier := querier.(priv_models.ChatQuerier) 78 | if !isChatQuerier { 79 | return fmt.Errorf("failed to cast Querier using model: '%v' to TextQuerier, cannot proceed", pq.conf.Model) 80 | } 81 | pq.querier = tq 82 | return nil 83 | } 84 | 85 | // Query the model with some input chat. Will return a chat containing updated responses. The returning chat may 86 | // append multiple messages to the chat, if the querier is configured to be agentic (use tools) 87 | func (pq *publicQuerier) Query(ctx context.Context, inpChat models.Chat) (models.Chat, error) { 88 | err := pq.Setup(ctx) 89 | if err != nil { 90 | return models.Chat{}, fmt.Errorf("pq.Query failed to Setup clone: %v", err) 91 | } 92 | return pq.querier.TextQuery(ctx, inpChat) 93 | } 94 | -------------------------------------------------------------------------------- /internal/glob/glob.go: -------------------------------------------------------------------------------- 1 | package glob 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/baalimago/clai/internal/utils" 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 13 | ) 14 | 15 | // Setup the glob parsing. Currently this is a bit messy as it works 16 | // both for flag glob and arg glob. Once arg glob is deprecated, this 17 | // function may be cleaned up 18 | func Setup(flagGlob string, args []string) (string, []string, error) { 19 | globArg := args[0] == "g" || args[0] == "glob" 20 | if globArg && len(args) < 2 { 21 | return "", args, fmt.Errorf("not enough arguments provided") 22 | } 23 | glob := args[1] 24 | if globArg { 25 | if flagGlob != "" { 26 | ancli.PrintWarn(fmt.Sprintf("both glob-arg and glob-flag is specified. This is confusing. Using glob-arg query: %v\n", glob)) 27 | } 28 | args = args[1:] 29 | } else { 30 | glob = flagGlob 31 | } 32 | if !strings.Contains(glob, "*") { 33 | ancli.PrintWarn(fmt.Sprintf("found no '*' in glob: %v, has it already been expanded? Consider enclosing glob in single quotes\n", glob)) 34 | } 35 | if misc.Truthy(os.Getenv("DEBUG")) { 36 | ancli.PrintOK(fmt.Sprintf("found glob: %v\n", glob)) 37 | } 38 | return glob, args, nil 39 | } 40 | 41 | func CreateChat(glob, systemPrompt string) (pub_models.Chat, error) { 42 | fileMessages, err := parseGlob(glob) 43 | if err != nil { 44 | return pub_models.Chat{}, fmt.Errorf("failed to parse glob string: '%v', err: %w", glob, err) 45 | } 46 | 47 | return pub_models.Chat{ 48 | ID: fmt.Sprintf("glob_%v", glob), 49 | Messages: constructGlobMessages(fileMessages), 50 | }, nil 51 | } 52 | 53 | func constructGlobMessages(globMessages []pub_models.Message) []pub_models.Message { 54 | ret := make([]pub_models.Message, 0, len(globMessages)+4) 55 | ret = append(ret, pub_models.Message{ 56 | Role: "system", 57 | Content: "You will be given a series of messages each containing contents from files, then a message containing this: '#####'. Using the file content as context, perform the request given in the message after the '#####'.", 58 | }) 59 | ret = append(ret, globMessages...) 60 | ret = append(ret, pub_models.Message{ 61 | Role: "user", 62 | Content: "#####", 63 | }) 64 | return ret 65 | } 66 | 67 | func parseGlob(glob string) ([]pub_models.Message, error) { 68 | glob, err := utils.ReplaceTildeWithHome(glob) 69 | if err != nil { 70 | return nil, fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err) 71 | } 72 | files, err := filepath.Glob(glob) 73 | ret := make([]pub_models.Message, 0, len(files)) 74 | if err != nil { 75 | return nil, fmt.Errorf("failed to parse glob: %w", err) 76 | } 77 | if misc.Truthy(os.Getenv("DEBUG")) { 78 | ancli.PrintOK(fmt.Sprintf("found %d files: %v\n", len(files), files)) 79 | } 80 | 81 | if len(files) == 0 { 82 | return nil, fmt.Errorf("no files found") 83 | } 84 | 85 | for _, file := range files { 86 | data, err := os.ReadFile(file) 87 | if err != nil { 88 | ancli.PrintWarn(fmt.Sprintf("failed to read file: %v\n", err)) 89 | continue 90 | } 91 | ret = append(ret, pub_models.Message{ 92 | Role: "user", 93 | Content: fmt.Sprintf("{\"fileName\": \"%v\", \"data\": \"%v\"}", file, string(data)), 94 | }) 95 | } 96 | return ret, nil 97 | } 98 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_git.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "strconv" 7 | 8 | pub_models "github.com/baalimago/clai/pkg/text/models" 9 | ) 10 | 11 | type GitTool pub_models.Specification 12 | 13 | var Git = GitTool{ 14 | Name: "git", 15 | Description: "Run read-only git commands like log, diff, show, blame and status.", 16 | Inputs: &pub_models.InputSchema{ 17 | Type: "object", 18 | Properties: map[string]pub_models.ParameterObject{ 19 | "operation": { 20 | Type: "string", 21 | Description: "The git operation to run.", 22 | Enum: &[]string{"log", "diff", "show", "status", "blame"}, 23 | }, 24 | "file": { 25 | Type: "string", 26 | Description: "Optional file path used by diff, show or blame.", 27 | }, 28 | "commit": { 29 | Type: "string", 30 | Description: "Optional commit hash used by show or diff.", 31 | }, 32 | "range": { 33 | Type: "string", 34 | Description: "Optional revision range for log or diff.", 35 | }, 36 | "n": { 37 | Type: "integer", 38 | Description: "Number of log entries to display.", 39 | }, 40 | "dir": { 41 | Type: "string", 42 | Description: "Directory containing the git repository (optional).", 43 | }, 44 | }, 45 | Required: []string{"operation"}, 46 | }, 47 | } 48 | 49 | func (g GitTool) Call(input pub_models.Input) (string, error) { 50 | op, ok := input["operation"].(string) 51 | if !ok { 52 | return "", fmt.Errorf("operation must be a string") 53 | } 54 | 55 | args := []string{op} 56 | 57 | switch op { 58 | case "log": 59 | if v, ok := input["n"]; ok { 60 | num := 0 61 | switch n := v.(type) { 62 | case int: 63 | num = n 64 | case float64: 65 | num = int(n) 66 | case string: 67 | if i, err := strconv.Atoi(n); err == nil { 68 | num = i 69 | } 70 | } 71 | if num > 0 { 72 | args = append(args, "-n", fmt.Sprintf("%d", num)) 73 | } 74 | } 75 | if r, ok := input["range"].(string); ok && r != "" { 76 | args = append(args, r) 77 | } 78 | case "diff": 79 | if r, ok := input["range"].(string); ok && r != "" { 80 | args = append(args, r) 81 | } 82 | if f, ok := input["file"].(string); ok && f != "" { 83 | args = append(args, "--", f) 84 | } 85 | case "show": 86 | if c, ok := input["commit"].(string); ok && c != "" { 87 | args = append(args, c) 88 | } 89 | if f, ok := input["file"].(string); ok && f != "" { 90 | args = append(args, "--", f) 91 | } 92 | case "status": 93 | args = []string{"status", "--short"} 94 | case "blame": 95 | if f, ok := input["file"].(string); ok && f != "" { 96 | args = append(args, f) 97 | } else { 98 | return "", fmt.Errorf("file is required for blame") 99 | } 100 | default: 101 | return "", fmt.Errorf("unsupported git operation: %s", op) 102 | } 103 | 104 | cmd := exec.Command("git", args...) 105 | if d, ok := input["dir"].(string); ok && d != "" { 106 | cmd.Dir = d 107 | } 108 | output, err := cmd.CombinedOutput() 109 | if err != nil { 110 | return "", fmt.Errorf("failed to run git %s: %w, output: %s", op, err, string(output)) 111 | } 112 | return string(output), nil 113 | } 114 | 115 | func (g GitTool) Specification() pub_models.Specification { 116 | return pub_models.Specification(Git) 117 | } 118 | -------------------------------------------------------------------------------- /internal/setup/mcp_parser.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/baalimago/clai/internal/utils" 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | ) 13 | 14 | // McpServerInput represents the external format that users might paste 15 | type McpServerInput struct { 16 | McpServers map[string]McpServerExternal `json:"mcpServers"` 17 | } 18 | 19 | // McpServerExternal represents various external MCP server formats 20 | type McpServerExternal struct { 21 | Command string `json:"command"` 22 | Args []string `json:"args"` 23 | Env map[string]string `json:"env,omitempty"` 24 | } 25 | 26 | // ParseAndAddMcpServer parses pasted MCP server configuration and adds it to the system 27 | func ParseAndAddMcpServer(mcpServersDir, pastedConfig string) ([]string, error) { 28 | // Try to parse as the external format first 29 | var input McpServerInput 30 | if err := json.Unmarshal([]byte(pastedConfig), &input); err != nil { 31 | return nil, fmt.Errorf("failed to parse MCP server configuration: %w", err) 32 | } 33 | 34 | if len(input.McpServers) == 0 { 35 | return nil, fmt.Errorf("no MCP servers found in configuration") 36 | } 37 | 38 | ret := make([]string, 0) 39 | 40 | // Convert and save each server 41 | for serverName, externalServer := range input.McpServers { 42 | internalServer := convertToInternalFormat(externalServer) 43 | 44 | // Save to individual file 45 | serverPath := filepath.Join(mcpServersDir, fmt.Sprintf("%s.json", serverName)) 46 | if err := utils.CreateFile(serverPath, &internalServer); err != nil { 47 | return nil, fmt.Errorf("failed to create server file for %s: %w", serverName, err) 48 | } 49 | 50 | ancli.Noticef("Added MCP server: %s\n", serverName) 51 | ret = append(ret, serverName) 52 | } 53 | 54 | return ret, nil 55 | } 56 | 57 | // convertToInternalFormat converts external format to internal pub_models.McpServer format 58 | func convertToInternalFormat(external McpServerExternal) pub_models.McpServer { 59 | internal := pub_models.McpServer{ 60 | Command: external.Command, 61 | Args: external.Args, 62 | Env: external.Env, 63 | } 64 | 65 | // Initialize empty env map if nil 66 | if internal.Env == nil { 67 | internal.Env = make(map[string]string) 68 | } 69 | 70 | return internal 71 | } 72 | 73 | // ValidateMcpServerConfig validates that the pasted config is valid 74 | func ValidateMcpServerConfig(pastedConfig string) error { 75 | pastedConfig = strings.TrimSpace(pastedConfig) 76 | if pastedConfig == "" { 77 | return fmt.Errorf("empty configuration provided") 78 | } 79 | 80 | // Try to parse as JSON 81 | var input McpServerInput 82 | if err := json.Unmarshal([]byte(pastedConfig), &input); err != nil { 83 | return fmt.Errorf("invalid JSON format: %w", err) 84 | } 85 | 86 | if len(input.McpServers) == 0 { 87 | return fmt.Errorf("no 'mcpServers' section found or it's empty") 88 | } 89 | 90 | // Validate each server 91 | for serverName, server := range input.McpServers { 92 | if serverName == "" { 93 | return fmt.Errorf("server name cannot be empty") 94 | } 95 | if server.Command == "" { 96 | return fmt.Errorf("command cannot be empty for server '%s'", serverName) 97 | } 98 | } 99 | 100 | return nil 101 | } 102 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_sed.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "regexp" 7 | "strconv" 8 | "strings" 9 | 10 | pub_models "github.com/baalimago/clai/pkg/text/models" 11 | ) 12 | 13 | type SedTool pub_models.Specification 14 | 15 | var Sed = SedTool{ 16 | Name: "sed", 17 | Description: "Perform a basic regex substitution on each line or within a specific line range of a file (like 'sed s/pattern/repl/g'). Overwrites the file.", 18 | Inputs: &pub_models.InputSchema{ 19 | Type: "object", 20 | Properties: map[string]pub_models.ParameterObject{ 21 | "file_path": { 22 | Type: "string", 23 | Description: "The path to the file to modify.", 24 | }, 25 | "pattern": { 26 | Type: "string", 27 | Description: "The regex pattern to search for.", 28 | }, 29 | "repl": { 30 | Type: "string", 31 | Description: "The replacement string.", 32 | }, 33 | "start_line": { 34 | Type: "integer", 35 | Description: "Optional. First line to modify (1-based, inclusive).", 36 | }, 37 | "end_line": { 38 | Type: "integer", 39 | Description: "Optional. Last line to modify (1-based, inclusive).", 40 | }, 41 | }, 42 | Required: []string{"file_path", "pattern", "repl"}, 43 | }, 44 | } 45 | 46 | func (s SedTool) Call(input pub_models.Input) (string, error) { 47 | filePath, ok := input["file_path"].(string) 48 | if !ok { 49 | return "", fmt.Errorf("file_path must be a string") 50 | } 51 | pattern, ok := input["pattern"].(string) 52 | if !ok { 53 | return "", fmt.Errorf("pattern must be a string") 54 | } 55 | repl, ok := input["repl"].(string) 56 | if !ok { 57 | return "", fmt.Errorf("repl must be a string") 58 | } 59 | 60 | var startLine, endLine int 61 | if v, ok := input["start_line"]; ok { 62 | switch n := v.(type) { 63 | case float64: 64 | startLine = int(n) 65 | case int: 66 | startLine = n 67 | case string: 68 | startLine, _ = strconv.Atoi(n) 69 | } 70 | } 71 | if v, ok := input["end_line"]; ok { 72 | switch n := v.(type) { 73 | case float64: 74 | endLine = int(n) 75 | case int: 76 | endLine = n 77 | case string: 78 | endLine, _ = strconv.Atoi(n) 79 | } 80 | } 81 | 82 | raw, err := os.ReadFile(filePath) 83 | if err != nil { 84 | return "", fmt.Errorf("failed to read file: %w", err) 85 | } 86 | 87 | re, err := regexp.Compile(pattern) 88 | if err != nil { 89 | return "", fmt.Errorf("invalid regex: %w", err) 90 | } 91 | 92 | lines := strings.Split(string(raw), "\n") 93 | for i := range lines { 94 | lineNum := i + 1 95 | if (startLine == 0 && endLine == 0) || 96 | (startLine > 0 && endLine > 0 && lineNum >= startLine && lineNum <= endLine) || 97 | (startLine > 0 && endLine == 0 && lineNum >= startLine) || 98 | (startLine == 0 && endLine > 0 && lineNum <= endLine) { 99 | lines[i] = re.ReplaceAllString(lines[i], repl) 100 | } 101 | } 102 | 103 | out := strings.Join(lines, "\n") 104 | err = os.WriteFile(filePath, []byte(out), 0o644) 105 | if err != nil { 106 | return "", fmt.Errorf("failed to write file: %w", err) 107 | } 108 | return fmt.Sprintf("sed: replaced occurrences of %q with %q in %s (%d-%d)", pattern, repl, filePath, startLine, endLine), nil 109 | } 110 | 111 | func (s SedTool) Specification() pub_models.Specification { 112 | return pub_models.Specification(Sed) 113 | } 114 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_stream_block_events.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/baalimago/clai/internal/models" 9 | pub_models "github.com/baalimago/clai/pkg/text/models" 10 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 11 | ) 12 | 13 | func (c *Claude) handleContentBlockStart(blockStart string) models.CompletionEvent { 14 | var blockSuper ContentBlockSuper 15 | blockStart = trimDataPrefix(blockStart) 16 | if err := json.Unmarshal([]byte(blockStart), &blockSuper); err != nil { 17 | return fmt.Errorf("failed to unmarshal blockStart with content: %v, error: %w", blockStart, err) 18 | } 19 | block := blockSuper.ToolContentBlock 20 | c.contentBlockType = block.Type 21 | switch block.Type { 22 | case "tool_use": 23 | c.functionName = block.Name 24 | c.functionID = block.ID 25 | } 26 | return models.NoopEvent{} 27 | } 28 | 29 | // handleContentBlockDelta processes a delta token to generate a CompletionEvent. 30 | // It converts the delta token into a structured format and evaluates the type of 31 | // delta to determine the appropriate action. The function handles "text_delta" 32 | // types by checking if the text content is empty, and returns an error if so. 33 | // For "input_json_delta" types, it delegates processing to handleInputJSONDelta. 34 | // Returns an error for unexpected delta types. JSON data is printed if debugging 35 | // is enabled. 36 | // 37 | // Parameters: 38 | // - deltaToken: A string representing the delta token to be processed. 39 | // 40 | // Returns: 41 | // - models.CompletionEvent: A response event generated from the delta token. 42 | // - error: An error is returned if the delta type is unexpected or the text is empty. 43 | func (c *Claude) handleContentBlockDelta(deltaToken string) models.CompletionEvent { 44 | delta, err := c.stringFromDeltaToken(deltaToken) 45 | if err != nil { 46 | return fmt.Errorf("failed to convert string to delta token: %w", err) 47 | } 48 | if c.debug { 49 | fmt.Printf("deltaStruct: '%v'\n---\n", 50 | debug.IndentedJsonFmt(delta)) 51 | } 52 | switch delta.Type { 53 | case "text_delta": 54 | if delta.Text == "" { 55 | return errors.New("unexpected empty response") 56 | } 57 | return delta.Text 58 | case "input_json_delta": 59 | return c.handleInputJSONDelta(delta) 60 | default: 61 | return fmt.Errorf("unexpected delta type: %v", delta.Type) 62 | } 63 | } 64 | 65 | func (c *Claude) handleInputJSONDelta(delta Delta) models.CompletionEvent { 66 | partial := delta.PartialJSON 67 | c.functionJSON += partial 68 | return partial 69 | } 70 | 71 | func (c *Claude) handleContentBlockStop(blockStop string) models.CompletionEvent { 72 | defer func() { 73 | c.debugFullStreamMsg = "" 74 | c.functionJSON = "" 75 | }() 76 | var block ToolUseContentBlock 77 | blockStop = trimDataPrefix(blockStop) 78 | if err := json.Unmarshal([]byte(blockStop), &block); err != nil { 79 | return fmt.Errorf("failed to unmarshal blockStop: %w", err) 80 | } 81 | 82 | switch c.contentBlockType { 83 | case "tool_use": 84 | var inputs pub_models.Input 85 | if c.functionJSON != "" { 86 | if err := json.Unmarshal([]byte(c.functionJSON), &inputs); err != nil { 87 | return fmt.Errorf("failed to unmarshal functionJSON: %v, error is: %w", c.functionJSON, err) 88 | } 89 | } 90 | return pub_models.Call{ 91 | Name: c.functionName, 92 | Inputs: &inputs, 93 | ID: c.functionID, 94 | } 95 | } 96 | return models.NoopEvent{} 97 | } 98 | -------------------------------------------------------------------------------- /internal/vendors/ollama/ollama_test.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestSetup_Default_SetsFields(t *testing.T) { 9 | v := Default 10 | t.Setenv("OLLAMA_API_KEY", "") 11 | if err := v.Setup(); err != nil { 12 | t.Fatalf("setup failed: %v", err) 13 | } 14 | if v.ToolChoice == nil { 15 | t.Fatalf("toolchoice nil") 16 | } 17 | if *v.ToolChoice != "auto" { 18 | t.Fatalf("toolchoice got %q want %q", *v.ToolChoice, "auto") 19 | } 20 | if v.StreamCompleter.Temperature == nil { 21 | t.Fatalf("temperature ptr nil") 22 | } 23 | if v.StreamCompleter.TopP == nil { 24 | t.Fatalf("top_p ptr nil") 25 | } 26 | if v.StreamCompleter.FrequencyPenalty == nil { 27 | t.Fatalf("freq ptr nil") 28 | } 29 | // should keep model when no prefix is present 30 | if v.StreamCompleter.Model != "llama3" { 31 | t.Fatalf("model got %q want %q", 32 | v.StreamCompleter.Model, "llama3") 33 | } 34 | } 35 | 36 | func TestSetup_TrimsOllamaPrefix(t *testing.T) { 37 | v := Default 38 | v.Model = "ollama:deepseek-r1:8b" 39 | t.Setenv("OLLAMA_API_KEY", "") 40 | if err := v.Setup(); err != nil { 41 | t.Fatalf("setup failed: %v", err) 42 | } 43 | want := "deepseek-r1:8b" 44 | if v.StreamCompleter.Model != want { 45 | t.Fatalf("model got %q want %q", 46 | v.StreamCompleter.Model, want) 47 | } 48 | } 49 | 50 | func TestSetup_RespectsExistingAPIKey(t *testing.T) { 51 | v := Default 52 | t.Setenv("OLLAMA_API_KEY", "some-key") 53 | if err := v.Setup(); err != nil { 54 | t.Fatalf("setup failed: %v", err) 55 | } 56 | if got := os.Getenv("OLLAMA_API_KEY"); got != "some-key" { 57 | t.Fatalf("api key got %q want %q", got, "some-key") 58 | } 59 | } 60 | 61 | func TestSetupConfigMapping(t *testing.T) { 62 | v := Default 63 | fp := 0.11 64 | v.FrequencyPenalty = fp 65 | mt := 123 66 | v.MaxTokens = &mt 67 | v.Temperature = 0.22 68 | v.TopP = 0.33 69 | v.Model = "llama3:custom" 70 | 71 | t.Setenv("OLLAMA_API_KEY", "k") 72 | if err := v.Setup(); err != nil { 73 | t.Fatalf("setup failed: %v", err) 74 | } 75 | if v.StreamCompleter.Model != v.Model { 76 | t.Errorf("expected Model %q, got %q", v.Model, v.StreamCompleter.Model) 77 | } 78 | if v.StreamCompleter.FrequencyPenalty == nil || *v.StreamCompleter.FrequencyPenalty != v.FrequencyPenalty { 79 | t.Errorf("frequency penalty not mapped, got %#v want %v", v.StreamCompleter.FrequencyPenalty, v.FrequencyPenalty) 80 | } 81 | if v.StreamCompleter.MaxTokens == nil || *v.StreamCompleter.MaxTokens != *v.MaxTokens { 82 | t.Errorf("max tokens not mapped, got %#v want %v", v.StreamCompleter.MaxTokens, *v.MaxTokens) 83 | } 84 | if v.StreamCompleter.Temperature == nil || *v.StreamCompleter.Temperature != v.Temperature { 85 | t.Errorf("temperature not mapped, got %#v want %v", v.StreamCompleter.Temperature, v.Temperature) 86 | } 87 | if v.StreamCompleter.TopP == nil || *v.StreamCompleter.TopP != v.TopP { 88 | t.Errorf("top_p not mapped, got %#v want %v", v.StreamCompleter.TopP, v.TopP) 89 | } 90 | if v.ToolChoice == nil || *v.ToolChoice != "auto" { 91 | t.Errorf("tool choice expected 'auto', got %#v", v.ToolChoice) 92 | } 93 | } 94 | 95 | func TestSetupSetsDefaultEnvWhenMissingOLLAMA(t *testing.T) { 96 | v := Default 97 | t.Setenv("OLLAMA_API_KEY", "") 98 | if err := v.Setup(); err != nil { 99 | t.Fatalf("setup failed: %v", err) 100 | } 101 | if got := os.Getenv("OLLAMA_API_KEY"); got == "" { 102 | t.Fatalf("expected OLLAMA_API_KEY to be set by Setup") 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /internal/utils/misc_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestGetFirstTokens(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | prompt []string 12 | n int 13 | want []string 14 | }{ 15 | { 16 | name: "Empty prompt", 17 | prompt: []string{}, 18 | n: 5, 19 | want: []string{}, 20 | }, 21 | { 22 | name: "Prompt with less than n tokens", 23 | prompt: []string{"Hello", "World"}, 24 | n: 5, 25 | want: []string{"Hello", "World"}, 26 | }, 27 | { 28 | name: "Prompt with exactly n tokens", 29 | prompt: []string{"This", "is", "a", "test", "prompt"}, 30 | n: 5, 31 | want: []string{"This", "is", "a", "test", "prompt"}, 32 | }, 33 | { 34 | name: "Prompt with more than n tokens", 35 | prompt: []string{"This", "is", "a", "longer", "test", "prompt"}, 36 | n: 4, 37 | want: []string{"This", "is", "a", "longer"}, 38 | }, 39 | { 40 | name: "Prompt with empty tokens", 41 | prompt: []string{"", "Hello", "", "World", ""}, 42 | n: 3, 43 | want: []string{"Hello", "World"}, 44 | }, 45 | } 46 | 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | got := GetFirstTokens(tt.prompt, tt.n) 50 | if !reflect.DeepEqual(got, tt.want) { 51 | t.Errorf("GetFirstTokens() = %v, want %v", got, tt.want) 52 | } 53 | }) 54 | } 55 | } 56 | 57 | func TestDeleteRange(t *testing.T) { 58 | orig := []int{1, 2, 3, 4, 5, 6, 7, 8, 9} 59 | 60 | t.Run("middle range", func(t *testing.T) { 61 | tt := make([]int, len(orig)) 62 | copy(tt, orig) 63 | res, _ := DeleteRange(tt, 2, 5) // Should remove 3,4,5,6 (indices 2-5) 64 | want := []int{1, 2, 7, 8, 9} 65 | if !reflect.DeepEqual(res, want) { 66 | t.Errorf("DeleteRange() = %v, want %v", res, want) 67 | } 68 | }) 69 | t.Run("remove first", func(t *testing.T) { 70 | tt := make([]int, len(orig)) 71 | copy(tt, orig) 72 | res, _ := DeleteRange(tt, 0, 0) 73 | want := []int{2, 3, 4, 5, 6, 7, 8, 9} 74 | if !reflect.DeepEqual(res, want) { 75 | t.Errorf("DeleteRange(remove first) = %v, want %v", res, want) 76 | } 77 | }) 78 | t.Run("remove last", func(t *testing.T) { 79 | tt := make([]int, len(orig)) 80 | copy(tt, orig) 81 | res, _ := DeleteRange(tt, len(tt)-1, len(tt)-1) 82 | want := []int{1, 2, 3, 4, 5, 6, 7, 8} 83 | if !reflect.DeepEqual(res, want) { 84 | t.Errorf("DeleteRange(remove last) = %v, want %v", res, want) 85 | } 86 | }) 87 | t.Run("remove all", func(t *testing.T) { 88 | tt := make([]int, len(orig)) 89 | copy(tt, orig) 90 | res, _ := DeleteRange(tt, 0, len(tt)-1) 91 | want := []int{} 92 | if !reflect.DeepEqual(res, want) { 93 | t.Errorf("DeleteRange(remove all) = %v, want %v", res, want) 94 | } 95 | }) 96 | } 97 | 98 | func TestDeleteRangeInvalidInputs(t *testing.T) { 99 | orig := []int{1, 2, 3, 4, 5} 100 | 101 | t.Run("invalid range start greater than end", func(t *testing.T) { 102 | tt := make([]int, len(orig)) 103 | copy(tt, orig) 104 | _, err := DeleteRange(tt, 3, 2) 105 | if err == nil { 106 | t.Errorf("DeleteRange() expected error for start greater than end, got nil") 107 | } 108 | }) 109 | 110 | t.Run("start index out of bounds", func(t *testing.T) { 111 | tt := make([]int, len(orig)) 112 | copy(tt, orig) 113 | _, err := DeleteRange(tt, -1, 2) 114 | if err == nil { 115 | t.Errorf("DeleteRange() expected error for start index out of bounds, got nil") 116 | } 117 | }) 118 | 119 | t.Run("end index out of bounds", func(t *testing.T) { 120 | tt := make([]int, len(orig)) 121 | copy(tt, orig) 122 | _, err := DeleteRange(tt, 1, 10) 123 | if err == nil { 124 | t.Errorf("DeleteRange() expected error for end index out of bounds, got nil") 125 | } 126 | }) 127 | } 128 | -------------------------------------------------------------------------------- /internal/profiles/cmd.go: -------------------------------------------------------------------------------- 1 | package profiles 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | ) 14 | 15 | // SubCmd handles `clai profiles` sub-commands. 16 | // 17 | // Usage: 18 | // 19 | // clai profiles # list configured profiles 20 | // clai profiles list 21 | // 22 | // Additional sub-commands can be added later (e.g. show, delete, etc.). 23 | func SubCmd(ctx context.Context, args []string) error { 24 | _ = ctx // currently unused; kept for future expansion 25 | 26 | // We expect args[0] to be "profiles". 27 | fs := flag.NewFlagSet("profiles", flag.ContinueOnError) 28 | fs.SetOutput(nil) // silence default usage output; we handle errors ourselves 29 | 30 | if err := fs.Parse(args[1:]); err != nil { 31 | return fmt.Errorf("failed to parse profiles flags: %w", err) 32 | } 33 | 34 | rest := fs.Args() 35 | if len(rest) == 0 || rest[0] == "list" { 36 | return runProfilesList() 37 | } 38 | 39 | return fmt.Errorf("unknown profiles subcommand: %q", rest[0]) 40 | } 41 | 42 | // runProfilesList lists all static profiles from /.clai/profiles. 43 | func runProfilesList() error { 44 | configDir, err := utils.GetClaiConfigDir() 45 | if err != nil { 46 | return fmt.Errorf("failed to get clai config dir: %w", err) 47 | } 48 | 49 | profilesDir := filepath.Join(configDir, "profiles") 50 | if _, err := os.Stat(profilesDir); os.IsNotExist(err) { 51 | ancli.Warnf("no profiles directory found at %s\n", profilesDir) 52 | return utils.ErrUserInitiatedExit 53 | } 54 | 55 | files, err := os.ReadDir(profilesDir) 56 | if err != nil { 57 | return fmt.Errorf("failed to read profiles directory: %w", err) 58 | } 59 | 60 | if len(files) == 0 { 61 | ancli.Warnf("no profiles found in %s\n", profilesDir) 62 | return utils.ErrUserInitiatedExit 63 | } 64 | 65 | // local view of the on-disk profile; we only need a subset of fields here 66 | type profile struct { 67 | Name string `json:"name"` 68 | Model string `json:"model"` 69 | Tools []string `json:"tools"` 70 | Prompt string `json:"prompt"` 71 | } 72 | 73 | validCount := 0 74 | for _, f := range files { 75 | if f.IsDir() || filepath.Ext(f.Name()) != ".json" { 76 | continue 77 | } 78 | 79 | fullPath := filepath.Join(profilesDir, f.Name()) 80 | 81 | var p profile 82 | if err := utils.ReadAndUnmarshal(fullPath, &p); err != nil { 83 | // Skip malformed profile files 84 | continue 85 | } 86 | 87 | // Backwards compatible: if Name is empty, derive from filename (without .json). 88 | if strings.TrimSpace(p.Name) == "" { 89 | base := filepath.Base(f.Name()) 90 | p.Name = strings.TrimSuffix(base, filepath.Ext(base)) 91 | } 92 | 93 | fmt.Printf("Name: %s\nModel: %s\nTools: %v\nFirst sentence prompt: %s\n---\n", 94 | p.Name, 95 | p.Model, 96 | p.Tools, 97 | getFirstSentence(p.Prompt), 98 | ) 99 | validCount++ 100 | } 101 | 102 | if validCount == 0 { 103 | ancli.Warnf("no valid profiles found in %s\n", profilesDir) 104 | } 105 | 106 | return utils.ErrUserInitiatedExit 107 | } 108 | 109 | // getFirstSentence returns the first sentence / line of a prompt, used for summaries. 110 | func getFirstSentence(s string) string { 111 | if s == "" { 112 | return "[Empty prompt]" 113 | } 114 | 115 | idxDot := strings.Index(s, ".") 116 | idxExcl := strings.Index(s, "!") 117 | idxQues := strings.Index(s, "?") 118 | idxNewLine := strings.Index(s, "\n") 119 | 120 | minIdx := len(s) 121 | for _, idx := range []int{idxDot, idxExcl, idxQues, idxNewLine} { 122 | if idx != -1 && idx < minIdx { 123 | minIdx = idx 124 | } 125 | } 126 | 127 | if minIdx < len(s) { 128 | return s[:minIdx+1] 129 | } 130 | return s 131 | } 132 | -------------------------------------------------------------------------------- /internal/text/generic/stream_completer_models.go: -------------------------------------------------------------------------------- 1 | package generic 2 | 3 | import ( 4 | "net/http" 5 | 6 | pub_models "github.com/baalimago/clai/pkg/text/models" 7 | ) 8 | 9 | // StreamCompleter is a struct which follows the model for both OpenAI and Mistral 10 | type StreamCompleter struct { 11 | Model string `json:"-"` 12 | FrequencyPenalty *float64 `json:"-"` 13 | MaxTokens *int `json:"-"` 14 | PresencePenalty *float64 `json:"-"` 15 | Temperature *float64 `json:"-"` 16 | TopP *float64 `json:"-"` 17 | ToolChoice *string `json:"-"` 18 | Clean func([]pub_models.Message) []pub_models.Message `json:"-"` 19 | URL string 20 | tools []ToolSuper 21 | toolsCallName string 22 | // Argument string exists since the arguments for function calls is streamed token by token... yeah... great idea 23 | toolsCallArgsString string 24 | toolsCallID string 25 | extraContent map[string]any 26 | client *http.Client 27 | apiKey string 28 | debug bool 29 | } 30 | 31 | type ToolSuper struct { 32 | Type string `json:"type"` 33 | Function Tool `json:"function"` 34 | } 35 | 36 | type Tool struct { 37 | Name string `json:"name"` 38 | Description string `json:"description"` 39 | Inputs pub_models.InputSchema `json:"parameters,omitempty"` 40 | } 41 | 42 | type chatCompletionChunk struct { 43 | ID string `json:"id"` 44 | Object string `json:"object"` 45 | Created int `json:"created"` 46 | Model string `json:"model"` 47 | SystemFingerprint string `json:"system_fingerprint"` 48 | Choices []Choice `json:"choices"` 49 | } 50 | 51 | type Choice struct { 52 | Index int `json:"index"` 53 | Delta Delta `json:"delta"` 54 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 55 | FinishReason string `json:"finish_reason"` 56 | } 57 | 58 | type Delta struct { 59 | Content any `json:"content"` 60 | Role string `json:"role"` 61 | ToolCalls []ToolsCall `json:"tool_calls"` 62 | } 63 | 64 | type ExtraContent map[string]map[string]any 65 | 66 | type ToolsCall struct { 67 | Function Func `json:"function"` 68 | ID string `json:"id"` 69 | Index int `json:"index"` 70 | Type string `json:"type"` 71 | 72 | // ExtraContent for initially google thought_signature 73 | ExtraContent map[string]any `json:"extra_content,omitempty"` 74 | } 75 | 76 | type Func struct { 77 | Arguments string `json:"arguments"` 78 | Name string `json:"name"` 79 | } 80 | 81 | type responseFormat struct { 82 | Type string `json:"type"` 83 | } 84 | 85 | type req struct { 86 | Model string `json:"model,omitempty"` 87 | ResponseFormat responseFormat `json:"response_format,omitempty"` 88 | Messages []pub_models.Message `json:"messages,omitempty"` 89 | Stream bool `json:"stream,omitempty"` 90 | FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` 91 | MaxTokens *int `json:"max_tokens,omitempty"` 92 | PresencePenalty *float64 `json:"presence_penalty,omitempty"` 93 | Temperature *float64 `json:"temperature,omitempty"` 94 | TopP *float64 `json:"top_p,omitempty"` 95 | ToolChoice *string `json:"tool_choice,omitempty"` 96 | Tools []ToolSuper `json:"tools,omitempty"` 97 | ParalellToolCalls bool `json:"parallel_tools_call,omitempty"` 98 | } 99 | --------------------------------------------------------------------------------