├── src └── shelloracle │ ├── __init__.py │ ├── py.typed │ ├── __main__.py │ ├── shell │ ├── shelloracle.fish │ ├── shelloracle.bash │ └── shelloracle.zsh │ ├── cli │ ├── config │ │ ├── __init__.py │ │ ├── edit.py │ │ ├── init.py │ │ └── show.py │ ├── application.py │ └── __init__.py │ ├── tty_log_handler.py │ ├── providers │ ├── google.py │ ├── openai.py │ ├── deepseek.py │ ├── openai_compat.py │ ├── localai.py │ ├── xai.py │ ├── ollama.py │ └── __init__.py │ ├── config.py │ ├── shelloracle.py │ └── bootstrap.py ├── .github ├── FUNDING.yml └── workflows │ ├── lint.yml │ ├── tests.yml │ └── release.yml ├── tests ├── conftest.py ├── providers │ ├── test_localai.py │ ├── test_openai.py │ ├── test_xai.py │ ├── conftest.py │ ├── test_deepseek.py │ └── test_ollama.py ├── test_shelloracle.py └── test_config.py ├── CITATION.cff ├── pyproject.toml ├── .gitignore ├── README.md └── LICENSE /src/shelloracle/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/shelloracle/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [djcopley] 2 | -------------------------------------------------------------------------------- /src/shelloracle/__main__.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from shelloracle.cli import main 3 | 4 | main() 5 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from shelloracle.cli import Application 4 | 5 | 6 | @pytest.fixture 7 | def global_app(): 8 | return Application() 9 | -------------------------------------------------------------------------------- /src/shelloracle/shell/shelloracle.fish: -------------------------------------------------------------------------------- 1 | function __shelloracle__ 2 | set -l output (shor) 3 | if test $status -ne 0 4 | return $status 5 | end 6 | commandline -r -- $output 7 | end 8 | 9 | bind \cf __shelloracle__ 10 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you reference this software, please cite it as below." 3 | authors: 4 | - family-names: "Copley" 5 | given-names: "Daniel" 6 | orcid: "https://orcid.org/0009-0003-9149-0518" 7 | title: "ShellOracle" 8 | url: "https://github.com/djcopley/ShellOracle" -------------------------------------------------------------------------------- /src/shelloracle/cli/config/__init__.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from shelloracle.cli.config.edit import edit 4 | from shelloracle.cli.config.init import init 5 | from shelloracle.cli.config.show import show 6 | 7 | 8 | @click.group() 9 | def config(): ... 10 | 11 | 12 | config.add_command(edit) 13 | config.add_command(init) 14 | config.add_command(show) 15 | -------------------------------------------------------------------------------- /src/shelloracle/cli/config/edit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import click 6 | 7 | if TYPE_CHECKING: 8 | from shelloracle.cli.application import Application 9 | 10 | 11 | @click.command() 12 | @click.pass_obj 13 | def edit(app: Application): 14 | """Edit shelloracle configuration.""" 15 | click.edit(filename=app.config_path) 16 | -------------------------------------------------------------------------------- /src/shelloracle/cli/config/init.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import click 6 | 7 | if TYPE_CHECKING: 8 | from shelloracle.cli import Application 9 | 10 | 11 | @click.command() 12 | @click.pass_obj 13 | def init(app: Application): 14 | """Install shelloracle keybindings.""" 15 | # nest this import in a function to avoid expensive module loads 16 | from shelloracle.bootstrap import bootstrap_shelloracle 17 | 18 | bootstrap_shelloracle(app.config_path) 19 | -------------------------------------------------------------------------------- /src/shelloracle/cli/application.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from pathlib import Path 5 | 6 | if typing.TYPE_CHECKING: 7 | from shelloracle.config import Configuration 8 | 9 | shelloracle_home = Path.home() / ".shelloracle" 10 | shelloracle_home.mkdir(exist_ok=True) 11 | 12 | 13 | class Application: 14 | configuration: Configuration 15 | 16 | def __init__(self): 17 | self.config_path = shelloracle_home / "config.toml" 18 | self.log_path = shelloracle_home / "shelloracle.log" 19 | -------------------------------------------------------------------------------- /src/shelloracle/cli/config/show.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pygments 3 | from prompt_toolkit import print_formatted_text 4 | from prompt_toolkit.formatted_text import PygmentsTokens 5 | from pygments.lexers import TOMLLexer 6 | 7 | from shelloracle.cli.application import Application 8 | 9 | 10 | @click.command() 11 | @click.pass_obj 12 | def show(app: Application): 13 | """Display shelloracle configuration.""" 14 | with app.config_path.open("r") as f: 15 | tokens = list(pygments.lex(f.read(), lexer=TOMLLexer())) 16 | print_formatted_text(PygmentsTokens(tokens)) 17 | -------------------------------------------------------------------------------- /src/shelloracle/shell/shelloracle.bash: -------------------------------------------------------------------------------- 1 | __shelloracle__() { 2 | local output 3 | output=$(shor) || return 4 | READLINE_LINE=${output#*$'\t'} 5 | if [[ -z "$READLINE_POINT" ]]; then 6 | echo "$READLINE_LINE" 7 | else 8 | READLINE_POINT=0x7fffffff 9 | fi 10 | } 11 | 12 | if (( BASH_VERSINFO[0] < 4 )); then 13 | bind -m emacs-standard '"\C-f": "\C-e \C-u\C-y\ey\C-u"$(__shelloracle__)"\e\C-e\er"' 14 | bind -m vi-command '"\C-f": "\C-z\C-r\C-z"' 15 | bind -m vi-insert '"\C-f": "\C-z\C-r\C-z"' 16 | else 17 | bind -m emacs-standard -x '"\C-f": __shelloracle__' 18 | bind -m vi-command -x '"\C-f": __shelloracle__' 19 | bind -m vi-insert -x '"\C-f": __shelloracle__' 20 | fi 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | push: 4 | pull_request: 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | ruff: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.x" 16 | - uses: pypa/hatch@install 17 | - name: Run ruff linter and formatter 18 | run: hatch fmt --check 19 | mypy: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | with: 24 | fetch-depth: 0 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.x" 28 | - uses: pypa/hatch@install 29 | - name: Run mypy 30 | run: hatch run types:check 31 | -------------------------------------------------------------------------------- /src/shelloracle/tty_log_handler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from prompt_toolkit import print_formatted_text 6 | from prompt_toolkit.application import create_app_session_from_tty 7 | from prompt_toolkit.formatted_text import FormattedText 8 | 9 | 10 | class TtyLogHandler(logging.Handler): 11 | def emit(self, record: logging.LogRecord): 12 | if record.levelno >= logging.ERROR: 13 | color = "ansired" 14 | elif record.levelno == logging.WARNING: 15 | color = "ansiyellow" 16 | else: 17 | color = "ansywhite" 18 | log_entry = self.format(record) 19 | formatted_log_entry = FormattedText([(color, f"\n{log_entry}")]) 20 | with create_app_session_from_tty(): 21 | print_formatted_text(formatted_log_entry) 22 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | pull_request: 5 | concurrency: 6 | group: test-${{ github.ref }} 7 | cancel-in-progress: true 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | test: 13 | name: Test with ${{ matrix.py }} 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | py: 19 | - "3.13" 20 | - "3.12" 21 | - "3.11" 22 | - "3.10" 23 | - "3.9" 24 | steps: 25 | - uses: pypa/hatch@install 26 | - uses: actions/checkout@v3 27 | with: 28 | fetch-depth: 0 29 | - name: Setup python for test ${{ matrix.py }} 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: ${{ matrix.py }} 33 | - name: Run test suite 34 | run: hatch test -py=${{ matrix.py }} 35 | -------------------------------------------------------------------------------- /src/shelloracle/shell/shelloracle.zsh: -------------------------------------------------------------------------------- 1 | # Define the function shelloracle-widget 2 | shelloracle-widget() { 3 | # Set options and suppress any error messages 4 | setopt localoptions noglobsubst noposixbuiltins pipefail no_aliases 2> /dev/null 5 | 6 | # Run the shelloracle python module and store the result in the "selected" array 7 | local selected=( $(SHOR_DEFAULT_PROMPT=${LBUFFER} shor) ) 8 | 9 | # Get the return status of the last executed command 10 | local ret=$? 11 | 12 | # Reset the prompt 13 | zle reset-prompt 14 | 15 | # Set the BUFFER variable to the selected result 16 | BUFFER=$selected 17 | 18 | # Set the CURSOR position at the end of BUFFER 19 | CURSOR=$#BUFFER 20 | 21 | # Return the status 22 | return $ret 23 | } 24 | 25 | # Register the function as a ZLE widget 26 | zle -N shelloracle-widget 27 | 28 | # Install the ZLE widget as a keyboard shortcut Ctrl+F 29 | bindkey '^F' shelloracle-widget -------------------------------------------------------------------------------- /tests/providers/test_localai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from shelloracle.providers.localai import LocalAI 4 | 5 | 6 | class TestOpenAI: 7 | @pytest.fixture 8 | def localai_config(self): 9 | return { 10 | "shelloracle": {"provider": "LocalAI"}, 11 | "provider": {"LocalAI": {"host": "localhost", "port": 8080, "model": "mistral-openorca"}}, 12 | } 13 | 14 | @pytest.fixture 15 | def localai_instance(self, localai_config): 16 | return LocalAI(localai_config) 17 | 18 | def test_name(self): 19 | assert LocalAI.name == "LocalAI" 20 | 21 | def test_model(self, localai_instance): 22 | assert localai_instance.model == "mistral-openorca" 23 | 24 | @pytest.mark.asyncio 25 | async def test_generate(self, mock_asyncopenai, localai_instance): 26 | result = "" 27 | async for response in localai_instance.generate(""): 28 | result += response 29 | assert result == "head -c 100 /dev/urandom | hexdump -C" 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | release: 4 | types: [ published ] 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | release-build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.x" 16 | - name: Build release distributions 17 | run: | 18 | pip install build 19 | python -m build 20 | - name: upload windows dists 21 | uses: actions/upload-artifact@v4 22 | with: 23 | name: release-dists 24 | path: dist/ 25 | pypi-publish: 26 | runs-on: ubuntu-latest 27 | needs: 28 | - release-build 29 | environment: release 30 | permissions: 31 | id-token: write 32 | steps: 33 | - name: Retrieve release distributions 34 | uses: actions/download-artifact@v4 35 | with: 36 | name: release-dists 37 | path: dist/ 38 | - name: Publish release distributions to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /tests/providers/test_openai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from shelloracle.config import Configuration 4 | from shelloracle.providers.openai import OpenAI 5 | 6 | 7 | class TestOpenAI: 8 | @pytest.fixture 9 | def openai_config(self): 10 | config = { 11 | "shelloracle": {"provider": "OpenAI"}, 12 | "provider": { 13 | "OpenAI": {"api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", "model": "gpt-3.5-turbo"} 14 | }, 15 | } 16 | return Configuration(config) 17 | 18 | @pytest.fixture 19 | def openai_instance(self, openai_config): 20 | return OpenAI(openai_config) 21 | 22 | def test_name(self): 23 | assert OpenAI.name == "OpenAI" 24 | 25 | def test_api_key(self, openai_instance): 26 | assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 27 | 28 | def test_model(self, openai_instance): 29 | assert openai_instance.model == "gpt-3.5-turbo" 30 | 31 | @pytest.mark.asyncio 32 | async def test_generate(self, mock_asyncopenai, openai_instance): 33 | result = "" 34 | async for response in openai_instance.generate(""): 35 | result += response 36 | assert result == "head -c 100 /dev/urandom | hexdump -C" 37 | -------------------------------------------------------------------------------- /src/shelloracle/providers/google.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import google.generativeai as genai 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class Google(Provider): 14 | name = "Google" 15 | 16 | api_key = Setting(default="") 17 | model = Setting(default="gemini-2.0-flash") # Assuming a default model name 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__(*args, **kwargs) 21 | if not self.api_key: 22 | msg = "No API key provided" 23 | raise ProviderError(msg) 24 | genai.configure(api_key=self.api_key) 25 | self.model_instance = genai.GenerativeModel(self.model, system_instruction=system_prompt) 26 | 27 | async def generate(self, prompt: str) -> AsyncIterator[str]: 28 | try: 29 | response = await self.model_instance.generate_content_async( 30 | [prompt], 31 | stream=True, 32 | ) 33 | 34 | async for chunk in response: 35 | yield chunk.text 36 | except Exception as e: 37 | msg = f"Something went wrong while querying Google Gemini: {e}" 38 | raise ProviderError(msg) from e 39 | -------------------------------------------------------------------------------- /tests/providers/test_xai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from shelloracle.config import Configuration 4 | from shelloracle.providers.xai import XAI 5 | 6 | 7 | class TestOpenAI: 8 | @pytest.fixture 9 | def xai_config(self): 10 | config = { 11 | "shelloracle": {"provider": "XAI"}, 12 | "provider": { 13 | "XAI": { 14 | "api_key": "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 15 | "model": "grok-beta", 16 | } 17 | }, 18 | } 19 | return Configuration(config) 20 | 21 | @pytest.fixture 22 | def xai_instance(self, xai_config): 23 | return XAI(xai_config) 24 | 25 | def test_name(self): 26 | assert XAI.name == "XAI" 27 | 28 | def test_api_key(self, xai_instance): 29 | assert ( 30 | xai_instance.api_key 31 | == "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 32 | ) 33 | 34 | def test_model(self, xai_instance): 35 | assert xai_instance.model == "grok-beta" 36 | 37 | @pytest.mark.asyncio 38 | async def test_generate(self, mock_asyncopenai, xai_instance): 39 | result = "" 40 | async for response in xai_instance.generate(""): 41 | result += response 42 | assert result == "head -c 100 /dev/urandom | hexdump -C" 43 | -------------------------------------------------------------------------------- /src/shelloracle/providers/openai.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from openai import APIError, AsyncOpenAI 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class OpenAI(Provider): 14 | name = "OpenAI" 15 | 16 | api_key = Setting(default="") 17 | model = Setting(default="gpt-3.5-turbo") 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__(*args, **kwargs) 21 | if not self.api_key: 22 | msg = "No API key provided" 23 | raise ProviderError(msg) 24 | self.client = AsyncOpenAI(api_key=self.api_key) 25 | 26 | async def generate(self, prompt: str) -> AsyncIterator[str]: 27 | try: 28 | stream = await self.client.chat.completions.create( 29 | model=self.model, 30 | messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], 31 | stream=True, 32 | ) 33 | async for chunk in stream: 34 | if chunk.choices[0].delta.content is not None: 35 | yield chunk.choices[0].delta.content 36 | except APIError as e: 37 | msg = f"Something went wrong while querying OpenAI: {e}" 38 | raise ProviderError(msg) from e 39 | -------------------------------------------------------------------------------- /tests/providers/conftest.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | 5 | 6 | def split_with_delimiter(string, delim): 7 | result = [] 8 | last_split = 0 9 | for index, character in enumerate(string): 10 | if character == delim: 11 | result.append(string[last_split : index + 1]) 12 | last_split = index + 1 13 | if last_split != len(string): 14 | result.append(string[last_split:]) 15 | return result 16 | 17 | 18 | @pytest.fixture 19 | def mock_asyncopenai(monkeypatch): 20 | class AsyncChatCompletionIterator: 21 | def __init__(self, answer: str): 22 | self.answer_index = 0 23 | self.answer_deltas = split_with_delimiter(answer, " ") 24 | 25 | def __aiter__(self): 26 | return self 27 | 28 | async def __anext__(self): 29 | if self.answer_index >= len(self.answer_deltas): 30 | raise StopAsyncIteration 31 | answer_chunk = self.answer_deltas[self.answer_index] 32 | self.answer_index += 1 33 | choice = MagicMock() 34 | choice.delta.content = answer_chunk 35 | chunk = MagicMock() 36 | chunk.choices = [choice] 37 | return chunk 38 | 39 | async def mock_acreate(*args, **kwargs): 40 | return AsyncChatCompletionIterator("head -c 100 /dev/urandom | hexdump -C") 41 | 42 | monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate) 43 | -------------------------------------------------------------------------------- /tests/providers/test_deepseek.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from shelloracle.config import Configuration 4 | from shelloracle.providers.deepseek import Deepseek 5 | 6 | 7 | class TestOpenAI: 8 | @pytest.fixture 9 | def deepseek_config(self): 10 | config = { 11 | "shelloracle": {"provider": "Deepseek"}, 12 | "provider": { 13 | "Deepseek": { 14 | "api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 15 | "model": "grok-beta", 16 | } 17 | }, 18 | } 19 | return Configuration(config) 20 | 21 | @pytest.fixture 22 | def deepseek_instance(self, deepseek_config): 23 | return Deepseek(deepseek_config) 24 | 25 | def test_name(self): 26 | assert Deepseek.name == "Deepseek" 27 | 28 | def test_api_key(self, deepseek_instance): 29 | assert ( 30 | deepseek_instance.api_key 31 | == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 32 | ) 33 | 34 | def test_model(self, deepseek_instance): 35 | assert deepseek_instance.model == "grok-beta" 36 | 37 | @pytest.mark.asyncio 38 | async def test_generate(self, mock_asyncopenai, deepseek_instance): 39 | result = "" 40 | async for response in deepseek_instance.generate(""): 41 | result += response 42 | assert result == "head -c 100 /dev/urandom | hexdump -C" 43 | -------------------------------------------------------------------------------- /src/shelloracle/providers/deepseek.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from openai import APIError, AsyncOpenAI 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class Deepseek(Provider): 14 | name = "Deepseek" 15 | 16 | api_key = Setting(default="") 17 | model = Setting(default="deepseek-chat") 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__(*args, **kwargs) 21 | if not self.api_key: 22 | msg = "No API key provided" 23 | raise ProviderError(msg) 24 | self.client = AsyncOpenAI(base_url="https://api.deepseek.com/v1", api_key=self.api_key) 25 | 26 | async def generate(self, prompt: str) -> AsyncIterator[str]: 27 | try: 28 | stream = await self.client.chat.completions.create( 29 | model=self.model, 30 | messages=[ 31 | {"role": "system", "content": system_prompt}, 32 | {"role": "user", "content": prompt}, 33 | ], 34 | stream=True, 35 | ) 36 | async for chunk in stream: 37 | if chunk.choices[0].delta.content is not None: 38 | yield chunk.choices[0].delta.content 39 | except APIError as e: 40 | msg = f"Something went wrong while querying Deepseek: {e}" 41 | raise ProviderError(msg) from e 42 | -------------------------------------------------------------------------------- /src/shelloracle/providers/openai_compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from openai import APIError, AsyncOpenAI 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class OpenAICompat(Provider): 14 | name = "OpenAICompat" 15 | 16 | base_url = Setting(default="") 17 | api_key = Setting(default="") 18 | model = Setting(default="") 19 | 20 | def __init__(self, *args, **kwargs) -> None: 21 | super().__init__(*args, **kwargs) 22 | if not self.api_key: 23 | msg = "No API key provided. Use a dummy placeholder if no key is required" 24 | raise ProviderError(msg) 25 | self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) 26 | 27 | async def generate(self, prompt: str) -> AsyncIterator[str]: 28 | try: 29 | stream = await self.client.chat.completions.create( 30 | model=self.model, 31 | messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], 32 | stream=True, 33 | ) 34 | async for chunk in stream: 35 | if chunk.choices[0].delta.content is not None: 36 | yield chunk.choices[0].delta.content 37 | except APIError as e: 38 | msg = f"Something went wrong while querying OpenAICompat: {e}" 39 | raise ProviderError(msg) from e 40 | -------------------------------------------------------------------------------- /src/shelloracle/providers/localai.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from openai import APIError, AsyncOpenAI 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class LocalAI(Provider): 14 | name = "LocalAI" 15 | 16 | host = Setting(default="localhost") 17 | port = Setting(default=8080) 18 | model = Setting(default="mistral-openorca") 19 | 20 | @property 21 | def endpoint(self) -> str: 22 | return f"http://{self.host}:{self.port}" 23 | 24 | def __init__(self, *args, **kwargs) -> None: 25 | super().__init__(*args, **kwargs) 26 | # Use a placeholder API key so the client will work 27 | self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint) 28 | 29 | async def generate(self, prompt: str) -> AsyncIterator[str]: 30 | try: 31 | stream = await self.client.chat.completions.create( 32 | model=self.model, 33 | messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], 34 | stream=True, 35 | ) 36 | async for chunk in stream: 37 | if chunk.choices[0].delta.content is not None: 38 | yield chunk.choices[0].delta.content 39 | except APIError as e: 40 | msg = f"Something went wrong while querying LocalAI: {e}" 41 | raise ProviderError(msg) from e 42 | -------------------------------------------------------------------------------- /src/shelloracle/providers/xai.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from openai import APIError, AsyncOpenAI 6 | 7 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import AsyncIterator 11 | 12 | 13 | class XAI(Provider): 14 | name = "XAI" 15 | 16 | api_key = Setting(default="") 17 | model = Setting(default="grok-beta") 18 | 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__(*args, **kwargs) 21 | if not self.api_key: 22 | msg = "No API key provided" 23 | raise ProviderError(msg) 24 | self.client = AsyncOpenAI( 25 | api_key=self.api_key, 26 | base_url="https://api.x.ai/v1", 27 | ) 28 | 29 | async def generate(self, prompt: str) -> AsyncIterator[str]: 30 | try: 31 | stream = await self.client.chat.completions.create( 32 | model=self.model, 33 | messages=[ 34 | {"role": "system", "content": system_prompt}, 35 | {"role": "user", "content": prompt}, 36 | ], 37 | stream=True, 38 | ) 39 | async for chunk in stream: 40 | if chunk.choices[0].delta.content is not None: 41 | yield chunk.choices[0].delta.content 42 | except APIError as e: 43 | msg = f"Something went wrong while querying XAI: {e}" 44 | raise ProviderError(msg) from e 45 | -------------------------------------------------------------------------------- /tests/test_shelloracle.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from unittest.mock import MagicMock, call 6 | 7 | import pytest 8 | from yaspin.spinners import Spinners 9 | 10 | from shelloracle.shelloracle import get_query_from_pipe, spinner 11 | 12 | 13 | @pytest.fixture 14 | def mock_yaspin(monkeypatch): 15 | mock = MagicMock() 16 | monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock) 17 | return mock 18 | 19 | 20 | @pytest.mark.parametrize(("spinner_style", "expected"), [(None, call()), ("earth", call(Spinners.earth))]) 21 | def test_spinner(spinner_style, expected, mock_yaspin): 22 | spinner(spinner_style) 23 | assert mock_yaspin.call_args == expected 24 | 25 | 26 | def test_spinner_fail(mock_yaspin): 27 | with pytest.raises(AttributeError): 28 | spinner("not a spinner style") 29 | 30 | 31 | @pytest.mark.parametrize( 32 | ("isatty", "readlines", "expected"), [(True, None, None), (False, [], None), (False, ["what is up"], "what is up")] 33 | ) 34 | def test_get_query_from_pipe(isatty, readlines, expected, monkeypatch): 35 | monkeypatch.setattr(os, "isatty", lambda _: isatty) 36 | monkeypatch.setattr(sys.stdin, "readlines", lambda: readlines) 37 | assert get_query_from_pipe() == expected 38 | 39 | 40 | def test_get_query_from_pipe_fail(monkeypatch): 41 | monkeypatch.setattr(os, "isatty", lambda _: False) 42 | monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up", "what is down"]) 43 | with pytest.raises(ValueError, match="Multi-line input is not supported"): 44 | get_query_from_pipe() 45 | -------------------------------------------------------------------------------- /tests/providers/test_ollama.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_httpx import IteratorStream 3 | 4 | from shelloracle.config import Configuration 5 | from shelloracle.providers.ollama import Ollama 6 | 7 | 8 | class TestOllama: 9 | @pytest.fixture 10 | def ollama_config(self): 11 | config = { 12 | "shelloracle": {"provider": "Ollama"}, 13 | "provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}}, 14 | } 15 | return Configuration(config) 16 | 17 | @pytest.fixture 18 | def ollama_instance(self, ollama_config): 19 | return Ollama(ollama_config) 20 | 21 | def test_name(self): 22 | assert Ollama.name == "Ollama" 23 | 24 | def test_host(self, ollama_instance): 25 | assert ollama_instance.host == "localhost" 26 | 27 | def test_port(self, ollama_instance): 28 | assert ollama_instance.port == 11434 29 | 30 | def test_model(self, ollama_instance): 31 | assert ollama_instance.model == "dolphin-mistral" 32 | 33 | def test_endpoint(self, ollama_instance): 34 | assert ollama_instance.endpoint == "http://localhost:11434/api/generate" 35 | 36 | @pytest.mark.asyncio 37 | async def test_generate(self, ollama_instance, httpx_mock): 38 | responses = [ 39 | b'{"response": "cat"}\n', 40 | b'{"response": " test"}\n', 41 | b'{"response": "."}\n', 42 | b'{"response": "py"}\n', 43 | b'{"response": ""}\n', 44 | ] 45 | httpx_mock.add_response(stream=IteratorStream(responses)) 46 | result = "" 47 | async for response in ollama_instance.generate(""): 48 | result += response 49 | assert result == "cat test.py" 50 | -------------------------------------------------------------------------------- /src/shelloracle/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import sys 5 | from collections.abc import Iterator, Mapping 6 | from typing import TYPE_CHECKING, Any 7 | 8 | from yaspin.spinners import SPINNERS_DATA 9 | 10 | if TYPE_CHECKING: 11 | from pathlib import Path 12 | 13 | 14 | if sys.version_info < (3, 11): 15 | import tomli as tomllib 16 | else: 17 | import tomllib 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Configuration(Mapping): 23 | def __init__(self, config: dict[str, Any]) -> None: 24 | """ShellOracle application configuration 25 | 26 | :param config: configuration dict 27 | :raises FileNotFoundError: if the configuration file does not exist 28 | """ 29 | self._config = config 30 | 31 | def __getitem__(self, key: str) -> Any: 32 | return self._config[key] 33 | 34 | def __len__(self) -> int: 35 | return len(self._config) 36 | 37 | def __iter__(self) -> Iterator[Any]: 38 | return iter(self._config) 39 | 40 | def __str__(self): 41 | return f"Configuration({self._config})" 42 | 43 | def __repr__(self) -> str: 44 | return str(self) 45 | 46 | @property 47 | def raw_config(self) -> dict[str, Any]: 48 | return self._config 49 | 50 | @property 51 | def provider(self) -> str: 52 | return self["shelloracle"]["provider"] 53 | 54 | @property 55 | def spinner_style(self) -> str | None: 56 | style = self["shelloracle"].get("spinner_style", None) 57 | if not style: 58 | return None 59 | if style not in SPINNERS_DATA: 60 | logger.warning("invalid spinner style: %s", style) 61 | return None 62 | return style 63 | 64 | @classmethod 65 | def from_file(cls, filepath: Path): 66 | with filepath.open("rb") as config_file: 67 | config = tomllib.load(config_file) 68 | return cls(config) 69 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import tomli_w 5 | 6 | from shelloracle.config import Configuration 7 | 8 | 9 | class TestConfiguration: 10 | @pytest.fixture 11 | def default_config(self): 12 | return Configuration( 13 | { 14 | "shelloracle": {"provider": "Ollama", "spinner_style": "earth"}, 15 | "provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}}, 16 | } 17 | ) 18 | 19 | def test_from_file(self, default_config, tmp_path): 20 | config_path = tmp_path / "config.toml" 21 | with config_path.open("wb") as f: 22 | tomli_w.dump(default_config.raw_config, f) 23 | assert Configuration.from_file(config_path) == default_config 24 | 25 | def test_getitem(self, default_config): 26 | for key in default_config: 27 | assert default_config[key] == default_config.raw_config[key] 28 | 29 | def test_len(self, default_config): 30 | assert len(default_config) == len(default_config.raw_config) 31 | 32 | def test_iter(self, default_config): 33 | assert list(iter(default_config)) == list(iter(default_config.raw_config)) 34 | 35 | def test_provider(self, default_config): 36 | assert default_config.provider == "Ollama" 37 | 38 | def test_spinner_style(self, default_config): 39 | assert default_config.spinner_style == "earth" 40 | 41 | def test_no_spinner_style(self, caplog): 42 | config = Configuration( 43 | { 44 | "shelloracle": {"provider": "Ollama"}, 45 | "provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}}, 46 | } 47 | ) 48 | assert config.spinner_style is None 49 | assert "invalid spinner style" not in caplog.text 50 | 51 | def test_invalid_spinner_style(self, caplog): 52 | config = Configuration( 53 | { 54 | "shelloracle": {"provider": "Ollama", "spinner_style": "invalid"}, 55 | "provider": {"Ollama": {"host": "localhost", "port": 11434, "model": "dolphin-mistral"}}, 56 | } 57 | ) 58 | assert config.spinner_style is None 59 | assert "invalid spinner style" in caplog.text 60 | -------------------------------------------------------------------------------- /src/shelloracle/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import logging 5 | import os 6 | import sys 7 | from typing import TYPE_CHECKING 8 | 9 | import click 10 | 11 | from shelloracle.cli.application import Application 12 | from shelloracle.cli.config import config 13 | from shelloracle.config import Configuration 14 | from shelloracle.shelloracle import shelloracle 15 | from shelloracle.tty_log_handler import TtyLogHandler 16 | 17 | if TYPE_CHECKING: 18 | from pathlib import Path 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def configure_logging(log_path: Path): 24 | root_logger = logging.getLogger() 25 | root_logger.setLevel(logging.DEBUG) 26 | 27 | file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") 28 | file_handler = logging.FileHandler(log_path) 29 | file_handler.setLevel(logging.DEBUG) 30 | file_handler.setFormatter(file_formatter) 31 | 32 | tty_formatter = logging.Formatter("%(message)s") 33 | tty_handler = TtyLogHandler() 34 | tty_handler.setLevel(logging.WARNING) 35 | tty_handler.setFormatter(tty_formatter) 36 | 37 | root_logger.addHandler(file_handler) 38 | root_logger.addHandler(tty_handler) 39 | 40 | 41 | @click.group(invoke_without_command=True) 42 | @click.version_option() 43 | @click.pass_context 44 | def cli(ctx: click.Context): 45 | """ShellOracle command line interface.""" 46 | app = Application() 47 | configure_logging(app.log_path) 48 | ctx.obj = app 49 | 50 | if ctx.invoked_subcommand is not None: 51 | # If no subcommand is invoked, run the main CLI 52 | return 53 | 54 | try: 55 | app.configuration = Configuration.from_file(app.config_path) 56 | except FileNotFoundError: 57 | logger.warning("Configuration not found. Run `shor config init` to initialize.") 58 | sys.exit(1) 59 | 60 | asyncio.run(shelloracle(app)) 61 | 62 | 63 | cli.add_command(config) 64 | 65 | 66 | def main(): 67 | try: 68 | cli() 69 | except Exception: # noqa: BLE001 70 | import sys 71 | 72 | from rich.console import Console 73 | 74 | console = Console(stderr=True) 75 | shor_debug = os.getenv("SHOR_DEBUG") in {"1", "true"} 76 | console.print_exception(suppress=[click, asyncio], show_locals=shor_debug) 77 | sys.exit(1) 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "shelloracle" 7 | description = """ShellOracle is a pluggable terminal utility that takes a natural language description of a \ 8 | command and substitutes it into your terminal buffer.""" 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | dynamic = ["version"] 12 | authors = [ 13 | { name = "Daniel Copley", email = "djcopley@proton.me" }, 14 | ] 15 | requires-python = ">=3.9" 16 | classifiers = [ 17 | "Development Status :: 5 - Production/Stable", 18 | "Environment :: Console", 19 | "Operating System :: MacOS", 20 | "Operating System :: POSIX :: Linux", 21 | "Intended Audience :: Developers", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Programming Language :: Python :: 3.13", 28 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 29 | ] 30 | dependencies = [ 31 | "tomli ~= 2.1; python_version < '3.11'", 32 | "click~=8.1", 33 | "dspy~=2.6", 34 | "httpx~=0.28", 35 | "openai~=1.66", 36 | "prompt-toolkit~=3.0", 37 | "yaspin~=3.1", 38 | "tomlkit~=0.13", 39 | "google-generativeai~=0.8", 40 | "pygments~=2.19", 41 | "rich>=13.9.4", 42 | ] 43 | 44 | [dependency-groups] 45 | test = [ 46 | "pytest~=8.3", 47 | "pytest-cov~=5.0", 48 | "pytest-sugar~=1.0", 49 | "pytest-xdist~=3.6", 50 | "pytest-asyncio~=0.24", 51 | "pytest-httpx~=0.30", 52 | "tomli-w~=1.2.0", 53 | ] 54 | 55 | [project.scripts] 56 | shor = "shelloracle.cli:main" 57 | 58 | [project.urls] 59 | Homepage = "https://github.com/djcopley/ShellOracle" 60 | Repository = "https://github.com/djcopley/ShellOracle.git" 61 | Issues = "https://github.com/djcopley/ShellOracle/issues" 62 | 63 | [tool.hatch.version] 64 | source = "vcs" 65 | 66 | [tool.hatch.envs.hatch-test] 67 | dependencies = [ 68 | "pytest~=8.3", 69 | "pytest-cov~=5.0", 70 | "pytest-sugar~=1.0", 71 | "pytest-xdist~=3.6", 72 | "pytest-asyncio~=0.24", 73 | "pytest-httpx~=0.30", 74 | "tomli-w>=1.2.0", 75 | ] 76 | 77 | [[tool.hatch.envs.hatch-test.matrix]] 78 | python = ["3.9", "3.10", "3.11", "3.12", "3.13"] 79 | 80 | [tool.hatch.envs.types] 81 | template = "hatch-test" 82 | extra-dependencies = [ 83 | "mypy~=1.0", 84 | "types-Pygments", 85 | ] 86 | 87 | [tool.hatch.envs.types.scripts] 88 | check = [ 89 | "mypy {args:src/shelloracle}", 90 | "mypy --explicit-package-bases tests" 91 | ] 92 | 93 | [tool.hatch.envs.profile.scripts] 94 | importtime = "python -X importtime -m shelloracle 2> {args}" 95 | 96 | [tool.pytest.ini_options] 97 | pythonpath = "src" 98 | addopts = [ 99 | "--import-mode=importlib", 100 | ] 101 | asyncio_default_fixture_loop_scope = "function" 102 | 103 | [tool.coverage.run] 104 | source = ["src/"] 105 | branch = true 106 | parallel = true 107 | 108 | [tool.coverage.report] 109 | exclude_lines = [ 110 | "no cov", 111 | "if __name__ == .__main__.:", 112 | "if TYPE_CHECKING:", 113 | ] 114 | 115 | [tool.ruff.lint.extend-per-file-ignores] 116 | "tests/*" = ["INP001", "ARG"] 117 | -------------------------------------------------------------------------------- /src/shelloracle/shelloracle.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import sys 6 | from pathlib import Path 7 | from typing import TYPE_CHECKING 8 | 9 | from prompt_toolkit import PromptSession 10 | from prompt_toolkit.application import create_app_session_from_tty 11 | from prompt_toolkit.history import FileHistory 12 | from prompt_toolkit.patch_stdout import patch_stdout 13 | from yaspin import yaspin 14 | from yaspin.spinners import Spinners 15 | 16 | from shelloracle.providers import get_provider 17 | 18 | if TYPE_CHECKING: 19 | from yaspin.core import Yaspin 20 | 21 | from shelloracle.cli import Application 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | async def prompt_user(default_prompt: str | None = None) -> str: 27 | # stdin doesn't exist when running as a zle widget 28 | with create_app_session_from_tty(), patch_stdout(): 29 | history_file = Path.home() / ".shelloracle_history" 30 | prompt_session: PromptSession = PromptSession(history=FileHistory(str(history_file))) 31 | prompt_session.output.write_raw("\033[E") 32 | return await prompt_session.prompt_async("> ", default=default_prompt or "") 33 | 34 | 35 | def get_query_from_pipe() -> str | None: 36 | """Get a query from stdin pipe. 37 | 38 | :raises ValueError: If the input is more than one line 39 | :return: The query from the stdin pipe 40 | """ 41 | if os.isatty(0) or not (lines := sys.stdin.readlines()): # Return 'None' if fd 0 is a tty (no pipe) 42 | return None 43 | if len(lines) > 1: 44 | msg = "Multi-line input is not supported" 45 | raise ValueError(msg) 46 | logger.debug("using query from stdin: %s", lines) 47 | return lines[0].rstrip() 48 | 49 | 50 | def spinner(style: str | None) -> Yaspin: 51 | """Get the correct spinner based on the user's configuration 52 | 53 | :param style: The spinner style 54 | :returns: yaspin object 55 | """ 56 | if style: 57 | style = getattr(Spinners, style) 58 | return yaspin(style) 59 | return yaspin() 60 | 61 | 62 | async def shelloracle(app: Application) -> None: 63 | """ShellOracle program entrypoint 64 | 65 | If there is a query from the input pipe, it processes the query to generate a response. 66 | If there isn't a query from the input pipe, it prompts the user for input. 67 | 68 | Environment variables: 69 | - SHOR_DEFAULT_PROMPT: This is the initial user prompt that can be configured via this environment variable. 70 | 71 | :returns: None 72 | :raises KeyboardInterrupt: if the user presses CTRL+C 73 | """ 74 | if not (prompt := get_query_from_pipe()): 75 | default_prompt = os.environ.get("SHOR_DEFAULT_PROMPT") 76 | prompt = await prompt_user(default_prompt) 77 | logger.info("user prompt: %s", prompt) 78 | 79 | provider = get_provider(app.configuration.provider)(app.configuration) 80 | 81 | shell_command = "" 82 | with create_app_session_from_tty(), patch_stdout(raw=True), spinner(app.configuration.spinner_style) as sp: 83 | async for token in provider.generate(prompt): 84 | # some models may erroneously return a newline, which causes issues with the status spinner 85 | shell_command += token.replace("\n", "") 86 | sp.text = shell_command 87 | logger.info("generated shell command: %s", shell_command) 88 | sys.stdout.write(shell_command) 89 | -------------------------------------------------------------------------------- /src/shelloracle/providers/ollama.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from dataclasses import asdict, dataclass 5 | from typing import TYPE_CHECKING, Any 6 | 7 | import httpx 8 | 9 | from shelloracle.providers import Provider, ProviderError, Setting, system_prompt 10 | 11 | if TYPE_CHECKING: 12 | from collections.abc import AsyncIterator 13 | 14 | 15 | def dataclass_to_json(obj: Any) -> dict[str, Any]: 16 | """Convert dataclass to a json dict 17 | 18 | This function filters out 'None' values. 19 | 20 | :param obj: the dataclass to serialize 21 | :return: serialized dataclass 22 | :raises TypeError: if obj is not a dataclass 23 | """ 24 | return {k: v for k, v in asdict(obj).items() if v is not None} 25 | 26 | 27 | @dataclass 28 | class GenerateRequest: 29 | model: str 30 | """(required) the model name""" 31 | prompt: str | None = None 32 | """the prompt to generate a response for""" 33 | images: list[str] | None = None 34 | """a list of base64-encoded images (for multimodal models such as llava)""" 35 | format: str | None = None 36 | """the format to return a response in. Currently the only accepted value is json""" 37 | options: dict | None = None 38 | """additional model parameters listed in the documentation for the Modelfile such as temperature""" 39 | system: str | None = None 40 | """system prompt to (overrides what is defined in the Modelfile)""" 41 | template: str | None = None 42 | """the full prompt or prompt template (overrides what is defined in the Modelfile)""" 43 | context: str | None = None 44 | """the context parameter returned from a previous request to /generate, this can be used to keep a short 45 | conversational memory""" 46 | stream: bool | None = None 47 | """if false the response will be returned as a single response object, rather than a stream of objects""" 48 | raw: bool | None = None 49 | """if true no formatting will be applied to the prompt and no context will be returned. You may choose to use 50 | the raw parameter if you are specifying a full templated prompt in your request to the API, and are managing 51 | history yourself. JSON mode""" 52 | 53 | 54 | class Ollama(Provider): 55 | name = "Ollama" 56 | 57 | host = Setting(default="localhost") 58 | port = Setting(default=11434) 59 | model = Setting(default="dolphin-mistral") 60 | 61 | @property 62 | def endpoint(self) -> str: 63 | # computed property because python descriptors need to be bound to an instance before access 64 | return f"http://{self.host}:{self.port}/api/generate" 65 | 66 | async def generate(self, prompt: str) -> AsyncIterator[str]: 67 | request = GenerateRequest(self.model, prompt, system=system_prompt, stream=True) 68 | data = dataclass_to_json(request) 69 | try: 70 | async with ( 71 | httpx.AsyncClient() as client, 72 | client.stream("POST", self.endpoint, json=data, timeout=20.0) as stream, 73 | ): 74 | async for line in stream.aiter_lines(): 75 | response = json.loads(line) 76 | if "error" in response: 77 | raise ProviderError(response["error"]) 78 | yield response["response"] 79 | except (httpx.HTTPError, httpx.StreamError) as e: 80 | msg = f"Something went wrong while querying Ollama: {e}" 81 | raise ProviderError(msg) from e 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Jetbrains folder 156 | .idea/ 157 | 158 | # Installer dependencies 159 | /installer/installer/**/* 160 | 161 | # Demo scripts 162 | demo/ 163 | 164 | # Experimental scripts 165 | experiments/ 166 | 167 | # Jujutsu VCS 168 | /.jj/ 169 | 170 | # macos 171 | .DS_Store 172 | -------------------------------------------------------------------------------- /src/shelloracle/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from abc import abstractmethod 5 | from typing import TYPE_CHECKING, Generic, TypeVar 6 | 7 | if TYPE_CHECKING: 8 | from collections.abc import AsyncIterator 9 | 10 | from shelloracle.config import Configuration 11 | 12 | system_prompt = ( 13 | "Based on the following user description, generate a corresponding shell command. Focus solely " 14 | "on interpreting the requirements and translating them into a single, executable Bash command. " 15 | "Ensure accuracy and relevance to the user's description. The output should be a valid shell " 16 | "command that directly aligns with the user's intent, ready for execution in a command-line " 17 | "environment. Do not output anything except for the command. No code block, no English explanation, " 18 | "no newlines, and no start/end tags." 19 | ) 20 | 21 | 22 | class ProviderError(Exception): 23 | """LLM providers raise this error to gracefully indicate something has gone wrong.""" 24 | 25 | 26 | class Provider(abc.ABC): 27 | """ 28 | LLM Provider Protocol 29 | 30 | All LLM backends must implement this interface. 31 | """ 32 | 33 | name: str 34 | config: Configuration 35 | 36 | def __init__(self, config: Configuration) -> None: 37 | """Initialize the provider with the given configuration. 38 | 39 | :param config: the configuration object 40 | :return: none 41 | """ 42 | self.config = config 43 | 44 | @abstractmethod 45 | def generate(self, prompt: str) -> AsyncIterator[str]: 46 | """ 47 | This is an asynchronous generator method which defines the protocol that a provider implementation 48 | should adhere to. The method takes a prompt as an argument and produces an asynchronous stream 49 | of string results. 50 | 51 | :param prompt: A string value which serves as input to the provider's process of generating results. 52 | :return: An asynchronous generator yielding string results. 53 | """ 54 | # If you are wondering why the 'generate' signature doesn't include 'async', see 55 | # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators 56 | 57 | 58 | T = TypeVar("T") 59 | 60 | 61 | class Setting(Generic[T]): 62 | def __init__(self, *, name: str | None = None, default: T | None = None) -> None: 63 | self.name = name 64 | self.default = default 65 | 66 | def __set_name__(self, owner: type[Provider], name: str) -> None: 67 | if not self.name: 68 | self.name = name 69 | 70 | def __get__(self, instance: Provider, owner: type[Provider]) -> T: 71 | if instance is None: 72 | # Accessing settings as a class attribute is not supported because it prevents 73 | # inspect.get_members from determining the object type 74 | msg = "Settings must be accessed through a provider instance." 75 | raise AttributeError(msg) 76 | try: 77 | return instance.config["provider"][owner.name][self.name] 78 | except KeyError: 79 | if self.default is None: 80 | raise 81 | return self.default 82 | 83 | 84 | def _providers() -> dict[str, type[Provider]]: 85 | from shelloracle.providers.deepseek import Deepseek 86 | from shelloracle.providers.google import Google 87 | from shelloracle.providers.localai import LocalAI 88 | from shelloracle.providers.ollama import Ollama 89 | from shelloracle.providers.openai import OpenAI 90 | from shelloracle.providers.openai_compat import OpenAICompat 91 | from shelloracle.providers.xai import XAI 92 | 93 | return { 94 | Ollama.name: Ollama, 95 | OpenAI.name: OpenAI, 96 | OpenAICompat.name: OpenAICompat, 97 | LocalAI.name: LocalAI, 98 | XAI.name: XAI, 99 | Deepseek.name: Deepseek, 100 | Google.name: Google, 101 | } 102 | 103 | 104 | def get_provider(name: str) -> type[Provider]: 105 | """Imports and loads a requested provider 106 | 107 | :param name: the provider name 108 | :return: the requested provider 109 | """ 110 | 111 | return _providers()[name] 112 | 113 | 114 | def list_providers() -> list[str]: 115 | return list(_providers()) 116 | -------------------------------------------------------------------------------- /src/shelloracle/bootstrap.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import shutil 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Any 7 | 8 | import tomlkit 9 | from prompt_toolkit import print_formatted_text, prompt 10 | from prompt_toolkit.completion import WordCompleter 11 | from prompt_toolkit.formatted_text import FormattedText 12 | from prompt_toolkit.shortcuts import confirm 13 | 14 | from shelloracle.providers import Provider, Setting, get_provider, list_providers 15 | 16 | if TYPE_CHECKING: 17 | from collections.abc import Iterator, Sequence 18 | 19 | 20 | class UserError(Exception): ... 21 | 22 | 23 | def print_info(info: str) -> None: 24 | print_formatted_text(FormattedText([("ansiblue", info)])) 25 | 26 | 27 | def print_warning(warning: str) -> None: 28 | print_formatted_text(FormattedText([("ansiyellow", warning)])) 29 | 30 | 31 | def print_error(error: str) -> None: 32 | print_formatted_text(FormattedText([("ansired", error)])) 33 | 34 | 35 | def replace_home_with_tilde(path: Path) -> Path: 36 | relative_path = path.relative_to(Path.home()) 37 | return Path("~") / relative_path 38 | 39 | 40 | supported_shells = ("zsh", "bash", "fish") 41 | 42 | 43 | def get_installed_shells() -> list[str]: 44 | return [shell for shell in supported_shells if shutil.which(shell)] 45 | 46 | 47 | def get_bundled_script_path(shell: str) -> Path: 48 | shell_dir = Path(__file__).parent / "shell" 49 | if shell == "zsh": 50 | return shell_dir / "shelloracle.zsh" 51 | if shell == "fish": 52 | return shell_dir / "shelloracle.fish" 53 | return shell_dir / "shelloracle.bash" 54 | 55 | 56 | def get_script_path(shell: str) -> Path: 57 | if shell == "zsh": 58 | return Path.home() / ".shelloracle.zsh" 59 | if shell == "fish": 60 | return Path.home() / ".shelloracle.fish" 61 | return Path.home() / ".shelloracle.bash" 62 | 63 | 64 | def get_rc_path(shell: str) -> Path: 65 | if shell == "zsh": 66 | return Path.home() / ".zshrc" 67 | if shell == "fish": 68 | return Path.home() / ".config/fish/config.fish" 69 | return Path.home() / ".bashrc" 70 | 71 | 72 | def write_script_home(shell: str) -> None: 73 | shelloracle = get_bundled_script_path(shell).read_bytes() 74 | destination = get_script_path(shell) 75 | destination.write_bytes(shelloracle) 76 | print_info(f"Successfully wrote key bindings to {replace_home_with_tilde(destination)}") 77 | 78 | 79 | def update_rc(shell: str) -> None: 80 | rc_path = get_rc_path(shell) 81 | rc_path.touch(exist_ok=True) 82 | with rc_path.open("r") as file: 83 | rc_content = file.read() 84 | if shell == "fish": 85 | line = f"if test -f {get_script_path(shell)}; source {get_script_path(shell)}; end" 86 | else: 87 | shelloracle_script = get_script_path(shell) 88 | line = f"[ -f {shelloracle_script} ] && source {shelloracle_script}" 89 | if line not in rc_content: 90 | with rc_path.open("a") as file: 91 | file.write("\n") 92 | file.write(line) 93 | print_info(f"Successfully updated {replace_home_with_tilde(rc_path)}") 94 | 95 | 96 | def get_settings(provider: type[Provider]) -> Iterator[tuple[str, Setting]]: 97 | settings = inspect.getmembers(provider, predicate=lambda p: isinstance(p, Setting)) 98 | 99 | def correct_name_setting(): 100 | for name, setting in settings: 101 | yield setting.name or name, setting 102 | 103 | yield from correct_name_setting() 104 | 105 | 106 | def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any], config_path: Path) -> None: 107 | config = tomlkit.document() 108 | 109 | shor_table = tomlkit.table() 110 | shor_table.add("provider", provider.name) 111 | config.add("shelloracle", shor_table) 112 | 113 | provider_table = tomlkit.table() 114 | config.add("provider", provider_table) 115 | 116 | provider_configuration_table = tomlkit.table() 117 | for setting, value in settings.items(): 118 | provider_configuration_table.add(setting, value) 119 | provider_table.add(provider.name, provider_configuration_table) 120 | 121 | with config_path.open("w") as config_file: 122 | tomlkit.dump(config, config_file) 123 | 124 | 125 | def install_keybindings() -> None: 126 | if not (shells := get_installed_shells()): 127 | print_warning( 128 | "Cannot install keybindings: no compatible shells found. " f"Supported shells: {' '.join(supported_shells)}" 129 | ) 130 | return 131 | if confirm("Enable terminal keybindings and update rc?", suffix=" ([y]/n) ") is False: 132 | return 133 | for shell in shells: 134 | write_script_home(shell) 135 | update_rc(shell) 136 | 137 | 138 | def user_configure_settings(provider: type[Provider]) -> dict[str, Any]: 139 | settings = {} 140 | for name, setting in get_settings(provider): 141 | user_input = prompt(f"{name}: ", default=str(setting.default)) 142 | type_ = type(setting.default) if setting.default else str 143 | value = type_(user_input) # type: ignore[operator] 144 | settings[name] = value 145 | return settings 146 | 147 | 148 | def case_correct_user_input(user_input: str, options: Sequence[str]) -> str | None: 149 | for option in options: 150 | if user_input.lower() == option.lower(): 151 | return option 152 | return None 153 | 154 | 155 | def user_select_provider() -> type[Provider]: 156 | providers = list_providers() 157 | completer = WordCompleter(providers, ignore_case=True) 158 | user_selected_provider = prompt(f"Choose your LLM provider ({', '.join(providers)}): ", completer=completer) 159 | if (provider_name := case_correct_user_input(user_selected_provider, providers)) is None: 160 | msg = f"Invalid provider: {user_selected_provider or 'no input'}" 161 | raise UserError(msg) 162 | return get_provider(provider_name) 163 | 164 | 165 | def bootstrap_shelloracle(config_path: Path) -> None: 166 | try: 167 | provider = user_select_provider() 168 | settings = user_configure_settings(provider) 169 | except UserError as e: 170 | print_error(str(e)) 171 | return 172 | except KeyboardInterrupt: 173 | return 174 | write_shelloracle_config(provider, settings, config_path) 175 | install_keybindings() 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |