├── .flake8 ├── .github └── workflows │ ├── pypi-publish.yml │ └── python-app.yml ├── .gitignore ├── LICENSE ├── README.md ├── gptcli ├── __init__.py ├── assistant.py ├── cli.py ├── completion.py ├── composite.py ├── config.py ├── cost.py ├── gpt.py ├── logging_utils.py ├── providers │ ├── __init__.py │ ├── anthropic.py │ ├── azure_openai.py │ ├── cohere.py │ ├── google.py │ ├── llama.py │ └── openai.py ├── session.py └── shell.py ├── pyproject.toml ├── screenshot.png ├── tests ├── __init__.py ├── test_assistant.py └── test_session.py └── uv.lock /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, E402, W503 -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | environment: 23 | name: pypi-publish 24 | url: https://pypi.org/p/gpt-command-line 25 | permissions: 26 | id-token: write 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | - name: Set up Python 31 | uses: actions/setup-python@v3 32 | with: 33 | python-version: '3.x' 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install build 38 | - name: Build package 39 | run: python -m build 40 | - name: Publish package distributions to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.9 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.9" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pytest 30 | pip install . 31 | #- name: Lint with flake8 32 | #run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | #flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | #flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 by Val Kharitonov 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gpt-cli 2 | 3 | Command-line interface for chat LLMs. 4 | 5 | ## Try now 6 | ``` 7 | export ANTHROPIC_API_KEY=xcxc 8 | uvx --from gpt-command-line gpt 9 | ``` 10 | 11 | ## Supported providers 12 | 13 | - OpenAI 14 | - Anthropic 15 | - Google Gemini 16 | - Cohere 17 | - Other APIs compatible with OpenAI (e.g. Together, OpenRouter, local models with LM Studio) 18 | 19 | ![screenshot](https://github.com/kharvd/gpt-cli/assets/466920/ecbcccc4-7cfa-4c04-83c3-a822b6596f01) 20 | 21 | ## Features 22 | 23 | - **Command-Line Interface**: Interact with ChatGPT or Claude directly from your terminal. 24 | - **Model Customization**: Override the default model, temperature, and top_p values for each assistant, giving you fine-grained control over the AI's behavior. 25 | - **Extended Thinking Mode**: Enable Claude 3.7's extended thinking capability to see its reasoning process for complex problems. 26 | - **Usage tracking**: Track your API usage with token count and price information. 27 | - **Keyboard Shortcuts**: Use Ctrl-C, Ctrl-D, and Ctrl-R shortcuts for easier conversation management and input control. 28 | - **Multi-Line Input**: Enter multi-line mode for more complex queries or conversations. 29 | - **Markdown Support**: Enable or disable markdown formatting for chat sessions to tailor the output to your preferences. 30 | - **Predefined Messages**: Set up predefined messages for your custom assistants to establish context or role-play scenarios. 31 | - **Multiple Assistants**: Easily switch between different assistants, including general, dev, and custom assistants defined in the config file. 32 | - **Flexible Configuration**: Define your assistants, model parameters, and API key in a YAML configuration file, allowing for easy customization and management. 33 | 34 | ## Installation 35 | 36 | This install assumes a Linux/OSX machine with Python and pip available. 37 | 38 | ```bash 39 | pip install gpt-command-line 40 | ``` 41 | 42 | Install latest version from source: 43 | 44 | ```bash 45 | pip install git+https://github.com/kharvd/gpt-cli.git 46 | ``` 47 | 48 | Or install by cloning the repository manually: 49 | 50 | ```bash 51 | git clone https://github.com/kharvd/gpt-cli.git 52 | cd gpt-cli 53 | pip install . 54 | ``` 55 | 56 | Add the OpenAI API key to your `.bashrc` file (in the root of your home folder). 57 | In this example we use nano, you can use any text editor. 58 | 59 | ``` 60 | nano ~/.bashrc 61 | export OPENAI_API_KEY= 62 | ``` 63 | 64 | Run the tool 65 | 66 | ``` 67 | gpt 68 | ``` 69 | 70 | You can also use a `gpt.yml` file for configuration. See the [Configuration](README.md#Configuration) section below. 71 | 72 | ## Usage 73 | 74 | Make sure to set the `OPENAI_API_KEY` environment variable to your OpenAI API key (or put it in the `~/.config/gpt-cli/gpt.yml` file as described below). 75 | 76 | ``` 77 | usage: gpt [-h] [--no_markdown] [--model MODEL] [--temperature TEMPERATURE] [--top_p TOP_P] 78 | [--thinking THINKING_BUDGET] [--log_file LOG_FILE] 79 | [--log_level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--prompt PROMPT] 80 | [--execute EXECUTE] [--no_stream] [{dev,general,bash}] 81 | 82 | Run a chat session with ChatGPT. See https://github.com/kharvd/gpt-cli for more information. 83 | 84 | positional arguments: 85 | {dev,general,bash} 86 | The name of assistant to use. `general` (default) is a generally helpful 87 | assistant, `dev` is a software development assistant with shorter 88 | responses. You can specify your own assistants in the config file 89 | ~/.config/gpt-cli/gpt.yml. See the README for more information. 90 | 91 | optional arguments: 92 | -h, --help show this help message and exit 93 | --no_markdown Disable markdown formatting in the chat session. 94 | --model MODEL The model to use for the chat session. Overrides the default model defined 95 | for the assistant. 96 | --temperature TEMPERATURE 97 | The temperature to use for the chat session. Overrides the default 98 | temperature defined for the assistant. 99 | --top_p TOP_P The top_p to use for the chat session. Overrides the default top_p defined 100 | for the assistant. 101 | --thinking THINKING_BUDGET 102 | Enable Claude's extended thinking mode with the specified token budget. 103 | Only applies to Claude 3.7 models. 104 | --log_file LOG_FILE The file to write logs to. Supports strftime format codes. 105 | --log_level {DEBUG,INFO,WARNING,ERROR,CRITICAL} 106 | The log level to use 107 | --prompt PROMPT, -p PROMPT 108 | If specified, will not start an interactive chat session and instead will 109 | print the response to standard output and exit. May be specified multiple 110 | times. Use `-` to read the prompt from standard input. Implies 111 | --no_markdown. 112 | --execute EXECUTE, -e EXECUTE 113 | If specified, passes the prompt to the assistant and allows the user to 114 | edit the produced shell command before executing it. Implies --no_stream. 115 | Use `-` to read the prompt from standard input. 116 | --no_stream If specified, will not stream the response to standard output. This is 117 | useful if you want to use the response in a script. Ignored when the 118 | --prompt option is not specified. 119 | --no_price Disable price logging. 120 | ``` 121 | 122 | Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C to clear the conversation, `:r` or Ctrl-R to re-generate the last response. 123 | To enter multi-line mode, enter a backslash `\` followed by a new line. Exit the multi-line mode by pressing ESC and then Enter. 124 | 125 | The `dev` assistant is instructed to be an expert in software development and provide short responses. 126 | 127 | ```bash 128 | $ gpt dev 129 | ``` 130 | 131 | The `bash` assistant is instructed to be an expert in bash scripting and provide only bash commands. Use the `--execute` option to execute the commands. It works best with the `gpt-4` model. 132 | 133 | ```bash 134 | gpt bash -e "How do I list files in a directory?" 135 | ``` 136 | 137 | This will prompt you to edit the command in your `$EDITOR` it before executing it. 138 | 139 | ## Configuration 140 | 141 | You can configure the assistants in the config file `~/.config/gpt-cli/gpt.yml`. The file is a YAML file with the following structure (see also [config.py](./gptcli/config.py)) 142 | 143 | ```yaml 144 | default_assistant: 145 | markdown: False 146 | openai_api_key: 147 | anthropic_api_key: 148 | log_file: 149 | log_level: 150 | assistants: 151 | : 152 | model: 153 | temperature: 154 | top_p: 155 | thinking_budget: # Claude 3.7 models only 156 | messages: 157 | - { role: , content: } 158 | - ... 159 | : 160 | ... 161 | ``` 162 | 163 | You can override the parameters for the pre-defined assistants as well. 164 | 165 | You can specify the default assistant to use by setting the `default_assistant` field. If you don't specify it, the default assistant is `general`. You can also specify the `model`, `temperature` and `top_p` to use for the assistant. If you don't specify them, the default values are used. These parameters can also be overridden by the command-line arguments. 166 | 167 | Example: 168 | 169 | ```yaml 170 | default_assistant: dev 171 | markdown: True 172 | openai_api_key: 173 | assistants: 174 | pirate: 175 | model: gpt-4 176 | temperature: 1.0 177 | messages: 178 | - { role: system, content: "You are a pirate." } 179 | ``` 180 | 181 | ``` 182 | $ gpt pirate 183 | 184 | > Arrrr 185 | Ahoy, matey! What be bringing ye to these here waters? Be it treasure or adventure ye seek, we be sailing the high seas together. Ready yer map and compass, for we have a long voyage ahead! 186 | ``` 187 | 188 | ### Read other context to the assistant with !include 189 | 190 | You can read in files to the assistant's context with !include . 191 | 192 | ```yaml 193 | default_assistant: dev 194 | markdown: True 195 | openai_api_key: 196 | assistants: 197 | pirate: 198 | model: gpt-4 199 | temperature: 1.0 200 | messages: 201 | - { role: system, content: !include "pirate.txt" } 202 | ``` 203 | 204 | ### Customize OpenAI API URL 205 | 206 | If you are using other models compatible with the OpenAI Python SDK, you can configure them by modifying the `openai_base_url` setting in the config file or using the `OPENAI_BASE_URL` environment variable . 207 | 208 | Example: 209 | 210 | ``` 211 | openai_base_url: https://your-custom-api-url.com/v1 212 | ``` 213 | 214 | Use `oai-compat:` prefix for the model name to pass non-GPT model names to the API. For example, to chat with Llama3-70b on [Together](https://together.ai), use the following command: 215 | 216 | ```bash 217 | OPENAI_API_KEY=$TOGETHER_API_KEY OPENAI_BASE_URL=https://api.together.xyz/v1 gpt general --model oai-compat:meta-llama/Llama-3-70b-chat-hf 218 | ``` 219 | 220 | The prefix is stripped before sending the request to the API. 221 | 222 | Similarly, use the `oai-azure:` model name prefix to use a model deployed via Azure Open AI. For example, `oai-azure:my-deployment-name`. 223 | 224 | With assistant configuration, you can override the base URL and API key for a specific assistant. 225 | 226 | ```yaml 227 | # ~/.config/gpt-cli/gpt.yml 228 | assistants: 229 | llama: 230 | model: oai-compat:meta-llama/llama-3.3-70b-instruct 231 | openai_base_url_override: https://openrouter.ai/api/v1 232 | openai_api_key_override: $OPENROUTER_API_KEY 233 | ``` 234 | 235 | ## Other chat bots 236 | 237 | ### Anthropic Claude 238 | 239 | To use Claude, you should have an API key from [Anthropic](https://console.anthropic.com/) (currently there is a waitlist for API access). After getting the API key, you can add an environment variable 240 | 241 | ```bash 242 | export ANTHROPIC_API_KEY= 243 | ``` 244 | 245 | or a config line in `~/.config/gpt-cli/gpt.yml`: 246 | 247 | ```yaml 248 | anthropic_api_key: 249 | ``` 250 | 251 | Now you should be able to run `gpt` with `--model claude-3-(opus|sonnet|haiku)-`. 252 | 253 | ```bash 254 | gpt --model claude-3-opus-20240229 255 | ``` 256 | 257 | #### Claude 3.7 Sonnet Extended Thinking Mode 258 | 259 | Claude 3.7 Sonnet supports an extended thinking mode, which shows Claude's reasoning process before delivering the final answer. This is useful for complex analysis, advanced STEM problems, and tasks with multiple constraints. 260 | 261 | Enable it with the `--thinking` parameter, specifying the token budget for the thinking process: 262 | 263 | ```bash 264 | gpt --model claude-3-7-sonnet-20250219 --thinking 32000 265 | ``` 266 | 267 | You can also configure thinking mode for specific assistants in your config: 268 | 269 | ```yaml 270 | assistants: 271 | math: 272 | model: claude-3-7-sonnet-20250219 273 | thinking_budget: 32000 274 | messages: 275 | - { role: system, content: "You are a math expert." } 276 | ``` 277 | 278 | **Note**: When thinking mode is enabled, the temperature is automatically set to 1.0 and top_p is unset as required by the Claude API. 279 | 280 | ### Google Gemini 281 | 282 | ```bash 283 | export GOOGLE_API_KEY= 284 | ``` 285 | 286 | or 287 | 288 | ```yaml 289 | google_api_key: 290 | ``` 291 | 292 | ### Cohere 293 | 294 | ```bash 295 | export COHERE_API_KEY= 296 | ``` 297 | 298 | or 299 | 300 | ```yaml 301 | cohere_api_key: 302 | ``` 303 | -------------------------------------------------------------------------------- /gptcli/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.3" 2 | -------------------------------------------------------------------------------- /gptcli/assistant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from attr import dataclass 4 | import platform 5 | from typing import Any, Dict, Iterator, Optional, TypedDict, List 6 | 7 | from gptcli.completion import ( 8 | CompletionEvent, 9 | CompletionProvider, 10 | Message, 11 | ) 12 | from gptcli.providers.google import GoogleCompletionProvider 13 | from gptcli.providers.llama import LLaMACompletionProvider 14 | from gptcli.providers.openai import OpenAICompletionProvider 15 | from gptcli.providers.anthropic import AnthropicCompletionProvider 16 | from gptcli.providers.cohere import CohereCompletionProvider 17 | from gptcli.providers.azure_openai import AzureOpenAICompletionProvider 18 | 19 | 20 | class AssistantConfig(TypedDict, total=False): 21 | messages: List[Message] 22 | model: str 23 | openai_base_url_override: Optional[str] 24 | openai_api_key_override: Optional[str] 25 | temperature: float 26 | top_p: float 27 | thinking_budget: Optional[int] 28 | 29 | 30 | CONFIG_DEFAULTS = { 31 | "model": "gpt-3.5-turbo", 32 | "temperature": 0.7, 33 | "top_p": 1.0, 34 | } 35 | 36 | DEFAULT_ASSISTANTS: Dict[str, AssistantConfig] = { 37 | "dev": { 38 | "messages": [ 39 | { 40 | "role": "system", 41 | "content": f"You are a helpful assistant who is an expert in software development. \ 42 | You are helping a user who is a software developer. Your responses are short and concise. \ 43 | You include code snippets when appropriate. Code snippets are formatted using Markdown \ 44 | with a correct language tag. User's `uname`: {platform.uname()}", 45 | }, 46 | { 47 | "role": "user", 48 | "content": "Your responses must be short and concise. Do not include explanations unless asked.", 49 | }, 50 | { 51 | "role": "assistant", 52 | "content": "Understood.", 53 | }, 54 | ], 55 | }, 56 | "general": { 57 | "messages": [], 58 | }, 59 | "bash": { 60 | "messages": [ 61 | { 62 | "role": "system", 63 | "content": f"You output only valid and correct shell commands according to the user's prompt. \ 64 | You don't provide any explanations or any other text that is not valid shell commands. \ 65 | User's `uname`: {platform.uname()}. User's `$SHELL`: {os.environ.get('SHELL')}.", 66 | } 67 | ], 68 | }, 69 | } 70 | 71 | 72 | def get_completion_provider( 73 | model: str, 74 | openai_base_url_override: Optional[str] = None, 75 | openai_api_key_override: Optional[str] = None, 76 | ) -> CompletionProvider: 77 | if ( 78 | model.startswith("gpt") 79 | or model.startswith("ft:gpt") 80 | or model.startswith("oai-compat:") 81 | or model.startswith("chatgpt") 82 | or model.startswith("o1") 83 | or model.startswith("o3") 84 | or model.startswith("o4") 85 | ): 86 | return OpenAICompletionProvider( 87 | openai_base_url_override, openai_api_key_override 88 | ) 89 | elif model.startswith("oai-azure:"): 90 | return AzureOpenAICompletionProvider() 91 | elif model.startswith("claude"): 92 | return AnthropicCompletionProvider() 93 | elif model.startswith("llama"): 94 | return LLaMACompletionProvider() 95 | elif model.startswith("command") or model.startswith("c4ai"): 96 | return CohereCompletionProvider() 97 | elif model.startswith("gemini") or model.startswith("gemma"): 98 | return GoogleCompletionProvider() 99 | else: 100 | raise ValueError(f"Unknown model: {model}") 101 | 102 | 103 | class Assistant: 104 | def __init__(self, config: AssistantConfig): 105 | self.config = config 106 | 107 | @classmethod 108 | def from_config(cls, name: str, config: AssistantConfig): 109 | config = config.copy() 110 | if name in DEFAULT_ASSISTANTS: 111 | # Merge the config with the default config 112 | # If a key is in both, use the value from the config 113 | default_config = DEFAULT_ASSISTANTS[name] 114 | for key in [*config.keys(), *default_config.keys()]: 115 | if config.get(key) is None: 116 | config[key] = default_config[key] 117 | 118 | return cls(config) 119 | 120 | def init_messages(self) -> List[Message]: 121 | return self.config.get("messages", [])[:] 122 | 123 | def _param(self, param: str) -> Any: 124 | # Use the value from the config if exists 125 | # Otherwise, use the default value 126 | return self.config.get(param, CONFIG_DEFAULTS.get(param, None)) 127 | 128 | def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]: 129 | model = self._param("model") 130 | completion_provider = get_completion_provider( 131 | model, 132 | self._param("openai_base_url_override"), 133 | self._param("openai_api_key_override"), 134 | ) 135 | 136 | args = { 137 | "model": model, 138 | "temperature": float(self._param("temperature")), 139 | "top_p": float(self._param("top_p")), 140 | } 141 | 142 | # Add thinking budget if it's specified and we're using Claude 3.7 143 | thinking_budget = self.config.get("thinking_budget") 144 | if thinking_budget is not None and "claude-3-7" in model: 145 | args["thinking_budget"] = thinking_budget 146 | 147 | return completion_provider.complete( 148 | messages, 149 | args, 150 | stream, 151 | ) 152 | 153 | 154 | @dataclass 155 | class AssistantGlobalArgs: 156 | assistant_name: str 157 | model: Optional[str] = None 158 | temperature: Optional[float] = None 159 | top_p: Optional[float] = None 160 | thinking_budget: Optional[int] = None 161 | 162 | 163 | def init_assistant( 164 | args: AssistantGlobalArgs, custom_assistants: Dict[str, AssistantConfig] 165 | ) -> Assistant: 166 | name = args.assistant_name 167 | if name in custom_assistants: 168 | assistant = Assistant.from_config(name, custom_assistants[name]) 169 | elif name in DEFAULT_ASSISTANTS: 170 | assistant = Assistant.from_config(name, DEFAULT_ASSISTANTS[name]) 171 | else: 172 | print(f"Unknown assistant: {name}") 173 | sys.exit(1) 174 | 175 | # Override config with command line arguments 176 | if args.temperature is not None: 177 | assistant.config["temperature"] = args.temperature 178 | if args.model is not None: 179 | assistant.config["model"] = args.model 180 | if args.top_p is not None: 181 | assistant.config["top_p"] = args.top_p 182 | if args.thinking_budget is not None: 183 | assistant.config["thinking_budget"] = args.thinking_budget 184 | return assistant 185 | -------------------------------------------------------------------------------- /gptcli/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from openai import BadRequestError, OpenAIError 4 | from prompt_toolkit import PromptSession 5 | from prompt_toolkit.history import FileHistory 6 | from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent 7 | from prompt_toolkit.key_binding.bindings import named_commands 8 | from rich.console import Console 9 | from rich.live import Live 10 | from rich.markdown import Markdown 11 | from rich.text import Text 12 | 13 | from gptcli.completion import ToolCallEvent 14 | from gptcli.session import ( 15 | ALL_COMMANDS, 16 | COMMAND_CLEAR, 17 | COMMAND_QUIT, 18 | COMMAND_RERUN, 19 | ChatListener, 20 | InvalidArgumentError, 21 | ResponseStreamer, 22 | UserInputProvider, 23 | ) 24 | 25 | TERMINAL_WELCOME = """ 26 | Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear 27 | the conversation, `:r` or Ctrl-R to re-generate the last response. 28 | To enter multi-line mode, enter a backslash `\\` followed by a new line. 29 | Exit the multi-line mode by pressing ESC and then Enter (Meta+Enter). 30 | Try `:?` for help. 31 | """ 32 | 33 | 34 | class StreamingMarkdownPrinter: 35 | def __init__(self, console: Console, markdown: bool, style: str = "green"): 36 | self.console = console 37 | self.current_text = "" 38 | self.markdown = markdown 39 | self.style = style 40 | self.live: Optional[Live] = None 41 | 42 | def __enter__(self) -> "StreamingMarkdownPrinter": 43 | if self.markdown: 44 | self.live = Live( 45 | console=self.console, auto_refresh=False, vertical_overflow="visible" 46 | ) 47 | self.live.__enter__() 48 | return self 49 | 50 | def print(self, text: str): 51 | self.current_text += text 52 | if self.markdown: 53 | assert self.live 54 | content = Markdown(self.current_text, style=self.style) 55 | self.live.update(content) 56 | self.live.refresh() 57 | else: 58 | self.console.print(Text(text, style=self.style), end="") 59 | 60 | def __exit__(self, *args): 61 | if self.markdown: 62 | assert self.live 63 | self.live.__exit__(*args) 64 | self.console.print() 65 | 66 | 67 | class CLIResponseStreamer(ResponseStreamer): 68 | def __init__(self, console: Console, markdown: bool): 69 | self.console = console 70 | self.markdown = markdown 71 | self.printer = StreamingMarkdownPrinter(self.console, self.markdown) 72 | self.thinking_printer = None 73 | self.first_token = True 74 | 75 | def __enter__(self): 76 | self.printer.__enter__() 77 | return self 78 | 79 | def on_next_token(self, token: str): 80 | if self.first_token and token.startswith(" "): 81 | token = token[1:] 82 | self.first_token = False 83 | if self.thinking_printer is not None: 84 | self.printer.print("\n") 85 | self.thinking_printer.__exit__() 86 | self.thinking_printer = None 87 | self.printer.print(token) 88 | 89 | def on_thinking_token(self, token: str): 90 | if self.thinking_printer is None: 91 | self.console.print("[bold blue]Thinking...[/bold blue]", end="\n\n") 92 | self.thinking_printer = StreamingMarkdownPrinter( 93 | self.console, self.markdown, style="dim blue" 94 | ) 95 | self.thinking_printer.__enter__() 96 | self.thinking_printer.print(token) 97 | 98 | def on_tool_call(self, tool_call: ToolCallEvent): 99 | self.console.print(f"[bold green]{tool_call.text}[/bold green]", end="\n") 100 | 101 | def __exit__(self, *args): 102 | if self.thinking_printer: 103 | self.thinking_printer.__exit__(*args) 104 | self.printer.__exit__(*args) 105 | 106 | 107 | class CLIChatListener(ChatListener): 108 | def __init__(self, markdown: bool): 109 | self.markdown = markdown 110 | self.console = Console() 111 | 112 | def on_chat_start(self): 113 | console = Console(width=80) 114 | console.print(Markdown(TERMINAL_WELCOME)) 115 | 116 | def on_chat_clear(self): 117 | self.console.print("[bold]Cleared the conversation.[/bold]") 118 | 119 | def on_chat_rerun(self, success: bool): 120 | if success: 121 | self.console.print("[bold]Re-running the last message.[/bold]") 122 | else: 123 | self.console.print("[bold]Nothing to re-run.[/bold]") 124 | 125 | def on_error(self, e: Exception): 126 | if isinstance(e, BadRequestError): 127 | self.console.print( 128 | f"[red]Request Error. The last prompt was not saved: {type(e)}: {e}[/red]" 129 | ) 130 | elif isinstance(e, OpenAIError): 131 | self.console.print( 132 | f"[red]API Error. Type `r` or Ctrl-R to try again: {type(e)}: {e}[/red]" 133 | ) 134 | elif isinstance(e, InvalidArgumentError): 135 | self.console.print(f"[red]{e.message}[/red]") 136 | else: 137 | self.console.print(f"[red]Error: {type(e)}: {e}[/red]") 138 | 139 | def response_streamer(self) -> ResponseStreamer: 140 | return CLIResponseStreamer(self.console, self.markdown) 141 | 142 | 143 | class CLIFileHistory(FileHistory): 144 | def append_string(self, string: str) -> None: 145 | if string in ALL_COMMANDS: 146 | return 147 | return super().append_string(string) 148 | 149 | 150 | class CLIUserInputProvider(UserInputProvider): 151 | def __init__(self, history_filename) -> None: 152 | self.prompt_session = PromptSession[str]( 153 | history=CLIFileHistory(history_filename) 154 | ) 155 | 156 | def get_user_input(self) -> str: 157 | while (next_user_input := self._request_input()) == "": 158 | pass 159 | 160 | return next_user_input 161 | 162 | def prompt(self, multiline=False): 163 | bindings = KeyBindings() 164 | 165 | bindings.add("c-a")(named_commands.get_by_name("beginning-of-line")) 166 | bindings.add("c-b")(named_commands.get_by_name("backward-char")) 167 | bindings.add("c-e")(named_commands.get_by_name("end-of-line")) 168 | bindings.add("c-f")(named_commands.get_by_name("forward-char")) 169 | bindings.add("c-left")(named_commands.get_by_name("backward-word")) 170 | bindings.add("c-right")(named_commands.get_by_name("forward-word")) 171 | 172 | @bindings.add("c-c") 173 | def _(event: KeyPressEvent): 174 | if len(event.current_buffer.text) == 0 and not multiline: 175 | event.current_buffer.text = COMMAND_CLEAR[0] 176 | event.current_buffer.cursor_right(len(COMMAND_CLEAR[0])) 177 | else: 178 | event.app.exit(exception=KeyboardInterrupt, style="class:aborting") 179 | 180 | @bindings.add("c-d") 181 | def _(event: KeyPressEvent): 182 | if len(event.current_buffer.text) == 0: 183 | if not multiline: 184 | event.current_buffer.text = COMMAND_QUIT[0] 185 | event.current_buffer.validate_and_handle() 186 | 187 | @bindings.add("c-r") 188 | def _(event: KeyPressEvent): 189 | if len(event.current_buffer.text) == 0: 190 | event.current_buffer.text = COMMAND_RERUN[0] 191 | event.current_buffer.validate_and_handle() 192 | 193 | try: 194 | return self.prompt_session.prompt( 195 | "> " if not multiline else "multiline> ", 196 | vi_mode=True, 197 | multiline=multiline, 198 | enable_open_in_editor=True, 199 | key_bindings=bindings, 200 | ) 201 | except KeyboardInterrupt: 202 | return "" 203 | 204 | def _request_input(self): 205 | line = self.prompt() 206 | 207 | if line != "\\": 208 | return line 209 | 210 | return self.prompt(multiline=True) 211 | -------------------------------------------------------------------------------- /gptcli/completion.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Iterator, List, Literal, TypedDict, Union 3 | 4 | from attr import dataclass 5 | 6 | 7 | class Message(TypedDict): 8 | role: str 9 | content: str 10 | 11 | 12 | class Pricing(TypedDict): 13 | prompt: float 14 | response: float 15 | 16 | 17 | @dataclass 18 | class MessageDeltaEvent: 19 | text: str 20 | type: Literal["message_delta"] = "message_delta" 21 | 22 | 23 | @dataclass 24 | class ThinkingDeltaEvent: 25 | text: str 26 | type: Literal["thinking_delta"] = "thinking_delta" 27 | 28 | 29 | @dataclass 30 | class ToolCallEvent: 31 | text: str 32 | type: Literal["tool_call"] = "tool_call" 33 | 34 | 35 | @dataclass 36 | class UsageEvent: 37 | prompt_tokens: int 38 | completion_tokens: int 39 | total_tokens: int 40 | cost: float 41 | type: Literal["usage"] = "usage" 42 | 43 | @staticmethod 44 | def with_pricing( 45 | prompt_tokens: int, completion_tokens: int, total_tokens: int, pricing: Pricing 46 | ) -> "UsageEvent": 47 | return UsageEvent( 48 | prompt_tokens=prompt_tokens, 49 | completion_tokens=completion_tokens, 50 | total_tokens=total_tokens, 51 | cost=prompt_tokens * pricing["prompt"] 52 | + completion_tokens * pricing["response"], 53 | ) 54 | 55 | 56 | CompletionEvent = Union[ 57 | MessageDeltaEvent, ThinkingDeltaEvent, UsageEvent, ToolCallEvent 58 | ] 59 | 60 | 61 | class CompletionProvider: 62 | @abstractmethod 63 | def complete( 64 | self, messages: List[Message], args: dict, stream: bool = False 65 | ) -> Iterator[CompletionEvent]: 66 | pass 67 | 68 | 69 | class CompletionError(Exception): 70 | pass 71 | 72 | 73 | class BadRequestError(CompletionError): 74 | pass 75 | -------------------------------------------------------------------------------- /gptcli/composite.py: -------------------------------------------------------------------------------- 1 | from gptcli.completion import Message, ToolCallEvent, UsageEvent 2 | from gptcli.session import ChatListener, ResponseStreamer 3 | 4 | 5 | from typing import List, Optional 6 | 7 | 8 | class CompositeResponseStreamer(ResponseStreamer): 9 | def __init__(self, streamers: List[ResponseStreamer]): 10 | self.streamers = streamers 11 | 12 | def __enter__(self): 13 | for streamer in self.streamers: 14 | streamer.__enter__() 15 | return self 16 | 17 | def on_next_token(self, token: str): 18 | for streamer in self.streamers: 19 | streamer.on_next_token(token) 20 | 21 | def on_thinking_token(self, token: str): 22 | for streamer in self.streamers: 23 | streamer.on_thinking_token(token) 24 | 25 | def on_tool_call(self, tool_call: ToolCallEvent): 26 | for streamer in self.streamers: 27 | streamer.on_tool_call(tool_call) 28 | 29 | def __exit__(self, *args): 30 | for streamer in self.streamers: 31 | streamer.__exit__(*args) 32 | 33 | 34 | class CompositeChatListener(ChatListener): 35 | def __init__(self, listeners: List[ChatListener]): 36 | self.listeners = listeners 37 | 38 | def on_chat_start(self): 39 | for listener in self.listeners: 40 | listener.on_chat_start() 41 | 42 | def on_chat_clear(self): 43 | for listener in self.listeners: 44 | listener.on_chat_clear() 45 | 46 | def on_chat_rerun(self, success: bool): 47 | for listener in self.listeners: 48 | listener.on_chat_rerun(success) 49 | 50 | def on_error(self, e: Exception): 51 | for listener in self.listeners: 52 | listener.on_error(e) 53 | 54 | def response_streamer(self) -> ResponseStreamer: 55 | return CompositeResponseStreamer( 56 | [listener.response_streamer() for listener in self.listeners] 57 | ) 58 | 59 | def on_chat_message(self, message: Message): 60 | for listener in self.listeners: 61 | listener.on_chat_message(message) 62 | 63 | def on_chat_response( 64 | self, 65 | messages: List[Message], 66 | response: Message, 67 | usage: Optional[UsageEvent], 68 | ): 69 | for listener in self.listeners: 70 | listener.on_chat_response(messages, response, usage) 71 | -------------------------------------------------------------------------------- /gptcli/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Optional 3 | 4 | import yaml 5 | from attr import dataclass 6 | 7 | from gptcli.assistant import AssistantConfig 8 | from gptcli.providers.llama import LLaMAModelConfig 9 | 10 | CONFIG_FILE_PATHS = [ 11 | os.path.join(os.path.expanduser("~"), ".config", "gpt-cli", "gpt.yml"), 12 | os.path.join(os.path.expanduser("~"), ".gptrc"), 13 | ] 14 | 15 | 16 | @dataclass 17 | class GptCliConfig: 18 | default_assistant: str = "general" 19 | markdown: bool = True 20 | show_price: bool = True 21 | api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") 22 | openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") 23 | openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL") 24 | openai_azure_api_version: str = "2024-10-21" 25 | anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") 26 | google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") 27 | cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") 28 | log_file: Optional[str] = None 29 | log_level: str = "INFO" 30 | assistants: Dict[str, AssistantConfig] = {} 31 | interactive: Optional[bool] = None 32 | llama_models: Optional[Dict[str, LLaMAModelConfig]] = None 33 | 34 | 35 | def choose_config_file(paths: List[str]) -> str: 36 | for path in paths: 37 | if os.path.isfile(path): 38 | return path 39 | return "" 40 | 41 | 42 | # Custom YAML Loader with !include support 43 | class CustomLoader(yaml.SafeLoader): 44 | pass 45 | 46 | 47 | def include_constructor(loader, node): 48 | # Get the file path from the node 49 | file_path = loader.construct_scalar(node) 50 | # Read and return the content of the included file 51 | with open(file_path, "r") as include_file: 52 | return include_file.read() 53 | 54 | 55 | # Register the !include constructor 56 | CustomLoader.add_constructor("!include", include_constructor) 57 | 58 | 59 | def read_yaml_config(file_path: str) -> GptCliConfig: 60 | with open(file_path, "r") as file: 61 | config = yaml.load(file, Loader=CustomLoader) 62 | return GptCliConfig(**config) 63 | -------------------------------------------------------------------------------- /gptcli/cost.py: -------------------------------------------------------------------------------- 1 | from gptcli.assistant import Assistant 2 | from gptcli.completion import Message, UsageEvent 3 | from gptcli.session import ChatListener 4 | 5 | from rich.console import Console 6 | 7 | import logging 8 | from typing import List, Optional 9 | 10 | 11 | class PriceChatListener(ChatListener): 12 | def __init__(self, assistant: Assistant): 13 | self.assistant = assistant 14 | self.current_spend = 0 15 | self.logger = logging.getLogger("gptcli-price") 16 | self.console = Console() 17 | 18 | def on_chat_clear(self): 19 | self.current_spend = 0 20 | 21 | def on_chat_response( 22 | self, 23 | messages: List[Message], 24 | response: Message, 25 | usage: Optional[UsageEvent] = None, 26 | ): 27 | if usage is None: 28 | return 29 | 30 | model = self.assistant._param("model") 31 | num_tokens = usage.total_tokens 32 | cost = usage.cost 33 | 34 | if cost is None: 35 | self.logger.error(f"Cannot get cost information for model {model}") 36 | return 37 | 38 | self.current_spend += cost 39 | self.logger.info(f"Token usage {num_tokens}") 40 | self.logger.info(f"Message price (model: {model}): ${cost:.3f}") 41 | self.logger.info(f"Current spend: ${self.current_spend:.3f}") 42 | self.console.print( 43 | f"Tokens: {num_tokens} | Price: ${cost:.3f} | Total: ${self.current_spend:.3f}", 44 | justify="right", 45 | style="dim", 46 | ) 47 | -------------------------------------------------------------------------------- /gptcli/gpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | 5 | MIN_PYTHON = (3, 9) 6 | if sys.version_info < MIN_PYTHON: 7 | sys.exit("Python %s.%s or later is required.\n" % MIN_PYTHON) 8 | 9 | import os 10 | from typing import cast 11 | import openai 12 | import argparse 13 | import sys 14 | import logging 15 | import datetime 16 | import gptcli.providers.anthropic 17 | import gptcli.providers.cohere 18 | import gptcli.providers.google as google 19 | from gptcli.assistant import ( 20 | Assistant, 21 | DEFAULT_ASSISTANTS, 22 | AssistantGlobalArgs, 23 | init_assistant, 24 | ) 25 | from gptcli.cli import ( 26 | CLIChatListener, 27 | CLIUserInputProvider, 28 | ) 29 | from gptcli.composite import CompositeChatListener 30 | from gptcli.config import ( 31 | CONFIG_FILE_PATHS, 32 | GptCliConfig, 33 | choose_config_file, 34 | read_yaml_config, 35 | ) 36 | from gptcli.providers.llama import init_llama_models 37 | from gptcli.logging_utils import LoggingChatListener 38 | from gptcli.cost import PriceChatListener 39 | from gptcli.session import ChatSession 40 | from gptcli.shell import execute, simple_response 41 | 42 | 43 | logger = logging.getLogger("gptcli") 44 | 45 | default_exception_handler = sys.excepthook 46 | 47 | 48 | def exception_handler(type, value, traceback): 49 | logger.exception("Uncaught exception", exc_info=(type, value, traceback)) 50 | print("An uncaught exception occurred. Please report this issue on GitHub.") 51 | default_exception_handler(type, value, traceback) 52 | 53 | 54 | sys.excepthook = exception_handler 55 | 56 | 57 | def parse_args(config: GptCliConfig): 58 | parser = argparse.ArgumentParser( 59 | description="Run a chat session with ChatGPT. See https://github.com/kharvd/gpt-cli for more information." 60 | ) 61 | parser.add_argument( 62 | "assistant_name", 63 | type=str, 64 | default=config.default_assistant, 65 | nargs="?", 66 | choices=list(set([*DEFAULT_ASSISTANTS.keys(), *config.assistants.keys()])), 67 | help="The name of assistant to use. `general` (default) is a generally helpful assistant, `dev` is a software \ 68 | development assistant with shorter responses. You can specify your own assistants in the config file \ 69 | ~/.config/gpt-cli/gpt.yml. See the README for more information.", 70 | ) 71 | parser.add_argument( 72 | "--no_markdown", 73 | action="store_false", 74 | dest="markdown", 75 | help="Disable markdown formatting in the chat session.", 76 | default=config.markdown, 77 | ) 78 | parser.add_argument( 79 | "--model", 80 | type=str, 81 | default=None, 82 | help="The model to use for the chat session. Overrides the default model defined for the assistant.", 83 | ) 84 | parser.add_argument( 85 | "--temperature", 86 | type=float, 87 | default=None, 88 | help="The temperature to use for the chat session. Overrides the default temperature defined \ 89 | for the assistant.", 90 | ) 91 | parser.add_argument( 92 | "--top_p", 93 | type=float, 94 | default=None, 95 | help="The top_p to use for the chat session. Overrides the default top_p defined for the assistant.", 96 | ) 97 | parser.add_argument( 98 | "--thinking", 99 | type=int, 100 | dest="thinking_budget", 101 | default=None, 102 | help="Enable Claude's extended thinking mode with the specified token budget. Only applies to Claude 3.7 models.", 103 | ) 104 | parser.add_argument( 105 | "--log_file", 106 | type=str, 107 | default=config.log_file, 108 | help="The file to write logs to. Supports strftime format codes.", 109 | ) 110 | parser.add_argument( 111 | "--log_level", 112 | type=str, 113 | default=config.log_level, 114 | choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], 115 | help="The log level to use", 116 | ) 117 | parser.add_argument( 118 | "--prompt", 119 | "-p", 120 | type=str, 121 | action="append", 122 | default=None, 123 | help="If specified, will not start an interactive chat session and instead will print the response to standard \ 124 | output and exit. May be specified multiple times. Use `-` to read the prompt from standard input. \ 125 | Implies --no_markdown.", 126 | ) 127 | parser.add_argument( 128 | "--execute", 129 | "-e", 130 | type=str, 131 | default=None, 132 | help="If specified, passes the prompt to the assistant and allows the user to edit the produced shell command \ 133 | before executing it. Implies --no_stream. Use `-` to read the prompt from standard input.", 134 | ) 135 | parser.add_argument( 136 | "--no_stream", 137 | action="store_true", 138 | default=False, 139 | help="If specified, will not stream the response to standard output. This is useful if you want to use the \ 140 | response in a script. Ignored when the --prompt option is not specified.", 141 | ) 142 | parser.add_argument( 143 | "--no_price", 144 | action="store_false", 145 | dest="show_price", 146 | help="Disable price logging.", 147 | default=config.show_price, 148 | ) 149 | parser.add_argument( 150 | "--version", 151 | "-v", 152 | action="version", 153 | version=f"gpt-cli v{gptcli.__version__}", 154 | help="Print the version number and exit.", 155 | ) 156 | 157 | return parser.parse_args() 158 | 159 | 160 | def validate_args(args): 161 | if args.prompt is not None and args.execute is not None: 162 | print( 163 | "The --prompt and --execute options are mutually exclusive. Please specify only one of them." 164 | ) 165 | sys.exit(1) 166 | 167 | 168 | def main(): 169 | config_file_path = choose_config_file(CONFIG_FILE_PATHS) 170 | if config_file_path: 171 | config = read_yaml_config(config_file_path) 172 | else: 173 | config = GptCliConfig() 174 | args = parse_args(config) 175 | 176 | if args.log_file is not None: 177 | filename = datetime.datetime.now().strftime(args.log_file) 178 | logging.basicConfig( 179 | filename=filename, 180 | level=args.log_level, 181 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 182 | ) 183 | # Disable overly verbose logging for markdown_it 184 | logging.getLogger("markdown_it").setLevel(logging.INFO) 185 | 186 | if config.openai_base_url: 187 | openai.base_url = config.openai_base_url 188 | 189 | if config.openai_azure_api_version: 190 | openai.api_version = config.openai_azure_api_version 191 | 192 | if config.api_key: 193 | openai.api_key = config.api_key 194 | elif config.openai_api_key: 195 | openai.api_key = config.openai_api_key 196 | 197 | if config.anthropic_api_key: 198 | gptcli.providers.anthropic.api_key = config.anthropic_api_key 199 | 200 | if config.cohere_api_key: 201 | gptcli.providers.cohere.api_key = config.cohere_api_key 202 | 203 | if config.google_api_key: 204 | google.api_key = config.google_api_key 205 | 206 | if config.llama_models is not None: 207 | init_llama_models(config.llama_models) 208 | 209 | assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants) 210 | 211 | if args.prompt is not None: 212 | run_non_interactive(args, assistant) 213 | elif args.execute is not None: 214 | run_execute(args, assistant) 215 | else: 216 | run_interactive(args, assistant) 217 | 218 | 219 | def run_execute(args, assistant): 220 | logger.info( 221 | "Starting a non-interactive execution session with prompt '%s'. Assistant config: %s", 222 | args.prompt, 223 | assistant.config, 224 | ) 225 | if args.execute == "-": 226 | args.execute = "".join(sys.stdin.readlines()) 227 | execute(assistant, args.execute) 228 | 229 | 230 | def run_non_interactive(args, assistant): 231 | logger.info( 232 | "Starting a non-interactive session with prompt '%s'. Assistant config: %s", 233 | args.prompt, 234 | assistant.config, 235 | ) 236 | if "-" in args.prompt: 237 | args.prompt[args.prompt.index("-")] = "".join(sys.stdin.readlines()) 238 | 239 | simple_response(assistant, "\n".join(args.prompt), stream=not args.no_stream) 240 | 241 | 242 | class CLIChatSession(ChatSession): 243 | def __init__( 244 | self, assistant: Assistant, markdown: bool, show_price: bool, stream: bool 245 | ): 246 | listeners = [ 247 | CLIChatListener(markdown), 248 | LoggingChatListener(), 249 | ] 250 | 251 | if show_price: 252 | listeners.append(PriceChatListener(assistant)) 253 | 254 | listener = CompositeChatListener(listeners) 255 | super().__init__(assistant, listener, stream) 256 | 257 | 258 | def run_interactive(args, assistant): 259 | logger.info("Starting a new chat session. Assistant config: %s", assistant.config) 260 | session = CLIChatSession( 261 | assistant=assistant, 262 | markdown=args.markdown, 263 | show_price=args.show_price, 264 | stream=not args.no_stream, 265 | ) 266 | history_filename = os.path.expanduser("~/.config/gpt-cli/history") 267 | os.makedirs(os.path.dirname(history_filename), exist_ok=True) 268 | input_provider = CLIUserInputProvider(history_filename=history_filename) 269 | session.loop(input_provider) 270 | 271 | 272 | if __name__ == "__main__": 273 | main() 274 | -------------------------------------------------------------------------------- /gptcli/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from gptcli.completion import Message 3 | from gptcli.session import ChatListener 4 | 5 | 6 | class LoggingChatListener(ChatListener): 7 | def __init__(self): 8 | self.logger = logging.getLogger("gptcli-session") 9 | 10 | def on_chat_start(self): 11 | self.logger.info("Chat started") 12 | 13 | def on_chat_clear(self): 14 | self.logger.info("Cleared the conversation.") 15 | 16 | def on_chat_rerun(self, success: bool): 17 | if success: 18 | self.logger.info("Re-generating the last message.") 19 | 20 | def on_error(self, e: Exception): 21 | self.logger.exception(e) 22 | 23 | def on_chat_message(self, message: Message): 24 | self.logger.info(f"{message['role']}: {message['content']}") 25 | -------------------------------------------------------------------------------- /gptcli/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kharvd/gpt-cli/bff5d87311abf00a900fb17595d443953750e29c/gptcli/providers/__init__.py -------------------------------------------------------------------------------- /gptcli/providers/anthropic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Iterator, List, Optional 3 | import anthropic 4 | 5 | from gptcli.completion import ( 6 | CompletionEvent, 7 | CompletionProvider, 8 | Message, 9 | CompletionError, 10 | BadRequestError, 11 | MessageDeltaEvent, 12 | Pricing, 13 | UsageEvent, 14 | ThinkingDeltaEvent, 15 | ) 16 | 17 | api_key = os.environ.get("ANTHROPIC_API_KEY") 18 | 19 | 20 | def get_client(): 21 | if not api_key: 22 | raise ValueError("ANTHROPIC_API_KEY environment variable not set") 23 | 24 | return anthropic.Anthropic(api_key=api_key) 25 | 26 | 27 | class AnthropicCompletionProvider(CompletionProvider): 28 | def complete( 29 | self, messages: List[Message], args: dict, stream: bool = False 30 | ) -> Iterator[CompletionEvent]: 31 | # Default max tokens and max allowed by Claude API 32 | DEFAULT_MAX_TOKENS = 4096 33 | CLAUDE_MAX_TOKENS_LIMIT = 64000 34 | 35 | # Set initial max_tokens value 36 | max_tokens = DEFAULT_MAX_TOKENS 37 | 38 | # If thinking mode is enabled, adjust max_tokens accordingly 39 | if "thinking_budget" in args and "claude-3-7" in args["model"]: 40 | thinking_budget = args["thinking_budget"] 41 | # Max tokens must be greater than thinking budget 42 | # Calculate required max_tokens, but don't exceed the API limit 43 | response_tokens = min( 44 | DEFAULT_MAX_TOKENS, CLAUDE_MAX_TOKENS_LIMIT - thinking_budget 45 | ) 46 | max_tokens = min(thinking_budget + response_tokens, CLAUDE_MAX_TOKENS_LIMIT) 47 | 48 | kwargs = { 49 | "stop_sequences": [anthropic.HUMAN_PROMPT], 50 | "max_tokens": max_tokens, 51 | "model": args["model"], 52 | } 53 | 54 | # Check if thinking mode is enabled 55 | thinking_enabled = "thinking_budget" in args and "claude-3-7" in args["model"] 56 | 57 | # Handle temperature and top_p 58 | if thinking_enabled: 59 | # When thinking is enabled, temperature must be set to 1.0 and top_p must be unset 60 | kwargs["temperature"] = 1.0 61 | # Do not set top_p in this case 62 | else: 63 | # Normal mode - apply user settings 64 | if "temperature" in args: 65 | kwargs["temperature"] = args["temperature"] 66 | if "top_p" in args: 67 | kwargs["top_p"] = args["top_p"] 68 | 69 | # Handle thinking mode 70 | if thinking_enabled: 71 | kwargs["thinking"] = { 72 | "type": "enabled", 73 | "budget_tokens": args["thinking_budget"], 74 | } 75 | 76 | if len(messages) > 0 and messages[0]["role"] == "system": 77 | kwargs["system"] = messages[0]["content"] 78 | messages = messages[1:] 79 | 80 | kwargs["messages"] = messages 81 | 82 | client = get_client() 83 | input_tokens = None 84 | try: 85 | if stream: 86 | with client.messages.stream(**kwargs) as completion: 87 | for event in completion: 88 | if event.type == "content_block_delta": 89 | if event.delta.type == "thinking_delta": 90 | yield ThinkingDeltaEvent(event.delta.thinking) 91 | elif event.delta.type == "text_delta": 92 | yield MessageDeltaEvent(event.delta.text) 93 | # Skip other delta types 94 | if event.type == "message_start": 95 | input_tokens = event.message.usage.input_tokens 96 | if ( 97 | event.type == "message_delta" 98 | and (pricing := claude_pricing(args["model"])) 99 | and input_tokens 100 | ): 101 | yield UsageEvent.with_pricing( 102 | prompt_tokens=input_tokens, 103 | completion_tokens=event.usage.output_tokens, 104 | total_tokens=input_tokens + event.usage.output_tokens, 105 | pricing=pricing, 106 | ) 107 | 108 | else: 109 | response = client.messages.create(**kwargs, stream=False) 110 | yield MessageDeltaEvent( 111 | "".join( 112 | c.text if c.type == "text" else "" for c in response.content 113 | ) 114 | ) 115 | if pricing := claude_pricing(args["model"]): 116 | yield UsageEvent.with_pricing( 117 | prompt_tokens=response.usage.input_tokens, 118 | completion_tokens=response.usage.output_tokens, 119 | total_tokens=response.usage.input_tokens 120 | + response.usage.output_tokens, 121 | pricing=pricing, 122 | ) 123 | except anthropic.BadRequestError as e: 124 | raise BadRequestError(e.message) from e 125 | except anthropic.APIError as e: 126 | raise CompletionError(e.message) from e 127 | 128 | 129 | CLAUDE_PRICE_PER_TOKEN: Pricing = { 130 | "prompt": 11.02 / 1_000_000, 131 | "response": 32.68 / 1_000_000, 132 | } 133 | 134 | CLAUDE_INSTANT_PRICE_PER_TOKEN: Pricing = { 135 | "prompt": 1.63 / 1_000_000, 136 | "response": 5.51 / 1_000_000, 137 | } 138 | 139 | CLAUDE_3_OPUS_PRICING: Pricing = { 140 | "prompt": 15.0 / 1_000_000, 141 | "response": 75.0 / 1_000_000, 142 | } 143 | 144 | CLAUDE_3_SONNET_PRICING: Pricing = { 145 | "prompt": 3.0 / 1_000_000, 146 | "response": 15.0 / 1_000_000, 147 | } 148 | 149 | CLAUDE_3_7_SONNET_PRICING: Pricing = { 150 | "prompt": 3.0 / 1_000_000, 151 | "response": 15.0 / 1_000_000, 152 | } 153 | 154 | CLAUDE_3_HAIKU_PRICING: Pricing = { 155 | "prompt": 0.25 / 1_000_000, 156 | "response": 1.25 / 1_000_000, 157 | } 158 | 159 | 160 | def claude_pricing(model: str) -> Optional[Pricing]: 161 | if "instant" in model: 162 | pricing = CLAUDE_INSTANT_PRICE_PER_TOKEN 163 | elif "claude-3" in model: 164 | if "opus" in model: 165 | pricing = CLAUDE_3_OPUS_PRICING 166 | elif "3-7-sonnet" in model: 167 | pricing = CLAUDE_3_7_SONNET_PRICING 168 | elif "sonnet" in model: 169 | pricing = CLAUDE_3_SONNET_PRICING 170 | elif "haiku" in model: 171 | pricing = CLAUDE_3_HAIKU_PRICING 172 | else: 173 | return None 174 | elif "claude-2" in model: 175 | pricing = CLAUDE_PRICE_PER_TOKEN 176 | else: 177 | return None 178 | return pricing 179 | -------------------------------------------------------------------------------- /gptcli/providers/azure_openai.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai import AzureOpenAI 3 | from gptcli.providers.openai import OpenAICompletionProvider 4 | 5 | 6 | class AzureOpenAICompletionProvider(OpenAICompletionProvider): 7 | def __init__(self): 8 | super().__init__() 9 | self.client = AzureOpenAI( 10 | api_key=openai.api_key, 11 | base_url=openai.base_url, 12 | api_version=openai.api_version, 13 | ) 14 | -------------------------------------------------------------------------------- /gptcli/providers/cohere.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cohere 3 | from typing import Iterator, List 4 | 5 | from gptcli.completion import ( 6 | CompletionEvent, 7 | CompletionProvider, 8 | Message, 9 | CompletionError, 10 | BadRequestError, 11 | MessageDeltaEvent, 12 | Pricing, 13 | UsageEvent, 14 | ) 15 | 16 | api_key = os.environ.get("COHERE_API_KEY") 17 | 18 | ROLE_MAP = { 19 | "system": "SYSTEM", 20 | "user": "USER", 21 | "assistant": "CHATBOT", 22 | } 23 | 24 | 25 | def map_message(message: Message) -> cohere.Message: 26 | if message["role"] == "system": 27 | return cohere.Message_System(message=message["content"]) 28 | elif message["role"] == "user": 29 | return cohere.Message_User(message=message["content"]) 30 | elif message["role"] == "assistant": 31 | return cohere.Message_Chatbot(message=message["content"]) 32 | else: 33 | raise ValueError(f"Unknown message role: {message['role']}") 34 | 35 | 36 | class CohereCompletionProvider(CompletionProvider): 37 | def __init__(self): 38 | self.client = cohere.Client(api_key=api_key) 39 | 40 | def complete( 41 | self, messages: List[Message], args: dict, stream: bool = False 42 | ) -> Iterator[CompletionEvent]: 43 | kwargs = {} 44 | if "temperature" in args: 45 | kwargs["temperature"] = args["temperature"] 46 | if "top_p" in args: 47 | kwargs["p"] = args["top_p"] 48 | 49 | model = args["model"] 50 | 51 | if messages[0]["role"] == "system": 52 | kwargs["preamble"] = messages[0]["content"] 53 | messages = messages[1:] 54 | 55 | message = messages[-1] 56 | assert message["role"] == "user", "Last message must be user message" 57 | 58 | chat_history = [map_message(m) for m in messages[:-1]] 59 | 60 | try: 61 | if stream: 62 | response_iter = self.client.chat_stream( 63 | chat_history=chat_history, 64 | message=message["content"], 65 | model=model, 66 | **kwargs, 67 | ) 68 | 69 | for response in response_iter: 70 | if response.event_type == "text-generation": 71 | yield MessageDeltaEvent(response.text) 72 | 73 | if ( 74 | response.event_type == "stream-end" 75 | and response.response.meta 76 | and response.response.meta.tokens 77 | and (pricing := COHERE_PRICING.get(args["model"])) 78 | ): 79 | input_tokens = int( 80 | response.response.meta.tokens.input_tokens or 0 81 | ) 82 | output_tokens = int( 83 | response.response.meta.tokens.output_tokens or 0 84 | ) 85 | total_tokens = input_tokens + output_tokens 86 | 87 | yield UsageEvent.with_pricing( 88 | prompt_tokens=input_tokens, 89 | completion_tokens=output_tokens, 90 | total_tokens=total_tokens, 91 | pricing=pricing, 92 | ) 93 | 94 | else: 95 | response = self.client.chat( 96 | chat_history=chat_history, 97 | message=message["content"], 98 | model=model, 99 | **kwargs, 100 | ) 101 | yield MessageDeltaEvent(response.text) 102 | 103 | if ( 104 | response.meta 105 | and response.meta.tokens 106 | and (pricing := COHERE_PRICING.get(args["model"])) 107 | ): 108 | input_tokens = int(response.meta.tokens.input_tokens or 0) 109 | output_tokens = int(response.meta.tokens.output_tokens or 0) 110 | total_tokens = input_tokens + output_tokens 111 | 112 | yield UsageEvent.with_pricing( 113 | prompt_tokens=input_tokens, 114 | completion_tokens=output_tokens, 115 | total_tokens=total_tokens, 116 | pricing=pricing, 117 | ) 118 | 119 | except cohere.BadRequestError as e: 120 | raise BadRequestError(e.body) from e 121 | except ( 122 | cohere.TooManyRequestsError, 123 | cohere.InternalServerError, 124 | cohere.core.api_error.ApiError, # type: ignore 125 | ) as e: 126 | raise CompletionError(e.body) from e 127 | 128 | 129 | COHERE_PRICING: dict[str, Pricing] = { 130 | "command-r": { 131 | "prompt": 0.5 / 1_000_000, 132 | "response": 1.5 / 1_000_000, 133 | }, 134 | "command-r-plus": { 135 | "prompt": 3.0 / 1_000_000, 136 | "response": 15.0 / 1_000_000, 137 | }, 138 | } 139 | -------------------------------------------------------------------------------- /gptcli/providers/google.py: -------------------------------------------------------------------------------- 1 | import os 2 | from google import genai 3 | from google.genai import types 4 | 5 | from typing import Iterator, List, Optional 6 | 7 | from gptcli.completion import ( 8 | CompletionEvent, 9 | CompletionProvider, 10 | Message, 11 | MessageDeltaEvent, 12 | Pricing, 13 | UsageEvent, 14 | ) 15 | 16 | ROLE_MAP = { 17 | "user": "user", 18 | "assistant": "model", 19 | } 20 | 21 | 22 | api_key = os.environ.get("GEMINI_API_KEY") 23 | 24 | 25 | class GoogleCompletionProvider(CompletionProvider): 26 | def complete( 27 | self, messages: List[Message], args: dict, stream: bool = False 28 | ) -> Iterator[CompletionEvent]: 29 | client = genai.Client(api_key=api_key) 30 | model = args["model"] 31 | system_instruction = None 32 | if messages[0]["role"] == "system": 33 | system_instruction = messages[0]["content"] 34 | messages = messages[1:] 35 | 36 | contents = [ 37 | types.Content( 38 | role=ROLE_MAP[m["role"]], 39 | parts=[types.Part.from_text(text=m["content"])], 40 | ) 41 | for m in messages 42 | ] 43 | 44 | generate_content_config = types.GenerateContentConfig( 45 | system_instruction=system_instruction, 46 | temperature=args.get("temperature"), 47 | top_p=args.get("top_p"), 48 | thinking_config=( 49 | types.ThinkingConfig( 50 | include_thoughts=True, 51 | thinking_budget=args.get("thinking_budget"), 52 | ) 53 | if args.get("thinking_budget") 54 | else None 55 | ), 56 | response_mime_type="text/plain", 57 | ) 58 | 59 | if stream: 60 | response = client.models.generate_content_stream( 61 | model=model, 62 | contents=list(contents), 63 | config=generate_content_config, 64 | ) 65 | 66 | for chunk in response: 67 | if chunk.usage_metadata: 68 | prompt_tokens = chunk.usage_metadata.prompt_token_count or 0 69 | completion_tokens = chunk.usage_metadata.candidates_token_count or 0 70 | total_tokens = prompt_tokens + completion_tokens 71 | yield MessageDeltaEvent(chunk.text or "") 72 | 73 | else: 74 | response = client.models.generate_content( 75 | model=model, 76 | contents=list(contents), 77 | config=generate_content_config, 78 | ) 79 | yield MessageDeltaEvent(response.text or "") 80 | 81 | prompt_tokens = 0 82 | completion_tokens = 0 83 | total_tokens = 0 84 | if response.usage_metadata: 85 | prompt_tokens = response.usage_metadata.prompt_token_count or 0 86 | completion_tokens = response.usage_metadata.candidates_token_count or 0 87 | total_tokens = prompt_tokens + completion_tokens 88 | 89 | pricing = get_gemini_pricing(model, prompt_tokens) 90 | if pricing: 91 | yield UsageEvent.with_pricing( 92 | prompt_tokens=prompt_tokens, 93 | completion_tokens=completion_tokens, 94 | total_tokens=total_tokens, 95 | pricing=pricing, 96 | ) 97 | 98 | 99 | def get_gemini_pricing(model: str, prompt_tokens: int) -> Optional[Pricing]: 100 | if model.startswith("gemini-1.5-flash-8b"): 101 | return { 102 | "prompt": (0.0375 if prompt_tokens < 128000 else 0.075) / 1_000_000, 103 | "response": (0.15 if prompt_tokens < 128000 else 0.30) / 1_000_000, 104 | } 105 | if model.startswith("gemini-1.5-flash"): 106 | return { 107 | "prompt": (0.075 if prompt_tokens < 128000 else 0.15) / 1_000_000, 108 | "response": (0.30 if prompt_tokens < 128000 else 0.60) / 1_000_000, 109 | } 110 | elif model.startswith("gemini-1.5-pro"): 111 | return { 112 | "prompt": (1.25 if prompt_tokens < 128000 else 2.50) / 1_000_000, 113 | "response": (5.0 if prompt_tokens < 128000 else 10.0) / 1_000_000, 114 | } 115 | elif model.startswith("gemini-2.0-flash-lite"): 116 | return { 117 | "prompt": 0.075 / 1_000_000, 118 | "response": 0.30 / 1_000_000, 119 | } 120 | elif model.startswith("gemini-2.0-flash"): 121 | return { 122 | "prompt": 0.10 / 1_000_000, 123 | "response": 0.40 / 1_000_000, 124 | } 125 | elif model.startswith("gemini-2.5-pro"): 126 | return { 127 | "prompt": (1.25 if prompt_tokens < 200000 else 2.50) / 1_000_000, 128 | "response": (10.0 if prompt_tokens < 200000 else 15.0) / 1_000_000, 129 | } 130 | elif model.startswith("gemini-pro"): 131 | return { 132 | "prompt": 0.50 / 1_000_000, 133 | "response": 1.50 / 1_000_000, 134 | } 135 | else: 136 | return None 137 | -------------------------------------------------------------------------------- /gptcli/providers/llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Iterator, List, Optional, TypedDict, cast 4 | 5 | try: 6 | from llama_cpp import Completion, CompletionChunk, Llama 7 | 8 | LLAMA_AVAILABLE = True 9 | except ImportError: 10 | LLAMA_AVAILABLE = False 11 | 12 | from gptcli.completion import ( 13 | CompletionEvent, 14 | CompletionProvider, 15 | Message, 16 | MessageDeltaEvent, 17 | ) 18 | 19 | 20 | class LLaMAModelConfig(TypedDict): 21 | path: str 22 | human_prompt: str 23 | assistant_prompt: str 24 | 25 | 26 | LLAMA_MODELS: Optional[dict[str, LLaMAModelConfig]] = None 27 | 28 | 29 | def init_llama_models(models: dict[str, LLaMAModelConfig]): 30 | if not LLAMA_AVAILABLE: 31 | print( 32 | "Error: To use llama, you need to install gpt-command-line with the llama optional dependency: \ 33 | pip install gpt-command-line[llama]." 34 | ) 35 | sys.exit(1) 36 | 37 | for name, model_config in models.items(): 38 | if not os.path.isfile(model_config["path"]): 39 | print(f"LLaMA model {name} not found at {model_config['path']}.") 40 | sys.exit(1) 41 | if not name.startswith("llama"): 42 | print(f"LLaMA model names must start with `llama`, but got `{name}`.") 43 | sys.exit(1) 44 | 45 | global LLAMA_MODELS 46 | LLAMA_MODELS = models 47 | 48 | 49 | def role_to_name(role: str, model_config: LLaMAModelConfig) -> str: 50 | if role == "system" or role == "user": 51 | return model_config["human_prompt"] 52 | elif role == "assistant": 53 | return model_config["assistant_prompt"] 54 | else: 55 | raise ValueError(f"Unknown role: {role}") 56 | 57 | 58 | def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str: 59 | prompt = "\n".join( 60 | [ 61 | f"{role_to_name(message['role'], model_config)} {message['content']}" 62 | for message in messages 63 | ] 64 | ) 65 | prompt += f"\n{model_config['assistant_prompt']}" 66 | return prompt 67 | 68 | 69 | class LLaMACompletionProvider(CompletionProvider): 70 | def complete( 71 | self, messages: List[Message], args: dict, stream: bool = False 72 | ) -> Iterator[CompletionEvent]: 73 | assert LLAMA_MODELS, "LLaMA models not initialized" 74 | 75 | model_config = LLAMA_MODELS[args["model"]] 76 | 77 | with suppress_stderr(): 78 | llm = Llama( 79 | model_path=model_config["path"], 80 | n_ctx=2048, 81 | verbose=False, 82 | use_mlock=True, 83 | ) 84 | prompt = make_prompt(messages, model_config) 85 | print(prompt) 86 | 87 | extra_args = {} 88 | if "temperature" in args: 89 | extra_args["temperature"] = args["temperature"] 90 | if "top_p" in args: 91 | extra_args["top_p"] = args["top_p"] 92 | 93 | gen = llm.create_completion( 94 | prompt, 95 | max_tokens=1024, 96 | stop=model_config["human_prompt"], 97 | stream=stream, 98 | echo=False, 99 | **extra_args, 100 | ) 101 | if stream: 102 | for x in cast(Iterator[CompletionChunk], gen): 103 | yield MessageDeltaEvent(x["choices"][0]["text"]) 104 | else: 105 | yield MessageDeltaEvent(cast(Completion, gen)["choices"][0]["text"]) 106 | 107 | 108 | # https://stackoverflow.com/a/50438156 109 | class suppress_stderr(object): 110 | def __enter__(self): 111 | self.errnull_file = open(os.devnull, "w") 112 | self.old_stderr_fileno_undup = sys.stderr.fileno() 113 | self.old_stderr_fileno = os.dup(sys.stderr.fileno()) 114 | self.old_stderr = sys.stderr 115 | os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) 116 | sys.stderr = self.errnull_file 117 | return self 118 | 119 | def __exit__(self, *_): 120 | sys.stderr = self.old_stderr 121 | os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 122 | os.close(self.old_stderr_fileno) 123 | self.errnull_file.close() 124 | -------------------------------------------------------------------------------- /gptcli/providers/openai.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Iterator, List, Optional, cast 3 | import openai 4 | from openai import OpenAI 5 | from openai.types.responses import ResponseInputParam 6 | 7 | from gptcli.completion import ( 8 | CompletionEvent, 9 | CompletionProvider, 10 | Message, 11 | CompletionError, 12 | BadRequestError, 13 | MessageDeltaEvent, 14 | Pricing, 15 | ThinkingDeltaEvent, 16 | ToolCallEvent, 17 | UsageEvent, 18 | ) 19 | 20 | 21 | def is_reasoning_model(model: str) -> bool: 22 | return model.startswith("o1") or model.startswith("o3") or model.startswith("o4") 23 | 24 | 25 | class OpenAICompletionProvider(CompletionProvider): 26 | def __init__(self, base_url: Optional[str] = None, api_key: Optional[str] = None): 27 | self.client = OpenAI( 28 | api_key=api_key or openai.api_key, base_url=base_url or openai.base_url 29 | ) 30 | 31 | def complete( 32 | self, messages: List[Message], args: dict, stream: bool = False 33 | ) -> Iterator[CompletionEvent]: 34 | model = args["model"] 35 | if model.startswith("oai-compat:"): 36 | model = model[len("oai-compat:") :] 37 | 38 | if model.startswith("oai-azure:"): 39 | model = model[len("oai-azure:") :] 40 | 41 | kwargs = {} 42 | is_reasoning = is_reasoning_model(args["model"]) 43 | if "temperature" in args and not is_reasoning: 44 | kwargs["temperature"] = args["temperature"] 45 | if "top_p" in args and not is_reasoning: 46 | kwargs["top_p"] = args["top_p"] 47 | if is_reasoning: 48 | kwargs["reasoning"] = {"effort": "high", "summary": "auto"} 49 | kwargs["tools"] = [ 50 | {"type": "web_search_preview"} 51 | ] # provide reasoning models with search capabilities 52 | 53 | try: 54 | if stream: 55 | response_iter = self.client.responses.create( 56 | model=model, 57 | input=cast(ResponseInputParam, messages), 58 | stream=True, 59 | store=False, 60 | **kwargs, 61 | ) 62 | 63 | for response in response_iter: 64 | if response.type == "response.output_text.delta": 65 | yield MessageDeltaEvent(response.delta) 66 | elif response.type == "response.reasoning_summary_text.delta": 67 | yield ThinkingDeltaEvent(response.delta) 68 | elif response.type == "response.reasoning_summary_part.done": 69 | yield ThinkingDeltaEvent("\n\n") 70 | elif response.type == "response.web_search_call.in_progress": 71 | yield ToolCallEvent("Searching the web...") 72 | elif response.type == "response.completed" and ( 73 | pricing := gpt_pricing(args["model"]) 74 | ): 75 | if response.response.usage: 76 | yield UsageEvent.with_pricing( 77 | prompt_tokens=response.response.usage.input_tokens, 78 | completion_tokens=response.response.usage.output_tokens, 79 | total_tokens=response.response.usage.input_tokens 80 | + response.response.usage.output_tokens, 81 | pricing=pricing, 82 | ) 83 | else: 84 | response = self.client.responses.create( 85 | model=model, 86 | input=cast(ResponseInputParam, messages), 87 | stream=False, 88 | store=False, 89 | **kwargs, 90 | ) 91 | 92 | yield MessageDeltaEvent(response.output_text) 93 | 94 | if response.usage and (pricing := gpt_pricing(args["model"])): 95 | yield UsageEvent.with_pricing( 96 | prompt_tokens=response.usage.input_tokens, 97 | completion_tokens=response.usage.output_tokens, 98 | total_tokens=response.usage.input_tokens 99 | + response.usage.output_tokens, 100 | pricing=pricing, 101 | ) 102 | 103 | except openai.BadRequestError as e: 104 | raise BadRequestError(e.message) from e 105 | except openai.APIError as e: 106 | raise CompletionError(e.message) from e 107 | 108 | 109 | GPT_3_5_TURBO_PRICE_PER_TOKEN: Pricing = { 110 | "prompt": 0.50 / 1_000_000, 111 | "response": 1.50 / 1_000_000, 112 | } 113 | 114 | GPT_3_5_TURBO_16K_PRICE_PER_TOKEN: Pricing = { 115 | "prompt": 0.003 / 1000, 116 | "response": 0.004 / 1000, 117 | } 118 | 119 | GPT_4_PRICE_PER_TOKEN: Pricing = { 120 | "prompt": 30.0 / 1_000_000, 121 | "response": 60.0 / 1_000_000, 122 | } 123 | 124 | GPT_4_TURBO_PRICE_PER_TOKEN: Pricing = { 125 | "prompt": 10.0 / 1_000_000, 126 | "response": 30.0 / 1_000_000, 127 | } 128 | 129 | GPT_4_32K_PRICE_PER_TOKEN: Pricing = { 130 | "prompt": 60.0 / 1_000_000, 131 | "response": 120.0 / 1_000_000, 132 | } 133 | 134 | GPT_4_O_2024_05_13_PRICE_PER_TOKEN: Pricing = { 135 | "prompt": 5.0 / 1_000_000, 136 | "response": 15.0 / 1_000_000, 137 | } 138 | 139 | GPT_4_O_2024_08_06_PRICE_PER_TOKEN: Pricing = { 140 | "prompt": 2.50 / 1_000_000, 141 | "response": 10.0 / 1_000_000, 142 | } 143 | 144 | GPT_4_O_MINI_PRICE_PER_TOKEN: Pricing = { 145 | "prompt": 0.150 / 1_000_000, 146 | "response": 0.600 / 1_000_000, 147 | } 148 | 149 | GPT_4_1_PRICE_PER_TOKEN: Pricing = { 150 | "prompt": 2.0 / 1_000_000, 151 | "response": 8.0 / 1_000_000, 152 | } 153 | 154 | GPT_4_1_MINI_PRICE_PER_TOKEN: Pricing = { 155 | "prompt": 0.400 / 1_000_000, 156 | "response": 1.600 / 1_000_000, 157 | } 158 | 159 | GPT_4_1_NANO_PRICE_PER_TOKEN: Pricing = { 160 | "prompt": 0.1 / 1_000_000, 161 | "response": 0.4 / 1_000_000, 162 | } 163 | 164 | GPT_4_5_PRICE_PER_TOKEN: Pricing = { 165 | "prompt": 75.0 / 1_000_000, 166 | "response": 150.0 / 1_000_000, 167 | } 168 | 169 | O_1_PRO_PRICE_PER_TOKEN: Pricing = { 170 | "prompt": 150.0 / 1_000_000, 171 | "response": 600.0 / 1_000_000, 172 | } 173 | 174 | O_1_PRICE_PER_TOKEN: Pricing = { 175 | "prompt": 15.0 / 1_000_000, 176 | "response": 60.0 / 1_000_000, 177 | } 178 | 179 | O_1_PREVIEW_PRICE_PER_TOKEN: Pricing = { 180 | "prompt": 15.0 / 1_000_000, 181 | "response": 60.0 / 1_000_000, 182 | } 183 | 184 | O_1_MINI_PRICE_PER_TOKEN: Pricing = { 185 | "prompt": 3.0 / 1_000_000, 186 | "response": 12.0 / 1_000_000, 187 | } 188 | 189 | O_3_MINI_PRICE_PER_TOKEN: Pricing = { 190 | "prompt": 1.1 / 1_000_000, 191 | "response": 4.4 / 1_000_000, 192 | } 193 | 194 | O_3_PRICE_PER_TOKEN: Pricing = { 195 | "prompt": 10.0 / 1_000_000, 196 | "response": 40.0 / 1_000_000, 197 | } 198 | 199 | O_4_MINI_PRICE_PER_TOKEN: Pricing = { 200 | "prompt": 1.1 / 1_000_000, 201 | "response": 4.4 / 1_000_000, 202 | } 203 | 204 | 205 | def gpt_pricing(model: str) -> Optional[Pricing]: 206 | if model.startswith("gpt-3.5-turbo-16k"): 207 | return GPT_3_5_TURBO_16K_PRICE_PER_TOKEN 208 | elif model.startswith("gpt-3.5-turbo"): 209 | return GPT_3_5_TURBO_PRICE_PER_TOKEN 210 | elif model.startswith("gpt-4-32k"): 211 | return GPT_4_32K_PRICE_PER_TOKEN 212 | elif model.startswith("gpt-4o-mini"): 213 | return GPT_4_O_MINI_PRICE_PER_TOKEN 214 | elif model.startswith("gpt-4o-2024-05-13") or model.startswith("chatgpt-4o-latest"): 215 | return GPT_4_O_2024_05_13_PRICE_PER_TOKEN 216 | elif model.startswith("gpt-4o"): 217 | return GPT_4_O_2024_08_06_PRICE_PER_TOKEN 218 | elif model.startswith("gpt-4.1-mini"): 219 | return GPT_4_1_MINI_PRICE_PER_TOKEN 220 | elif model.startswith("gpt-4.1-nano"): 221 | return GPT_4_1_NANO_PRICE_PER_TOKEN 222 | elif model.startswith("gpt-4.1"): 223 | return GPT_4_1_PRICE_PER_TOKEN 224 | elif model.startswith("gpt-4.5"): 225 | return GPT_4_5_PRICE_PER_TOKEN 226 | elif model.startswith("gpt-4-turbo") or re.match(r"gpt-4-\d\d\d\d-preview", model): 227 | return GPT_4_TURBO_PRICE_PER_TOKEN 228 | elif model.startswith("gpt-4"): 229 | return GPT_4_PRICE_PER_TOKEN 230 | elif model.startswith("o1-pro"): 231 | return O_1_PRO_PRICE_PER_TOKEN 232 | elif model.startswith("o1-preview"): 233 | return O_1_PREVIEW_PRICE_PER_TOKEN 234 | elif model.startswith("o1-mini"): 235 | return O_1_MINI_PRICE_PER_TOKEN 236 | elif model.startswith("o1"): 237 | return O_1_PRICE_PER_TOKEN 238 | elif model.startswith("o3-mini"): 239 | return O_3_MINI_PRICE_PER_TOKEN 240 | elif model.startswith("o3"): 241 | return O_3_PRICE_PER_TOKEN 242 | elif model.startswith("o4-mini"): 243 | return O_4_MINI_PRICE_PER_TOKEN 244 | else: 245 | return None 246 | -------------------------------------------------------------------------------- /gptcli/session.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from gptcli.assistant import Assistant 3 | from gptcli.completion import ( 4 | Message, 5 | CompletionError, 6 | BadRequestError, 7 | ToolCallEvent, 8 | UsageEvent, 9 | ) 10 | from typing import List, Optional 11 | 12 | 13 | class ResponseStreamer: 14 | def __enter__(self) -> "ResponseStreamer": 15 | return self 16 | 17 | def on_next_token(self, token: str): 18 | pass 19 | 20 | def on_thinking_token(self, token: str): 21 | pass 22 | 23 | def on_tool_call(self, tool_call: ToolCallEvent): 24 | pass 25 | 26 | def __exit__(self, *args): 27 | pass 28 | 29 | 30 | class ChatListener: 31 | def on_chat_start(self): 32 | pass 33 | 34 | def on_chat_clear(self): 35 | pass 36 | 37 | def on_chat_rerun(self, success: bool): 38 | pass 39 | 40 | def on_error(self, error: Exception): 41 | pass 42 | 43 | def response_streamer(self) -> ResponseStreamer: 44 | return ResponseStreamer() 45 | 46 | def on_chat_message(self, message: Message): 47 | pass 48 | 49 | def on_chat_response( 50 | self, 51 | messages: List[Message], 52 | response: Message, 53 | usage: Optional[UsageEvent] = None, 54 | ): 55 | pass 56 | 57 | 58 | class UserInputProvider: 59 | @abstractmethod 60 | def get_user_input(self) -> str: 61 | pass 62 | 63 | 64 | class InvalidArgumentError(Exception): 65 | def __init__(self, message: str): 66 | self.message = message 67 | 68 | 69 | COMMAND_CLEAR = (":clear", ":c") 70 | COMMAND_QUIT = (":quit", ":q") 71 | COMMAND_RERUN = (":rerun", ":r") 72 | COMMAND_HELP = (":help", ":h", ":?") 73 | ALL_COMMANDS = [*COMMAND_CLEAR, *COMMAND_QUIT, *COMMAND_RERUN, *COMMAND_HELP] 74 | COMMANDS_HELP = """ 75 | Commands: 76 | - `:clear` / `:c` / Ctrl+C - Clear the conversation. 77 | - `:quit` / `:q` / Ctrl+D - Quit the program. 78 | - `:rerun` / `:r` / Ctrl+R - Re-run the last message. 79 | - `:help` / `:h` / `:?` - Show this help message. 80 | """ 81 | 82 | 83 | class ChatSession: 84 | def __init__( 85 | self, 86 | assistant: Assistant, 87 | listener: ChatListener, 88 | stream: bool = True, 89 | ): 90 | self.assistant = assistant 91 | self.messages: List[Message] = assistant.init_messages() 92 | self.user_prompts: List[Message] = [] 93 | self.listener = listener 94 | self.stream = stream 95 | 96 | def _clear(self): 97 | self.messages = self.assistant.init_messages() 98 | self.user_prompts = [] 99 | self.listener.on_chat_clear() 100 | 101 | def _rerun(self): 102 | if len(self.user_prompts) == 0: 103 | self.listener.on_chat_rerun(False) 104 | return 105 | 106 | if self.messages[-1]["role"] == "assistant": 107 | self.messages = self.messages[:-1] 108 | 109 | self.listener.on_chat_rerun(True) 110 | self._respond() 111 | 112 | def _respond(self) -> bool: 113 | """ 114 | Respond to the user's input and return whether the assistant's response was saved. 115 | """ 116 | next_response: str = "" 117 | usage: Optional[UsageEvent] = None 118 | try: 119 | completion_iter = self.assistant.complete_chat( 120 | self.messages, stream=self.stream 121 | ) 122 | 123 | with self.listener.response_streamer() as stream: 124 | for event in completion_iter: 125 | if event.type == "message_delta": 126 | next_response += event.text 127 | stream.on_next_token(event.text) 128 | elif event.type == "thinking_delta": 129 | stream.on_thinking_token(event.text) 130 | elif event.type == "tool_call": 131 | stream.on_tool_call(event) 132 | elif event.type == "usage": 133 | usage = event 134 | 135 | except KeyboardInterrupt: 136 | # If the user interrupts the chat completion, we'll just return what we have so far 137 | pass 138 | except BadRequestError as e: 139 | self.listener.on_error(e) 140 | return False 141 | except CompletionError as e: 142 | self.listener.on_error(e) 143 | return True 144 | 145 | next_message: Message = {"role": "assistant", "content": next_response} 146 | self.listener.on_chat_message(next_message) 147 | self.listener.on_chat_response(self.messages, next_message, usage) 148 | 149 | self.messages = self.messages + [next_message] 150 | return True 151 | 152 | def _add_user_message(self, user_input: str): 153 | user_message: Message = {"role": "user", "content": user_input} 154 | self.messages = self.messages + [user_message] 155 | self.listener.on_chat_message(user_message) 156 | self.user_prompts.append(user_message) 157 | 158 | def _rollback_user_message(self): 159 | self.messages = self.messages[:-1] 160 | self.user_prompts = self.user_prompts[:-1] 161 | 162 | def _print_help(self): 163 | with self.listener.response_streamer() as stream: 164 | stream.on_next_token(COMMANDS_HELP) 165 | 166 | def process_input(self, user_input: str): 167 | """ 168 | Process the user's input and return whether the session should continue. 169 | """ 170 | if user_input in COMMAND_QUIT: 171 | return False 172 | elif user_input in COMMAND_CLEAR: 173 | self._clear() 174 | return True 175 | elif user_input in COMMAND_RERUN: 176 | self._rerun() 177 | return True 178 | elif user_input in COMMAND_HELP: 179 | self._print_help() 180 | return True 181 | 182 | self._add_user_message(user_input) 183 | response_saved = self._respond() 184 | if not response_saved: 185 | self._rollback_user_message() 186 | 187 | return True 188 | 189 | def loop(self, input_provider: UserInputProvider): 190 | self.listener.on_chat_start() 191 | while self.process_input(input_provider.get_user_input()): 192 | pass 193 | -------------------------------------------------------------------------------- /gptcli/shell.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | import subprocess 5 | import tempfile 6 | from gptcli.assistant import Assistant 7 | 8 | 9 | def simple_response(assistant: Assistant, prompt: str, stream: bool) -> None: 10 | messages = assistant.init_messages() 11 | messages.append({"role": "user", "content": prompt}) 12 | logging.info("User: %s", prompt) 13 | response_iter = assistant.complete_chat(messages, stream=stream) 14 | result = "" 15 | try: 16 | for response in response_iter: 17 | if response.type == "message_delta": 18 | result += response.text 19 | sys.stdout.write(response.text) 20 | except KeyboardInterrupt: 21 | pass 22 | finally: 23 | sys.stdout.flush() 24 | logging.info("Assistant: %s", result) 25 | 26 | 27 | def execute(assistant: Assistant, prompt: str) -> None: 28 | messages = assistant.init_messages() 29 | messages.append({"role": "user", "content": prompt}) 30 | logging.info("User: %s", prompt) 31 | response_iter = assistant.complete_chat(messages, stream=False) 32 | result = next(response_iter) 33 | assert result.type == "message_delta" 34 | result = result.text 35 | logging.info("Assistant: %s", result) 36 | 37 | with tempfile.NamedTemporaryFile(mode="w", prefix="gptcli-", delete=False) as f: 38 | f.write("# Edit the command to execute below. Save and exit to execute it.\n") 39 | f.write("# Delete the contents to cancel.\n") 40 | f.write(result) 41 | f.flush() 42 | 43 | editor = os.environ.get("EDITOR", "nano") 44 | subprocess.run([editor, f.name]) 45 | 46 | with open(f.name) as f: 47 | lines = [line for line in f.readlines() if not line.startswith("#")] 48 | command = "".join(lines).strip() 49 | 50 | if command == "": 51 | print("No command to execute.") 52 | return 53 | 54 | shell = os.environ.get("SHELL", "/bin/bash") 55 | 56 | logging.info(f"Executing: {command}") 57 | print(f"Executing:\n{command}") 58 | subprocess.run([shell, f.name]) 59 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gpt-command-line" 3 | version = "0.4.3" 4 | description = "Command-line interface for ChatGPT and Claude" 5 | authors = [{name = "Val Kharitonov", email = "val@kharvd.com"}] 6 | readme = "README.md" 7 | license = {file = "LICENSE"} 8 | requires-python = ">=3.9,<3.13" 9 | keywords = ["cli", "command-line", "assistant", "openai", "claude", "cohere", "gpt-3", "gpt-4", "llm", "chatgpt", "gpt-cli", "anthropic", "gpt-client", "anthropic-claude"] 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Environment :: Console", 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: End Users/Desktop", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: MIT License", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 18 | ] 19 | dependencies = [ 20 | "anthropic~=0.47.1", 21 | "attrs~=25.1.0", 22 | "black~=25.1.0", 23 | "cohere~=5.13.12", 24 | "google-genai~=1.10.0", 25 | "openai~=1.75.0", 26 | "prompt-toolkit~=3.0.50", 27 | "pytest~=8.3.4", 28 | "PyYAML~=6.0.2", 29 | "rich~=13.9.4", 30 | "typing_extensions~=4.12.2", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | llama = [ 35 | "llama-cpp-python==0.2.74", 36 | ] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/kharvd/gpt-cli" 40 | 41 | [project.scripts] 42 | gpt = "gptcli.gpt:main" 43 | 44 | [build-system] 45 | requires = ["pip>=23.0.0", "setuptools>=58.0.0", "wheel"] 46 | build-backend = "setuptools.build_meta" 47 | 48 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kharvd/gpt-cli/bff5d87311abf00a900fb17595d443953750e29c/screenshot.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kharvd/gpt-cli/bff5d87311abf00a900fb17595d443953750e29c/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_assistant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from gptcli.assistant import AssistantGlobalArgs, init_assistant 3 | 4 | 5 | @pytest.mark.parametrize( 6 | "args,custom_assistants,expected_config", 7 | [ 8 | ( 9 | AssistantGlobalArgs("dev"), 10 | {}, 11 | {}, 12 | ), 13 | ( 14 | AssistantGlobalArgs("dev", model="gpt-4"), 15 | {}, 16 | {"model": "gpt-4"}, 17 | ), 18 | ( 19 | AssistantGlobalArgs("dev", temperature=0.5, top_p=0.5), 20 | {}, 21 | {"temperature": 0.5, "top_p": 0.5}, 22 | ), 23 | ( 24 | AssistantGlobalArgs("dev"), 25 | { 26 | "dev": { 27 | "model": "gpt-4", 28 | }, 29 | }, 30 | {"model": "gpt-4"}, 31 | ), 32 | ( 33 | AssistantGlobalArgs("dev", model="gpt-4"), 34 | { 35 | "dev": { 36 | "model": "gpt-3.5-turbo", 37 | }, 38 | }, 39 | {"model": "gpt-4"}, 40 | ), 41 | ( 42 | AssistantGlobalArgs("custom"), 43 | { 44 | "custom": { 45 | "model": "gpt-4", 46 | "temperature": 0.5, 47 | "top_p": 0.5, 48 | "messages": [], 49 | }, 50 | }, 51 | {"model": "gpt-4", "temperature": 0.5, "top_p": 0.5}, 52 | ), 53 | ( 54 | AssistantGlobalArgs( 55 | "custom", model="gpt-3.5-turbo", temperature=1.0, top_p=1.0 56 | ), 57 | { 58 | "custom": { 59 | "model": "gpt-4", 60 | "temperature": 0.5, 61 | "top_p": 0.5, 62 | "messages": [], 63 | }, 64 | }, 65 | {"model": "gpt-3.5-turbo", "temperature": 1.0, "top_p": 1.0}, 66 | ), 67 | ], 68 | ) 69 | def test_init_assistant(args, custom_assistants, expected_config): 70 | assistant = init_assistant(args, custom_assistants) 71 | assert assistant.config.get("model") == expected_config.get("model") 72 | assert assistant.config.get("temperature") == expected_config.get("temperature") 73 | assert assistant.config.get("top_p") == expected_config.get("top_p") 74 | -------------------------------------------------------------------------------- /tests/test_session.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | from gptcli.completion import CompletionError, BadRequestError, MessageDeltaEvent 3 | from gptcli.session import ChatSession 4 | 5 | system_message = {"role": "system", "content": "system message"} 6 | 7 | 8 | def setup_assistant_mock(): 9 | assistant_mock = mock.MagicMock() 10 | assistant_mock.init_messages.return_value = [system_message] 11 | return assistant_mock 12 | 13 | 14 | def setup_listener_mock(): 15 | listener_mock = mock.MagicMock() 16 | response_streamer_mock = mock.MagicMock() 17 | response_streamer_mock.__enter__.return_value = response_streamer_mock 18 | listener_mock.response_streamer.return_value = response_streamer_mock 19 | return listener_mock, response_streamer_mock 20 | 21 | 22 | def setup_session(): 23 | assistant_mock = setup_assistant_mock() 24 | listener_mock, _ = setup_listener_mock() 25 | session = ChatSession(assistant_mock, listener_mock) 26 | return assistant_mock, listener_mock, session 27 | 28 | 29 | def test_simple_input(): 30 | assistant_mock, listener_mock, session = setup_session() 31 | 32 | expected_response = "assistant message" 33 | assistant_mock.complete_chat.return_value = [MessageDeltaEvent(expected_response)] 34 | 35 | user_input = "user message" 36 | should_continue = session.process_input(user_input) 37 | assert should_continue 38 | 39 | user_message = {"role": "user", "content": user_input} 40 | assistant_message = {"role": "assistant", "content": expected_response} 41 | 42 | assistant_mock.complete_chat.assert_called_once_with( 43 | [system_message, user_message], 44 | stream=True, 45 | ) 46 | listener_mock.on_chat_message.assert_has_calls( 47 | [mock.call(user_message), mock.call(assistant_message)] 48 | ) 49 | 50 | 51 | def test_quit(): 52 | _, _, session = setup_session() 53 | should_continue = session.process_input(":q") 54 | assert not should_continue 55 | 56 | 57 | def test_clear(): 58 | assistant_mock, listener_mock, session = setup_session() 59 | 60 | assistant_mock.init_messages.assert_called_once() 61 | assistant_mock.init_messages.reset_mock() 62 | 63 | assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant_message")] 64 | 65 | should_continue = session.process_input("user_message") 66 | assert should_continue 67 | 68 | assistant_mock.complete_chat.assert_called_once_with( 69 | [system_message, {"role": "user", "content": "user_message"}], 70 | stream=True, 71 | ) 72 | listener_mock.on_chat_message.assert_has_calls( 73 | [ 74 | mock.call({"role": "user", "content": "user_message"}), 75 | mock.call({"role": "assistant", "content": "assistant_message"}), 76 | ] 77 | ) 78 | assistant_mock.complete_chat.reset_mock() 79 | listener_mock.on_chat_message.reset_mock() 80 | 81 | should_continue = session.process_input(":c") 82 | assert should_continue 83 | 84 | assistant_mock.init_messages.assert_called_once() 85 | listener_mock.on_chat_clear.assert_called_once() 86 | assistant_mock.complete_chat.assert_not_called() 87 | 88 | assistant_mock.complete_chat.return_value = [ 89 | MessageDeltaEvent("assistant_message_1") 90 | ] 91 | 92 | should_continue = session.process_input("user_message_1") 93 | assert should_continue 94 | 95 | assistant_mock.complete_chat.assert_called_once_with( 96 | [system_message, {"role": "user", "content": "user_message_1"}], 97 | stream=True, 98 | ) 99 | listener_mock.on_chat_message.assert_has_calls( 100 | [ 101 | mock.call({"role": "user", "content": "user_message_1"}), 102 | mock.call({"role": "assistant", "content": "assistant_message_1"}), 103 | ] 104 | ) 105 | 106 | 107 | def test_rerun(): 108 | assistant_mock, listener_mock, session = setup_session() 109 | 110 | assistant_mock.init_messages.assert_called_once() 111 | assistant_mock.init_messages.reset_mock() 112 | 113 | # Re-run before any input shouldn't do anything 114 | should_continue = session.process_input(":r") 115 | assert should_continue 116 | 117 | assistant_mock.init_messages.assert_not_called() 118 | assistant_mock.complete_chat.assert_not_called() 119 | listener_mock.on_chat_message.assert_not_called() 120 | listener_mock.on_chat_rerun.assert_called_once_with(False) 121 | 122 | listener_mock.on_chat_rerun.reset_mock() 123 | 124 | # Now proper re-run 125 | assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant_message")] 126 | 127 | should_continue = session.process_input("user_message") 128 | assert should_continue 129 | 130 | assistant_mock.complete_chat.assert_called_once_with( 131 | [system_message, {"role": "user", "content": "user_message"}], 132 | stream=True, 133 | ) 134 | listener_mock.on_chat_message.assert_has_calls( 135 | [ 136 | mock.call({"role": "user", "content": "user_message"}), 137 | mock.call({"role": "assistant", "content": "assistant_message"}), 138 | ] 139 | ) 140 | assistant_mock.complete_chat.reset_mock() 141 | listener_mock.on_chat_message.reset_mock() 142 | 143 | assistant_mock.complete_chat.return_value = [ 144 | MessageDeltaEvent("assistant_message_1") 145 | ] 146 | 147 | should_continue = session.process_input(":r") 148 | assert should_continue 149 | 150 | listener_mock.on_chat_rerun.assert_called_once_with(True) 151 | 152 | assistant_mock.complete_chat.assert_called_once_with( 153 | [system_message, {"role": "user", "content": "user_message"}], 154 | stream=True, 155 | ) 156 | listener_mock.on_chat_message.assert_has_calls( 157 | [ 158 | mock.call({"role": "assistant", "content": "assistant_message_1"}), 159 | ] 160 | ) 161 | 162 | 163 | def test_invalid_request_error(): 164 | assistant_mock, listener_mock, session = setup_session() 165 | 166 | error = BadRequestError("error message") 167 | assistant_mock.complete_chat.side_effect = error 168 | 169 | user_input = "user message" 170 | should_continue = session.process_input(user_input) 171 | assert should_continue 172 | 173 | user_message = {"role": "user", "content": user_input} 174 | listener_mock.on_chat_message.assert_has_calls([mock.call(user_message)]) 175 | listener_mock.on_error.assert_called_once_with(error) 176 | 177 | # Now rerun shouldn't do anything because user input was not saved 178 | assistant_mock.complete_chat.reset_mock() 179 | listener_mock.on_chat_message.reset_mock() 180 | listener_mock.on_error.reset_mock() 181 | 182 | should_continue = session.process_input(":r") 183 | assert should_continue 184 | 185 | assistant_mock.complete_chat.assert_not_called() 186 | listener_mock.on_chat_message.assert_not_called() 187 | listener_mock.on_error.assert_not_called() 188 | listener_mock.on_chat_rerun.assert_called_once_with(False) 189 | 190 | 191 | def test_openai_error(): 192 | assistant_mock, listener_mock, session = setup_session() 193 | 194 | error = CompletionError("error message") 195 | assistant_mock.complete_chat.side_effect = error 196 | 197 | user_input = "user message" 198 | should_continue = session.process_input(user_input) 199 | assert should_continue 200 | 201 | user_message = {"role": "user", "content": user_input} 202 | listener_mock.on_chat_message.assert_has_calls([mock.call(user_message)]) 203 | listener_mock.on_error.assert_called_once_with(error) 204 | 205 | # Re-run should work 206 | assistant_mock.complete_chat.reset_mock() 207 | listener_mock.on_chat_message.reset_mock() 208 | listener_mock.on_error.reset_mock() 209 | 210 | assistant_mock.complete_chat.side_effect = None 211 | assistant_mock.complete_chat.return_value = [MessageDeltaEvent("assistant message")] 212 | 213 | should_continue = session.process_input(":r") 214 | assert should_continue 215 | 216 | assistant_mock.complete_chat.assert_called_once_with( 217 | [system_message, user_message], 218 | stream=True, 219 | ) 220 | listener_mock.on_chat_message.assert_has_calls( 221 | [ 222 | mock.call({"role": "assistant", "content": "assistant message"}), 223 | ] 224 | ) 225 | 226 | 227 | def test_stream(): 228 | assistant_mock, listener_mock, session = setup_session() 229 | assistant_message = "assistant message" 230 | assistant_mock.complete_chat.return_value = ( 231 | MessageDeltaEvent(tok) for tok in list(assistant_message) 232 | ) 233 | 234 | response_streamer_mock = listener_mock.response_streamer.return_value 235 | 236 | session.process_input("user message") 237 | 238 | response_streamer_mock.assert_has_calls( 239 | [mock.call.on_next_token(token) for token in assistant_message] 240 | ) 241 | --------------------------------------------------------------------------------