├── .chglog ├── CHANGELOG.tpl.md └── config.yml ├── .dockerignore ├── .github └── workflows │ ├── go.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── .goreleaser.yaml ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── Makefile.def ├── README.md ├── README_ja.md ├── README_zh.md ├── cmd └── cli │ └── main.go ├── go.mod ├── go.sum ├── hack └── boilerplate.go.txt ├── installer └── dockerfile │ └── cli │ └── Dockerfile └── internal ├── ai ├── ai.go ├── ai_options.go ├── logging.go ├── model.go ├── prompts.go └── types.go ├── cli ├── ask │ └── ask.go ├── cli.go ├── coder │ └── coder.go ├── commit │ ├── commit.go │ └── options.go ├── completion │ └── completion.go ├── configure │ ├── configure.go │ ├── echo.go │ └── reset.go ├── convo │ ├── convo.go │ ├── ls.go │ ├── rm.go │ └── show.go ├── hook │ └── hook.go ├── loadctx │ ├── clean.go │ ├── context.go │ ├── list.go │ └── load.go ├── manpage │ └── manpage.go ├── profiling.go ├── review │ └── review.go └── version │ └── version.go ├── convo ├── chat_history_store.go ├── chat_history_store_test.go ├── conversation.go ├── factory.go ├── format.go ├── sha.go └── sqlite3 │ ├── convo_store.go │ ├── convo_store_options.go │ ├── convo_store_test.go │ ├── loadcontext_store.go │ └── loadcontext_store_test.go ├── errbook ├── error.go └── handle.go ├── git ├── git.go ├── git_test.go └── options.go ├── options ├── basic_flags.go ├── config.go ├── config_template.yml ├── config_test.go ├── datastore_flags.go ├── model_flags.go └── options.go ├── prompt ├── language.go ├── language_test.go ├── prompt.go ├── prompt_test.go ├── template_loader.go └── templates │ ├── code_review_file_diff.tmpl │ ├── commit-msg.tmpl │ ├── conventional_commit.tmpl │ ├── summarize_file_diff.tmpl │ ├── summarize_title.tmpl │ └── translation.tmpl ├── runner ├── output.go ├── output_test.go ├── runner.go └── runner_test.go ├── system ├── system.go └── types.go ├── ui ├── chat │ ├── chat.go │ └── options.go ├── coders │ ├── auto_coder.go │ ├── auto_coder_options.go │ ├── banner.txt │ ├── code_editor.go │ ├── commands.go │ ├── commands_test.go │ ├── completer.go │ ├── context.go │ ├── diff_block_editor.go │ ├── diff_block_editor_test.go │ ├── fence.go │ ├── fence_test.go │ ├── llm_callback.go │ ├── prompts.go │ ├── prompts_test.go │ └── types.go ├── console │ ├── anim.go │ ├── confirm.go │ ├── error.go │ ├── info.go │ ├── key_bindings.go │ ├── prompt.go │ ├── render.go │ ├── styles.go │ ├── success.go │ └── warn.go ├── textarea.go └── types.go └── util ├── debug ├── log.go ├── log_test.go └── options.go ├── flag └── flag.go ├── genericclioptions └── io_options.go ├── rest └── rest.go ├── templates ├── command_groups.go ├── markdown.go ├── normalizers.go ├── templater.go └── templates.go └── term ├── pipe.go ├── resize.go ├── term.go ├── term_writer.go └── term_writer_test.go /.chglog/CHANGELOG.tpl.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | 3 | {{ range .Versions }} 4 | 5 | ## {{ if .Tag.Previous }}[{{ .Tag.Name }}]{{ else }}{{ .Tag.Name }}{{ end }} - {{ datetime "2006-01-02" .Tag.Date }} 6 | {{ range .CommitGroups -}} 7 | ### {{ .Title }} 8 | {{ range .Commits -}} 9 | - {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }} 10 | {{ end }} 11 | {{ end -}} 12 | 13 | {{- if .RevertCommits -}} 14 | ### Reverts 15 | {{ range .RevertCommits -}} 16 | - {{ .Revert.Header }} 17 | {{ end }} 18 | {{ end -}} 19 | 20 | {{- if .MergeCommits -}} 21 | ### Pull Requests 22 | {{ range .MergeCommits -}} 23 | - {{ .Header }} 24 | {{ end }} 25 | {{ end -}} 26 | 27 | {{- if .NoteGroups -}} 28 | {{ range .NoteGroups -}} 29 | ### {{ .Title }} 30 | {{ range .Notes }} 31 | {{ .Body }} 32 | {{ end }} 33 | {{ end -}} 34 | {{ end -}} 35 | {{ end -}} 36 | 37 | {{- if .Versions }} 38 | [Unreleased]: {{ .Info.RepositoryURL }}/compare/{{ $latest := index .Versions 0 }}{{ $latest.Tag.Name }}...HEAD 39 | {{ range .Versions -}} 40 | {{ if .Tag.Previous -}} 41 | [{{ .Tag.Name }}]: {{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }} 42 | {{ end -}} 43 | {{ end -}} 44 | {{ end -}} -------------------------------------------------------------------------------- /.chglog/config.yml: -------------------------------------------------------------------------------- 1 | style: github 2 | template: CHANGELOG.tpl.md 3 | info: 4 | title: CHANGELOG 5 | repository_url: https://github.com/coding-hui/ai-terminal 6 | options: 7 | commits: 8 | filters: 9 | Type: 10 | - feat 11 | - fix 12 | - perf 13 | - refactor 14 | commit_groups: 15 | title_maps: 16 | feat: Features 17 | fix: Bug Fixes 18 | perf: Performance Improvements 19 | refactor: Code Refactoring 20 | header: 21 | pattern: "^(\\w*)\\:\\s(.*)$" 22 | pattern_maps: 23 | - Type 24 | - Subject 25 | notes: 26 | keywords: 27 | - BREAKING CHANGE -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file 2 | # Ignore build and test binaries. 3 | bin/ 4 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "*" ] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version-file: './go.mod' 22 | 23 | - name: Set CI ENV 24 | run: export ENV=CI 25 | 26 | - name: Build 27 | run: make 28 | 29 | - name: Lint 30 | run: make lint 31 | 32 | - name: Test 33 | env: 34 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 35 | OPENAI_MODEL: ${{ vars.OPENAI_MODEL }} 36 | OPENAI_API_BASE: ${{ vars.OPENAI_API_BASE }} 37 | SILICONCLOUD_API_KEY: ${{ secrets.SILICONCLOUD_API_KEY }} 38 | run: make test 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/release.yaml 2 | 3 | name: Release 4 | 5 | on: 6 | push: 7 | tags: [ "v*" ] 8 | 9 | permissions: 10 | contents: write 11 | packages: write 12 | 13 | jobs: 14 | # homebrew: 15 | # runs-on: ubuntu-latest 16 | # steps: 17 | # - name: Update Homebrew formula 18 | # uses: dawidd6/action-homebrew-bump-formula@v3 19 | # with: 20 | # GitHub token, required, not the default one 21 | # token: ${{secrets.TOKEN}} 22 | # Optional, defaults to homebrew/core 23 | # tap: USER/REPO 24 | # Formula name, required 25 | # formula: FORMULA 26 | # Optional, will be determined automatically 27 | # tag: ${{github.ref}} 28 | # Optional, will be determined automatically 29 | # revision: ${{github.sha}} 30 | # Optional, if don't want to check for already open PRs 31 | # force: false # true 32 | build-go-binary: 33 | name: Release Go Binary 34 | runs-on: ubuntu-latest 35 | strategy: 36 | matrix: 37 | goos: [ linux, windows, darwin ] 38 | goarch: [ amd64, arm64 ] 39 | exclude: 40 | - goarch: "386" 41 | goos: darwin 42 | - goarch: arm64 43 | goos: windows 44 | steps: 45 | - uses: actions/checkout@v4 46 | - uses: wangyoucao577/go-release-action@v1 47 | with: 48 | github_token: ${{ secrets.GITHUB_TOKEN }} 49 | goos: ${{ matrix.goos }} 50 | goarch: ${{ matrix.goarch }} 51 | binary_name: "ai" 52 | build_command: make build 53 | extra_files: ./bin/ai LICENSE README.md 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | bin/* 8 | Dockerfile.cross 9 | 10 | # Test binary, built with `go test -c` 11 | *.test 12 | 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | # Go workspace file 17 | go.work 18 | 19 | # Kubernetes Generated files - skip generated files, except for vendored files 20 | !vendor/**/zz_generated.* 21 | 22 | # editor and IDE paraphernalia 23 | .idea 24 | .vscode 25 | *.swp 26 | *.swo 27 | *~ 28 | 29 | logs 30 | *.log 31 | dist/ 32 | 33 | completions/* 34 | manpages/* -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | timeout: 5m 3 | allow-parallel-runners: true 4 | 5 | issues: 6 | # don't skip warning about doc comments 7 | # don't exclude the default set of lint 8 | exclude-use-default: false 9 | # restore some of the defaults 10 | # (fill in the rest as needed) 11 | exclude-rules: 12 | - path: "api/*" 13 | linters: 14 | - lll 15 | - path: "internal/*" 16 | linters: 17 | - dupl 18 | - lll 19 | linters: 20 | disable-all: true 21 | enable: 22 | - dupl 23 | - errcheck 24 | - copyloopvar 25 | - ginkgolinter 26 | - goconst 27 | - gocyclo 28 | - gofmt 29 | - goimports 30 | - gosimple 31 | - govet 32 | - ineffassign 33 | - lll 34 | - misspell 35 | - nakedret 36 | - prealloc 37 | - staticcheck 38 | - typecheck 39 | - unconvert 40 | - unparam 41 | - unused 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 coding-hui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile.def: -------------------------------------------------------------------------------- 1 | ifeq ($(origin VERSION), undefined) 2 | VERSION := $(shell git describe --tags --always --match='v*') 3 | endif 4 | 5 | # Check if the tree is dirty. default to dirty 6 | GIT_TREE_STATE:="dirty" 7 | ifeq (, $(shell git status --porcelain 2>/dev/null)) 8 | GIT_TREE_STATE="clean" 9 | endif 10 | ifeq ($(origin GIT_COMMIT), undefined) 11 | GIT_COMMIT:=$(shell git rev-parse HEAD) 12 | endif 13 | 14 | GitSHA=`git rev-parse HEAD` 15 | 16 | VERSION_PACKAGE=github.com/coding-hui/common/version 17 | 18 | GO_LDFLAGS += -X $(VERSION_PACKAGE).GitVersion=$(VERSION) \ 19 | -X $(VERSION_PACKAGE).GitCommit=$(GIT_COMMIT) \ 20 | -X $(VERSION_PACKAGE).GitTreeState=$(GIT_TREE_STATE) \ 21 | -X $(VERSION_PACKAGE).BuildDate=$(shell date -u +'%Y-%m-%dT%H:%M:%SZ') 22 | ifneq ($(DLV),) 23 | GO_BUILD_FLAGS += -gcflags "all=-N -l" 24 | LDFLAGS = "" 25 | endif 26 | GO_BUILD_FLAGS += -ldflags "$(GO_LDFLAGS)" 27 | 28 | ifeq ($(GOOS),windows) 29 | GO_OUT_EXT := .exe 30 | endif 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Language : [🇺🇸 us](./README.md) | [🇨🇳 zh](./README_zh.md) | [🇯🇵 ja](./README_ja.md) 2 | 3 | # AI-Terminal 4 | 5 | AI-Terminal is an advanced AI-powered CLI that enhances terminal workflows through AI-driven automation and 6 | optimization. It efficiently manages tasks such as file management, data processing, and system diagnostics. 7 | 8 | ## Key Features 9 | 10 | - **Contextual Assistance:** Learns from user commands to provide syntax suggestions. 11 | - **Automated Tasks:** Recognizes repetitive task patterns and creates shortcuts. 12 | - **Intelligent Search:** Conducts searches within files, directories, and specific file types. 13 | - **Error Correction:** Corrects incorrect commands and suggests alternatives. 14 | - **Custom Integrations:** Supports integration with various tools and services via plugins or APIs. 15 | 16 | ## Getting Started 17 | 18 | ### Prerequisites 19 | 20 | - Go version v1.22.0 or higher. 21 | 22 | ### Installation 23 | 24 | Install using Homebrew: 25 | 26 | ```bash 27 | brew install coding-hui/tap/ai-terminal 28 | ``` 29 | 30 | Or, download it: 31 | 32 | - [Packages][releases] are available in Debian and RPM formats 33 | - [Binaries][releases] are available for Linux, macOS, and Windows 34 | 35 | [releases]: https://github.com/coding-hui/ai-terminal/releases 36 | 37 | Or, build from source (requires Go 1.22+): 38 | 39 | ```sh 40 | make build 41 | ``` 42 | 43 | Then initialize configuration: 44 | ```sh 45 | ai configure 46 | ``` 47 | 48 |
49 | Shell Completions 50 | 51 | All packages and archives come with pre-generated completion files for Bash, 52 | ZSH, Fish, and PowerShell. 53 | 54 | If you built it from source, you can generate them with: 55 | 56 | ```bash 57 | ai completion bash -h 58 | ai completion zsh -h 59 | ai completion fish -h 60 | ai completion powershell -h 61 | ``` 62 | 63 | If you use a package (like Homebrew, Debs, etc), the completions should be set 64 | up automatically, given your shell is configured properly. 65 | 66 |
67 | 68 | ### Usage 69 | 70 | Here are some examples of how to use AI-Terminal, grouped by functionality: 71 | 72 | #### Chat 73 | 74 | - **Initiate a Chat:** 75 | ```sh 76 | ai ask "What is the best way to manage Docker containers?" 77 | ``` 78 | 79 | - **Use a Prompt File:** 80 | ```sh 81 | ai ask --file /path/to/prompt_file.txt 82 | ``` 83 | 84 | - **Pipe Input:** 85 | ```sh 86 | cat some_script.go | ai ask generate unit tests 87 | ``` 88 | 89 | #### Code Generation 90 | 91 | - **Interactive Code Generation:** 92 | ```sh 93 | ai coder 94 | ``` 95 | Starts interactive mode for generating code based on prompts. 96 | 97 | - **CLI-based Code Generation:** 98 | ```sh 99 | ai ctx load /path/to/context_file 100 | ai coder -c session_id -p "improve comments and add unit tests" 101 | ``` 102 | Load context files and specify session ID for batch processing. Supports: 103 | - Code improvement 104 | - Comment enhancement 105 | - Unit test generation 106 | - Code refactoring 107 | 108 | - **Generate Code with Context:** 109 | ```sh 110 | ai ctx load /path/to/context_file 111 | ai coder "implement feature xxx" 112 | ``` 113 | Load context files first to provide additional information for code generation. 114 | 115 | #### Code Review 116 | 117 | - **Review Code Changes:** 118 | ```sh 119 | ai review --exclude-list "*.md,*.txt" 120 | ``` 121 | 122 | #### Commit Messages 123 | 124 | - **Generate a Commit Message:** 125 | ```sh 126 | ai commit --diff-unified 3 --lang en 127 | ``` 128 | 129 | ## Contributing 130 | 131 | We welcome contributions! Please see our [Contribution Guidelines](CONTRIBUTING.md) for more details. 132 | 133 | ### Changelog 134 | 135 | Check out the [CHANGELOG.md](CHANGELOG.md) for detailed updates and changes to the project. 136 | 137 | ### License 138 | 139 | Copyright 2024 coding-hui. Licensed under the [MIT License](LICENSE). 140 | 141 | -------------------------------------------------------------------------------- /README_ja.md: -------------------------------------------------------------------------------- 1 | # AI-ターミナル 2 | 3 | AI-ターミナルは、AI駆動の自動化と最適化を通じてターミナルワークフローを強化する高度なAI駆動のCLIです。ファイル管理、データ処理、システム診断などのタスクを効率的に管理します。 4 | 5 | ## 主な機能 6 | 7 | - **コンテキストアシスタンス:** ユーザーのコマンドから学習し、構文の提案を提供します。 8 | - **自動化タスク:** 繰り返しのタスクパターンを認識し、ショートカットを作成します。 9 | - **インテリジェント検索:** ファイル、ディレクトリ、および特定のファイルタイプ内で検索を行います。 10 | - **エラー修正:** 不正確なコマンドを修正し、代替案を提案します。 11 | - **カスタム統合:** プラグインやAPIを介してさまざまなツールやサービスとの統合をサポートします。 12 | 13 | ## はじめに 14 | 15 | ### 前提条件 16 | 17 | - Goバージョンv1.22.0以上。 18 | 19 | ### インストール 20 | 21 | Homebrewを使用してインストール: 22 | 23 | ```bash 24 | brew install coding-hui/tap/ai-terminal 25 | ``` 26 | 27 | または直接ダウンロード: 28 | 29 | - [パッケージ][releases] Debian および RPM 形式で提供 30 | - [バイナリ][releases] Linux、macOS、Windows 用 31 | 32 | [releases]: https://github.com/coding-hui/ai-terminal/releases 33 | 34 | またはソースからビルド(Go 1.22+ が必要): 35 | 36 | ```sh 37 | make build 38 | ``` 39 | 40 | 設定を初期化: 41 | ```sh 42 | ai configure 43 | ``` 44 | 45 |
46 | シェル補完 47 | 48 | すべてのパッケージとアーカイブには、Bash、ZSH、Fish、PowerShell 用の事前生成された補完ファイルが含まれています。 49 | 50 | ソースからビルドした場合、以下のコマンドで生成できます: 51 | 52 | ```bash 53 | ai completion bash -h 54 | ai completion zsh -h 55 | ai completion fish -h 56 | ai completion powershell -h 57 | ``` 58 | 59 | パッケージ(Homebrew、Debs など)を使用する場合、シェルの設定が正しければ補完は自動的に設定されます。 60 | 61 |
62 | 63 | ### 使用方法 64 | 65 | AI-ターミナルの使用例を機能別に紹介します: 66 | 67 | #### チャット 68 | 69 | - **チャットを開始する:** 70 | ```sh 71 | ai ask "Dockerコンテナを管理する最良の方法は何ですか?" 72 | ``` 73 | 74 | - **プロンプトファイルを使用する:** 75 | ```sh 76 | ai ask --file /path/to/prompt_file.txt 77 | ``` 78 | 79 | - **パイプ入力:** 80 | ```sh 81 | cat some_script.go | ai ask generate unit tests 82 | ``` 83 | 84 | #### コード生成 85 | 86 | - **対話型コード生成:** 87 | ```sh 88 | ai coder 89 | ``` 90 | プロンプトに基づいて対話的にコードを生成します。 91 | 92 | - **CLIベースのコード生成:** 93 | ```sh 94 | ai ctx load /path/to/context_file 95 | ai coder -c session_id -p "コメントを改善し、単体テストを追加" 96 | ``` 97 | コンテキストファイルを読み込み、セッションIDを指定してバッチ処理を実行。以下をサポート: 98 | - コード改善 99 | - コメントの強化 100 | - 単体テストの生成 101 | - コードリファクタリング 102 | 103 | - **コンテキスト付きでコードを生成する:** 104 | ```sh 105 | ai ctx load /path/to/context_file 106 | ai coder "feature X を実装" 107 | ``` 108 | コード生成のための追加情報を提供するために、まずコンテキストファイルを読み込みます。 109 | 110 | #### コードレビュー 111 | 112 | - **コード変更をレビューする:** 113 | ```sh 114 | ai review --exclude-list "*.md,*.txt" 115 | ``` 116 | 117 | #### コミットメッセージ 118 | 119 | - **コミットメッセージを生成する:** 120 | ```sh 121 | ai commit --diff-unified 3 --lang ja 122 | ``` 123 | 124 | ## 貢献 125 | 126 | 貢献を歓迎します!詳細については、[貢献ガイドライン](CONTRIBUTING.md)をご覧ください。 127 | 128 | ### 変更履歴 129 | 130 | プロジェクトの詳細な更新と変更については、[CHANGELOG.md](CHANGELOG.md)をご覧ください。 131 | 132 | ### ライセンス 133 | 134 | 2024年 coding-hui。 [MITライセンス](LICENSE)の下でライセンスされています。 135 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # AI-终端 2 | 3 | AI-终端是一个先进的AI驱动的CLI,通过AI驱动的自动化和优化来增强终端工作流程。它有效地管理任务,如文件管理、数据处理和系统诊断。 4 | 5 | ## 主要特点 6 | 7 | - **上下文帮助:** 从用户命令中学习,提供语法建议。 8 | - **自动化任务:** 识别重复任务模式并创建快捷方式。 9 | - **智能搜索:** 在文件、目录和特定文件类型中进行搜索。 10 | - **错误纠正:** 纠正不正确的命令并提供替代方案。 11 | - **自定义集成:** 通过插件或API支持与各种工具和服务集成。 12 | 13 | ## 入门 14 | 15 | ### 先决条件 16 | 17 | - Go版本v1.22.0或更高。 18 | 19 | ### 安装 20 | 21 | 使用 Homebrew 安装: 22 | 23 | ```bash 24 | brew install coding-hui/tap/ai-terminal 25 | ``` 26 | 27 | 或者直接下载: 28 | 29 | - [软件包][releases] 提供 Debian 和 RPM 格式 30 | - [二进制文件][releases] 适用于 Linux、macOS 和 Windows 31 | 32 | [releases]: https://github.com/coding-hui/ai-terminal/releases 33 | 34 | 或者从源码编译(需要 Go 1.22+): 35 | 36 | ```sh 37 | make build 38 | ``` 39 | 40 | 然后初始化配置: 41 | ```sh 42 | ai configure 43 | ``` 44 | 45 |
46 | Shell 自动补全 47 | 48 | 所有软件包和压缩包都包含预生成的 Bash、ZSH、Fish 和 PowerShell 的自动补全文件。 49 | 50 | 如果从源码构建,可以使用以下命令生成: 51 | 52 | ```bash 53 | ai completion bash -h 54 | ai completion zsh -h 55 | ai completion fish -h 56 | ai completion powershell -h 57 | ``` 58 | 59 | 如果使用软件包(如 Homebrew、Debs 等),只要 shell 配置正确,自动补全应该会自动设置。 60 | 61 |
62 | 63 | ### 使用 64 | 65 | 以下是一些使用AI-终端的示例,按功能分组: 66 | 67 | #### 聊天 68 | 69 | - **启动聊天:** 70 | ```sh 71 | ai ask "管理Docker容器的最佳方式是什么?" 72 | ``` 73 | 74 | - **使用提示文件:** 75 | ```sh 76 | ai ask --file /path/to/prompt_file.txt 77 | ``` 78 | 79 | - **管道输入:** 80 | ```sh 81 | cat some_script.go | ai ask generate unit tests 82 | ``` 83 | 84 | #### 代码生成 85 | 86 | - **交互式代码生成:** 87 | ```sh 88 | ai coder 89 | ``` 90 | 启动交互模式,根据提示生成代码。 91 | 92 | - **CLI模式代码生成:** 93 | ```sh 94 | ai ctx load /path/to/context_file 95 | ai coder -c 会话ID -p "完善注释并补充单元测试" 96 | ``` 97 | 加载上下文文件并指定会话ID进行批量处理。支持: 98 | - 代码优化 99 | - 注释完善 100 | - 单元测试生成 101 | - 代码重构 102 | 103 | - **带上下文生成代码:** 104 | ```sh 105 | ai ctx load /path/to/context_file 106 | ai coder "实现功能xxx" 107 | ``` 108 | 先加载上下文文件为代码生成提供额外信息。 109 | 110 | #### 代码审查 111 | 112 | - **审查代码更改:** 113 | ```sh 114 | ai review --exclude-list "*.md,*.txt" 115 | ``` 116 | 117 | #### 提交消息 118 | 119 | - **生成提交消息:** 120 | ```sh 121 | ai commit --diff-unified 3 --lang zh 122 | ``` 123 | 124 | ## 贡献 125 | 126 | 我们欢迎贡献!请参阅我们的[贡献指南](CONTRIBUTING_zh.md)以获取更多详细信息。 127 | 128 | ## 许可证 129 | 130 | 版权所有 2024 coding-hui。根据[MIT许可证](LICENSE)授权。 131 | -------------------------------------------------------------------------------- /cmd/cli/main.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "math/rand" 9 | "os" 10 | "time" 11 | 12 | "github.com/coding-hui/ai-terminal/internal/cli" 13 | "github.com/coding-hui/ai-terminal/internal/errbook" 14 | ) 15 | 16 | func main() { 17 | rand.New(rand.NewSource(time.Now().UnixNano())) 18 | 19 | command := cli.NewDefaultAICommand() 20 | if err := command.Execute(); err != nil { 21 | errbook.HandleError(err) 22 | os.Exit(1) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coding-hui/ai-terminal 2 | 3 | go 1.23.2 4 | 5 | require ( 6 | github.com/AlekSi/pointer v1.2.0 7 | github.com/MakeNowJust/heredoc/v2 v2.0.1 8 | github.com/PuerkitoBio/goquery v1.10.2 9 | github.com/adrg/xdg v0.5.3 10 | github.com/atotto/clipboard v0.1.4 11 | github.com/caarlos0/duration v0.0.0-20241219124531-2bb7dc683aa4 12 | github.com/caarlos0/env/v9 v9.0.0 13 | github.com/caarlos0/go-shellwords v1.0.12 14 | github.com/caarlos0/timea.go v1.2.0 15 | github.com/charmbracelet/bubbles v0.20.0 16 | github.com/charmbracelet/bubbletea v1.3.3 17 | github.com/charmbracelet/glamour v0.8.0 18 | github.com/charmbracelet/huh v0.6.0 19 | github.com/charmbracelet/lipgloss v1.0.0 20 | github.com/charmbracelet/x/exp/ordered v0.1.0 21 | github.com/charmbracelet/x/exp/strings v0.0.0-20250213125511-a0c32e22e4fc 22 | github.com/coding-hui/common v0.8.7 23 | github.com/coding-hui/wecoding-sdk-go v0.8.15 24 | github.com/elk-language/go-prompt v1.1.5 25 | github.com/erikgeiser/promptkit v0.9.0 26 | github.com/fatih/color v1.18.0 27 | github.com/ghodss/yaml v1.0.0 28 | github.com/jmoiron/sqlx v1.4.0 29 | github.com/lucasb-eyer/go-colorful v1.2.0 30 | github.com/mattn/go-isatty v0.0.20 31 | github.com/mitchellh/go-wordwrap v1.0.1 32 | github.com/moby/term v0.5.2 33 | github.com/muesli/mango-cobra v1.2.0 34 | github.com/muesli/roff v0.1.0 35 | github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a 36 | github.com/russross/blackfriday v1.6.0 37 | github.com/spf13/cobra v1.8.1 38 | github.com/spf13/pflag v1.0.6 39 | github.com/stretchr/testify v1.10.0 40 | github.com/volcengine/volcengine-go-sdk v1.0.181 41 | gopkg.in/yaml.v3 v3.0.1 42 | k8s.io/klog/v2 v2.130.1 43 | modernc.org/sqlite v1.35.0 44 | ) 45 | 46 | require ( 47 | dario.cat/mergo v1.0.1 // indirect 48 | github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect 49 | github.com/Masterminds/goutils v1.1.1 // indirect 50 | github.com/Masterminds/semver/v3 v3.3.1 // indirect 51 | github.com/Masterminds/sprig/v3 v3.3.0 // indirect 52 | github.com/alecthomas/chroma/v2 v2.15.0 // indirect 53 | github.com/andybalholm/cascadia v1.3.3 // indirect 54 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 55 | github.com/aymerick/douceur v0.2.0 // indirect 56 | github.com/catppuccin/go v0.2.0 // indirect 57 | github.com/charmbracelet/x/ansi v0.8.0 // indirect 58 | github.com/charmbracelet/x/term v0.2.1 // indirect 59 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 60 | github.com/dlclark/regexp2 v1.11.5 // indirect 61 | github.com/dustin/go-humanize v1.0.1 // indirect 62 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 63 | github.com/go-logr/logr v1.4.2 // indirect 64 | github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db // indirect 65 | github.com/google/uuid v1.6.0 // indirect 66 | github.com/goph/emperror v0.17.2 // indirect 67 | github.com/gorilla/css v1.0.1 // indirect 68 | github.com/gosuri/uitable v0.0.4 // indirect 69 | github.com/h2non/filetype v1.1.3 // indirect 70 | github.com/huandu/xstrings v1.5.0 // indirect 71 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 72 | github.com/jmespath/go-jmespath v0.4.0 // indirect 73 | github.com/json-iterator/go v1.1.12 // indirect 74 | github.com/mattn/go-colorable v0.1.14 // indirect 75 | github.com/mattn/go-localereader v0.0.1 // indirect 76 | github.com/mattn/go-runewidth v0.0.16 // indirect 77 | github.com/mattn/go-sqlite3 v1.14.24 // indirect 78 | github.com/mattn/go-tty v0.0.7 // indirect 79 | github.com/microcosm-cc/bluemonday v1.0.27 // indirect 80 | github.com/mitchellh/copystructure v1.2.0 // indirect 81 | github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect 82 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 83 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 84 | github.com/modern-go/reflect2 v1.0.2 // indirect 85 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect 86 | github.com/muesli/cancelreader v0.2.2 // indirect 87 | github.com/muesli/mango v0.1.0 // indirect 88 | github.com/muesli/mango-pflag v0.1.0 // indirect 89 | github.com/muesli/reflow v0.3.0 // indirect 90 | github.com/ncruces/go-strftime v0.1.9 // indirect 91 | github.com/nikolalohinski/gonja v1.5.3 // indirect 92 | github.com/pelletier/go-toml/v2 v2.1.0 // indirect 93 | github.com/pkg/errors v0.9.1 // indirect 94 | github.com/pkg/term v1.2.0-beta.2 // indirect 95 | github.com/pkoukk/tiktoken-go v0.1.7 // indirect 96 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 97 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 98 | github.com/rivo/uniseg v0.4.7 // indirect 99 | github.com/sashabaranov/go-openai v1.37.0 // indirect 100 | github.com/shopspring/decimal v1.4.0 // indirect 101 | github.com/sirupsen/logrus v1.9.3 // indirect 102 | github.com/spf13/cast v1.7.1 // indirect 103 | github.com/volcengine/volc-sdk-golang v1.0.23 // indirect 104 | github.com/yargevad/filepathx v1.0.0 // indirect 105 | github.com/yuin/goldmark v1.7.4 // indirect 106 | github.com/yuin/goldmark-emoji v1.0.3 // indirect 107 | golang.org/x/crypto v0.33.0 // indirect 108 | golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac // indirect 109 | golang.org/x/net v0.35.0 // indirect 110 | golang.org/x/sync v0.11.0 // indirect 111 | golang.org/x/sys v0.30.0 // indirect 112 | golang.org/x/term v0.29.0 // indirect 113 | golang.org/x/text v0.22.0 // indirect 114 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 115 | gopkg.in/yaml.v2 v2.4.0 // indirect 116 | modernc.org/libc v1.61.13 // indirect 117 | modernc.org/mathutil v1.7.1 // indirect 118 | modernc.org/memory v1.8.2 // indirect 119 | ) 120 | -------------------------------------------------------------------------------- /hack/boilerplate.go.txt: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2024 coding-hui 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ -------------------------------------------------------------------------------- /installer/dockerfile/cli/Dockerfile: -------------------------------------------------------------------------------- 1 | # Build the manager binary 2 | FROM golang:1.25 AS builder 3 | ARG TARGETOS 4 | ARG TARGETARCH 5 | 6 | WORKDIR /workspace 7 | # Copy the Go Modules manifests 8 | ADD . ai-terminal 9 | 10 | RUN cd ai-terminal && CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} make build 11 | 12 | # Use distroless as minimal base image to package the manager binary 13 | # Refer to https://github.com/GoogleContainerTools/distroless for more details 14 | FROM gcr.io/distroless/static:nonroot 15 | WORKDIR / 16 | COPY --from=builder /workspace/ai-terminal/bin/ai . 17 | USER 65532:65532 18 | 19 | ENTRYPOINT ["/ai"] 20 | -------------------------------------------------------------------------------- /internal/ai/ai.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "html" 7 | "strings" 8 | 9 | "github.com/coding-hui/common/util/slices" 10 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 11 | 12 | "github.com/coding-hui/ai-terminal/internal/convo" 13 | "github.com/coding-hui/ai-terminal/internal/errbook" 14 | "github.com/coding-hui/ai-terminal/internal/options" 15 | "github.com/coding-hui/ai-terminal/internal/ui/console" 16 | ) 17 | 18 | const ( 19 | noExec = "[noexec]" 20 | ) 21 | 22 | type Engine struct { 23 | mode EngineMode 24 | running bool 25 | channel chan StreamCompletionOutput 26 | 27 | convoStore convo.Store 28 | model Model 29 | 30 | Config *options.Config 31 | } 32 | 33 | func New(ops ...Option) (*Engine, error) { 34 | return applyOptions(ops...) 35 | } 36 | 37 | func (e *Engine) SetMode(m EngineMode) { 38 | e.mode = m 39 | } 40 | 41 | func (e *Engine) GetMode() EngineMode { 42 | return e.mode 43 | } 44 | 45 | func (e *Engine) GetChannel() chan StreamCompletionOutput { 46 | return e.channel 47 | } 48 | 49 | func (e *Engine) GetConvoStore() convo.Store { 50 | return e.convoStore 51 | } 52 | 53 | func (e *Engine) Interrupt() { 54 | e.channel <- StreamCompletionOutput{ 55 | Content: "[Interrupt]", 56 | Last: true, 57 | Interrupt: true, 58 | Executable: false, 59 | } 60 | 61 | e.running = false 62 | } 63 | 64 | func (e *Engine) CreateCompletion(ctx context.Context, messages []llms.ChatMessage) (*CompletionOutput, error) { 65 | e.running = true 66 | 67 | if err := e.setupChatContext(ctx, &messages); err != nil { 68 | return nil, err 69 | } 70 | 71 | rsp, err := e.model.GenerateContent(ctx, slices.Map(messages, convert), e.callOptions()...) 72 | if err != nil { 73 | return nil, errbook.Wrap("Failed to create completion.", err) 74 | } 75 | 76 | content := rsp.Choices[0].Content 77 | content = html.UnescapeString(content) 78 | 79 | e.appendAssistantMessage(content) 80 | 81 | e.running = false 82 | 83 | return &CompletionOutput{ 84 | Command: "", 85 | Explanation: content, 86 | Executable: false, 87 | Usage: rsp.Usage, 88 | }, nil 89 | } 90 | 91 | func (e *Engine) CreateStreamCompletion(ctx context.Context, messages []llms.ChatMessage) (*StreamCompletionOutput, error) { 92 | e.running = true 93 | 94 | streamingFunc := func(ctx context.Context, chunk []byte) error { 95 | if !e.Config.Quiet { 96 | e.channel <- StreamCompletionOutput{ 97 | Content: string(chunk), 98 | Last: false, 99 | } 100 | } 101 | return nil 102 | } 103 | 104 | if err := e.setupChatContext(ctx, &messages); err != nil { 105 | return nil, err 106 | } 107 | 108 | for _, v := range messages { 109 | err := e.convoStore.AddMessage(ctx, e.Config.CacheWriteToID, v) 110 | if err != nil { 111 | errbook.HandleError(errbook.Wrap("Failed to add user chat input message to convo", err)) 112 | } 113 | } 114 | 115 | messageParts := slices.Map(messages, convert) 116 | rsp, err := e.model.GenerateContent(ctx, messageParts, e.callOptions(streamingFunc)...) 117 | if err != nil { 118 | e.running = false 119 | return nil, errbook.Wrap("Failed to create stream completion.", err) 120 | } 121 | 122 | executable := false 123 | output := rsp.Choices[0].Content 124 | 125 | if e.mode == ExecEngineMode { 126 | if !strings.HasPrefix(output, noExec) && !strings.Contains(output, "\n") { 127 | executable = true 128 | } 129 | } 130 | 131 | output = html.UnescapeString(output) 132 | 133 | if !e.Config.Quiet { 134 | e.channel <- StreamCompletionOutput{ 135 | Content: "", 136 | Last: true, 137 | Executable: executable, 138 | Usage: rsp.Usage, 139 | } 140 | } 141 | e.running = false 142 | 143 | e.appendAssistantMessage(output) 144 | 145 | return &StreamCompletionOutput{ 146 | Content: output, 147 | Last: true, 148 | Executable: executable, 149 | Usage: rsp.Usage, 150 | }, nil 151 | } 152 | 153 | func (e *Engine) callOptions(streamingFunc ...func(ctx context.Context, chunk []byte) error) []llms.CallOption { 154 | var opts []llms.CallOption 155 | if e.Config.MaxTokens > 0 { 156 | opts = append(opts, llms.WithMaxTokens(e.Config.MaxTokens)) 157 | } 158 | if len(streamingFunc) > 0 && streamingFunc[0] != nil { 159 | opts = append(opts, llms.WithStreamingFunc(streamingFunc[0])) 160 | } 161 | opts = append(opts, llms.WithModel(e.Config.CurrentModel.Name)) 162 | opts = append(opts, llms.WithMaxLength(e.Config.CurrentModel.MaxChars)) 163 | opts = append(opts, llms.WithTemperature(e.Config.Temperature)) 164 | opts = append(opts, llms.WithTopP(e.Config.TopP)) 165 | opts = append(opts, llms.WithTopK(e.Config.TopK)) 166 | opts = append(opts, llms.WithMultiContent(false)) 167 | 168 | return opts 169 | } 170 | 171 | func (e *Engine) setupChatContext(ctx context.Context, messages *[]llms.ChatMessage) error { 172 | store := e.convoStore 173 | if store == nil { 174 | return errbook.New("no chat convo store found") 175 | } 176 | 177 | if !e.Config.NoCache && e.Config.CacheReadFromID != "" { 178 | history, err := store.Messages(ctx, e.Config.CacheReadFromID) 179 | if err != nil { 180 | return errbook.Wrap(fmt.Sprintf( 181 | "There was a problem reading the cache. Use %s / %s to disable it.", 182 | console.StderrStyles().InlineCode.Render("--no-cache"), 183 | console.StderrStyles().InlineCode.Render("NO_CACHE"), 184 | ), err) 185 | } 186 | *messages = append(*messages, history...) 187 | } 188 | 189 | return nil 190 | } 191 | 192 | func (e *Engine) appendAssistantMessage(content string) { 193 | if e.convoStore != nil && e.Config.CacheWriteToID != "" { 194 | if err := e.convoStore.AddAIMessage(context.Background(), e.Config.CacheWriteToID, content); err != nil { 195 | errbook.HandleError(errbook.Wrap("failed to add assistant chat output message to convo", err)) 196 | } 197 | } 198 | } 199 | 200 | func convert(msg llms.ChatMessage) llms.MessageContent { 201 | return llms.MessageContent{ 202 | Role: msg.GetType(), 203 | Parts: []llms.ContentPart{llms.TextPart(msg.GetContent())}, 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /internal/ai/ai_options.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "github.com/volcengine/volcengine-go-sdk/service/arkruntime" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/convo" 7 | "github.com/coding-hui/ai-terminal/internal/errbook" 8 | "github.com/coding-hui/ai-terminal/internal/options" 9 | 10 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms/openai" 11 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms/volcengine" 12 | ) 13 | 14 | type Option func(*Engine) 15 | 16 | func WithMode(mode EngineMode) Option { 17 | return func(e *Engine) { 18 | e.mode = mode 19 | } 20 | } 21 | 22 | func WithConfig(cfg *options.Config) Option { 23 | return func(e *Engine) { 24 | e.Config = cfg 25 | } 26 | } 27 | 28 | func WithStore(store convo.Store) Option { 29 | return func(a *Engine) { 30 | a.convoStore = store 31 | } 32 | } 33 | 34 | func applyOptions(engineOpts ...Option) (engine *Engine, err error) { 35 | engine = &Engine{ 36 | channel: make(chan StreamCompletionOutput), 37 | running: false, 38 | } 39 | 40 | for _, option := range engineOpts { 41 | option(engine) 42 | } 43 | 44 | cfg := engine.Config 45 | if cfg == nil { 46 | return nil, errbook.New("Failed to initialize engine. Config is nil.") 47 | } 48 | 49 | if engine.convoStore == nil { 50 | engine.convoStore, err = convo.GetConversationStore(cfg) 51 | if err != nil { 52 | return nil, errbook.Wrap("Failed to get chat convo store.", err) 53 | } 54 | } 55 | 56 | cfg.CurrentModel, err = cfg.GetModel(cfg.Model) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | cfg.CurrentAPI, err = cfg.GetAPI(cfg.API) 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | mod, api := cfg.CurrentModel, cfg.CurrentAPI 67 | switch api.Name { 68 | case ModelTypeARK: 69 | engine.model, err = volcengine.NewClientWithApiKey( 70 | cfg.CurrentAPI.APIKey, 71 | arkruntime.WithBaseUrl(api.BaseURL), 72 | arkruntime.WithRegion(api.Region), 73 | arkruntime.WithTimeout(api.Timeout), 74 | arkruntime.WithRetryTimes(api.RetryTimes), 75 | ) 76 | if err != nil { 77 | return nil, err 78 | } 79 | default: 80 | engine.model, err = openai.New( 81 | openai.WithModel(mod.Name), 82 | openai.WithBaseURL(api.BaseURL), 83 | openai.WithToken(api.APIKey), 84 | ) 85 | if err != nil { 86 | return nil, err 87 | } 88 | } 89 | 90 | return engine, nil 91 | } 92 | -------------------------------------------------------------------------------- /internal/ai/logging.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "k8s.io/klog/v2" 9 | 10 | "github.com/coding-hui/wecoding-sdk-go/services/ai/callbacks" 11 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 12 | ) 13 | 14 | // LogHandler is a callback handler that prints to the standard output. 15 | type LogHandler struct{} 16 | 17 | var _ callbacks.Handler = LogHandler{} 18 | 19 | func (l LogHandler) HandleLLMGenerateContentStart(_ context.Context, ms []llms.MessageContent) { 20 | info := strings.Builder{} 21 | info.WriteString("Entering LLM with messages:\n") 22 | for _, m := range ms { 23 | var buf strings.Builder 24 | for _, t := range m.Parts { 25 | if t, ok := t.(llms.TextContent); ok { 26 | buf.WriteString(t.Text) 27 | } 28 | } 29 | info.WriteString(fmt.Sprintf(" Role: %s\n", m.Role)) 30 | info.WriteString(fmt.Sprintf(" Text: %s\n", strings.TrimSpace(buf.String()))) 31 | } 32 | klog.V(2).Info(info.String()) 33 | } 34 | 35 | func (l LogHandler) HandleLLMGenerateContentEnd(_ context.Context, res *llms.ContentResponse) { 36 | info := strings.Builder{} 37 | info.WriteString("Exiting LLM with response:\n") 38 | for _, c := range res.Choices { 39 | if c.Content != "" { 40 | info.WriteString(fmt.Sprintf(" Content: %s\n", c.Content)) 41 | } 42 | if c.StopReason != "" { 43 | info.WriteString(fmt.Sprintf(" StopReason: %s\n", c.StopReason)) 44 | } 45 | if len(c.GenerationInfo) > 0 { 46 | info.WriteString(" GenerationInfo:\n") 47 | for k, v := range c.GenerationInfo { 48 | info.WriteString(fmt.Sprintf("%20s: %v\n", k, v)) 49 | } 50 | } 51 | if c.FuncCall != nil { 52 | info.WriteString(fmt.Sprintf(" FuncCall: %s %s\n", c.FuncCall.Name, c.FuncCall.Arguments)) 53 | } 54 | } 55 | klog.V(2).Info(info.String()) 56 | } 57 | 58 | func (l LogHandler) HandleStreamingFunc(_ context.Context, chunk []byte) { 59 | fmt.Println(string(chunk)) 60 | } 61 | 62 | func (l LogHandler) HandleText(_ context.Context, text string) { 63 | fmt.Println(text) 64 | } 65 | 66 | func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) { 67 | fmt.Println("Entering LLM with prompt:", prompts) 68 | } 69 | 70 | func (l LogHandler) HandleLLMError(_ context.Context, err error) { 71 | fmt.Println("Exiting LLM with error:", err) 72 | } 73 | 74 | func (l LogHandler) HandleChainStart(_ context.Context, inputs map[string]any) { 75 | fmt.Println("Entering chain with inputs:", formatChainValues(inputs)) 76 | } 77 | 78 | func (l LogHandler) HandleChainEnd(_ context.Context, outputs map[string]any) { 79 | fmt.Println("Exiting chain with outputs:", formatChainValues(outputs)) 80 | } 81 | 82 | func (l LogHandler) HandleChainError(_ context.Context, err error) { 83 | fmt.Println("Exiting chain with error:", err) 84 | } 85 | 86 | func (l LogHandler) HandleToolStart(_ context.Context, input string) { 87 | fmt.Println("Entering tool with input:", removeNewLines(input)) 88 | } 89 | 90 | func (l LogHandler) HandleToolEnd(_ context.Context, output string) { 91 | fmt.Println("Exiting tool with output:", removeNewLines(output)) 92 | } 93 | 94 | func (l LogHandler) HandleToolError(_ context.Context, err error) { 95 | fmt.Println("Exiting tool with error:", err) 96 | } 97 | 98 | func formatChainValues(values map[string]any) string { 99 | output := "" 100 | for key, value := range values { 101 | output += fmt.Sprintf("\"%s\" : \"%s\", ", removeNewLines(key), removeNewLines(value)) 102 | } 103 | 104 | return output 105 | } 106 | 107 | func removeNewLines(s any) string { 108 | return strings.ReplaceAll(fmt.Sprint(s), "\n", " ") 109 | } 110 | -------------------------------------------------------------------------------- /internal/ai/model.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 7 | ) 8 | 9 | const ( 10 | ModelTypeOpenAI = "openai" 11 | ModelTypeARK = "ark" 12 | ) 13 | 14 | type Model interface { 15 | GenerateContent(context.Context, []llms.MessageContent, ...llms.CallOption) (*llms.ContentResponse, error) 16 | } 17 | -------------------------------------------------------------------------------- /internal/ai/prompts.go: -------------------------------------------------------------------------------- 1 | package ai 2 | -------------------------------------------------------------------------------- /internal/ai/types.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 4 | 5 | type EngineMode int 6 | 7 | const ( 8 | ChatEngineMode EngineMode = iota 9 | ExecEngineMode 10 | ) 11 | 12 | func (m EngineMode) String() string { 13 | if m == ExecEngineMode { 14 | return "exec" 15 | } else { 16 | return "chat" 17 | } 18 | } 19 | 20 | type CompletionOutput struct { 21 | Command string `json:"cmd"` 22 | Explanation string `json:"exp"` 23 | Executable bool `json:"exec"` 24 | 25 | Usage llms.Usage `json:"usage"` 26 | } 27 | 28 | func (c CompletionOutput) GetCommand() string { 29 | return c.Command 30 | } 31 | 32 | func (c CompletionOutput) GetExplanation() string { 33 | return c.Explanation 34 | } 35 | 36 | func (c CompletionOutput) IsExecutable() bool { 37 | return c.Executable 38 | } 39 | 40 | // CompletionInput is a tea.Msg that wraps the content read from stdin. 41 | type CompletionInput struct { 42 | Messages []llms.ChatMessage 43 | } 44 | 45 | // StreamCompletionOutput a tea.Msg that wraps the content returned from ai. 46 | type StreamCompletionOutput struct { 47 | Content string 48 | Last bool 49 | Interrupt bool 50 | Executable bool 51 | 52 | Usage llms.Usage `json:"usage"` 53 | } 54 | 55 | func (c StreamCompletionOutput) GetContent() string { 56 | return c.Content 57 | } 58 | 59 | func (c StreamCompletionOutput) IsLast() bool { 60 | return c.Last 61 | } 62 | 63 | func (c StreamCompletionOutput) IsInterrupt() bool { 64 | return c.Interrupt 65 | } 66 | 67 | func (c StreamCompletionOutput) IsExecutable() bool { 68 | return c.Executable 69 | } 70 | 71 | func (c StreamCompletionOutput) GetUsage() llms.Usage { 72 | return c.Usage 73 | } 74 | -------------------------------------------------------------------------------- /internal/cli/ask/ask.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package ask 6 | 7 | import ( 8 | "os" 9 | "strings" 10 | 11 | "github.com/spf13/cobra" 12 | 13 | "github.com/coding-hui/ai-terminal/internal/ai" 14 | "github.com/coding-hui/ai-terminal/internal/errbook" 15 | "github.com/coding-hui/ai-terminal/internal/options" 16 | "github.com/coding-hui/ai-terminal/internal/ui" 17 | "github.com/coding-hui/ai-terminal/internal/ui/chat" 18 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 19 | "github.com/coding-hui/ai-terminal/internal/util/templates" 20 | "github.com/coding-hui/ai-terminal/internal/util/term" 21 | ) 22 | 23 | var askExample = templates.Examples(` 24 | # You can ask any question, enforcing 💬 ask prompt mode: 25 | ai ask generate me a go application example using fiber 26 | 27 | # You can also pipe input that will be taken into account in your request: 28 | cat some_script.go | ai ask generate unit tests 29 | 30 | # Write new sections for a readme": 31 | cat README.md | ai ask "write a new section to this README documenting a pdf sharing feature" 32 | `) 33 | 34 | // Options is a struct to support ask command. 35 | type Options struct { 36 | genericclioptions.IOStreams 37 | pipe string 38 | prompts []string 39 | tempPromptFile string 40 | cfg *options.Config 41 | } 42 | 43 | // NewOptions returns initialized Options. 44 | func NewOptions(ioStreams genericclioptions.IOStreams, cfg *options.Config) *Options { 45 | return &Options{ 46 | IOStreams: ioStreams, 47 | cfg: cfg, 48 | } 49 | } 50 | 51 | // NewCmdASK returns a cobra command for ask any question. 52 | func NewCmdASK(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 53 | o := NewOptions(ioStreams, cfg) 54 | cmd := &cobra.Command{ 55 | Use: "ask", 56 | Short: "CLI mode is made to be integrated in your command lines workflow.", 57 | Example: askExample, 58 | PreRunE: func(c *cobra.Command, args []string) error { 59 | err := o.preparePrompts(args) 60 | if err != nil { 61 | return err 62 | } 63 | return nil 64 | }, 65 | RunE: func(cmd *cobra.Command, args []string) error { 66 | return o.Run() 67 | }, 68 | PostRunE: func(c *cobra.Command, args []string) error { 69 | if o.tempPromptFile != "" { 70 | err := os.Remove(o.tempPromptFile) 71 | if err != nil { 72 | return errbook.Wrap("Failed to remove temporary file: "+o.tempPromptFile, err) 73 | } 74 | } 75 | return nil 76 | }, 77 | } 78 | 79 | cmd.Flags().BoolVarP(&o.cfg.Interactive, "interactive", "i", o.cfg.Interactive, "Interactive dialogue model.") 80 | cmd.Flags().StringVarP(&o.cfg.PromptFile, "file", "f", o.cfg.PromptFile, "File containing prompt.") 81 | 82 | return cmd 83 | } 84 | 85 | // Validate validates the provided options. 86 | func (o *Options) Validate() error { 87 | return nil 88 | } 89 | 90 | // Run executes ask command. 91 | func (o *Options) Run() error { 92 | runMode := ui.CliMode 93 | if o.cfg.Interactive { 94 | runMode = ui.ReplMode 95 | } 96 | 97 | engine, err := ai.New(ai.WithConfig(o.cfg)) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | chatModel := chat.NewChat(o.cfg, 103 | chat.WithContent(o.pipe+"\n\n"+strings.Join(o.prompts, "\n\n")), 104 | chat.WithRunMode(runMode), 105 | chat.WithEngine(engine), 106 | ) 107 | 108 | return chatModel.Run() 109 | } 110 | 111 | func (o *Options) preparePrompts(args []string) error { 112 | if len(args) > 0 { 113 | o.prompts = append(o.prompts, strings.Join(args, " ")) 114 | } 115 | 116 | if o.cfg.PromptFile != "" { 117 | bytes, err := os.ReadFile(o.cfg.PromptFile) 118 | if err != nil { 119 | return errbook.Wrap("Couldn't reading prompt file.", err) 120 | } 121 | o.prompts = append(o.prompts, string(bytes)) 122 | } 123 | 124 | o.pipe = term.ReadPipeInput() 125 | 126 | return nil 127 | } 128 | -------------------------------------------------------------------------------- /internal/cli/cli.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023-2024 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package cli create a root cobra command and add subcommands to it. 6 | package cli 7 | 8 | import ( 9 | "flag" 10 | "io" 11 | "os" 12 | "slices" 13 | 14 | "github.com/spf13/cobra" 15 | 16 | cliflag "github.com/coding-hui/common/cli/flag" 17 | 18 | "github.com/coding-hui/ai-terminal/internal/cli/ask" 19 | "github.com/coding-hui/ai-terminal/internal/cli/coder" 20 | "github.com/coding-hui/ai-terminal/internal/cli/commit" 21 | "github.com/coding-hui/ai-terminal/internal/cli/completion" 22 | "github.com/coding-hui/ai-terminal/internal/cli/configure" 23 | "github.com/coding-hui/ai-terminal/internal/cli/convo" 24 | "github.com/coding-hui/ai-terminal/internal/cli/hook" 25 | "github.com/coding-hui/ai-terminal/internal/cli/loadctx" 26 | "github.com/coding-hui/ai-terminal/internal/cli/manpage" 27 | "github.com/coding-hui/ai-terminal/internal/cli/review" 28 | "github.com/coding-hui/ai-terminal/internal/cli/version" 29 | "github.com/coding-hui/ai-terminal/internal/errbook" 30 | "github.com/coding-hui/ai-terminal/internal/options" 31 | "github.com/coding-hui/ai-terminal/internal/util/debug" 32 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 33 | "github.com/coding-hui/ai-terminal/internal/util/templates" 34 | 35 | _ "github.com/coding-hui/ai-terminal/internal/convo/sqlite3" 36 | ) 37 | 38 | // var logFlushFreq = pflag.Duration(options.FlagLogFlushFrequency, 5*time.Second, "Maximum number of seconds between log flushes") 39 | 40 | // NewDefaultAICommand creates the `ai` command with default arguments. 41 | func NewDefaultAICommand() *cobra.Command { 42 | return NewAICommand(os.Stdin, os.Stdout, os.Stderr) 43 | } 44 | 45 | // NewAICommand returns new initialized instance of 'ai' root command. 46 | func NewAICommand(in io.Reader, out, errOut io.Writer) *cobra.Command { 47 | cfg, err := options.EnsureConfig() 48 | if err != nil { 49 | errbook.HandleError(errbook.Wrap("Could not load your configuration file.", err)) 50 | // if user is editing the settings, only print out the error, but do 51 | // not exit. 52 | if !slices.Contains(os.Args, "--settings") { 53 | os.Exit(1) 54 | } 55 | } 56 | 57 | // Parent command to which all subcommands are added. 58 | cmds := &cobra.Command{ 59 | Use: "ai", 60 | Short: "AI driven development in your terminal.", 61 | Long: templates.LongDesc(` 62 | AI driven development in your terminal. 63 | 64 | Find more information at: 65 | https://github.com/coding-hui/ai-terminal`), 66 | SilenceUsage: true, 67 | SilenceErrors: true, 68 | Run: runHelp, 69 | // Hook before and after Run initialize and write profiles to disk, 70 | // respectively. 71 | PersistentPreRunE: func(*cobra.Command, []string) error { 72 | return initProfiling() 73 | }, 74 | PersistentPostRunE: func(*cobra.Command, []string) error { 75 | return postRunHook(&cfg) 76 | }, 77 | } 78 | 79 | flags := cmds.PersistentFlags() 80 | flags.SetNormalizeFunc(cliflag.WarnWordSepNormalizeFunc) // Warn for "_" flags 81 | 82 | // Normalize all flags that are coming from other packages or pre-configurations 83 | // a.k.a. change all "_" to "-". e.g. glog package 84 | flags.SetNormalizeFunc(cliflag.WordSepNormalizeFunc) 85 | 86 | addProfilingFlags(flags) 87 | 88 | options.AddBasicFlags(flags, &cfg) 89 | 90 | cmds.PersistentFlags().AddGoFlagSet(flag.CommandLine) 91 | 92 | // From this point and forward we get warnings on flags that contain "_" separators 93 | cmds.SetGlobalNormalizationFunc(cliflag.WarnWordSepNormalizeFunc) 94 | 95 | ioStreams := genericclioptions.IOStreams{In: in, Out: out, ErrOut: errOut} 96 | 97 | groups := templates.CommandGroups{ 98 | templates.CommandGroup{ 99 | Message: "AI Commands:", 100 | Commands: []*cobra.Command{ 101 | coder.NewCmdCoder(&cfg), 102 | ask.NewCmdASK(ioStreams, &cfg), 103 | convo.NewCmdConversation(ioStreams, &cfg), 104 | commit.NewCmdCommit(ioStreams, &cfg), 105 | review.NewCmdCommit(ioStreams, &cfg), 106 | loadctx.NewCmdContext(ioStreams, &cfg), 107 | }, 108 | }, 109 | templates.CommandGroup{ 110 | Message: "Settings Commands:", 111 | Commands: []*cobra.Command{ 112 | configure.NewCmdConfigure(ioStreams, &cfg), 113 | completion.NewCmdCompletion(), 114 | manpage.NewCmdManPage(cmds), 115 | hook.NewCmdHook(), 116 | }, 117 | }, 118 | } 119 | groups.Add(cmds) 120 | 121 | filters := []string{"options"} 122 | templates.ActsAsRootCommand(cmds, filters, groups...) 123 | 124 | cmds.AddCommand(version.NewCmdVersion(ioStreams)) 125 | cmds.AddCommand(options.NewCmdOptions(ioStreams.Out)) 126 | 127 | defer func() { 128 | debug.Teardown() 129 | }() 130 | 131 | return cmds 132 | } 133 | 134 | func runHelp(cmd *cobra.Command, _ []string) { 135 | _ = cmd.Help() 136 | } 137 | 138 | func postRunHook(_ *options.Config) error { 139 | if err := flushProfiling(); err != nil { 140 | return err 141 | } 142 | return nil 143 | } 144 | -------------------------------------------------------------------------------- /internal/cli/coder/coder.go: -------------------------------------------------------------------------------- 1 | package coder 2 | 3 | import ( 4 | "path/filepath" 5 | "strings" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/coding-hui/ai-terminal/internal/ai" 10 | "github.com/coding-hui/ai-terminal/internal/convo" 11 | "github.com/coding-hui/ai-terminal/internal/errbook" 12 | "github.com/coding-hui/ai-terminal/internal/git" 13 | "github.com/coding-hui/ai-terminal/internal/options" 14 | "github.com/coding-hui/ai-terminal/internal/ui/coders" 15 | ) 16 | 17 | type Options struct { 18 | cfg *options.Config 19 | prompt string 20 | } 21 | 22 | func NewCmdCoder(cfg *options.Config) *cobra.Command { 23 | ops := &Options{cfg: cfg} 24 | cmd := &cobra.Command{ 25 | Use: "coder", 26 | Short: "Automatically generate code based on prompts.", 27 | RunE: ops.run, 28 | } 29 | 30 | cmd.Flags().StringVarP(&ops.prompt, "prompt", "p", "", "Prompt to generate code.") 31 | 32 | return cmd 33 | } 34 | 35 | func (o *Options) run(_ *cobra.Command, args []string) error { 36 | if len(args) > 0 { 37 | o.prompt = strings.Join(args, " ") + "\n" + o.prompt 38 | } 39 | 40 | repo := git.New() 41 | root, err := repo.GitDir() 42 | if err != nil { 43 | return errbook.Wrap("Could not get git root", err) 44 | } 45 | 46 | engine, err := ai.New(ai.WithConfig(o.cfg)) 47 | if err != nil { 48 | return errbook.Wrap("Could not initialized ai engine", err) 49 | } 50 | 51 | store, err := convo.GetConversationStore(o.cfg) 52 | if err != nil { 53 | return errbook.Wrap("Could not initialize conversation store", err) 54 | } 55 | 56 | autoCoder := coders.NewAutoCoder( 57 | coders.WithConfig(o.cfg), 58 | coders.WithEngine(engine), 59 | coders.WithRepo(repo), 60 | coders.WithCodeBasePath(filepath.Dir(root)), 61 | coders.WithStore(store), 62 | coders.WithPrompt(o.prompt), 63 | ) 64 | 65 | return autoCoder.Run() 66 | } 67 | -------------------------------------------------------------------------------- /internal/cli/commit/options.go: -------------------------------------------------------------------------------- 1 | package commit 2 | 3 | import ( 4 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/options" 7 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 8 | ) 9 | 10 | type Options struct { 11 | commitMsgFile string 12 | preview bool 13 | diffUnified int 14 | excludeList []string 15 | templateFile string 16 | templateString string 17 | commitAmend bool 18 | noConfirm bool 19 | commitLang string 20 | userPrompt string 21 | commitPrefix string 22 | 23 | cfg *options.Config 24 | genericclioptions.IOStreams 25 | 26 | FilesToAdd []string 27 | TokenUsage llms.Usage 28 | 29 | // Step token usages 30 | CodeReviewUsage llms.Usage 31 | SummarizeTitleUsage llms.Usage 32 | SummarizePrefixUsage llms.Usage 33 | TranslationUsage llms.Usage 34 | } 35 | 36 | // Option defines a function type for configuring Options 37 | type Option func(*Options) 38 | 39 | // WithNoConfirm sets the noConfirm flag 40 | func WithNoConfirm(noConfirm bool) Option { 41 | return func(o *Options) { 42 | o.noConfirm = noConfirm 43 | } 44 | } 45 | 46 | // WithFilesToAdd sets the files to add 47 | func WithFilesToAdd(files []string) Option { 48 | return func(o *Options) { 49 | o.FilesToAdd = files 50 | } 51 | } 52 | 53 | // WithIOStreams sets the IO streams 54 | func WithIOStreams(ioStreams genericclioptions.IOStreams) Option { 55 | return func(o *Options) { 56 | o.IOStreams = ioStreams 57 | } 58 | } 59 | 60 | // WithConfig sets the configuration 61 | func WithConfig(cfg *options.Config) Option { 62 | return func(o *Options) { 63 | o.cfg = cfg 64 | } 65 | } 66 | 67 | // WithCommitPrefix sets the commit prefix 68 | func WithCommitPrefix(prefix string) Option { 69 | return func(o *Options) { 70 | o.commitPrefix = prefix 71 | } 72 | } 73 | 74 | // WithCommitLang sets the commit language 75 | func WithCommitLang(lang string) Option { 76 | return func(o *Options) { 77 | o.commitLang = lang 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /internal/cli/completion/completion.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package completion output shell completion code for the specified shell (bash or zsh). 6 | package completion 7 | 8 | import ( 9 | "os" 10 | 11 | "github.com/spf13/cobra" 12 | ) 13 | 14 | // NewCmdCompletion creates the `completion` command. 15 | func NewCmdCompletion() *cobra.Command { 16 | return &cobra.Command{ 17 | Use: "completion [bash|zsh|fish|powershell]", 18 | Short: "Generate completion script", 19 | Long: "To load completions", 20 | DisableFlagsInUseLine: true, 21 | ValidArgs: []string{"bash", "zsh", "fish", "powershell"}, 22 | Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs), 23 | Run: func(cmd *cobra.Command, args []string) { 24 | switch args[0] { 25 | case "bash": 26 | _ = cmd.Root().GenBashCompletion(os.Stdout) 27 | case "zsh": 28 | _ = cmd.Root().GenZshCompletion(os.Stdout) 29 | case "fish": 30 | _ = cmd.Root().GenFishCompletion(os.Stdout, true) 31 | case "powershell": 32 | _ = cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) 33 | } 34 | }, 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /internal/cli/configure/configure.go: -------------------------------------------------------------------------------- 1 | package configure 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/coding-hui/ai-terminal/internal/errbook" 10 | "github.com/coding-hui/ai-terminal/internal/options" 11 | "github.com/coding-hui/ai-terminal/internal/runner" 12 | "github.com/coding-hui/ai-terminal/internal/system" 13 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 14 | ) 15 | 16 | // NewCmdConfigure implements the configure command. 17 | func NewCmdConfigure(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 18 | cmd := &cobra.Command{ 19 | Use: "configure", 20 | Aliases: []string{"conf", "cfg", "config", "settings"}, 21 | Short: "Manage AI Terminal configuration settings", 22 | Example: ` # Open settings in default editor 23 | ai cfg 24 | 25 | # Reset settings to defaults 26 | ai cfg reset 27 | 28 | # Show current configuration 29 | ai cfg echo`, 30 | RunE: func(cmd *cobra.Command, args []string) error { 31 | editor := system.Analyse().GetEditor() 32 | editorCmd := runner.PrepareEditSettingsCommand(editor, cfg.SettingsPath) 33 | editorCmd.Stdin, editorCmd.Stdout, editorCmd.Stderr = ioStreams.In, ioStreams.Out, ioStreams.ErrOut 34 | 35 | err := editorCmd.Start() 36 | if err != nil { 37 | return errbook.Wrap("Could not edit your settings file.", err) 38 | } 39 | 40 | err = editorCmd.Wait() 41 | if err != nil { 42 | return errbook.Wrap("Could not wait for your settings file to be saved.", err) 43 | } 44 | 45 | if !cfg.Quiet { 46 | fmt.Fprintln(os.Stderr, "Wrote config file to:", cfg.SettingsPath) 47 | } 48 | 49 | return nil 50 | }, 51 | } 52 | 53 | cmd.AddCommand(newCmdResetConfig(ioStreams, cfg)) 54 | cmd.AddCommand(newCmdEchoConfig(ioStreams, cfg)) 55 | 56 | return cmd 57 | } 58 | -------------------------------------------------------------------------------- /internal/cli/configure/echo.go: -------------------------------------------------------------------------------- 1 | package configure 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "text/template" 8 | 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/errbook" 12 | "github.com/coding-hui/ai-terminal/internal/options" 13 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 14 | ) 15 | 16 | type echo struct { 17 | genericclioptions.IOStreams 18 | Template string 19 | } 20 | 21 | func newCmdEchoConfig(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 22 | e := &echo{ 23 | IOStreams: ioStreams, 24 | Template: "", 25 | } 26 | cmd := &cobra.Command{ 27 | Use: "echo", 28 | Short: "Display current configuration settings", 29 | Example: ` # Show full configuration 30 | ai cfg echo 31 | 32 | # Show specific setting using template 33 | ai cfg echo -t "{{.model}}" 34 | 35 | # Check cache location 36 | ai cfg echo -t "Cache: {{.DataStore.CachePath}}"`, 37 | SilenceUsage: true, 38 | RunE: func(cmd *cobra.Command, args []string) error { 39 | return e.echoSettings(cfg) 40 | }, 41 | } 42 | 43 | cmd.Flags().StringVarP(&e.Template, "template", "t", e.Template, "Template string to format the settings output.") 44 | 45 | return cmd 46 | } 47 | 48 | func (e *echo) echoSettings(cfg *options.Config) error { 49 | _, err := os.Stat(cfg.SettingsPath) 50 | if err != nil { 51 | return errbook.Wrap("Couldn't read config file.", err) 52 | } 53 | inputFile, err := os.Open(cfg.SettingsPath) 54 | if err != nil { 55 | return errbook.Wrap("Couldn't open config file.", err) 56 | } 57 | defer inputFile.Close() //nolint:errcheck 58 | 59 | if e.Template != "" { 60 | tmpl, err := template.New("settings").Parse(e.Template) 61 | if err != nil { 62 | return errbook.Wrap("Couldn't pares template.", err) 63 | } 64 | err = tmpl.Execute(e.Out, cfg) 65 | if err != nil { 66 | return errbook.Wrap("Couldn't render template.", err) 67 | } 68 | return nil 69 | } 70 | 71 | _, _ = fmt.Fprintln(e.Out, "Current settings:") 72 | _, err = io.Copy(e.Out, inputFile) 73 | if err != nil { 74 | return errbook.Wrap("Couldn't echo config file.", err) 75 | } 76 | 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /internal/cli/configure/reset.go: -------------------------------------------------------------------------------- 1 | package configure 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/coding-hui/ai-terminal/internal/errbook" 11 | "github.com/coding-hui/ai-terminal/internal/options" 12 | "github.com/coding-hui/ai-terminal/internal/ui/console" 13 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 14 | ) 15 | 16 | type reset struct { 17 | genericclioptions.IOStreams 18 | } 19 | 20 | func newCmdResetConfig(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 21 | r := &reset{ 22 | IOStreams: ioStreams, 23 | } 24 | cmd := &cobra.Command{ 25 | Use: "reset", 26 | Short: "Reset configuration to default values", 27 | Example: ` # Reset settings with backup 28 | ai cfg reset 29 | 30 | # Reset and open new config in editor 31 | ai cfg reset && ai cfg`, 32 | SilenceUsage: true, 33 | RunE: func(cmd *cobra.Command, args []string) error { 34 | return r.resetSettings(cfg) 35 | }, 36 | } 37 | 38 | return cmd 39 | } 40 | 41 | func (r *reset) resetSettings(cfg *options.Config) error { 42 | _, err := os.Stat(cfg.SettingsPath) 43 | if err != nil { 44 | return errbook.Wrap("Couldn't read config file.", err) 45 | } 46 | inputFile, err := os.Open(cfg.SettingsPath) 47 | if err != nil { 48 | return errbook.Wrap("Couldn't open config file.", err) 49 | } 50 | defer inputFile.Close() //nolint:errcheck 51 | outputFile, err := os.Create(cfg.SettingsPath + ".bak") 52 | if err != nil { 53 | return errbook.Wrap("Couldn't backup config file.", err) 54 | } 55 | defer outputFile.Close() //nolint:errcheck 56 | _, err = io.Copy(outputFile, inputFile) 57 | if err != nil { 58 | return errbook.Wrap("Couldn't write config file.", err) 59 | } 60 | // The copy was successful, so now delete the original file 61 | err = os.Remove(cfg.SettingsPath) 62 | if err != nil { 63 | return errbook.Wrap("Couldn't remove config file.", err) 64 | } 65 | err = options.WriteConfigFile(cfg.SettingsPath) 66 | if err != nil { 67 | return errbook.Wrap("Couldn't write new config file.", err) 68 | } 69 | if !cfg.Quiet { 70 | _, _ = fmt.Fprintln(r.Out, "\nSettings restored to defaults!") 71 | _, _ = fmt.Fprintf(r.Out, 72 | "\n %s %s\n\n", 73 | console.StderrStyles().Comment.Render("Your old settings have been saved to:"), 74 | console.StderrStyles().Link.Render(cfg.SettingsPath+".bak"), 75 | ) 76 | } 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /internal/cli/convo/convo.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/options" 7 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 8 | ) 9 | 10 | // NewCmdConversation returns a cobra command for manager convo. 11 | func NewCmdConversation(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 12 | cmd := &cobra.Command{ 13 | Use: "convo", 14 | Short: "Managing chat conversations.", 15 | Run: func(cmd *cobra.Command, args []string) { 16 | _ = cmd.Help() 17 | }, 18 | } 19 | 20 | cmd.AddCommand(newCmdLsConversation(ioStreams, cfg)) 21 | cmd.AddCommand(newCmdRemoveConversation(ioStreams, cfg)) 22 | cmd.AddCommand(newCmdShowConversation(ioStreams, cfg)) 23 | 24 | return cmd 25 | } 26 | -------------------------------------------------------------------------------- /internal/cli/convo/ls.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/atotto/clipboard" 10 | timeago "github.com/caarlos0/timea.go" 11 | "github.com/charmbracelet/huh" 12 | "github.com/muesli/termenv" 13 | "github.com/spf13/cobra" 14 | 15 | "github.com/coding-hui/ai-terminal/internal/convo" 16 | "github.com/coding-hui/ai-terminal/internal/options" 17 | "github.com/coding-hui/ai-terminal/internal/ui/console" 18 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 19 | "github.com/coding-hui/ai-terminal/internal/util/term" 20 | ) 21 | 22 | type ls struct{} 23 | 24 | func newCmdLsConversation(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 25 | o := &ls{} 26 | cmd := &cobra.Command{ 27 | Use: "ls", 28 | Short: "Show chat conversations.", 29 | Example: `# Managing conversations: 30 | ai convo ls`, 31 | RunE: func(cmd *cobra.Command, _ []string) error { 32 | return o.Run(ioStreams, cfg) 33 | }, 34 | } 35 | 36 | return cmd 37 | } 38 | 39 | // Run executes convo command. 40 | func (o *ls) Run(ioStreams genericclioptions.IOStreams, cfg *options.Config) error { 41 | store, err := convo.GetConversationStore(cfg) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | conversations, err := store.ListConversations(context.Background()) 47 | if err != nil { 48 | return err 49 | } 50 | 51 | if len(conversations) == 0 { 52 | _, _ = fmt.Fprintln(ioStreams.ErrOut, "No conversations found.") 53 | return nil 54 | } 55 | 56 | if term.IsInputTTY() && term.IsOutputTTY() { 57 | selectFromList(conversations) 58 | return nil 59 | } 60 | 61 | printList(conversations) 62 | 63 | return nil 64 | } 65 | 66 | func makeOptions(conversations []convo.Conversation) []huh.Option[string] { 67 | opts := make([]huh.Option[string], 0, len(conversations)) 68 | for _, c := range conversations { 69 | timea := console.StdoutStyles().Timeago.Render(timeago.Of(c.UpdatedAt)) 70 | left := console.StdoutStyles().SHA1.Render(c.ID[:convo.Sha1short]) 71 | right := console.StdoutStyles().ConversationList.Render(c.Title, timea) 72 | if c.Model != nil { 73 | right += console.StdoutStyles().Comment.Render(*c.Model) 74 | } 75 | opts = append(opts, huh.NewOption(left+" "+right, c.ID)) 76 | } 77 | return opts 78 | } 79 | 80 | func selectFromList(conversations []convo.Conversation) { 81 | var selected string 82 | if err := huh.NewForm( 83 | huh.NewGroup( 84 | huh.NewSelect[string](). 85 | Title("Conversations"). 86 | Value(&selected). 87 | Options(makeOptions(conversations)...), 88 | ), 89 | ).Run(); err != nil { 90 | if !errors.Is(err, huh.ErrUserAborted) { 91 | fmt.Fprintln(os.Stderr, err.Error()) 92 | } 93 | return 94 | } 95 | 96 | _ = clipboard.WriteAll(selected) 97 | termenv.Copy(selected) 98 | console.PrintConfirmation("COPIED", selected) 99 | // suggest actions to use this conversation ID 100 | fmt.Println(console.StdoutStyles().Comment.Render( 101 | "You can use this conversation ID with the following commands:", 102 | )) 103 | 104 | type suggestion struct { 105 | cmd string 106 | usage string 107 | } 108 | 109 | suggestions := []suggestion{ 110 | { 111 | cmd: "show-convo", 112 | usage: "ai convo show", 113 | }, 114 | { 115 | cmd: "continue", 116 | usage: "ai ask --continue", 117 | }, 118 | { 119 | cmd: "rm-convo", 120 | usage: "ai convo rm", 121 | }, 122 | } 123 | for _, flag := range suggestions { 124 | fmt.Printf( 125 | " %-44s %s\n", 126 | console.StdoutStyles().Flag.Render(flag.usage), 127 | console.StdoutStyles().FlagDesc.Render(options.Help[flag.cmd]), 128 | ) 129 | } 130 | } 131 | 132 | func printList(conversations []convo.Conversation) { 133 | for _, conversation := range conversations { 134 | _, _ = fmt.Fprintf( 135 | os.Stdout, 136 | "%s\t%s\t%s\n", 137 | console.StdoutStyles().SHA1.Render(conversation.ID[:convo.Sha1short]), 138 | conversation.Title, 139 | console.StdoutStyles().Timeago.Render(timeago.Of(conversation.UpdatedAt)), 140 | ) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /internal/cli/convo/rm.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/convo" 12 | "github.com/coding-hui/ai-terminal/internal/errbook" 13 | "github.com/coding-hui/ai-terminal/internal/options" 14 | "github.com/coding-hui/ai-terminal/internal/ui/console" 15 | "github.com/coding-hui/ai-terminal/internal/util/flag" 16 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 17 | ) 18 | 19 | type rm struct { 20 | genericclioptions.IOStreams 21 | cfg *options.Config 22 | DeleteOlderThan time.Duration 23 | DeleteAll bool 24 | } 25 | 26 | func newCmdRemoveConversation(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 27 | o := &rm{IOStreams: ioStreams, cfg: cfg} 28 | cmd := &cobra.Command{ 29 | Use: "rm", 30 | Short: "Remove chat conversation.", 31 | SilenceUsage: true, 32 | RunE: func(cmd *cobra.Command, args []string) error { 33 | return o.Run(args) 34 | }, 35 | } 36 | 37 | cmd.Flags().Var(flag.NewDurationFlag(o.DeleteOlderThan, &o.DeleteOlderThan), "delete-older-than", console.StdoutStyles().FlagDesc.Render(options.Help["rm-convo-older-than"])) 38 | cmd.Flags().BoolVar(&o.DeleteAll, "all", false, console.StdoutStyles().FlagDesc.Render(options.Help["rm-all-convo"])) 39 | 40 | return cmd 41 | } 42 | 43 | func (r *rm) Run(args []string) error { 44 | if r.DeleteOlderThan <= 0 && len(args) <= 0 && !r.DeleteAll { 45 | return errbook.New("Please provide at least one conversation ID or --delete-older-than flag") 46 | } 47 | 48 | store, err := convo.GetConversationStore(r.cfg) 49 | if err != nil { 50 | return err 51 | } 52 | 53 | if r.DeleteOlderThan > 0 { 54 | return r.deleteConversationOlderThan(store, false) 55 | } 56 | 57 | if r.DeleteAll { 58 | return r.deleteConversationOlderThan(store, true) 59 | } 60 | 61 | return r.deleteConversation(store, args[0]) 62 | } 63 | 64 | func (r *rm) deleteConversation(store convo.Store, conversationID string) error { 65 | conversation, err := store.GetConversation(context.Background(), conversationID) 66 | if err != nil { 67 | return errbook.Wrap("Couldn't find conversation to delete.", err) 68 | } 69 | 70 | err = store.DeleteConversation(context.Background(), conversation.ID) 71 | if err != nil { 72 | return errbook.Wrap("Couldn't delete conversation.", err) 73 | } 74 | 75 | _, err = store.CleanContexts(context.Background(), conversation.ID) 76 | if err != nil { 77 | return errbook.Wrap("Couldn't clean conversation load contexts.", err) 78 | } 79 | 80 | err = store.InvalidateMessages(context.Background(), conversation.ID) 81 | if err != nil { 82 | return errbook.Wrap("Couldn't invalidate conversation.", err) 83 | } 84 | 85 | if !r.cfg.Quiet { 86 | fmt.Fprintln(os.Stderr, "Conversation deleted:", conversation.ID[:convo.Sha1minLen]) 87 | } 88 | 89 | return nil 90 | } 91 | 92 | func (r *rm) deleteConversationOlderThan(store convo.Store, deleteAll bool) error { 93 | var err error 94 | var conversations []convo.Conversation 95 | 96 | if deleteAll { 97 | conversations, err = store.ListConversations(context.Background()) 98 | if err != nil { 99 | return errbook.Wrap("Couldn't list conversations.", err) 100 | } 101 | } else { 102 | conversations, err = store.ListConversationsOlderThan(context.Background(), r.DeleteOlderThan) 103 | if err != nil { 104 | return errbook.Wrap("Couldn't list conversations.", err) 105 | } 106 | } 107 | 108 | if len(conversations) == 0 { 109 | if !r.cfg.Quiet { 110 | fmt.Fprintln(os.Stderr, "No conversations found.") 111 | return nil 112 | } 113 | return nil 114 | } 115 | 116 | if !r.cfg.Quiet { 117 | printList(conversations) 118 | confirmTitle := "Delete all conversations?" 119 | if !deleteAll { 120 | confirmTitle = fmt.Sprintf("Delete conversations older than %s?", r.DeleteOlderThan) 121 | } 122 | confirm := console.WaitForUserConfirm(console.Yes, "%s", confirmTitle) 123 | if !confirm { 124 | return errbook.NewUserErrorf("Aborted by user.") 125 | } 126 | } 127 | 128 | for _, c := range conversations { 129 | if err := r.deleteConversation(store, c.ID); err != nil { 130 | return err 131 | } 132 | } 133 | 134 | return nil 135 | } 136 | -------------------------------------------------------------------------------- /internal/cli/convo/show.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/cli/ask" 7 | "github.com/coding-hui/ai-terminal/internal/options" 8 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 9 | ) 10 | 11 | type show struct { 12 | genericclioptions.IOStreams 13 | last bool 14 | } 15 | 16 | func newCmdShowConversation(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 17 | o := &show{IOStreams: ioStreams} 18 | cmd := &cobra.Command{ 19 | Use: "show", 20 | Short: "Show chat conversation.", 21 | SilenceUsage: true, 22 | RunE: func(cmd *cobra.Command, args []string) error { 23 | return o.Run(args, cfg) 24 | }, 25 | } 26 | 27 | cmd.Flags().BoolVarP(&o.last, "last", "l", false, "show last chat conversation.") 28 | 29 | return cmd 30 | } 31 | 32 | func (s *show) Run(args []string, cfg *options.Config) error { 33 | cfg.ShowLast = s.last 34 | if len(args) > 0 { 35 | cfg.Show = args[0] 36 | } 37 | if err := ask.NewOptions(s.IOStreams, cfg).Run(); err != nil { 38 | return err 39 | } 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /internal/cli/hook/hook.go: -------------------------------------------------------------------------------- 1 | package hook 2 | 3 | import ( 4 | "github.com/fatih/color" 5 | "github.com/spf13/cobra" 6 | 7 | "github.com/coding-hui/ai-terminal/internal/git" 8 | ) 9 | 10 | func NewCmdHook() *cobra.Command { 11 | hookCmd := &cobra.Command{ 12 | Use: "hook", 13 | Short: "install/uninstall git prepare-commit-msg hook", 14 | } 15 | 16 | hookCmd.AddCommand( 17 | &cobra.Command{ 18 | Use: "install", 19 | Short: "install git prepare-commit-msg hook", 20 | RunE: func(cmd *cobra.Command, args []string) error { 21 | g := git.New() 22 | 23 | if err := g.InstallHook(); err != nil { 24 | return err 25 | } 26 | color.Green("Install git hook: prepare-commit-msg successfully") 27 | color.Green("You can see the hook file: .git/hooks/prepare-commit-msg") 28 | 29 | return nil 30 | }, 31 | }, 32 | &cobra.Command{ 33 | Use: "uninstall", 34 | Short: "uninstall git prepare-commit-msg hook", 35 | RunE: func(cmd *cobra.Command, args []string) error { 36 | g := git.New() 37 | 38 | if err := g.UninstallHook(); err != nil { 39 | return err 40 | } 41 | color.Green("Remove git hook: prepare-commit-msg successfully") 42 | 43 | return nil 44 | }, 45 | }, 46 | ) 47 | 48 | return hookCmd 49 | } 50 | -------------------------------------------------------------------------------- /internal/cli/loadctx/clean.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package loadctx 6 | 7 | import ( 8 | "context" 9 | 10 | "github.com/spf13/cobra" 11 | 12 | "github.com/coding-hui/ai-terminal/internal/convo" 13 | "github.com/coding-hui/ai-terminal/internal/errbook" 14 | "github.com/coding-hui/ai-terminal/internal/options" 15 | "github.com/coding-hui/ai-terminal/internal/ui/console" 16 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 17 | ) 18 | 19 | // clean represents a command to clean up loaded contexts in the AI terminal. 20 | // It provides functionality to remove all context data for the current conversation. 21 | type clean struct { 22 | genericclioptions.IOStreams 23 | cfg *options.Config // Configuration for the AI terminal 24 | convoStore convo.Store // Storage interface for conversation data 25 | } 26 | 27 | // newClean creates and initializes a new clean command instance. 28 | // Parameters: 29 | // - ioStreams: Streams for input/output operations 30 | // - cfg: Configuration settings for the AI terminal 31 | // 32 | // Returns: 33 | // - *clean: A new clean command instance 34 | func newClean(ioStreams genericclioptions.IOStreams, cfg *options.Config) *clean { 35 | return &clean{ 36 | IOStreams: ioStreams, 37 | cfg: cfg, 38 | } 39 | } 40 | 41 | // newCmdClean creates a new cobra command for the clean operation. 42 | // It sets up the command structure and binds it to the clean functionality. 43 | func newCmdClean(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 44 | o := newClean(ioStreams, cfg) 45 | 46 | cmd := &cobra.Command{ 47 | Use: "clean", 48 | Aliases: []string{"drop"}, 49 | Short: "Delete all loaded contexts", 50 | Long: "Clean command removes all loaded contexts for the current conversation", 51 | RunE: func(cmd *cobra.Command, args []string) error { 52 | return o.Run() 53 | }, 54 | } 55 | 56 | return cmd 57 | } 58 | 59 | // Run executes the clean command logic. 60 | // It performs the following steps: 61 | // 1. Initializes the conversation store 62 | // 2. Retrieves the current conversation ID 63 | // 3. Cleans all contexts associated with the current conversation 64 | // 4. Displays the operation result to the user 65 | // 66 | // Returns: 67 | // - error: Any error that occurred during execution 68 | func (o *clean) Run() error { 69 | ctx := context.Background() 70 | 71 | // Initialize conversation store 72 | var err error 73 | o.convoStore, err = convo.GetConversationStore(o.cfg) 74 | if err != nil { 75 | return errbook.Wrap("failed to initialize conversation store", err) 76 | } 77 | 78 | // Get current conversation details 79 | conversation, err := convo.GetCurrentConversationID(ctx, o.cfg, o.convoStore) 80 | if err != nil { 81 | return errbook.Wrap("failed to get current conversation", err) 82 | } 83 | 84 | // Clean all contexts for current conversation 85 | count, err := o.convoStore.CleanContexts(ctx, conversation.ReadID) 86 | if err != nil { 87 | return errbook.Wrap("failed to clean contexts", err) 88 | } 89 | 90 | // Display operation result 91 | if count > 0 { 92 | console.Render("Successfully deleted %d loaded contexts", count) 93 | } else { 94 | console.Render("No contexts to delete") 95 | } 96 | 97 | return nil 98 | } 99 | -------------------------------------------------------------------------------- /internal/cli/loadctx/context.go: -------------------------------------------------------------------------------- 1 | package loadctx 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/options" 7 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 8 | ) 9 | 10 | // NewCmdContext returns a cobra command for managing context 11 | func NewCmdContext(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 12 | cmd := &cobra.Command{ 13 | Use: "context", 14 | Aliases: []string{"context", "ctx"}, 15 | Short: "Manage context for AI interactions", 16 | Run: func(cmd *cobra.Command, args []string) { 17 | _ = cmd.Help() 18 | }, 19 | } 20 | 21 | cmd.AddCommand( 22 | newCmdLoad(ioStreams, cfg), 23 | newCmdList(ioStreams, cfg), 24 | newCmdClean(ioStreams, cfg), 25 | ) 26 | 27 | return cmd 28 | } 29 | -------------------------------------------------------------------------------- /internal/cli/loadctx/list.go: -------------------------------------------------------------------------------- 1 | package loadctx 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | timeago "github.com/caarlos0/timea.go" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/convo" 12 | "github.com/coding-hui/ai-terminal/internal/errbook" 13 | "github.com/coding-hui/ai-terminal/internal/options" 14 | "github.com/coding-hui/ai-terminal/internal/ui/console" 15 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 16 | ) 17 | 18 | type list struct { 19 | genericclioptions.IOStreams 20 | cfg *options.Config 21 | convoStore convo.Store 22 | } 23 | 24 | func newList(ioStreams genericclioptions.IOStreams, cfg *options.Config) *list { 25 | return &list{ 26 | IOStreams: ioStreams, 27 | cfg: cfg, 28 | } 29 | } 30 | 31 | func newCmdList(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 32 | o := newList(ioStreams, cfg) 33 | 34 | cmd := &cobra.Command{ 35 | Use: "list", 36 | Aliases: []string{"ls"}, 37 | Short: "List all loaded contexts", 38 | RunE: func(cmd *cobra.Command, args []string) error { 39 | return o.Run() 40 | }, 41 | } 42 | 43 | return cmd 44 | } 45 | 46 | func (o *list) Run() error { 47 | // Initialize conversation store 48 | var err error 49 | o.convoStore, err = convo.GetConversationStore(o.cfg) 50 | if err != nil { 51 | return errbook.Wrap("failed to initialize conversation store", err) 52 | } 53 | 54 | conversation, err := convo.GetCurrentConversationID(context.Background(), o.cfg, o.convoStore) 55 | if err != nil { 56 | return errbook.Wrap("failed to get current conversation", err) 57 | } 58 | 59 | // Get contexts for current conversation 60 | ctxs, err := o.convoStore.ListContextsByteConvoID(context.Background(), conversation.ReadID) 61 | if err != nil { 62 | return errbook.Wrap("failed to list contexts", err) 63 | } 64 | 65 | if len(ctxs) == 0 { 66 | console.Render("No contexts loaded") 67 | return nil 68 | } 69 | 70 | // Display contexts in a table 71 | printList(ctxs) 72 | 73 | return nil 74 | } 75 | 76 | func printList(ctxs []convo.LoadContext) { 77 | for _, ctx := range ctxs { 78 | _, _ = fmt.Fprintf( 79 | os.Stdout, 80 | "%s\t%s\t%s\t%s\n", 81 | console.StdoutStyles().SHA1.Render(ctx.ConversationID[:convo.Sha1short]), 82 | ctx.Name, 83 | ctx.Type, 84 | console.StdoutStyles().Timeago.Render(timeago.Of(ctx.UpdatedAt)), 85 | ) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /internal/cli/loadctx/load.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package loadctx 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "os" 11 | "path/filepath" 12 | 13 | "github.com/spf13/cobra" 14 | 15 | "github.com/coding-hui/ai-terminal/internal/convo" 16 | "github.com/coding-hui/ai-terminal/internal/errbook" 17 | "github.com/coding-hui/ai-terminal/internal/options" 18 | "github.com/coding-hui/ai-terminal/internal/ui/console" 19 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 20 | "github.com/coding-hui/ai-terminal/internal/util/rest" 21 | "github.com/coding-hui/ai-terminal/internal/util/term" 22 | ) 23 | 24 | // load is a struct to support load command 25 | type load struct { 26 | genericclioptions.IOStreams 27 | cfg *options.Config 28 | convoStore convo.Store 29 | currentConversation convo.CacheDetailsMsg 30 | } 31 | 32 | // newLoad returns initialized load 33 | func newLoad(ioStreams genericclioptions.IOStreams, cfg *options.Config) *load { 34 | return &load{ 35 | IOStreams: ioStreams, 36 | cfg: cfg, 37 | } 38 | } 39 | 40 | func newCmdLoad(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 41 | o := newLoad(ioStreams, cfg) 42 | 43 | cmd := &cobra.Command{ 44 | Use: "load ", 45 | Short: "Preload files or remote documents for later use", 46 | Example: ` # Load a local file 47 | ai context load ./example.txt 48 | 49 | # Load a remote document 50 | ai context load https://example.com/doc.txt`, 51 | RunE: func(cmd *cobra.Command, args []string) error { 52 | if len(args) == 0 { 53 | return errbook.New("Please provide at least one file or URL to load") 54 | } 55 | return o.Run(args) 56 | }, 57 | } 58 | 59 | return cmd 60 | } 61 | 62 | // Run executes the load command 63 | func (o *load) Run(args []string) (err error) { 64 | // Initialize conversation store 65 | o.convoStore, err = convo.GetConversationStore(o.cfg) 66 | if err != nil { 67 | return errbook.Wrap("Failed to initialize conversation store", err) 68 | } 69 | 70 | // Get current conversation ID 71 | details, err := convo.GetCurrentConversationID(context.Background(), o.cfg, o.convoStore) 72 | if err != nil { 73 | return errbook.Wrap("Failed to get current conversation ID", err) 74 | } 75 | o.currentConversation = details 76 | 77 | for _, path := range args { 78 | if err = o.loadPath(path); err != nil { 79 | return err 80 | } 81 | } 82 | 83 | err = o.convoStore.SaveConversation( 84 | context.Background(), 85 | o.currentConversation.WriteID, 86 | fmt.Sprintf("load-contexts-%s", o.currentConversation.WriteID[:convo.Sha1short]), 87 | o.currentConversation.Model, 88 | ) 89 | if err != nil { 90 | return errbook.Wrap("Failed to save conversation", err) 91 | } 92 | 93 | return nil 94 | } 95 | 96 | func (o *load) loadPath(path string) error { 97 | // Handle remote URLs 98 | if rest.IsValidURL(path) { 99 | console.Render("Loading remote content [%s]", path) 100 | content, err := rest.FetchURLContent(path) 101 | if err != nil { 102 | return errbook.Wrap("Failed to load remote content", err) 103 | } 104 | return o.saveContent(path, content, convo.ContentTypeURL) 105 | } 106 | 107 | // Handle local files 108 | console.Render("Loading local file [%s]", path) 109 | if err := o.saveContent(path, "", convo.ContentTypeFile); err != nil { 110 | return err 111 | } 112 | 113 | return nil 114 | } 115 | 116 | func (o *load) saveContent(sourcePath, content string, contentType convo.ContentType) error { 117 | // Create cache directory if it doesn't exist 118 | cacheDir := filepath.Join(o.cfg.DataStore.CachePath, "loaded") 119 | if err := os.MkdirAll(cacheDir, 0755); err != nil { 120 | return errbook.Wrap("Failed to create cache directory", err) 121 | } 122 | 123 | // Generate safe filename 124 | filename := term.SanitizeFilename(sourcePath) 125 | cachePath := sourcePath 126 | 127 | if contentType != convo.ContentTypeFile { 128 | cachePath = filepath.Join(cacheDir, filename) 129 | // Save content to cache 130 | if err := os.WriteFile(cachePath, []byte(content), 0644); err != nil { 131 | return errbook.Wrap("Failed to save content", err) 132 | } 133 | } 134 | 135 | err := o.convoStore.SaveContext(context.Background(), &convo.LoadContext{ 136 | ConversationID: o.currentConversation.WriteID, 137 | Name: filename, 138 | Type: contentType, 139 | URL: sourcePath, 140 | FilePath: cachePath, 141 | Content: content, 142 | }) 143 | if err != nil { 144 | return errbook.Wrap("Failed to save load content", err) 145 | } 146 | 147 | // save to conversation 148 | console.Render("Saved content [%s] to conversation [%s]", cachePath, o.currentConversation.WriteID[:convo.Sha1short]) 149 | 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /internal/cli/manpage/manpage.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package manpage 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | 11 | mcobra "github.com/muesli/mango-cobra" 12 | "github.com/muesli/roff" 13 | "github.com/spf13/cobra" 14 | ) 15 | 16 | // NewCmdManPage creates the `manpage` command. 17 | func NewCmdManPage(rootCmd *cobra.Command) *cobra.Command { 18 | return &cobra.Command{ 19 | Use: "manpage", 20 | Short: "Generates manpages", 21 | SilenceUsage: true, 22 | DisableFlagsInUseLine: true, 23 | Hidden: true, 24 | Args: cobra.NoArgs, 25 | RunE: func(*cobra.Command, []string) error { 26 | manPage, err := mcobra.NewManPage(1, rootCmd) 27 | if err != nil { 28 | //nolint:wrapcheck 29 | return err 30 | } 31 | _, err = fmt.Fprint(os.Stdout, manPage.Build(roff.NewDocument())) 32 | //nolint:wrapcheck 33 | return err 34 | }, 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /internal/cli/profiling.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package cli 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | "runtime" 11 | "runtime/pprof" 12 | 13 | "github.com/spf13/pflag" 14 | ) 15 | 16 | var ( 17 | profileName string 18 | profileOutput string 19 | ) 20 | 21 | func addProfilingFlags(flags *pflag.FlagSet) { 22 | flags.StringVar( 23 | &profileName, 24 | "profile", 25 | "none", 26 | "Name of profile to capture. One of (none|cpu|heap|goroutine|threadcreate|block|mutex)", 27 | ) 28 | flags.StringVar(&profileOutput, "profile-output", "profile.pprof", "Name of the file to write the profile to") 29 | } 30 | 31 | func initProfiling() error { 32 | switch profileName { 33 | case "none": 34 | return nil 35 | case "cpu": 36 | f, err := os.Create(profileOutput) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | return pprof.StartCPUProfile(f) 42 | // Block and mutex profiles need a call to Set{Block,Mutex}ProfileRate to 43 | // output anything. We choose to sample all events. 44 | case "block": 45 | runtime.SetBlockProfileRate(1) 46 | 47 | return nil 48 | case "mutex": 49 | runtime.SetMutexProfileFraction(1) 50 | 51 | return nil 52 | default: 53 | // Check the profile name is valid. 54 | if profile := pprof.Lookup(profileName); profile == nil { 55 | return fmt.Errorf("unknown profile '%s'", profileName) 56 | } 57 | } 58 | 59 | return nil 60 | } 61 | 62 | func flushProfiling() error { 63 | switch profileName { 64 | case "none": 65 | return nil 66 | case "cpu": 67 | pprof.StopCPUProfile() 68 | case "heap": 69 | runtime.GC() 70 | 71 | fallthrough 72 | default: 73 | profile := pprof.Lookup(profileName) 74 | if profile == nil { 75 | return nil 76 | } 77 | 78 | f, err := os.Create(profileOutput) 79 | if err != nil { 80 | return err 81 | } 82 | _ = profile.WriteTo(f, 0) 83 | } 84 | 85 | return nil 86 | } 87 | -------------------------------------------------------------------------------- /internal/cli/review/review.go: -------------------------------------------------------------------------------- 1 | package review 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | 8 | "github.com/fatih/color" 9 | "github.com/spf13/cobra" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/ai" 12 | "github.com/coding-hui/ai-terminal/internal/git" 13 | "github.com/coding-hui/ai-terminal/internal/options" 14 | "github.com/coding-hui/ai-terminal/internal/prompt" 15 | "github.com/coding-hui/ai-terminal/internal/runner" 16 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 17 | ) 18 | 19 | type Options struct { 20 | diffUnified int 21 | excludeList []string 22 | commitAmend bool 23 | commitLang string 24 | 25 | cfg *options.Config 26 | genericclioptions.IOStreams 27 | } 28 | 29 | // NewCmdCommit returns a cobra command for commit msg. 30 | func NewCmdCommit(ioStreams genericclioptions.IOStreams, cfg *options.Config) *cobra.Command { 31 | ops := &Options{ 32 | IOStreams: ioStreams, 33 | cfg: cfg, 34 | } 35 | 36 | reviewCmd := &cobra.Command{ 37 | Use: "review", 38 | Short: "Auto review code changes", 39 | RunE: ops.reviewCode, 40 | } 41 | 42 | reviewCmd.Flags().IntVar(&ops.diffUnified, "diff-unified", 3, "generate diffs with lines of context, default is 3") 43 | reviewCmd.Flags().StringSliceVar(&ops.excludeList, "exclude-list", []string{}, "exclude file from git diff command") 44 | reviewCmd.Flags().BoolVar(&ops.commitAmend, "amend", false, "replace the tip of the current branch by creating a new commit.") 45 | reviewCmd.Flags().StringVar(&ops.commitLang, "lang", "en", "summarizing language uses English by default. "+ 46 | "support en, zh-cn, zh-tw, ja, pt, pt-br.") 47 | 48 | return reviewCmd 49 | } 50 | 51 | func (o *Options) reviewCode(cmd *cobra.Command, args []string) error { 52 | if !runner.IsCommandAvailable("git") { 53 | return errors.New("git command not found on your system's PATH. Please install Git and try again") 54 | } 55 | 56 | llmEngine, err := ai.New(ai.WithConfig(o.cfg)) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | g := git.New( 62 | git.WithDiffUnified(o.diffUnified), 63 | git.WithExcludeList(o.excludeList), 64 | git.WithEnableAmend(o.commitAmend), 65 | ) 66 | diff, err := g.DiffFiles() 67 | if err != nil { 68 | return err 69 | } 70 | 71 | vars := map[string]any{prompt.FileDiffsKey: diff} 72 | 73 | reviewPrompt, err := prompt.GetPromptStringByTemplateName(prompt.CodeReviewTemplate, vars) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | // Get summarize comment from diff datas 79 | color.Cyan("We are trying to review code changes") 80 | reviewResp, err := llmEngine.CreateCompletion(context.Background(), reviewPrompt.Messages()) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | reviewMessage := reviewResp.Explanation 86 | if prompt.GetLanguage(o.commitLang) != prompt.DefaultLanguage { 87 | translationPrompt, err := prompt.GetPromptStringByTemplateName( 88 | prompt.TranslationTemplate, map[string]any{ 89 | prompt.OutputLanguageKey: prompt.GetLanguage(o.commitLang), 90 | prompt.OutputMessageKey: reviewMessage, 91 | }, 92 | ) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | color.Cyan("we are trying to translate code review to " + o.commitLang + " language") 98 | translationResp, err := llmEngine.CreateCompletion(context.Background(), translationPrompt.Messages()) 99 | if err != nil { 100 | return err 101 | } 102 | reviewMessage = translationResp.Explanation 103 | } 104 | 105 | // Output core review summary 106 | color.Yellow("================Review Summary====================") 107 | color.Yellow("\n" + strings.TrimSpace(reviewMessage) + "\n\n") 108 | color.Yellow("==================================================") 109 | 110 | return nil 111 | } 112 | -------------------------------------------------------------------------------- /internal/cli/version/version.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package version print the client and server version information. 6 | package version 7 | 8 | import ( 9 | "encoding/json" 10 | "fmt" 11 | "text/template" 12 | 13 | "github.com/ghodss/yaml" 14 | "github.com/spf13/cobra" 15 | 16 | "github.com/coding-hui/common/version" 17 | 18 | "github.com/coding-hui/ai-terminal/internal/errbook" 19 | "github.com/coding-hui/ai-terminal/internal/util/genericclioptions" 20 | "github.com/coding-hui/ai-terminal/internal/util/templates" 21 | ) 22 | 23 | var versionExample = templates.Examples(` 24 | # Print the client and server versions for the current context 25 | ai version 26 | 27 | # Print the version in JSON format 28 | ai version --output json 29 | 30 | # Print the version in YAML format 31 | ai version --output yaml 32 | 33 | # Print the version using a custom Go template 34 | ai version --template '{{.GitVersion}}'`) 35 | 36 | // Options is a struct to support version command. 37 | type Options struct { 38 | ClientOnly bool 39 | Short bool 40 | Output string 41 | Template string 42 | 43 | genericclioptions.IOStreams 44 | } 45 | 46 | // NewOptions returns initialized Options. 47 | func NewOptions(ioStreams genericclioptions.IOStreams) *Options { 48 | return &Options{ 49 | IOStreams: ioStreams, 50 | Template: "", 51 | } 52 | } 53 | 54 | // NewCmdVersion returns a cobra command for fetching versions. 55 | func NewCmdVersion(ioStreams genericclioptions.IOStreams) *cobra.Command { 56 | o := NewOptions(ioStreams) 57 | cmd := &cobra.Command{ 58 | Use: "version", 59 | Short: "Print the cli version information", 60 | Long: `Print the cli version information for the current context. 61 | 62 | The version information includes the following fields: 63 | - GitVersion: The semantic version of the build. 64 | - GitCommit: The git commit hash of the build. 65 | - GitTreeState: The state of the git tree, either 'clean' or 'dirty'. 66 | - BuildDate: The date of the build. 67 | - GoVersion: The version of Go used to compile the binary. 68 | - Compiler: The compiler used to compile the binary. 69 | - Platform: The platform (OS/Architecture) for which the binary was built.`, 70 | Example: versionExample, 71 | RunE: func(cmd *cobra.Command, args []string) error { 72 | if err := o.Validate(); err != nil { 73 | return err 74 | } 75 | return o.Run() 76 | }, 77 | } 78 | 79 | cmd.Flags().BoolVar(&o.Short, "short", o.Short, "If true, print just the version number.") 80 | cmd.Flags().StringVarP(&o.Output, "output", "o", o.Output, "One of 'yaml' or 'json'.") 81 | cmd.Flags().StringVarP(&o.Template, "template", "t", o.Template, "Template string to format the version output.") 82 | 83 | return cmd 84 | } 85 | 86 | // Validate validates the provided options. 87 | func (o *Options) Validate() error { 88 | if o.Output != "" && o.Output != "yaml" && o.Output != "json" { 89 | return errbook.New("Invalid output format. Please use 'yaml' or 'json'.") 90 | } 91 | 92 | return nil 93 | } 94 | 95 | // Run executes version command. 96 | func (o *Options) Run() error { 97 | versionInfo := version.Get() 98 | 99 | if o.Template != "" { 100 | tmpl, err := template.New("version").Parse(o.Template) 101 | if err != nil { 102 | return errbook.Wrap("Failed to parse template", err) 103 | } 104 | err = tmpl.Execute(o.Out, versionInfo) 105 | if err != nil { 106 | return errbook.Wrap("Failed to execute template", err) 107 | } 108 | } else { 109 | switch o.Output { 110 | case "": 111 | if o.Short { 112 | fmt.Fprintf(o.Out, "%s\n", versionInfo.GitVersion) //nolint:errcheck 113 | } else { 114 | fmt.Fprintf(o.Out, "Version: %s\n", versionInfo.GitVersion) //nolint:errcheck 115 | } 116 | case "yaml": 117 | marshaled, err := yaml.Marshal(&versionInfo) 118 | if err != nil { 119 | return errbook.Wrap("Failed to marshal version info to yaml", err) 120 | } 121 | fmt.Fprintln(o.Out, string(marshaled)) //nolint:errcheck 122 | case "json": 123 | marshaled, err := json.MarshalIndent(&versionInfo, "", " ") 124 | if err != nil { 125 | return errbook.Wrap("Failed to marshal version info to json", err) 126 | } 127 | fmt.Fprintln(o.Out, string(marshaled)) //nolint:errcheck 128 | default: 129 | // There is a bug in the program if we hit this case. 130 | // However, we follow a policy of never panicking. 131 | return errbook.New("Invalid output format: %q. Please use 'yaml' or 'json'.", o.Output) 132 | } 133 | } 134 | 135 | return nil 136 | } 137 | -------------------------------------------------------------------------------- /internal/convo/chat_history_store.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "sync" 10 | 11 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 12 | ) 13 | 14 | const cacheExt = ".gob" 15 | 16 | var errInvalidID = errors.New("invalid id") 17 | 18 | type SimpleChatHistoryStore struct { 19 | dir string 20 | messages map[string][]llms.ChatMessage 21 | loaded map[string]bool // tracks which conversations have been loaded 22 | 23 | sync.RWMutex // protects access to messages and loaded maps 24 | } 25 | 26 | func NewSimpleChatHistoryStore(dir string) *SimpleChatHistoryStore { 27 | return &SimpleChatHistoryStore{ 28 | dir: dir, 29 | messages: make(map[string][]llms.ChatMessage), 30 | loaded: make(map[string]bool), 31 | } 32 | } 33 | 34 | // AddAIMessage adds an AIMessage to the chat message convo. 35 | func (h *SimpleChatHistoryStore) AddAIMessage(ctx context.Context, convoID, message string) error { 36 | return h.AddMessage(ctx, convoID, llms.AIChatMessage{Content: message}) 37 | } 38 | 39 | // AddUserMessage adds a user to the chat message convo. 40 | func (h *SimpleChatHistoryStore) AddUserMessage(ctx context.Context, convoID, message string) error { 41 | return h.AddMessage(ctx, convoID, llms.HumanChatMessage{Content: message}) 42 | } 43 | 44 | func (h *SimpleChatHistoryStore) AddMessage(_ context.Context, convoID string, message llms.ChatMessage) error { 45 | h.Lock() 46 | defer h.Unlock() 47 | 48 | if !h.loaded[convoID] { 49 | if err := h.load(convoID); err != nil && !errors.Is(err, os.ErrNotExist) { 50 | return err 51 | } 52 | h.loaded[convoID] = true 53 | } 54 | h.messages[convoID] = append(h.messages[convoID], message) 55 | return nil 56 | } 57 | 58 | func (h *SimpleChatHistoryStore) SetMessages(ctx context.Context, convoID string, messages []llms.ChatMessage) error { 59 | if err := h.InvalidateMessages(ctx, convoID); err != nil && !errors.Is(err, os.ErrNotExist) { 60 | return err 61 | } 62 | 63 | h.Lock() 64 | defer h.Unlock() 65 | h.messages[convoID] = messages 66 | return h.PersistentMessages(ctx, convoID) 67 | } 68 | 69 | func (h *SimpleChatHistoryStore) Messages(_ context.Context, convoID string) ([]llms.ChatMessage, error) { 70 | h.RLock() 71 | defer h.RUnlock() 72 | 73 | if !h.loaded[convoID] { 74 | if err := h.load(convoID); err != nil && !errors.Is(err, os.ErrNotExist) { 75 | return nil, err 76 | } 77 | h.loaded[convoID] = true 78 | } 79 | return h.messages[convoID], nil 80 | } 81 | 82 | func (h *SimpleChatHistoryStore) load(convoID string) error { 83 | if convoID == "" { 84 | return fmt.Errorf("read: %w", errInvalidID) 85 | } 86 | file, err := os.Open(filepath.Join(h.dir, convoID+cacheExt)) 87 | if err != nil { 88 | return fmt.Errorf("read: %w", err) 89 | } 90 | defer file.Close() //nolint:errcheck 91 | 92 | var rawMessages []llms.ChatMessageModel 93 | if err := decode(file, &rawMessages); err != nil { 94 | return fmt.Errorf("read: %w", err) 95 | } 96 | 97 | h.messages[convoID] = nil 98 | for _, v := range rawMessages { 99 | h.messages[convoID] = append(h.messages[convoID], v.ToChatMessage()) 100 | } 101 | 102 | return nil 103 | } 104 | 105 | func (h *SimpleChatHistoryStore) PersistentMessages(_ context.Context, convoID string) error { 106 | if convoID == "" { 107 | return fmt.Errorf("write: %w", errInvalidID) 108 | } 109 | 110 | // Ensure directory exists 111 | if err := os.MkdirAll(h.dir, 0755); err != nil { 112 | return fmt.Errorf("create directory: %w", err) 113 | } 114 | 115 | file, err := os.Create(filepath.Join(h.dir, convoID+cacheExt)) 116 | if err != nil { 117 | return fmt.Errorf("write: %w", err) 118 | } 119 | defer file.Close() //nolint:errcheck 120 | 121 | var rawMessages []llms.ChatMessageModel 122 | for _, v := range h.messages[convoID] { 123 | if v != nil { 124 | rawMessages = append(rawMessages, llms.ConvertChatMessageToModel(v)) 125 | } 126 | } 127 | if err := encode(file, &rawMessages); err != nil { 128 | return fmt.Errorf("write: %w", err) 129 | } 130 | 131 | return nil 132 | } 133 | 134 | func (h *SimpleChatHistoryStore) InvalidateMessages(_ context.Context, convoID string) error { 135 | h.Lock() 136 | defer h.Unlock() 137 | 138 | if convoID == "" { 139 | return fmt.Errorf("delete: %w", errInvalidID) 140 | } 141 | if err := os.Remove(filepath.Join(h.dir, convoID+cacheExt)); err != nil && !errors.Is(err, os.ErrNotExist) { 142 | return fmt.Errorf("delete: %w", err) 143 | } 144 | delete(h.messages, convoID) 145 | delete(h.loaded, convoID) 146 | return nil 147 | } 148 | -------------------------------------------------------------------------------- /internal/convo/chat_history_store_test.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestStore(t *testing.T) { 15 | convoID := NewConversationID() 16 | 17 | t.Run("read non-existent", func(t *testing.T) { 18 | store := NewSimpleChatHistoryStore(t.TempDir()) 19 | err := store.load("super-fake") 20 | require.ErrorIs(t, err, os.ErrNotExist) 21 | defer func() { 22 | _ = store.InvalidateMessages(context.Background(), convoID) 23 | }() 24 | }) 25 | 26 | t.Run("set messages", func(t *testing.T) { 27 | ctx := context.Background() 28 | store := NewSimpleChatHistoryStore(t.TempDir()) 29 | _ = store.InvalidateMessages(context.Background(), convoID) 30 | require.NoError(t, store.AddUserMessage(ctx, convoID, "hello")) 31 | require.NoError(t, store.AddAIMessage(ctx, convoID, "hi")) 32 | require.NoError(t, store.AddAIMessage(ctx, convoID, "bye")) 33 | 34 | // Messages should be in memory but not persisted 35 | messages, err := store.Messages(ctx, convoID) 36 | require.NoError(t, err) 37 | require.Equal(t, 3, len(messages)) 38 | 39 | // After persist, messages should be saved 40 | require.NoError(t, store.PersistentMessages(ctx, convoID)) 41 | persistedMessages, err := store.Messages(ctx, convoID) 42 | require.NoError(t, err) 43 | require.Equal(t, messages, persistedMessages) 44 | 45 | defer func() { 46 | _ = store.InvalidateMessages(context.Background(), convoID) 47 | }() 48 | }) 49 | 50 | t.Run("delete", func(t *testing.T) { 51 | store := NewSimpleChatHistoryStore(t.TempDir()) 52 | require.NoError(t, store.PersistentMessages(context.Background(), convoID)) 53 | require.NoError(t, store.InvalidateMessages(context.Background(), convoID)) 54 | require.ErrorIs(t, store.load(convoID), os.ErrNotExist) 55 | defer func() { 56 | _ = store.InvalidateMessages(context.Background(), convoID) 57 | }() 58 | }) 59 | 60 | t.Run("concurrent access", func(t *testing.T) { 61 | ctx := context.Background() 62 | store := NewSimpleChatHistoryStore(t.TempDir()) 63 | convoID := NewConversationID() 64 | 65 | var wg sync.WaitGroup 66 | const numWorkers = 100 67 | const numMessages = 100 68 | 69 | // Concurrent writers 70 | wg.Add(numWorkers) 71 | for i := 0; i < numWorkers; i++ { 72 | go func(i int) { 73 | defer wg.Done() 74 | for j := 0; j < numMessages; j++ { 75 | if i%2 == 0 { 76 | require.NoError(t, store.AddUserMessage(ctx, convoID, fmt.Sprintf("user %d-%d", i, j))) 77 | } else { 78 | require.NoError(t, store.AddAIMessage(ctx, convoID, fmt.Sprintf("ai %d-%d", i, j))) 79 | } 80 | time.Sleep(time.Millisecond) // Add some delay to increase contention 81 | } 82 | }(i) 83 | } 84 | 85 | // Concurrent readers 86 | wg.Add(numWorkers) 87 | for i := 0; i < numWorkers; i++ { 88 | go func() { 89 | defer wg.Done() 90 | for j := 0; j < numMessages; j++ { 91 | msgs, err := store.Messages(ctx, convoID) 92 | require.NoError(t, err) 93 | require.True(t, len(msgs) <= numWorkers*numMessages) 94 | time.Sleep(time.Millisecond) 95 | } 96 | }() 97 | } 98 | 99 | wg.Wait() 100 | 101 | // Verify final message count 102 | msgs, err := store.Messages(ctx, convoID) 103 | require.NoError(t, err) 104 | require.Equal(t, numWorkers*numMessages, len(msgs)) 105 | 106 | // Verify persistence 107 | require.NoError(t, store.PersistentMessages(ctx, convoID)) 108 | persistedMsgs, err := store.Messages(ctx, convoID) 109 | require.NoError(t, err) 110 | require.Equal(t, msgs, persistedMsgs) 111 | 112 | defer func() { 113 | _ = store.InvalidateMessages(context.Background(), convoID) 114 | }() 115 | }) 116 | } 117 | -------------------------------------------------------------------------------- /internal/convo/factory.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/coding-hui/ai-terminal/internal/errbook" 9 | "github.com/coding-hui/ai-terminal/internal/options" 10 | ) 11 | 12 | var ( 13 | conversationStoreFactories = make(map[string]Factory) 14 | conversationStores = make(map[string]Store) 15 | 16 | lock = sync.Mutex{} 17 | ) 18 | 19 | type Factory interface { 20 | // Type unique type of the convo store 21 | Type() string 22 | // Create relevant convo store by type 23 | Create(options *options.Config) (Store, error) 24 | } 25 | 26 | func GetConversationStore(cfg *options.Config) (Store, error) { 27 | dsType := cfg.DataStore.Type 28 | 29 | // Check if store already exists 30 | store, exists := conversationStores[dsType] 31 | if !exists { 32 | // Create new store if it doesn't exist 33 | factory, ok := conversationStoreFactories[dsType] 34 | if !ok { 35 | return nil, fmt.Errorf("chat convo store `%s` is not supported", dsType) 36 | } 37 | 38 | newStore, err := factory.Create(cfg) 39 | if err != nil { 40 | return nil, errbook.Wrap("Failed to create chat convo store: "+dsType, err) 41 | } 42 | 43 | lock.Lock() 44 | conversationStores[dsType] = newStore 45 | lock.Unlock() 46 | store = newStore 47 | } 48 | 49 | // Handle conversation ID if not provided 50 | if cfg.ConversationID == "" { 51 | // Try to get latest conversation 52 | latest, err := store.LatestConversation(context.Background()) 53 | if err != nil { 54 | return nil, errbook.Wrap("Failed to get latest conversation", err) 55 | } 56 | 57 | if latest != nil && latest.ID != "" { 58 | cfg.ConversationID = latest.ID 59 | } else { 60 | // No conversations exist, generate new ID 61 | cfg.ConversationID = NewConversationID() 62 | //debug.Trace("conversation id not provided, generating new id `%s`", cfg.ConversationID) 63 | } 64 | } 65 | 66 | return store, nil 67 | } 68 | 69 | func RegisterConversationStore(factory Factory) { 70 | conversationStoreFactories[factory.Type()] = factory 71 | } 72 | -------------------------------------------------------------------------------- /internal/convo/format.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "encoding/gob" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 9 | ) 10 | 11 | func encode(w io.Writer, messages *[]llms.ChatMessageModel) error { 12 | if err := gob.NewEncoder(w).Encode(messages); err != nil { 13 | return fmt.Errorf("encode: %w", err) 14 | } 15 | return nil 16 | } 17 | 18 | func decode(r io.Reader, messages *[]llms.ChatMessageModel) error { 19 | if err := gob.NewDecoder(r).Decode(messages); err != nil { 20 | return fmt.Errorf("decode: %w", err) 21 | } 22 | return nil 23 | } 24 | -------------------------------------------------------------------------------- /internal/convo/sha.go: -------------------------------------------------------------------------------- 1 | package convo 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/sha1" //nolint: gosec 6 | "fmt" 7 | "regexp" 8 | ) 9 | 10 | const ( 11 | Sha1short = 7 12 | Sha1minLen = 4 13 | Sha1ReadBlockSize = 4096 14 | ) 15 | 16 | var Sha1reg = regexp.MustCompile(`\b[0-9a-f]{40}\b`) 17 | 18 | func NewConversationID() string { 19 | b := make([]byte, Sha1ReadBlockSize) 20 | _, _ = rand.Read(b) 21 | return fmt.Sprintf("%x", sha1.Sum(b)) //nolint: gosec 22 | } 23 | 24 | func MatchSha1(s string) bool { 25 | return Sha1reg.MatchString(s) 26 | } 27 | -------------------------------------------------------------------------------- /internal/convo/sqlite3/convo_store_options.go: -------------------------------------------------------------------------------- 1 | package sqlite3 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "github.com/jmoiron/sqlx" 8 | 9 | "github.com/coding-hui/ai-terminal/internal/convo" 10 | "github.com/coding-hui/ai-terminal/internal/errbook" 11 | ) 12 | 13 | // DefaultSchema sets a default schema to be run after connecting. 14 | const DefaultSchema = `CREATE TABLE 15 | IF NOT EXISTS conversations ( 16 | id string NOT NULL PRIMARY KEY, 17 | title string NOT NULL, 18 | model string NOT NULL, 19 | updated_at datetime NOT NULL DEFAULT (strftime ('%Y-%m-%d %H:%M:%f', 'now')), 20 | CHECK (id <> ''), 21 | CHECK (title <> '') 22 | ); 23 | CREATE INDEX IF NOT EXISTS idx_conv_id ON conversations (id); 24 | CREATE INDEX IF NOT EXISTS idx_conv_title ON conversations (title); 25 | 26 | CREATE TABLE IF NOT EXISTS load_contexts ( 27 | id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, 28 | type string NOT NULL, 29 | url string, 30 | file_path string, 31 | content text NOT NULL, 32 | name string NOT NULL, 33 | conversation_id string NOT NULL, 34 | updated_at datetime NOT NULL DEFAULT (strftime ('%Y-%m-%d %H:%M:%f', 'now')), 35 | CHECK (name <> ''), 36 | CHECK (conversation_id <> ''), 37 | FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE 38 | ); 39 | CREATE INDEX IF NOT EXISTS idx_loadctx_convo ON load_contexts (conversation_id); 40 | ` 41 | 42 | // SqliteChatMessageHistoryOption is a function for creating new 43 | // chat message convo with other than the default values. 44 | type SqliteChatMessageHistoryOption func(m *SqliteStore) 45 | 46 | // WithDataPath is an option for NewSqliteChatMessageHistory for 47 | // setting a path to the data directory. 48 | func WithDataPath(path string) SqliteChatMessageHistoryOption { 49 | return func(m *SqliteStore) { 50 | m.DataPath = path 51 | } 52 | } 53 | 54 | // WithDB is an option for NewSqliteChatMessageHistory for adding 55 | // a database connection. 56 | func WithDB(db *sqlx.DB) SqliteChatMessageHistoryOption { 57 | return func(m *SqliteStore) { 58 | m.DB = db 59 | } 60 | } 61 | 62 | // WithContext is an option for NewSqliteChatMessageHistory 63 | // to use a context internally when running Schema. 64 | func WithContext(ctx context.Context) SqliteChatMessageHistoryOption { 65 | return func(m *SqliteStore) { 66 | m.Ctx = ctx //nolint:fatcontext 67 | } 68 | } 69 | 70 | // WithDBAddress is an option for NewSqliteChatMessageHistory for 71 | // specifying an address or file path for when connecting the db. 72 | func WithDBAddress(addr string) SqliteChatMessageHistoryOption { 73 | return func(m *SqliteStore) { 74 | m.DBAddress = addr 75 | } 76 | } 77 | 78 | // WithConversation is an option for NewSqliteChatMessageHistory for 79 | // setting a session name or ConvoID for the convo. 80 | func WithConversation(convoID string) SqliteChatMessageHistoryOption { 81 | return func(m *SqliteStore) { 82 | m.ConvoID = convoID 83 | } 84 | } 85 | 86 | func applyChatOptions(options ...SqliteChatMessageHistoryOption) *SqliteStore { 87 | h := &SqliteStore{} 88 | 89 | for _, option := range options { 90 | option(h) 91 | } 92 | 93 | if h.Ctx == nil { 94 | h.Ctx = context.Background() 95 | } 96 | 97 | if h.DBAddress == "" { 98 | h.DBAddress = ":memory:" 99 | } 100 | 101 | if h.ConvoID == "" { 102 | h.ConvoID = "default" 103 | } 104 | 105 | if h.DataPath == "" { 106 | h.DataPath = "./data" 107 | } 108 | 109 | if h.DB == nil { 110 | db, err := sqlx.Open("sqlite", h.DBAddress) 111 | if err != nil { 112 | errbook.HandleError(errbook.Wrap("Could not open database.", err)) 113 | os.Exit(1) 114 | } 115 | h.DB = db 116 | } 117 | 118 | if err := h.DB.Ping(); err != nil { 119 | errbook.HandleError(errbook.Wrap("Could not connect to database.", err)) 120 | os.Exit(1) 121 | } 122 | 123 | if _, err := h.DB.ExecContext(h.Ctx, DefaultSchema); err != nil { 124 | errbook.HandleError(errbook.Wrap("Could not create convo db table.", err)) 125 | os.Exit(1) 126 | } 127 | 128 | h.SimpleChatHistoryStore = convo.NewSimpleChatHistoryStore(h.DataPath) 129 | h.sqliteLoadContextStore = newLoadContextStore(h.DB) 130 | 131 | return h 132 | } 133 | -------------------------------------------------------------------------------- /internal/convo/sqlite3/convo_store_test.go: -------------------------------------------------------------------------------- 1 | package sqlite3 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 13 | 14 | "github.com/coding-hui/ai-terminal/internal/convo" 15 | ) 16 | 17 | func TestSqliteStore(t *testing.T) { 18 | t.Parallel() 19 | 20 | ctx := context.Background() 21 | convoID := convo.NewConversationID() 22 | h := NewSqliteStore( 23 | WithConversation(convoID), 24 | WithContext(ctx), 25 | WithDataPath(t.TempDir()), 26 | ) 27 | 28 | t.Run("Save and Get conversation", func(t *testing.T) { 29 | err := h.SaveConversation(ctx, convoID, "foo", "test") 30 | require.NoError(t, err) 31 | 32 | convo, err := h.GetConversation(ctx, convoID) 33 | require.NoError(t, err) 34 | assert.Equal(t, "foo", convo.Title) 35 | assert.False(t, convo.UpdatedAt.IsZero()) 36 | }) 37 | 38 | t.Run("Get non-existent conversation", func(t *testing.T) { 39 | _, err := h.GetConversation(ctx, "nonexistent") 40 | assert.True(t, errors.Is(err, errNoMatches)) 41 | }) 42 | 43 | t.Run("Delete conversation", func(t *testing.T) { 44 | require.NoError(t, h.DeleteConversation(ctx, convoID)) 45 | 46 | ok, err := h.ConversationExists(ctx, convoID) 47 | require.NoError(t, err) 48 | assert.False(t, ok) 49 | }) 50 | 51 | t.Run("List conversations", func(t *testing.T) { 52 | // Create multiple conversations 53 | ids := []string{ 54 | convo.NewConversationID(), 55 | convo.NewConversationID(), 56 | } 57 | require.NoError(t, h.SaveConversation(ctx, ids[0], "first", "test")) 58 | require.NoError(t, h.SaveConversation(ctx, ids[1], "second", "test")) 59 | 60 | convos, err := h.ListConversations(ctx) 61 | require.NoError(t, err) 62 | assert.GreaterOrEqual(t, len(convos), 2) 63 | }) 64 | 65 | t.Run("ListOlderThan", func(t *testing.T) { 66 | oldID := convo.NewConversationID() 67 | require.NoError(t, h.SaveConversation(ctx, oldID, "old", "test")) 68 | 69 | // Update timestamp to be old 70 | _, err := h.DB.ExecContext(ctx, ` 71 | UPDATE conversations 72 | SET updated_at = datetime('now', '-1 hour') 73 | WHERE id = ? 74 | `, oldID) 75 | require.NoError(t, err) 76 | 77 | convos, err := h.ListConversationsOlderThan(ctx, 30*time.Minute) 78 | require.NoError(t, err) 79 | assert.GreaterOrEqual(t, len(convos), 1) 80 | }) 81 | 82 | t.Run("Clear all conversations", func(t *testing.T) { 83 | require.NoError(t, h.ClearConversations(ctx)) 84 | 85 | convos, err := h.ListConversations(ctx) 86 | require.NoError(t, err) 87 | assert.Empty(t, convos) 88 | }) 89 | } 90 | 91 | func TestSqliteChatMessageHistory(t *testing.T) { 92 | t.Parallel() 93 | 94 | ctx := context.Background() 95 | convoID := convo.NewConversationID() 96 | h := NewSqliteStore( 97 | WithContext(ctx), 98 | WithDataPath(t.TempDir()), 99 | ) 100 | 101 | t.Run("Add and get messages", func(t *testing.T) { 102 | err := h.AddAIMessage(ctx, convoID, "foo") 103 | require.NoError(t, err) 104 | 105 | err = h.AddUserMessage(ctx, convoID, "bar") 106 | require.NoError(t, err) 107 | 108 | messages, err := h.Messages(ctx, convoID) 109 | require.NoError(t, err) 110 | 111 | assert.Equal(t, []llms.ChatMessage{ 112 | llms.AIChatMessage{Content: "foo"}, 113 | llms.HumanChatMessage{Content: "bar"}, 114 | }, messages) 115 | }) 116 | 117 | t.Run("Set and add messages", func(t *testing.T) { 118 | err := h.SetMessages(ctx, 119 | convoID, 120 | []llms.ChatMessage{ 121 | llms.AIChatMessage{Content: "foo"}, 122 | llms.SystemChatMessage{Content: "bar"}, 123 | }) 124 | require.NoError(t, err) 125 | 126 | err = h.AddUserMessage(ctx, convoID, "zoo") 127 | require.NoError(t, err) 128 | 129 | messages, err := h.Messages(ctx, convoID) 130 | require.NoError(t, err) 131 | 132 | assert.Equal(t, []llms.ChatMessage{ 133 | llms.AIChatMessage{Content: "foo"}, 134 | llms.SystemChatMessage{Content: "bar"}, 135 | llms.HumanChatMessage{Content: "zoo"}, 136 | }, messages) 137 | }) 138 | 139 | t.Run("Get messages from non-existent conversation", func(t *testing.T) { 140 | messages, err := h.Messages(ctx, "nonexistent") 141 | assert.NoError(t, err) 142 | assert.Len(t, messages, 0) 143 | }) 144 | } 145 | -------------------------------------------------------------------------------- /internal/convo/sqlite3/loadcontext_store.go: -------------------------------------------------------------------------------- 1 | package sqlite3 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | 9 | "github.com/jmoiron/sqlx" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/convo" 12 | ) 13 | 14 | var ( 15 | errLoadContextNotFound = errors.New("load context not found") 16 | ) 17 | 18 | type sqliteLoadContextStore struct { 19 | db *sqlx.DB 20 | } 21 | 22 | func newLoadContextStore(db *sqlx.DB) *sqliteLoadContextStore { 23 | return &sqliteLoadContextStore{db: db} 24 | } 25 | 26 | func (s *sqliteLoadContextStore) SaveContext(ctx context.Context, lc *convo.LoadContext) error { 27 | res, err := s.db.ExecContext(ctx, s.db.Rebind(` 28 | UPDATE load_contexts 29 | SET 30 | type = ?, 31 | url = ?, 32 | file_path = ?, 33 | content = ?, 34 | name = ?, 35 | conversation_id = ?, 36 | updated_at = CURRENT_TIMESTAMP 37 | WHERE 38 | id = ? 39 | `), lc.Type, lc.URL, lc.FilePath, lc.Content, lc.Name, lc.ConversationID, lc.ID) 40 | if err != nil { 41 | return fmt.Errorf("SaveContext: %w", err) 42 | } 43 | 44 | rows, err := res.RowsAffected() 45 | if err != nil { 46 | return fmt.Errorf("SaveContext: %w", err) 47 | } 48 | 49 | if rows > 0 { 50 | return nil 51 | } 52 | 53 | resp, err := s.db.ExecContext(ctx, s.db.Rebind(` 54 | INSERT INTO load_contexts ( 55 | type, url, file_path, content, name, conversation_id 56 | ) VALUES ( 57 | ?, ?, ?, ?, ?, ? 58 | ) 59 | `), lc.Type, lc.URL, lc.FilePath, lc.Content, lc.Name, lc.ConversationID) 60 | if err != nil { 61 | return fmt.Errorf("SaveContext: %w", err) 62 | } 63 | 64 | lastInsertId, err := resp.LastInsertId() 65 | if err != nil { 66 | return fmt.Errorf("SaveContext: %w", err) 67 | } 68 | lc.ID = uint64(lastInsertId) 69 | 70 | return nil 71 | } 72 | 73 | func (s *sqliteLoadContextStore) GetContext(ctx context.Context, id uint64) (*convo.LoadContext, error) { 74 | var lc convo.LoadContext 75 | err := s.db.GetContext(ctx, &lc, s.db.Rebind(` 76 | SELECT id, type, url, file_path, content, name, conversation_id, updated_at 77 | FROM load_contexts WHERE id = ? 78 | `), id) 79 | if err != nil { 80 | if errors.Is(err, sql.ErrNoRows) { 81 | return nil, fmt.Errorf("%w: %v", errLoadContextNotFound, err) 82 | } 83 | return nil, fmt.Errorf("GetContext: %w", err) 84 | } 85 | return &lc, nil 86 | } 87 | 88 | func (s *sqliteLoadContextStore) ListContextsByteConvoID(ctx context.Context, conversationID string) ([]convo.LoadContext, error) { 89 | var contexts []convo.LoadContext 90 | if err := s.db.SelectContext(ctx, &contexts, s.db.Rebind(` 91 | SELECT id, type, url, file_path, content, name, conversation_id, updated_at 92 | FROM load_contexts WHERE conversation_id = ? 93 | `), conversationID); err != nil { 94 | return nil, fmt.Errorf("ListContextsByteConvoID: %w", err) 95 | } 96 | return contexts, nil 97 | } 98 | 99 | func (s *sqliteLoadContextStore) DeleteContexts(ctx context.Context, id uint64) error { 100 | res, err := s.db.ExecContext(ctx, s.db.Rebind(` 101 | DELETE FROM load_contexts WHERE id = ? 102 | `), id) 103 | if err != nil { 104 | return fmt.Errorf("DeleteContexts: %w", err) 105 | } 106 | 107 | rows, err := res.RowsAffected() 108 | if err != nil { 109 | return fmt.Errorf("DeleteContexts: %w", err) 110 | } 111 | 112 | if rows == 0 { 113 | return fmt.Errorf("DeleteContexts: no rows affected") 114 | } 115 | 116 | return nil 117 | } 118 | 119 | func (s *sqliteLoadContextStore) CleanContexts(ctx context.Context, conversationID string) (int64, error) { 120 | res, err := s.db.ExecContext(ctx, s.db.Rebind(` 121 | DELETE FROM load_contexts WHERE conversation_id = ? 122 | `), conversationID) 123 | if err != nil { 124 | return 0, fmt.Errorf("CleanContexts: %w", err) 125 | } 126 | 127 | rows, err := res.RowsAffected() 128 | if err != nil { 129 | return 0, fmt.Errorf("CleanContexts: %w", err) 130 | } 131 | 132 | return rows, nil 133 | } 134 | -------------------------------------------------------------------------------- /internal/convo/sqlite3/loadcontext_store_test.go: -------------------------------------------------------------------------------- 1 | package sqlite3 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/jmoiron/sqlx" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/convo" 12 | ) 13 | 14 | func TestLoadContextStore(t *testing.T) { 15 | t.Parallel() 16 | 17 | ctx := context.Background() 18 | db := setupTestDB(t) 19 | store := newLoadContextStore(db) 20 | 21 | t.Run("SaveContext and GetContext LoadContext", func(t *testing.T) { 22 | lc := &convo.LoadContext{ 23 | Type: "file", 24 | URL: "", 25 | FilePath: "/path/to/file", 26 | Content: "test content", 27 | Name: "test.txt", 28 | ConversationID: "conv1", 29 | } 30 | 31 | err := store.SaveContext(ctx, lc) 32 | require.NoError(t, err) 33 | 34 | retrieved, err := store.GetContext(ctx, lc.ID) 35 | require.NoError(t, err) 36 | assert.Equal(t, lc.ID, retrieved.ID) 37 | assert.Equal(t, lc.Type, retrieved.Type) 38 | assert.Equal(t, lc.FilePath, retrieved.FilePath) 39 | assert.Equal(t, lc.Content, retrieved.Content) 40 | assert.Equal(t, lc.Name, retrieved.Name) 41 | assert.Equal(t, lc.ConversationID, retrieved.ConversationID) 42 | }) 43 | 44 | t.Run("GetContext non-existent LoadContext", func(t *testing.T) { 45 | _, err := store.GetContext(ctx, uint64(999)) 46 | require.Error(t, err) 47 | assert.Contains(t, err.Error(), "load context not found") 48 | }) 49 | 50 | t.Run("ListContextsByteConvoID LoadContexts by convo", func(t *testing.T) { 51 | convID := "conv2" 52 | lc1 := &convo.LoadContext{ 53 | ID: uint64(2), 54 | Type: "file", 55 | FilePath: "/path/to/file1", 56 | Content: "content1", 57 | Name: "file1.txt", 58 | ConversationID: convID, 59 | } 60 | lc2 := &convo.LoadContext{ 61 | ID: uint64(3), 62 | Type: "file", 63 | FilePath: "/path/to/file2", 64 | Content: "content2", 65 | Name: "file2.txt", 66 | ConversationID: convID, 67 | } 68 | 69 | require.NoError(t, store.SaveContext(ctx, lc1)) 70 | require.NoError(t, store.SaveContext(ctx, lc2)) 71 | 72 | contexts, err := store.ListContextsByteConvoID(ctx, convID) 73 | require.NoError(t, err) 74 | assert.Len(t, contexts, 2) 75 | }) 76 | 77 | t.Run("DeleteContexts LoadContext", func(t *testing.T) { 78 | lc := &convo.LoadContext{ 79 | ID: uint64(4), 80 | Type: "file", 81 | FilePath: "/path/to/file", 82 | Content: "test content", 83 | Name: "test.txt", 84 | ConversationID: "conv3", 85 | } 86 | 87 | require.NoError(t, store.SaveContext(ctx, lc)) 88 | require.NoError(t, store.DeleteContexts(ctx, lc.ID)) 89 | 90 | _, err := store.GetContext(ctx, lc.ID) 91 | require.Error(t, err) 92 | }) 93 | 94 | t.Run("CleanContexts LoadContexts by convo", func(t *testing.T) { 95 | convID := "conv4" 96 | lc1 := &convo.LoadContext{ 97 | ID: uint64(5), 98 | Type: "file", 99 | FilePath: "/path/to/file1", 100 | Content: "content1", 101 | Name: "file1.txt", 102 | ConversationID: convID, 103 | } 104 | lc2 := &convo.LoadContext{ 105 | ID: uint64(6), 106 | Type: "file", 107 | FilePath: "/path/to/file2", 108 | Content: "content2", 109 | Name: "file2.txt", 110 | ConversationID: convID, 111 | } 112 | 113 | require.NoError(t, store.SaveContext(ctx, lc1)) 114 | require.NoError(t, store.SaveContext(ctx, lc2)) 115 | 116 | _, err := store.CleanContexts(ctx, convID) 117 | require.NoError(t, err) 118 | 119 | contexts, err := store.ListContextsByteConvoID(ctx, convID) 120 | require.NoError(t, err) 121 | assert.Empty(t, contexts) 122 | }) 123 | 124 | t.Run("Update existing LoadContext", func(t *testing.T) { 125 | lc := &convo.LoadContext{ 126 | ID: uint64(7), 127 | Type: "file", 128 | FilePath: "/path/to/file", 129 | Content: "initial content", 130 | Name: "test.txt", 131 | ConversationID: "conv5", 132 | } 133 | 134 | require.NoError(t, store.SaveContext(ctx, lc)) 135 | 136 | // Update content 137 | lc.Content = "updated content" 138 | require.NoError(t, store.SaveContext(ctx, lc)) 139 | 140 | retrieved, err := store.GetContext(ctx, lc.ID) 141 | require.NoError(t, err) 142 | assert.Equal(t, "updated content", retrieved.Content) 143 | }) 144 | } 145 | 146 | func setupTestDB(t *testing.T) *sqlx.DB { 147 | db, err := sqlx.Open("sqlite", ":memory:") 148 | require.NoError(t, err) 149 | 150 | _, err = db.Exec(DefaultSchema) 151 | require.NoError(t, err) 152 | 153 | t.Cleanup(func() { 154 | _ = db.Close() 155 | }) 156 | 157 | return db 158 | } 159 | -------------------------------------------------------------------------------- /internal/errbook/error.go: -------------------------------------------------------------------------------- 1 | package errbook 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // ErrUserAborted is the errbook returned when a user exits the form before submitting. 9 | var ErrUserAborted = errors.New("user aborted") 10 | 11 | // ErrTimeout is the errbook returned when the timeout is reached. 12 | var ErrTimeout = errors.New("timeout") 13 | 14 | // ErrTimeoutUnsupported is the errbook returned when timeout is used while in accessible mode. 15 | var ErrTimeoutUnsupported = errors.New("timeout is not supported in accessible mode") 16 | 17 | // ErrInvalidArgument is the errbook returned when the input is invalid. 18 | var ErrInvalidArgument = errors.New("invalid argument") 19 | 20 | // NewUserErrorf is a user-facing errbook. 21 | // this function is mostly to avoid linters complain about errbook starting with a capitalized letter. 22 | func NewUserErrorf(format string, a ...any) error { 23 | return fmt.Errorf(format, a...) 24 | } 25 | 26 | // AiError is a wrapper around an errbook that adds additional context. 27 | type AiError struct { 28 | err error 29 | reason string 30 | } 31 | 32 | func New(format string, a ...any) error { 33 | return AiError{ 34 | err: fmt.Errorf(format, a...), 35 | } 36 | } 37 | 38 | func Wrap(reason string, err error) error { 39 | return AiError{ 40 | err: err, 41 | reason: reason, 42 | } 43 | } 44 | 45 | func (m AiError) Error() string { 46 | return m.err.Error() 47 | } 48 | 49 | func (m AiError) Reason() string { 50 | return m.reason 51 | } 52 | -------------------------------------------------------------------------------- /internal/errbook/handle.go: -------------------------------------------------------------------------------- 1 | package errbook 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "github.com/coding-hui/ai-terminal/internal/ui/console" 10 | "github.com/coding-hui/ai-terminal/internal/util/term" 11 | ) 12 | 13 | func HandleError(err error) { 14 | // exhaust stdin 15 | if !term.IsInputTTY() { 16 | _, _ = io.ReadAll(os.Stdin) 17 | } 18 | 19 | format := "\n%s\n\n" 20 | 21 | var args []interface{} 22 | var aiTermErr AiError 23 | if errors.As(err, &aiTermErr) { 24 | args = []interface{}{ 25 | console.StderrStyles().ErrPadding.Render(console.StderrStyles().ErrorHeader.String(), aiTermErr.reason), 26 | } 27 | 28 | // Skip the errbook details if the user simply canceled out of huh. 29 | if !errors.Is(aiTermErr.err, ErrUserAborted) { 30 | format += "%s\n\n" 31 | args = append(args, console.StderrStyles().ErrPadding.Render(console.StderrStyles().ErrorDetails.Render(err.Error()))) 32 | } 33 | } else { 34 | args = []interface{}{ 35 | console.StderrStyles().ErrPadding.Render(console.StderrStyles().ErrorDetails.Render(err.Error())), 36 | } 37 | } 38 | 39 | _, _ = fmt.Fprintf(os.Stderr, format, args...) 40 | } 41 | -------------------------------------------------------------------------------- /internal/git/git_test.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestCommand_GitDir(t *testing.T) { 12 | g := New() 13 | dir, err := g.GitDir() 14 | require.NoError(t, err) 15 | assert.NotEmpty(t, dir) 16 | assert.True(t, strings.HasSuffix(strings.TrimSpace(dir), ".git")) 17 | } 18 | -------------------------------------------------------------------------------- /internal/git/options.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | // optionFunc is a type of function that can be used to implement the Option interface. 4 | // It takes a pointer to a config struct and modifies it. 5 | type optionFunc func(*config) 6 | 7 | // Option is an interface that specifies instrumentation configuration options. 8 | type Option interface { 9 | apply(*config) 10 | } 11 | 12 | // Ensure that optionFunc satisfies the Option interface. 13 | var _ Option = (*optionFunc)(nil) 14 | 15 | // The apply method of optionFunc type is implemented here to modify the config struct based on the function passed. 16 | func (o optionFunc) apply(c *config) { 17 | o(c) 18 | } 19 | 20 | // WithDiffUnified is a function that generate diffs with lines of context instead of the usual three. 21 | func WithDiffUnified(val int) Option { 22 | return optionFunc(func(c *config) { 23 | c.diffUnified = val 24 | }) 25 | } 26 | 27 | // WithExcludeList returns an Option that sets the excludeList field of a config object to the given value. 28 | func WithExcludeList(val []string) Option { 29 | return optionFunc(func(c *config) { 30 | // If the given value is empty, do nothing. 31 | if len(val) == 0 { 32 | return 33 | } 34 | c.excludeList = val 35 | }) 36 | } 37 | 38 | // WithEnableAmend returns an Option that sets the isAmend field of a config object to the given value. 39 | func WithEnableAmend(val bool) Option { 40 | return optionFunc(func(c *config) { 41 | c.isAmend = val 42 | }) 43 | } 44 | 45 | // config is a struct that stores configuration options for the instrumentation. 46 | type config struct { 47 | diffUnified int 48 | excludeList []string 49 | isAmend bool 50 | } 51 | -------------------------------------------------------------------------------- /internal/options/basic_flags.go: -------------------------------------------------------------------------------- 1 | package options 2 | 3 | import ( 4 | "github.com/spf13/pflag" 5 | 6 | "github.com/coding-hui/ai-terminal/internal/ui/console" 7 | ) 8 | 9 | const ( 10 | FlagLogFlushFrequency = "log-flush-frequency" 11 | ) 12 | 13 | // AddBasicFlags binds client configuration flags to a given flagset. 14 | func AddBasicFlags(flags *pflag.FlagSet, cfg *Config) { 15 | flags.StringVarP(&cfg.Model, "model", "m", cfg.Model, console.StdoutStyles().FlagDesc.Render(Help["model"])) 16 | flags.StringVarP(&cfg.API, "api", "a", cfg.API, console.StdoutStyles().FlagDesc.Render(Help["api"])) 17 | //flags.StringVarP(&cfg.HTTPProxy, "http-proxy", "x", cfg.HTTPProxy, console.StdoutStyles().FlagDesc.Render(Help["http-proxy"])) 18 | //flags.BoolVarP(&cfg.Format, "format", "f", cfg.Format, console.StdoutStyles().FlagDesc.Render(Help["format"])) 19 | flags.StringVar(&cfg.FormatAs, "format-as", cfg.FormatAs, console.StdoutStyles().FlagDesc.Render(Help["format-as"])) 20 | flags.BoolVarP(&cfg.Raw, "raw", "r", cfg.Raw, console.StdoutStyles().FlagDesc.Render(Help["raw"])) 21 | flags.BoolVarP(&cfg.Quiet, "quiet", "q", cfg.Quiet, console.StdoutStyles().FlagDesc.Render(Help["quiet"])) 22 | flags.IntVar(&cfg.MaxRetries, "max-retries", cfg.MaxRetries, console.StdoutStyles().FlagDesc.Render(Help["max-retries"])) 23 | flags.BoolVar(&cfg.NoLimit, "no-limit", cfg.NoLimit, console.StdoutStyles().FlagDesc.Render(Help["no-limit"])) 24 | flags.IntVar(&cfg.MaxTokens, "max-tokens", cfg.MaxTokens, console.StdoutStyles().FlagDesc.Render(Help["max-tokens"])) 25 | flags.IntVar(&cfg.WordWrap, "word-wrap", cfg.WordWrap, console.StdoutStyles().FlagDesc.Render(Help["word-wrap"])) 26 | flags.Float64Var(&cfg.Temperature, "temp", cfg.Temperature, console.StdoutStyles().FlagDesc.Render(Help["temp"])) 27 | flags.StringArrayVar(&cfg.Stop, "stop", cfg.Stop, console.StdoutStyles().FlagDesc.Render(Help["stop"])) 28 | flags.Float64Var(&cfg.TopP, "topp", cfg.TopP, console.StdoutStyles().FlagDesc.Render(Help["topp"])) 29 | flags.IntVar(&cfg.TopK, "topk", cfg.TopK, console.StdoutStyles().FlagDesc.Render(Help["topk"])) 30 | flags.UintVar(&cfg.Fanciness, "fanciness", cfg.Fanciness, console.StdoutStyles().FlagDesc.Render(Help["fanciness"])) 31 | flags.StringVar(&cfg.LoadingText, "loading-text", cfg.LoadingText, console.StdoutStyles().FlagDesc.Render(Help["status-text"])) 32 | flags.BoolVar(&cfg.NoCache, "no-cache", cfg.NoCache, console.StdoutStyles().FlagDesc.Render(Help["no-cache"])) 33 | flags.StringVarP(&cfg.Show, "show", "s", cfg.Show, console.StdoutStyles().FlagDesc.Render(Help["show"])) 34 | flags.BoolVarP(&cfg.ShowLast, "show-last", "S", false, console.StdoutStyles().FlagDesc.Render(Help["show-last"])) 35 | flags.StringVarP(&cfg.Continue, "continue", "c", "", console.StdoutStyles().FlagDesc.Render(Help["continue"])) 36 | flags.BoolVarP(&cfg.ContinueLast, "continue-last", "C", false, console.StdoutStyles().FlagDesc.Render(Help["continue-last"])) 37 | flags.StringVarP(&cfg.Title, "title", "T", cfg.Title, console.StdoutStyles().FlagDesc.Render(Help["title"])) 38 | flags.IntVarP(&cfg.Verbose, "verbose", "v", cfg.Verbose, console.StdoutStyles().FlagDesc.Render(Help["verbose"])) 39 | //flags.StringVarP(&cfg.Role, "role", "R", cfg.Role, console.StdoutStyles().FlagDesc.Render(Help["role"])) 40 | //flags.BoolVar(&cfg.ListRoles, "list-roles", cfg.ListRoles, console.StdoutStyles().FlagDesc.Render(Help["list-roles"])) 41 | //flags.StringVar(&cfg.Theme, "theme", "charm", console.StdoutStyles().FlagDesc.Render(Help["theme"])) 42 | } 43 | -------------------------------------------------------------------------------- /internal/options/config_test.go: -------------------------------------------------------------------------------- 1 | package options 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | "gopkg.in/yaml.v3" 8 | ) 9 | 10 | func TestConfig(t *testing.T) { 11 | t.Run("old format text", func(t *testing.T) { 12 | var cfg Config 13 | require.NoError(t, yaml.Unmarshal([]byte("format-text: as markdown"), &cfg)) 14 | require.Equal(t, FormatText(map[string]string{ 15 | "markdown": "as markdown", 16 | }), cfg.FormatText) 17 | }) 18 | t.Run("new format text", func(t *testing.T) { 19 | var cfg Config 20 | require.NoError(t, yaml.Unmarshal([]byte("format-text:\n markdown: as markdown\n json: as json"), &cfg)) 21 | require.Equal(t, FormatText(map[string]string{ 22 | "markdown": "as markdown", 23 | "json": "as json", 24 | }), cfg.FormatText) 25 | }) 26 | } 27 | -------------------------------------------------------------------------------- /internal/options/datastore_flags.go: -------------------------------------------------------------------------------- 1 | package options 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AlekSi/pointer" 7 | "github.com/spf13/pflag" 8 | ) 9 | 10 | const ( 11 | FlagDatastoreType = "datastore.type" 12 | FlagDatastoreUrl = "datastore.url" 13 | FlagDatastorePath = "datastore.path" 14 | FlagDatastoreUsername = "datastore.username" 15 | FlagDatastorePassword = "datastore.password" 16 | ) 17 | 18 | type DataStoreFlags struct { 19 | Type *string 20 | Url *string 21 | Path *string 22 | Username *string 23 | Password *string 24 | } 25 | 26 | // NewDatastoreFlags returns DataStoreFlags with default values set. 27 | func NewDatastoreFlags(dsType string) *DataStoreFlags { 28 | return &DataStoreFlags{ 29 | Type: pointer.ToString(dsType), 30 | Url: pointer.ToString(""), 31 | } 32 | } 33 | 34 | // AddFlags binds client configuration flags to a given cmd. 35 | func (d *DataStoreFlags) AddFlags(flags *pflag.FlagSet) { 36 | if d.Type != nil { 37 | flags.StringVar(d.Type, FlagDatastoreType, *d.Type, "Datastore provider type") 38 | } 39 | if d.Url != nil { 40 | flags.StringVar(d.Url, FlagDatastoreUrl, *d.Url, "Datastore connection url") 41 | } 42 | if d.Path != nil { 43 | flags.StringVar(d.Path, FlagDatastorePath, *d.Path, "Datastore save path") 44 | } 45 | if d.Username != nil { 46 | flags.StringVar(d.Username, FlagDatastoreUsername, *d.Username, "Datastore username") 47 | } 48 | if d.Password != nil { 49 | flags.StringVar(d.Password, FlagDatastorePassword, *d.Password, "Datastore password") 50 | } 51 | } 52 | 53 | func (d *DataStoreFlags) Validate() error { 54 | if d.Type != nil { 55 | dsType := *d.Type 56 | if dsType != "file" && dsType != "mongo" { 57 | return fmt.Errorf("invalid datastore type: %s", dsType) 58 | } 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /internal/options/model_flags.go: -------------------------------------------------------------------------------- 1 | package options 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AlekSi/pointer" 7 | "github.com/spf13/pflag" 8 | ) 9 | 10 | const ( 11 | FlagDefaultSystemPrompt = "system-prompt" 12 | FlagAiModel = "model" 13 | FlagAiToken = "token" 14 | FlagAiApiBase = "api-base" 15 | FlagAiTemperature = "temperature" 16 | FlagAiTopP = "top-p" 17 | FlagAiMaxTokens = "max-tokens" 18 | FlagOutputFormat = "output-format" 19 | FlagMultiContentEnabled = "multi-content-enabled" 20 | ) 21 | 22 | type ModelFlags struct { 23 | Token *string 24 | Model *string 25 | ApiBase *string 26 | Temperature *float64 27 | TopP *float64 28 | MaxTokens *int 29 | Proxy *string 30 | OutputFormat *string 31 | MultiContentEnabled *bool 32 | } 33 | 34 | // NewModelFlags returns ModelFlags with default values set. 35 | func NewModelFlags() *ModelFlags { 36 | return &ModelFlags{ 37 | Token: pointer.ToString(""), 38 | Model: pointer.ToString(""), 39 | ApiBase: pointer.ToString(""), 40 | Temperature: pointer.ToFloat64(0.5), 41 | TopP: pointer.ToFloat64(0.5), 42 | MaxTokens: pointer.ToInt(1024), 43 | OutputFormat: pointer.ToString(string(MarkdownOutputFormat)), 44 | } 45 | } 46 | 47 | // AddFlags binds client configuration flags to a given flagset. 48 | func (m *ModelFlags) AddFlags(flags *pflag.FlagSet) { 49 | if m.Token != nil { 50 | flags.StringVar(m.Token, FlagAiToken, *m.Token, "Api token to use for CLI requests") 51 | } 52 | if m.Model != nil { 53 | flags.StringVar(m.Model, FlagAiModel, *m.Model, "The encoding of the model to be called.") 54 | } 55 | if m.ApiBase != nil { 56 | flags.StringVar(m.ApiBase, FlagAiApiBase, *m.ApiBase, "Interface for the API.") 57 | } 58 | if m.Temperature != nil { 59 | flags.Float64Var(m.Temperature, FlagAiTemperature, *m.Temperature, "Sampling temperature to control the randomness of the output.") 60 | } 61 | if m.TopP != nil { 62 | flags.Float64Var(m.TopP, FlagAiTopP, *m.TopP, "Nucleus sampling method to control the probability mass of the output.") 63 | } 64 | if m.MaxTokens != nil { 65 | flags.IntVar(m.MaxTokens, FlagAiMaxTokens, *m.MaxTokens, "The maximum number of tokens the model can output.") 66 | } 67 | if m.OutputFormat != nil { 68 | flags.StringVarP(m.OutputFormat, FlagOutputFormat, "o", *m.OutputFormat, "Output format. One of: (markdown, raw).") 69 | } 70 | if m.MultiContentEnabled != nil { 71 | flags.BoolVar(m.MultiContentEnabled, FlagMultiContentEnabled, *m.MultiContentEnabled, "LLM multi content is enabled") 72 | } 73 | } 74 | 75 | func (m *ModelFlags) Validate() error { 76 | if m.OutputFormat != nil { 77 | output := *m.OutputFormat 78 | if output != string(MarkdownOutputFormat) && output != string(RawOutputFormat) { 79 | return fmt.Errorf("invalid output format: %s", output) 80 | } 81 | } 82 | return nil 83 | } 84 | -------------------------------------------------------------------------------- /internal/options/options.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package options print a list of global command-line options (applies to all commands). 6 | package options 7 | 8 | import ( 9 | "io" 10 | 11 | "github.com/spf13/cobra" 12 | 13 | "github.com/coding-hui/ai-terminal/internal/util/templates" 14 | ) 15 | 16 | var optionsExample = templates.Examples(` 17 | # Print flags inherited by all commands 18 | ai options`) 19 | 20 | // NewCmdOptions implements the options command. 21 | func NewCmdOptions(out io.Writer) *cobra.Command { 22 | cmd := &cobra.Command{ 23 | Use: "options", 24 | Short: "Print the list of flags inherited by all commands", 25 | Example: optionsExample, 26 | Run: func(cmd *cobra.Command, args []string) { 27 | _ = cmd.Usage() 28 | }, 29 | } 30 | 31 | // The `options` command needs write its output to the `out` stream 32 | // (typically stdout). Without calling SetOutput here, the Usage() 33 | // function call will fall back to stderr. 34 | cmd.SetOutput(out) 35 | 36 | templates.UseOptionsTemplates(cmd) 37 | 38 | return cmd 39 | } 40 | -------------------------------------------------------------------------------- /internal/prompt/language.go: -------------------------------------------------------------------------------- 1 | package prompt 2 | 3 | const DefaultLanguage = "en" 4 | 5 | var languageMaps = map[string]string{ 6 | "en": "English", 7 | "zh-tw": "Traditional Chinese", 8 | "zh-cn": "Simplified Chinese", 9 | "ja": "Japanese", 10 | "pt": "Portuguese", 11 | "pt-br": "Brazilian Portuguese", 12 | } 13 | 14 | // GetLanguage returns the language name for the given language code, 15 | // or the default language if the code is not recognized. 16 | func GetLanguage(langCode string) string { 17 | if language, ok := languageMaps[langCode]; ok { 18 | return language 19 | } 20 | return languageMaps[DefaultLanguage] 21 | } 22 | -------------------------------------------------------------------------------- /internal/prompt/language_test.go: -------------------------------------------------------------------------------- 1 | package prompt 2 | 3 | import "testing" 4 | 5 | func TestGetLanguage(t *testing.T) { 6 | testCases := []struct { 7 | langCode string 8 | expected string 9 | }{ 10 | {"en", "English"}, 11 | {"zh-tw", "Traditional Chinese"}, 12 | {"zh-cn", "Simplified Chinese"}, 13 | {"ja", "Japanese"}, 14 | {"fr", "English"}, 15 | } 16 | 17 | for _, tc := range testCases { 18 | result := GetLanguage(tc.langCode) 19 | if result != tc.expected { 20 | t.Errorf("GetLanguage(%q) = %q, expected %q", tc.langCode, result, tc.expected) 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /internal/prompt/prompt.go: -------------------------------------------------------------------------------- 1 | package prompt 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 7 | "github.com/coding-hui/wecoding-sdk-go/services/ai/prompts" 8 | ) 9 | 10 | func GetPromptStringByTemplate(promptTemplate string, vars map[string]any) (llms.PromptValue, error) { 11 | tpl := prompts.NewPromptTemplate(promptTemplate, nil) 12 | return tpl.FormatPrompt(vars) 13 | } 14 | 15 | func GetPromptStringByTemplateName(templateName string, vars map[string]any) (llms.PromptValue, error) { 16 | t, ok := promptTemplates[templateName] 17 | if !ok { 18 | return nil, fmt.Errorf("prompt template %s not found", templateName) 19 | } 20 | res, err := t.template.FormatPrompt(vars) 21 | if err != nil { 22 | return nil, err 23 | } 24 | return res, nil 25 | } 26 | -------------------------------------------------------------------------------- /internal/prompt/prompt_test.go: -------------------------------------------------------------------------------- 1 | package prompt 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestGetPromptStringByTemplate(t *testing.T) { 11 | promptTemplate := `{{ .summarize_prefix }}: {{ .summarize_title }} 12 | 13 | {{ .summarize_message }} 14 | ` 15 | p, err := GetPromptStringByTemplate(promptTemplate, map[string]any{ 16 | "summarize_prefix": "feat", 17 | "summarize_title": "ut test", 18 | "summarize_message": "test test test", 19 | }) 20 | require.NoError(t, err) 21 | assert.NotEmpty(t, p) 22 | assert.Contains(t, p, "feat") 23 | } 24 | -------------------------------------------------------------------------------- /internal/prompt/template_loader.go: -------------------------------------------------------------------------------- 1 | package prompt 2 | 3 | import ( 4 | "embed" 5 | 6 | "github.com/coding-hui/wecoding-sdk-go/services/ai/prompts" 7 | "k8s.io/klog/v2" 8 | ) 9 | 10 | //go:embed templates/* 11 | var templatesFS embed.FS 12 | 13 | // Template file names 14 | const ( 15 | CodeReviewTemplate = "code_review_file_diff.tmpl" 16 | SummarizeFileDiffTemplate = "summarize_file_diff.tmpl" 17 | SummarizeTitleTemplate = "summarize_title.tmpl" 18 | ConventionalCommitTemplate = "conventional_commit.tmpl" 19 | TranslationTemplate = "translation.tmpl" 20 | CommitMessageTemplate = "commit-msg.tmpl" 21 | 22 | UserAdditionalPrompt = "user_additional_prompt" 23 | SummarizePrefixKey = "summarize_prefix" 24 | SummarizeTitleKey = "summarize_title" 25 | SummarizeMessageKey = "summarize_message" 26 | SummarizePointsKey = "summary_points" 27 | FileDiffsKey = "file_diffs" 28 | OutputLanguageKey = "output_language" 29 | OutputMessageKey = "output_message" 30 | ) 31 | 32 | type prompt struct { 33 | inputVars []string 34 | template prompts.PromptTemplate 35 | } 36 | 37 | var ( 38 | templatesDir = "templates" 39 | promptTemplates = map[string]*prompt{ 40 | CodeReviewTemplate: { 41 | inputVars: []string{FileDiffsKey}, 42 | }, 43 | SummarizeFileDiffTemplate: { 44 | inputVars: []string{FileDiffsKey}, 45 | }, 46 | SummarizeTitleTemplate: { 47 | inputVars: []string{SummarizePointsKey}, 48 | }, 49 | ConventionalCommitTemplate: { 50 | inputVars: []string{SummarizePointsKey}, 51 | }, 52 | TranslationTemplate: { 53 | inputVars: []string{OutputLanguageKey, OutputMessageKey}, 54 | }, 55 | CommitMessageTemplate: { 56 | inputVars: []string{SummarizePrefixKey, SummarizeTitleKey, SummarizeMessageKey}, 57 | }, 58 | } 59 | ) 60 | 61 | // Initializes the prompt package by loading the templates from the embedded file system. 62 | func init() { //nolint:gochecknoinits 63 | if err := loadPromptTemplates(templatesFS); err != nil { 64 | klog.Fatal(err) 65 | } 66 | } 67 | 68 | // LoadPromptTemplates loads all the prompt templates found in the templates directory. 69 | func loadPromptTemplates(files embed.FS) error { 70 | for k, v := range promptTemplates { 71 | content, err := files.ReadFile(templatesDir + "/" + k) 72 | if err != nil { 73 | return err 74 | } 75 | v.template = prompts.NewPromptTemplate(string(content), v.inputVars) 76 | } 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /internal/prompt/templates/code_review_file_diff.tmpl: -------------------------------------------------------------------------------- 1 | Bellow is the code patch, please help me do a brief code review if any bug risk, security vulnerabilities and improvement suggestion are welcome 2 | 3 | THE Code Patch TO BE Reviewed: 4 | 5 | {{ .file_diffs }} 6 | -------------------------------------------------------------------------------- /internal/prompt/templates/commit-msg.tmpl: -------------------------------------------------------------------------------- 1 | {{ .summarize_prefix }}: {{ .summarize_title }} 2 | 3 | {{ .summarize_message }} 4 | -------------------------------------------------------------------------------- /internal/prompt/templates/conventional_commit.tmpl: -------------------------------------------------------------------------------- 1 | You are an expert programmer, and you are trying to summarize a code change. 2 | You went over every file that was changed in it. 3 | For some of these files changes where too big and were omitted in the files diff summary. 4 | Determine the best label for the commit. 5 | 6 | Here are the labels you can choose from: 7 | 8 | - build: Changes that affect the build system or external dependencies (example scopes: gulp, broccoli, npm) 9 | - chore: Updating libraries, copyrights or other repo setting, includes updating dependencies. 10 | - ci: Changes to our CI configuration files and scripts (example scopes: Travis, Circle, GitHub Actions) 11 | - docs: Non-code changes, such as fixing typos or adding new documentation (example scopes: Markdown file) 12 | - feat: a commit of the type feat introduces a new feature to the codebase 13 | - fix: A commit of the type fix patches a bug in your codebase 14 | - perf: A code change that improves performance 15 | - refactor: A code change that neither fixes a bug nor adds a feature 16 | - style: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc) 17 | - test: Adding missing tests or correcting existing tests 18 | 19 | 20 | THE FILE SUMMARIES: 21 | 22 | {{- if .user_additional_prompt }} 23 | # {{ .user_additional_prompt }} 24 | {{- end }} 25 | {{ .summary_points }} 26 | 27 | Based on the changes described in the file summaries, What's the best label for the commit? Your answer must be one of the labels above. Don't describe the changes, just write the label. -------------------------------------------------------------------------------- /internal/prompt/templates/summarize_file_diff.tmpl: -------------------------------------------------------------------------------- 1 | You are an expert programmer, and you are trying to summarize a git diff. 2 | And your main language is {{ .output_language }}. 3 | Reminders about the git diff format: 4 | For every file, there are a few metadata lines, like (for example): 5 | ``` 6 | diff --git a/lib/index.js b/lib/index.js 7 | index aadf691..bfef603 100644 8 | --- a/lib/index.js 9 | +++ b/lib/index.js 10 | ``` 11 | This means that `lib/index.js` was modified in this commit. Note that this is only an example. 12 | Then there is a specifier of the lines that were modified. 13 | A line starting with `+` means it was added. 14 | A line that starting with `-` means that line was deleted. 15 | A line that starts with neither `+` nor `-` is code given for context and better understanding. 16 | It is not part of the diff. 17 | After the git diff of the first file, there will be an empty line, and then the git diff of the next file. 18 | 19 | Do not include the file name as another part of the comment. 20 | Do not use the characters `[` or `]` in the summary. 21 | Write every summary comment in a new line. 22 | Comments should be in a bullet point list, each line starting with a `-`. 23 | The summary should not include comments copied from the code. 24 | The output should be easily readable. When in doubt, write less comments and not more. Do not output comments that simply repeat the contents of the file. 25 | Readability is top priority. Write only the most important comments about the diff. 26 | 27 | EXAMPLE SUMMARY COMMENTS: 28 | 29 | - Raise the amount of returned recordings from `10` to `100` 30 | - Fix a typo in the github action name 31 | - Move the `octokit` initialization to a separate file 32 | - Add an OpenAI API for completions 33 | - Lower numeric tolerance for test files 34 | - Add 2 tests for the inclusive string split function 35 | 36 | Most commits will have less comments than this examples list. 37 | The last comment does not include the file names, 38 | because there were more than two relevant files in the hypothetical commit. 39 | Do not include parts of the example in your summary. 40 | It is given only as an example of appropriate comments. 41 | 42 | 43 | THE GIT DIFF TO BE SUMMARIZED: 44 | 45 | {{ .file_diffs }} 46 | 47 | THE SUMMARY: 48 | -------------------------------------------------------------------------------- /internal/prompt/templates/summarize_title.tmpl: -------------------------------------------------------------------------------- 1 | You are an expert programmer, and you are trying to title a pull request. 2 | And your main language is {{ .output_language }}. 3 | You went over every file that was changed in it. 4 | For some of these files changes were too big and were omitted in the files diff summary. 5 | Please summarize the pull request into a single specific theme. 6 | Write your response using the imperative tense following the kernel git commit style guide. 7 | Write a high level title. 8 | Title is concise and technically accurate. Refrain from using "update", "temp", "xxx", etc. 9 | Do not repeat the commit summaries or the file summaries. 10 | Do not list individual changes in the title. 11 | 12 | EXAMPLE SUMMARY COMMENTS: 13 | ``` 14 | Raise the amount of returned recordings 15 | Switch to internal API for completions 16 | Lower numeric tolerance for test files 17 | Schedule all GitHub actions on all OSs 18 | ``` 19 | 20 | THE FILE SUMMARIES: 21 | 22 | {{- if .user_additional_prompt }} 23 | {{ .user_additional_prompt }} 24 | {{- end }} 25 | {{ .summary_points }} 26 | 27 | Remember to write only one line, no more than 50 characters. 28 | THE PULL REQUEST TITLE: 29 | -------------------------------------------------------------------------------- /internal/prompt/templates/translation.tmpl: -------------------------------------------------------------------------------- 1 | You are a professional programmer and translator, and you are trying to translate a git commit message. 2 | You want to ensure that the translation is high level and in line with the programmer's consensus, taking care to keep the formatting intact. 3 | 4 | Now, translate the following message into {{ .output_language }}. 5 | 6 | GIT COMMIT MESSAGE: 7 | 8 | {{ .output_message }} 9 | 10 | Remember translate all given git commit message. 11 | THE TRANSLATION: 12 | -------------------------------------------------------------------------------- /internal/runner/output.go: -------------------------------------------------------------------------------- 1 | package runner 2 | 3 | import "fmt" 4 | 5 | type Output struct { 6 | error error 7 | errorMessage string 8 | successMessage string 9 | } 10 | 11 | func NewRunOutput(error error, errorMessage string, successMessage string) Output { 12 | return Output{ 13 | error: error, 14 | errorMessage: errorMessage, 15 | successMessage: successMessage, 16 | } 17 | } 18 | 19 | func (o Output) HasError() bool { 20 | return o.error != nil 21 | } 22 | 23 | func (o Output) GetErrorMessage() string { 24 | return fmt.Sprintf("%s: %s", o.errorMessage, o.error) 25 | } 26 | 27 | func (o Output) GetSuccessMessage() string { 28 | return o.successMessage 29 | } 30 | -------------------------------------------------------------------------------- /internal/runner/output_test.go: -------------------------------------------------------------------------------- 1 | package runner 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRunOutput(t *testing.T) { 11 | t.Run("HasError", testHasError) 12 | t.Run("GetErrorMessage", testGetErrorMessage) 13 | t.Run("GetSuccessMessage", testGetSuccessMessage) 14 | } 15 | 16 | func testHasError(t *testing.T) { 17 | err := errors.New("test error") 18 | runOutputWithError := NewRunOutput(err, "Error occurred", "Success") 19 | runOutputWithoutError := NewRunOutput(nil, "Error occurred", "Success") 20 | 21 | assert.True(t, runOutputWithError.HasError(), "RunOutput should have an error.") 22 | assert.False(t, runOutputWithoutError.HasError(), "RunOutput should not have an error.") 23 | } 24 | 25 | func testGetErrorMessage(t *testing.T) { 26 | err := errors.New("test error") 27 | runOutput := NewRunOutput(err, "Error occurred", "Success") 28 | 29 | expectedErrorMessage := "Error occurred: test error" 30 | actualErrorMessage := runOutput.GetErrorMessage() 31 | 32 | assert.Equal(t, expectedErrorMessage, actualErrorMessage, "The error messages should be the same.") 33 | } 34 | 35 | func testGetSuccessMessage(t *testing.T) { 36 | runOutput := NewRunOutput(nil, "Error occurred", "Success") 37 | 38 | expectedSuccessMessage := "Success" 39 | actualSuccessMessage := runOutput.GetSuccessMessage() 40 | 41 | assert.Equal(t, expectedSuccessMessage, actualSuccessMessage, "The success messages should be the same.") 42 | } 43 | -------------------------------------------------------------------------------- /internal/runner/runner.go: -------------------------------------------------------------------------------- 1 | package runner 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | "runtime" 7 | ) 8 | 9 | const ( 10 | defaultShellUnix = "bash" 11 | defaultShellWin = "cmd.exe" 12 | ) 13 | 14 | func Run(cmd string, arg ...string) (string, error) { 15 | out, err := exec.Command(cmd, arg...).Output() 16 | if err != nil { 17 | return fmt.Sprintf("error: %v", err), err 18 | } 19 | 20 | return string(out), nil 21 | } 22 | 23 | func PrepareInteractiveCommand(input string) *exec.Cmd { 24 | if isWindows() { 25 | return prepareWindowsCommand(input) 26 | } 27 | return prepareUnixCommand(input) 28 | } 29 | 30 | func PrepareEditSettingsCommand(editor, filename string) *exec.Cmd { 31 | switch editor { 32 | case "vim": 33 | return exec.Command("vim", "+normal G$", filename) 34 | case "nano": 35 | return exec.Command("nano", "+99999999", filename) 36 | default: 37 | return exec.Command(editor, filename) 38 | } 39 | } 40 | 41 | // IsCommandAvailable checks whether a command is available in the PATH. 42 | func IsCommandAvailable(cmd string) bool { 43 | _, err := exec.LookPath(cmd) 44 | return err == nil 45 | } 46 | 47 | func isWindows() bool { 48 | return runtime.GOOS == "windows" 49 | } 50 | 51 | func prepareWindowsCommand(args ...string) *exec.Cmd { 52 | return exec.Command( 53 | defaultShellWin, 54 | append([]string{"/c"}, args...)..., 55 | ) 56 | } 57 | 58 | func prepareUnixCommand(args ...string) *exec.Cmd { 59 | return exec.Command( 60 | defaultShellUnix, 61 | append([]string{"-c"}, args...)..., 62 | ) 63 | } 64 | -------------------------------------------------------------------------------- /internal/runner/runner_test.go: -------------------------------------------------------------------------------- 1 | package runner 2 | 3 | import ( 4 | "os/exec" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRun(t *testing.T) { 11 | t.Run("PrepareUnixEditSettingsCommand", testPrepareUnixEditSettingsCommand) 12 | t.Run("PrepareWinEditSettingsCommand", testPrepareWinEditSettingsCommand) 13 | } 14 | 15 | func testPrepareUnixEditSettingsCommand(t *testing.T) { 16 | cmd := prepareUnixCommand("ai.json") 17 | 18 | expectedCmd := exec.Command( 19 | "bash", 20 | "-c", 21 | "ai.json", 22 | ) 23 | 24 | assert.Equal(t, expectedCmd.Args, cmd.Args, "The command arguments should be the same.") 25 | } 26 | 27 | func testPrepareWinEditSettingsCommand(t *testing.T) { 28 | cmd := prepareWindowsCommand("notepad.exe ai.json") 29 | 30 | expectedCmd := exec.Command( 31 | "cmd.exe", 32 | "/c", 33 | "notepad.exe ai.json", 34 | ) 35 | 36 | assert.Equal(t, expectedCmd.Args, cmd.Args, "The command arguments should be the same.") 37 | } 38 | -------------------------------------------------------------------------------- /internal/system/system.go: -------------------------------------------------------------------------------- 1 | package system 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "runtime" 7 | "strings" 8 | 9 | "github.com/adrg/xdg" 10 | 11 | "github.com/coding-hui/common/util/homedir" 12 | 13 | "github.com/coding-hui/ai-terminal/internal/runner" 14 | ) 15 | 16 | const ( 17 | DefaultApplicationName = "ai" 18 | DefaultEditor = "vim" 19 | ) 20 | 21 | type Analysis struct { 22 | operatingSystem OperatingSystem 23 | distribution string 24 | shell string 25 | homeDirectory string 26 | username string 27 | editor string 28 | configFile string 29 | } 30 | 31 | func (a *Analysis) GetApplicationName() string { 32 | return DefaultApplicationName 33 | } 34 | 35 | func (a *Analysis) GetOperatingSystem() OperatingSystem { 36 | return a.operatingSystem 37 | } 38 | 39 | func (a *Analysis) GetDistribution() string { 40 | return a.distribution 41 | } 42 | 43 | func (a *Analysis) GetShell() string { 44 | return a.shell 45 | } 46 | 47 | func (a *Analysis) GetHomeDirectory() string { 48 | return a.homeDirectory 49 | } 50 | 51 | func (a *Analysis) GetUsername() string { 52 | return a.username 53 | } 54 | 55 | func (a *Analysis) GetEditor() string { 56 | return a.editor 57 | } 58 | 59 | func (a *Analysis) GetConfigFile() string { 60 | return a.configFile 61 | } 62 | 63 | func Analyse() *Analysis { 64 | return &Analysis{ 65 | operatingSystem: GetOperatingSystem(), 66 | distribution: GetDistribution(), 67 | shell: GetShell(), 68 | homeDirectory: GetHomeDirectory(), 69 | username: GetUsername(), 70 | editor: GetEditor(), 71 | configFile: GetConfigFile(), 72 | } 73 | } 74 | 75 | func GetOperatingSystem() OperatingSystem { 76 | switch runtime.GOOS { 77 | case "linux": 78 | return LinuxOperatingSystem 79 | case "darwin": 80 | return MacOperatingSystem 81 | case "windows": 82 | return WindowsOperatingSystem 83 | default: 84 | return UnknownOperatingSystem 85 | } 86 | } 87 | 88 | func GetDistribution() string { 89 | dist, err := runner.Run("lsb_release", "-sd") 90 | if err != nil { 91 | return "" 92 | } 93 | 94 | return strings.Trim(strings.TrimSpace(dist), "\"") 95 | } 96 | 97 | func GetShell() string { 98 | var ( 99 | shell string 100 | err error 101 | operatingSystem = GetOperatingSystem() 102 | ) 103 | 104 | if operatingSystem == WindowsOperatingSystem { 105 | shell, err = runner.Run("echo", os.Getenv("COMSPEC")) 106 | } else { 107 | shell, err = runner.Run("echo", os.Getenv("SHELL")) 108 | } 109 | if err != nil { 110 | return "" 111 | } 112 | 113 | shell = strings.TrimSpace(shell) // Trims all leading and trailing white spaces 114 | shellPath := filepath.ToSlash(shell) // Normalize path separators to forward slash 115 | shellParts := strings.Split(shellPath, "/") 116 | 117 | return shellParts[len(shellParts)-1] 118 | } 119 | 120 | func GetHomeDirectory() string { 121 | return homedir.HomeDir() 122 | } 123 | 124 | func GetUsername() string { 125 | name, err := runner.Run("echo", os.Getenv("USER")) 126 | if err != nil { 127 | return "" 128 | } 129 | name = strings.TrimSpace(name) 130 | if name == "" { 131 | name, err = runner.Run("whoami") 132 | if err != nil { 133 | return "" 134 | } 135 | } 136 | 137 | nameParts := strings.Split(filepath.ToSlash(name), "/") 138 | 139 | return strings.TrimSpace(nameParts[len(nameParts)-1]) 140 | } 141 | 142 | func GetEditor() string { 143 | editor := os.Getenv("EDITOR") 144 | if editor == "" { 145 | editor = os.Getenv("VISUAL") 146 | if editor == "" { 147 | editor = DefaultEditor 148 | } 149 | } 150 | return editor 151 | } 152 | 153 | func GetConfigFile() string { 154 | sp, _ := xdg.ConfigFile(filepath.Join("ai-terminal", "config.yml")) 155 | return sp 156 | } 157 | -------------------------------------------------------------------------------- /internal/system/types.go: -------------------------------------------------------------------------------- 1 | package system 2 | 3 | type OperatingSystem int 4 | 5 | const ( 6 | UnknownOperatingSystem OperatingSystem = iota 7 | LinuxOperatingSystem 8 | MacOperatingSystem 9 | WindowsOperatingSystem 10 | ) 11 | 12 | func (o OperatingSystem) String() string { 13 | switch o { 14 | case LinuxOperatingSystem: 15 | return "linux" 16 | case MacOperatingSystem: 17 | return "macOS" 18 | case WindowsOperatingSystem: 19 | return "windows" 20 | default: 21 | return "unknown" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /internal/ui/chat/options.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/charmbracelet/lipgloss" 7 | 8 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 9 | 10 | "github.com/coding-hui/ai-terminal/internal/ai" 11 | "github.com/coding-hui/ai-terminal/internal/ui" 12 | "github.com/coding-hui/ai-terminal/internal/ui/console" 13 | ) 14 | 15 | type Options struct { 16 | ctx context.Context 17 | runMode ui.RunMode 18 | promptMode ui.PromptMode 19 | renderer *lipgloss.Renderer 20 | wordWrap int 21 | copyToClipboard bool 22 | 23 | engine *ai.Engine 24 | 25 | content string 26 | messages []llms.ChatMessage 27 | } 28 | 29 | type Option func(*Options) 30 | 31 | func WithContext(ctx context.Context) Option { 32 | return func(o *Options) { 33 | o.ctx = ctx 34 | } 35 | } 36 | 37 | func WithEngine(engine *ai.Engine) Option { 38 | return func(o *Options) { 39 | o.engine = engine 40 | } 41 | } 42 | 43 | func WithRunMode(runMode ui.RunMode) Option { 44 | return func(o *Options) { 45 | o.runMode = runMode 46 | } 47 | } 48 | 49 | func WithPromptMode(promptMode ui.PromptMode) Option { 50 | return func(o *Options) { 51 | o.promptMode = promptMode 52 | } 53 | } 54 | 55 | func WithContent(content string) Option { 56 | return func(o *Options) { 57 | o.content = content 58 | } 59 | } 60 | 61 | func WithMessages(messages []llms.ChatMessage) Option { 62 | return func(o *Options) { 63 | o.messages = messages 64 | } 65 | } 66 | 67 | func WithWordWrap(wordWrap int) Option { 68 | return func(o *Options) { 69 | o.wordWrap = wordWrap 70 | } 71 | } 72 | 73 | func WithRenderer(renderer *lipgloss.Renderer) Option { 74 | return func(o *Options) { 75 | o.renderer = renderer 76 | } 77 | } 78 | 79 | func WithCopyToClipboard(copy bool) Option { 80 | return func(o *Options) { 81 | o.copyToClipboard = copy 82 | } 83 | } 84 | 85 | func NewOptions(opts ...Option) *Options { 86 | o := &Options{ 87 | runMode: ui.CliMode, 88 | promptMode: ui.ChatPromptMode, 89 | renderer: console.StderrRenderer(), 90 | } 91 | 92 | for _, opt := range opts { 93 | opt(o) 94 | } 95 | 96 | return o 97 | } 98 | -------------------------------------------------------------------------------- /internal/ui/coders/auto_coder.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "context" 5 | _ "embed" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/coding-hui/common/version" 10 | 11 | "github.com/coding-hui/ai-terminal/internal/ai" 12 | "github.com/coding-hui/ai-terminal/internal/convo" 13 | "github.com/coding-hui/ai-terminal/internal/errbook" 14 | "github.com/coding-hui/ai-terminal/internal/git" 15 | "github.com/coding-hui/ai-terminal/internal/options" 16 | "github.com/coding-hui/ai-terminal/internal/ui/console" 17 | ) 18 | 19 | //go:embed banner.txt 20 | var banner string 21 | 22 | type AutoCoder struct { 23 | codeBasePath, prompt string 24 | repo *git.Command 25 | loadedContexts []*convo.LoadContext 26 | engine *ai.Engine 27 | store convo.Store 28 | 29 | versionInfo version.Info 30 | cfg *options.Config 31 | } 32 | 33 | func NewAutoCoder(opts ...AutoCoderOption) *AutoCoder { 34 | return applyAutoCoderOptions(opts...) 35 | } 36 | 37 | // saveContext persists the conversation context to the store for future reference 38 | func (a *AutoCoder) saveContext(ctx context.Context, lc *convo.LoadContext) error { 39 | lc.ConversationID = a.cfg.CacheWriteToID 40 | return a.store.SaveContext(ctx, lc) 41 | } 42 | 43 | // deleteContext removes a specific conversation context from the store by its ID 44 | func (a *AutoCoder) deleteContext(ctx context.Context, id uint64) error { 45 | return a.store.DeleteContexts(ctx, id) 46 | } 47 | 48 | func (a *AutoCoder) loadExistingContexts() error { 49 | // Get current conversation details 50 | details, err := convo.GetCurrentConversationID(context.Background(), a.cfg, a.store) 51 | if err != nil { 52 | return errbook.Wrap("Failed to get current conversation", err) 53 | } 54 | 55 | a.cfg.CacheWriteToID = details.WriteID 56 | a.cfg.CacheWriteToTitle = details.Title 57 | a.cfg.CacheReadFromID = details.ReadID 58 | a.cfg.Model = details.Model 59 | 60 | // Load all conversation contexts associated with the current session 61 | contexts, err := a.store.ListContextsByteConvoID(context.Background(), details.WriteID) 62 | if err != nil { 63 | return errbook.Wrap("Failed to load conversation contexts", err) 64 | } 65 | 66 | // Convert loaded contexts to pointers and store them in the AutoCoder instance 67 | for _, ctx := range contexts { 68 | a.loadedContexts = append(a.loadedContexts, &ctx) 69 | } 70 | 71 | return nil 72 | } 73 | 74 | func (a *AutoCoder) Run() error { 75 | codingCmd := strings.TrimSpace(a.prompt) != "" 76 | if !codingCmd { 77 | a.printWelcome() 78 | } 79 | 80 | // Load any existing contexts from previous session 81 | if err := a.loadExistingContexts(); err != nil { 82 | return err 83 | } 84 | 85 | cmdExecutor := NewCommandExecutor(a) 86 | 87 | if codingCmd { 88 | cmdExecutor.Executor(fmt.Sprintf("/coding %s", a.prompt)) 89 | return nil 90 | } 91 | 92 | cmdCompleter := NewCommandCompleter(a.repo) 93 | p := console.NewPrompt( 94 | a.cfg.AutoCoder.PromptPrefix, 95 | true, 96 | cmdCompleter.Complete, 97 | cmdExecutor.Executor, 98 | ) 99 | 100 | // Start the interactive REPL (Read-Eval-Print Loop) for command processing 101 | p.Run() 102 | 103 | return nil 104 | } 105 | 106 | func (a *AutoCoder) printWelcome() { 107 | fmt.Println(banner) 108 | console.RenderComment("") 109 | console.RenderComment("Welcome to AutoCoder - Your AI Coding Assistant! (%s) [Model: %s]\n", a.versionInfo.GitVersion, a.cfg.CurrentModel.Name) 110 | 111 | // Get current conversation info from config 112 | if a.cfg.CacheWriteToID != "" { 113 | console.RenderComment("Current Session:") 114 | console.RenderComment(" • ID: %s", a.cfg.CacheWriteToID) 115 | if a.cfg.CacheWriteToTitle != "" { 116 | console.RenderComment(" • Title: %s", a.cfg.CacheWriteToTitle) 117 | } 118 | console.RenderComment("") 119 | } 120 | 121 | console.Render("Let's start coding! 🚀") 122 | console.RenderComment("") 123 | } 124 | 125 | func (a *AutoCoder) determineBeatCodeFences(rawCode string) (string, string) { 126 | if len(a.cfg.AutoCoder.GetDefaultFences()) == 2 { 127 | f := a.cfg.AutoCoder.GetDefaultFences() 128 | return f[0], f[1] 129 | } 130 | return chooseBestFence(rawCode) 131 | } 132 | -------------------------------------------------------------------------------- /internal/ui/coders/auto_coder_options.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "github.com/coding-hui/ai-terminal/internal/ai" 5 | "github.com/coding-hui/ai-terminal/internal/convo" 6 | "github.com/coding-hui/ai-terminal/internal/git" 7 | "github.com/coding-hui/ai-terminal/internal/options" 8 | 9 | "github.com/coding-hui/common/version" 10 | ) 11 | 12 | type AutoCoderOption func(*AutoCoder) 13 | 14 | func WithConfig(cfg *options.Config) AutoCoderOption { 15 | return func(a *AutoCoder) { 16 | a.cfg = cfg 17 | } 18 | } 19 | 20 | func WithEngine(engine *ai.Engine) AutoCoderOption { 21 | return func(a *AutoCoder) { 22 | a.engine = engine 23 | } 24 | } 25 | 26 | func WithRepo(repo *git.Command) AutoCoderOption { 27 | return func(a *AutoCoder) { 28 | a.repo = repo 29 | } 30 | } 31 | 32 | func WithCodeBasePath(path string) AutoCoderOption { 33 | return func(a *AutoCoder) { 34 | a.codeBasePath = path 35 | } 36 | } 37 | 38 | func WithStore(store convo.Store) AutoCoderOption { 39 | return func(a *AutoCoder) { 40 | a.store = store 41 | } 42 | } 43 | 44 | func WithLoadedContexts(contexts []*convo.LoadContext) AutoCoderOption { 45 | return func(a *AutoCoder) { 46 | a.loadedContexts = contexts 47 | } 48 | } 49 | 50 | func WithPrompt(prompt string) AutoCoderOption { 51 | return func(a *AutoCoder) { 52 | a.prompt = prompt 53 | } 54 | } 55 | 56 | func applyAutoCoderOptions(options ...AutoCoderOption) *AutoCoder { 57 | ac := &AutoCoder{ 58 | versionInfo: version.Get(), 59 | loadedContexts: []*convo.LoadContext{}, 60 | } 61 | 62 | for _, option := range options { 63 | option(ac) 64 | } 65 | 66 | if ac.loadedContexts == nil { 67 | ac.loadedContexts = []*convo.LoadContext{} 68 | } 69 | 70 | return ac 71 | } 72 | -------------------------------------------------------------------------------- /internal/ui/coders/banner.txt: -------------------------------------------------------------------------------- 1 | _ _ _ _____ ___ ___ ___ ___ ___ ___ 2 | /_\| | | |_ _/ _ \ ___ / __/ _ \| \| __| _ \ 3 | / _ \ |_| | | || (_) |___| (_| (_) | |) | _|| / 4 | /_/ \_\___/ |_| \___/ \___\___/|___/|___|_|_\ -------------------------------------------------------------------------------- /internal/ui/coders/code_editor.go: -------------------------------------------------------------------------------- 1 | // Package coders provides interfaces and types for working with code editors. 2 | package coders 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | 8 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 9 | "github.com/coding-hui/wecoding-sdk-go/services/ai/prompts" 10 | ) 11 | 12 | // PartialCodeBlock represents a partial code block with its file path and original and updated text. 13 | type PartialCodeBlock struct { 14 | // Path is the file path of the code block. 15 | Path string 16 | // OriginalText is the original text of the code block. 17 | OriginalText string 18 | // UpdatedText is the updated text of the code block. 19 | UpdatedText string 20 | } 21 | 22 | func (p PartialCodeBlock) String() string { 23 | return fmt.Sprintf("%s\n```<<<<<< Original\n%s\n====== Updated\n%s```", 24 | p.Path, p.OriginalText, p.UpdatedText, 25 | ) 26 | } 27 | 28 | // Coder defines the interface for a code editor. 29 | type Coder interface { 30 | // Name returns the name of the code editor. 31 | Name() string 32 | // Prompt returns the prompt template used by the code editor. 33 | Prompt() prompts.ChatPromptTemplate 34 | // FormatMessages formats the messages with the provided values and returns the formatted messages. 35 | FormatMessages(values map[string]any) ([]llms.ChatMessage, error) 36 | // GetEdits retrieves the list of edits made to the code. 37 | GetEdits(ctx context.Context, codes string, fences []string) ([]PartialCodeBlock, error) 38 | // GetModifiedFiles retrieves the list of files that have been modified. 39 | GetModifiedFiles(ctx context.Context) ([]string, error) 40 | // UpdateCodeFences updates the code fence for the given code. 41 | UpdateCodeFences(ctx context.Context, code string) (string, string) 42 | // ApplyEdits applies the given list of edits to the code. 43 | ApplyEdits(ctx context.Context, edits []PartialCodeBlock) error 44 | // Execute runs the code editor with the specified input messages. 45 | Execute(ctx context.Context, messages []llms.ChatMessage) error 46 | } 47 | -------------------------------------------------------------------------------- /internal/ui/coders/commands_test.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCommands(t *testing.T) { 8 | } 9 | -------------------------------------------------------------------------------- /internal/ui/coders/completer.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/elk-language/go-prompt" 7 | pstrings "github.com/elk-language/go-prompt/strings" 8 | 9 | "github.com/coding-hui/ai-terminal/internal/git" 10 | ) 11 | 12 | type CommandCompleter struct { 13 | cmds []string 14 | repo *git.Command 15 | } 16 | 17 | func NewCommandCompleter(repo *git.Command) CommandCompleter { 18 | return CommandCompleter{ 19 | cmds: getSupportedCommands(), 20 | repo: repo, 21 | } 22 | } 23 | 24 | func (c CommandCompleter) Complete(d prompt.Document) (suggestions []prompt.Suggest, startChar, endChar pstrings.RuneNumber) { 25 | endIndex := d.CurrentRuneIndex() 26 | w := d.GetWordBeforeCursor() 27 | startIndex := endIndex - pstrings.RuneCount([]byte(w)) 28 | 29 | // if the input starts with "/", then we use the command completer 30 | if strings.HasPrefix(w, "/") { 31 | var completions []prompt.Suggest 32 | for _, v := range c.cmds { 33 | completions = append(completions, prompt.Suggest{Text: v}) 34 | } 35 | return prompt.FilterHasPrefix(completions, w, true), startIndex, endIndex 36 | } 37 | 38 | // if the input starts with "@", then we use the file completer 39 | if strings.HasPrefix(w, "@") { 40 | files, _ := c.repo.ListAllFiles() 41 | var completions []prompt.Suggest 42 | for _, v := range files { 43 | completions = append(completions, prompt.Suggest{Text: v}) 44 | } 45 | w = strings.TrimPrefix(w, "@") 46 | return prompt.FilterFuzzy(completions, w, true), startIndex, endIndex 47 | } 48 | 49 | // if the input starts with "--", then we use the flag completer 50 | if strings.HasPrefix(w, "--") { 51 | completions := []prompt.Suggest{ 52 | {Text: "--verbose"}, 53 | {Text: "--help"}, 54 | } 55 | return prompt.FilterContains(completions, w, true), startIndex, endIndex 56 | } 57 | 58 | return prompt.FilterHasPrefix([]prompt.Suggest{}, w, true), startIndex, endIndex 59 | } 60 | -------------------------------------------------------------------------------- /internal/ui/coders/context.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | 7 | "github.com/coding-hui/ai-terminal/internal/ai" 8 | "github.com/coding-hui/ai-terminal/internal/errbook" 9 | "github.com/coding-hui/ai-terminal/internal/git" 10 | "github.com/coding-hui/ai-terminal/internal/options" 11 | ) 12 | 13 | type CoderContext struct { 14 | context.Context 15 | 16 | codeBasePath string 17 | repo *git.Command 18 | absFileNames map[string]struct{} 19 | engine *ai.Engine 20 | } 21 | 22 | func NewCoderContext(cfg *options.Config) (*CoderContext, error) { 23 | repo := git.New() 24 | root, _ := repo.GitDir() 25 | engine, err := ai.New(ai.WithConfig(cfg)) 26 | if err != nil { 27 | return nil, errbook.Wrap("Could not initialized ai engine", err) 28 | } 29 | return &CoderContext{ 30 | codeBasePath: filepath.Dir(root), 31 | repo: repo, 32 | engine: engine, 33 | absFileNames: map[string]struct{}{}, 34 | }, nil 35 | } 36 | -------------------------------------------------------------------------------- /internal/ui/coders/fence.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "fmt" 5 | "path/filepath" 6 | "strings" 7 | ) 8 | 9 | var fences = [][]string{ 10 | {"``" + "`", "``" + "`"}, 11 | {"", ""}, 12 | {"", ""}, 13 | {"
", "
"}, 14 | {"", ""}, 15 | {"", ""}, 16 | } 17 | 18 | func wrapFenceWithType(rawContent, filename string, fences []string) string { 19 | fileExt := strings.TrimLeft(filepath.Ext(filename), ".") 20 | openFence, closeFence := defaultBestFence() 21 | if len(fences) == 2 { 22 | openFence, closeFence = fences[0], fences[1] 23 | } 24 | return fmt.Sprintf("\n%s%s\n%s\n%s\n", openFence, fileExt, rawContent, closeFence) 25 | } 26 | 27 | func defaultBestFence() (open string, close string) { 28 | return fences[0][0], fences[0][1] 29 | } 30 | 31 | // chooseExistingFence finds and returns the first existing code fence pair in the rawContent 32 | // where the open fence is followed by "<<<<<<< SEARCH" and the close fence is preceded by ">>>>>>> REPLACE". 33 | // It returns the open and close fence strings if found, otherwise empty strings. 34 | func chooseExistingFence(rawContent string) (open string, close string) { 35 | // Iterate through all supported fence pairs 36 | for _, fence := range fences { 37 | openFence, closeFence := fence[0], fence[1] 38 | 39 | // Find the open fence 40 | openIndex := strings.Index(rawContent, openFence) 41 | if openIndex == -1 { 42 | continue 43 | } 44 | 45 | // Find the close fence after the open fence 46 | remainingContent := rawContent[openIndex+len(openFence):] 47 | closeIndex := strings.Index(remainingContent, closeFence) 48 | if closeIndex == -1 { 49 | continue 50 | } 51 | 52 | // Check if the line before close fence is ">>>>>>> REPLACE" 53 | beforeClose := remainingContent[:closeIndex] 54 | lastNewline := strings.LastIndex(beforeClose, "\n") 55 | if lastNewline == -1 || !strings.HasSuffix(strings.TrimSpace(beforeClose[:lastNewline]), UPDATED) { 56 | continue 57 | } 58 | 59 | // Check if the line after open fence is "<<<<<<< SEARCH" 60 | firstNewline := strings.Index(beforeClose, "\n") 61 | if firstNewline == -1 || firstNewline > closeIndex { 62 | continue 63 | } 64 | if !strings.Contains(strings.TrimSpace(beforeClose[firstNewline:]), HEAD) { 65 | continue 66 | } 67 | 68 | return openFence, closeFence 69 | } 70 | 71 | return chooseBestFence(rawContent) 72 | } 73 | 74 | func chooseBestFence(rawContent string) (open string, close string) { 75 | for _, fence := range fences { 76 | if strings.Contains(rawContent, fence[0]) || strings.Contains(rawContent, fence[1]) { 77 | continue 78 | } 79 | open, close = fence[0], fence[1] 80 | return 81 | } 82 | 83 | // Unable to find a fencing strategy! 84 | return defaultBestFence() 85 | } 86 | -------------------------------------------------------------------------------- /internal/ui/coders/llm_callback.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/coding-hui/wecoding-sdk-go/services/ai/callbacks" 7 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 8 | ) 9 | 10 | type llmCallback struct{} 11 | 12 | var _ callbacks.Handler = llmCallback{} 13 | 14 | func (l llmCallback) HandleLLMGenerateContentStart(_ context.Context, ms []llms.MessageContent) { 15 | } 16 | 17 | func (l llmCallback) HandleLLMGenerateContentEnd(_ context.Context, res *llms.ContentResponse) { 18 | } 19 | 20 | func (l llmCallback) HandleStreamingFunc(_ context.Context, chunk []byte) { 21 | } 22 | 23 | func (l llmCallback) HandleText(_ context.Context, text string) { 24 | } 25 | 26 | func (l llmCallback) HandleLLMStart(_ context.Context, prompts []string) { 27 | } 28 | 29 | func (l llmCallback) HandleLLMError(_ context.Context, err error) { 30 | } 31 | 32 | func (l llmCallback) HandleChainStart(_ context.Context, inputs map[string]any) { 33 | } 34 | 35 | func (l llmCallback) HandleChainEnd(_ context.Context, outputs map[string]any) { 36 | } 37 | 38 | func (l llmCallback) HandleChainError(_ context.Context, err error) { 39 | } 40 | 41 | func (l llmCallback) HandleToolStart(_ context.Context, input string) { 42 | } 43 | 44 | func (l llmCallback) HandleToolEnd(_ context.Context, output string) { 45 | } 46 | 47 | func (l llmCallback) HandleToolError(_ context.Context, err error) { 48 | } 49 | -------------------------------------------------------------------------------- /internal/ui/coders/prompts_test.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | import ( 4 | "fmt" 5 | "html" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestPrompts(t *testing.T) { 12 | t.Run("editBlockCoderPrompt", testEditBlockCoderPrompt) 13 | } 14 | 15 | func testEditBlockCoderPrompt(t *testing.T) { 16 | tpl, err := promptBaseCoder.FormatPrompt(map[string]any{ 17 | userQuestionKey: "add comment", 18 | addedFilesKey: "test", 19 | openFenceKey: "```", 20 | closeFenceKey: "```", 21 | lazyPromptKey: lazyPrompt, 22 | }) 23 | require.NoError(t, err) 24 | fmt.Println(html.UnescapeString(tpl.String())) 25 | } 26 | -------------------------------------------------------------------------------- /internal/ui/coders/types.go: -------------------------------------------------------------------------------- 1 | package coders 2 | 3 | const ( 4 | FlagVerbose = "verbose" 5 | ) 6 | 7 | // PromptMode represents the mode of the prompt. 8 | type PromptMode int 9 | 10 | // Constants representing different prompt modes. 11 | const ( 12 | ExecPromptMode PromptMode = iota // ExecPromptMode represents the execution mode. 13 | ChatPromptMode // ChatPromptMode represents the chat mode. 14 | DefaultPromptMode // DefaultPromptMode represents the default mode. 15 | ) 16 | 17 | // String returns the string representation of the PromptMode. 18 | func (m PromptMode) String() string { 19 | switch m { 20 | case ExecPromptMode: 21 | return "exec" 22 | case ChatPromptMode: 23 | return "chat" 24 | default: 25 | return "default" 26 | } 27 | } 28 | 29 | // GetPromptModeFromString returns the PromptMode corresponding to the given string. 30 | func GetPromptModeFromString(s string) PromptMode { 31 | switch s { 32 | case "exec": 33 | return ExecPromptMode 34 | case "chat": 35 | return ChatPromptMode 36 | default: 37 | return DefaultPromptMode 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /internal/ui/console/confirm.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/charmbracelet/lipgloss" 8 | "github.com/erikgeiser/promptkit/confirmation" 9 | ) 10 | 11 | var No = confirmation.No 12 | var Yes = confirmation.Yes 13 | 14 | func WaitForUserConfirm(defaultVal confirmation.Value, format string, args ...interface{}) bool { 15 | input := confirmation.New(fmt.Sprintf(format, args...), defaultVal) 16 | confirm, err := input.RunPrompt() 17 | if err != nil { 18 | return false 19 | } 20 | return confirm 21 | } 22 | 23 | // action messages 24 | 25 | const defaultAction = "WROTE" 26 | 27 | var outputHeader = lipgloss.NewStyle().Foreground(lipgloss.Color("#F1F1F1")).Background(lipgloss.Color("#6C50FF")).Bold(true).Padding(0, 1).MarginRight(1) 28 | 29 | func PrintConfirmation(action, content string) { 30 | if action == "" { 31 | action = defaultAction 32 | } 33 | outputHeader = outputHeader.SetString(strings.ToUpper(action)) 34 | fmt.Println(lipgloss.JoinHorizontal(lipgloss.Center, outputHeader.String(), content)) 35 | } 36 | -------------------------------------------------------------------------------- /internal/ui/console/error.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/charmbracelet/lipgloss" 8 | ) 9 | 10 | var style = lipgloss.NewStyle(). 11 | Bold(true). 12 | PaddingTop(1). 13 | Foreground(lipgloss.Color("9")) 14 | 15 | // Deprecated: uses console.StderrRenderer() 16 | // Error handles and displays an error message. It formats the error message using the provided arguments. 17 | func Error(args interface{}) { 18 | Errorf("%s", args) 19 | } 20 | 21 | // Deprecated: uses console.StderrRenderer() 22 | // Errorf formats and displays an error message using the provided format string and arguments. 23 | func Errorf(format string, args ...interface{}) { 24 | fmt.Println(style.Render(fmt.Sprintf(format, args...))) 25 | } 26 | 27 | // Deprecated: uses console.StderrRenderer() 28 | // Fatal handles and displays a fatal error message. It then exits the program with a status code of 1. 29 | func Fatal(args interface{}) { 30 | Error(args) 31 | os.Exit(1) 32 | } 33 | 34 | // Deprecated: uses console.StderrRenderer() 35 | // Fatalf formats and displays a fatal error message using the provided format string and arguments. It then exits the program with a status code of 1. 36 | func Fatalf(format string, args ...interface{}) { 37 | Errorf(format, args...) 38 | os.Exit(1) 39 | } 40 | -------------------------------------------------------------------------------- /internal/ui/console/info.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/charmbracelet/lipgloss" 7 | ) 8 | 9 | // infoStyle defines the style for informational messages. 10 | var infoStyle = lipgloss.NewStyle(). 11 | Bold(false). 12 | PaddingTop(1). 13 | PaddingBottom(1) 14 | 15 | // Info prints an informational message with the defined style. 16 | func Info(text string) { 17 | fmt.Println(infoStyle.Render(text)) 18 | } 19 | 20 | // Infof formats and prints an informational message with the defined style. 21 | func Infof(format string, a ...interface{}) { 22 | Info(fmt.Sprintf(format, a...)) 23 | } 24 | -------------------------------------------------------------------------------- /internal/ui/console/key_bindings.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/elk-language/go-prompt" 8 | pstrings "github.com/elk-language/go-prompt/strings" 9 | ) 10 | 11 | var ( 12 | altD = []byte{226, 136, 130} 13 | ) 14 | 15 | var deleteWholeLine = prompt.ASCIICodeBind{ 16 | ASCIICode: altD, 17 | Fn: func(p *prompt.Prompt) bool { 18 | p.DeleteBeforeCursorRunes(pstrings.RuneNumber(len(p.Buffer().Document().Text))) 19 | return true 20 | }, 21 | } 22 | 23 | func makeNecessaryKeyBindings() []prompt.KeyBind { 24 | keyBinds := []prompt.KeyBind{ 25 | { 26 | Key: prompt.ControlH, 27 | Fn: prompt.DeleteBeforeChar, 28 | }, 29 | { 30 | Key: prompt.ControlU, 31 | Fn: deleteWholeLine.Fn, 32 | }, 33 | { 34 | Key: prompt.ControlC, 35 | Fn: func(b *prompt.Prompt) bool { 36 | fmt.Println("Bye!") 37 | os.Exit(0) 38 | return true 39 | }, 40 | }, 41 | } 42 | 43 | return keyBinds 44 | } 45 | -------------------------------------------------------------------------------- /internal/ui/console/prompt.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "github.com/elk-language/go-prompt" 5 | ) 6 | 7 | const suggestLimit = 5 8 | 9 | var ( 10 | options = []prompt.Option{ 11 | prompt.WithPrefixTextColor(prompt.Turquoise), 12 | prompt.WithCompletionOnDown(), 13 | prompt.WithSuggestionBGColor(prompt.DarkGray), 14 | prompt.WithSuggestionTextColor(prompt.White), 15 | prompt.WithDescriptionBGColor(prompt.LightGray), 16 | prompt.WithDescriptionTextColor(prompt.Black), 17 | prompt.WithSelectedSuggestionBGColor(prompt.Black), 18 | prompt.WithSelectedSuggestionTextColor(prompt.White), 19 | prompt.WithSelectedDescriptionBGColor(prompt.DarkGray), 20 | prompt.WithScrollbarThumbColor(prompt.Black), 21 | prompt.WithScrollbarBGColor(prompt.White), 22 | prompt.WithMaxSuggestion(suggestLimit), 23 | prompt.WithKeyBind(makeNecessaryKeyBindings()...), 24 | } 25 | ) 26 | 27 | func NewPrompt(prefix string, enableColor bool, completer prompt.Completer, exec func(string)) *prompt.Prompt { 28 | promptOptions := append(options, prompt.WithPrefix(prefix+" > ")) 29 | promptOptions = append(promptOptions, prompt.WithCompleter(completer)) 30 | if !enableColor { 31 | promptOptions = append(promptOptions, prompt.WithPrefixTextColor(prompt.DefaultColor)) 32 | } 33 | 34 | return prompt.New( 35 | exec, 36 | promptOptions..., 37 | ) 38 | } 39 | -------------------------------------------------------------------------------- /internal/ui/console/render.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | "html" 6 | "os" 7 | "sync" 8 | 9 | "github.com/atotto/clipboard" 10 | "github.com/charmbracelet/lipgloss" 11 | "github.com/muesli/termenv" 12 | 13 | "github.com/coding-hui/wecoding-sdk-go/services/ai/llms" 14 | ) 15 | 16 | var StdoutRenderer = sync.OnceValue(func() *lipgloss.Renderer { 17 | return lipgloss.DefaultRenderer() 18 | }) 19 | 20 | var StdoutStyles = sync.OnceValue(func() Styles { 21 | return MakeStyles(StdoutRenderer()) 22 | }) 23 | 24 | var StderrRenderer = sync.OnceValue(func() *lipgloss.Renderer { 25 | return lipgloss.NewRenderer(os.Stderr, termenv.WithColorCache(true)) 26 | }) 27 | 28 | var StderrStyles = sync.OnceValue(func() Styles { 29 | return MakeStyles(StderrRenderer()) 30 | }) 31 | 32 | func Render(format string, args ...interface{}) { 33 | msg := StdoutStyles().AppName.Render(fmt.Sprintf(format, args...)) 34 | fmt.Println(msg) 35 | } 36 | 37 | func RenderComment(format string, args ...interface{}) { 38 | msg := StdoutStyles().Comment.Render(fmt.Sprintf(format, args...)) 39 | fmt.Println(msg) 40 | } 41 | 42 | // RenderStep renders commit process step messages with a prefix 43 | func RenderStep(format string, args ...interface{}) { 44 | msg := StdoutStyles().CommitStep.Render(fmt.Sprintf("➤ "+format, args...)) 45 | fmt.Println(msg) 46 | } 47 | 48 | // RenderSuccess renders successful commit messages 49 | func RenderSuccess(format string, args ...interface{}) { 50 | msg := StdoutStyles().CommitSuccess.Render(fmt.Sprintf("✓ "+format, args...)) 51 | fmt.Println(msg) 52 | } 53 | 54 | func RenderError(err error, reason string, args ...interface{}) { 55 | header := StderrStyles().ErrPadding.Render(StderrStyles().ErrorHeader.String(), err.Error()) 56 | detail := StderrStyles().ErrPadding.Render(StderrStyles().ErrorDetails.Render(fmt.Sprintf(reason, args...))) 57 | _, _ = fmt.Printf("\n%s\n%s\n\n", header, detail) 58 | } 59 | 60 | func RenderAppName(appName string, suffix string, args ...interface{}) { 61 | appName = MakeGradientText(StdoutStyles().AppName, appName) 62 | fmt.Print(appName + " " + fmt.Sprintf(suffix, args...)) 63 | } 64 | 65 | func RenderChatMessages(messages []llms.ChatMessage) error { 66 | content, err := llms.GetBufferString(messages, "", "human", "ai") 67 | if err != nil { 68 | return err 69 | } 70 | out := html.UnescapeString(content) 71 | fmt.Println(out) 72 | _ = clipboard.WriteAll(out) 73 | termenv.Copy(out) 74 | PrintConfirmation("COPIED", "The content copied to clipboard!") 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /internal/ui/console/styles.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "github.com/charmbracelet/lipgloss" 5 | ) 6 | 7 | type Styles struct { 8 | AppName, 9 | CliArgs, 10 | Comment, 11 | CyclingChars, 12 | ErrorHeader, 13 | ErrorDetails, 14 | ErrPadding, 15 | Flag, 16 | FlagComma, 17 | FlagDesc, 18 | InlineCode, 19 | Link, 20 | Pipe, 21 | Quote, 22 | ConversationList, 23 | SHA1, 24 | Timeago, 25 | CommitStep, 26 | CommitSuccess, 27 | DiffHeader, 28 | DiffFileHeader, 29 | DiffHunkHeader, 30 | DiffAdded, 31 | DiffRemoved, 32 | DiffContext lipgloss.Style 33 | } 34 | 35 | func MakeStyles(r *lipgloss.Renderer) (s Styles) { 36 | const horizontalEdgePadding = 2 37 | s.AppName = r.NewStyle().Bold(true) 38 | s.CliArgs = r.NewStyle().Foreground(lipgloss.Color("#585858")) 39 | s.Comment = r.NewStyle().Foreground(lipgloss.Color("#757575")) 40 | s.CyclingChars = r.NewStyle().Foreground(lipgloss.Color("#FF87D7")) 41 | s.ErrorHeader = r.NewStyle().Foreground(lipgloss.Color("#F1F1F1")).Background(lipgloss.Color("#FF5F87")).Bold(true).Padding(0, 1).SetString("ERROR") 42 | s.ErrorDetails = s.Comment 43 | s.ErrPadding = r.NewStyle().Padding(0, horizontalEdgePadding) 44 | s.Flag = r.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#00B594", Dark: "#3EEFCF"}).Bold(true) 45 | s.FlagComma = r.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#5DD6C0", Dark: "#427C72"}).SetString(",") 46 | s.FlagDesc = s.Comment 47 | s.InlineCode = r.NewStyle().Foreground(lipgloss.Color("#FF5F87")).Background(lipgloss.Color("#3A3A3A")).Padding(0, 1) 48 | s.Link = r.NewStyle().Foreground(lipgloss.Color("#00AF87")).Underline(true) 49 | s.Quote = r.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#FF71D0", Dark: "#FF78D2"}) 50 | s.Pipe = r.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#8470FF", Dark: "#745CFF"}) 51 | s.ConversationList = r.NewStyle().Padding(0, 1) 52 | s.SHA1 = s.Flag 53 | s.Timeago = r.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#999", Dark: "#555"}) 54 | 55 | // Commit message styles 56 | s.CommitStep = r.NewStyle().Foreground(lipgloss.Color("#00CED1")).Bold(true) 57 | s.CommitSuccess = r.NewStyle().Foreground(lipgloss.Color("#32CD32")) 58 | 59 | // Diff styles 60 | s.DiffHeader = r.NewStyle().Bold(true) 61 | s.DiffFileHeader = r.NewStyle().Foreground(lipgloss.Color("#00CED1")).Bold(true) 62 | s.DiffHunkHeader = r.NewStyle().Foreground(lipgloss.Color("#888888")).Bold(true) 63 | s.DiffAdded = r.NewStyle().Foreground(lipgloss.Color("#00AA00")) 64 | s.DiffRemoved = r.NewStyle().Foreground(lipgloss.Color("#AA0000")) 65 | s.DiffContext = r.NewStyle().Foreground(lipgloss.Color("#888888")) 66 | 67 | return s 68 | } 69 | -------------------------------------------------------------------------------- /internal/ui/console/success.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/charmbracelet/lipgloss" 7 | ) 8 | 9 | var successStyle = lipgloss.NewStyle(). 10 | Bold(true). 11 | PaddingTop(1). 12 | Foreground(lipgloss.Color("2")) 13 | 14 | // Success prints a bold green success message with top padding. 15 | // The message is styled with a green foreground color (terminal color 2) and bold text. 16 | // Example usage: console.Success("Operation completed successfully") 17 | func Success(text string) { 18 | fmt.Println(successStyle.Render(text)) 19 | } 20 | 21 | // Successf prints a formatted bold green success message with top padding. 22 | // Uses fmt.Sprintf syntax for formatting and applies the same styling as Success. 23 | // Example usage: console.Successf("Successfully processed %d items", count) 24 | func Successf(format string, args ...interface{}) { 25 | fmt.Println(successStyle.Render(fmt.Sprintf(format, args...))) 26 | } 27 | -------------------------------------------------------------------------------- /internal/ui/console/warn.go: -------------------------------------------------------------------------------- 1 | package console 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/charmbracelet/lipgloss" 7 | ) 8 | 9 | var warnStyle = lipgloss.NewStyle(). 10 | Bold(true). 11 | PaddingTop(1). 12 | Foreground(lipgloss.Color("3")) 13 | 14 | // Warn prints the given text with a warning style. 15 | func Warn(text string) { 16 | fmt.Println(warnStyle.Render(text)) 17 | } 18 | 19 | // Warnf prints the formatted string with a warning style. 20 | func Warnf(format string, args ...interface{}) { 21 | fmt.Println(warnStyle.Render(fmt.Sprintf(format, args...))) 22 | } 23 | -------------------------------------------------------------------------------- /internal/ui/textarea.go: -------------------------------------------------------------------------------- 1 | package ui 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/charmbracelet/bubbles/textarea" 8 | tea "github.com/charmbracelet/bubbletea" 9 | ) 10 | 11 | type errMsg error 12 | 13 | type TextareaModel struct { 14 | Textarea textarea.Model 15 | Err error 16 | } 17 | 18 | func InitialTextareaPrompt(value string) TextareaModel { 19 | ti := textarea.New() 20 | ti.InsertString(value) 21 | ti.SetWidth(80) 22 | ti.SetHeight(len(strings.Split(value, "\n"))) 23 | ti.Focus() 24 | 25 | return TextareaModel{ 26 | Textarea: ti, 27 | Err: nil, 28 | } 29 | } 30 | 31 | func (m TextareaModel) Init() tea.Cmd { 32 | return textarea.Blink 33 | } 34 | 35 | func (m TextareaModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 36 | var cmds []tea.Cmd 37 | var cmd tea.Cmd 38 | 39 | switch msg := msg.(type) { 40 | case tea.KeyMsg: 41 | switch msg.Type { //nolint:exhaustive 42 | case tea.KeyEsc: 43 | if m.Textarea.Focused() { 44 | m.Textarea.Blur() 45 | } 46 | case tea.KeyCtrlC: 47 | return m, tea.Quit 48 | default: 49 | if !m.Textarea.Focused() { 50 | cmd = m.Textarea.Focus() 51 | cmds = append(cmds, cmd) 52 | } 53 | } 54 | 55 | // We handle errors just like any other message 56 | case errMsg: 57 | m.Err = msg 58 | return m, nil 59 | } 60 | 61 | m.Textarea, cmd = m.Textarea.Update(msg) 62 | cmds = append(cmds, cmd) 63 | return m, tea.Batch(cmds...) 64 | } 65 | 66 | func (m TextareaModel) View() string { 67 | return fmt.Sprintf( 68 | "Please confirm the following commit message.\n\n%s\n\n%s", 69 | m.Textarea.View(), 70 | "(ctrl+c to continue.)", 71 | ) + "\n\n" 72 | } 73 | -------------------------------------------------------------------------------- /internal/ui/types.go: -------------------------------------------------------------------------------- 1 | package ui 2 | 3 | // PromptMode represents different modes of user interaction 4 | type PromptMode int 5 | 6 | const ( 7 | // ExecPromptMode is for executing commands 8 | ExecPromptMode PromptMode = iota 9 | 10 | // TokenConfigPromptMode is for configuring API tokens 11 | TokenConfigPromptMode 12 | 13 | // ModelConfigPromptMode is for configuring models 14 | ModelConfigPromptMode 15 | 16 | // ApiBaseConfigPromptMode is for configuring API base URLs 17 | ApiBaseConfigPromptMode 18 | 19 | // ChatPromptMode is for chat interactions 20 | ChatPromptMode 21 | 22 | // DefaultPromptMode is the fallback mode 23 | DefaultPromptMode 24 | ) 25 | 26 | // String returns the string representation of a PromptMode 27 | func (m PromptMode) String() string { 28 | switch m { 29 | case ExecPromptMode: 30 | return "exec" 31 | case TokenConfigPromptMode: 32 | return "tokenConfig" 33 | case ModelConfigPromptMode: 34 | return "modelConfig" 35 | case ApiBaseConfigPromptMode: 36 | return "apiBaseConfig" 37 | case ChatPromptMode: 38 | return "ask" 39 | default: 40 | return "default" 41 | } 42 | } 43 | 44 | // GetPromptModeFromString converts a string to its corresponding PromptMode 45 | func GetPromptModeFromString(s string) PromptMode { 46 | switch s { 47 | case "exec": 48 | return ExecPromptMode 49 | case "tokenConfig": 50 | return TokenConfigPromptMode 51 | case "modelConfig": 52 | return ModelConfigPromptMode 53 | case "apiBaseConfig": 54 | return ApiBaseConfigPromptMode 55 | case "ask": 56 | return ChatPromptMode 57 | default: 58 | return DefaultPromptMode 59 | } 60 | } 61 | 62 | // RunMode represents different runtime modes 63 | type RunMode int 64 | 65 | const ( 66 | // CliMode is for command-line interface operation 67 | CliMode RunMode = iota 68 | 69 | // ReplMode is for read-eval-print-loop operation 70 | ReplMode 71 | ) 72 | 73 | // String returns the string representation of a RunMode 74 | func (m RunMode) String() string { 75 | if m == CliMode { 76 | return "cli" 77 | } else { 78 | return "repl" 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /internal/util/debug/log_test.go: -------------------------------------------------------------------------------- 1 | //nolint:errcheck 2 | package debug 3 | 4 | import ( 5 | "os" 6 | "testing" 7 | ) 8 | 9 | func TestTraceUn(t *testing.T) { 10 | // Setup test environment 11 | os.Setenv(envEnableLog, "true") 12 | os.Setenv(envLogFormat, FormatJSON) 13 | defer os.Unsetenv(envEnableLog) 14 | defer os.Unsetenv(envLogFormat) 15 | 16 | // Initialize logger 17 | Initialize() 18 | 19 | funcName := "TestFunction" 20 | args := []interface{}{"arg1", 123, true} 21 | 22 | // Test Trace 23 | traceName := Trace(funcName, args...) 24 | if traceName != funcName { 25 | t.Errorf("Trace() returned unexpected name: got %v want %v", traceName, funcName) 26 | } 27 | 28 | // Test Un 29 | Untrace(funcName) 30 | 31 | // Clean up 32 | Teardown() 33 | } 34 | -------------------------------------------------------------------------------- /internal/util/debug/options.go: -------------------------------------------------------------------------------- 1 | package debug 2 | 3 | // Option configuration function type 4 | type Option func(*Logger) error 5 | 6 | // WithFormatter sets log format 7 | func WithFormatter(f Formatter) Option { 8 | return func(l *Logger) error { 9 | l.formatter = f 10 | return nil 11 | } 12 | } 13 | 14 | // WithBufferSize sets buffer size 15 | func WithBufferSize(size int) Option { 16 | return func(l *Logger) error { 17 | l.bufferSize = size 18 | return nil 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /internal/util/flag/flag.go: -------------------------------------------------------------------------------- 1 | package flag 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/caarlos0/duration" 7 | ) 8 | 9 | func NewDurationFlag(val time.Duration, p *time.Duration) *DurationFlag { 10 | *p = val 11 | return (*DurationFlag)(p) 12 | } 13 | 14 | type DurationFlag time.Duration 15 | 16 | func (d *DurationFlag) Set(s string) error { 17 | v, err := duration.Parse(s) 18 | *d = DurationFlag(v) 19 | //nolint: wrapcheck 20 | return err 21 | } 22 | 23 | func (d *DurationFlag) String() string { 24 | return time.Duration(*d).String() 25 | } 26 | 27 | func (*DurationFlag) Type() string { 28 | return "duration" 29 | } 30 | -------------------------------------------------------------------------------- /internal/util/genericclioptions/io_options.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package genericclioptions 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | ) 11 | 12 | // IOStreams provides the standard names for iostreams. This is useful for embedding and for unit testing. 13 | // Inconsistent and different names make it hard to read and review code. 14 | type IOStreams struct { 15 | // In think, os.Stdin 16 | In io.Reader 17 | // Out think, os.Stdout 18 | Out io.Writer 19 | // ErrOut think, os.Stderr 20 | ErrOut io.Writer 21 | } 22 | 23 | // NewTestIOStreams returns a valid IOStreams and in, out, errout buffers for unit tests. 24 | func NewTestIOStreams() (io IOStreams, in *bytes.Buffer, out *bytes.Buffer, errOut *bytes.Buffer) { 25 | in = &bytes.Buffer{} 26 | out = &bytes.Buffer{} 27 | errOut = &bytes.Buffer{} 28 | 29 | return IOStreams{ 30 | In: in, 31 | Out: out, 32 | ErrOut: errOut, 33 | }, in, out, errOut 34 | } 35 | 36 | // NewTestIOStreamsDiscard returns a valid IOStreams that just discards. 37 | func NewTestIOStreamsDiscard() IOStreams { 38 | in := &bytes.Buffer{} 39 | 40 | return IOStreams{ 41 | In: in, 42 | Out: io.Discard, 43 | ErrOut: io.Discard, 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /internal/util/rest/rest.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/url" 7 | "regexp" 8 | "strings" 9 | "time" 10 | 11 | "github.com/PuerkitoBio/goquery" 12 | 13 | "github.com/coding-hui/ai-terminal/internal/errbook" 14 | ) 15 | 16 | const ( 17 | // Constants for fetchURLContent function 18 | maxRedirections = 10 19 | httpTimeout = 30 * time.Second 20 | maxContentSizeInMB = 10 21 | ) 22 | 23 | func FetchURLContent(url string) (string, error) { 24 | client := &http.Client{ 25 | Timeout: httpTimeout, 26 | CheckRedirect: func(req *http.Request, via []*http.Request) error { 27 | if len(via) >= maxRedirections { 28 | return errbook.New("stopped after too many redirects") 29 | } 30 | return nil 31 | }, 32 | } 33 | 34 | resp, err := client.Get(url) 35 | if err != nil { 36 | return "", err 37 | } 38 | defer resp.Body.Close() //nolint:errcheck 39 | 40 | if resp.StatusCode < 200 || resp.StatusCode >= 300 { 41 | return "", errbook.New("unexpected status code %d", resp.StatusCode) 42 | } 43 | 44 | // Limit the response reader to a maximum amount 45 | limitedReader := io.LimitReader(resp.Body, maxContentSizeInMB*1024*1024) 46 | 47 | content, err := io.ReadAll(limitedReader) 48 | if err != nil { 49 | return "", err 50 | } 51 | 52 | contentType := resp.Header.Get("Content-Type") 53 | if strings.Contains(contentType, "text/html") { 54 | return ExtractTextualContent(string(content)), nil 55 | } else { 56 | return string(content), nil 57 | } 58 | } 59 | 60 | func ExtractTextualContent(htmlContent string) string { 61 | r := strings.NewReader(htmlContent) 62 | doc, err := goquery.NewDocumentFromReader(r) 63 | if err != nil { 64 | return "" 65 | } 66 | 67 | return doc.Text() 68 | } 69 | 70 | func SanitizeURL(url string) string { 71 | // remove protocol portion with a regex 72 | re := regexp.MustCompile(`^.*?://`) 73 | url = re.ReplaceAllString(url, "") 74 | 75 | // Replace common invalid filename characters. You can extend this list as needed. 76 | sanitized := strings.ReplaceAll(url, ":", "_") 77 | sanitized = strings.ReplaceAll(sanitized, "/", "_") 78 | sanitized = strings.ReplaceAll(sanitized, "?", "_") 79 | sanitized = strings.ReplaceAll(sanitized, "&", "_") 80 | sanitized = strings.ReplaceAll(sanitized, "=", "_") 81 | sanitized = strings.ReplaceAll(sanitized, "#", "_") 82 | sanitized = strings.ReplaceAll(sanitized, "%", "_") 83 | sanitized = strings.ReplaceAll(sanitized, "*", "_") 84 | sanitized = strings.ReplaceAll(sanitized, " ", "_") 85 | return sanitized 86 | } 87 | 88 | func IsValidURL(str string) bool { 89 | u, err := url.Parse(str) 90 | return err == nil && u.Scheme != "" && u.Host != "" 91 | } 92 | -------------------------------------------------------------------------------- /internal/util/templates/command_groups.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package templates 6 | 7 | import ( 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | type CommandGroup struct { 12 | Message string 13 | Commands []*cobra.Command 14 | } 15 | 16 | type CommandGroups []CommandGroup 17 | 18 | func (g CommandGroups) Add(c *cobra.Command) { 19 | for _, group := range g { 20 | c.AddCommand(group.Commands...) 21 | } 22 | } 23 | 24 | func (g CommandGroups) Has(c *cobra.Command) bool { 25 | for _, group := range g { 26 | for _, command := range group.Commands { 27 | if command == c { 28 | return true 29 | } 30 | } 31 | } 32 | return false 33 | } 34 | 35 | func AddAdditionalCommands(g CommandGroups, message string, cmds []*cobra.Command) CommandGroups { 36 | group := CommandGroup{Message: message} 37 | for _, c := range cmds { 38 | // Don't show commands that have no short description 39 | if !g.Has(c) && len(c.Short) != 0 { 40 | group.Commands = append(group.Commands, c) 41 | } 42 | } 43 | if len(group.Commands) == 0 { 44 | return g 45 | } 46 | return append(g, group) 47 | } 48 | -------------------------------------------------------------------------------- /internal/util/templates/normalizers.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package templates 6 | 7 | import ( 8 | "strings" 9 | 10 | "github.com/MakeNowJust/heredoc/v2" 11 | "github.com/russross/blackfriday" 12 | "github.com/spf13/cobra" 13 | ) 14 | 15 | const Indentation = ` ` 16 | 17 | // LongDesc normalizes a command's long description to follow the conventions. 18 | func LongDesc(s string) string { 19 | if len(s) == 0 { 20 | return s 21 | } 22 | return normalizer{s}.heredoc().markdown().trim().string 23 | } 24 | 25 | // Examples normalizes a command's examples to follow the conventions. 26 | func Examples(s string) string { 27 | if len(s) == 0 { 28 | return s 29 | } 30 | return normalizer{s}.trim().indent().string 31 | } 32 | 33 | // Normalize perform all required normalizations on a given command. 34 | func Normalize(cmd *cobra.Command) *cobra.Command { 35 | if len(cmd.Long) > 0 { 36 | cmd.Long = LongDesc(cmd.Long) 37 | } 38 | if len(cmd.Example) > 0 { 39 | cmd.Example = Examples(cmd.Example) 40 | } 41 | return cmd 42 | } 43 | 44 | // NormalizeAll perform all required normalizations in the entire command tree. 45 | func NormalizeAll(cmd *cobra.Command) *cobra.Command { 46 | if cmd.HasSubCommands() { 47 | for _, subCmd := range cmd.Commands() { 48 | NormalizeAll(subCmd) 49 | } 50 | } 51 | Normalize(cmd) 52 | return cmd 53 | } 54 | 55 | type normalizer struct { 56 | string 57 | } 58 | 59 | func (s normalizer) markdown() normalizer { 60 | bytes := []byte(s.string) 61 | formatted := blackfriday.Markdown( 62 | bytes, 63 | &ASCIIRenderer{Indentation: Indentation}, 64 | blackfriday.EXTENSION_NO_INTRA_EMPHASIS, 65 | ) 66 | s.string = string(formatted) 67 | return s 68 | } 69 | 70 | func (s normalizer) heredoc() normalizer { 71 | s.string = heredoc.Doc(s.string) 72 | return s 73 | } 74 | 75 | func (s normalizer) trim() normalizer { 76 | s.string = strings.TrimSpace(s.string) 77 | return s 78 | } 79 | 80 | func (s normalizer) indent() normalizer { 81 | indentedLines := []string{} 82 | for _, line := range strings.Split(s.string, "\n") { 83 | trimmed := strings.TrimSpace(line) 84 | indented := Indentation + trimmed 85 | indentedLines = append(indentedLines, indented) 86 | } 87 | s.string = strings.Join(indentedLines, "\n") 88 | return s 89 | } 90 | -------------------------------------------------------------------------------- /internal/util/templates/templates.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package templates provides template functions for working with templates. 6 | package templates 7 | 8 | import ( 9 | "strings" 10 | "unicode" 11 | ) 12 | 13 | const ( 14 | // SectionVars is the help template section that declares variables to be used in the template. 15 | SectionVars = `{{$isRootCmd := isRootCmd .}}` + 16 | `{{$rootCmd := rootCmd .}}` + 17 | `{{$visibleFlags := visibleFlags (flagsNotIntersected .LocalFlags .PersistentFlags)}}` + 18 | `{{$explicitlyExposedFlags := exposed .}}` + 19 | `{{$optionsCmdFor := optionsCmdFor .}}` + 20 | `{{$usageLine := usageLine .}}` 21 | 22 | // SectionAliases is the help template section that displays command aliases. 23 | SectionAliases = `{{if gt .Aliases 0}}Aliases: 24 | {{.NameAndAliases}} 25 | 26 | {{end}}` 27 | 28 | // SectionExamples is the help template section that displays command examples. 29 | SectionExamples = `{{if .HasExample}}Examples: 30 | {{trimRight .Example}} 31 | 32 | {{end}}` 33 | 34 | // SectionSubcommands is the help template section that displays the command's subcommands. 35 | SectionSubcommands = `{{if .HasAvailableSubCommands}}{{cmdGroupsString .}} 36 | 37 | {{end}}` 38 | 39 | // SectionFlags is the help template section that displays the command's flags. 40 | SectionFlags = `{{ if or $visibleFlags.HasFlags $explicitlyExposedFlags.HasFlags}}Options: 41 | {{ if $visibleFlags.HasFlags}}{{trimRight (flagsUsages $visibleFlags)}}{{end}}{{ if $explicitlyExposedFlags.HasFlags}}{{ if $visibleFlags.HasFlags}} 42 | {{end}}{{trimRight (flagsUsages $explicitlyExposedFlags)}}{{end}} 43 | 44 | {{end}}` 45 | 46 | // SectionUsage is the help template section that displays the command's usage. 47 | SectionUsage = `{{if and .Runnable (ne .UseLine "") (ne .UseLine $rootCmd)}}Usage: 48 | {{$usageLine}} 49 | 50 | {{end}}` 51 | 52 | // SectionTipsHelp is the help template section that displays the '--help' hint. 53 | SectionTipsHelp = `{{if .HasSubCommands}}Use "{{$rootCmd}} --help" for more information about a given command. 54 | {{end}}` 55 | 56 | // SectionTipsGlobalOptions is the help template section that displays the 'options' hint for displaying global 57 | // flags. 58 | SectionTipsGlobalOptions = `{{if $optionsCmdFor}}Use "{{$optionsCmdFor}}" for a list of global command-line options (applies to all commands). 59 | {{end}}` 60 | ) 61 | 62 | // MainHelpTemplate if the template for 'help' used by most commands. 63 | func MainHelpTemplate() string { 64 | return `{{with or .Long .Short }}{{. | trim}}{{end}}{{if or .Runnable .HasSubCommands}}{{.UsageString}}{{end}}` 65 | } 66 | 67 | // MainUsageTemplate if the template for 'usage' used by most commands. 68 | func MainUsageTemplate() string { 69 | sections := []string{ 70 | "\n\n", 71 | SectionVars, 72 | SectionAliases, 73 | SectionExamples, 74 | SectionSubcommands, 75 | SectionFlags, 76 | SectionUsage, 77 | SectionTipsHelp, 78 | SectionTipsGlobalOptions, 79 | } 80 | return strings.TrimRightFunc(strings.Join(sections, ""), unicode.IsSpace) 81 | } 82 | 83 | // OptionsHelpTemplate if the template for 'help' used by the 'options' command. 84 | func OptionsHelpTemplate() string { 85 | return "" 86 | } 87 | 88 | // OptionsUsageTemplate if the template for 'usage' used by the 'options' command. 89 | func OptionsUsageTemplate() string { 90 | return `{{ if .HasInheritedFlags}}The following options can be passed to any command: 91 | 92 | {{flagsUsages .InheritedFlags}}{{end}}` 93 | } 94 | -------------------------------------------------------------------------------- /internal/util/term/pipe.go: -------------------------------------------------------------------------------- 1 | package term 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | "strings" 8 | 9 | "k8s.io/klog/v2" 10 | ) 11 | 12 | // ReadPipeInput reads input from a pipe and returns it as a string. 13 | // It checks if the input is coming from a pipe and reads the content until EOF. 14 | // Returns: 15 | // - A string containing the pipe input if successful. 16 | // - An empty string if there is an error or no input from the pipe. 17 | func ReadPipeInput() string { 18 | // Check the status of standard input to determine if it's a pipe. 19 | stat, err := os.Stdin.Stat() 20 | if err != nil { 21 | klog.Warningf("Failed to start pipe input: %v", err) 22 | return "" 23 | } 24 | pipe := "" 25 | // Check if the input is coming from a pipe and not empty. 26 | if !(stat.Mode()&os.ModeNamedPipe == 0 && stat.Size() == 0) { 27 | reader := bufio.NewReader(os.Stdin) 28 | var builder strings.Builder 29 | 30 | // Read runes from the pipe until EOF is reached. 31 | for { 32 | r, _, err := reader.ReadRune() 33 | if err != nil && err == io.EOF { 34 | break 35 | } 36 | _, err = builder.WriteRune(r) 37 | if err != nil { 38 | klog.Warningf("Failed to getting pipe input: %v", err) 39 | return "" 40 | } 41 | } 42 | 43 | // Trim any leading or trailing whitespace from the input. 44 | pipe = strings.TrimSpace(builder.String()) 45 | } 46 | 47 | return pipe 48 | } 49 | -------------------------------------------------------------------------------- /internal/util/term/resize.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package term 6 | 7 | import ( 8 | "github.com/moby/term" 9 | ) 10 | 11 | // TerminalSize represents the width and height of a terminal. 12 | type TerminalSize struct { 13 | Width uint16 14 | Height uint16 15 | } 16 | 17 | // TerminalSizeQueue is capable of returning terminal resize events as they occur. 18 | type TerminalSizeQueue interface { 19 | // Next returns the new terminal size after the terminal has been resized. It returns nil when 20 | // monitoring has been stopped. 21 | Next() *TerminalSize 22 | } 23 | 24 | // GetSize returns the current size of the user's terminal. If it isn't a terminal, 25 | // nil is returned. 26 | func (t TTY) GetSize() *TerminalSize { 27 | outFd, isTerminal := term.GetFdInfo(t.Out) 28 | if !isTerminal { 29 | return nil 30 | } 31 | return GetSize(outFd) 32 | } 33 | 34 | // GetSize returns the current size of the terminal associated with fd. 35 | func GetSize(fd uintptr) *TerminalSize { 36 | winsize, err := term.GetWinsize(fd) 37 | if err != nil { 38 | // runtime.HandleError(fmt.Errorf("unable to get terminal size: %v", err)) 39 | return nil 40 | } 41 | 42 | return &TerminalSize{Width: winsize.Width, Height: winsize.Height} 43 | } 44 | -------------------------------------------------------------------------------- /internal/util/term/term.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package term provides structures and helper functions to work with 6 | // terminal (state, sizes). 7 | package term 8 | 9 | import ( 10 | "io" 11 | "os" 12 | "regexp" 13 | "strings" 14 | "sync" 15 | 16 | "github.com/mattn/go-isatty" 17 | ) 18 | 19 | // TTY helps invoke a function and preserve the state of the terminal, even if the process is 20 | // terminated during execution. It also provides support for terminal resizing for remote command 21 | // execution/attachment. 22 | type TTY struct { 23 | // In is a reader representing stdin. It is a required field. 24 | In io.Reader 25 | // Out is a writer representing stdout. It must be set to support terminal resizing. It is an 26 | // optional field. 27 | Out io.Writer 28 | // Raw is true if the terminal should be set raw. 29 | Raw bool 30 | // TryDev indicates the TTY should try to open /dev/tty if the provided input 31 | // is not a file descriptor. 32 | TryDev bool 33 | } 34 | 35 | var IsInputTTY = sync.OnceValue(func() bool { 36 | return isatty.IsTerminal(os.Stdin.Fd()) 37 | }) 38 | 39 | var IsOutputTTY = sync.OnceValue(func() bool { 40 | return isatty.IsTerminal(os.Stdout.Fd()) 41 | }) 42 | 43 | // SanitizeFilename cleans up a filename to make it safe for use in file systems 44 | // by removing or replacing invalid characters. 45 | func SanitizeFilename(filename string) string { 46 | // Remove any path components 47 | filename = strings.ReplaceAll(filename, "/", "_") 48 | filename = strings.ReplaceAll(filename, "\\", "_") 49 | 50 | // Remove other potentially problematic characters 51 | reg := regexp.MustCompile(`[<>:"|?*]`) 52 | filename = reg.ReplaceAllString(filename, "_") 53 | 54 | // Trim spaces and dots from start/end 55 | filename = strings.TrimSpace(filename) 56 | filename = strings.Trim(filename, ".") 57 | 58 | // Ensure filename is not empty 59 | if filename == "" { 60 | filename = "unnamed" 61 | } 62 | 63 | // Limit length to 255 characters (common filesystem limit) 64 | if len(filename) > 64 { 65 | filename = filename[:64] 66 | } 67 | 68 | return filename 69 | } 70 | -------------------------------------------------------------------------------- /internal/util/term/term_writer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package term 6 | 7 | import ( 8 | "io" 9 | "os" 10 | 11 | "github.com/mitchellh/go-wordwrap" 12 | "github.com/moby/term" 13 | ) 14 | 15 | type wordWrapWriter struct { 16 | limit uint 17 | writer io.Writer 18 | } 19 | 20 | // NewResponsiveWriter creates a Writer that detects the column width of the 21 | // terminal we are in, and adjusts every line width to fit and use recommended 22 | // terminal sizes for better readability. Does proper word wrapping automatically. 23 | // 24 | // if terminal width >= 120 columns use 120 columns 25 | // if terminal width >= 100 columns use 100 columns 26 | // if terminal width >= 80 columns use 80 columns 27 | // 28 | // In case we're not in a terminal or if it's smaller than 80 columns width, 29 | // doesn't do any wrapping. 30 | func NewResponsiveWriter(w io.Writer) io.Writer { 31 | file, ok := w.(*os.File) 32 | if !ok { 33 | return w 34 | } 35 | fd := file.Fd() 36 | if !term.IsTerminal(fd) { 37 | return w 38 | } 39 | 40 | terminalSize := GetSize(fd) 41 | if terminalSize == nil { 42 | return w 43 | } 44 | 45 | var limit uint 46 | switch { 47 | case terminalSize.Width >= 120: 48 | limit = 120 49 | case terminalSize.Width >= 100: 50 | limit = 100 51 | case terminalSize.Width >= 80: 52 | limit = 80 53 | } 54 | 55 | return NewWordWrapWriter(w, limit) 56 | } 57 | 58 | // NewWordWrapWriter is a Writer that supports a limit of characters on every line 59 | // and does auto word wrapping that respects that limit. 60 | func NewWordWrapWriter(w io.Writer, limit uint) io.Writer { 61 | return &wordWrapWriter{ 62 | limit: limit, 63 | writer: w, 64 | } 65 | } 66 | 67 | func (w wordWrapWriter) Write(p []byte) (nn int, err error) { 68 | if w.limit == 0 { 69 | return w.writer.Write(p) 70 | } 71 | original := string(p) 72 | wrapped := wordwrap.WrapString(original, w.limit) 73 | return w.writer.Write([]byte(wrapped)) 74 | } 75 | 76 | // NewPunchCardWriter is a NewWordWrapWriter that limits the line width to 80 columns. 77 | func NewPunchCardWriter(w io.Writer) io.Writer { 78 | return NewWordWrapWriter(w, 80) 79 | } 80 | 81 | type maxWidthWriter struct { 82 | maxWidth uint 83 | currentWidth uint 84 | written uint 85 | writer io.Writer 86 | } 87 | 88 | // NewMaxWidthWriter is a Writer that supports a limit of characters on every 89 | // line, but doesn't do any word wrapping automatically. 90 | func NewMaxWidthWriter(w io.Writer, maxWidth uint) io.Writer { 91 | return &maxWidthWriter{ 92 | maxWidth: maxWidth, 93 | writer: w, 94 | } 95 | } 96 | 97 | func (m maxWidthWriter) Write(p []byte) (nn int, err error) { 98 | for _, b := range p { 99 | if m.currentWidth == m.maxWidth { 100 | _, err := m.writer.Write([]byte{'\n'}) 101 | if err != nil { 102 | return 0, err 103 | } 104 | m.currentWidth = 0 105 | } 106 | if b == '\n' { 107 | m.currentWidth = 0 108 | } 109 | _, err := m.writer.Write([]byte{b}) 110 | if err != nil { 111 | return int(m.written), err 112 | } 113 | m.written++ 114 | m.currentWidth++ 115 | } 116 | return len(p), nil 117 | } 118 | -------------------------------------------------------------------------------- /internal/util/term/term_writer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023 coding-hui. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package term 6 | 7 | import ( 8 | "bytes" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | const test = "Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam Iam" 14 | 15 | func TestWordWrapWriter(t *testing.T) { 16 | testcases := map[string]struct { 17 | input string 18 | maxWidth uint 19 | }{ 20 | "max 10": {input: test, maxWidth: 10}, 21 | "max 80": {input: test, maxWidth: 80}, 22 | "max 120": {input: test, maxWidth: 120}, 23 | "max 5000": {input: test, maxWidth: 5000}, 24 | } 25 | for k, tc := range testcases { 26 | b := bytes.NewBufferString("") 27 | w := NewWordWrapWriter(b, tc.maxWidth) 28 | _, err := w.Write([]byte(tc.input)) 29 | if err != nil { 30 | t.Errorf("%s: Unexpected error: %v", k, err) 31 | } 32 | result := b.String() 33 | if !strings.Contains(result, "Iam") { 34 | t.Errorf("%s: Expected to contain \"Iam\"", k) 35 | } 36 | if len(result) < len(tc.input) { 37 | t.Errorf( 38 | "%s: Unexpectedly short string, got %d wanted at least %d chars: %q", 39 | k, 40 | len(result), 41 | len(tc.input), 42 | result, 43 | ) 44 | } 45 | for _, line := range strings.Split(result, "\n") { 46 | if len(line) > int(tc.maxWidth) { 47 | t.Errorf("%s: Every line must be at most %d chars long, got %d: %q", k, tc.maxWidth, len(line), line) 48 | } 49 | } 50 | for _, word := range strings.Split(result, " ") { 51 | if !strings.Contains(word, "Iam") { 52 | t.Errorf("%s: Unexpected broken word: %q", k, word) 53 | } 54 | } 55 | } 56 | } 57 | 58 | func TestMaxWidthWriter(t *testing.T) { 59 | testcases := map[string]struct { 60 | input string 61 | maxWidth uint 62 | }{ 63 | "max 10": {input: test, maxWidth: 10}, 64 | "max 80": {input: test, maxWidth: 80}, 65 | "max 120": {input: test, maxWidth: 120}, 66 | "max 5000": {input: test, maxWidth: 5000}, 67 | } 68 | for k, tc := range testcases { 69 | b := bytes.NewBufferString("") 70 | w := NewMaxWidthWriter(b, tc.maxWidth) 71 | _, err := w.Write([]byte(tc.input)) 72 | if err != nil { 73 | t.Errorf("%s: Unexpected error: %v", k, err) 74 | } 75 | result := b.String() 76 | if !strings.Contains(result, "Iam") { 77 | t.Errorf("%s: Expected to contain \"Iam\"", k) 78 | } 79 | if len(result) < len(tc.input) { 80 | t.Errorf( 81 | "%s: Unexpectedly short string, got %d wanted at least %d chars: %q", 82 | k, 83 | len(result), 84 | len(tc.input), 85 | result, 86 | ) 87 | } 88 | lines := strings.Split(result, "\n") 89 | for i, line := range lines { 90 | if len(line) > int(tc.maxWidth) { 91 | t.Errorf("%s: Every line must be at most %d chars long, got %d: %q", k, tc.maxWidth, len(line), line) 92 | } 93 | if i < len(lines)-1 && len(line) != int(tc.maxWidth) { 94 | t.Errorf( 95 | "%s: Lines except the last one are expected to be exactly %d chars long, got %d: %q", 96 | k, 97 | tc.maxWidth, 98 | len(line), 99 | line, 100 | ) 101 | } 102 | } 103 | } 104 | } 105 | --------------------------------------------------------------------------------