├── .github └── workflows │ ├── release.yml │ └── validate.yml ├── .gitignore ├── .vscode └── launch.json ├── EXAMPLES.md ├── LICENSE ├── README.md ├── go.mod ├── go.sum ├── img ├── chats.gif ├── piping.gif └── profiles.gif ├── internal ├── chat │ ├── chat.go │ └── handler.go ├── create_queriers.go ├── glob │ ├── glob.go │ └── glob_test.go ├── models │ ├── completion │ │ └── types.go │ ├── models.go │ └── models_tests.go ├── photo │ ├── conf.go │ ├── funimation_0.go │ └── prompt.go ├── reply │ ├── replay.go │ └── reply.go ├── setup.go ├── setup │ ├── setup.go │ ├── setup_actions.go │ ├── setup_actions_test.go │ └── setup_test.go ├── setup_config_migrations.go ├── setup_config_migrations_test.go ├── setup_flags.go ├── setup_flags_test.go ├── text │ ├── conf.go │ ├── conf_profile.go │ ├── generic │ │ ├── stream_completer.go │ │ ├── stream_completer_models.go │ │ └── stream_completer_setup.go │ ├── querier.go │ ├── querier_cmd_mode.go │ ├── querier_cmd_mode_test.go │ ├── querier_setup.go │ └── querier_test.go ├── tools │ ├── bash_tool_cat.go │ ├── bash_tool_file.go │ ├── bash_tool_find.go │ ├── bash_tool_freetext_command.go │ ├── bash_tool_ls.go │ ├── bash_tool_rg.go │ ├── bash_tool_tree.go │ ├── handler.go │ ├── models.go │ ├── programming_tool_go.go │ ├── programming_tool_rows_between.go │ ├── programming_tool_rows_between_test.go │ ├── programming_tool_sed.go │ ├── programming_tool_sed_test.go │ ├── programming_tool_write_file.go │ ├── programming_tool_write_file_test.go │ ├── web_tool_website_text.go │ └── web_tool_website_text_test.go ├── utils │ ├── config.go │ ├── config_test.go │ ├── errors.go │ ├── file.go │ ├── file_test.go │ ├── input.go │ ├── misc.go │ ├── misc_test.go │ ├── print.go │ ├── print_test.go │ ├── prompt.go │ ├── prompt_test.go │ └── term.go ├── vendors │ ├── anthropic │ │ ├── claude.go │ │ ├── claude_models.go │ │ ├── claude_setup.go │ │ ├── claude_setup_test.go │ │ ├── claude_stream.go │ │ ├── claude_stream_block_events.go │ │ ├── claude_stream_test.go │ │ ├── claude_test.go │ │ └── constants.go │ ├── deepseek │ │ ├── deepseek.go │ │ ├── deepseek_setup.go │ │ └── models.go │ ├── mistral │ │ ├── mistral.go │ │ ├── mistral_setup.go │ │ └── models.go │ ├── novita │ │ ├── models.go │ │ ├── novita.go │ │ └── novita_setup.go │ ├── ollama │ │ ├── models.go │ │ ├── ollama.go │ │ └── ollama_setup.go │ └── openai │ │ ├── constants.go │ │ ├── dalle.go │ │ ├── gpt.go │ │ ├── gpt_setup.go │ │ └── models.go └── version.go ├── main.go ├── oopsies.go └── setup.sh /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Simple Go Pipeline - release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9]+.[0-9]+.[0-9]+" 7 | jobs: 8 | call-workflow: 9 | uses: baalimago/simple-go-pipeline/.github/workflows/release.yml@v0.3.0 10 | with: 11 | project-name: clai 12 | branch: main 13 | version-var: "github.com/baalimago/clai/internal.BUILD_VERSION" 14 | -------------------------------------------------------------------------------- /.github/workflows/validate.yml: -------------------------------------------------------------------------------- 1 | name: Simple Go Pipeline - validate 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | call-workflow: 11 | uses: baalimago/simple-go-pipeline/.github/workflows/validate.yml@main 12 | with: 13 | go-version: "1.24" 14 | staticcheck-version: "2025.1.1" 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.out 2 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "ChatGPT - Query", 6 | "type": "go", 7 | "request": "launch", 8 | "program": "${workspaceFolder}", 9 | "args": [ 10 | "-cm", 11 | "gpt-4-turbo-preview", 12 | "query", 13 | "I'm debugging my cli ai too. Write a short consice response." 14 | ], 15 | "env": { 16 | "NO_COLOR": "true" 17 | } 18 | }, 19 | { 20 | "name": "ChatGPT - Query - Tool", 21 | "type": "go", 22 | "request": "launch", 23 | "program": "${workspaceFolder}", 24 | "args": [ 25 | "-t", 26 | "-cm", 27 | "gpt-4-turbo", 28 | "query", 29 | "try to call the file tree command using /home/imago as agument, i'm debugging this functionality." 30 | ], 31 | "env": { 32 | "NO_COLOR": "true" 33 | } 34 | }, 35 | { 36 | "name": "Claude - Query", 37 | "type": "go", 38 | "request": "launch", 39 | "program": "${workspaceFolder}", 40 | "args": [ 41 | "query", 42 | "test" 43 | ], 44 | "env": { 45 | "NO_COLOR": "true" 46 | } 47 | }, 48 | { 49 | "name": "Claude - Query - Tool", 50 | "type": "go", 51 | "request": "launch", 52 | "program": "${workspaceFolder}", 53 | "args": [ 54 | "-t", 55 | "-cm", 56 | "claude-3-opus-20240229", 57 | "query", 58 | "try to call the file ls command on ~/, i'm debugging this functionality." 59 | ], 60 | "env": { 61 | "NO_COLOR": "true" 62 | } 63 | }, 64 | { 65 | "name": "ChatGPT - Chat - GlobMode", 66 | "type": "go", 67 | "request": "launch", 68 | "program": "${workspaceFolder}", 69 | "args": [ 70 | "-cm", 71 | "gpt-4o", 72 | "-glob", 73 | "README.md", 74 | "chat", 75 | "new", 76 | "Explain this project in 5 words" 77 | ], 78 | "env": { 79 | "NO_COLOR": "true" 80 | } 81 | }, 82 | { 83 | "name": "ChatGPT - Chat - Cmd", 84 | "type": "go", 85 | "request": "launch", 86 | "program": "${workspaceFolder}", 87 | "args": [ 88 | "-cm", 89 | "gpt-4o", 90 | "cmd", 91 | "give me a command to show my current directory" 92 | ], 93 | "env": { 94 | "NO_COLOR": "true" 95 | } 96 | } 97 | ] 98 | } 99 | -------------------------------------------------------------------------------- /EXAMPLES.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | All of the commands support xargs-like `-i`/`-I`/`-replace` flags. 4 | 5 | Example: `clai h | clai -i q Summarize this for me: {}`, this would summarize the output of `clai h`. 6 | 7 | Regardless of you wish to generate a photo, continue a chat or reply to your previous query, the prompt system remains the same. 8 | 9 | ```bash 10 | clai help `# For more info about the available commands (and shorthands)` 11 | ``` 12 | 13 | ### Queries 14 | ```bash 15 | clai query My favorite color is blue, tell me some facts about it 16 | ``` 17 | ```bash 18 | clai -re `# Use the -re flag to use the previous query as context for some next query` \ 19 | q Write a poem about my favorite colour 20 | ``` 21 | 22 | Personally I have `alias ask=clai q` and then `alias rask=clai -re q`. 23 | This way I can `ask` -> `rask` -> `rask` for a temporary conversation. 24 | 25 | Every 'temporary conversation' is also saved as a chat, so it's possible to continue it later, see below on how to list chats. 26 | 27 | ### Tooling 28 | Many vendors support function calling/tooling. 29 | This basically means that the AI model will ask *your local machine* to run some command, then it will analyze the output of said command. 30 | 31 | See all the currently available tools [here](./internal/tools/), please create an issue if you'd like to see some tool added. 32 | ```bash 33 | clai -t q `# Specify you wish to enable tools with -t/-tools` \ 34 | Analyze the project found at ~/Projects/clai and give me a brief summary of what it does 35 | ``` 36 | 37 | ### Chatting 38 | ```bash 39 | clai -chat-model claude-3-opus-20240229 ` # Using some other model` \ 40 | chat new Lets have a conversation about Hegel 41 | ``` 42 | 43 | The `-cm`/`-chat-model` flag works for any text-like command. 44 | Meaning: you can start a conversation with one chat model, then continue it with another. 45 | ```bash 46 | clai chat list 47 | ``` 48 | ```bash 49 | clai c continue Lets_have_a_conversation_about 50 | ``` 51 | 52 | ```bash 53 | clai c continue 1 kant is better `# Continue some previous chat with message ` 54 | ``` 55 | 56 | ### Globs 57 | ```bash 58 | clai -raw `# Don't format output as markdown` \ 59 | glob '*.go' Generate a README for this project > README.md 60 | ``` 61 | The `-raw` flag will ensure that the output stays what the model outputs, without `glow` or animations. 62 | 63 | Note that the glob mode also can be used by using the `-g` flag, as in, `clai -g '' query/chat/photo/cmd`. 64 | Glob-as-arg will be deprecated at some point. 65 | 66 | ### Cmd 67 | ```bash 68 | clai cmd to show all files in home 69 | ``` 70 | 71 | Will work like many of the popular command suggestion LLM tools out there. 72 | Flags works with this mode as well, such as `clai -re -g 'some_file.go' cmd to cleanup this messy code`, but it's not guaranteed the LLM will output an executable output. 73 | 74 | ### Profiles 75 | 1. `clai setup -> 2 -> n` 76 | 1. Write some profile, example 'gopher' 77 | 1. `clai -p gopher -g './internal/moneymaker/handler_test.go' q Fix the tests in this file` 78 | 79 | Profiles allows you to preconfigure certain fields which will be passed to the llms, most noteably the prompt and which tools to use. 80 | This, in turn, enables you to quickly swap between different 'LLM-modes'. 81 | 82 | For instance, you may have one profile which is prompted for golang programming tasks "gopher", it has tools `write_file`, `rip grep` and `go` enabled, and then another profile which is for terraform named "terry". 83 | With these, you don't have to 'pre-prompt' with `clai q _in terraform_ ...` or `clai q _in golang_ ...` but instead can use `clai -p terry q ...`/`clai -p gopher q ...` and also restrict which tools are allowed, as opposed to enabling _all_ tools (with `-t`). 84 | 85 | These profiles are saved as json at [os.GetConfigDir()](https://pkg.go.dev/os#UserConfigDir)`/.clai/profiles`. 86 | This means that you can sync them across all of your machines and tweak your prompts wherever you code. 87 | 88 | Yet again, I've personally utilized aliases here. 89 | `ask` -> Generic profile-less prompt 90 | `gask` -> `clai -p gopher q`, `grask` -> `clai -re -p gopher q` and then `task` -> `clai -p terry q`, etc. 91 | These aliases are later on synched with the rest of my dotfiles + clai profiles, so they're shared on all my development machines. 92 | 93 | ### Photos 94 | ```bash 95 | printf "flowers" | clai -i ` # stdin replacement works for photos also` \ 96 | --photo-prefix=flowercat ` # Sets the prefix for local photo` \ 97 | --photo-dir=/tmp ` # Sets the output dir` \ 98 | photo A cat made out of {} 99 | ``` 100 | 101 | Since -N alternatives are disabled for many newer OpenAI models, you can use [repeater](https://github.com/baalimago/repeater) to generate several responses from the same prompt: 102 | ```bash 103 | NO_COLOR=true repeater -n 10 -w 3 -increment -file out.txt -output BOTH \ 104 | clai -pp flower_INC p A cat made of flowers 105 | ``` 106 | 107 | 108 | ## Configuration 109 | `clai` will create configuration files at [os.GetConfigDir()](https://pkg.go.dev/os#UserConfigDir)`/.clai/`. 110 | First time you run `clai`, two default command-related ones, `textConfig.json` and `photoConfig.json`, will be created and then one for each specific model. 111 | The configuration presedence is as follows (from lowest to highest): 112 | 1. Default hard-coded configurations [such as this](./internal/text/conf.go), these gets written to file first time you run `clai` 113 | 1. Configurations from local `textConfig.json` or `photoConfig.json` file 114 | 1. Profiles 115 | 1. Flags 116 | 117 | The `textConfig.json/photoConfig.json` files configures _what_ you want done, not _how_ the models should perform it. 118 | This way it scales for any vendor + model. 119 | 120 | ### Models 121 | There's two ways to configure the models: 122 | 1. Set flag `-chat-model` or `-photo-model` 123 | 1. Set the `model` field in the `textConfig.json` or `photoConfig.json` file. This will make it default, if not overwritten by flags. 124 | 125 | Then, for each model, a new configuration file will be created. 126 | Since each vendor's model supports quite different configurations, the model configurations aren't exposed as flags. 127 | Instead, modify the model by adjusting its configuration file, found in [os.GetConfigDir()](https://pkg.go.dev/os#UserConfigDir)`/.clai/__.json`. 128 | This config json will in effect be unmarshaled into a request send to the model's vendor. 129 | 130 | ### Conversations 131 | Within [os.GetConfigDir()](https://pkg.go.dev/os#UserConfigDir)`/.clai/conversations` you'll find all the conversations. 132 | You can also modify the chats here as a way to prompt, or create entirely new ones as you see fit. 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 baalimago 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # clai: command line artificial intelligence 2 | 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/baalimago/clai)](https://goreportcard.com/report/github.com/baalimago/clai) 4 | ![Wakatime](https://wakatime.com/badge/user/018cc8d2-3fd9-47ef-81dc-e4ad645d5f34/project/018e07e1-bd22-4077-a213-c16290d3db52.svg) 5 | 6 | `clai` integrates AI models of multiple vendors via cli. 7 | You can generate images, text, summarize content and chat while using native terminal functionality, such as pipes and termination signals. 8 | 9 | It's not (only) a LLM powered command suggester, instead it's a cli native LLM context feeder designed to fit into each user's own workflows. 10 | 11 | The multi-vendor aspect enables easy comparisons between different models, also removes the need for multiple subscriptions: most APIs are usage-based (some with expiration time). 12 | 13 | ## Features 14 | 15 | Piping into LLM: 16 | ![piping](./img/piping.gif "Piping data into queries") 17 | 18 | Easily configurable profiles (note the built in tools!): 19 | ![profiles](./img/profiles.gif "Profiles allowing easily customized prompts") 20 | 21 | Conversation history and simple TUI to browse and continue old chats: 22 | ![chats](./img/chats.gif "Conversation history and simple GUI to continue old chats:") 23 | 24 | These are the core features which can be combined. 25 | For instance, you can pipe data into an existing chat. 26 | Continue a chat with another profile, or another chat model. 27 | 28 | All the configuration files and chats are json, so manual tweaks and manipulation is easy to do. 29 | 30 | If you have time, checkout [this blogpost](https://lorentz.app/blog-item.html?id=clai) for a slightly more structured introduction on how to use clai efficiently. 31 | 32 | ## Supported vendors 33 | 34 | - **OpenAI API Key:** Set the `OPENAI_API_KEY` env var to your [OpenAI API key](https://platform.openai.com/docs/quickstart/step-2-set-up-your-api-key). [Text models](https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo), [photo models](https://platform.openai.com/docs/models/dall-e). 35 | - **Anthropic API Key:** Set the `ANTHROPIC_API_KEY` env var to your [Anthropic API key](https://console.anthropic.com/login?returnTo=%2F). [Text models](https://docs.anthropic.com/claude/docs/models-overview#model-recommendations). 36 | - **Mistral API Key:** Set the `MISTRAL_API_KEY` env var to your [Mistral API key](https://console.mistral.ai/). [Text models](https://docs.mistral.ai/getting-started/models/) 37 | - **Deepseek:** Set the `DEEPSEEK_API_KEY` env var to your [Deepseek API key](https://api-docs.deepseek.com/). [Text models](https://api-docs.deepseek.com/quick_start/pricing) 38 | - **Novita AI:** Set the `NOVITA_API_KEY` env var to your [Novita API key](https://novita.ai/settings?utm_source=github_clai&utm_medium=github_readme&utm_campaign=link#key-management). Target the model using novita prefix, like this: `novita:`, where `` is one of the [text models](https://novita.ai/model-api/product/llm-api?utm_source=github_clai&utm_medium=github_readme&utm_campaign=link). 39 | - **Ollama:** Start your ollama server (defaults to localhost:11434). Target using model format `ollama:`, where `` is optional (defaults to llama3). Reconfigure url with `clai setup -> 1 -> ` 40 | 41 | Note that you can only use the models that you have bought an API key for. 42 | 43 | ## Get started 44 | 45 | ```bash 46 | go install github.com/baalimago/clai@latest 47 | ``` 48 | 49 | You may also use the setup script: 50 | 51 | ```bash 52 | curl -fsSL https://raw.githubusercontent.com/baalimago/clai/main/setup.sh | sh 53 | ``` 54 | 55 | Either look at `clai help` or the [examples](./EXAMPLES.md) for how to use `clai`. 56 | 57 | Install [Glow](https://github.com/charmbracelet/glow) for formatted markdown output when querying text responses. 58 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/baalimago/clai 2 | 3 | go 1.23 4 | 5 | require github.com/baalimago/go_away_boilerplate v1.3.34 6 | 7 | require golang.org/x/net v0.24.0 8 | 9 | require golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/baalimago/go_away_boilerplate v1.3.34 h1:6fzbpN/mWPYkboO9TF8F6jdV7wNQfahyNh4pQ2NxM3A= 2 | github.com/baalimago/go_away_boilerplate v1.3.34/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU= 3 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= 4 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= 5 | golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= 6 | golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= 7 | -------------------------------------------------------------------------------- /img/chats.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/5fa8b3b81c64165fb2583dbcfbb32e151423e12b/img/chats.gif -------------------------------------------------------------------------------- /img/piping.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/5fa8b3b81c64165fb2583dbcfbb32e151423e12b/img/piping.gif -------------------------------------------------------------------------------- /img/profiles.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baalimago/clai/5fa8b3b81c64165fb2583dbcfbb32e151423e12b/img/profiles.gif -------------------------------------------------------------------------------- /internal/chat/chat.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/models" 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | ) 15 | 16 | func FromPath(path string) (models.Chat, error) { 17 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("DEBUG_REPLY_MODE")) { 18 | ancli.PrintOK(fmt.Sprintf("reading chat from '%v'\n", path)) 19 | } 20 | b, err := os.ReadFile(path) 21 | if err != nil { 22 | return models.Chat{}, fmt.Errorf("failed to read file: %w", err) 23 | } 24 | var chat models.Chat 25 | err = json.Unmarshal(b, &chat) 26 | if err != nil { 27 | return models.Chat{}, fmt.Errorf("failed to decode JSON: %w", err) 28 | } 29 | 30 | return chat, nil 31 | } 32 | 33 | func Save(saveAt string, chat models.Chat) error { 34 | b, err := json.Marshal(chat) 35 | if err != nil { 36 | return fmt.Errorf("failed to encode JSON: %w", err) 37 | } 38 | fileName := path.Join(saveAt, fmt.Sprintf("%v.json", chat.ID)) 39 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("DEBUG_REPLY_MODE")) { 40 | ancli.PrintOK(fmt.Sprintf("saving chat to: '%v', content (on new line):\n'%v'\n", fileName, string(b))) 41 | } 42 | return os.WriteFile(fileName, b, 0o644) 43 | } 44 | 45 | func IDFromPrompt(prompt string) string { 46 | id := strings.Join(utils.GetFirstTokens(strings.Split(prompt, " "), 5), "_") 47 | // Slashes messes up the save path pretty bad 48 | id = strings.ReplaceAll(id, "/", ".") 49 | // You're welcome, windows users. You're also weird. 50 | id = strings.ReplaceAll(id, "\\", ".") 51 | return id 52 | } 53 | -------------------------------------------------------------------------------- /internal/create_queriers.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/baalimago/clai/internal/chat" 9 | "github.com/baalimago/clai/internal/models" 10 | "github.com/baalimago/clai/internal/photo" 11 | "github.com/baalimago/clai/internal/text" 12 | "github.com/baalimago/clai/internal/vendors/anthropic" 13 | "github.com/baalimago/clai/internal/vendors/deepseek" 14 | "github.com/baalimago/clai/internal/vendors/mistral" 15 | "github.com/baalimago/clai/internal/vendors/novita" 16 | "github.com/baalimago/clai/internal/vendors/ollama" 17 | "github.com/baalimago/clai/internal/vendors/openai" 18 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 19 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 20 | ) 21 | 22 | // CreateTextQuerier by checking the model for which vendor to use, then initiating 23 | // a TextQuerier 24 | func CreateTextQuerier(conf text.Configurations) (models.Querier, error) { 25 | var q models.Querier 26 | found := false 27 | 28 | if strings.Contains(conf.Model, "claude") { 29 | found = true 30 | defaultCpy := anthropic.CLAUDE_DEFAULT 31 | // The model determines where to check for the config using 32 | // cfgdir/vendor_model_version.json. If it doesn't find it, 33 | // it will use the default to create a new config with this 34 | // path and the default values. In there, the model needs to be 35 | // the configured model (not the default one) 36 | defaultCpy.Model = conf.Model 37 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 38 | if err != nil { 39 | return nil, fmt.Errorf("failed to create text querier: %w", err) 40 | } 41 | q = &qTmp 42 | } 43 | 44 | if strings.Contains(conf.Model, "gpt") { 45 | found = true 46 | defaultCpy := openai.GPT_DEFAULT 47 | defaultCpy.Model = conf.Model 48 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 49 | if err != nil { 50 | return nil, fmt.Errorf("failed to create text querier: %w", err) 51 | } 52 | q = &qTmp 53 | } 54 | 55 | if strings.Contains(conf.Model, "deepseek") { 56 | found = true 57 | defaultCpy := deepseek.DEEPSEEK_DEFAULT 58 | defaultCpy.Model = conf.Model 59 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 60 | if err != nil { 61 | return nil, fmt.Errorf("failed to create text querier: %w", err) 62 | } 63 | q = &qTmp 64 | } 65 | 66 | // process before mistral, in case we want to use mistral for ollama 67 | if strings.HasPrefix(conf.Model, "ollama:") || conf.Model == "ollama" { 68 | found = true 69 | defaultCpy := ollama.OLLAMA_DEFAULT 70 | if len(conf.Model) > 7 { 71 | defaultCpy.Model = conf.Model[7:] 72 | } 73 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 74 | if err != nil { 75 | return nil, fmt.Errorf("failed to create text querier: %w", err) 76 | } 77 | q = &qTmp 78 | } else if strings.HasPrefix(conf.Model, "novita:") { 79 | found = true 80 | defaultCpy := novita.NOVITA_DEFAULT 81 | defaultCpy.Model = conf.Model[7:] 82 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 83 | if err != nil { 84 | return nil, fmt.Errorf("failed to create text querier: %w", err) 85 | } 86 | q = &qTmp 87 | } else if strings.Contains(conf.Model, "mistral") || strings.Contains(conf.Model, "mixtral") { 88 | found = true 89 | defaultCpy := mistral.MINSTRAL_DEFAULT 90 | defaultCpy.Model = conf.Model 91 | qTmp, err := text.NewQuerier(conf, &defaultCpy) 92 | if err != nil { 93 | return nil, fmt.Errorf("failed to create text querier: %w", err) 94 | } 95 | q = &qTmp 96 | } 97 | 98 | if !found { 99 | return nil, fmt.Errorf("failed to find text querier for model: %v", conf.Model) 100 | } 101 | 102 | if misc.Truthy(os.Getenv("DEBUG")) { 103 | ancli.PrintOK(fmt.Sprintf("chat mode: %v\n", conf.ChatMode)) 104 | } 105 | if conf.ChatMode { 106 | tq, isTextQuerier := q.(models.ChatQuerier) 107 | if !isTextQuerier { 108 | return nil, fmt.Errorf("failed to cast Querier using model: '%v' to TextQuerier, cannot proceed to chat", conf.Model) 109 | } 110 | configDir, _ := os.UserConfigDir() 111 | chatQ, err := chat.New(tq, configDir, conf.PostProccessedPrompt, conf.InitialPrompt.Messages, chat.NotCyclicalImport{ 112 | UseTools: conf.UseTools, 113 | UseProfile: conf.UseProfile, 114 | Model: conf.Model, 115 | }) 116 | if err != nil { 117 | return nil, fmt.Errorf("failed to create chat querier: %w", err) 118 | } 119 | q = chatQ 120 | } 121 | return q, nil 122 | } 123 | 124 | func NewPhotoQuerier(conf photo.Configurations) (models.Querier, error) { 125 | if err := photo.ValidateOutputType(conf.Output.Type); err != nil { 126 | return nil, err 127 | } 128 | 129 | if conf.Output.Type == photo.LOCAL { 130 | if _, err := os.Stat(conf.Output.Dir); os.IsNotExist(err) { 131 | return nil, fmt.Errorf("failed to find photo output directory: %w", err) 132 | } 133 | } 134 | 135 | if strings.Contains(conf.Model, "dall-e") { 136 | q, err := openai.NewPhotoQuerier(conf) 137 | if err != nil { 138 | return nil, fmt.Errorf("failed to create dall-e photo querier: %w", err) 139 | } 140 | return q, nil 141 | } 142 | 143 | return nil, fmt.Errorf("failed to find photo querier for model: %v", conf.Model) 144 | } 145 | -------------------------------------------------------------------------------- /internal/glob/glob.go: -------------------------------------------------------------------------------- 1 | package glob 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/baalimago/clai/internal/models" 10 | "github.com/baalimago/clai/internal/utils" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 13 | ) 14 | 15 | // Setup the glob parsing. Currently this is a bit messy as it works 16 | // both for flag glob and arg glob. Once arg glob is deprecated, this 17 | // function may be cleaned up 18 | func Setup(flagGlob string, args []string) (string, []string, error) { 19 | globArg := args[0] == "g" || args[0] == "glob" 20 | if globArg && len(args) < 2 { 21 | return "", args, fmt.Errorf("not enough arguments provided") 22 | } 23 | glob := args[1] 24 | if globArg { 25 | if flagGlob != "" { 26 | ancli.PrintWarn(fmt.Sprintf("both glob-arg and glob-flag is specified. This is confusing. Using glob-arg query: %v\n", glob)) 27 | } 28 | args = args[1:] 29 | } else { 30 | glob = flagGlob 31 | } 32 | if !strings.Contains(glob, "*") { 33 | ancli.PrintWarn(fmt.Sprintf("found no '*' in glob: %v, has it already been expanded? Consider enclosing glob in single quotes\n", glob)) 34 | } 35 | if misc.Truthy(os.Getenv("DEBUG")) { 36 | ancli.PrintOK(fmt.Sprintf("found glob: %v\n", glob)) 37 | } 38 | return glob, args, nil 39 | } 40 | 41 | func CreateChat(glob, systemPrompt string) (models.Chat, error) { 42 | fileMessages, err := parseGlob(glob) 43 | if err != nil { 44 | return models.Chat{}, fmt.Errorf("failed to parse glob string: '%v', err: %w", glob, err) 45 | } 46 | 47 | return models.Chat{ 48 | ID: fmt.Sprintf("glob_%v", glob), 49 | Messages: constructGlobMessages(fileMessages), 50 | }, nil 51 | } 52 | 53 | func constructGlobMessages(globMessages []models.Message) []models.Message { 54 | ret := make([]models.Message, 0, len(globMessages)+4) 55 | ret = append(ret, models.Message{ 56 | Role: "system", 57 | Content: "You will be given a series of messages each containing contents from files, then a message containing this: '#####'. Using the file content as context, perform the request given in the message after the '#####'.", 58 | }) 59 | ret = append(ret, globMessages...) 60 | ret = append(ret, models.Message{ 61 | Role: "user", 62 | Content: "#####", 63 | }) 64 | return ret 65 | } 66 | 67 | func parseGlob(glob string) ([]models.Message, error) { 68 | glob, err := utils.ReplaceTildeWithHome(glob) 69 | if err != nil { 70 | return nil, fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err) 71 | } 72 | files, err := filepath.Glob(glob) 73 | ret := make([]models.Message, 0, len(files)) 74 | if err != nil { 75 | return nil, fmt.Errorf("failed to parse glob: %w", err) 76 | } 77 | if misc.Truthy(os.Getenv("DEBUG")) { 78 | ancli.PrintOK(fmt.Sprintf("found %d files: %v\n", len(files), files)) 79 | } 80 | 81 | if len(files) == 0 { 82 | return nil, fmt.Errorf("no files found") 83 | } 84 | 85 | for _, file := range files { 86 | data, err := os.ReadFile(file) 87 | if err != nil { 88 | ancli.PrintWarn(fmt.Sprintf("failed to read file: %v\n", err)) 89 | continue 90 | } 91 | ret = append(ret, models.Message{ 92 | Role: "user", 93 | Content: fmt.Sprintf("{\"fileName\": \"%v\", \"data\": \"%v\"}", file, string(data)), 94 | }) 95 | } 96 | return ret, nil 97 | } 98 | -------------------------------------------------------------------------------- /internal/glob/glob_test.go: -------------------------------------------------------------------------------- 1 | package glob 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/baalimago/clai/internal/models" 10 | ) 11 | 12 | func TestParseGlob(t *testing.T) { 13 | // Setup a mock filesystem using afero 14 | tmpDir := t.TempDir() 15 | os.WriteFile(fmt.Sprintf("%v/%v", tmpDir, "test1.txt"), []byte("content1"), 0o644) 16 | os.WriteFile(fmt.Sprintf("%v/%v", tmpDir, "test2.txt"), []byte("content2"), 0o644) 17 | 18 | // Test case 19 | tests := []struct { 20 | name string 21 | glob string 22 | want int // Number of messages expected 23 | wantErr bool 24 | }{ 25 | {"two files", fmt.Sprintf("%v/*.txt", tmpDir), 2, false}, 26 | {"no match", "*.log", 0, true}, 27 | {"invalid pattern", "[", 0, true}, 28 | {"home directory", "~/*.txt", 2, false}, // This test case will fail on windows and plan9 29 | } 30 | oldHome := os.Getenv("HOME") 31 | _ = os.Setenv("HOME", tmpDir) 32 | defer func() { 33 | _ = os.Setenv("HOME", oldHome) 34 | }() 35 | for _, tt := range tests { 36 | t.Run(tt.name, func(t *testing.T) { 37 | got, err := parseGlob(tt.glob) 38 | if (err != nil) != tt.wantErr { 39 | t.Errorf("parseGlob() error = %v, wantErr %v", err, tt.wantErr) 40 | return 41 | } 42 | if len(got) != tt.want { 43 | t.Errorf("parseGlob() got %v messages, want %v", len(got), tt.want) 44 | } 45 | }) 46 | } 47 | } 48 | 49 | func TestSetup(t *testing.T) { 50 | // Set up test cases 51 | testCases := []struct { 52 | name string 53 | args []string 54 | expectedErr bool 55 | }{ 56 | { 57 | name: "Not enough arguments", 58 | args: []string{"glob"}, 59 | expectedErr: true, 60 | }, 61 | { 62 | name: "Valid glob", 63 | args: []string{"clai", "glob", "*.go", "argument"}, 64 | expectedErr: false, 65 | }, 66 | } 67 | 68 | // Run test cases 69 | for _, tc := range testCases { 70 | t.Run(tc.name, func(t *testing.T) { 71 | flag.Parse() 72 | _, _, err := Setup("", tc.args) 73 | if tc.expectedErr && err == nil { 74 | t.Errorf("Expected an error, but got none") 75 | } 76 | if !tc.expectedErr && err != nil { 77 | t.Errorf("Unexpected error: %v", err) 78 | } 79 | }) 80 | } 81 | } 82 | 83 | func TestCreateChat(t *testing.T) { 84 | // Set up test case 85 | glob := "*.go" 86 | systemPrompt := "You are a helpful assistant." 87 | 88 | // Run the function 89 | chat, err := CreateChat(glob, systemPrompt) 90 | if err != nil { 91 | t.Fatalf("Unexpected error: %v", err) 92 | } 93 | 94 | // Check the chat ID 95 | expectedID := "glob_*.go" 96 | if chat.ID != expectedID { 97 | t.Errorf("Expected chat ID: %s, got: %s", expectedID, chat.ID) 98 | } 99 | 100 | // Check the number of messages 101 | if len(chat.Messages) < 4 { 102 | t.Errorf("Expected at least 4 messages, got: %d", len(chat.Messages)) 103 | } 104 | } 105 | 106 | func TestConstructGlobMessages(t *testing.T) { 107 | // Set up test case 108 | globMessages := []models.Message{ 109 | {Role: "user", Content: "{\"fileName\": \"file1.go\", \"data\": \"package main\"}"}, 110 | {Role: "user", Content: "{\"fileName\": \"file2.go\", \"data\": \"func main()\"}"}, 111 | } 112 | 113 | // Run the function 114 | messages := constructGlobMessages(globMessages) 115 | 116 | // Check the number of messages 117 | expectedLen := len(globMessages) + 2 118 | if len(messages) != expectedLen { 119 | t.Errorf("Expected %d messages, got: %d", expectedLen, len(messages)) 120 | } 121 | 122 | // Check the system message 123 | expectedSystemMsg := "You will be given a series of messages each containing contents from files, then a message containing this: '#####'. Using the file content as context, perform the request given in the message after the '#####'." 124 | if messages[0].Content != expectedSystemMsg { 125 | t.Errorf("Expected system message: %s, got: %s", expectedSystemMsg, messages[0].Content) 126 | } 127 | 128 | // Check the user message 129 | expectedUserMsg := "#####" 130 | if messages[len(messages)-1].Content != expectedUserMsg { 131 | t.Errorf("Expected user message: %s, got: %s", expectedUserMsg, messages[len(messages)-1].Content) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /internal/models/completion/types.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | type Type int 4 | 5 | const ( 6 | ERROR Type = iota 7 | TOKEN 8 | ) 9 | -------------------------------------------------------------------------------- /internal/models/models.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "time" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | ) 10 | 11 | type Querier interface { 12 | Query(ctx context.Context) error 13 | } 14 | 15 | type ChatQuerier interface { 16 | Querier 17 | TextQuery(context.Context, Chat) (Chat, error) 18 | } 19 | 20 | type StreamCompleter interface { 21 | // Setup the stream completer, do things like init http.Client/websocket etc 22 | // Will be called synchronously. Should return error if setup fails 23 | Setup() error 24 | 25 | // StreamCompletions and return a channel which sends CompletionsEvents. 26 | // The CompletionEvents should be a string, an error, NoopEvent or a models.Call. If there is 27 | // a catastrophic error, return the error and close the channel. 28 | StreamCompletions(context.Context, Chat) (chan CompletionEvent, error) 29 | } 30 | 31 | // A ToolBox can register tools which later on will be added to the chat completion queries 32 | type ToolBox interface { 33 | // RegisterTool registers a tool to the ToolBox 34 | RegisterTool(tools.AiTool) 35 | } 36 | 37 | type CompletionEvent any 38 | 39 | type NoopEvent struct{} 40 | 41 | type Message struct { 42 | Role string `json:"role"` 43 | Content string `json:"content,omitempty"` 44 | ToolCalls []tools.Call `json:"tool_calls,omitempty"` 45 | ToolCallID string `json:"tool_call_id,omitempty"` 46 | } 47 | 48 | type Chat struct { 49 | Created time.Time `json:"created,omitempty"` 50 | ID string `json:"id"` 51 | Messages []Message `json:"messages"` 52 | } 53 | 54 | // FirstSystemMessage returns the first encountered Message with role 'system' 55 | func (c *Chat) FirstSystemMessage() (Message, error) { 56 | for _, msg := range c.Messages { 57 | if msg.Role == "system" { 58 | return msg, nil 59 | } 60 | } 61 | return Message{}, errors.New("failed to find any system message") 62 | } 63 | 64 | func (c *Chat) FirstUserMessage() (Message, error) { 65 | for _, msg := range c.Messages { 66 | if msg.Role == "user" { 67 | return msg, nil 68 | } 69 | } 70 | return Message{}, errors.New("failed to find any user message") 71 | } 72 | -------------------------------------------------------------------------------- /internal/models/models_tests.go: -------------------------------------------------------------------------------- 1 | // This package contains test intended to be used by the implementations of the 2 | // Querier, ChatQuerier and StreamCompleter interfaces 3 | package models 4 | 5 | import ( 6 | "context" 7 | "testing" 8 | "time" 9 | 10 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 11 | ) 12 | 13 | func Querier_Context_Test(t *testing.T, q Querier) { 14 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 15 | q.Query(ctx) 16 | }, time.Second) 17 | } 18 | 19 | func ChatQuerier_Test(t *testing.T, q ChatQuerier) { 20 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 21 | q.TextQuery(ctx, Chat{}) 22 | }, time.Second) 23 | } 24 | 25 | func StreamCompleter_Test(t *testing.T, s StreamCompleter) { 26 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 27 | s.StreamCompletions(ctx, Chat{}) 28 | }, time.Second) 29 | } 30 | -------------------------------------------------------------------------------- /internal/photo/conf.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | type Configurations struct { 9 | Model string `json:"model"` 10 | // Format of the prompt, will place prompt at '%v' 11 | PromptFormat string `json:"prompt-format"` 12 | Output Output `json:"output"` 13 | Raw bool `json:"raw"` 14 | StdinReplace string `json:"-"` 15 | ReplyMode bool `json:"-"` 16 | Prompt string `json:"-"` 17 | } 18 | 19 | type Output struct { 20 | Type OutputType `json:"type"` 21 | Dir string `json:"dir"` 22 | Prefix string `json:"prefix"` 23 | } 24 | 25 | var DEFAULT = Configurations{ 26 | Model: "dall-e-3", 27 | PromptFormat: "I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS: '%v'", 28 | Output: Output{ 29 | Type: LOCAL, 30 | Dir: fmt.Sprintf("%v/Pictures", os.Getenv("HOME")), 31 | Prefix: "clai", 32 | }, 33 | } 34 | 35 | type OutputType string 36 | 37 | const ( 38 | URL OutputType = "url" 39 | LOCAL OutputType = "local" 40 | ) 41 | 42 | func ValidateOutputType(outputType OutputType) error { 43 | switch outputType { 44 | case URL, LOCAL: 45 | return nil 46 | default: 47 | return fmt.Errorf("invalid output type: %v", outputType) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /internal/photo/funimation_0.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 10 | ) 11 | 12 | func StartAnimation() func() { 13 | t0 := time.Now() 14 | ticker := time.NewTicker(time.Second / 60) 15 | stop := make(chan struct{}) 16 | termWidth, err := utils.TermWidth() 17 | if err != nil { 18 | ancli.PrintWarn(fmt.Sprintf("failed to get terminal size: %v\n", err)) 19 | termWidth = 100 20 | } 21 | go func() { 22 | for { 23 | select { 24 | case <-ticker.C: 25 | cTick := time.Since(t0) 26 | clearLine := strings.Repeat(" ", termWidth) 27 | fmt.Printf("\r%v", clearLine) 28 | fmt.Printf("\rElapsed time: %v - %v", funimation(cTick), cTick) 29 | case <-stop: 30 | return 31 | } 32 | } 33 | }() 34 | return func() { 35 | close(stop) 36 | } 37 | } 38 | 39 | func funimation(t time.Duration) string { 40 | images := []string{ 41 | "🕛", 42 | "🕧", 43 | "🕐", 44 | "🕜", 45 | "🕑", 46 | "🕝", 47 | "🕒", 48 | "🕞", 49 | "🕓", 50 | "🕟", 51 | "🕔", 52 | "🕠", 53 | "🕕", 54 | "🕡", 55 | "🕖", 56 | "🕢", 57 | "🕗", 58 | "🕣", 59 | "🕘", 60 | "🕤", 61 | "🕙", 62 | "🕥", 63 | "🕚", 64 | "🕦", 65 | } 66 | // 1 nanosecond / 23 frames = 43478260 nanoseconds. Too low brainjuice to know 67 | // why that works right now 68 | return images[int(t.Nanoseconds()/43478260)%len(images)] 69 | } 70 | -------------------------------------------------------------------------------- /internal/photo/prompt.go: -------------------------------------------------------------------------------- 1 | package photo 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "os" 8 | "path" 9 | 10 | "github.com/baalimago/clai/internal/reply" 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | ) 15 | 16 | func (c *Configurations) SetupPrompts() error { 17 | args := flag.Args() 18 | if c.ReplyMode { 19 | confDir, err := os.UserConfigDir() 20 | if err != nil { 21 | return fmt.Errorf("failed to get config dir: %w", err) 22 | } 23 | iP, err := reply.Load(path.Join(confDir, ".clai")) 24 | if err != nil { 25 | return fmt.Errorf("failed to load previous query: %w", err) 26 | } 27 | if len(iP.Messages) > 0 { 28 | replyMessages := "You will be given a serie of messages from different roles, then a prompt descibing what to do with these messages. " 29 | replyMessages += "Between the messages and the prompt, there will be this line: '-------------'." 30 | replyMessages += "The format is json with the structure {\"role\": \"\", \"content\": \"\"}. " 31 | replyMessages += "The roles are 'system' and 'user'. " 32 | b, err := json.Marshal(iP.Messages) 33 | if err != nil { 34 | return fmt.Errorf("failed to encode reply JSON: %w", err) 35 | } 36 | replyMessages = fmt.Sprintf("%vMessages:\n%v\n-------------\n", replyMessages, string(b)) 37 | c.Prompt += replyMessages 38 | } 39 | } 40 | prompt, err := utils.Prompt(c.StdinReplace, args) 41 | if err != nil { 42 | return fmt.Errorf("failed to setup prompt from stdin: %w", err) 43 | } 44 | if misc.Truthy(os.Getenv("DEBUG")) { 45 | ancli.PrintOK(fmt.Sprintf("format: '%v', prompt: '%v'\n", c.PromptFormat, prompt)) 46 | } 47 | c.Prompt += fmt.Sprintf(c.PromptFormat, prompt) 48 | return nil 49 | } 50 | -------------------------------------------------------------------------------- /internal/reply/replay.go: -------------------------------------------------------------------------------- 1 | package reply 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/baalimago/clai/internal/utils" 8 | ) 9 | 10 | func Replay(raw bool) error { 11 | prevReply, err := Load("") 12 | if err != nil { 13 | return fmt.Errorf("failed to load previous reply: %v", err) 14 | } 15 | amMessages := len(prevReply.Messages) 16 | if amMessages == 0 { 17 | return errors.New("failed to find any recent reply") 18 | } 19 | mostRecentMsg := prevReply.Messages[amMessages-1] 20 | utils.AttemptPrettyPrint(mostRecentMsg, "system", raw) 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /internal/reply/reply.go: -------------------------------------------------------------------------------- 1 | package reply 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/fs" 7 | "os" 8 | "path" 9 | "time" 10 | 11 | "github.com/baalimago/clai/internal/chat" 12 | "github.com/baalimago/clai/internal/models" 13 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 14 | ) 15 | 16 | // SaveAsPreviousQuery at claiConfDir/conversations/prevQuery.json with ID prevQuery 17 | func SaveAsPreviousQuery(claiConfDir string, msgs []models.Message) error { 18 | prevQueryChat := models.Chat{ 19 | Created: time.Now(), 20 | ID: "prevQuery", 21 | Messages: msgs, 22 | } 23 | // This check avoid storing queries without any replies, which would most likely 24 | // flood the conversations needlessly 25 | if len(msgs) > 2 { 26 | firstUserMsg, err := prevQueryChat.FirstUserMessage() 27 | if err != nil { 28 | return fmt.Errorf("failed to get first user message: %w", err) 29 | } 30 | convChat := models.Chat{ 31 | Created: time.Now(), 32 | ID: chat.IDFromPrompt(firstUserMsg.Content), 33 | Messages: msgs, 34 | } 35 | err = chat.Save(path.Join(claiConfDir, "conversations"), convChat) 36 | if err != nil { 37 | return fmt.Errorf("failed to save previous query as new conversation: %w", err) 38 | } 39 | } 40 | 41 | return chat.Save(path.Join(claiConfDir, "conversations"), prevQueryChat) 42 | } 43 | 44 | // Load the prevQuery.json from the claiConfDir/conversations directory 45 | // If claiConfDir is left empty, it will be re-constructed. The technical debt 46 | // is piling up quite fast here 47 | func Load(claiConfDir string) (models.Chat, error) { 48 | if claiConfDir == "" { 49 | confDir, err := os.UserConfigDir() 50 | if err != nil { 51 | return models.Chat{}, fmt.Errorf("failed to find home dir: %v", err) 52 | } 53 | claiConfDir = path.Join(confDir, ".clai") 54 | } 55 | 56 | c, err := chat.FromPath(path.Join(claiConfDir, "conversations", "prevQuery.json")) 57 | if err != nil { 58 | if errors.Is(err, fs.ErrNotExist) { 59 | ancli.PrintWarn("no previous query found\n") 60 | } else { 61 | return models.Chat{}, fmt.Errorf("failed to read from path: %w", err) 62 | } 63 | } 64 | return c, nil 65 | } 66 | -------------------------------------------------------------------------------- /internal/setup/setup.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "path/filepath" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/text" 11 | "github.com/baalimago/clai/internal/utils" 12 | ) 13 | 14 | type config struct { 15 | name string 16 | filePath string 17 | } 18 | 19 | type action uint8 20 | 21 | const ( 22 | unset action = iota 23 | conf 24 | del 25 | newaction 26 | confWithEditor 27 | ) 28 | 29 | func (a action) String() string { 30 | switch a { 31 | case unset: 32 | return "unset" 33 | case conf: 34 | return "[c]onfigure" 35 | case del: 36 | return "[d]el" 37 | case newaction: 38 | return "create [n]ew" 39 | case confWithEditor: 40 | return "configure with [e]ditor" 41 | default: 42 | return "unset" 43 | } 44 | } 45 | 46 | const stage_0 = `Do you wish to configure: 47 | 0. mode-files (example: /.clai/textConfig.json- or photoConfig.json) 48 | 1. model files (example: /.clai/openai-gpt-4o.json, /.clai/anthropic-claude-opus.json) 49 | 2. text generation profiles (see: "clai [h]elp [p]rofile" for additional info) 50 | [0/1/2]: ` 51 | 52 | // Run the setup to configure the different files 53 | func Run() error { 54 | fmt.Print(stage_0) 55 | 56 | input, err := utils.ReadUserInput() 57 | if err != nil { 58 | return fmt.Errorf("failed to read input while running: %w", err) 59 | } 60 | var configs []config 61 | var a action 62 | configDir, err := os.UserConfigDir() 63 | if err != nil { 64 | return fmt.Errorf("failed to get user config directory: %v", err) 65 | } 66 | claiDir := filepath.Join(configDir, ".clai") 67 | switch input { 68 | case "0": 69 | t, err := getConfigs(filepath.Join(claiDir, "*Config.json"), []string{}) 70 | if err != nil { 71 | return fmt.Errorf("failed to get configs files: %w", err) 72 | } 73 | configs = t 74 | a = conf 75 | case "1": 76 | t, err := getConfigs(filepath.Join(claiDir, "*.json"), []string{"textConfig", "photoConfig"}) 77 | if err != nil { 78 | return fmt.Errorf("failed to get configs files: %w", err) 79 | } 80 | configs = t 81 | qAct, err := queryForAction([]action{conf, del, confWithEditor}) 82 | if err != nil { 83 | return fmt.Errorf("failed to find action: %w", err) 84 | } 85 | a = qAct 86 | case "2": 87 | profilesDir := filepath.Join(claiDir, "profiles") 88 | t, err := getConfigs(filepath.Join(profilesDir, "*.json"), []string{}) 89 | if err != nil { 90 | return fmt.Errorf("failed to get configs files: %w", err) 91 | } 92 | configs = t 93 | qAct, err := queryForAction([]action{conf, del, newaction, confWithEditor}) 94 | if err != nil { 95 | return fmt.Errorf("failed to find action: %w", err) 96 | } 97 | a = qAct 98 | if a == newaction { 99 | c, err := createProFile(profilesDir) 100 | if err != nil { 101 | return fmt.Errorf("failed to create profile file: %w", err) 102 | } 103 | // Reset config list as the user most likely only wants to edit the newly configured profile 104 | configs = make([]config, 0) 105 | configs = append(configs, c) 106 | // Once new file has potentially been created, potentially alter it 107 | a = conf 108 | } 109 | case "q", "quit", "e", "exit": 110 | return utils.ErrUserInitiatedExit 111 | default: 112 | return fmt.Errorf("unrecognized selection: %v", input) 113 | } 114 | return configure(configs, a) 115 | } 116 | 117 | // createProFile, as in create profile file. I'm a very funny person. 118 | func createProFile(profilePath string) (config, error) { 119 | if _, err := os.Stat(profilePath); os.IsNotExist(err) { 120 | os.MkdirAll(profilePath, os.ModePerm) 121 | } 122 | fmt.Print("Enter profile name: ") 123 | profileName, err := utils.ReadUserInput() 124 | if err != nil { 125 | return config{}, err 126 | } 127 | newProfilePath := path.Join(profilePath, fmt.Sprintf("%v.json", profileName)) 128 | err = utils.CreateFile(newProfilePath, &text.DEFAULT_PROFILE) 129 | if err != nil { 130 | return config{}, err 131 | } 132 | return config{ 133 | name: profileName, 134 | filePath: newProfilePath, 135 | }, nil 136 | } 137 | 138 | // getConfigs using a glob, and then exclude files using strings.Contains() 139 | func getConfigs(includeGlob string, excludeContains []string) ([]config, error) { 140 | files, err := filepath.Glob(includeGlob) 141 | if err != nil { 142 | return nil, fmt.Errorf("failed to glob pattern %v: %v", includeGlob, err) 143 | } 144 | var configs []config 145 | OUTER: 146 | for _, file := range files { 147 | // The moment this becomes a performance issue it's time to think about 148 | // maybe reducing the amount of config files 149 | for _, e := range excludeContains { 150 | if strings.Contains(filepath.Base(file), e) { 151 | continue OUTER 152 | } 153 | } 154 | configs = append(configs, config{ 155 | name: filepath.Base(file), 156 | filePath: file, 157 | }) 158 | } 159 | 160 | return configs, nil 161 | } 162 | -------------------------------------------------------------------------------- /internal/setup/setup_actions_test.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestQueryForAction(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | options []action 14 | input string 15 | want action 16 | wantErr bool 17 | }{ 18 | {"Configure", []action{conf}, "c", conf, false}, 19 | {"Delete", []action{del}, "d", del, false}, 20 | {"New", []action{newaction}, "n", newaction, false}, 21 | {"Quit", []action{conf, del, newaction}, "q", unset, true}, 22 | {"Invalid", []action{conf, del, newaction}, "x", unset, true}, 23 | } 24 | 25 | for _, tt := range tests { 26 | t.Run(tt.name, func(t *testing.T) { 27 | // Simulate user input 28 | oldStdin := os.Stdin 29 | defer func() { os.Stdin = oldStdin }() 30 | r, w, _ := os.Pipe() 31 | os.Stdin = r 32 | w.Write([]byte(tt.input + "\n")) 33 | w.Close() 34 | 35 | got, err := queryForAction(tt.options) 36 | if (err != nil) != tt.wantErr { 37 | t.Errorf("queryForAction() error = %v, wantErr %v", err, tt.wantErr) 38 | return 39 | } 40 | if got != tt.want { 41 | t.Errorf("queryForAction() = %v, want %v", got, tt.want) 42 | } 43 | }) 44 | } 45 | } 46 | 47 | func TestCastPrimitive(t *testing.T) { 48 | tests := []struct { 49 | name string 50 | input any 51 | want any 52 | }{ 53 | {"String to int", "42", 42}, 54 | {"String to float", "3.14", 3.14}, 55 | {"String remains string", "hello", "hello"}, 56 | {"Boolean true", "true", true}, 57 | {"Boolean false", "false", false}, 58 | } 59 | 60 | for _, tt := range tests { 61 | t.Run(tt.name, func(t *testing.T) { 62 | got := castPrimitive(tt.input) 63 | if !reflect.DeepEqual(got, tt.want) { 64 | t.Errorf("castPrimitive() = %v, want %v", got, tt.want) 65 | } 66 | }) 67 | } 68 | } 69 | 70 | func TestGetToolsValue(t *testing.T) { 71 | oldStdin := os.Stdin 72 | defer func() { os.Stdin = oldStdin }() 73 | 74 | input := "0,2,4\n" 75 | r, w, _ := os.Pipe() 76 | os.Stdin = r 77 | w.Write([]byte(input)) 78 | w.Close() 79 | 80 | initialTools := []any{"tool1", "tool2", "tool3", "tool4", "tool5"} 81 | 82 | result, err := getToolsValue(initialTools) 83 | if err != nil { 84 | t.Fatalf("getToolsValue() error = %v", err) 85 | } 86 | 87 | // The actual tool names might be different, so we'll just check the length 88 | if len(result) != 3 { 89 | t.Errorf("getToolsValue() returned %d tools, want 3", len(result)) 90 | } 91 | } 92 | 93 | func TestReconfigureWithEditor(t *testing.T) { 94 | tests := []struct { 95 | name string 96 | editor string 97 | content string 98 | wantErr bool 99 | }{ 100 | { 101 | name: "No editor set", 102 | editor: "", 103 | content: "", 104 | wantErr: true, 105 | }, 106 | { 107 | name: "Valid editor", 108 | editor: "echo", 109 | content: "{\"test\": \"value\"}", 110 | wantErr: false, 111 | }, 112 | } 113 | 114 | for _, tt := range tests { 115 | t.Run(tt.name, func(t *testing.T) { 116 | // Setup temporary file 117 | tmpDir := t.TempDir() 118 | tmpFile := filepath.Join(tmpDir, "config.json") 119 | if err := os.WriteFile(tmpFile, []byte(tt.content), 0o644); err != nil { 120 | t.Fatal(err) 121 | } 122 | 123 | // Set environment 124 | oldEditor := os.Getenv("EDITOR") 125 | defer os.Setenv("EDITOR", oldEditor) 126 | os.Setenv("EDITOR", tt.editor) 127 | 128 | cfg := config{ 129 | name: "test", 130 | filePath: tmpFile, 131 | } 132 | 133 | err := reconfigureWithEditor(cfg) 134 | if (err != nil) != tt.wantErr { 135 | t.Errorf("reconfigureWithEditor() error = %v, wantErr %v", err, tt.wantErr) 136 | } 137 | }) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /internal/setup/setup_test.go: -------------------------------------------------------------------------------- 1 | package setup 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestGetConfigs(t *testing.T) { 11 | // Create a temporary directory for test files 12 | tempDir, err := os.MkdirTemp("", "test_configs") 13 | if err != nil { 14 | t.Fatalf("Failed to create temp directory: %v", err) 15 | } 16 | defer os.RemoveAll(tempDir) 17 | 18 | // Create test files 19 | testFiles := []string{ 20 | "config1.json", 21 | "config2.json", 22 | "textConfig.json", 23 | "photoConfig.json", 24 | "otherFile.txt", 25 | } 26 | for _, file := range testFiles { 27 | _, err := os.Create(filepath.Join(tempDir, file)) 28 | if err != nil { 29 | t.Fatalf("Failed to create test file %s: %v", file, err) 30 | } 31 | } 32 | 33 | tests := []struct { 34 | name string 35 | includeGlob string 36 | excludeContains []string 37 | want []config 38 | }{ 39 | { 40 | name: "All JSON files", 41 | includeGlob: filepath.Join(tempDir, "*.json"), 42 | excludeContains: []string{}, 43 | want: []config{ 44 | {name: "config1.json", filePath: filepath.Join(tempDir, "config1.json")}, 45 | {name: "config2.json", filePath: filepath.Join(tempDir, "config2.json")}, 46 | {name: "photoConfig.json", filePath: filepath.Join(tempDir, "photoConfig.json")}, 47 | {name: "textConfig.json", filePath: filepath.Join(tempDir, "textConfig.json")}, 48 | }, 49 | }, 50 | { 51 | name: "Exclude text and photo configs", 52 | includeGlob: filepath.Join(tempDir, "*.json"), 53 | excludeContains: []string{"textConfig", "photoConfig"}, 54 | want: []config{ 55 | {name: "config1.json", filePath: filepath.Join(tempDir, "config1.json")}, 56 | {name: "config2.json", filePath: filepath.Join(tempDir, "config2.json")}, 57 | }, 58 | }, 59 | } 60 | 61 | for _, tt := range tests { 62 | t.Run(tt.name, func(t *testing.T) { 63 | got, err := getConfigs(tt.includeGlob, tt.excludeContains) 64 | if err != nil { 65 | t.Errorf("getConfigs() error = %v", err) 66 | return 67 | } 68 | if !reflect.DeepEqual(got, tt.want) { 69 | t.Errorf("getConfigs() = %v, want %v", got, tt.want) 70 | } 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /internal/setup_config_migrations.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/baalimago/clai/internal/photo" 9 | "github.com/baalimago/clai/internal/text" 10 | "github.com/baalimago/clai/internal/utils" 11 | "github.com/baalimago/clai/internal/vendors/openai" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | ) 15 | 16 | type oldChatConfig struct { 17 | Model string `json:"model"` 18 | SystemPrompt string `json:"system_prompt"` 19 | Raw bool `json:"raw"` 20 | URL string `json:"url"` 21 | FrequencyPenalty float64 `json:"frequency_penalty"` 22 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 23 | PresencePenalty float64 `json:"presence_penalty"` 24 | Temperature float64 `json:"temperature"` 25 | TopP float64 `json:"top_p"` 26 | } 27 | 28 | type oldPhotoConfig struct { 29 | Model string `json:"model"` 30 | PictureDir string `json:"photo-dir"` 31 | PicturePrefix string `json:"photo-prefix"` 32 | PromptFormat string `json:"prompt-format"` 33 | } 34 | 35 | // migrateOldChatConfig by first checking if file chatConfig exists, then 36 | // reading + copying the fields to the new text.Configrations struct. Then write the 37 | // file as textConfig. For the remaining fields, create vendor specific gpt4TurboPreview 38 | // struct and write that to gpt4TurboPreview.json. 39 | func migrateOldChatConfig(configDirPath string) error { 40 | oldChatConfigPath := fmt.Sprintf("%v/chatConfig.json", configDirPath) 41 | if _, err := os.Stat(oldChatConfigPath); os.IsNotExist(err) { 42 | // Nothing to migrate 43 | return nil 44 | } 45 | var oldConf oldChatConfig 46 | err := utils.ReadAndUnmarshal(oldChatConfigPath, &oldConf) 47 | if err != nil { 48 | return fmt.Errorf("failed to unmarshal old photo config: %w", err) 49 | } 50 | ancli.PrintOK("migrating old chat config to new format in textConfg.json\n") 51 | migratedTextConfig := text.Configurations{ 52 | Model: oldConf.Model, 53 | SystemPrompt: oldConf.SystemPrompt, 54 | } 55 | 56 | err = os.Remove(oldChatConfigPath) 57 | if err != nil { 58 | return fmt.Errorf("failed to remove old chatConfig: %w", err) 59 | } 60 | err = utils.CreateFile(fmt.Sprintf("%v/textConfig.json", configDirPath), &migratedTextConfig) 61 | if err != nil { 62 | return fmt.Errorf("failed to write new text config: %w", err) 63 | } 64 | 65 | migratedChatgptConfig := openai.ChatGPT{ 66 | FrequencyPenalty: oldConf.FrequencyPenalty, 67 | MaxTokens: oldConf.MaxTokens, 68 | PresencePenalty: oldConf.PresencePenalty, 69 | Temperature: oldConf.Temperature, 70 | TopP: oldConf.TopP, 71 | Model: oldConf.Model, 72 | Url: oldConf.URL, 73 | } 74 | 75 | err = utils.CreateFile(fmt.Sprintf("%v/openai_gpt_%v.json", configDirPath, oldConf.Model), &migratedChatgptConfig) 76 | if err != nil { 77 | return fmt.Errorf("failed to write gpt4 turbo preview config: %w", err) 78 | } 79 | return nil 80 | } 81 | 82 | // migrateOldPhotoConfig by attempting to read and unmarshal the photoConfig.json file 83 | // and transferring the fields which are applicable to the new photo.Configurations struct. 84 | // Then writes the new photoConfig.json file. 85 | func migrateOldPhotoConfig(configDirPath string) error { 86 | oldPhotoConfigPath := fmt.Sprintf("%v/photoConfig.json", configDirPath) 87 | if _, err := os.Stat(oldPhotoConfigPath); os.IsNotExist(err) { 88 | // Nothing to migrate, return 89 | return nil 90 | } 91 | var oldConf oldPhotoConfig 92 | err := utils.ReadAndUnmarshal(oldPhotoConfigPath, &oldConf) 93 | if err != nil { 94 | return fmt.Errorf("failed to unmarshal old photo config: %w", err) 95 | } 96 | if misc.Truthy(os.Getenv("DEBUG")) { 97 | ancli.PrintOK(fmt.Sprintf("oldConf: %+v\n", oldConf)) 98 | } 99 | if oldConf.PictureDir == "" { 100 | // Field is empty only if the photoConfig already has been migrated. Super hacky dodge, but good enough for now 101 | return nil 102 | } 103 | newFilePath := path.Join(configDirPath, "photoConfig.json") 104 | ancli.PrintOK(fmt.Sprintf("migrating old photo config to new format saved to: '%v'\n", newFilePath)) 105 | migratedPhotoConfig := photo.Configurations{ 106 | Model: oldConf.Model, 107 | PromptFormat: oldConf.PromptFormat, 108 | Output: photo.Output{ 109 | Type: photo.LOCAL, 110 | Dir: oldConf.PictureDir, 111 | Prefix: oldConf.PicturePrefix, 112 | }, 113 | } 114 | err = os.Remove(oldPhotoConfigPath) 115 | if err != nil { 116 | return fmt.Errorf("failed to remove old photoConfig: %w", err) 117 | } 118 | err = utils.CreateFile(newFilePath, &migratedPhotoConfig) 119 | if err != nil { 120 | return fmt.Errorf("failed to write new chat config: %w", err) 121 | } 122 | 123 | return nil 124 | } 125 | -------------------------------------------------------------------------------- /internal/setup_config_migrations_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/baalimago/clai/internal/photo" 9 | "github.com/baalimago/clai/internal/utils" 10 | ) 11 | 12 | func TestMigrateOldChatConfig(t *testing.T) { 13 | // Create a temporary directory for testing 14 | tempDir, err := os.MkdirTemp("", "test") 15 | if err != nil { 16 | t.Fatalf("failed to create temp dirr: %v", err) 17 | } 18 | defer os.RemoveAll(tempDir) 19 | 20 | // Create an old chat config file 21 | oldChatConfig := oldChatConfig{ 22 | Model: "gpt-3.5-turbo", 23 | SystemPrompt: "You are a helpful assistant.", 24 | FrequencyPenalty: 0.5, 25 | MaxTokens: nil, 26 | PresencePenalty: 0.5, 27 | Temperature: 0.8, 28 | TopP: 1.0, 29 | URL: "https://api.openai.com", 30 | } 31 | oldChatConfigPath := filepath.Join(tempDir, "chatConfig.json") 32 | err = utils.CreateFile(oldChatConfigPath, &oldChatConfig) 33 | if err != nil { 34 | t.Fatalf("failed to create file: %v", err) 35 | } 36 | 37 | // Run the migration function 38 | err = migrateOldChatConfig(tempDir) 39 | if err != nil { 40 | t.Fatalf("failed to migrate old chat config: %v", err) 41 | } 42 | 43 | // Check if the new text config file is created 44 | newTextConfigPath := filepath.Join(tempDir, "textConfig.json") 45 | _, err = os.Stat(newTextConfigPath) 46 | if err != nil { 47 | t.Fatalf("failed to find new config file: %v", err) 48 | } 49 | 50 | // Check if the old chat config file is removed 51 | _, err = os.Stat(oldChatConfigPath) 52 | if !os.IsNotExist(err) { 53 | t.Fatalf("failed to remove old chat config file: %v", err) 54 | } 55 | 56 | // Check if the new vendor-specific config file is created 57 | newVendorConfigPath := filepath.Join(tempDir, "openai_gpt_gpt-3.5-turbo.json") 58 | _, err = os.Stat(newVendorConfigPath) 59 | if err != nil { 60 | t.Fatalf("failed to create new config: %v", err) 61 | } 62 | } 63 | 64 | func TestMigrateOldPhotoConfig(t *testing.T) { 65 | // Create a temporary directory for testing 66 | tempDir := t.TempDir() 67 | 68 | // Create an old photoConfig.json file with test data 69 | oldPhotoConfigData := `{ 70 | "model": "test-model", 71 | "photo-dir": "test-photo-dir", 72 | "photo-prefix": "test-photo-prefix", 73 | "prompt-format": "test-prompt-format" 74 | }` 75 | oldPhotoConfigPath := filepath.Join(tempDir, "photoConfig.json") 76 | err := os.WriteFile(oldPhotoConfigPath, []byte(oldPhotoConfigData), 0o644) 77 | if err != nil { 78 | t.Fatalf("Failed to create old photoConfig.json: %v", err) 79 | } 80 | 81 | // Call migrateOldPhotoConfig 82 | err = migrateOldPhotoConfig(tempDir) 83 | if err != nil { 84 | t.Fatalf("migrateOldPhotoConfig failed: %v", err) 85 | } 86 | 87 | // Check if the new photoConfig.json file was created 88 | newPhotoConfigPath := filepath.Join(tempDir, "photoConfig.json") 89 | if _, err := os.Stat(newPhotoConfigPath); os.IsNotExist(err) { 90 | t.Error("New photoConfig.json file was not created") 91 | } 92 | 93 | // Read the new photoConfig.json file and check its contents 94 | var newPhotoConfig photo.Configurations 95 | err = utils.ReadAndUnmarshal(newPhotoConfigPath, &newPhotoConfig) 96 | if err != nil { 97 | t.Fatalf("Failed to read new photoConfig.json: %v", err) 98 | } 99 | 100 | expectedPhotoConfig := photo.Configurations{ 101 | Model: "test-model", 102 | PromptFormat: "test-prompt-format", 103 | Output: photo.Output{ 104 | Type: photo.LOCAL, 105 | Dir: "test-photo-dir", 106 | Prefix: "test-photo-prefix", 107 | }, 108 | } 109 | 110 | if newPhotoConfig != expectedPhotoConfig { 111 | t.Errorf("Unexpected photo config.\nExpected: %+v\nGot: %+v", expectedPhotoConfig, newPhotoConfig) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /internal/setup_flags_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "flag" 5 | "os" 6 | "testing" 7 | 8 | "github.com/baalimago/clai/internal/text" 9 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 10 | ) 11 | 12 | // helper function to reset flags between tests 13 | func resetFlags() { 14 | flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) 15 | } 16 | 17 | func TestSetupFlags(t *testing.T) { 18 | testCases := []struct { 19 | name string 20 | args []string 21 | defaults Configurations 22 | expected Configurations 23 | }{ 24 | { 25 | name: "Default Values", 26 | args: []string{"cmd"}, 27 | defaults: Configurations{ 28 | ChatModel: "gpt-4-turbo-preview", 29 | PhotoModel: "dall-e-3", 30 | PhotoPrefix: "clai", 31 | PhotoDir: "picDir", 32 | StdinReplace: "stdInReplace", 33 | PrintRaw: false, 34 | ReplyMode: false, 35 | }, 36 | expected: Configurations{ 37 | ChatModel: "gpt-4-turbo-preview", 38 | PhotoModel: "dall-e-3", 39 | PhotoPrefix: "clai", 40 | PhotoDir: "picDir", 41 | StdinReplace: "stdInReplace", 42 | PrintRaw: false, 43 | ReplyMode: false, 44 | }, 45 | }, 46 | { 47 | name: "Short Flags", 48 | args: []string{"cmd", "-cm", "gpt-4", "-pm", "dall-e-2", "-pd", "/tmp", "-pp", "test-", "-I", "[stdin]", "-r", "-re"}, 49 | defaults: Configurations{}, 50 | expected: Configurations{ 51 | ChatModel: "gpt-4", 52 | PhotoModel: "dall-e-2", 53 | PhotoDir: "/tmp", 54 | PhotoPrefix: "test-", 55 | StdinReplace: "[stdin]", 56 | PrintRaw: true, 57 | ReplyMode: true, 58 | }, 59 | }, 60 | { 61 | name: "Long Flags", 62 | args: []string{"cmd", "-chat-model", "gpt-4", "-photo-model", "dall-e-2", "-photo-dir", "/tmp", "-photo-prefix", "test-", "-replace", "[stdin]", "-raw", "-reply"}, 63 | defaults: Configurations{}, 64 | expected: Configurations{ 65 | ChatModel: "gpt-4", 66 | PhotoModel: "dall-e-2", 67 | PhotoDir: "/tmp", 68 | PhotoPrefix: "test-", 69 | StdinReplace: "[stdin]", 70 | PrintRaw: true, 71 | ReplyMode: true, 72 | }, 73 | }, 74 | { 75 | name: "Precedence", 76 | args: []string{"cmd", "-cm", "gpt-4-short", "-pm", "dall-e-2-short"}, 77 | defaults: Configurations{ 78 | ChatModel: "shouldBeReplaced", 79 | PhotoModel: "shouldBeReplaced", 80 | }, 81 | expected: Configurations{ 82 | ChatModel: "gpt-4-short", 83 | PhotoModel: "dall-e-2-short", 84 | }, 85 | }, 86 | { 87 | name: "-i should cause stdin replace", 88 | args: []string{"cmd", "-i"}, 89 | defaults: Configurations{ 90 | ChatModel: "gpt-4", 91 | PhotoModel: "dall-e-2", 92 | PhotoDir: "/tmp", 93 | PhotoPrefix: "test-", 94 | StdinReplace: "{}", 95 | PrintRaw: true, 96 | ReplyMode: true, 97 | ExpectReplace: false, 98 | }, 99 | expected: Configurations{ 100 | ChatModel: "gpt-4", 101 | PhotoModel: "dall-e-2", 102 | PhotoDir: "/tmp", 103 | PhotoPrefix: "test-", 104 | StdinReplace: "{}", 105 | PrintRaw: true, 106 | ReplyMode: true, 107 | ExpectReplace: true, 108 | }, 109 | }, 110 | } 111 | 112 | for _, tc := range testCases { 113 | t.Run(tc.name, func(t *testing.T) { 114 | resetFlags() 115 | os.Args = tc.args 116 | result := setupFlags(tc.defaults) 117 | if result != tc.expected { 118 | t.Errorf("Expected %+v, but got %+v", tc.expected, result) 119 | } 120 | }) 121 | } 122 | } 123 | 124 | func Test_applyFlagOverridesForTest(t *testing.T) { 125 | testCases := []struct { 126 | desc string 127 | given text.Configurations 128 | flagSet Configurations 129 | defaultFlags Configurations 130 | want text.Configurations 131 | }{ 132 | { 133 | desc: "it should set stdinput config if flagged and default is empty", 134 | given: text.Configurations{ 135 | StdinReplace: "", 136 | }, 137 | flagSet: Configurations{ 138 | ExpectReplace: true, 139 | StdinReplace: "{}", 140 | }, 141 | // Use real defualtFlags here to check for regressions if defaults change 142 | defaultFlags: defaultFlags, 143 | want: text.Configurations{ 144 | StdinReplace: "{}", 145 | }, 146 | }, 147 | } 148 | 149 | for _, tc := range testCases { 150 | t.Run(tc.desc, func(t *testing.T) { 151 | applyFlagOverridesForText(&tc.given, tc.flagSet, tc.defaultFlags) 152 | testboil.FailTestIfDiff(t, tc.given.StdinReplace, tc.want.StdinReplace) 153 | }) 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /internal/text/conf.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/chat" 8 | "github.com/baalimago/clai/internal/glob" 9 | "github.com/baalimago/clai/internal/models" 10 | "github.com/baalimago/clai/internal/reply" 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 14 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 15 | ) 16 | 17 | // Configurations used to setup the requirements of text models 18 | type Configurations struct { 19 | Model string `json:"model"` 20 | SystemPrompt string `json:"system-prompt"` 21 | CmdModePrompt string `json:"cmd-mode-prompt"` 22 | Raw bool `json:"raw"` 23 | UseTools bool `json:"use-tools"` 24 | TokenWarnLimit int `json:"token-warn-limit"` 25 | SaveReplyAsConv bool `json:"save-reply-as-prompt"` 26 | ConfigDir string `json:"-"` 27 | StdinReplace string `json:"-"` 28 | Stream bool `json:"-"` 29 | ReplyMode bool `json:"-"` 30 | ChatMode bool `json:"-"` 31 | CmdMode bool `json:"-"` 32 | Glob string `json:"-"` 33 | InitialPrompt models.Chat `json:"-"` 34 | UseProfile string `json:"-"` 35 | Tools []string `json:"-"` 36 | // PostProccessedPrompt which has had it's strings replaced etc 37 | PostProccessedPrompt string `json:"-"` 38 | } 39 | 40 | // Profile which allows for specialized ai configurations for specific tasks 41 | type Profile struct { 42 | Name string `json:"-"` 43 | Model string `json:"model"` 44 | UseTools bool `json:"use_tools"` 45 | Tools []string `json:"tools"` 46 | Prompt string `json:"prompt"` 47 | SaveReplyAsConv bool `json:"save-reply-as-conv"` 48 | } 49 | 50 | var DEFAULT = Configurations{ 51 | Model: "gpt-4.1", 52 | SystemPrompt: "You are an assistant for a CLI tool. Answer concisely and informatively. Prefer markdown if possible.", 53 | CmdModePrompt: "You are an assistant for a CLI tool aiding with cli tool suggestions. Write ONLY the command and nothing else. Disregard any queries asking for anything except a bash command. Do not shell escape single or double quotes.", 54 | Raw: false, 55 | UseTools: false, 56 | // Aproximately $1 for the worst input rates as of 2024-05 57 | TokenWarnLimit: 17000, 58 | SaveReplyAsConv: true, 59 | } 60 | 61 | var DEFAULT_PROFILE = Profile{ 62 | UseTools: true, 63 | SaveReplyAsConv: true, 64 | Tools: []string{}, 65 | } 66 | 67 | func (c *Configurations) SetupPrompts(args []string) error { 68 | if c.Glob != "" && c.ReplyMode { 69 | ancli.PrintWarn("Using glob + reply modes together might yield strange results. The prevQuery will be appended after the glob messages.\n") 70 | } 71 | 72 | if !c.ReplyMode { 73 | c.InitialPrompt = models.Chat{ 74 | Messages: []models.Message{ 75 | {Role: "system", Content: c.SystemPrompt}, 76 | }, 77 | } 78 | } 79 | if c.Glob != "" { 80 | globChat, err := glob.CreateChat(c.Glob, c.SystemPrompt) 81 | if err != nil { 82 | return fmt.Errorf("failed to get glob chat: %w", err) 83 | } 84 | if misc.Truthy(os.Getenv("DEBUG")) { 85 | ancli.PrintOK(fmt.Sprintf("glob messages: %v", globChat.Messages)) 86 | } 87 | c.InitialPrompt = globChat 88 | } 89 | 90 | if c.ReplyMode { 91 | iP, err := reply.Load(c.ConfigDir) 92 | if err != nil { 93 | return fmt.Errorf("failed to load previous query: %w", err) 94 | } 95 | c.InitialPrompt.Messages = append(c.InitialPrompt.Messages, iP.Messages...) 96 | 97 | if c.CmdMode { 98 | // Replace the initial message with the cmd prompt. This sort of 99 | // destroys the history, but since the conversation might be long it's fine 100 | c.InitialPrompt.Messages[0].Content = c.SystemPrompt 101 | } 102 | } 103 | 104 | prompt, err := utils.Prompt(c.StdinReplace, args) 105 | if err != nil { 106 | return fmt.Errorf("failed to setup prompt: %w", err) 107 | } 108 | // If chatmode, the initial message will be handled by the chat querier 109 | if !c.ChatMode { 110 | c.InitialPrompt.Messages = append(c.InitialPrompt.Messages, models.Message{ 111 | Role: "user", 112 | Content: prompt, 113 | }) 114 | } 115 | 116 | if misc.Truthy(os.Getenv("DEBUG")) { 117 | ancli.PrintOK(fmt.Sprintf("InitialPrompt: %v\n", debug.IndentedJsonFmt(c.InitialPrompt))) 118 | } 119 | c.PostProccessedPrompt = prompt 120 | if c.InitialPrompt.ID == "" { 121 | c.InitialPrompt.ID = chat.IDFromPrompt(prompt) 122 | } 123 | return nil 124 | } 125 | -------------------------------------------------------------------------------- /internal/text/conf_profile.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/baalimago/clai/internal/utils" 9 | ) 10 | 11 | func findProfile(profileName string) (Profile, error) { 12 | cfg, _ := os.UserConfigDir() 13 | profilePath := path.Join(cfg, ".clai", "profiles") 14 | var p Profile 15 | err := utils.ReadAndUnmarshal(path.Join(profilePath, fmt.Sprintf("%v.json", profileName)), &p) 16 | if err != nil { 17 | p.Name = profileName 18 | return p, err 19 | } 20 | return p, nil 21 | } 22 | 23 | func (c *Configurations) ProfileOverrides() error { 24 | if c.UseProfile == "" { 25 | return nil 26 | } 27 | profile, err := findProfile(c.UseProfile) 28 | if err != nil { 29 | return fmt.Errorf("failed to find profile: %w", err) 30 | } 31 | c.Model = profile.Model 32 | newPrompt := profile.Prompt 33 | if c.CmdMode { 34 | // SystmePrompt here is CmdPrompt, keep it and remoind llm to only suggest cmd 35 | newPrompt = fmt.Sprintf("You will get this pattern: || | ||. It is VERY vital that you DO NOT disobey the with whatever is posted in 0 { 67 | reqData.Tools = s.tools 68 | reqData.ToolChoice = s.ToolChoice 69 | } 70 | if s.debug { 71 | ancli.PrintOK(fmt.Sprintf("generic streamcompleter request: %v\n", debug.IndentedJsonFmt(reqData))) 72 | } 73 | jsonData, err := json.Marshal(reqData) 74 | if err != nil { 75 | return nil, fmt.Errorf("failed to encode JSON: %w", err) 76 | } 77 | 78 | req, err := http.NewRequestWithContext(ctx, "POST", s.url, bytes.NewBuffer(jsonData)) 79 | if err != nil { 80 | return nil, fmt.Errorf("failed to create request: %w", err) 81 | } 82 | 83 | req.Header.Set("Content-Type", "application/json") 84 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", s.apiKey)) 85 | req.Header.Set("Accept", "text/event-stream") 86 | req.Header.Set("Connection", "keep-alive") 87 | return req, nil 88 | } 89 | 90 | func (s *StreamCompleter) handleStreamResponse(ctx context.Context, res *http.Response) (chan models.CompletionEvent, error) { 91 | outChan := make(chan models.CompletionEvent) 92 | go func() { 93 | br := bufio.NewReader(res.Body) 94 | defer func() { 95 | res.Body.Close() 96 | close(outChan) 97 | }() 98 | for { 99 | if ctx.Err() != nil { 100 | close(outChan) 101 | return 102 | } 103 | token, err := br.ReadBytes('\n') 104 | if err != nil { 105 | outChan <- fmt.Errorf("failed to read line: %w", err) 106 | } 107 | outChan <- s.handleStreamChunk(token) 108 | } 109 | }() 110 | 111 | return outChan, nil 112 | } 113 | 114 | func (s *StreamCompleter) handleStreamChunk(token []byte) models.CompletionEvent { 115 | token = bytes.TrimPrefix(token, dataPrefix) 116 | token = bytes.TrimSpace(token) 117 | if string(token) == "[DONE]" { 118 | return models.NoopEvent{} 119 | } 120 | 121 | if s.debug { 122 | ancli.PrintOK(fmt.Sprintf("token: %+v\n", string(token))) 123 | } 124 | var chunk chatCompletionChunk 125 | err := json.Unmarshal(token, &chunk) 126 | if err != nil { 127 | if misc.Truthy(os.Getenv("DEBUG")) { 128 | // Expect some failing unmarshalls, which seems to be fine 129 | ancli.PrintWarn(fmt.Sprintf("failed to unmarshal token: %v, err: %v\n", token, err)) 130 | return models.NoopEvent{} 131 | } 132 | } 133 | if len(chunk.Choices) == 0 { 134 | return models.NoopEvent{} 135 | } 136 | 137 | var chosen models.CompletionEvent 138 | for _, choice := range chunk.Choices { 139 | compEvent := s.handleChoice(choice) 140 | switch compEvent.(type) { 141 | // Set chosen to the first error, string 142 | case error, string, models.NoopEvent: 143 | _, isNoopEvent := chosen.(models.NoopEvent) 144 | if chosen == nil || isNoopEvent { 145 | chosen = compEvent 146 | } 147 | case tools.Call: 148 | // Always prefer tools call, if possible 149 | chosen = compEvent 150 | } 151 | } 152 | 153 | if s.debug { 154 | ancli.PrintOK(fmt.Sprintf("chosen: %T - %+v\n", chosen, chosen)) 155 | } 156 | return chosen 157 | } 158 | 159 | func (s *StreamCompleter) handleChoice(choice Choice) models.CompletionEvent { 160 | // If there is no tools call, just handle it as a strins. This works for most cases 161 | if len(choice.Delta.ToolCalls) == 0 && choice.FinishReason != "tool_calls" { 162 | return choice.Delta.Content 163 | } 164 | 165 | // Function name is only shown in first chunk of a functions call 166 | // TODO: Implement support for parallel function calls, now we only handle first tools call in list 167 | var funcName, argChunk string 168 | if len(choice.Delta.ToolCalls) > 0 && choice.Delta.ToolCalls[0].Function.Name != "" { 169 | funcName = choice.Delta.ToolCalls[0].Function.Name 170 | s.toolsCallName = choice.Delta.ToolCalls[0].Function.Name 171 | s.toolsCallID = choice.Delta.ToolCalls[0].ID 172 | } 173 | 174 | if len(choice.Delta.ToolCalls) > 0 { 175 | argChunk = choice.Delta.ToolCalls[0].Function.Arguments 176 | // The arguments is streamed as a stringified json for chatgpt, chunk by chunk, with no apparent structure 177 | s.toolsCallArgsString += argChunk 178 | if s.debug { 179 | ancli.PrintOK(fmt.Sprintf("toolsCallArgsString: %v\n", s.toolsCallArgsString)) 180 | } 181 | } 182 | 183 | if choice.FinishReason != "" || 184 | // This is an indication that chatgpt wants to call another function, or a variant of the function call 185 | (s.toolsCallArgsString != "" && argChunk == "" && funcName != "") { 186 | return s.doToolsCall() 187 | } 188 | return models.NoopEvent{} 189 | } 190 | 191 | // doToolsCall by parsing the arguments 192 | func (s *StreamCompleter) doToolsCall() models.CompletionEvent { 193 | defer func() { 194 | // Reset tools call construction strings to prepare for consequtive calls 195 | s.toolsCallName = "" 196 | s.toolsCallArgsString = "" 197 | }() 198 | var input tools.Input 199 | err := json.Unmarshal([]byte(s.toolsCallArgsString), &input) 200 | if err != nil { 201 | return fmt.Errorf("failed to unmarshal argument string: %w, argsString: %v", err, s.toolsCallArgsString) 202 | } 203 | 204 | userFunc := tools.UserFunctionFromName(s.toolsCallName) 205 | userFunc.Arguments = s.toolsCallArgsString 206 | userFunc.Inputs = nil 207 | 208 | return tools.Call{ 209 | ID: s.toolsCallID, 210 | Name: s.toolsCallName, 211 | Inputs: input, 212 | Type: "function", 213 | Function: userFunc, 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /internal/text/generic/stream_completer_models.go: -------------------------------------------------------------------------------- 1 | package generic 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/baalimago/clai/internal/models" 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | // StreamCompleter is a struct which follows the model for both OpenAI and Mistral 11 | type StreamCompleter struct { 12 | Model string `json:"-"` 13 | FrequencyPenalty *float64 `json:"-"` 14 | MaxTokens *int `json:"-"` 15 | PresencePenalty *float64 `json:"-"` 16 | Temperature *float64 `json:"-"` 17 | TopP *float64 `json:"-"` 18 | ToolChoice *string `json:"-"` 19 | Clean func([]models.Message) []models.Message `json:"-"` 20 | url string 21 | tools []ToolSuper 22 | toolsCallName string 23 | // Argument string exists since the arguments for function calls is streamed token by token... yeah... great idea 24 | toolsCallArgsString string 25 | toolsCallID string 26 | client *http.Client 27 | apiKey string 28 | debug bool 29 | } 30 | 31 | type ToolSuper struct { 32 | Type string `json:"type"` 33 | Function Tool `json:"function"` 34 | } 35 | 36 | type Tool struct { 37 | Name string `json:"name"` 38 | Description string `json:"description"` 39 | Inputs tools.InputSchema `json:"parameters"` 40 | } 41 | 42 | type chatCompletionChunk struct { 43 | Id string `json:"id"` 44 | Object string `json:"object"` 45 | Created int `json:"created"` 46 | Model string `json:"model"` 47 | SystemFingerprint string `json:"system_fingerprint"` 48 | Choices []Choice `json:"choices"` 49 | } 50 | 51 | type Choice struct { 52 | Index int `json:"index"` 53 | Delta Delta `json:"delta"` 54 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 55 | FinishReason string `json:"finish_reason"` 56 | } 57 | 58 | type Delta struct { 59 | Content any `json:"content"` 60 | Role string `json:"role"` 61 | ToolCalls []ToolsCall `json:"tool_calls"` 62 | } 63 | 64 | type ToolsCall struct { 65 | Function Func `json:"function"` 66 | ID string `json:"id"` 67 | Index int `json:"index"` 68 | Type string `json:"type"` 69 | } 70 | 71 | type Func struct { 72 | Arguments string `json:"arguments"` 73 | Name string `json:"name"` 74 | } 75 | 76 | type responseFormat struct { 77 | Type string `json:"type"` 78 | } 79 | 80 | type req struct { 81 | Model string `json:"model,omitempty"` 82 | ResponseFormat responseFormat `json:"response_format,omitempty"` 83 | Messages []models.Message `json:"messages,omitempty"` 84 | Stream bool `json:"stream,omitempty"` 85 | FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` 86 | MaxTokens *int `json:"max_tokens,omitempty"` 87 | PresencePenalty *float64 `json:"presence_penalty,omitempty"` 88 | Temperature *float64 `json:"temperature,omitempty"` 89 | TopP *float64 `json:"top_p,omitempty"` 90 | ToolChoice *string `json:"tool_choice,omitempty"` 91 | Tools []ToolSuper `json:"tools,omitempty"` 92 | ParalellToolCalls bool `json:"parallel_tools_call,omitempty"` 93 | } 94 | -------------------------------------------------------------------------------- /internal/text/generic/stream_completer_setup.go: -------------------------------------------------------------------------------- 1 | package generic 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 10 | ) 11 | 12 | func (s *StreamCompleter) Setup(apiKeyEnv, url, debugEnv string) error { 13 | apiKey := os.Getenv(apiKeyEnv) 14 | if apiKey == "" { 15 | return fmt.Errorf("environment variable '%v' not set", apiKeyEnv) 16 | } 17 | s.client = &http.Client{} 18 | s.apiKey = apiKey 19 | s.url = url 20 | 21 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv(debugEnv)) { 22 | s.debug = true 23 | } 24 | 25 | return nil 26 | } 27 | 28 | func (g *StreamCompleter) InternalRegisterTool(tool tools.AiTool) { 29 | g.tools = append(g.tools, ToolSuper{ 30 | Type: "function", 31 | Function: convertToGenericTool(tool.UserFunction()), 32 | }) 33 | } 34 | 35 | func convertToGenericTool(tool tools.UserFunction) Tool { 36 | return Tool{ 37 | Name: tool.Name, 38 | Description: tool.Description, 39 | Inputs: *tool.Inputs, 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /internal/text/querier_cmd_mode.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | 10 | "github.com/baalimago/clai/internal/utils" 11 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 12 | ) 13 | 14 | var errFormat = "code: %v, stderr: '%v'\n" 15 | 16 | func (q *Querier[C]) handleCmdMode() error { 17 | // Tokens stream end without endline 18 | fmt.Println() 19 | 20 | if q.execErr != nil { 21 | return nil 22 | } 23 | 24 | for { 25 | fmt.Print("Do you want to [e]xecute cmd, [q]uit?: ") 26 | input, err := utils.ReadUserInput() 27 | if err != nil { 28 | return err 29 | } 30 | switch strings.ToLower(input) { 31 | case "q": 32 | return nil 33 | case "e": 34 | err := q.executeLlmCmd() 35 | if err == nil { 36 | return nil 37 | } else { 38 | return fmt.Errorf("failed to execute cmd: %v", err) 39 | } 40 | default: 41 | ancli.PrintWarn(fmt.Sprintf("unrecognized command: %v, please try again\n", input)) 42 | } 43 | } 44 | } 45 | 46 | func (q *Querier[C]) executeLlmCmd() error { 47 | fullMsg, err := utils.ReplaceTildeWithHome(q.fullMsg) 48 | if err != nil { 49 | return fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err) 50 | } 51 | // Quotes are, in 99% of the time, expanded by the shell in 52 | // different ways and then passed into the shell. So when LLM 53 | // suggests a command, executeAiCmd needs to act the same (meaning) 54 | // remove/expand the quotes 55 | fullMsg = strings.ReplaceAll(fullMsg, "\"", "") 56 | split := strings.Split(fullMsg, " ") 57 | if len(split) < 1 { 58 | return errors.New("Querier.executeAiCmd: too few tokens in q.fullMsg") 59 | } 60 | cmd := split[0] 61 | args := split[1:] 62 | 63 | if len(cmd) == 0 { 64 | return errors.New("Querier.executeAiCmd: command is empty") 65 | } 66 | 67 | command := exec.Command(cmd, args...) 68 | command.Stdout = os.Stdout 69 | command.Stderr = os.Stderr 70 | err = command.Run() 71 | if err != nil { 72 | cast := &exec.ExitError{} 73 | if errors.As(err, &cast) { 74 | return fmt.Errorf(errFormat, cast.ExitCode()) 75 | } else { 76 | return fmt.Errorf("Querier.executeAiCmd - run error: %w", err) 77 | } 78 | } 79 | 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/text/querier_cmd_mode_test.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | 9 | "github.com/baalimago/clai/internal/models" 10 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 11 | ) 12 | 13 | type mockCompleter struct{} 14 | 15 | func (m mockCompleter) Setup() error { 16 | return nil 17 | } 18 | 19 | func (m mockCompleter) StreamCompletions(ctx context.Context, c models.Chat) (chan models.CompletionEvent, error) { 20 | return nil, nil 21 | } 22 | 23 | func Test_executeAiCmd(t *testing.T) { 24 | testCases := []struct { 25 | description string 26 | setup func(t *testing.T) 27 | given string 28 | want string 29 | wantErr error 30 | }{ 31 | { 32 | description: "it should run shell cmd", 33 | given: "printf 'test'", 34 | want: "'test'", 35 | wantErr: nil, 36 | }, 37 | { 38 | description: "it should work with quotes", 39 | setup: func(t *testing.T) { 40 | t.Helper() 41 | os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name())) 42 | }, 43 | given: "find ./ -name \"testfile\"", 44 | want: "./testfile\n", 45 | wantErr: nil, 46 | }, 47 | { 48 | description: "it should work without quotes", 49 | setup: func(t *testing.T) { 50 | t.Helper() 51 | os.Chdir(filepath.Dir(testboil.CreateTestFile(t, "testfile").Name())) 52 | }, 53 | given: "find ./ -name testfile", 54 | want: "./testfile\n", 55 | wantErr: nil, 56 | }, 57 | } 58 | 59 | for _, tc := range testCases { 60 | t.Run(tc.description, func(t *testing.T) { 61 | var gotErr error 62 | got := testboil.CaptureStdout(t, func(t *testing.T) { 63 | q := Querier[mockCompleter]{} 64 | if tc.setup != nil { 65 | tc.setup(t) 66 | } 67 | q.fullMsg = tc.given 68 | tmp := q.executeLlmCmd() 69 | gotErr = tmp 70 | }) 71 | if got != tc.want { 72 | t.Fatalf("expected: %v, got: %v", tc.want, got) 73 | } 74 | 75 | if gotErr != tc.wantErr { 76 | t.Fatalf("expected error: %v, got: %v", tc.wantErr, gotErr) 77 | } 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /internal/text/querier_setup.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "os/user" 9 | "path" 10 | "strings" 11 | 12 | "github.com/baalimago/clai/internal/models" 13 | "github.com/baalimago/clai/internal/tools" 14 | "github.com/baalimago/clai/internal/utils" 15 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 16 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 17 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 18 | ) 19 | 20 | func vendorType(fromModel string) (string, string, string) { 21 | if strings.Contains(fromModel, "gpt") { 22 | return "openai", "gpt", fromModel 23 | } 24 | if strings.Contains(fromModel, "claude") { 25 | return "anthropic", "claude", fromModel 26 | } 27 | if strings.Contains(fromModel, "ollama") { 28 | m := "llama3" 29 | if strings.HasPrefix(fromModel, "ollama:") { 30 | m = fromModel[7:] 31 | } 32 | return "ollama", m, fromModel 33 | } 34 | if strings.Contains(fromModel, "novita") { 35 | m := "" 36 | modelVersion := fromModel 37 | if strings.HasPrefix(fromModel, "novita:") { 38 | parts := strings.Split(fromModel[7:], "/") 39 | if len(parts) > 1 { 40 | m = parts[0] 41 | modelVersion = parts[1] 42 | } 43 | } 44 | 45 | return "novita", m, modelVersion 46 | } 47 | if strings.Contains(fromModel, "mistral") || strings.Contains(fromModel, "mixtral") { 48 | return "mistral", "mistral", fromModel 49 | } 50 | 51 | if strings.Contains(fromModel, "deepseek") { 52 | return "deepseek", "deepseek", fromModel 53 | } 54 | if strings.Contains(fromModel, "mock") { 55 | return "mock", "mock", "mock" 56 | } 57 | 58 | return "VENDOR", "NOT", "FOUND" 59 | } 60 | 61 | func NewQuerier[C models.StreamCompleter](userConf Configurations, dfault C) (Querier[C], error) { 62 | vendor, model, modelVersion := vendorType(userConf.Model) 63 | claiConfDir := userConf.ConfigDir 64 | configPath := path.Join(claiConfDir, fmt.Sprintf("%v_%v_%v.json", vendor, model, modelVersion)) 65 | querier := Querier[C]{} 66 | querier.configDir = claiConfDir 67 | var modelConf C 68 | err := utils.ReadAndUnmarshal(configPath, &modelConf) 69 | if err != nil { 70 | if errors.Is(err, os.ErrNotExist) { 71 | data, err := json.Marshal(dfault) 72 | if err != nil { 73 | return querier, fmt.Errorf("failed to marshal default model: %v, error: %w", dfault, err) 74 | } 75 | err = os.WriteFile(configPath, data, os.FileMode(0o644)) 76 | if err != nil { 77 | return querier, fmt.Errorf("failed to write default model: %v, error: %w", dfault, err) 78 | } 79 | 80 | err = utils.ReadAndUnmarshal(configPath, &modelConf) 81 | if err != nil { 82 | return querier, fmt.Errorf("failed to read default model: %v, error: %w", dfault, err) 83 | } 84 | } else { 85 | return querier, fmt.Errorf("failed to load querier of model: %v, error: %w", userConf.Model, err) 86 | } 87 | } 88 | 89 | if misc.Truthy(os.Getenv("DEBUG")) { 90 | ancli.PrintOK(fmt.Sprintf("userConf: %v\n", debug.IndentedJsonFmt(userConf))) 91 | } 92 | toolBox, ok := any(modelConf).(models.ToolBox) 93 | if ok && userConf.UseTools { 94 | if misc.Truthy(os.Getenv("DEBUG")) { 95 | ancli.PrintOK(fmt.Sprintf("Registering tools on type: %T\n", modelConf)) 96 | } 97 | // If usetools and no specific tools chocen, assume all are valid 98 | if len(userConf.Tools) == 0 { 99 | for _, tool := range tools.Tools { 100 | if misc.Truthy(os.Getenv("DEBUG")) { 101 | ancli.PrintOK(fmt.Sprintf("\tadding tool: %T\n", tool)) 102 | } 103 | toolBox.RegisterTool(tool) 104 | } 105 | } else { 106 | for _, t := range userConf.Tools { 107 | tool, exists := tools.Tools[t] 108 | if !exists { 109 | ancli.PrintWarn(fmt.Sprintf("attempted to find tool: '%v', which doesn't exist, skipping", tool)) 110 | continue 111 | } 112 | 113 | if misc.Truthy(os.Getenv("DEBUG")) { 114 | ancli.PrintOK(fmt.Sprintf("\tadding tool: %T\n", tool)) 115 | } 116 | toolBox.RegisterTool(tool) 117 | } 118 | } 119 | } 120 | 121 | err = modelConf.Setup() 122 | if err != nil { 123 | return Querier[C]{}, fmt.Errorf("failed to setup model: %w", err) 124 | } 125 | 126 | termWidth, err := utils.TermWidth() 127 | if err == nil { 128 | querier.termWidth = termWidth 129 | } else { 130 | ancli.PrintWarn(fmt.Sprintf("failed to get terminal size: %v\n", err)) 131 | } 132 | currentUser, err := user.Current() 133 | if err == nil { 134 | querier.username = currentUser.Username 135 | } else { 136 | querier.username = "user" 137 | } 138 | querier.Model = modelConf 139 | if misc.Truthy(os.Getenv("DEBUG")) { 140 | ancli.PrintOK(fmt.Sprintf("querier: %v,\n===\nmodels: %v\n", debug.IndentedJsonFmt(querier), debug.IndentedJsonFmt(modelConf))) 141 | } 142 | querier.chat = userConf.InitialPrompt 143 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("TEXT_QUERIER_DEBUG")) { 144 | querier.debug = true 145 | } 146 | querier.Raw = userConf.Raw 147 | querier.cmdMode = userConf.CmdMode 148 | querier.shouldSaveReply = !userConf.ChatMode && userConf.SaveReplyAsConv 149 | querier.tokenWarnLimit = userConf.TokenWarnLimit 150 | return querier, nil 151 | } 152 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_cat.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type CatTool UserFunction 9 | 10 | var Cat = CatTool{ 11 | Name: "cat", 12 | Description: "Display the contents of a file. Uses the linux command 'cat'.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "file": { 17 | Type: "string", 18 | Description: "The file to display the contents of.", 19 | }, 20 | "number": { 21 | Type: "boolean", 22 | Description: "Number all output lines.", 23 | }, 24 | "showEnds": { 25 | Type: "boolean", 26 | Description: "Display $ at end of each line.", 27 | }, 28 | "squeezeBlank": { 29 | Type: "boolean", 30 | Description: "Suppress repeated empty output lines.", 31 | }, 32 | }, 33 | Required: []string{"file"}, 34 | }, 35 | } 36 | 37 | func (c CatTool) Call(input Input) (string, error) { 38 | file, ok := input["file"].(string) 39 | if !ok { 40 | return "", fmt.Errorf("file must be a string") 41 | } 42 | cmd := exec.Command("cat", file) 43 | if input["number"] != nil { 44 | number, ok := input["number"].(bool) 45 | if !ok { 46 | return "", fmt.Errorf("number must be a boolean") 47 | } 48 | if number { 49 | cmd.Args = append(cmd.Args, "-n") 50 | } 51 | } 52 | if input["showEnds"] != nil { 53 | showEnds, ok := input["showEnds"].(bool) 54 | if !ok { 55 | return "", fmt.Errorf("showEnds must be a boolean") 56 | } 57 | if showEnds { 58 | cmd.Args = append(cmd.Args, "-E") 59 | } 60 | } 61 | if input["squeezeBlank"] != nil { 62 | squeezeBlank, ok := input["squeezeBlank"].(bool) 63 | if !ok { 64 | return "", fmt.Errorf("squeezeBlank must be a boolean") 65 | } 66 | if squeezeBlank { 67 | cmd.Args = append(cmd.Args, "-s") 68 | } 69 | } 70 | output, err := cmd.CombinedOutput() 71 | if err != nil { 72 | return "", fmt.Errorf("failed to run cat: %w, output: %v", err, string(output)) 73 | } 74 | return string(output), nil 75 | } 76 | 77 | func (c CatTool) UserFunction() UserFunction { 78 | return UserFunction(Cat) 79 | } 80 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_file.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type FileTypeTool UserFunction 9 | 10 | var FileType = FileTypeTool{ 11 | Name: "file_type", 12 | Description: "Determine the file type of a given file. Uses the linux command 'file'.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "file_path": { 17 | Type: "string", 18 | Description: "The path to the file to analyze.", 19 | }, 20 | "mime_type": { 21 | Type: "boolean", 22 | Description: "Whether to display the MIME type of the file.", 23 | }, 24 | }, 25 | Required: []string{"file_path"}, 26 | }, 27 | } 28 | 29 | func (f FileTypeTool) Call(input Input) (string, error) { 30 | filePath, ok := input["file_path"].(string) 31 | if !ok { 32 | return "", fmt.Errorf("file_path must be a string") 33 | } 34 | cmd := exec.Command("file", filePath) 35 | if input["mime_type"] != nil { 36 | mimeType, ok := input["mime_type"].(bool) 37 | if !ok { 38 | return "", fmt.Errorf("mime_type must be a boolean") 39 | } 40 | if mimeType { 41 | cmd.Args = append(cmd.Args, "--mime-type") 42 | } 43 | } 44 | output, err := cmd.CombinedOutput() 45 | if err != nil { 46 | return "", fmt.Errorf("failed to run file command: %w, output: %v", err, string(output)) 47 | } 48 | return string(output), nil 49 | } 50 | 51 | func (f FileTypeTool) UserFunction() UserFunction { 52 | return UserFunction(FileType) 53 | } 54 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_find.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type FindTool UserFunction 9 | 10 | var Find = FindTool{ 11 | Name: "find", 12 | Description: "Search for files in a directory hierarchy. Uses linux command 'find'.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "directory": { 17 | Type: "string", 18 | Description: "The directory to start the search from.", 19 | }, 20 | "name": { 21 | Type: "string", 22 | Description: "The name pattern to search for.", 23 | }, 24 | "type": { 25 | Type: "string", 26 | Description: "The file type to search for (f: regular file, d: directory).", 27 | }, 28 | "maxdepth": { 29 | Type: "integer", 30 | Description: "The maximum depth of directories to search.", 31 | }, 32 | }, 33 | Required: []string{"directory"}, 34 | }, 35 | } 36 | 37 | func (f FindTool) Call(input Input) (string, error) { 38 | directory, ok := input["directory"].(string) 39 | if !ok { 40 | return "", fmt.Errorf("directory must be a string") 41 | } 42 | cmd := exec.Command("find", directory) 43 | if input["name"] != nil { 44 | name, ok := input["name"].(string) 45 | if !ok { 46 | return "", fmt.Errorf("name must be a string") 47 | } 48 | cmd.Args = append(cmd.Args, "-name", name) 49 | } 50 | if input["type"] != nil { 51 | fileType, ok := input["type"].(string) 52 | if !ok { 53 | return "", fmt.Errorf("type must be a string") 54 | } 55 | cmd.Args = append(cmd.Args, "-type", fileType) 56 | } 57 | if input["maxdepth"] != nil { 58 | maxdepth, ok := input["maxdepth"].(float64) 59 | if !ok { 60 | return "", fmt.Errorf("maxdepth must be a number") 61 | } 62 | cmd.Args = append(cmd.Args, "-maxdepth", fmt.Sprintf("%v", maxdepth)) 63 | } 64 | output, err := cmd.CombinedOutput() 65 | if err != nil { 66 | return "", fmt.Errorf("failed to run find: %w, output: %v", err, string(output)) 67 | } 68 | return string(output), nil 69 | } 70 | 71 | func (f FindTool) UserFunction() UserFunction { 72 | return UserFunction(Find) 73 | } 74 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_freetext_command.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "strings" 7 | ) 8 | 9 | type FreetextCmdTool UserFunction 10 | 11 | var FreetextCmd = FreetextCmdTool{ 12 | Name: "freetext_command", 13 | Description: "Run any entered string as a terminal command.", 14 | Inputs: &InputSchema{ 15 | Type: "object", 16 | Properties: map[string]ParameterObject{ 17 | "command": { 18 | Type: "string", 19 | Description: "The freetext comand. May be any string. Will return error on non-zero exit code.", 20 | }, 21 | }, 22 | Required: []string{"command"}, 23 | }, 24 | } 25 | 26 | func (r FreetextCmdTool) Call(input Input) (string, error) { 27 | freetextCmd, ok := input["command"].(string) 28 | if !ok { 29 | return "", fmt.Errorf("freetextCmd must be a string") 30 | } 31 | freetextCmdSplit := strings.Split(freetextCmd, " ") 32 | var potentialArgsFlags []string 33 | if len(freetextCmdSplit) > 0 { 34 | potentialArgsFlags = freetextCmdSplit[1:] 35 | } 36 | cmd := exec.Command(freetextCmdSplit[0], potentialArgsFlags...) 37 | 38 | output, err := cmd.CombinedOutput() 39 | if err != nil { 40 | return "", fmt.Errorf("error: '%w', output: %v", err, string(output)) 41 | } 42 | return string(output), nil 43 | } 44 | 45 | func (r FreetextCmdTool) UserFunction() UserFunction { 46 | return UserFunction(FreetextCmd) 47 | } 48 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_ls.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type LsTool UserFunction 9 | 10 | var LS = LsTool{ 11 | Name: "ls", 12 | Description: "List the files in a directory. Uses the Linux command 'ls'.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "directory": { 17 | Type: "string", 18 | Description: "The directory to list the files of.", 19 | }, 20 | "all": { 21 | Type: "boolean", 22 | Description: "Show all files, including hidden files.", 23 | }, 24 | "long": { 25 | Type: "boolean", 26 | Description: "Use a long listing format.", 27 | }, 28 | }, 29 | Required: []string{"directory"}, 30 | }, 31 | } 32 | 33 | func (f LsTool) Call(input Input) (string, error) { 34 | directory, ok := input["directory"].(string) 35 | if !ok { 36 | return "", fmt.Errorf("directory must be a string") 37 | } 38 | cmd := exec.Command("ls", directory) 39 | if input["all"] != nil { 40 | all, ok := input["all"].(bool) 41 | if !ok { 42 | return "", fmt.Errorf("all must be a boolean") 43 | } 44 | if all { 45 | cmd.Args = append(cmd.Args, "-a") 46 | } 47 | } 48 | if input["long"] != nil { 49 | long, ok := input["long"].(bool) 50 | if !ok { 51 | return "", fmt.Errorf("long must be a boolean") 52 | } 53 | if long { 54 | cmd.Args = append(cmd.Args, "-l") 55 | } 56 | } 57 | output, err := cmd.CombinedOutput() 58 | if err != nil { 59 | return "", fmt.Errorf("failed to run ls: %w, output: %v", err, string(output)) 60 | } 61 | return string(output), nil 62 | } 63 | 64 | func (f LsTool) UserFunction() UserFunction { 65 | return UserFunction(LS) 66 | } 67 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_rg.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type RipGrepTool UserFunction 9 | 10 | var RipGrep = RipGrepTool{ 11 | Name: "rg", 12 | Description: "Search for a pattern in files using ripgrep.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "pattern": { 17 | Type: "string", 18 | Description: "The pattern to search for.", 19 | }, 20 | "path": { 21 | Type: "string", 22 | Description: "The path to search in.", 23 | }, 24 | "case_sensitive": { 25 | Type: "boolean", 26 | Description: "Whether the search should be case sensitive.", 27 | }, 28 | "line_number": { 29 | Type: "boolean", 30 | Description: "Whether to show line numbers.", 31 | }, 32 | "hidden": { 33 | Type: "boolean", 34 | Description: "Whether to search hidden files and directories.", 35 | }, 36 | }, 37 | Required: []string{"pattern"}, 38 | }, 39 | } 40 | 41 | func (r RipGrepTool) Call(input Input) (string, error) { 42 | pattern, ok := input["pattern"].(string) 43 | if !ok { 44 | return "", fmt.Errorf("pattern must be a string") 45 | } 46 | cmd := exec.Command("rg", pattern) 47 | if input["path"] != nil { 48 | path, ok := input["path"].(string) 49 | if !ok { 50 | return "", fmt.Errorf("path must be a string") 51 | } 52 | cmd.Args = append(cmd.Args, path) 53 | } 54 | if input["case_sensitive"] != nil { 55 | caseSensitive, ok := input["case_sensitive"].(bool) 56 | if !ok { 57 | return "", fmt.Errorf("case_sensitive must be a boolean") 58 | } 59 | if caseSensitive { 60 | cmd.Args = append(cmd.Args, "--case-sensitive") 61 | } 62 | } 63 | if input["line_number"] != nil { 64 | lineNumber, ok := input["line_number"].(bool) 65 | if !ok { 66 | return "", fmt.Errorf("line_number must be a boolean") 67 | } 68 | if lineNumber { 69 | cmd.Args = append(cmd.Args, "--line-number") 70 | } 71 | } 72 | if input["hidden"] != nil { 73 | hidden, ok := input["hidden"].(bool) 74 | if !ok { 75 | return "", fmt.Errorf("hidden must be a boolean") 76 | } 77 | if hidden { 78 | cmd.Args = append(cmd.Args, "--hidden") 79 | } 80 | } 81 | output, err := cmd.CombinedOutput() 82 | if err != nil { 83 | // exit status 1 is not found, and not to be considered an error 84 | if err.Error() == "exit status 1" { 85 | err = nil 86 | output = []byte(fmt.Sprintf("found no hits with pattern: '%s'", pattern)) 87 | } else { 88 | return "", fmt.Errorf("failed to run rg: %w, output: %v", err, string(output)) 89 | } 90 | } 91 | return string(output), nil 92 | } 93 | 94 | func (r RipGrepTool) UserFunction() UserFunction { 95 | return UserFunction(RipGrep) 96 | } 97 | -------------------------------------------------------------------------------- /internal/tools/bash_tool_tree.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | ) 7 | 8 | type FileTreeTool UserFunction 9 | 10 | var FileTree = FileTreeTool{ 11 | Name: "file_tree", 12 | Description: "List the filetree of some directory. Uses linux command 'tree'.", 13 | Inputs: &InputSchema{ 14 | Type: "object", 15 | Properties: map[string]ParameterObject{ 16 | "directory": { 17 | Type: "string", 18 | Description: "The directory to list the filetree of.", 19 | }, 20 | "level": { 21 | Type: "integer", 22 | Description: "The depth of the tree to display.", 23 | }, 24 | }, 25 | Required: []string{"directory"}, 26 | }, 27 | } 28 | 29 | func (f FileTreeTool) Call(input Input) (string, error) { 30 | directory, ok := input["directory"].(string) 31 | if !ok { 32 | return "", fmt.Errorf("directory must be a string") 33 | } 34 | cmd := exec.Command("tree", directory) 35 | if input["level"] != nil { 36 | level, ok := input["level"].(float64) 37 | if !ok { 38 | return "", fmt.Errorf("level must be a number") 39 | } 40 | cmd.Args = append(cmd.Args, "-L") 41 | cmd.Args = append(cmd.Args, fmt.Sprintf("%v", level)) 42 | } 43 | output, err := cmd.CombinedOutput() 44 | if err != nil { 45 | return "", fmt.Errorf("failed to run tree: %w, output: %v", err, string(output)) 46 | } 47 | return string(output), nil 48 | } 49 | 50 | func (f FileTreeTool) UserFunction() UserFunction { 51 | return UserFunction(FileTree) 52 | } 53 | -------------------------------------------------------------------------------- /internal/tools/handler.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import "fmt" 4 | 5 | var Tools = map[string]AiTool{ 6 | "file_tree": FileTree, 7 | "cat": Cat, 8 | "find": Find, 9 | "file_type": FileType, 10 | "ls": LS, 11 | "website_text": WebsiteText, 12 | "rg": RipGrep, 13 | "go": Go, 14 | "write_file": WriteFile, 15 | "freetext_command": FreetextCmd, 16 | "sed": Sed, 17 | "rows_between": RowsBetween, 18 | } 19 | 20 | // Invoke the call, and gather both error and output in the same string 21 | func Invoke(call Call) string { 22 | t, exists := Tools[call.Name] 23 | if !exists { 24 | return "ERROR: unknown tool call: " + call.Name 25 | } 26 | out, err := t.Call(call.Inputs) 27 | if err != nil { 28 | return fmt.Sprintf("ERROR: failed to run tool: %v, error: %v", call.Name, err) 29 | } 30 | return out 31 | } 32 | 33 | func UserFunctionFromName(name string) UserFunction { 34 | t, exists := Tools[name] 35 | if !exists { 36 | return UserFunction{} 37 | } 38 | return t.UserFunction() 39 | } 40 | -------------------------------------------------------------------------------- /internal/tools/models.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "slices" 7 | ) 8 | 9 | type UserFunction struct { 10 | Name string `json:"name"` 11 | Description string `json:"description,omitempty"` 12 | // Format is the same, but name of the field different. So this way, each 13 | // vendor can set their own field name 14 | Inputs *InputSchema `json:"input_schema,omitempty"` 15 | // Chatgpt wants this 16 | Arguments string `json:"arguments,omitempty"` 17 | } 18 | 19 | type InputSchema struct { 20 | Type string `json:"type"` 21 | Required []string `json:"required"` 22 | Properties map[string]ParameterObject `json:"properties"` 23 | } 24 | 25 | type Input map[string]any 26 | 27 | type Call struct { 28 | ID string `json:"id,omitempty"` 29 | Name string `json:"name,omitempty"` 30 | Type string `json:"type,omitempty"` 31 | Inputs Input `json:"inputs,omitempty"` 32 | Function UserFunction `json:"function,omitempty"` 33 | } 34 | 35 | // PrettyPrint the call, showing name and what input params is used 36 | // on a concise way 37 | func (c Call) PrettyPrint() string { 38 | paramStr := "" 39 | i := 0 40 | lenInp := len(c.Inputs) 41 | for flag, val := range c.Inputs { 42 | paramStr += fmt.Sprintf("'%v': '%v'", flag, val) 43 | if i < lenInp-1 { 44 | paramStr += "," 45 | } 46 | i++ 47 | } 48 | 49 | return fmt.Sprintf("Call: '%s', inputs: [ %s ]", c.Name, paramStr) 50 | } 51 | 52 | func (c Call) JSON() string { 53 | json, err := json.MarshalIndent(c, "", " ") 54 | if err != nil { 55 | return fmt.Sprintf("ERROR: Failed to unmarshal: %v", err) 56 | } 57 | return string(json) 58 | } 59 | 60 | type ParameterObject struct { 61 | Type string `json:"type"` 62 | Description string `json:"description"` 63 | Enum []string `json:"enum,omitempty"` 64 | } 65 | 66 | type ValidationError struct { 67 | fieldsMissing []string 68 | } 69 | 70 | func NewValidationError(fieldsMissing []string) error { 71 | // Sort for deterministic error print 72 | slices.Sort(fieldsMissing) 73 | return ValidationError{fieldsMissing: fieldsMissing} 74 | } 75 | 76 | func (v ValidationError) Error() string { 77 | return fmt.Sprintf("validation error, fields missing: %v", v.fieldsMissing) 78 | } 79 | 80 | type AiTool interface { 81 | // Call the AI tool with the given Input. Returns output from the tool or an error 82 | // if the call returned an error-like. An error-like is either exit code non-zero or 83 | // restful response non 2xx. 84 | Call(Input) (string, error) 85 | 86 | // Return the UserFunction, later on used 87 | // by text queriers to send to their respective 88 | // models 89 | UserFunction() UserFunction 90 | } 91 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_go.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "strings" 7 | ) 8 | 9 | type GoTool UserFunction 10 | 11 | var Go = GoTool{ 12 | Name: "go", 13 | Description: "Run Go commands like 'go test' and 'go run' to compile, test, and run Go programs. Run 'go help' to get details of this tool.", 14 | Inputs: &InputSchema{ 15 | Type: "object", 16 | Properties: map[string]ParameterObject{ 17 | "command": { 18 | Type: "string", 19 | Description: "The Go command to run (e.g., 'run', 'test', 'build').", 20 | }, 21 | "args": { 22 | Type: "string", 23 | Description: "Additional arguments for the Go command (e.g., file names, flags).", 24 | }, 25 | "dir": { 26 | Type: "string", 27 | Description: "The directory to run the command in (optional, defaults to current directory).", 28 | }, 29 | }, 30 | Required: []string{"command"}, 31 | }, 32 | } 33 | 34 | func (g GoTool) Call(input Input) (string, error) { 35 | command, ok := input["command"].(string) 36 | if !ok { 37 | return "", fmt.Errorf("command must be a string") 38 | } 39 | 40 | args := []string{command} 41 | 42 | if inputArgs, ok := input["args"].(string); ok { 43 | args = append(args, strings.Fields(inputArgs)...) 44 | } 45 | 46 | cmd := exec.Command("go", args...) 47 | 48 | if dir, ok := input["dir"].(string); ok { 49 | cmd.Dir = dir 50 | } 51 | 52 | output, err := cmd.CombinedOutput() 53 | if err != nil { 54 | return "", fmt.Errorf("failed to run go command: %w, output: %v", err, string(output)) 55 | } 56 | 57 | return string(output), nil 58 | } 59 | 60 | func (g GoTool) UserFunction() UserFunction { 61 | return UserFunction(Go) 62 | } 63 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_rows_between.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | type RowsBetweenTool UserFunction 12 | 13 | var RowsBetween = RowsBetweenTool{ 14 | Name: "rows_between", 15 | Description: "Fetch the lines between two line numbers (inclusive) from a file.", 16 | Inputs: &InputSchema{ 17 | Type: "object", 18 | Properties: map[string]ParameterObject{ 19 | "file_path": { 20 | Type: "string", 21 | Description: "The path to the file to read.", 22 | }, 23 | "start_line": { 24 | Type: "integer", 25 | Description: "First line to include (1-based, inclusive).", 26 | }, 27 | "end_line": { 28 | Type: "integer", 29 | Description: "Last line to include (1-based, inclusive).", 30 | }, 31 | }, 32 | Required: []string{"file_path", "start_line", "end_line"}, 33 | }, 34 | } 35 | 36 | func (r RowsBetweenTool) Call(input Input) (string, error) { 37 | filePath, ok := input["file_path"].(string) 38 | if !ok { 39 | return "", fmt.Errorf("file_path must be a string") 40 | } 41 | startLine, ok := input["start_line"].(int) 42 | if !ok { 43 | // Accept float64 (from JSON decoding) 44 | if f, ok := input["start_line"].(float64); ok { 45 | startLine = int(f) 46 | } else if s, ok := input["start_line"].(string); ok { 47 | startLine, _ = strconv.Atoi(s) 48 | } else { 49 | return "", fmt.Errorf("start_line must be an integer") 50 | } 51 | } 52 | endLine, ok := input["end_line"].(int) 53 | if !ok { 54 | if f, ok := input["end_line"].(float64); ok { 55 | endLine = int(f) 56 | } else if s, ok := input["end_line"].(string); ok { 57 | endLine, _ = strconv.Atoi(s) 58 | } else { 59 | return "", fmt.Errorf("end_line must be an integer") 60 | } 61 | } 62 | 63 | if startLine <= 0 || endLine < startLine { 64 | return "", fmt.Errorf("invalid line range") 65 | } 66 | 67 | file, err := os.Open(filePath) 68 | if err != nil { 69 | return "", fmt.Errorf("failed to open file: %w", err) 70 | } 71 | defer file.Close() 72 | 73 | var lines []string 74 | scanner := bufio.NewScanner(file) 75 | for i := 1; scanner.Scan(); i++ { 76 | if i >= startLine && i <= endLine { 77 | lines = append(lines, scanner.Text()) 78 | } 79 | if i > endLine { 80 | break 81 | } 82 | } 83 | if err := scanner.Err(); err != nil { 84 | return "", fmt.Errorf("failed to scan file: %w", err) 85 | } 86 | 87 | return strings.Join(lines, "\n"), nil 88 | } 89 | 90 | func (r RowsBetweenTool) UserFunction() UserFunction { 91 | return UserFunction(RowsBetween) 92 | } 93 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_rows_between_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestRowsBetweenTool_Call(t *testing.T) { 9 | const fileName = "test_rows_between.txt" 10 | initial := "one\ntwo\nthree\nfour\nfive\n" 11 | err := os.WriteFile(fileName, []byte(initial), 0o644) 12 | if err != nil { 13 | t.Fatalf("setup failed: %v", err) 14 | } 15 | defer os.Remove(fileName) 16 | 17 | cases := []struct { 18 | start, end int 19 | expected string 20 | }{ 21 | {1, 3, "one\ntwo\nthree"}, 22 | {2, 4, "two\nthree\nfour"}, 23 | {4, 5, "four\nfive"}, 24 | {3, 3, "three"}, 25 | } 26 | 27 | for _, tc := range cases { 28 | got, err := RowsBetween.Call(Input{ 29 | "file_path": fileName, 30 | "start_line": tc.start, 31 | "end_line": tc.end, 32 | }) 33 | if err != nil { 34 | t.Errorf("unexpected error: %v", err) 35 | } 36 | if got != tc.expected { 37 | t.Errorf("unexpected output: got %q want %q (start=%d, end=%d)", got, tc.expected, tc.start, tc.end) 38 | } 39 | } 40 | } 41 | 42 | func TestRowsBetweenTool_BadInputs(t *testing.T) { 43 | _, err := RowsBetween.Call(Input{"file_path": "nonexistent.txt", "start_line": 1, "end_line": 3}) 44 | if err == nil { 45 | t.Error("expected error for missing file") 46 | } 47 | 48 | _, err = RowsBetween.Call(Input{"file_path": "", "start_line": 1, "end_line": 3}) 49 | if err == nil { 50 | t.Error("expected error for missing file_path") 51 | } 52 | _, err = RowsBetween.Call(Input{"file_path": "test_rows_between.txt", "start_line": -2, "end_line": 3}) 53 | if err == nil { 54 | t.Error("expected error for bad start_line") 55 | } 56 | _, err = RowsBetween.Call(Input{"file_path": "test_rows_between.txt", "start_line": 4, "end_line": 2}) 57 | if err == nil { 58 | t.Error("expected error for inverted lines") 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_sed.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "regexp" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | type SedTool UserFunction 12 | 13 | var Sed = SedTool{ 14 | Name: "sed", 15 | Description: "Perform a basic regex substitution on each line or within a specific line range of a file (like 'sed s/pattern/repl/g'). Overwrites the file.", 16 | Inputs: &InputSchema{ 17 | Type: "object", 18 | Properties: map[string]ParameterObject{ 19 | "file_path": { 20 | Type: "string", 21 | Description: "The path to the file to modify.", 22 | }, 23 | "pattern": { 24 | Type: "string", 25 | Description: "The regex pattern to search for.", 26 | }, 27 | "repl": { 28 | Type: "string", 29 | Description: "The replacement string.", 30 | }, 31 | "start_line": { 32 | Type: "integer", 33 | Description: "Optional. First line to modify (1-based, inclusive).", 34 | }, 35 | "end_line": { 36 | Type: "integer", 37 | Description: "Optional. Last line to modify (1-based, inclusive).", 38 | }, 39 | }, 40 | Required: []string{"file_path", "pattern", "repl"}, 41 | }, 42 | } 43 | 44 | func (s SedTool) Call(input Input) (string, error) { 45 | filePath, ok := input["file_path"].(string) 46 | if !ok { 47 | return "", fmt.Errorf("file_path must be a string") 48 | } 49 | pattern, ok := input["pattern"].(string) 50 | if !ok { 51 | return "", fmt.Errorf("pattern must be a string") 52 | } 53 | repl, ok := input["repl"].(string) 54 | if !ok { 55 | return "", fmt.Errorf("repl must be a string") 56 | } 57 | 58 | var startLine, endLine int 59 | if v, ok := input["start_line"]; ok { 60 | switch n := v.(type) { 61 | case float64: 62 | startLine = int(n) 63 | case int: 64 | startLine = n 65 | case string: 66 | startLine, _ = strconv.Atoi(n) 67 | } 68 | } 69 | if v, ok := input["end_line"]; ok { 70 | switch n := v.(type) { 71 | case float64: 72 | endLine = int(n) 73 | case int: 74 | endLine = n 75 | case string: 76 | endLine, _ = strconv.Atoi(n) 77 | } 78 | } 79 | 80 | raw, err := os.ReadFile(filePath) 81 | if err != nil { 82 | return "", fmt.Errorf("failed to read file: %w", err) 83 | } 84 | 85 | re, err := regexp.Compile(pattern) 86 | if err != nil { 87 | return "", fmt.Errorf("invalid regex: %w", err) 88 | } 89 | 90 | lines := strings.Split(string(raw), "\n") 91 | for i := range lines { 92 | lineNum := i + 1 93 | if (startLine == 0 && endLine == 0) || 94 | (startLine > 0 && endLine > 0 && lineNum >= startLine && lineNum <= endLine) || 95 | (startLine > 0 && endLine == 0 && lineNum >= startLine) || 96 | (startLine == 0 && endLine > 0 && lineNum <= endLine) { 97 | lines[i] = re.ReplaceAllString(lines[i], repl) 98 | } 99 | } 100 | 101 | out := strings.Join(lines, "\n") 102 | err = os.WriteFile(filePath, []byte(out), 0o644) 103 | if err != nil { 104 | return "", fmt.Errorf("failed to write file: %w", err) 105 | } 106 | return fmt.Sprintf("sed: replaced occurrences of %q with %q in %s (%d-%d)", pattern, repl, filePath, startLine, endLine), nil 107 | } 108 | 109 | func (s SedTool) UserFunction() UserFunction { 110 | return UserFunction(Sed) 111 | } 112 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_sed_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestSedTool_Call(t *testing.T) { 9 | const fileName = "test_sed.txt" 10 | initial := "apple\nbanana\napple pie\n" 11 | err := os.WriteFile(fileName, []byte(initial), 0o644) 12 | if err != nil { 13 | t.Fatalf("setup failed: %v", err) 14 | } 15 | defer os.Remove(fileName) 16 | 17 | _, err = Sed.Call(Input{ 18 | "file_path": fileName, 19 | "pattern": "apple", 20 | "repl": "orange", 21 | }) 22 | if err != nil { 23 | t.Fatalf("sed failed: %v", err) 24 | } 25 | 26 | result, err := os.ReadFile(fileName) 27 | if err != nil { 28 | t.Fatalf("read failed: %v", err) 29 | } 30 | 31 | expected := "orange\nbanana\norange pie\n" 32 | if string(result) != expected { 33 | t.Errorf("unexpected output: got\n%q\nwant\n%q", string(result), expected) 34 | } 35 | } 36 | 37 | func TestSedTool_Range(t *testing.T) { 38 | const fileName = "test_sed_range.txt" 39 | initial := "foo\nfoo\nfoo\nfoo\n" 40 | err := os.WriteFile(fileName, []byte(initial), 0o644) 41 | if err != nil { 42 | t.Fatalf("setup failed: %v", err) 43 | } 44 | defer os.Remove(fileName) 45 | 46 | _, err = Sed.Call(Input{ 47 | "file_path": fileName, 48 | "pattern": "foo", 49 | "repl": "bar", 50 | "start_line": 2, 51 | "end_line": 3, 52 | }) 53 | if err != nil { 54 | t.Fatalf("sed with range failed: %v", err) 55 | } 56 | 57 | result, err := os.ReadFile(fileName) 58 | if err != nil { 59 | t.Fatalf("read failed: %v", err) 60 | } 61 | 62 | expected := "foo\nbar\nbar\nfoo\n" 63 | if string(result) != expected { 64 | t.Errorf("unexpected output: got\n%q\nwant\n%q", string(result), expected) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_write_file.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | type WriteFileTool UserFunction 10 | 11 | var WriteFile = WriteFileTool{ 12 | Name: "write_file", 13 | Description: "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does.", 14 | Inputs: &InputSchema{ 15 | Type: "object", 16 | Properties: map[string]ParameterObject{ 17 | "file_path": { 18 | Type: "string", 19 | Description: "The path to the file to write to.", 20 | }, 21 | "content": { 22 | Type: "string", 23 | Description: "The content to write to the file.", 24 | }, 25 | "append": { 26 | Type: "boolean", 27 | Description: "If true, append to the file instead of overwriting it.", 28 | }, 29 | }, 30 | Required: []string{"file_path", "content"}, 31 | }, 32 | } 33 | 34 | func (w WriteFileTool) Call(input Input) (string, error) { 35 | filePath, ok := input["file_path"].(string) 36 | if !ok { 37 | return "", fmt.Errorf("file_path must be a string") 38 | } 39 | 40 | content, ok := input["content"].(string) 41 | if !ok { 42 | return "", fmt.Errorf("content must be a string") 43 | } 44 | 45 | append := false 46 | if input["append"] != nil { 47 | append, ok = input["append"].(bool) 48 | if !ok { 49 | return "", fmt.Errorf("append must be a boolean") 50 | } 51 | } 52 | 53 | // Ensure the directory exists 54 | dir := filepath.Dir(filePath) 55 | if err := os.MkdirAll(dir, 0o755); err != nil { 56 | return "", fmt.Errorf("failed to create directory: %w", err) 57 | } 58 | 59 | var flag int 60 | if append { 61 | flag = os.O_APPEND | os.O_CREATE | os.O_WRONLY 62 | } else { 63 | flag = os.O_TRUNC | os.O_CREATE | os.O_WRONLY 64 | } 65 | 66 | file, err := os.OpenFile(filePath, flag, 0o644) 67 | if err != nil { 68 | return "", fmt.Errorf("failed to open file: %w", err) 69 | } 70 | defer file.Close() 71 | 72 | _, err = file.WriteString(content) 73 | if err != nil { 74 | return "", fmt.Errorf("failed to write to file: %w", err) 75 | } 76 | 77 | return fmt.Sprintf("Successfully wrote %d bytes to %s", len(content), filePath), nil 78 | } 79 | 80 | func (w WriteFileTool) UserFunction() UserFunction { 81 | return UserFunction(WriteFile) 82 | } 83 | -------------------------------------------------------------------------------- /internal/tools/programming_tool_write_file_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | ) 8 | 9 | func TestWriteFileTool_Call(t *testing.T) { 10 | tempDir := t.TempDir() 11 | 12 | tests := []struct { 13 | name string 14 | input Input 15 | wantErr bool 16 | check func(t *testing.T, filePath string) 17 | }{ 18 | { 19 | name: "write new file", 20 | input: Input{ 21 | "file_path": filepath.Join(tempDir, "test1.txt"), 22 | "content": "Hello, World!", 23 | }, 24 | wantErr: false, 25 | check: func(t *testing.T, filePath string) { 26 | content, err := os.ReadFile(filePath) 27 | if err != nil { 28 | t.Fatalf("Failed to read file: %v", err) 29 | } 30 | if string(content) != "Hello, World!" { 31 | t.Errorf("Expected content 'Hello, World!', got '%s'", string(content)) 32 | } 33 | }, 34 | }, 35 | { 36 | name: "overwrite existing file", 37 | input: Input{ 38 | "file_path": filepath.Join(tempDir, "test2.txt"), 39 | "content": "New content", 40 | }, 41 | wantErr: false, 42 | check: func(t *testing.T, filePath string) { 43 | content, err := os.ReadFile(filePath) 44 | if err != nil { 45 | t.Fatalf("Failed to read file: %v", err) 46 | } 47 | if string(content) != "New content" { 48 | t.Errorf("Expected content 'New content', got '%s'", string(content)) 49 | } 50 | }, 51 | }, 52 | { 53 | name: "append to existing file", 54 | input: Input{ 55 | "file_path": filepath.Join(tempDir, "test3.txt"), 56 | "content": " Appended content", 57 | "append": true, 58 | }, 59 | wantErr: false, 60 | check: func(t *testing.T, filePath string) { 61 | content, err := os.ReadFile(filePath) 62 | if err != nil { 63 | t.Fatalf("Failed to read file: %v", err) 64 | } 65 | if string(content) != "Initial content Appended content" { 66 | t.Errorf("Expected content 'Initial content Appended content', got '%s'", string(content)) 67 | } 68 | }, 69 | }, 70 | { 71 | name: "missing file_path", 72 | input: Input{ 73 | "content": "Some content", 74 | }, 75 | wantErr: true, 76 | }, 77 | { 78 | name: "missing content", 79 | input: Input{ 80 | "file_path": filepath.Join(tempDir, "test4.txt"), 81 | }, 82 | wantErr: true, 83 | }, 84 | { 85 | name: "invalid append type", 86 | input: Input{ 87 | "file_path": filepath.Join(tempDir, "test5.txt"), 88 | "content": "Some content", 89 | "append": "true", 90 | }, 91 | wantErr: true, 92 | }, 93 | } 94 | 95 | writeTool := WriteFileTool{} 96 | 97 | // Set up file for append test 98 | if err := os.WriteFile(filepath.Join(tempDir, "test3.txt"), []byte("Initial content"), 0o644); err != nil { 99 | t.Fatalf("Failed to set up append test: %v", err) 100 | } 101 | 102 | for _, tt := range tests { 103 | t.Run(tt.name, func(t *testing.T) { 104 | result, err := writeTool.Call(tt.input) 105 | if (err != nil) != tt.wantErr { 106 | t.Errorf("WriteFileTool.Call() error = %v, wantErr %v", err, tt.wantErr) 107 | return 108 | } 109 | if !tt.wantErr { 110 | if result == "" { 111 | t.Errorf("WriteFileTool.Call() returned empty result") 112 | } 113 | if tt.check != nil { 114 | tt.check(t, tt.input["file_path"].(string)) 115 | } 116 | } 117 | }) 118 | } 119 | } 120 | 121 | func TestWriteFileTool_UserFunction(t *testing.T) { 122 | writeTool := WriteFileTool{} 123 | userFunc := writeTool.UserFunction() 124 | 125 | if userFunc.Name != "write_file" { 126 | t.Errorf("Expected name 'write_file', got '%s'", userFunc.Name) 127 | } 128 | 129 | if userFunc.Description != "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does." { 130 | t.Errorf("Unexpected description: %s", userFunc.Description) 131 | } 132 | 133 | if len(userFunc.Inputs.Required) != 2 || userFunc.Inputs.Required[0] != "file_path" || userFunc.Inputs.Required[1] != "content" { 134 | t.Errorf("Unexpected required inputs: %v", userFunc.Inputs.Required) 135 | } 136 | 137 | if len(userFunc.Inputs.Properties) != 3 { 138 | t.Errorf("Expected 3 properties, got %d", len(userFunc.Inputs.Properties)) 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /internal/tools/web_tool_website_text.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | 10 | "golang.org/x/net/html" 11 | ) 12 | 13 | type WebsiteTextTool UserFunction 14 | 15 | var WebsiteText = WebsiteTextTool{ 16 | Name: "website_text", 17 | Description: "Get the text content of a website by stripping all non-text tags and trimming whitespace.", 18 | Inputs: &InputSchema{ 19 | Type: "object", 20 | Properties: map[string]ParameterObject{ 21 | "url": { 22 | Type: "string", 23 | Description: "The URL of the website to retrieve the text content from.", 24 | }, 25 | }, 26 | Required: []string{"url"}, 27 | }, 28 | } 29 | 30 | func (w WebsiteTextTool) Call(input Input) (string, error) { 31 | url, ok := input["url"].(string) 32 | if !ok { 33 | return "", fmt.Errorf("url must be a string") 34 | } 35 | resp, err := http.Get(url) 36 | if err != nil { 37 | return "", fmt.Errorf("failed to fetch website: %w", err) 38 | } 39 | defer resp.Body.Close() 40 | 41 | var text strings.Builder 42 | tokenizer := html.NewTokenizer(resp.Body) 43 | for { 44 | tt := tokenizer.Next() 45 | if tt == html.ErrorToken { 46 | if tokenizer.Err() == io.EOF { 47 | break 48 | } 49 | return "", fmt.Errorf("tokenizer error: %w", tokenizer.Err()) 50 | } 51 | if tt == html.TextToken { 52 | trimmed := bytes.TrimSpace(tokenizer.Text()) 53 | if len(trimmed) > 0 { 54 | text.Write(trimmed) 55 | text.WriteRune('\n') 56 | } 57 | } 58 | } 59 | return text.String(), nil 60 | } 61 | 62 | func (w WebsiteTextTool) UserFunction() UserFunction { 63 | return UserFunction(WebsiteText) 64 | } 65 | -------------------------------------------------------------------------------- /internal/tools/web_tool_website_text_test.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestWebsiteTextTool(t *testing.T) { 10 | // Test successful case 11 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | w.Write([]byte(` 13 | 14 | 15 |

Hello World

16 |

This is some text

17 | 18 | `)) 19 | })) 20 | defer server.Close() 21 | 22 | input := Input{"url": server.URL} 23 | expected := "Hello World\nThis is some text\n" 24 | 25 | actual, err := WebsiteText.Call(input) 26 | if err != nil { 27 | t.Errorf("Unexpected error: %v", err) 28 | } 29 | if actual != expected { 30 | t.Errorf("Expected %q, got %q", expected, actual) 31 | } 32 | 33 | // Test invalid URL 34 | input = Input{"url": "invalid"} 35 | _, err = WebsiteText.Call(input) 36 | if err == nil { 37 | t.Error("Expected error for invalid URL, got nil") 38 | } 39 | 40 | // Test invalid input type 41 | input = Input{"url": 123} 42 | _, err = WebsiteText.Call(input) 43 | if err == nil { 44 | t.Error("Expected error for invalid input type, got nil") 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/utils/config.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "path/filepath" 8 | "reflect" 9 | 10 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 11 | "github.com/baalimago/go_away_boilerplate/pkg/debug" 12 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 13 | ) 14 | 15 | func createConfigDir(configDirPath string) error { 16 | if _, err := os.Stat(configDirPath); os.IsNotExist(err) { 17 | err := setupClaiConfigDir(configDirPath) 18 | if err != nil { 19 | return fmt.Errorf("failed to setup config dotdir: %w", err) 20 | } 21 | } 22 | return nil 23 | } 24 | 25 | func setupClaiConfigDir(configPath string) error { 26 | conversationsDir := filepath.Join(configPath, "conversations") 27 | ancli.PrintOK("created conversations directory\n") 28 | 29 | // Create the .clai directory. 30 | if err := os.MkdirAll(conversationsDir, os.ModePerm); err != nil { 31 | return fmt.Errorf("failed to create .clai + .clai/conversations directory: %w", err) 32 | } 33 | ancli.PrintOK(fmt.Sprintf("created .clai directory at: '%v'\n", configPath)) 34 | return nil 35 | } 36 | 37 | func createDefaultConfigFile[T any](configDirPath, configFileName string, dflt *T) error { 38 | configFilePath := filepath.Join(configDirPath, configFileName) 39 | if _, err := os.Stat(configFilePath); os.IsNotExist(err) { 40 | if misc.Truthy(os.Getenv("DEBUG")) { 41 | ancli.PrintOK(fmt.Sprintf("attempting to create file: '%v'\n", configFilePath)) 42 | } 43 | err := CreateFile(configFilePath, dflt) 44 | if err != nil { 45 | return fmt.Errorf("failed to write config: '%v', error: %w", configFileName, err) 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | func runMigrationCallback(migrationCb func(string) error, configDirPath string) error { 52 | if migrationCb != nil { 53 | err := migrationCb(configDirPath) 54 | if err != nil { 55 | ancli.PrintWarn(fmt.Sprintf("failed to migrate for config, error: %v\n", err)) 56 | return err 57 | } 58 | } 59 | return nil 60 | } 61 | 62 | func LoadConfigFromFile[T any]( 63 | placeConfigPath, 64 | configFileName string, 65 | migrationCb func(string) error, 66 | dflt *T, 67 | ) (T, error) { 68 | configDirPath := fmt.Sprintf("%v/.clai/", placeConfigPath) 69 | if misc.Truthy(os.Getenv("DEBUG")) { 70 | ancli.PrintOK(fmt.Sprintf("attempting to load file: %v%v\n", configDirPath, configFileName)) 71 | } 72 | 73 | err := createConfigDir(configDirPath) 74 | if err != nil { 75 | var nilVal T 76 | return nilVal, err 77 | } 78 | 79 | err = createDefaultConfigFile(configDirPath, configFileName, dflt) 80 | if err != nil { 81 | var nilVal T 82 | return nilVal, err 83 | } 84 | 85 | err = runMigrationCallback(migrationCb, configDirPath) 86 | if err != nil { 87 | var nilVal T 88 | return nilVal, err 89 | } 90 | 91 | configPath := path.Join(configDirPath, configFileName) 92 | var conf T 93 | err = ReadAndUnmarshal(configPath, &conf) 94 | if err != nil { 95 | return conf, fmt.Errorf("failed to unmarshal config '%v', error: %v", configFileName, err) 96 | } 97 | 98 | // Append any new fields from defauly config, in case of config extension 99 | hasChanged := setNonZeroValueFields(&conf, dflt) 100 | 101 | if len(hasChanged) > 0 { 102 | err = CreateFile(configPath, &conf) 103 | if err != nil { 104 | return conf, fmt.Errorf("failed to write config '%v' post zero-field appendage, error: %v", configFileName, err) 105 | } 106 | ancli.PrintOK(fmt.Sprintf("appended new fields: '%s', to textConfig and updated config file: '%v'\n", hasChanged, configPath)) 107 | } 108 | 109 | if misc.Truthy(os.Getenv("DEBUG")) { 110 | ancli.PrintOK(fmt.Sprintf("found config: %v\n", debug.IndentedJsonFmt(conf))) 111 | } 112 | return conf, nil 113 | } 114 | 115 | // setNonZeroValueFields on a using b as template 116 | func setNonZeroValueFields[T any](a, b *T) []string { 117 | hasChanged := []string{} 118 | t := reflect.TypeOf(*a) 119 | for i := range t.NumField() { 120 | f := t.Field(i) 121 | aVal := reflect.ValueOf(a).Elem().FieldByName(f.Name) 122 | bVal := reflect.ValueOf(b).Elem().FieldByName(f.Name) 123 | if f.IsExported() && aVal.IsZero() && !bVal.IsZero() { 124 | hasChanged = append(hasChanged, f.Tag.Get("json")) 125 | aVal.Set(bVal) 126 | } 127 | } 128 | return hasChanged 129 | } 130 | 131 | func ReturnNonDefault[T comparable](a, b, defaultVal T) (T, error) { 132 | if a != defaultVal && b != defaultVal { 133 | return defaultVal, fmt.Errorf("values are mutually exclusive") 134 | } 135 | if a != defaultVal { 136 | return a, nil 137 | } 138 | if b != defaultVal { 139 | return b, nil 140 | } 141 | return defaultVal, nil 142 | } 143 | -------------------------------------------------------------------------------- /internal/utils/config_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 9 | ) 10 | 11 | func TestReturnNonDefault(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | a interface{} 15 | b interface{} 16 | defaultVal interface{} 17 | want interface{} 18 | wantErr bool 19 | }{ 20 | { 21 | name: "Both defaults", 22 | a: "default", 23 | b: "default", 24 | defaultVal: "default", 25 | want: "default", 26 | wantErr: false, 27 | }, 28 | { 29 | name: "A non-default", 30 | a: "non-default", 31 | b: "default", 32 | defaultVal: "default", 33 | want: "non-default", 34 | wantErr: false, 35 | }, 36 | { 37 | name: "B non-default", 38 | a: "default", 39 | b: "non-default", 40 | defaultVal: "default", 41 | want: "non-default", 42 | wantErr: false, 43 | }, 44 | { 45 | name: "Both non-default", 46 | a: "non-default-a", 47 | b: "non-default-b", 48 | defaultVal: "default", 49 | want: "default", 50 | wantErr: true, 51 | }, 52 | { 53 | name: "Both non-default same value", 54 | a: "non-default", 55 | b: "non-default", 56 | defaultVal: "default", 57 | want: "default", 58 | wantErr: true, 59 | }, 60 | } 61 | 62 | for _, tt := range tests { 63 | t.Run(tt.name, func(t *testing.T) { 64 | got, err := ReturnNonDefault(tt.a, tt.b, tt.defaultVal) 65 | if (err != nil) != tt.wantErr { 66 | t.Errorf("ReturnNonDefault() error = %v, wantErr %v", err, tt.wantErr) 67 | return 68 | } 69 | if !tt.wantErr && got != tt.want { 70 | t.Errorf("ReturnNonDefault() = %v, want %v", got, tt.want) 71 | } 72 | }) 73 | } 74 | } 75 | 76 | func TestRunMigrationCallback(t *testing.T) { 77 | // Create a test migration callback 78 | var migrationCalled bool 79 | migrationCb := func(configDirPath string) error { 80 | migrationCalled = true 81 | return nil 82 | } 83 | 84 | // Test running the migration callback 85 | configDirPath := "/path/to/config" 86 | err := runMigrationCallback(migrationCb, configDirPath) 87 | if err != nil { 88 | t.Errorf("Unexpected error running migration callback: %v", err) 89 | } 90 | if !migrationCalled { 91 | t.Error("Expected migration callback to be called") 92 | } 93 | 94 | // Test running the migration callback with nil callback 95 | migrationCalled = false 96 | err = runMigrationCallback(nil, configDirPath) 97 | if err != nil { 98 | t.Errorf("Unexpected error running nil migration callback: %v", err) 99 | } 100 | if migrationCalled { 101 | t.Error("Expected migration callback not to be called") 102 | } 103 | } 104 | 105 | func TestCreateConfigDir(t *testing.T) { 106 | // Create a temporary directory for testing 107 | configDirPath := filepath.Join(t.TempDir(), ".clai") 108 | 109 | // Test creating a new config directory 110 | err := createConfigDir(configDirPath) 111 | if err != nil { 112 | t.Errorf("Unexpected error creating config directory: %v", err) 113 | } 114 | if _, err := os.Stat(configDirPath); os.IsNotExist(err) { 115 | t.Error("Expected config directory to exist") 116 | } 117 | 118 | // Test creating an existing config directory 119 | err = createConfigDir(configDirPath) 120 | if err != nil { 121 | t.Errorf("Unexpected error creating existing config directory: %v", err) 122 | } 123 | } 124 | 125 | func TestCreateDefaultConfigFile(t *testing.T) { 126 | // Create a temporary directory for testing 127 | tempDir := t.TempDir() 128 | os.MkdirAll(filepath.Join(tempDir, ".clai"), 0o755) 129 | 130 | configDirPath := filepath.Join(tempDir, ".clai") 131 | configFileName := "config.json" 132 | 133 | // Test creating a new default config file 134 | dflt := &struct { 135 | Name string `json:"name"` 136 | }{Name: "John"} 137 | err := createDefaultConfigFile(configDirPath, configFileName, dflt) 138 | if err != nil { 139 | t.Errorf("Unexpected error creating default config file: %v", err) 140 | } 141 | configFilePath := filepath.Join(configDirPath, configFileName) 142 | if _, err := os.Stat(configFilePath); os.IsNotExist(err) { 143 | t.Error("Expected default config file to exist") 144 | } 145 | 146 | // Test creating an existing default config file 147 | err = createDefaultConfigFile(configDirPath, configFileName, dflt) 148 | if err != nil { 149 | t.Errorf("Unexpected error creating existing default config file: %v", err) 150 | } 151 | } 152 | 153 | type testStruct struct { 154 | A string 155 | B string 156 | } 157 | 158 | func Test_appendNewFieldsFromDefault(t *testing.T) { 159 | testCases := []struct { 160 | desc string 161 | given testStruct 162 | when testStruct 163 | want testStruct 164 | }{ 165 | { 166 | desc: "it should append new fields from default if they are zero value in want", 167 | given: testStruct{ 168 | A: "filled", 169 | }, 170 | when: testStruct{ 171 | A: "filled", 172 | B: "new", 173 | }, 174 | want: testStruct{ 175 | A: "filled", 176 | B: "new", 177 | }, 178 | }, 179 | } 180 | for _, tC := range testCases { 181 | t.Run(tC.desc, func(t *testing.T) { 182 | setNonZeroValueFields(&tC.given, &tC.when) 183 | got := tC.given 184 | testboil.FailTestIfDiff(t, got.A, tC.want.A) 185 | testboil.FailTestIfDiff(t, got.B, tC.want.B) 186 | }) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /internal/utils/errors.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "errors" 4 | 5 | var ErrUserInitiatedExit = errors.New("user exit") 6 | -------------------------------------------------------------------------------- /internal/utils/file.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | "os" 10 | ) 11 | 12 | func CreateFile[T any](path string, toCreate *T) error { 13 | file, err := os.Create(path) 14 | if err != nil { 15 | return fmt.Errorf("failed to create config file: %w", err) 16 | } 17 | defer file.Close() 18 | b, err := json.MarshalIndent(toCreate, "", " ") 19 | if err != nil { 20 | return fmt.Errorf("failed to marshal config: %w", err) 21 | } 22 | if _, err := file.Write(b); err != nil { 23 | return fmt.Errorf("failed to write config: %w", err) 24 | } 25 | return nil 26 | } 27 | 28 | func WriteFile[T any](path string, toWrite *T) error { 29 | fileBytes, err := json.MarshalIndent(toWrite, "", " ") 30 | if err != nil { 31 | return fmt.Errorf("failed to marshal file: %w", err) 32 | } 33 | err = os.WriteFile(path, fileBytes, 0o644) 34 | if err != nil { 35 | return fmt.Errorf("failed to write file: %w", err) 36 | } 37 | return nil 38 | } 39 | 40 | // ReadAndUnmarshal by first finding the file, then attempting to read + unmarshal to T 41 | func ReadAndUnmarshal[T any](filePath string, config *T) error { 42 | if _, err := os.Stat(filePath); errors.Is(err, fs.ErrNotExist) { 43 | return fmt.Errorf("failed to find file: %w", err) 44 | } 45 | file, err := os.Open(filePath) 46 | if err != nil { 47 | return fmt.Errorf("failed to open file: %w", err) 48 | } 49 | defer file.Close() 50 | fileBytes, err := io.ReadAll(file) 51 | if err != nil { 52 | return fmt.Errorf("failed to read file: %w", err) 53 | } 54 | err = json.Unmarshal(fileBytes, config) 55 | if err != nil { 56 | return fmt.Errorf("failed to unmarshal file: %w", err) 57 | } 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /internal/utils/file_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | type TestData struct { 9 | Name string `json:"name"` 10 | Age int `json:"age"` 11 | } 12 | 13 | func TestCreateFile(t *testing.T) { 14 | filePath := "test_create.json" 15 | defer os.Remove(filePath) 16 | 17 | data := &TestData{Name: "John", Age: 30} 18 | err := CreateFile(filePath, data) 19 | if err != nil { 20 | t.Errorf("CreateFile failed: %v", err) 21 | } 22 | 23 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 24 | t.Errorf("File not created: %v", err) 25 | } 26 | } 27 | 28 | func TestWriteFile(t *testing.T) { 29 | filePath := "test_write.json" 30 | defer os.Remove(filePath) 31 | 32 | data := &TestData{Name: "Alice", Age: 25} 33 | err := WriteFile(filePath, data) 34 | if err != nil { 35 | t.Errorf("WriteFile failed: %v", err) 36 | } 37 | 38 | if _, err := os.Stat(filePath); os.IsNotExist(err) { 39 | t.Errorf("File not written: %v", err) 40 | } 41 | } 42 | 43 | func TestReadAndUnmarshal(t *testing.T) { 44 | filePath := "test_read.json" 45 | defer os.Remove(filePath) 46 | 47 | expected := &TestData{Name: "Bob", Age: 35} 48 | err := CreateFile(filePath, expected) 49 | if err != nil { 50 | t.Fatalf("Failed to create test file: %v", err) 51 | } 52 | 53 | var actual TestData 54 | err = ReadAndUnmarshal(filePath, &actual) 55 | if err != nil { 56 | t.Errorf("ReadAndUnmarshal failed: %v", err) 57 | } 58 | 59 | if actual.Name != expected.Name || actual.Age != expected.Age { 60 | t.Errorf("ReadAndUnmarshal returned unexpected data: got %+v, want %+v", actual, expected) 61 | } 62 | } 63 | 64 | func TestReadAndUnmarshal_FileNotFound(t *testing.T) { 65 | filePath := "nonexistent.json" 66 | var data TestData 67 | err := ReadAndUnmarshal(filePath, &data) 68 | if err == nil { 69 | t.Error("ReadAndUnmarshal should have failed for non-existent file") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/utils/input.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | "slices" 10 | "strings" 11 | ) 12 | 13 | // ReadUserInput and return on interrupt channel 14 | func ReadUserInput() (string, error) { 15 | sigChan := make(chan os.Signal, 1) 16 | signal.Notify(sigChan, os.Interrupt) 17 | defer signal.Stop(sigChan) 18 | inputChan := make(chan string) 19 | errChan := make(chan error) 20 | 21 | go func() { 22 | reader := bufio.NewReader(os.Stdin) 23 | userInput, err := reader.ReadString('\n') 24 | if err != nil { 25 | errChan <- err 26 | return 27 | } 28 | inputChan <- userInput 29 | }() 30 | 31 | select { 32 | case <-sigChan: 33 | return "", ErrUserInitiatedExit 34 | case err := <-errChan: 35 | return "", fmt.Errorf("failed to read user input: %w", err) 36 | case userInput, open := <-inputChan: 37 | if open { 38 | trimmedInput := strings.TrimSpace(userInput) 39 | quitters := []string{"q", "quit"} 40 | if slices.Contains(quitters, trimmedInput) { 41 | return "", ErrUserInitiatedExit 42 | } 43 | return trimmedInput, nil 44 | } else { 45 | return "", errors.New("user input channel closed. Not sure how we ended up here🤔") 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /internal/utils/misc.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "math/rand" 4 | 5 | func RandomPrefix() string { 6 | const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 7 | result := make([]byte, 10) 8 | for i := range result { 9 | result[i] = charset[rand.Intn(len(charset))] 10 | } 11 | 12 | return string(result) 13 | } 14 | 15 | // GetFirstTokens returns the first n tokens of the prompt, or the whole prompt if it has less than n tokens 16 | func GetFirstTokens(prompt []string, n int) []string { 17 | ret := make([]string, 0) 18 | for _, token := range prompt { 19 | if token == "" { 20 | continue 21 | } 22 | if len(ret) < n { 23 | ret = append(ret, token) 24 | } else { 25 | return ret 26 | } 27 | } 28 | return ret 29 | } 30 | -------------------------------------------------------------------------------- /internal/utils/misc_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestGetFirstTokens(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | prompt []string 12 | n int 13 | want []string 14 | }{ 15 | { 16 | name: "Empty prompt", 17 | prompt: []string{}, 18 | n: 5, 19 | want: []string{}, 20 | }, 21 | { 22 | name: "Prompt with less than n tokens", 23 | prompt: []string{"Hello", "World"}, 24 | n: 5, 25 | want: []string{"Hello", "World"}, 26 | }, 27 | { 28 | name: "Prompt with exactly n tokens", 29 | prompt: []string{"This", "is", "a", "test", "prompt"}, 30 | n: 5, 31 | want: []string{"This", "is", "a", "test", "prompt"}, 32 | }, 33 | { 34 | name: "Prompt with more than n tokens", 35 | prompt: []string{"This", "is", "a", "longer", "test", "prompt"}, 36 | n: 4, 37 | want: []string{"This", "is", "a", "longer"}, 38 | }, 39 | { 40 | name: "Prompt with empty tokens", 41 | prompt: []string{"", "Hello", "", "World", ""}, 42 | n: 3, 43 | want: []string{"Hello", "World"}, 44 | }, 45 | } 46 | 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | got := GetFirstTokens(tt.prompt, tt.n) 50 | if !reflect.DeepEqual(got, tt.want) { 51 | t.Errorf("GetFirstTokens() = %v, want %v", got, tt.want) 52 | } 53 | }) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /internal/utils/print.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | "unicode/utf8" 10 | 11 | "github.com/baalimago/clai/internal/models" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | ) 14 | 15 | func ClearTermTo(termWidth, upTo int) { 16 | clearLine := strings.Repeat(" ", termWidth) 17 | // Move cursor up line by line and clear the line 18 | for upTo > 0 { 19 | fmt.Printf("\r%v", clearLine) 20 | fmt.Printf("\033[%dA", 1) 21 | upTo-- 22 | } 23 | fmt.Printf("\r%v", clearLine) 24 | // Place cursor at start of line 25 | fmt.Printf("\r") 26 | } 27 | 28 | // UpdateMessageTerminalMetadata updates the terminal metadata. Meaning the lineCount, to eventually 29 | // clear the terminal 30 | func UpdateMessageTerminalMetadata(msg string, line *string, lineCount *int, termWidth int) { 31 | if termWidth <= 0 { 32 | termWidth = 1 33 | } 34 | 35 | newlineSplit := strings.Split(*line+msg, "\n") 36 | *lineCount = 0 37 | 38 | for _, segment := range newlineSplit { 39 | if len(segment) == 0 { 40 | *lineCount++ 41 | continue 42 | } 43 | 44 | runeCount := utf8.RuneCountInString(segment) 45 | fullLines := runeCount / termWidth 46 | if runeCount%termWidth > 0 { 47 | fullLines++ 48 | } 49 | *lineCount += fullLines 50 | } 51 | 52 | if *lineCount == 0 { 53 | *lineCount = 1 54 | } 55 | 56 | lastSegment := newlineSplit[len(newlineSplit)-1] 57 | if len(lastSegment) > termWidth { 58 | lastWords := strings.Split(lastSegment, " ") 59 | lastWord := lastWords[len(lastWords)-1] 60 | if len(lastWord) > termWidth { 61 | *line = lastWord[len(lastWord)-termWidth:] 62 | } else { 63 | *line = lastWord 64 | } 65 | } else { 66 | *line = lastSegment 67 | } 68 | } 69 | 70 | // AttemptPrettyPrint by first checking if the glow command is available, and if so, pretty print the chat message 71 | // if not found, simply print the message as is 72 | func AttemptPrettyPrint(chatMessage models.Message, username string, raw bool) error { 73 | if raw { 74 | fmt.Println(chatMessage.Content) 75 | return nil 76 | } 77 | role := chatMessage.Role 78 | color := ancli.BLUE 79 | switch chatMessage.Role { 80 | case "tool": 81 | color = ancli.MAGENTA 82 | case "user": 83 | color = ancli.CYAN 84 | role = username 85 | case "system": 86 | color = ancli.BLUE 87 | } 88 | cmd := exec.Command("glow", "--version") 89 | if err := cmd.Run(); err != nil { 90 | fmt.Printf("%v: %v\n", ancli.ColoredMessage(color, role), chatMessage.Content) 91 | return nil 92 | } 93 | 94 | cmd = exec.Command("glow") 95 | inp := chatMessage.Content 96 | // For some reason glow hides specifically . So, replace it to [thinking] 97 | inp = strings.ReplaceAll(inp, "", "[thinking]") 98 | inp = strings.ReplaceAll(inp, "", "[/thinking]") 99 | cmd.Stdin = bytes.NewBufferString(inp) 100 | cmd.Stdout = os.Stdout 101 | fmt.Printf("%v:", ancli.ColoredMessage(color, role)) 102 | if err := cmd.Run(); err != nil { 103 | return fmt.Errorf("failed to run glow: %w", err) 104 | } 105 | return nil 106 | } 107 | -------------------------------------------------------------------------------- /internal/utils/print_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "testing" 4 | 5 | func TestUpdateMessageTerminalMetadata(t *testing.T) { 6 | testCases := []struct { 7 | name string 8 | msg string 9 | line string 10 | lineCount int 11 | termWidth int 12 | expectedLine string 13 | expectedLineCount int 14 | }{ 15 | { 16 | name: "Single line message", 17 | msg: "Hello", 18 | line: "", 19 | lineCount: 0, 20 | termWidth: 10, 21 | expectedLine: "Hello", 22 | expectedLineCount: 1, 23 | }, 24 | { 25 | name: "Message with newline", 26 | msg: "Hello\nWorld", 27 | line: "", 28 | lineCount: 0, 29 | termWidth: 10, 30 | expectedLine: "World", 31 | expectedLineCount: 2, 32 | }, 33 | { 34 | name: "Message exceeding terminal width", 35 | msg: "Hello World", 36 | line: "", 37 | lineCount: 0, 38 | termWidth: 5, 39 | expectedLine: "World", 40 | expectedLineCount: 3, 41 | }, 42 | { 43 | name: "Append to existing line", 44 | msg: "World", 45 | line: "Hello ", 46 | lineCount: 0, 47 | termWidth: 20, 48 | expectedLine: "Hello World", 49 | expectedLineCount: 1, 50 | }, 51 | { 52 | name: "It should handle multiple termwidth overflows", 53 | msg: "1111 2222 3333 4444", 54 | line: "", 55 | lineCount: 0, 56 | termWidth: 5, 57 | expectedLine: "4444", 58 | expectedLineCount: 4, 59 | }, 60 | { 61 | name: "It should handle multiple termwidth overflows + newlines", 62 | msg: "1111 22\n3333 4444", 63 | line: "", 64 | lineCount: 0, 65 | termWidth: 5, 66 | expectedLine: "4444", 67 | expectedLineCount: 4, 68 | }, 69 | { 70 | name: "It should handle multiple termwidth overflows + newlines", 71 | msg: "11 22 33 44 55 66", 72 | line: "", 73 | lineCount: 0, 74 | termWidth: 3, 75 | expectedLine: "66", 76 | expectedLineCount: 6, 77 | }, 78 | { 79 | name: "it should not fail on this edge case that I found", 80 | msg: "Debugging involves systematically finding and resolving issues within your code or software. Start by identifying the problem, replicate the error, and use tools like breakpoints or logging to trace the source. Testing changes iteratively helps ensure the fix is successful and doesn't cause new issues.", 81 | // This is not correct, but that's fine, the last line functionality isn't used anywhere anyways 82 | expectedLine: "issues.", 83 | lineCount: 0, 84 | termWidth: 223, 85 | expectedLineCount: 2, 86 | }, 87 | { 88 | name: "it should not fail on this edge case that I found", 89 | msg: "*Hurrmph* I'm as well as a 90-year old can be, which is better than the alternative, I suppose. My joints are creaking like an old rocking chair, but my mind is still sharp as a tack.\n\nWhat can I help you with today, young whippersnapper? *adjusts spectacles*\n", 90 | expectedLine: "", 91 | lineCount: 0, 92 | termWidth: 127, 93 | expectedLineCount: 5, 94 | }, 95 | } 96 | 97 | for _, tc := range testCases { 98 | t.Run(tc.name, func(t *testing.T) { 99 | line := tc.line 100 | lineCount := tc.lineCount 101 | 102 | UpdateMessageTerminalMetadata(tc.msg, &line, &lineCount, tc.termWidth) 103 | 104 | if line != tc.expectedLine { 105 | t.Errorf("Expected line: %q, got: %q", tc.expectedLine, line) 106 | } 107 | 108 | if lineCount != tc.expectedLineCount { 109 | t.Errorf("Expected lineCount: %d, got: %d", tc.expectedLineCount, lineCount) 110 | } 111 | }) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /internal/utils/prompt.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | 10 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 11 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 12 | ) 13 | 14 | // Prompt returns the prompt by checking all the arguments and stdin. 15 | // If there is no arguments, but data in stdin, stdin will become the prompt. 16 | // If there are arguments and data in stdin, all stdinReplace tokens will be substituted 17 | // with the data in stdin 18 | func Prompt(stdinReplace string, args []string) (string, error) { 19 | debug := misc.Truthy(os.Getenv("DEBUG")) 20 | if debug { 21 | ancli.PrintOK(fmt.Sprintf("stdinReplace: %v\n", stdinReplace)) 22 | } 23 | fi, err := os.Stdin.Stat() 24 | if err != nil { 25 | panic(err) 26 | } 27 | var hasPipe bool 28 | if fi.Mode()&os.ModeNamedPipe == 0 { 29 | hasPipe = false 30 | } else { 31 | hasPipe = true 32 | } 33 | 34 | if len(args) == 1 && !hasPipe { 35 | return "", errors.New("found no prompt, set args or pipe in some string") 36 | } 37 | // First argument is the command, so we skip it 38 | args = args[1:] 39 | // If no data is in stdin, simply return args 40 | if !hasPipe { 41 | return strings.Join(args, " "), nil 42 | } 43 | 44 | inputData, err := io.ReadAll(os.Stdin) 45 | if err != nil { 46 | return "", fmt.Errorf("failed to read stdin: %v", err) 47 | } 48 | pipeIn := string(inputData) 49 | // Add the pipeIn to the args if there are no args 50 | if len(args) == 0 { 51 | args = append(args, strings.Split(pipeIn, " ")...) 52 | } else if stdinReplace == "" && hasPipe { 53 | stdinReplace = "{}" 54 | args = append(args, "{}") 55 | } 56 | 57 | // Replace all occurrence of stdinReplaceSignal with pipeIn 58 | if stdinReplace != "" { 59 | if debug { 60 | ancli.PrintOK(fmt.Sprintf("attempting to replace: '%v' with stdin\n", stdinReplace)) 61 | } 62 | for i, arg := range args { 63 | if strings.Contains(arg, stdinReplace) { 64 | args[i] = strings.ReplaceAll(arg, stdinReplace, pipeIn) 65 | } 66 | } 67 | } 68 | 69 | if debug { 70 | ancli.PrintOK(fmt.Sprintf("args: %v\n", args)) 71 | } 72 | return strings.Join(args, " "), nil 73 | } 74 | 75 | func ReplaceTildeWithHome(s string) (string, error) { 76 | home, err := os.UserHomeDir() 77 | if err != nil && strings.Contains(s, "~/") { // only fail if glob contains ~/ and home dir is not found 78 | return "", fmt.Errorf("failed to get home dir: %w", err) 79 | } 80 | return strings.ReplaceAll(s, "~", home), nil 81 | } 82 | -------------------------------------------------------------------------------- /internal/utils/prompt_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestPrompt(t *testing.T) { 9 | testCases := []struct { 10 | name string 11 | stdinReplace string 12 | args []string 13 | stdin string 14 | expectedPrompt string 15 | expectedError bool 16 | }{ 17 | { 18 | name: "No arguments and no stdin", 19 | stdinReplace: "", 20 | args: []string{""}, 21 | stdin: "", 22 | expectedPrompt: "", 23 | expectedError: true, 24 | }, 25 | { 26 | name: "Arguments only", 27 | stdinReplace: "", 28 | args: []string{"cmd", "arg1", "arg2"}, 29 | stdin: "", 30 | expectedPrompt: "arg1 arg2", 31 | expectedError: false, 32 | }, 33 | { 34 | name: "Stdin only", 35 | stdinReplace: "", 36 | args: []string{"cmd"}, 37 | stdin: "input from stdin", 38 | expectedPrompt: "input from stdin", 39 | expectedError: false, 40 | }, 41 | { 42 | name: "Arguments and stdin", 43 | stdinReplace: "{}", 44 | args: []string{"cmd", "arg1", "arg2", "{}"}, 45 | stdin: "input from stdin", 46 | expectedPrompt: "arg1 arg2 input from stdin", 47 | expectedError: false, 48 | }, 49 | { 50 | name: "Arguments with stdinReplace", 51 | stdinReplace: "", 52 | args: []string{"cmd", "prefix", "", "suffix"}, 53 | stdin: "input from stdin", 54 | expectedPrompt: "prefix input from stdin suffix", 55 | expectedError: false, 56 | }, 57 | { 58 | name: "Arguments with stdinReplace", 59 | stdinReplace: "", 60 | args: []string{"cmd", "prefix", "suffix"}, 61 | stdin: "input from stdin", 62 | expectedPrompt: "prefix suffix input from stdin", 63 | expectedError: false, 64 | }, 65 | } 66 | 67 | for _, tc := range testCases { 68 | t.Run(tc.name, func(t *testing.T) { 69 | if tc.stdin != "" { 70 | // Set up stdin 71 | oldStdin := os.Stdin 72 | t.Cleanup(func() { os.Stdin = oldStdin }) 73 | r, w, err := os.Pipe() 74 | if err != nil { 75 | t.Fatal(err) 76 | } 77 | os.Stdin = r 78 | _, err = w.WriteString(tc.stdin) 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | w.Close() 83 | } 84 | 85 | // Call the function 86 | prompt, err := Prompt(tc.stdinReplace, tc.args) 87 | 88 | // Check the error 89 | if tc.expectedError && err == nil { 90 | t.Error("Expected an error, but got nil") 91 | } else if !tc.expectedError && err != nil { 92 | t.Errorf("Unexpected error: %v", err) 93 | } 94 | 95 | // Check the prompt 96 | if prompt != tc.expectedPrompt { 97 | t.Errorf("Prompt mismatch. Expected: %q, Got: %q", tc.expectedPrompt, prompt) 98 | } 99 | }) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /internal/utils/term.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "syscall" 5 | "unsafe" 6 | ) 7 | 8 | func TermWidth() (int, error) { 9 | ws := &struct { 10 | Row uint16 11 | Col uint16 12 | Xpixel uint16 13 | Ypixel uint16 14 | }{} 15 | 16 | retCode, _, errno := syscall.Syscall( 17 | syscall.SYS_IOCTL, 18 | uintptr(syscall.Stderr), 19 | uintptr(syscall.TIOCGWINSZ), 20 | uintptr(unsafe.Pointer(ws)), 21 | ) 22 | 23 | if int(retCode) == -1 { 24 | return 0, errno 25 | } 26 | 27 | return int(ws.Col), nil 28 | } 29 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/baalimago/clai/internal/models" 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | type Claude struct { 11 | Model string `json:"model"` 12 | MaxTokens int `json:"max_tokens"` 13 | Url string `json:"url"` 14 | AnthropicVersion string `json:"anthropic-version"` 15 | AnthropicBeta string `json:"anthropic-beta"` 16 | Temperature float64 `json:"temperature"` 17 | TopP float64 `json:"top_p"` 18 | TopK int `json:"top_k"` 19 | StopSequences []string `json:"stop_sequences"` 20 | client *http.Client `json:"-"` 21 | apiKey string `json:"-"` 22 | debug bool `json:"-"` 23 | tools []tools.UserFunction `json:"-"` 24 | functionName string `json:"-"` 25 | functionJson string `json:"-"` 26 | contentBlockType string `json:"-"` 27 | } 28 | 29 | var CLAUDE_DEFAULT = Claude{ 30 | Model: "claude-3-7-sonnet-latest", 31 | Url: ClaudeURL, 32 | AnthropicVersion: "2023-06-01", 33 | AnthropicBeta: "tools-2024-04-04", 34 | Temperature: 0.7, 35 | MaxTokens: 1024, 36 | TopP: -1, 37 | TopK: -1, 38 | StopSequences: make([]string, 0), 39 | } 40 | 41 | type claudeReq struct { 42 | Model string `json:"model"` 43 | Messages []models.Message `json:"messages"` 44 | MaxTokens int `json:"max_tokens"` 45 | Stream bool `json:"stream"` 46 | System string `json:"system"` 47 | Temperature float64 `json:"temperature"` 48 | TopP float64 `json:"top_p"` 49 | TopK int `json:"top_k"` 50 | StopSequences []string `json:"stop_sequences"` 51 | Tools []tools.UserFunction `json:"tools,omitempty"` 52 | } 53 | 54 | // claudifyMessages converts from 'normal' openai chat format into a format which claud prefers 55 | func claudifyMessages(msgs []models.Message) []models.Message { 56 | cleanedMsgs := make([]models.Message, 0, len(msgs)) 57 | // Remove any additional fields from the messages 58 | for _, msg := range msgs { 59 | cleanedMsgs = append(cleanedMsgs, models.Message{ 60 | Role: msg.Role, 61 | Content: msg.Content, 62 | }) 63 | } 64 | msgs = cleanedMsgs 65 | 66 | // If the first message is a system one, assume it's the system prompt and pop it 67 | if msgs[0].Role == "system" { 68 | msgs = msgs[1:] 69 | } 70 | 71 | // Convert system messages from 'system' to 'assistant' 72 | for i, v := range msgs { 73 | if v.Role == "system" { 74 | msgs[i].Role = "assistant" 75 | } 76 | } 77 | 78 | for i, v := range msgs { 79 | if v.Role == "tool" { 80 | msgs[i].Role = "user" 81 | } 82 | } 83 | 84 | // Merge consecutive assistant messages into the first one 85 | for i := 1; i < len(msgs); i++ { 86 | if msgs[i].Role == "assistant" && msgs[i-1].Role == "assistant" { 87 | msgs[i-1].Content += "\n" + msgs[i].Content 88 | msgs = append(msgs[:i], msgs[i+1:]...) 89 | i-- 90 | } 91 | } 92 | 93 | // Merge consecutive user messages into the last one 94 | for i := len(msgs) - 2; i >= 0; i-- { 95 | if msgs[i].Role == "user" && msgs[i+1].Role == "user" { 96 | msgs[i+1].Content = msgs[i].Content + "\n" + msgs[i+1].Content 97 | msgs = append(msgs[:i], msgs[i+1:]...) 98 | } 99 | } 100 | 101 | // If the first message is from an assistant, keep it as is 102 | // (no need to merge it into the upcoming user message) 103 | 104 | // If the last message is from an assistant, remove it 105 | if len(msgs) > 0 && msgs[len(msgs)-1].Role == "assistant" { 106 | msgs = msgs[:len(msgs)-1] 107 | } 108 | 109 | return msgs 110 | } 111 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_models.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/tools" 5 | ) 6 | 7 | type ClaudeResponse struct { 8 | Content []ClaudeMessage `json:"content"` 9 | ID string `json:"id"` 10 | Model string `json:"model"` 11 | Role string `json:"role"` 12 | StopReason string `json:"stop_reason"` 13 | StopSequence any `json:"stop_sequence"` 14 | Type string `json:"type"` 15 | Usage TokenInfo `json:"usage"` 16 | } 17 | 18 | type ClaudeMessage struct { 19 | ID string `json:"id,omitempty"` 20 | Input tools.Input `json:"input,omitempty"` 21 | Name string `json:"name,omitempty"` 22 | Text string `json:"text,omitempty"` 23 | Type string `json:"type"` 24 | } 25 | 26 | type TokenInfo struct { 27 | InputTokens int `json:"input_tokens"` 28 | OutputTokens int `json:"output_tokens"` 29 | } 30 | 31 | type Delta struct { 32 | Type string `json:"type"` 33 | Text string `json:"text,omitempty"` 34 | PartialJson string `json:"partial_json,omitempty"` 35 | } 36 | 37 | type ContentBlockDelta struct { 38 | Type string `json:"type"` 39 | Index int `json:"index"` 40 | Delta Delta `json:"delta"` 41 | } 42 | 43 | type ContentBlockSuper struct { 44 | Type string `json:"type"` 45 | Index int `json:"index"` 46 | ContentBlock ContentBlock `json:"content_block"` 47 | } 48 | 49 | type ContentBlock struct { 50 | Type string `json:"type"` 51 | ID string `json:"id"` 52 | Name string `json:"name"` 53 | Input map[string]interface{} `json:"input"` 54 | } 55 | 56 | type Root struct { 57 | Type string `json:"type"` 58 | Index int `json:"index"` 59 | ContentBlock ContentBlock `json:"content_block"` 60 | } 61 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_setup.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/baalimago/clai/internal/tools" 9 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 10 | ) 11 | 12 | func (c *Claude) Setup() error { 13 | apiKey := os.Getenv("ANTHROPIC_API_KEY") 14 | if apiKey == "" { 15 | return fmt.Errorf("environment variable 'ANTHROPIC_API_KEY' not set") 16 | } 17 | c.client = &http.Client{} 18 | c.apiKey = apiKey 19 | if misc.Truthy(os.Getenv("DEBUG")) || misc.Truthy(os.Getenv("ANTHROPIC_DEBUG")) { 20 | c.debug = true 21 | } 22 | return nil 23 | } 24 | 25 | func (c *Claude) RegisterTool(tool tools.AiTool) { 26 | c.tools = append(c.tools, tool.UserFunction()) 27 | } 28 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_setup_test.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import "testing" 4 | 5 | func Test_Setup(t *testing.T) { 6 | c := Claude{} 7 | 8 | t.Run("it should load environment variable from ANTHROPIC_API_KEY", func(t *testing.T) { 9 | want := "some-key" 10 | t.Setenv("ANTHROPIC_API_KEY", want) 11 | err := c.Setup() 12 | if err != nil { 13 | t.Fatalf("failed to run setup: %v", err) 14 | } 15 | got := c.apiKey 16 | if got != want { 17 | t.Fatalf("expected: %v, got: %v", want, got) 18 | } 19 | }) 20 | } 21 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_stream.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "strings" 13 | 14 | "github.com/baalimago/clai/internal/models" 15 | "github.com/baalimago/clai/internal/tools" 16 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 17 | ) 18 | 19 | func (c *Claude) StreamCompletions(ctx context.Context, chat models.Chat) (chan models.CompletionEvent, error) { 20 | req, err := c.constructRequest(ctx, chat) 21 | if err != nil { 22 | return nil, fmt.Errorf("failed to construct request: %w", err) 23 | } 24 | 25 | return c.stream(ctx, req) 26 | } 27 | 28 | func (c *Claude) stream(ctx context.Context, req *http.Request) (chan models.CompletionEvent, error) { 29 | resp, err := c.client.Do(req) 30 | if err != nil { 31 | return nil, fmt.Errorf("failed to do request: %w", err) 32 | } 33 | 34 | if resp.StatusCode != http.StatusOK { 35 | body, _ := io.ReadAll(resp.Body) 36 | return nil, fmt.Errorf("failed to execute request: %v, body: %v", resp.Status, string(body)) 37 | } 38 | 39 | outChan, err := c.handleStreamResponse(ctx, resp) 40 | if err != nil { 41 | return nil, fmt.Errorf("failed to parse response: %w", err) 42 | } 43 | return outChan, nil 44 | } 45 | 46 | func (c *Claude) handleStreamResponse(ctx context.Context, resp *http.Response) (chan models.CompletionEvent, error) { 47 | outChan := make(chan models.CompletionEvent) 48 | go func() { 49 | br := bufio.NewReader(resp.Body) 50 | defer func() { 51 | resp.Body.Close() 52 | close(outChan) 53 | }() 54 | for { 55 | token, err := br.ReadString('\n') 56 | if err != nil { 57 | if errors.Is(err, io.EOF) { 58 | if token != "" { 59 | c.handleFullResponse(token, outChan) 60 | } else { 61 | outChan <- err 62 | } 63 | } 64 | outChan <- models.CompletionEvent(fmt.Errorf("failed to read line: %w", err)) 65 | return 66 | } 67 | token = strings.TrimSpace(token) 68 | if ctx.Err() != nil { 69 | outChan <- models.CompletionEvent(errors.New("context cancelled")) 70 | return 71 | } 72 | if token == "" { 73 | continue 74 | } 75 | outChan <- c.handleToken(br, token) 76 | } 77 | }() 78 | return outChan, nil 79 | } 80 | 81 | func (c *Claude) handleFullResponse(token string, outChan chan models.CompletionEvent) { 82 | var rspBody ClaudeResponse 83 | err := json.Unmarshal([]byte(token), &rspBody) 84 | if err != nil { 85 | outChan <- models.CompletionEvent(fmt.Errorf("failed to unmarshal response: %w, resp body as string: %v", err, token)) 86 | return 87 | } 88 | for _, content := range rspBody.Content { 89 | switch content.Type { 90 | case "text": 91 | outChan <- content.Text 92 | case "tool_use": 93 | outChan <- tools.Call{ 94 | Name: content.Name, 95 | Inputs: content.Input, 96 | } 97 | } 98 | } 99 | } 100 | 101 | func (c *Claude) handleToken(br *bufio.Reader, token string) models.CompletionEvent { 102 | tokSplit := strings.Split(token, " ") 103 | if len(tokSplit) != 2 { 104 | return fmt.Errorf("unexpected token length for token: '%v', expected format: 'event: '", token) 105 | } 106 | eventTok := tokSplit[0] 107 | eventType := tokSplit[1] 108 | if eventTok != "event:" { 109 | return fmt.Errorf("unexpected token, want: 'event:', got: '%v'", eventTok) 110 | } 111 | eventType = strings.TrimSpace(eventType) 112 | if c.debug { 113 | fmt.Printf("eventTok: '%v', eventType: '%s'\n", eventTok, eventType) 114 | } 115 | switch eventType { 116 | case "message_stop": 117 | return io.EOF 118 | 119 | case "content_block_start": 120 | blockStart, err := br.ReadString('\n') 121 | if err != nil { 122 | return fmt.Errorf("failed to read content_block_delta: %w", err) 123 | } 124 | return c.handleContentBlockStart(blockStart) 125 | // TODO: Print token amount 126 | case "content_block_delta": 127 | deltaToken, err := br.ReadString('\n') 128 | if err != nil { 129 | return fmt.Errorf("failed to read content_block_delta: %w", err) 130 | } 131 | return c.handleContentBlockDelta(deltaToken) 132 | case "content_block_stop": 133 | blockStop, err := br.ReadString('\n') 134 | if err != nil { 135 | return fmt.Errorf("failed to read content_block_stop: %w", err) 136 | } 137 | return c.handleContentBlockStop(blockStop) 138 | } 139 | 140 | // Jump down one line to setup next event 141 | br.ReadString('\n') 142 | return models.NoopEvent{} 143 | } 144 | 145 | func trimDataPrefix(data string) string { 146 | return strings.TrimPrefix(data, "data: ") 147 | } 148 | 149 | func (c *Claude) stringFromDeltaToken(deltaToken string) (Delta, error) { 150 | deltaTokSplit := strings.Split(deltaToken, " ") 151 | if deltaTokSplit[0] != "data:" { 152 | return Delta{}, fmt.Errorf("unexpected split token. Expected: 'data:', got: '%v'", deltaTokSplit[0]) 153 | } 154 | deltaJsonString := strings.Join(deltaTokSplit[1:], " ") 155 | var contentBlockDelta ContentBlockDelta 156 | err := json.Unmarshal([]byte(deltaJsonString), &contentBlockDelta) 157 | if err != nil { 158 | return Delta{}, fmt.Errorf("failed to unmarshal deltaJsonString: '%v' to struct, err: %w", deltaJsonString, err) 159 | } 160 | if c.debug { 161 | ancli.PrintOK(fmt.Sprintf("delta struct: %+v\nstring: %v", contentBlockDelta, deltaJsonString)) 162 | } 163 | return contentBlockDelta.Delta, nil 164 | } 165 | 166 | func (c *Claude) constructRequest(ctx context.Context, chat models.Chat) (*http.Request, error) { 167 | // ignored for now as error is not used 168 | sysMsg, _ := chat.FirstSystemMessage() 169 | if c.debug { 170 | ancli.PrintOK(fmt.Sprintf("pre-claudified messages: %+v\n", chat.Messages)) 171 | } 172 | msgCopy := make([]models.Message, len(chat.Messages)) 173 | copy(msgCopy, chat.Messages) 174 | claudifiedMsgs := claudifyMessages(msgCopy) 175 | if c.debug { 176 | ancli.PrintOK(fmt.Sprintf("claudified messages: %+v\n", claudifiedMsgs)) 177 | } 178 | 179 | reqData := claudeReq{ 180 | Model: c.Model, 181 | Messages: claudifiedMsgs, 182 | MaxTokens: c.MaxTokens, 183 | Stream: true, 184 | System: sysMsg.Content, 185 | Temperature: c.Temperature, 186 | TopP: c.TopP, 187 | TopK: c.TopK, 188 | StopSequences: c.StopSequences, 189 | } 190 | if len(c.tools) > 0 { 191 | reqData.Tools = c.tools 192 | } 193 | jsonData, err := json.Marshal(reqData) 194 | if err != nil { 195 | return nil, fmt.Errorf("failed to marshal ClaudeReq: %w", err) 196 | } 197 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.Url, bytes.NewBuffer(jsonData)) 198 | if err != nil { 199 | return nil, fmt.Errorf("failed to create request: %w", err) 200 | } 201 | req.Header.Set("Content-Type", "application/json") 202 | req.Header.Set("x-api-key", c.apiKey) 203 | req.Header.Set("anthropic-version", c.AnthropicVersion) 204 | req.Header.Set("anthropic-beta", c.AnthropicBeta) 205 | if c.debug { 206 | ancli.PrintOK(fmt.Sprintf("Request: %+v\n", req)) 207 | } 208 | return req, nil 209 | } 210 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_stream_block_events.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/baalimago/clai/internal/models" 9 | "github.com/baalimago/clai/internal/tools" 10 | ) 11 | 12 | func (c *Claude) handleContentBlockStart(blockStart string) models.CompletionEvent { 13 | var blockSuper ContentBlockSuper 14 | blockStart = trimDataPrefix(blockStart) 15 | if err := json.Unmarshal([]byte(blockStart), &blockSuper); err != nil { 16 | return fmt.Errorf("failed to unmarshal blockStart with content: %v, error: %w", blockStart, err) 17 | } 18 | block := blockSuper.ContentBlock 19 | c.contentBlockType = block.Type 20 | switch block.Type { 21 | case "tool_use": 22 | c.functionName = block.Name 23 | } 24 | return models.NoopEvent{} 25 | } 26 | 27 | func (c *Claude) handleContentBlockDelta(deltaToken string) models.CompletionEvent { 28 | delta, err := c.stringFromDeltaToken(deltaToken) 29 | if err != nil { 30 | return fmt.Errorf("failed to convert string to delta token: %w", err) 31 | } 32 | if c.debug { 33 | fmt.Printf("deltaToken: '%v', claudeMsg: '%v'", deltaToken, delta) 34 | } 35 | switch delta.Type { 36 | case "text_delta": 37 | if delta.Text == "" { 38 | return errors.New("unexpected empty response") 39 | } 40 | return delta.Text 41 | case "input_json_delta": 42 | return c.handleInputJsonDelta(delta) 43 | default: 44 | return fmt.Errorf("unexpected delta type: %v", delta.Type) 45 | } 46 | } 47 | 48 | func (c *Claude) handleInputJsonDelta(delta Delta) models.CompletionEvent { 49 | partial := delta.PartialJson 50 | c.functionJson += partial 51 | return partial 52 | } 53 | 54 | func (c *Claude) handleContentBlockStop(blockStop string) models.CompletionEvent { 55 | var block ContentBlock 56 | blockStop = trimDataPrefix(blockStop) 57 | if err := json.Unmarshal([]byte(blockStop), &block); err != nil { 58 | return fmt.Errorf("failed to unmarshal blockStop: %w", err) 59 | } 60 | 61 | switch c.contentBlockType { 62 | case "tool_use": 63 | var inputs tools.Input 64 | if err := json.Unmarshal([]byte(c.functionJson), &inputs); err != nil { 65 | return fmt.Errorf("failed to unmarshal functionJson: %v, error is: %w", c.functionJson, err) 66 | } 67 | c.functionJson = "" 68 | return tools.Call{ 69 | Name: c.functionName, 70 | Inputs: inputs, 71 | } 72 | } 73 | return models.NoopEvent{} 74 | } 75 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_stream_test.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/baalimago/clai/internal/models" 13 | "github.com/baalimago/go_away_boilerplate/pkg/testboil" 14 | ) 15 | 16 | func Test_StreamCompletions(t *testing.T) { 17 | want := "Hello!" 18 | messages := [][]byte{ 19 | []byte(`event: message_start 20 | data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} 21 | 22 | `), 23 | []byte(`event: content_block_start 24 | data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}} 25 | 26 | `), 27 | 28 | []byte(`event: ping 29 | data: {"type": "ping"} 30 | 31 | `), 32 | // This should be picked up 33 | []byte(`event: content_block_delta 34 | data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} 35 | 36 | `), 37 | // This should also be picked up 38 | []byte(`event: content_block_delta 39 | data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} 40 | 41 | `), 42 | []byte(`event: content_block_stop 43 | data: {"type": "content_block_stop", "index": 0} 44 | 45 | `), 46 | []byte(`event: message_delta 47 | data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}} 48 | 49 | `), 50 | []byte(`event: message_stop 51 | data: {"type": "message_stop"} 52 | 53 | `), 54 | } 55 | testDone := make(chan string) 56 | testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 | w.Header().Set("Access-Control-Allow-Origin", "*") 58 | w.Header().Set("Access-Control-Expose-Headers", "Content-Type") 59 | w.Header().Set("Content-Type", "text/event-stream") 60 | w.Header().Set("Cache-Control", "no-cache") 61 | w.Header().Set("Connection", "keep-alive") 62 | for _, msg := range messages { 63 | w.Write(msg) 64 | w.(http.Flusher).Flush() 65 | } 66 | <-testDone 67 | })) 68 | context, contextCancel := context.WithTimeout(context.Background(), time.Second/100) 69 | t.Cleanup(func() { 70 | contextCancel() 71 | // Can't seem to figure out how to close the testserver. so well... it'll have to remain open 72 | // testServer.Close() 73 | close(testDone) 74 | }) 75 | 76 | // Use the test server's URL as the backend URL in your code 77 | c := Claude{ 78 | Url: testServer.URL, 79 | } 80 | t.Setenv("ANTHROPIC_API_KEY", "somekey") 81 | err := c.Setup() 82 | if err != nil { 83 | t.Fatalf("failed to setup claude: %v", err) 84 | } 85 | out, err := c.StreamCompletions(context, models.Chat{ 86 | ID: "test", 87 | Messages: []models.Message{ 88 | {Role: "system", Content: "test"}, 89 | {Role: "user", Content: "test"}, 90 | }, 91 | }) 92 | if err != nil { 93 | t.Fatalf("failed to stream completions: %v", err) 94 | } 95 | 96 | got := "" 97 | OUTER: 98 | for { 99 | select { 100 | case <-context.Done(): 101 | t.Fatal("test timeout") 102 | case tok, ok := <-out: 103 | if !ok { 104 | break OUTER 105 | } 106 | switch sel := tok.(type) { 107 | case string: 108 | got += sel 109 | case error: 110 | if errors.Is(sel, io.EOF) { 111 | break OUTER 112 | } 113 | t.Fatalf("unexpected error: %v", sel) 114 | } 115 | } 116 | } 117 | 118 | if got != want { 119 | t.Fatalf("expected: %v, got: %v", want, got) 120 | } 121 | } 122 | 123 | func Test_context(t *testing.T) { 124 | testDone := make(chan struct{}) 125 | testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 126 | <-testDone 127 | })) 128 | t.Cleanup(func() { 129 | testServer.Close() 130 | close(testDone) 131 | }) 132 | 133 | // Use the test server's URL as the backend URL in your code 134 | c := Claude{ 135 | Url: testServer.URL, 136 | } 137 | t.Setenv("ANTHROPIC_API_KEY", "somekey") 138 | err := c.Setup() 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | testboil.ReturnsOnContextCancel(t, func(ctx context.Context) { 143 | c.StreamCompletions(ctx, models.Chat{ 144 | ID: "test", 145 | Messages: []models.Message{ 146 | {Role: "system", Content: "test"}, 147 | {Role: "user", Content: "test"}, 148 | }, 149 | }) 150 | }, time.Second) 151 | } 152 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/claude_test.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/baalimago/clai/internal/models" 7 | ) 8 | 9 | func Test_claudifyMessage(t *testing.T) { 10 | testCases := []struct { 11 | desc string 12 | given []models.Message 13 | want []models.Message 14 | }{ 15 | { 16 | desc: "it should remove the first message if it is a system message", 17 | given: []models.Message{ 18 | {Role: "system", Content: "system message"}, 19 | {Role: "user", Content: "user message"}, 20 | }, 21 | want: []models.Message{ 22 | {Role: "user", Content: "user message"}, 23 | }, 24 | }, 25 | { 26 | desc: "it should convert system messages to assistant messages", 27 | given: []models.Message{ 28 | {Role: "user", Content: "user message"}, 29 | {Role: "system", Content: "system message"}, 30 | {Role: "user", Content: "user message"}, 31 | }, 32 | want: []models.Message{ 33 | {Role: "user", Content: "user message"}, 34 | {Role: "assistant", Content: "system message"}, 35 | {Role: "user", Content: "user message"}, 36 | }, 37 | }, 38 | { 39 | desc: "it should merge user messages into the upcoming message", 40 | given: []models.Message{ 41 | {Role: "user", Content: "user message 1"}, 42 | {Role: "assistant", Content: "assistant message 1"}, 43 | {Role: "user", Content: "user message 2"}, 44 | {Role: "user", Content: "user message 3"}, 45 | }, 46 | want: []models.Message{ 47 | {Role: "user", Content: "user message 1"}, 48 | {Role: "assistant", Content: "assistant message 1"}, 49 | {Role: "user", Content: "user message 2\nuser message 3"}, 50 | }, 51 | }, 52 | { 53 | desc: "glob message should start with assistant message", 54 | given: []models.Message{ 55 | {Role: "system", Content: "system message 1"}, 56 | {Role: "system", Content: "system message 2"}, 57 | {Role: "user", Content: "user message 1"}, 58 | }, 59 | want: []models.Message{ 60 | {Role: "assistant", Content: "system message 2"}, 61 | {Role: "user", Content: "user message 1"}, 62 | }, 63 | }, 64 | { 65 | desc: "tricky example 1", 66 | given: []models.Message{ 67 | {Role: "user", Content: "user message 1"}, 68 | {Role: "user", Content: "user message 2"}, 69 | {Role: "assistant", Content: "assistant message 1"}, 70 | {Role: "user", Content: "user message 3"}, 71 | {Role: "user", Content: "user message 4"}, 72 | {Role: "user", Content: "user message 5"}, 73 | }, 74 | want: []models.Message{ 75 | {Role: "user", Content: "user message 1\nuser message 2"}, 76 | {Role: "assistant", Content: "assistant message 1"}, 77 | {Role: "user", Content: "user message 3\nuser message 4\nuser message 5"}, 78 | }, 79 | }, 80 | } 81 | 82 | for _, tC := range testCases { 83 | t.Run(tC.desc, func(t *testing.T) { 84 | got := claudifyMessages(tC.given) 85 | if len(tC.want) != len(got) { 86 | t.Fatalf("incorrect length. expected: %v, got: %v", tC.want, got) 87 | } 88 | 89 | for i := range tC.want { 90 | if tC.want[i].Role != got[i].Role { 91 | t.Fatalf("expected: %q, got: %q", tC.want[i].Role, got[i].Role) 92 | } 93 | if tC.want[i].Content != got[i].Content { 94 | t.Fatalf("expected: %q, got: %q", tC.want[i].Content, got[i].Content) 95 | } 96 | } 97 | }) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /internal/vendors/anthropic/constants.go: -------------------------------------------------------------------------------- 1 | package anthropic 2 | 3 | const ClaudeURL = "https://api.anthropic.com/v1/messages" 4 | -------------------------------------------------------------------------------- /internal/vendors/deepseek/deepseek.go: -------------------------------------------------------------------------------- 1 | package deepseek 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/text/generic" 5 | ) 6 | 7 | var DEEPSEEK_DEFAULT = Deepseek{ 8 | Model: "deepseek-chat", 9 | Temperature: 1.0, 10 | TopP: 1.0, 11 | Url: ChatURL, 12 | } 13 | 14 | type Deepseek struct { 15 | generic.StreamCompleter 16 | Model string `json:"model"` 17 | FrequencyPenalty float64 `json:"frequency_penalty"` 18 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 19 | PresencePenalty float64 `json:"presence_penalty"` 20 | Temperature float64 `json:"temperature"` 21 | TopP float64 `json:"top_p"` 22 | Url string `json:"url"` 23 | } 24 | -------------------------------------------------------------------------------- /internal/vendors/deepseek/deepseek_setup.go: -------------------------------------------------------------------------------- 1 | package deepseek 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | const ChatURL = "https://api.deepseek.com/chat/completions" 11 | 12 | func (g *Deepseek) Setup() error { 13 | if os.Getenv("DEEPSEEK_API_KEY") == "" { 14 | os.Setenv("DEEPSEEK_API_KEY", "deepseek") 15 | } 16 | err := g.StreamCompleter.Setup("DEEPSEEK_API_KEY", ChatURL, "DEEPSEEK_DEBUG") 17 | if err != nil { 18 | return fmt.Errorf("failed to setup stream completer: %w", err) 19 | } 20 | g.StreamCompleter.Model = g.Model 21 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 22 | g.StreamCompleter.MaxTokens = g.MaxTokens 23 | g.StreamCompleter.Temperature = &g.Temperature 24 | g.StreamCompleter.TopP = &g.TopP 25 | toolChoice := "auto" 26 | g.ToolChoice = &toolChoice 27 | return nil 28 | } 29 | 30 | func (g *Deepseek) RegisterTool(tool tools.AiTool) { 31 | g.InternalRegisterTool(tool) 32 | } 33 | -------------------------------------------------------------------------------- /internal/vendors/deepseek/models.go: -------------------------------------------------------------------------------- 1 | package deepseek 2 | 3 | import "github.com/baalimago/clai/internal/tools" 4 | 5 | // Copy from novita, and openai 6 | type ChatCompletion struct { 7 | ID string `json:"id"` 8 | Object string `json:"object"` 9 | Created int64 `json:"created"` 10 | Model string `json:"model"` 11 | Choices []Choice `json:"choices"` 12 | Usage Usage `json:"usage"` 13 | SystemFingerprint string `json:"system_fingerprint"` 14 | } 15 | 16 | type Choice struct { 17 | Index int `json:"index"` 18 | Delta Delta `json:"delta"` 19 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 20 | FinishReason string `json:"finish_reason"` 21 | } 22 | 23 | type Usage struct { 24 | PromptTokens int `json:"prompt_tokens"` 25 | CompletionTokens int `json:"completion_tokens"` 26 | TotalTokens int `json:"total_tokens"` 27 | } 28 | 29 | type Delta struct { 30 | Content any `json:"content"` 31 | Role string `json:"role"` 32 | ToolCalls []ToolsCall `json:"tool_calls"` 33 | } 34 | 35 | type ToolsCall struct { 36 | Function GptFunc `json:"function"` 37 | ID string `json:"id"` 38 | Index int `json:"index"` 39 | Type string `json:"type"` 40 | } 41 | 42 | type GptFunc struct { 43 | Arguments string `json:"arguments"` 44 | Name string `json:"name"` 45 | } 46 | 47 | type GptTool struct { 48 | Name string `json:"name"` 49 | Description string `json:"description"` 50 | Inputs tools.InputSchema `json:"parameters"` 51 | } 52 | 53 | type GptToolSuper struct { 54 | Type string `json:"type"` 55 | Function GptTool `json:"function"` 56 | } 57 | -------------------------------------------------------------------------------- /internal/vendors/mistral/mistral.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/models" 5 | "github.com/baalimago/clai/internal/text/generic" 6 | ) 7 | 8 | const MistralURL = "https://api.mistral.ai/v1/chat/completions" 9 | 10 | var MINSTRAL_DEFAULT = Mistral{ 11 | Model: "mistral-large-latest", 12 | Temperature: 0.7, 13 | TopP: 1.0, 14 | Url: MistralURL, 15 | MaxTokens: 100000, 16 | } 17 | 18 | type Mistral struct { 19 | generic.StreamCompleter 20 | Model string `json:"model"` 21 | Url string `json:"url"` 22 | TopP float64 `json:"top_p"` 23 | Temperature float64 `json:"temperature"` 24 | SafePrompt bool `json:"safe_prompt"` 25 | MaxTokens int `json:"max_tokens"` 26 | RandomSeed int `json:"random_seed"` 27 | } 28 | 29 | func clean(msg []models.Message) []models.Message { 30 | // Mistral doesn't like additional fields in the tools call 31 | for i, m := range msg { 32 | if m.Role == "assistant" { 33 | if len(m.ToolCalls) > 0 { 34 | m.Content = "" 35 | } 36 | for j, tc := range m.ToolCalls { 37 | tc.Name = "" 38 | tc.Inputs = nil 39 | tc.Function.Description = "" 40 | m.ToolCalls[j] = tc 41 | } 42 | } 43 | msg[i] = m 44 | } 45 | 46 | for i := 0; i < len(msg)-1; i++ { 47 | if msg[i].Role == "tool" && msg[i+1].Role == "system" { 48 | msg[i+1].Role = "assistant" 49 | } 50 | } 51 | 52 | // Merge consequtive assistant messages 53 | for i := 1; i < len(msg); i++ { 54 | if msg[i].Role == "assistant" && msg[i-1].Role == "assistant" { 55 | msg[i-1].Content += "\n" + msg[i].Content 56 | msg = append(msg[:i], msg[i+1:]...) 57 | i-- 58 | } 59 | } 60 | 61 | return msg 62 | } 63 | -------------------------------------------------------------------------------- /internal/vendors/mistral/mistral_setup.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/baalimago/clai/internal/models" 8 | "github.com/baalimago/clai/internal/tools" 9 | ) 10 | 11 | func (m *Mistral) Setup() error { 12 | err := m.StreamCompleter.Setup("MISTRAL_API_KEY", MistralURL, "DEBUG_MISTRAL") 13 | if err != nil { 14 | return fmt.Errorf("failed to setup stream completer: %w", err) 15 | } 16 | m.StreamCompleter.Model = m.Model 17 | m.StreamCompleter.FrequencyPenalty = m.FrequencyPenalty 18 | m.StreamCompleter.MaxTokens = &m.MaxTokens 19 | m.StreamCompleter.Temperature = &m.Temperature 20 | m.StreamCompleter.TopP = &m.TopP 21 | toolChoice := "auto" 22 | m.StreamCompleter.ToolChoice = &toolChoice 23 | m.StreamCompleter.Clean = clean 24 | 25 | return nil 26 | } 27 | 28 | func (m *Mistral) StreamCompletions(ctx context.Context, chat models.Chat) (chan models.CompletionEvent, error) { 29 | return m.StreamCompleter.StreamCompletions(ctx, chat) 30 | } 31 | 32 | func (m *Mistral) RegisterTool(tool tools.AiTool) { 33 | m.StreamCompleter.InternalRegisterTool(tool) 34 | } 35 | -------------------------------------------------------------------------------- /internal/vendors/mistral/models.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/models" 5 | "github.com/baalimago/clai/internal/tools" 6 | ) 7 | 8 | type MistralToolSuper struct { 9 | Function MistralTool `json:"function"` 10 | Type string `json:"type"` 11 | } 12 | 13 | type MistralTool struct { 14 | Description string `json:"description"` 15 | Name string `json:"name"` 16 | Parameters tools.InputSchema `json:"parameters"` 17 | } 18 | 19 | type Request struct { 20 | MaxTokens int `json:"max_tokens,omitempty"` 21 | Messages []models.Message `json:"messages,omitempty"` 22 | Model string `json:"model,omitempty"` 23 | RandomSeed int `json:"random_seed,omitempty"` 24 | SafePrompt bool `json:"safe_prompt,omitempty"` 25 | Stream bool `json:"stream,omitempty"` 26 | Temperature float64 `json:"temperature,omitempty"` 27 | ToolChoice string `json:"tool_choice,omitempty"` 28 | Tools []MistralToolSuper `json:"tools,omitempty"` 29 | TopP float64 `json:"top_p,omitempty"` 30 | } 31 | 32 | type Response struct { 33 | FinishReason string `json:"finish_reason"` 34 | Choices []Choice `json:"choices"` 35 | Created int `json:"created"` 36 | ID string `json:"id"` 37 | Model string `json:"model"` 38 | Object string `json:"object"` 39 | Usage Usage `json:"usage"` 40 | } 41 | 42 | type Choice struct { 43 | Index int `json:"index"` 44 | Delta Message `json:"delta"` 45 | } 46 | 47 | type Message struct { 48 | Content string `json:"content"` 49 | Role string `json:"role"` 50 | ToolCalls []struct { 51 | Call Call `json:"function"` 52 | } `json:"tool_calls"` 53 | } 54 | 55 | type Call struct { 56 | Arguments string `json:"arguments"` 57 | Name string `json:"name"` 58 | } 59 | 60 | type Usage struct { 61 | CompletionTokens int `json:"completion_tokens"` 62 | PromptTokens int `json:"prompt_tokens"` 63 | TotalTokens int `json:"total_tokens"` 64 | } 65 | -------------------------------------------------------------------------------- /internal/vendors/novita/models.go: -------------------------------------------------------------------------------- 1 | package novita 2 | 3 | import "github.com/baalimago/clai/internal/tools" 4 | 5 | // since we can use Novita AI in OpenAI compatible mode, we use the same types as `openai` package 6 | type ChatCompletion struct { 7 | ID string `json:"id"` 8 | Object string `json:"object"` 9 | Created int64 `json:"created"` 10 | Model string `json:"model"` 11 | Choices []Choice `json:"choices"` 12 | Usage Usage `json:"usage"` 13 | SystemFingerprint string `json:"system_fingerprint"` 14 | } 15 | 16 | type Choice struct { 17 | Index int `json:"index"` 18 | Delta Delta `json:"delta"` 19 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 20 | FinishReason string `json:"finish_reason"` 21 | } 22 | 23 | type Usage struct { 24 | PromptTokens int `json:"prompt_tokens"` 25 | CompletionTokens int `json:"completion_tokens"` 26 | TotalTokens int `json:"total_tokens"` 27 | } 28 | 29 | type Delta struct { 30 | Content any `json:"content"` 31 | Role string `json:"role"` 32 | ToolCalls []ToolsCall `json:"tool_calls"` 33 | } 34 | 35 | type ToolsCall struct { 36 | Function GptFunc `json:"function"` 37 | ID string `json:"id"` 38 | Index int `json:"index"` 39 | Type string `json:"type"` 40 | } 41 | 42 | type GptFunc struct { 43 | Arguments string `json:"arguments"` 44 | Name string `json:"name"` 45 | } 46 | 47 | type GptTool struct { 48 | Name string `json:"name"` 49 | Description string `json:"description"` 50 | Inputs tools.InputSchema `json:"parameters"` 51 | } 52 | 53 | type GptToolSuper struct { 54 | Type string `json:"type"` 55 | Function GptTool `json:"function"` 56 | } 57 | -------------------------------------------------------------------------------- /internal/vendors/novita/novita.go: -------------------------------------------------------------------------------- 1 | package novita 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/text/generic" 5 | ) 6 | 7 | var NOVITA_DEFAULT = Novita{ 8 | Model: "gryphe/mythomax-l2-13b", 9 | Temperature: 1.0, 10 | TopP: 1.0, 11 | Url: ChatURL, 12 | } 13 | 14 | type Novita struct { 15 | generic.StreamCompleter 16 | Model string `json:"model"` 17 | FrequencyPenalty float64 `json:"frequency_penalty"` 18 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 19 | PresencePenalty float64 `json:"presence_penalty"` 20 | Temperature float64 `json:"temperature"` 21 | TopP float64 `json:"top_p"` 22 | Url string `json:"url"` 23 | } 24 | -------------------------------------------------------------------------------- /internal/vendors/novita/novita_setup.go: -------------------------------------------------------------------------------- 1 | package novita 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | const ChatURL = "https://api.novita.ai/v3/openai/chat/completions" 11 | 12 | func (g *Novita) Setup() error { 13 | if os.Getenv("NOVITA_API_KEY") == "" { 14 | os.Setenv("NOVITA_API_KEY", "novita") 15 | } 16 | err := g.StreamCompleter.Setup("NOVITA_API_KEY", ChatURL, "NOVITA_DEBUG") 17 | if err != nil { 18 | return fmt.Errorf("failed to setup stream completer: %w", err) 19 | } 20 | g.StreamCompleter.Model = g.Model 21 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 22 | g.StreamCompleter.MaxTokens = g.MaxTokens 23 | g.StreamCompleter.Temperature = &g.Temperature 24 | g.StreamCompleter.TopP = &g.TopP 25 | toolChoice := "auto" 26 | g.StreamCompleter.ToolChoice = &toolChoice 27 | return nil 28 | } 29 | 30 | func (g *Novita) RegisterTool(tool tools.AiTool) { 31 | g.StreamCompleter.InternalRegisterTool(tool) 32 | } 33 | -------------------------------------------------------------------------------- /internal/vendors/ollama/models.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import "github.com/baalimago/clai/internal/tools" 4 | 5 | // since we can use ollama in OpenAI compatible mode, we use the same types as `openai` package 6 | type ChatCompletion struct { 7 | ID string `json:"id"` 8 | Object string `json:"object"` 9 | Created int64 `json:"created"` 10 | Model string `json:"model"` 11 | Choices []Choice `json:"choices"` 12 | Usage Usage `json:"usage"` 13 | SystemFingerprint string `json:"system_fingerprint"` 14 | } 15 | 16 | type Choice struct { 17 | Index int `json:"index"` 18 | Delta Delta `json:"delta"` 19 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 20 | FinishReason string `json:"finish_reason"` 21 | } 22 | 23 | type Usage struct { 24 | PromptTokens int `json:"prompt_tokens"` 25 | CompletionTokens int `json:"completion_tokens"` 26 | TotalTokens int `json:"total_tokens"` 27 | } 28 | 29 | type Delta struct { 30 | Content any `json:"content"` 31 | Role string `json:"role"` 32 | ToolCalls []ToolsCall `json:"tool_calls"` 33 | } 34 | 35 | type ToolsCall struct { 36 | Function GptFunc `json:"function"` 37 | ID string `json:"id"` 38 | Index int `json:"index"` 39 | Type string `json:"type"` 40 | } 41 | 42 | type GptFunc struct { 43 | Arguments string `json:"arguments"` 44 | Name string `json:"name"` 45 | } 46 | 47 | type GptTool struct { 48 | Name string `json:"name"` 49 | Description string `json:"description"` 50 | Inputs tools.InputSchema `json:"parameters"` 51 | } 52 | 53 | type GptToolSuper struct { 54 | Type string `json:"type"` 55 | Function GptTool `json:"function"` 56 | } 57 | -------------------------------------------------------------------------------- /internal/vendors/ollama/ollama.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/text/generic" 5 | ) 6 | 7 | var OLLAMA_DEFAULT = Ollama{ 8 | Model: "llama3", 9 | Temperature: 1.0, 10 | TopP: 1.0, 11 | Url: ChatURL, 12 | } 13 | 14 | type Ollama struct { 15 | generic.StreamCompleter 16 | Model string `json:"model"` 17 | FrequencyPenalty float64 `json:"frequency_penalty"` 18 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 19 | PresencePenalty float64 `json:"presence_penalty"` 20 | Temperature float64 `json:"temperature"` 21 | TopP float64 `json:"top_p"` 22 | Url string `json:"url"` 23 | } 24 | -------------------------------------------------------------------------------- /internal/vendors/ollama/ollama_setup.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/baalimago/clai/internal/tools" 8 | ) 9 | 10 | const ChatURL = "http://localhost:11434/v1/chat/completions" 11 | 12 | func (g *Ollama) Setup() error { 13 | if os.Getenv("OLLAMA_API_KEY") == "" { 14 | os.Setenv("OLLAMA_API_KEY", "ollama") 15 | } 16 | err := g.StreamCompleter.Setup("OLLAMA_API_KEY", ChatURL, "OLLAMA_DEBUG") 17 | if err != nil { 18 | return fmt.Errorf("failed to setup stream completer: %w", err) 19 | } 20 | g.StreamCompleter.Model = g.Model 21 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 22 | g.StreamCompleter.MaxTokens = g.MaxTokens 23 | g.StreamCompleter.Temperature = &g.Temperature 24 | g.StreamCompleter.TopP = &g.TopP 25 | toolChoice := "auto" 26 | g.StreamCompleter.ToolChoice = &toolChoice 27 | return nil 28 | } 29 | 30 | func (g *Ollama) RegisterTool(tool tools.AiTool) { 31 | g.StreamCompleter.InternalRegisterTool(tool) 32 | } 33 | -------------------------------------------------------------------------------- /internal/vendors/openai/constants.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | const ( 4 | ChatURL = "https://api.openai.com/v1/chat/completions" 5 | PhotoURL = "https://api.openai.com/v1/images/generations" 6 | ) 7 | -------------------------------------------------------------------------------- /internal/vendors/openai/dalle.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "os" 12 | 13 | "github.com/baalimago/clai/internal/models" 14 | "github.com/baalimago/clai/internal/photo" 15 | "github.com/baalimago/clai/internal/utils" 16 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 17 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 18 | ) 19 | 20 | type DallE struct { 21 | Model string `json:"model"` 22 | N int `json:"n"` 23 | Size string `json:"size"` 24 | Quality string `json:"quality"` 25 | Style string `json:"style"` 26 | Output photo.Output `json:"output"` 27 | // Don't save this as this is set via the Output struct 28 | ResponseFormat string `json:"-"` 29 | Prompt string `json:"-"` 30 | client *http.Client `json:"-"` 31 | debug bool `json:"-"` 32 | raw bool `json:"-"` 33 | apiKey string `json:"-"` 34 | } 35 | 36 | type DallERequest struct { 37 | Model string `json:"model"` 38 | N int `json:"n"` 39 | Size string `json:"size"` 40 | Quality string `json:"quality"` 41 | Style string `json:"style"` 42 | ResponseFormat string `json:"response_format"` 43 | Prompt string `json:"prompt"` 44 | } 45 | 46 | type ImageResponse struct { 47 | RevisedPrompt string `json:"revised_prompt"` 48 | URL string `json:"url"` 49 | B64_JSON string `json:"b64_json"` 50 | } 51 | 52 | type ImageResponses struct { 53 | Created int `json:"created"` 54 | Data []ImageResponse `json:"data"` 55 | } 56 | 57 | var defaultDalle = DallE{ 58 | Model: "dall-e-3", 59 | Size: "1024x1024", 60 | N: 1, 61 | Style: "vivid", 62 | Quality: "hd", 63 | } 64 | 65 | func NewPhotoQuerier(pConf photo.Configurations) (models.Querier, error) { 66 | home, _ := os.UserConfigDir() 67 | apiKey := os.Getenv("OPENAI_API_KEY") 68 | if apiKey == "" { 69 | return nil, fmt.Errorf("environment variable 'OPENAI_API_KEY' not set") 70 | } 71 | model := pConf.Model 72 | defaultCpy := defaultDalle 73 | defaultCpy.Model = model 74 | defaultCpy.Output = pConf.Output 75 | // Load config based on model, allowing for different configs for each model 76 | dalleQuerier, err := utils.LoadConfigFromFile(home, fmt.Sprintf("openai_dalle_%v.json", model), nil, &defaultCpy) 77 | if dalleQuerier.Output.Type == photo.URL { 78 | dalleQuerier.ResponseFormat = "url" 79 | } else if dalleQuerier.Output.Type == photo.LOCAL { 80 | dalleQuerier.ResponseFormat = "b64_json" 81 | } 82 | 83 | if misc.Truthy(os.Getenv("DEBUG")) { 84 | dalleQuerier.debug = true 85 | } 86 | if err != nil { 87 | ancli.PrintWarn(fmt.Sprintf("failed to load config for model: %v, error: %v\n", model, err)) 88 | } 89 | dalleQuerier.client = &http.Client{} 90 | dalleQuerier.apiKey = apiKey 91 | dalleQuerier.Prompt = pConf.Prompt 92 | if err != nil { 93 | return nil, fmt.Errorf("failed to load config: %w", err) 94 | } 95 | return &dalleQuerier, nil 96 | } 97 | 98 | func (q *DallE) createRequest(ctx context.Context) (*http.Request, error) { 99 | if q.debug { 100 | ancli.PrintOK(fmt.Sprintf("DallE request: %+v\n", q)) 101 | } 102 | reqVersion := DallERequest{ 103 | Model: q.Model, 104 | N: q.N, 105 | Size: q.Size, 106 | Quality: q.Quality, 107 | Style: q.Style, 108 | ResponseFormat: q.ResponseFormat, 109 | Prompt: q.Prompt, 110 | } 111 | bodyBytes, err := json.Marshal(reqVersion) 112 | if err != nil { 113 | return nil, fmt.Errorf("failed to encode JSON: %w", err) 114 | } 115 | 116 | req, err := http.NewRequestWithContext(ctx, "POST", PhotoURL, bytes.NewBuffer(bodyBytes)) 117 | if err != nil { 118 | return nil, fmt.Errorf("failed to create request: %w", err) 119 | } 120 | 121 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", q.apiKey)) 122 | req.Header.Set("Content-Type", "application/json") 123 | 124 | ancli.PrintOK(fmt.Sprintf("pre-revision prompt: '%v'\n", q.Prompt)) 125 | return req, nil 126 | } 127 | 128 | func (q *DallE) do(req *http.Request) error { 129 | if !q.raw { 130 | stop := photo.StartAnimation() 131 | defer stop() 132 | } 133 | resp, err := q.client.Do(req) 134 | if err != nil { 135 | return fmt.Errorf("failed tosending request: %w", err) 136 | } 137 | defer resp.Body.Close() 138 | 139 | b, err := io.ReadAll(resp.Body) 140 | if err != nil { 141 | return fmt.Errorf("failed to read response body: %w", err) 142 | } 143 | if resp.StatusCode != 200 { 144 | return fmt.Errorf("non-OK status: %v, body: %v", resp.Status, string(b)) 145 | } 146 | var imgResps ImageResponses 147 | err = json.Unmarshal(b, &imgResps) 148 | if err != nil { 149 | return fmt.Errorf("failed to decode JSON: %w", err) 150 | } 151 | 152 | if q.Output.Type == photo.LOCAL { 153 | localPath, err := q.saveImage(imgResps.Data[0]) 154 | if err != nil { 155 | return fmt.Errorf("failed to save image: %w", err) 156 | } 157 | // Defer to let animator finish first 158 | defer func() { 159 | ancli.PrintOK(fmt.Sprintf("image saved to: '%v'\n", localPath)) 160 | }() 161 | } else { 162 | defer func() { 163 | ancli.PrintOK(fmt.Sprintf("image URL: '%v'", imgResps.Data[0].URL)) 164 | }() 165 | } 166 | defer func() { 167 | fmt.Println() 168 | ancli.PrintOK(fmt.Sprintf("revised prompt: '%v'\n", imgResps.Data[0].RevisedPrompt)) 169 | }() 170 | 171 | return nil 172 | } 173 | 174 | func (q *DallE) saveImage(imgResp ImageResponse) (string, error) { 175 | data, err := base64.StdEncoding.DecodeString(imgResp.B64_JSON) 176 | if err != nil { 177 | return "", fmt.Errorf("failed to decode base64: %w", err) 178 | } 179 | pictureName := fmt.Sprintf("%v_%v.jpg", q.Output.Prefix, utils.RandomPrefix()) 180 | outFile := fmt.Sprintf("%v/%v", q.Output.Dir, pictureName) 181 | err = os.WriteFile(outFile, data, 0o644) 182 | if err != nil { 183 | ancli.PrintWarn(fmt.Sprintf("failed to write file: '%v', attempting tmp file...\n", err)) 184 | outFile = fmt.Sprintf("/tmp/%v", pictureName) 185 | err = os.WriteFile(outFile, data, 0o644) 186 | if err != nil { 187 | return "", fmt.Errorf("failed to write file: %w", err) 188 | } 189 | } 190 | return outFile, nil 191 | } 192 | 193 | func (q *DallE) Query(ctx context.Context) error { 194 | req, err := q.createRequest(ctx) 195 | if err != nil { 196 | return fmt.Errorf("failed to create request: %w", err) 197 | } 198 | err = q.do(req) 199 | if err != nil { 200 | return fmt.Errorf("failed to do request: %w", err) 201 | } 202 | return nil 203 | } 204 | -------------------------------------------------------------------------------- /internal/vendors/openai/gpt.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "github.com/baalimago/clai/internal/text/generic" 5 | ) 6 | 7 | var GPT_DEFAULT = ChatGPT{ 8 | Model: "gpt-4.1-mini", 9 | Temperature: 1.0, 10 | TopP: 1.0, 11 | Url: ChatURL, 12 | } 13 | 14 | type ChatGPT struct { 15 | generic.StreamCompleter 16 | Model string `json:"model"` 17 | FrequencyPenalty float64 `json:"frequency_penalty"` 18 | MaxTokens *int `json:"max_tokens"` // Use a pointer to allow null value 19 | PresencePenalty float64 `json:"presence_penalty"` 20 | Temperature float64 `json:"temperature"` 21 | TopP float64 `json:"top_p"` 22 | Url string `json:"url"` 23 | } 24 | -------------------------------------------------------------------------------- /internal/vendors/openai/gpt_setup.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/baalimago/clai/internal/tools" 7 | ) 8 | 9 | func (g *ChatGPT) Setup() error { 10 | err := g.StreamCompleter.Setup("OPENAI_API_KEY", ChatURL, "DEBUG_OPENAI") 11 | if err != nil { 12 | return fmt.Errorf("failed to setup stream completer: %w", err) 13 | } 14 | g.StreamCompleter.Model = g.Model 15 | g.StreamCompleter.FrequencyPenalty = &g.FrequencyPenalty 16 | g.StreamCompleter.MaxTokens = g.MaxTokens 17 | g.StreamCompleter.Temperature = &g.Temperature 18 | g.StreamCompleter.TopP = &g.TopP 19 | toolChoice := "auto" 20 | g.StreamCompleter.ToolChoice = &toolChoice 21 | return nil 22 | } 23 | 24 | func (g *ChatGPT) RegisterTool(tool tools.AiTool) { 25 | g.StreamCompleter.InternalRegisterTool(tool) 26 | } 27 | -------------------------------------------------------------------------------- /internal/vendors/openai/models.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import "github.com/baalimago/clai/internal/tools" 4 | 5 | type ChatCompletion struct { 6 | ID string `json:"id"` 7 | Object string `json:"object"` 8 | Created int64 `json:"created"` 9 | Model string `json:"model"` 10 | Choices []Choice `json:"choices"` 11 | Usage Usage `json:"usage"` 12 | SystemFingerprint string `json:"system_fingerprint"` 13 | } 14 | 15 | type Choice struct { 16 | Index int `json:"index"` 17 | Delta Delta `json:"delta"` 18 | Logprobs interface{} `json:"logprobs"` // null or complex object, hence interface{} 19 | FinishReason string `json:"finish_reason"` 20 | } 21 | 22 | type Usage struct { 23 | PromptTokens int `json:"prompt_tokens"` 24 | CompletionTokens int `json:"completion_tokens"` 25 | TotalTokens int `json:"total_tokens"` 26 | } 27 | 28 | type Delta struct { 29 | Content any `json:"content"` 30 | Role string `json:"role"` 31 | ToolCalls []ToolsCall `json:"tool_calls"` 32 | } 33 | 34 | type ToolsCall struct { 35 | Function GptFunc `json:"function"` 36 | ID string `json:"id"` 37 | Index int `json:"index"` 38 | Type string `json:"type"` 39 | } 40 | 41 | type GptFunc struct { 42 | Arguments string `json:"arguments"` 43 | Name string `json:"name"` 44 | } 45 | 46 | type GptTool struct { 47 | Name string `json:"name"` 48 | Description string `json:"description"` 49 | Inputs tools.InputSchema `json:"parameters"` 50 | } 51 | 52 | type GptToolSuper struct { 53 | Type string `json:"type"` 54 | Function GptTool `json:"function"` 55 | } 56 | -------------------------------------------------------------------------------- /internal/version.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | // Set with buildflag if built in pipeline and not using go install 4 | var ( 5 | BUILD_VERSION = "" 6 | BUILD_CHECKSUM = "" 7 | ) 8 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "runtime/pprof" 9 | 10 | "github.com/baalimago/clai/internal" 11 | "github.com/baalimago/clai/internal/utils" 12 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 13 | "github.com/baalimago/go_away_boilerplate/pkg/misc" 14 | "github.com/baalimago/go_away_boilerplate/pkg/shutdown" 15 | ) 16 | 17 | const usage = `clai - (c)ommand (l)ine (a)rtificial (i)ntelligence 18 | 19 | Prerequisites: 20 | - Set the OPENAI_API_KEY environment variable to your OpenAI API key 21 | - Set the ANTHROPIC_API_KEY environment variable to your Anthropic API key 22 | - Set the MISTRAL_API_KEY environment variable to your Mistral API key 23 | - Set the DEEPSEEK_API_KEY environment variable to your Deepseek API key 24 | - Set the NOVITA_API_KEY environment variable to your Novita API key 25 | - (Optional) Set the NO_COLOR environment variable to disable ansi color output 26 | - (Optional) Install glow - https://github.com/charmbracelet/glow for formated markdown output 27 | 28 | Usage: clai [flags] 29 | 30 | Flags: 31 | -re, -reply bool Set to true to reply to the previous query, meaning that it will be used as context for your next query. (default %v) 32 | -r, -raw bool Set to true to print raw output (no animation, no glow). (default %v) 33 | -cm, -chat-model string Set the chat model to use. (default is found in textConfig.json) 34 | -pm, -photo-model string Set the image model to use. (default is found in photoConfig.json) 35 | -pd, -photo-dir string Set the directory to store the generated pictures. (default %v) 36 | -pp, -photo-prefix string Set the prefix for the generated pictures. (default %v) 37 | -I, -replace string Set the string to replace with stdin. (default %v) 38 | -i bool Set to true to replace '-replace' flag value with stdin. This is overwritten by -I and -replace. (default %v) 39 | -t, -tools bool Set to true to use text tools. Some models might not support streaming. (default %v) 40 | -g, -glob string Set the glob to use for globbing. Same as glob mode. (default '%v') 41 | -p, -profile string Set the profile which should be used. For details, see 'clai help profile'. (default '%v') 42 | 43 | Commands: 44 | h|help Display this help message 45 | s|setup Setup the configuration files 46 | q|query Query the chat model with the given text 47 | p|photo Ask the photo model a picture with the requested prompt 48 | g|glob Query the chat model with the contents of the files found by the glob and the given text 49 | cmd Describe the command you wish to do, then execute the suggested command. It's a bit wonky when used with -re. 50 | re|replay Replay the most recent message. 51 | 52 | c|chat n|new Create a new chat with the given prompt. 53 | c|chat c|continue Continue an existing chat with the given chat ID. 54 | c|chat d|delete Delete the chat with the given chat ID. 55 | c|chat l|list List all existing chats. 56 | c|chat h|help Display detailed help for chat subcommands. 57 | 58 | Examples: 59 | - clai h | clai -i q generate some examples for this usage string: '{}' 60 | - clai query "What's the weather like in Tokyo?" 61 | - clai glob "*.txt" "Summarize these documents." 62 | - clai -cm claude-3-opus-20240229 chat new "What are the latest advancements in AI?" 63 | - clai photo "A futuristic cityscape" 64 | - clai -pm dall-e-2 photo A cat in space 65 | - clai -pd ~/Downloads -pp holiday A beach at sunset 66 | - docker logs example | clai -I LOG q "Find errors in these logs: LOG" 67 | - clai c new "Let's have a conversation about climate change." 68 | - clai c list 69 | - clai c help 70 | ` 71 | 72 | func main() { 73 | if misc.Truthy(os.Getenv("DEBUG_CPU")) { 74 | f, err := os.Create("cpu_profile.prof") 75 | ok := true 76 | if err != nil { 77 | ancli.PrintErr(fmt.Sprintf("failed to create profiler file: %v", err)) 78 | } 79 | if ok { 80 | defer f.Close() 81 | // Start the CPU profile 82 | err = pprof.StartCPUProfile(f) 83 | if err != nil { 84 | ancli.PrintErr(fmt.Sprintf("failed to start profiler : %v", err)) 85 | } 86 | defer pprof.StopCPUProfile() 87 | } 88 | } 89 | 90 | err := handleOopsies() 91 | if err != nil { 92 | ancli.PrintWarn(fmt.Sprintf("failed to handle oopsies, but as we didn't panic, it should be benign. Error: %v\n", err)) 93 | } 94 | querier, err := internal.Setup(usage) 95 | if err != nil { 96 | if errors.Is(err, utils.ErrUserInitiatedExit) { 97 | ancli.Okf("Seems like you wanted out. Byebye!\n") 98 | os.Exit(0) 99 | } 100 | ancli.PrintErr(fmt.Sprintf("failed to setup: %v\n", err)) 101 | os.Exit(1) 102 | } 103 | ctx, cancel := context.WithCancel(context.Background()) 104 | go func() { shutdown.Monitor(cancel) }() 105 | err = querier.Query(ctx) 106 | if err != nil { 107 | if errors.Is(err, utils.ErrUserInitiatedExit) { 108 | ancli.Okf("Seems like you wanted out. Byebye!\n") 109 | os.Exit(0) 110 | } else { 111 | ancli.PrintErr(fmt.Sprintf("failed to run: %v\n", err)) 112 | os.Exit(1) 113 | } 114 | } 115 | cancel() 116 | if misc.Truthy(os.Getenv("DEBUG")) { 117 | ancli.PrintOK("things seems to have worked out. Bye bye! 🚀\n") 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /oopsies.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/baalimago/go_away_boilerplate/pkg/ancli" 9 | ) 10 | 11 | // moveConfFromHomeToConfig since I didn't know that there was a os.UserConfigDir function 12 | // to call. Better to follow standards as much as possible, even if it might cause some migration 13 | // issues 14 | func moveConfFromHomeToConfig() error { 15 | homeDir, err := os.UserHomeDir() 16 | if err != nil { 17 | return fmt.Errorf("failed to get home dir: %w", err) 18 | } 19 | oldClaiDir := path.Join(homeDir, ".clai") 20 | if _, err := os.Stat(oldClaiDir); !os.IsNotExist(err) { 21 | confDir, err := os.UserConfigDir() 22 | if err != nil { 23 | return fmt.Errorf("failed to get conf dir: %w", err) 24 | } 25 | ancli.PrintWarn(fmt.Sprintf("oopsie detected: attempting to move config from: %v, to %v, to better adhere to standards\n", oldClaiDir, confDir)) 26 | newClaiDir := path.Join(confDir, ".clai") 27 | err = os.Rename(oldClaiDir, newClaiDir) 28 | if err != nil { 29 | return fmt.Errorf("failed to rename: %w", err) 30 | } else { 31 | ancli.PrintOK(fmt.Sprintf("oopsie resolved: you'll now find your clai configurations in directory: '%v'\n", newClaiDir)) 32 | } 33 | } 34 | return nil 35 | } 36 | 37 | // handleOopsies by attempting to migrate and fix previous errors and issues caused by me, the writer of 38 | // the application, due to lack of knowledge and/or foresight 39 | func handleOopsies() error { 40 | err := moveConfFromHomeToConfig() 41 | if err != nil { 42 | ancli.PrintErr(fmt.Sprintf("failed to move conf from home to config: %v\n", err)) 43 | ancli.PrintErr("manual intervention is advised, sorry for this inconvenience. The configuration has moved from os.UserHomeDir() -> os.UserConfigDir(). Aborting to avoid conflicts.\n") 44 | os.Exit(1) 45 | } 46 | return nil 47 | } 48 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Function to get the latest release download URL for the specified OS and architecture 4 | get_latest_release_url() { 5 | repo="baalimago/clai" 6 | os="$1" 7 | arch="$2" 8 | 9 | # Fetch the latest release data from GitHub API 10 | release_data=$(curl -s "https://api.github.com/repos/$repo/releases/latest") 11 | 12 | # Extract the asset URL for the specified OS and architecture 13 | download_url=$(echo "$release_data" | grep "browser_download_url" | grep "$os" | grep "$arch" | cut -d '"' -f 4) 14 | 15 | echo "$download_url" 16 | } 17 | 18 | # Detect the OS 19 | case "$(uname)" in 20 | Linux*) 21 | os="linux" 22 | ;; 23 | Darwin*) 24 | os="darwin" 25 | ;; 26 | *) 27 | echo "Unsupported OS: $(uname)" 28 | exit 1 29 | ;; 30 | esac 31 | 32 | # Detect the architecture 33 | arch=$(uname -m) 34 | case "$arch" in 35 | x86_64) 36 | arch="amd64" 37 | ;; 38 | armv7*) 39 | arch="arm" 40 | ;; 41 | aarch64|arm64) 42 | arch="arm64" 43 | ;; 44 | i?86) 45 | arch="386" 46 | ;; 47 | *) 48 | echo "Unsupported architecture: $arch" 49 | exit 1 50 | ;; 51 | esac 52 | 53 | printf "detected os: '%s', arch: '%s'\n" "$os" "$arch" 54 | 55 | # Get the download URL for the latest release 56 | printf "finding asset url..." 57 | download_url=$(get_latest_release_url "$os" "$arch") 58 | printf "OK!\n" 59 | 60 | # Download the binary 61 | tmp_file=$(mktemp) 62 | 63 | printf "downloading binary..." 64 | if ! curl -s -L -o "$tmp_file" "$download_url"; then 65 | echo 66 | echo "Failed to download the binary." 67 | exit 1 68 | fi 69 | printf "OK!\n" 70 | 71 | printf "setting file executable file permissions..." 72 | # Make the binary executable 73 | 74 | if ! chmod +x "$tmp_file"; then 75 | echo 76 | echo "Failed to make the binary executable. Try running the script with sudo." 77 | exit 1 78 | fi 79 | printf "OK!\n" 80 | 81 | # Move the binary to standard XDG location and handle permission errors 82 | INSTALL_DIR=$HOME/.local/bin 83 | # If run as 'sudo', install to /usr/local/bin for systemwide use 84 | if [ -x /usr/bin/id ]; then 85 | if [ `/usr/bin/id -u` -eq 0 ]; then 86 | INSTALL_DIR=/usr/local/bin 87 | fi 88 | fi 89 | 90 | if ! mv "$tmp_file" $INSTALL_DIR/clai; then 91 | echo "Failed to move the binary to $INSTALL_DIR/clai, see error above. Try making sure you have write permission there, or run 'mv $tmp_file '." 92 | exit 1 93 | fi 94 | 95 | echo "clai installed successfully in $INSTALL_DIR, try it out with 'clai h'" 96 | --------------------------------------------------------------------------------