├── .github └── workflows │ ├── publish.yml │ └── tests.yml ├── .gitignore ├── License.md ├── README.md ├── docs └── images │ ├── ascii.gif │ ├── python.gif │ └── sh.gif ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── script └── eval ├── shy_sh ├── agents │ ├── __init__.py │ ├── chains │ │ ├── __init__.py │ │ ├── alternative_commands.py │ │ ├── explain.py │ │ ├── python_expert.py │ │ ├── shell_expert.py │ │ └── shy_agent.py │ ├── llms.py │ ├── misc.py │ ├── shy_agent │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── audio.py │ │ ├── edges │ │ │ ├── __init__.py │ │ │ ├── final_response.py │ │ │ └── tool_calls.py │ │ ├── graph.py │ │ └── nodes │ │ │ ├── __init__.py │ │ │ ├── chatbot.py │ │ │ └── tools_handler.py │ └── tools │ │ ├── __init__.py │ │ ├── python_expert.py │ │ ├── shell.py │ │ ├── shell_expert.py │ │ └── shell_history.py ├── main.py ├── manager │ ├── chat_manager.py │ ├── history.py │ └── sql_models.py ├── models.py ├── settings.py └── utils.py └── tests ├── conftest.py ├── test_cli.py ├── test_eval.py ├── test_eval_experts.py └── utils.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | tags: 4 | - '*.*.*' 5 | 6 | name: Publish shy-sh to PyPI 7 | 8 | jobs: 9 | build: 10 | name: Release to PyPI 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: write 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python 3.11 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.11 22 | - name: Install Poetry 23 | uses: snok/install-poetry@v1 24 | with: 25 | version: 1.8.4 26 | - name: Set up Poetry 27 | run: poetry config pypi-token.pypi ${{ secrets.PYPY_API_TOKEN }} 28 | - name: Publish package 29 | run: poetry publish --build 30 | - name: Zip dist 31 | run: zip --junk-paths dist.zip dist/* 32 | - name: Create Release 33 | id: create_release 34 | uses: actions/create-release@v1 35 | env: 36 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 37 | with: 38 | tag_name: ${{ github.ref_name }} 39 | release_name: Release ${{ github.ref_name }} 40 | draft: false 41 | prerelease: false 42 | - name: Upload Release Asset 43 | id: upload-release-asset 44 | uses: actions/upload-release-asset@v1 45 | env: 46 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 47 | with: 48 | upload_url: ${{ steps.create_release.outputs.upload_url }} 49 | asset_path: ./dist.zip 50 | asset_name: shy-sh.zip 51 | asset_content_type: application/zip 52 | - uses: actions/checkout@v4 53 | with: 54 | ref: main 55 | - name: Bump version 56 | run: | 57 | git checkout main 58 | git config --global user.name 'github-actions[bot]' 59 | git config --global user.email 'github-actions[bot]@users.noreply.github.com' 60 | poetry version patch 61 | git add pyproject.toml 62 | git commit -m "Bump version to $(poetry version -s)" 63 | git push 64 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | on: 2 | - push 3 | 4 | name: Tests 5 | 6 | jobs: 7 | build: 8 | name: Tests 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 3.11 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: 3.11 16 | - name: Install Poetry 17 | uses: snok/install-poetry@v1 18 | with: 19 | version: 1.8.4 20 | - name: Install dependencies 21 | run: poetry install 22 | - name: Test 23 | run: poetry run pytest 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .DS_Store 165 | shy.yml -------------------------------------------------------------------------------- /License.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Mattia Cecchini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shy.sh 2 | 3 | Sh shell AI copilot 4 | 5 | ![image_cover](./docs/images/sh.gif) 6 | 7 | ## Install 8 | 9 | ```sh 10 | pip install shy-sh 11 | ``` 12 | 13 | Configure your LLM 14 | 15 | ```sh 16 | shy --configure 17 | ``` 18 | 19 | Supported providers: openai, anthropic, google, groq, aws, ollama 20 | 21 | ## Help 22 | 23 | Usage: `shy [OPTIONS] [PROMPT]...` 24 | 25 | Arguments 26 | prompt [PROMPT] 27 | 28 | Options 29 | 30 | - -x Do not ask confirmation before executing scripts 31 | - -e Explain the given shell command 32 | - --configure Configure LLM 33 | - --help Show this message and exit. 34 | 35 | ## Settings 36 | 37 | ```sh 38 | shy --configure 39 | Provider: ollama 40 | Model: llama3.2 41 | Agent Pattern: react 42 | Temperature: 0.0 43 | Language: klingon 44 | Sandbox Mode: Yes 45 | ``` 46 | 47 | #### Configurable settings 48 | 49 | - Provider: The LLM provider to use [OpenAI, Anthropic, Google, Groq, AWS Bedrock, Ollama(local)]. 50 | - API Key: The API key for the LLM provider. (Format for aws bedrock: `region_name acces_key secret_key`) 51 | - Model: The LLM model to use. 52 | - Agent Pattern: react or function_call. (If you are not using OpenAI, Anthropic or Google, react is recommended) 53 | - Temperature: The LLM model's temperature setting. 54 | - Language: The language for the LLM's final answers. 55 | - Sandbox Mode: When enabled, no commands or scripts will be executed on your system; you will only receive suggestions. This feature is recommended for beginners. 56 | 57 | All the settings are saved in `~/.config/shy/config.yml` 58 | 59 | ## Examples 60 | 61 | ```sh 62 | > shy find all python files in this folder 63 | 64 | 🛠️ find . -type f -name '*.py' 65 | 66 | Do you want to execute this command? [Yes/no/copy/explain/alternatives]: 67 | 68 | ./src/chat_models.py 69 | ./src/agent/tools.py 70 | ./src/agent/__init__.py 71 | ./src/agent/agent.py 72 | ./src/settings.py 73 | ./src/main.py 74 | 75 | 🤖: Here are all the Python files found in the current folder and its subfolders. 76 | ``` 77 | 78 | ```sh 79 | > shy -x convert aaa.png to jpeg and resize to 200x200 80 | 81 | 🛠️ convert aaa.png -resize 200x200 aaa.jpg 82 | 83 | 🤖: I converted the file aaa.png to JPEG format and resized it to 200x200 pixels. 84 | ``` 85 | 86 | ```sh 87 | > shy resize movie.avi to 1024x768 and save it in mp4 88 | 89 | 🛠️ ffmpeg -i movie.avi -vf scale=1024:768 -c:v libx264 output.mp4 90 | 91 | Do you want to execute this command? [Yes/no/copy/explain/alternatives]: c 92 | 93 | 🤖: Command copied to the clipboard! 94 | ``` 95 | 96 | ```sh 97 | > shy 98 | 99 | ✨: Hello, how are you? 100 | 101 | 🤖: Hello! I'm fine thanks 102 | 103 | ✨: how many files in this folder 104 | 105 | 🛠️ ls | wc -l 106 | 107 | Do you want to execute this command? [Yes/no/copy/explain/alternatives]: 108 | 109 | 5 110 | 111 | ✨: exit 112 | 113 | 🤖: 👋 Bye! 114 | ``` 115 | 116 | ```sh 117 | > shy -e "find . -type f -name '*.py' | wc -l" 118 | 119 | 🤖: This shell command uses `find` to search for files (`-type f`) with the extension `.py` (`-name '*.py'`) in the current directory (`.`) and its subdirectories. 120 | The results are then piped to `wc -l`, which counts the number of line. 121 | In conclusion, the command presents the total count of Python files (*.py) located within the current directory and its subdirectories. 122 | ``` 123 | 124 | ![image_python](./docs/images/python.gif) 125 | 126 | ![image_ascii](./docs/images/ascii.gif) 127 | 128 | ## Chat commands 129 | 130 | You can use these commands during the chat: 131 | 132 | - `/chats` to list all the chats 133 | - `/clear` to clear the current chat 134 | - `/history` to list the recent executed commands/scripts 135 | - `/load [CHAT_ID]` to continue a previous chat 136 | 137 | ## Privacy 138 | 139 | If you are not using Ollama as provider, please note that information such as the current path, your operating system name, and the last commands executed in the shell may be included in the LLM context. 140 | -------------------------------------------------------------------------------- /docs/images/ascii.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/docs/images/ascii.gif -------------------------------------------------------------------------------- /docs/images/python.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/docs/images/python.gif -------------------------------------------------------------------------------- /docs/images/sh.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/docs/images/sh.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "shy_sh" 3 | version = "1.2.3" 4 | description = "Shell copilot - sh shell AI copilot" 5 | authors = ["Mattia Cecchini "] 6 | license = "MIT" 7 | repository = "https://github.com/mceck/shy-sh" 8 | readme = "README.md" 9 | classifiers=[ 10 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 11 | 'Topic :: Terminals', 12 | 'Topic :: Utilities', 13 | 'Programming Language :: Python :: 3', 14 | 'Programming Language :: Python :: 3.10', 15 | 'Programming Language :: Python :: 3.11', 16 | 'Programming Language :: Python :: 3.12', 17 | ] 18 | packages = [{include = "shy_sh"}] 19 | 20 | [tool.poetry.dependencies] 21 | python = "<3.14,>=3.10" 22 | typer = "^0.15.2" 23 | langchain = "^0.3.20" 24 | langgraph = "^0.2.76" 25 | langchain-anthropic = "^0.3.9" 26 | langchain-google-genai = "^2.0.11" 27 | langchain-groq = "^0.2.5" 28 | langchain-ollama = "^0.2.3" 29 | langchain-openai = "^0.3.8" 30 | tiktoken = "^0.9.0" 31 | pyyaml = "^6.0.2" 32 | pydantic-settings = "^2.8.1" 33 | pyperclip = "^1.9.0" 34 | questionary = "^2.1.0" 35 | pyreadline3 = {version = "^3.5.4", platform = "win32"} 36 | speechrecognition = {version = "^3.14.1", optional = true} 37 | langchain-aws = {version = "^0.2.15", optional = true} 38 | tzlocal = "^5.3.1" 39 | 40 | [tool.poetry.extras] 41 | audio = ["speechrecognition"] 42 | aws = ["langchain-aws"] 43 | 44 | [tool.poetry.group.dev.dependencies] 45 | poetry-bumpversion = "^0.3.2" 46 | pytest = "^8.3.3" 47 | pytest-mock = "^3.14.0" 48 | langchain-community = "^0.3.19" 49 | 50 | 51 | [tool.poetry.scripts] 52 | shy = "shy_sh.main:main" 53 | 54 | [build-system] 55 | requires = ["poetry-core"] 56 | build-backend = "poetry.core.masonry.api" 57 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli=true 3 | log_level=ERROR 4 | markers = 5 | eval: evaluation tests for llm 6 | filterwarnings= 7 | ignore:This process \(pid=[0-9]+\) is multi-threaded, use of forkpty\(\) may lead to deadlocks in the child. -------------------------------------------------------------------------------- /script/eval: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | poetry install 4 | poetry run pytest --eval -------------------------------------------------------------------------------- /shy_sh/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/shy_sh/agents/__init__.py -------------------------------------------------------------------------------- /shy_sh/agents/chains/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/shy_sh/agents/chains/__init__.py -------------------------------------------------------------------------------- /shy_sh/agents/chains/alternative_commands.py: -------------------------------------------------------------------------------- 1 | import re 2 | from langchain_core.runnables import chain 3 | from langchain_core.output_parsers import StrOutputParser 4 | from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder 5 | from langchain_core.messages import AIMessage 6 | from shy_sh.agents.llms import get_llm 7 | from shy_sh.settings import settings 8 | from shy_sh.utils import detect_shell 9 | from textwrap import dedent 10 | from rich.live import Live 11 | 12 | 13 | msg_template = dedent( 14 | """ 15 | Find some alternative commands using different tools and variations of the original that can be used to accomplish the same task as the given command. 16 | 17 | Generate only one line of code for each alternative command. 18 | Do not suggest alternatives that are not compatible with the current shell or operating system. 19 | Sort the commands by relevance and usefulness, the most relevant and useful commands should be at the top. 20 | Output each command in a separate block of code respecting the following format: 21 | 22 | # Max 1 line description of the command before each command block{lang_spec} 23 | ``` 24 | command 25 | ``` 26 | 27 | Given command: 28 | ``` 29 | {cmd} 30 | ``` 31 | """ 32 | ) 33 | 34 | example_sh = dedent( 35 | """ 36 | # List all files in the current directory in long format 37 | ``` 38 | ls -l 39 | ``` 40 | 41 | # List all files in the current directory in long format including hidden files 42 | ``` 43 | ls -la 44 | ``` 45 | 46 | # Find all files in the current directory and its subdirectories including hidden files 47 | ``` 48 | find . -type f 49 | ``` 50 | 51 | # Find all files in the current directory and print the full path 52 | ``` 53 | find . -type f -print 54 | ``` 55 | """ 56 | ) 57 | 58 | example_cmd = dedent( 59 | """ 60 | # List all files in the current directory in long format including hidden files 61 | ``` 62 | dir /A /B 63 | ``` 64 | 65 | # Find all files in the current directory and its subdirectories 66 | ``` 67 | dir /S 68 | ``` 69 | 70 | # Find all files in the current directory and its subdirectories using tree 71 | ``` 72 | tree /F 73 | ``` 74 | 75 | # Find all files in the current directory and its subdirectories in plan ASCII 76 | ``` 77 | tree /F /A 78 | ``` 79 | """ 80 | ) 81 | 82 | example_psh = dedent( 83 | """ 84 | # List all files in the current directory including hidden files 85 | ``` 86 | Get-ChildItem -Force 87 | ``` 88 | 89 | # Find all files in the current directory and its subdirectories 90 | ``` 91 | Get-ChildItem -Recurse 92 | ``` 93 | 94 | # List all file names in the current directory and its subdirectories, including hidden files 95 | ``` 96 | Get-ChildItem -Recurse -Force -Name 97 | ``` 98 | 99 | # List all directories in the current directory 100 | ``` 101 | Get-ChildItem -Directory 102 | ``` 103 | """ 104 | ) 105 | 106 | 107 | def get_example(): 108 | shell = detect_shell() 109 | if shell == "powershell": 110 | return example_psh 111 | if shell == "cmd": 112 | return example_cmd 113 | return example_sh 114 | 115 | 116 | @chain 117 | def alternative_commands_chain(inputs): 118 | llm = get_llm() 119 | prompt = ChatPromptTemplate.from_messages( 120 | [ 121 | ( 122 | "system", 123 | "You are a helpfull shell assistant. The current date and time is {timestamp}.\nYou are running on {system} using {shell} as shell", 124 | ), 125 | ( 126 | "human", 127 | msg_template.format(cmd="ls", lang_spec=""), 128 | ), 129 | AIMessage(content=get_example()), 130 | MessagesPlaceholder("history"), 131 | ("human", msg_template), 132 | ] 133 | ) 134 | return prompt | llm | StrOutputParser() 135 | 136 | 137 | def get_alternative_commands(inputs): 138 | with Live() as live: 139 | live.update("⏱️ Finding an alternative solutions...") 140 | response = alternative_commands_chain.invoke( 141 | { 142 | **inputs, 143 | "lang_spec": ( 144 | f" in {settings.language} language" if settings.language else "" 145 | ), 146 | } 147 | ) 148 | live.update("") 149 | return re.findall(r"([^\n]+)\n```[^\n]*\n([^\n]+)\n```", response) 150 | -------------------------------------------------------------------------------- /shy_sh/agents/chains/explain.py: -------------------------------------------------------------------------------- 1 | import pyperclip 2 | from langchain_core.runnables import chain 3 | from langchain_core.output_parsers import StrOutputParser 4 | from langchain.prompts import ChatPromptTemplate 5 | from shy_sh.agents.llms import get_llm 6 | from shy_sh.utils import ask_confirm 7 | from shy_sh.models import ToolMeta 8 | from shy_sh.settings import settings 9 | from textwrap import dedent 10 | from rich.live import Live 11 | from rich.markdown import Markdown 12 | from rich import print 13 | 14 | 15 | msg_template = dedent( 16 | """ 17 | The given task was {task}. 18 | Explain this {script_type} and why it should solve the task{lang_spec}. 19 | Be concise and please limit your explanation to the provided {script_type} and avoid suggesting alternative solutions or directly referencing the given task. 20 | You can use markdown formatting to enhance the explanation. 21 | ``` 22 | {script} 23 | ``` 24 | """ 25 | ) 26 | 27 | 28 | @chain 29 | def explain_chain(_): 30 | llm = get_llm() 31 | prompt = ChatPromptTemplate.from_messages( 32 | [ 33 | ( 34 | "system", 35 | "You are a shell expert. The current date and time is {timestamp}.", 36 | ), 37 | ("human", msg_template), 38 | ] 39 | ) 40 | return prompt | llm | StrOutputParser() 41 | 42 | 43 | def explain(inputs, ask_execute=True, ask_alternative=False): 44 | with Live(vertical_overflow="visible") as live: 45 | text = "🤖: " 46 | for chunk in explain_chain.stream( 47 | { 48 | **inputs, 49 | "lang_spec": ( 50 | f" in {settings.language} language" if settings.language else "" 51 | ), 52 | } 53 | ): 54 | text += chunk 55 | live.update(Markdown(text)) 56 | print() 57 | 58 | if not ask_execute: 59 | return 60 | print(f"🛠️ [bold green] {inputs['script']} [/bold green]") 61 | confirm = ask_confirm(explain=False, alternatives=ask_alternative) 62 | if confirm == "n": 63 | return "Command canceled by user", ToolMeta( 64 | stop_execution=True, skip_print=True 65 | ) 66 | elif confirm == "c": 67 | pyperclip.copy(inputs["script"]) 68 | return "Script copied to the clipboard!", ToolMeta(stop_execution=True) 69 | elif confirm == "a": 70 | return "alternative" 71 | -------------------------------------------------------------------------------- /shy_sh/agents/chains/python_expert.py: -------------------------------------------------------------------------------- 1 | from langchain_core.runnables import chain 2 | from langchain_core.output_parsers import StrOutputParser 3 | from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder 4 | from shy_sh.agents.llms import get_llm 5 | from textwrap import dedent 6 | 7 | msg_template = dedent( 8 | """ 9 | Output only a block of python code like this: 10 | ```python 11 | [your python code] 12 | ``` 13 | 14 | Write a python script that accomplishes the task. Try to avoid the usage of external libraries if not explicitly requested. 15 | Task: {input} 16 | """ 17 | ) 18 | 19 | 20 | @chain 21 | def pyexpert_chain(_): 22 | llm = get_llm() 23 | prompt = ChatPromptTemplate.from_messages( 24 | [ 25 | ( 26 | "system", 27 | "You are a python expert. The current date and time is {timestamp}", 28 | ), 29 | MessagesPlaceholder("history"), 30 | ("human", msg_template), 31 | ] 32 | ) 33 | return prompt | llm | StrOutputParser() 34 | -------------------------------------------------------------------------------- /shy_sh/agents/chains/shell_expert.py: -------------------------------------------------------------------------------- 1 | from langchain_core.runnables import chain 2 | from langchain_core.output_parsers import StrOutputParser 3 | from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder 4 | from shy_sh.agents.llms import get_llm 5 | from textwrap import dedent 6 | 7 | 8 | msg_template = dedent( 9 | """ 10 | Output only a block of code like this: 11 | ```sh 12 | #!/bin/sh 13 | set -e 14 | [your shell script] 15 | ``` 16 | 17 | This is a template for sh, but you should use the right shell syntaxt depending on your system. 18 | Don't install new packages if not explicitly requested. 19 | Write a shell script that accomplishes the task. 20 | 21 | Task: {input} 22 | """ 23 | ) 24 | 25 | 26 | @chain 27 | def shexpert_chain(_): 28 | llm = get_llm() 29 | prompt = ChatPromptTemplate.from_messages( 30 | [ 31 | ( 32 | "system", 33 | "You are a shell expert. The current date and time is {timestamp}\nYou are running on {system} using {shell} as shell", 34 | ), 35 | MessagesPlaceholder("history"), 36 | ("human", msg_template), 37 | ] 38 | ) 39 | return prompt | llm | StrOutputParser() 40 | -------------------------------------------------------------------------------- /shy_sh/agents/chains/shy_agent.py: -------------------------------------------------------------------------------- 1 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder 2 | from langchain_core.runnables import chain 3 | from shy_sh.agents.llms import get_llm 4 | from shy_sh.settings import settings 5 | from shy_sh.agents.tools import tools 6 | from textwrap import dedent 7 | 8 | SYS_TEMPLATES = { 9 | "function_call": dedent( 10 | """ 11 | You are a helpful shell assistant. The current date and time is {timestamp}. 12 | Solve the tasks that I request you to do. 13 | 14 | Answer truthfully with the informations you have. Output your answer in {lang_spec} language. 15 | """ 16 | ), 17 | "react": dedent( 18 | """ 19 | You are a helpful shell assistant. The current date and time is {timestamp}. 20 | Solve the tasks that I request you to do. 21 | 22 | You can use the following tools to accomplish the tasks: 23 | {tools_instructions} 24 | 25 | Rules: 26 | You can use only the tools provided in this prompt to accomplish the tasks 27 | If you need to use tools your response must be in JSON format with this structure: {{ "tool": "...", "arg": "...", "thoughts": "..." }} 28 | Use the shell and your other tools to gather all the information that you need before starting the actual task and also to double check the results if needed before giving the final answer 29 | After you completed the task output your final answer to the task in {lang_spec} language without including any json 30 | You can use markdown to format your final answer 31 | Answer truthfully with the informations you have 32 | You cannot use tools and complete the task with your final answer in the same message so remember to use the tools that you need first 33 | """ 34 | ), 35 | } 36 | 37 | 38 | @chain 39 | def shy_agent_chain(_): 40 | llm = get_llm() 41 | if settings.llm.agent_pattern == "function_call": 42 | llm = llm.bind_tools(tools) 43 | template = SYS_TEMPLATES[settings.llm.agent_pattern] 44 | prompt = ChatPromptTemplate.from_messages( 45 | [ 46 | ("system", template), 47 | MessagesPlaceholder("few_shot_examples", optional=True), 48 | MessagesPlaceholder("history"), 49 | MessagesPlaceholder("tool_history", optional=True), 50 | ] 51 | ) 52 | return prompt | llm 53 | -------------------------------------------------------------------------------- /shy_sh/agents/llms.py: -------------------------------------------------------------------------------- 1 | from shy_sh.settings import settings, BaseLLMSchema 2 | from functools import lru_cache 3 | 4 | 5 | @lru_cache 6 | def get_llm(): 7 | return _get_llm(settings.llm) 8 | 9 | 10 | def _get_llm(llm_config: BaseLLMSchema): 11 | llm = None 12 | match llm_config.provider: 13 | case "openai": 14 | from langchain_openai import ChatOpenAI 15 | 16 | llm = ChatOpenAI( 17 | model=llm_config.name, 18 | temperature=llm_config.temperature, 19 | api_key=llm_config.api_key, 20 | base_url=llm_config.base_url, 21 | ) 22 | case "ollama": 23 | from langchain_ollama import ChatOllama 24 | 25 | llm = ChatOllama(model=llm_config.name, temperature=llm_config.temperature) 26 | 27 | case "groq": 28 | from langchain_groq import ChatGroq 29 | 30 | llm = ChatGroq( 31 | model=llm_config.name, 32 | temperature=llm_config.temperature, 33 | api_key=llm_config.api_key, 34 | ) 35 | 36 | case "anthropic": 37 | from langchain_anthropic import ChatAnthropic 38 | 39 | llm = ChatAnthropic( 40 | model_name=llm_config.name, 41 | temperature=llm_config.temperature, 42 | anthropic_api_key=llm_config.api_key, 43 | ) 44 | 45 | case "google": 46 | from langchain_google_genai import ChatGoogleGenerativeAI 47 | import os 48 | 49 | # Suppress logging warnings 50 | os.environ["GRPC_VERBOSITY"] = "ERROR" 51 | os.environ["GLOG_minloglevel"] = "2" 52 | 53 | llm = ChatGoogleGenerativeAI( 54 | model=llm_config.name, 55 | temperature=llm_config.temperature, 56 | api_key=llm_config.api_key, 57 | ) 58 | 59 | case "aws": 60 | from langchain_aws import ChatBedrockConverse 61 | 62 | region, access_key, secret_key = llm_config.api_key.split(" ") 63 | 64 | llm = ChatBedrockConverse( 65 | model=llm_config.name, 66 | temperature=llm_config.temperature, 67 | region_name=region, 68 | aws_access_key_id=access_key, 69 | aws_secret_access_key=secret_key, 70 | ) 71 | case _: 72 | raise ValueError(f"Unknown LLM provider: {llm_config.provider}") 73 | return llm 74 | 75 | 76 | DEFAULT_CONTEXT_LEN = 8192 77 | LLM_CONTEXT_WINDOWS = { 78 | "openai": { 79 | "default": DEFAULT_CONTEXT_LEN * 4, 80 | }, 81 | "ollama": { 82 | "default": DEFAULT_CONTEXT_LEN, 83 | }, 84 | "groq": { 85 | "default": DEFAULT_CONTEXT_LEN, 86 | }, 87 | "anthropic": { 88 | "default": DEFAULT_CONTEXT_LEN * 4, 89 | }, 90 | "google": { 91 | "default": DEFAULT_CONTEXT_LEN * 4, 92 | }, 93 | "aws": { 94 | "default": DEFAULT_CONTEXT_LEN * 4, 95 | }, 96 | } 97 | 98 | 99 | def get_llm_context(): 100 | provider = LLM_CONTEXT_WINDOWS.get(settings.llm.provider, None) 101 | if not provider: 102 | return DEFAULT_CONTEXT_LEN 103 | return provider.get( 104 | settings.llm.name, 105 | provider.get("default", DEFAULT_CONTEXT_LEN), 106 | ) 107 | -------------------------------------------------------------------------------- /shy_sh/agents/misc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from time import strftime 4 | from uuid import uuid4 5 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage 6 | from shy_sh.settings import settings 7 | from shy_sh.utils import detect_shell, detect_os, run_shell 8 | from shy_sh.agents.tools import tools 9 | from shy_sh.models import ToolRequest 10 | 11 | 12 | def get_graph_inputs( 13 | history: list, 14 | examples: list, 15 | ask_before_execute: bool, 16 | ): 17 | return { 18 | "history": history, 19 | "timestamp": strftime("%Y-%m-%d %H:%M %Z"), 20 | "ask_before_execute": ask_before_execute, 21 | "lang_spec": settings.language or "", 22 | "few_shot_examples": examples, 23 | "tools_instructions": _format_tools(), 24 | } 25 | 26 | 27 | def parse_react_tool(message): 28 | start_idx = message.content.index("{") 29 | if start_idx < 0: 30 | raise ValueError("No tool call found") 31 | end_idx = start_idx + 1 32 | open_brackets = 1 33 | while open_brackets > 0 and end_idx < len(message.content): 34 | if message.content[end_idx] == "{": 35 | open_brackets += 1 36 | elif message.content[end_idx] == "}": 37 | open_brackets -= 1 38 | end_idx += 1 39 | maybe_tool = message.content[start_idx:end_idx] 40 | try: 41 | return ToolRequest.model_validate_json(maybe_tool) 42 | except Exception: 43 | try: 44 | maybe_tool = message.content[start_idx : message.content.rindex("}") + 1] 45 | return ToolRequest.model_validate_json(maybe_tool) 46 | except Exception: 47 | maybe_tool = re.sub(r"\\(?!\\)", r"\\\\", maybe_tool) 48 | return ToolRequest.model_validate_json(maybe_tool) 49 | 50 | 51 | def has_tool_calls(message): 52 | if settings.llm.agent_pattern == "function_call": 53 | return bool(getattr(message, "tool_calls", None)) 54 | elif settings.llm.agent_pattern == "react": 55 | try: 56 | parse_react_tool(message) 57 | return True 58 | except Exception: 59 | pass 60 | return False 61 | 62 | 63 | def run_few_shot_examples(): 64 | shell = detect_shell() 65 | os = detect_os() 66 | actions = [ 67 | { 68 | "tool": "shell", 69 | "arg": "echo test" if shell in ["powershell", "cmd"] else "echo $SHELL", 70 | "thoughts": "I'm checking the current shell", 71 | }, 72 | { 73 | "tool": "shell", 74 | "arg": "echo %cd%" if shell in ["powershell", "cmd"] else "pwd", 75 | "thoughts": "I'm checking the current working directory", 76 | }, 77 | { 78 | "tool": "shell", 79 | "arg": "git rev-parse --abbrev-ref HEAD", 80 | "thoughts": "I'm checking if it's a git repository", 81 | }, 82 | ] 83 | result = [] 84 | result.append( 85 | HumanMessage( 86 | content=f"You are on {os} system using {shell} as shell. Check your tools to get started." 87 | ) 88 | ) 89 | for action in actions: 90 | uid = str(uuid4()) 91 | ai_message, response = _run_example(action, uid) 92 | result.append(ai_message) 93 | if settings.llm.agent_pattern == "react": 94 | result.append(HumanMessage(content=f"Tool response:\n{response}")) 95 | elif settings.llm.agent_pattern == "function_call": 96 | result.append(ToolMessage(content=response, tool_call_id=uid)) 97 | result.append(AIMessage(content="All set! 👍")) 98 | return result 99 | 100 | 101 | def _run_example(action, uid): 102 | ai_message = AIMessage( 103 | content="", 104 | tool_calls=[ 105 | { 106 | "id": uid, 107 | "type": "tool_call", 108 | "name": action["tool"], 109 | "args": {"arg": action["arg"]}, 110 | } 111 | ], 112 | ) 113 | if settings.llm.agent_pattern == "react": 114 | ai_message.content = json.dumps(action) 115 | ai_message.tool_calls = [] 116 | return ai_message, run_shell(action["arg"]) 117 | 118 | 119 | def _format_tools(): 120 | if settings.llm.agent_pattern == "function_call": 121 | return None 122 | return "\n".join(map(lambda tool: f'- "{tool.name}": {tool.description}', tools)) 123 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/shy_sh/agents/shy_agent/__init__.py -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/agent.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | from shy_sh.settings import settings 3 | from shy_sh.agents.shy_agent.graph import shy_agent_graph 4 | from shy_sh.agents.misc import get_graph_inputs, run_few_shot_examples 5 | from shy_sh.agents.shy_agent.audio import capture_prompt 6 | from shy_sh.manager.history import ( 7 | print_recent_commands, 8 | load_chat_history, 9 | print_chat_history, 10 | save_chat_history, 11 | print_chat, 12 | ) 13 | from shy_sh.utils import command_completer, save_history 14 | from langchain_core.messages import HumanMessage 15 | 16 | 17 | class ShyAgent: 18 | def __init__( 19 | self, 20 | interactive=False, 21 | ask_before_execute=True, 22 | audio=False, 23 | ): 24 | self.interactive = interactive 25 | self.ask_before_execute = ask_before_execute 26 | self.audio = audio 27 | self.history = [] 28 | self.executed_scripts = [] 29 | self.examples = run_few_shot_examples() 30 | self.chat_id = None 31 | 32 | def _run(self, task: str): 33 | self.history.append(HumanMessage(content=task)) 34 | inputs = get_graph_inputs( 35 | history=self.history, 36 | examples=self.examples, 37 | ask_before_execute=self.ask_before_execute, 38 | ) 39 | 40 | res = shy_agent_graph.invoke(inputs) 41 | self.history += res["tool_history"] 42 | self.executed_scripts += res["executed_scripts"] 43 | 44 | def _handle_command(self, command: str): 45 | if not command.startswith("/"): 46 | return False 47 | command = command[1:] 48 | match command: 49 | case "chats" | "c": 50 | print_chat_history() 51 | print("\n🤖: Here is the list of all chats!") 52 | return True 53 | case "history" | "h": 54 | print_recent_commands() 55 | print("\n🤖: Here are the most recently executed commands!") 56 | return True 57 | case "clear": 58 | self.history = [] 59 | self.executed_scripts = [] 60 | self.chat_id = None 61 | print("🤖: New chat started") 62 | return True 63 | case command if command.startswith("chat "): 64 | try: 65 | print_chat(int(command[5:])) 66 | except ValueError: 67 | print(f"🚨 [bold red]Invalid chat ID {command[5:]}[/]") 68 | return True 69 | case command if command.startswith("c "): 70 | try: 71 | print() 72 | print_chat(int(command[2:])) 73 | except ValueError: 74 | print(f"🚨 [bold red]Invalid chat ID {command[2:]}[/]") 75 | return True 76 | case command if command.startswith("load "): 77 | try: 78 | chat_id = int(command[5:]) 79 | self.history = load_chat_history(chat_id) 80 | if self.history: 81 | self.chat_id = chat_id 82 | print(f"🤖: Loaded history for chat ID {chat_id}") 83 | else: 84 | print(f"🤖: No history found for chat ID {chat_id}") 85 | except ValueError: 86 | print(f"🚨 [bold red]Invalid chat ID {command[5:]}[/]") 87 | return True 88 | return False 89 | 90 | def _load_autocomplete(self): 91 | try: 92 | import readline 93 | 94 | readline.set_completer(command_completer) 95 | except Exception: 96 | pass 97 | 98 | def _dispose_autocomplete(self): 99 | try: 100 | import readline 101 | 102 | readline.set_completer(lambda *args: None) 103 | except Exception: 104 | pass 105 | 106 | def start(self, task: str): 107 | if task: 108 | self._run(task) 109 | if self.history: 110 | save_chat_history( 111 | chat_id=self.chat_id, 112 | title=task if not self.chat_id else None, 113 | messages=self.history, 114 | executed_scripts=self.executed_scripts, 115 | meta=settings.llm.model_dump(exclude={"api_key"}), 116 | ) 117 | if self.interactive: 118 | if self.audio: 119 | new_task = None 120 | while not new_task: 121 | print(f"\n🎤: ", end="") 122 | new_task = capture_prompt().strip() 123 | print(new_task) 124 | else: 125 | self._load_autocomplete() 126 | new_task = input("\n✨: ") 127 | self._dispose_autocomplete() 128 | while new_task.endswith("\\"): 129 | new_task = new_task[:-1] + "\n" + input(" ") 130 | save_history() 131 | 132 | if new_task == "exit" or new_task == "quit" or new_task == "q": 133 | print("\n🤖: 👋 Bye!\n") 134 | return 135 | if new_task.startswith("/"): 136 | r = self._handle_command(new_task) 137 | if r: 138 | new_task = "" 139 | 140 | self.start(new_task) 141 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shy_sh.settings import settings 3 | 4 | 5 | def capture_prompt(): 6 | import speech_recognition as sr 7 | 8 | r = sr.Recognizer() 9 | r.energy_threshold = 4000 10 | r.pause_threshold = 1.5 11 | r.phrase_threshold = 1.5 12 | with sr.Microphone() as source: 13 | audio = r.listen(source, timeout=5) 14 | 15 | try: 16 | os.environ["GROQ_API_KEY"] = ( 17 | os.environ.get("GROQ_API_KEY") or settings.llm.api_key 18 | ) 19 | return r.recognize_groq(audio) 20 | except Exception: 21 | return "" 22 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/edges/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/shy_sh/agents/shy_agent/edges/__init__.py -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/edges/final_response.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | from langgraph.graph import END 3 | from shy_sh.models import State, ToolMeta 4 | from shy_sh.settings import settings 5 | from shy_sh.utils import syntax 6 | 7 | 8 | def final_response_edge(state: State): 9 | last_tool = state["tool_history"][-1] 10 | artifact = getattr(last_tool, "artifact", None) 11 | if isinstance(artifact, ToolMeta) and artifact.stop_execution: 12 | message = last_tool.content 13 | if settings.llm.agent_pattern == "react": 14 | message = message.replace("Tool response:\n", "", 1) 15 | if not artifact.skip_print: 16 | print(syntax(f"🤖: {message}")) 17 | return END 18 | print() 19 | return "chatbot" 20 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/edges/tool_calls.py: -------------------------------------------------------------------------------- 1 | from langgraph.graph import END 2 | from shy_sh.models import State 3 | from shy_sh.agents.misc import has_tool_calls 4 | 5 | 6 | def tool_calls_edge(state: State): 7 | last_message = state["tool_history"][-1] 8 | if has_tool_calls(last_message): 9 | return "tools" 10 | return END 11 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/graph.py: -------------------------------------------------------------------------------- 1 | from langgraph.graph import StateGraph, START 2 | from shy_sh.agents.shy_agent.nodes.chatbot import chatbot 3 | from shy_sh.agents.shy_agent.nodes.tools_handler import tools_handler 4 | from shy_sh.agents.shy_agent.edges.final_response import final_response_edge 5 | from shy_sh.agents.shy_agent.edges.tool_calls import tool_calls_edge 6 | from shy_sh.models import State 7 | 8 | graph_builder = StateGraph(State) 9 | 10 | 11 | graph_builder.add_node("chatbot", chatbot) 12 | graph_builder.add_node("tools", tools_handler) 13 | 14 | graph_builder.add_edge(START, "chatbot") 15 | graph_builder.add_conditional_edges("chatbot", tool_calls_edge) 16 | graph_builder.add_conditional_edges("tools", final_response_edge) 17 | 18 | shy_agent_graph = graph_builder.compile() 19 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mceck/shy-sh/828cb619fdca19533783014b77daaa9973ae5b46/shy_sh/agents/shy_agent/nodes/__init__.py -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/nodes/chatbot.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import AIMessage 2 | from shy_sh.settings import settings 3 | from shy_sh.models import State 4 | from shy_sh.agents.llms import get_llm_context 5 | from shy_sh.utils import count_tokens 6 | from shy_sh.agents.chains.shy_agent import shy_agent_chain 7 | from shy_sh.agents.misc import has_tool_calls 8 | from rich.live import Live 9 | from rich.markdown import Markdown 10 | 11 | console_theme = { 12 | "lexer": "console", 13 | "theme": "one-dark", 14 | "background_color": "#181818", 15 | } 16 | 17 | loading_str = "⏱️ Loading..." 18 | 19 | 20 | def chatbot(state: State): 21 | final_message = None 22 | history = _compress_history(state["history"], state["tool_history"]) 23 | with Live(vertical_overflow="visible") as live: 24 | live.update(loading_str) 25 | for chunk in shy_agent_chain.stream({**state, "history": history}): 26 | final_message = chunk if final_message is None else final_message + chunk # type: ignore 27 | message = _parse_chunk_message(final_message) 28 | if _maybe_have_tool_calls(final_message, message): 29 | live.update(loading_str) 30 | else: 31 | live.update( 32 | Markdown(f"🤖: {message}"), 33 | refresh=True, 34 | ) 35 | message = _parse_chunk_message(final_message) 36 | ai_message = AIMessage( 37 | content=message, tool_calls=getattr(final_message, "tool_calls", []) 38 | ) 39 | has_tools = has_tool_calls(ai_message) 40 | if not message or (settings.llm.agent_pattern == "react" and has_tools): 41 | live.update("") 42 | else: 43 | nl = "\n" if has_tools else "" 44 | live.update(Markdown(f"\n🤖: {message}{nl}")) 45 | return {"tool_history": [ai_message]} 46 | 47 | 48 | def _maybe_have_tool_calls(message, parsed_message): 49 | return ( 50 | not message.content 51 | or getattr(message, "tool_calls", None) 52 | or (parsed_message.startswith("{") and settings.llm.agent_pattern == "react") 53 | ) 54 | 55 | 56 | def _parse_chunk_message(chunk): 57 | if isinstance(chunk.content, list): 58 | return "".join(c.get("text") for c in chunk.content if c.get("type") == "text") 59 | else: 60 | return chunk.content 61 | 62 | 63 | def _compress_history(history, tool_history): 64 | max_len = get_llm_context() 65 | tokens = count_tokens(history + tool_history) 66 | while tokens > max_len: 67 | history = history[2:] 68 | if not history: 69 | break 70 | tokens = count_tokens(history + tool_history) 71 | return history 72 | -------------------------------------------------------------------------------- /shy_sh/agents/shy_agent/nodes/tools_handler.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | from langchain_core.messages import ToolMessage, HumanMessage 3 | from rich import print 4 | from shy_sh.models import State 5 | from shy_sh.agents.tools import tools_by_name 6 | from shy_sh.settings import settings 7 | from shy_sh.agents.misc import parse_react_tool 8 | 9 | 10 | def tools_handler(state: State): 11 | last_message = state["tool_history"][-1] 12 | t_calls = _get_tool_calls(last_message) 13 | 14 | tool_answers = [] 15 | executed_scripts = [] 16 | for t_call in t_calls: 17 | try: 18 | t = tools_by_name[t_call["name"]] 19 | message = t.invoke( 20 | { 21 | **t_call, 22 | "args": {"state": state, **t_call["args"]}, 23 | } 24 | ) 25 | except Exception as e: 26 | print(f"[bold red]🚨 Tool error: {e}[/bold red]") 27 | message = ToolMessage(f"Tool error: {e}", tool_call_id=t_call["id"]) 28 | 29 | if settings.llm.agent_pattern == "react": 30 | m = HumanMessage(content=f"Tool response:\n{message.content}") 31 | m.artifact = getattr(message, "artifact", None) 32 | message = m 33 | tool_answers.append(message) 34 | if hasattr(message, "artifact") and message.artifact.executed_scripts: 35 | executed_scripts += message.artifact.executed_scripts 36 | return {"tool_history": tool_answers, "executed_scripts": executed_scripts} 37 | 38 | 39 | def _get_react_tool_calls(message): 40 | react_tool = parse_react_tool(message) 41 | return [ 42 | { 43 | "name": react_tool.tool, 44 | "args": {"arg": react_tool.arg}, 45 | "id": uuid4().hex, 46 | "type": "tool_call", 47 | } 48 | ] 49 | 50 | 51 | def _get_function_call_tool_calls(message): 52 | return message.tool_calls 53 | 54 | 55 | def _get_tool_calls(message): 56 | match (settings.llm.agent_pattern): 57 | case "react": 58 | return _get_react_tool_calls(message) 59 | case "function_call": 60 | return _get_function_call_tool_calls(message) 61 | case _: 62 | raise ValueError("Unknown agent pattern") 63 | -------------------------------------------------------------------------------- /shy_sh/agents/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from shy_sh.agents.tools.shell import shell 2 | from shy_sh.agents.tools.shell_expert import shell_expert 3 | from shy_sh.agents.tools.python_expert import python_expert 4 | from shy_sh.agents.tools.shell_history import shell_history 5 | 6 | tools = [shell, shell_expert, python_expert, shell_history] 7 | tools_by_name = {tool.name: tool for tool in tools} 8 | -------------------------------------------------------------------------------- /shy_sh/agents/tools/python_expert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pyperclip 4 | from typing import Annotated 5 | from tempfile import NamedTemporaryFile 6 | from rich import print 7 | from rich.live import Live 8 | from langgraph.prebuilt import InjectedState 9 | from langchain.tools import tool 10 | from shy_sh.models import State, ToolMeta 11 | from shy_sh.utils import ask_confirm, tools_to_human, syntax, run_python, parse_code 12 | from shy_sh.agents.chains.python_expert import pyexpert_chain 13 | from shy_sh.agents.chains.explain import explain 14 | 15 | 16 | @tool(response_format="content_and_artifact") 17 | def python_expert(arg: str, state: Annotated[State, InjectedState]): 18 | """to delegate the task to a python expert that can write and execute python code, use only if you cant resolve the task with shell, just explain what do you want to achieve in a short sentence in the arg without including any python code""" 19 | print(f"🐍 [bold yellow]Generating python script...[/bold yellow]\n") 20 | inputs = { 21 | "timestamp": state["timestamp"], 22 | "history": tools_to_human(state["history"] + state["tool_history"]), 23 | "input": arg, 24 | } 25 | code = "" 26 | with Live(vertical_overflow="visible") as live: 27 | for chunk in pyexpert_chain.stream(inputs): 28 | code += chunk # type: ignore 29 | live.update(syntax(code, "python")) 30 | code = parse_code(code) 31 | live.update(syntax(code.strip(), "python")) 32 | 33 | confirm = "y" 34 | if state["ask_before_execute"]: 35 | confirm = ask_confirm() 36 | print() 37 | if confirm == "n": 38 | return "Command canceled by user", ToolMeta( 39 | stop_execution=True, skip_print=True 40 | ) 41 | elif confirm == "c": 42 | pyperclip.copy(code) 43 | return "Script copied to the clipboard!", ToolMeta(stop_execution=True) 44 | elif confirm == "e": 45 | inputs = { 46 | "task": arg, 47 | "script_type": "python script", 48 | "script": code, 49 | "timestamp": state["timestamp"], 50 | } 51 | ret = explain(inputs) 52 | if ret: 53 | return ret 54 | 55 | if sys.version_info >= (3, 12): 56 | with NamedTemporaryFile("w+", suffix=".py", delete_on_close=False) as file: 57 | file.write(code) 58 | file.close() 59 | os.chmod(file.name, 0o755) 60 | result = run_python(file.name) 61 | 62 | if len(result) > 20000: 63 | print("\n🐳 [bold red]Output too long! It will be truncated[/bold red]") 64 | result = ( 65 | result[:9000] 66 | + "\n...(OUTPUT TOO LONG TRUNCATED!)...\n" 67 | + result[-9000:] 68 | ) 69 | 70 | else: 71 | with NamedTemporaryFile("w+", suffix=".py", delete=False) as file: 72 | file.write(code) 73 | file.close() 74 | os.chmod(file.name, 0o755) 75 | result = run_python(file.name) 76 | 77 | if len(result) > 20000: 78 | print("\n🐳 [bold red]Output too long! It will be truncated[/bold red]") 79 | result = ( 80 | result[:9000] 81 | + "\n...(OUTPUT TOO LONG TRUNCATED!)...\n" 82 | + result[-9000:] 83 | ) 84 | os.unlink(file.name) 85 | 86 | ret = f"\nScript executed:\n```python\n{code.strip()}\n```\n\nOutput:\n{result}" 87 | if len(ret) > 20000: 88 | ret = f"Output:\n{result}" 89 | return ( 90 | ret, 91 | ToolMeta( 92 | executed_scripts=[{"script": code, "type": "python", "result": result}] 93 | ), 94 | ) 95 | -------------------------------------------------------------------------------- /shy_sh/agents/tools/shell.py: -------------------------------------------------------------------------------- 1 | import pyperclip 2 | from typing import Annotated 3 | from rich import print 4 | from langgraph.prebuilt import InjectedState 5 | from langchain.tools import tool 6 | from questionary import select, Style, Choice 7 | from shy_sh.models import State, ToolMeta 8 | from shy_sh.utils import ( 9 | ask_confirm, 10 | run_command, 11 | tools_to_human, 12 | detect_shell, 13 | detect_os, 14 | ) 15 | from shy_sh.agents.chains.explain import explain 16 | from shy_sh.agents.chains.alternative_commands import get_alternative_commands 17 | from shy_sh.settings import settings 18 | 19 | _text_style = { 20 | "qmark": "", 21 | "style": Style.from_dict( 22 | { 23 | "selected": "fg:ansigreen noreverse bold", 24 | "question": "fg:darkorange nobold", 25 | "highlighted": "fg:ansigreen bold", 26 | "text": "fg:ansigreen bold", 27 | "answer": "fg:ansigreen bold", 28 | "instruction": "fg:ansigreen", 29 | } 30 | ), 31 | } 32 | 33 | _select_style = { 34 | "pointer": "►", 35 | "instruction": " ", 36 | **_text_style, 37 | } 38 | 39 | 40 | @tool(response_format="content_and_artifact") 41 | def shell(arg: str, state: Annotated[State, InjectedState]): 42 | """to execute a shell command in the terminal, useful for every task that requires to interact with the current system or local files, do not pass multiple lines commands, avoid to install new packages if not explicitly requested""" 43 | print(f"🛠️ [bold green] {arg} [/bold green]") 44 | result = "" 45 | confirm = "y" 46 | if state["ask_before_execute"]: 47 | confirm = ask_confirm(alternatives=True) 48 | print() 49 | if confirm == "n": 50 | return "Command interrupted by the user", ToolMeta( 51 | stop_execution=True, skip_print=True 52 | ) 53 | elif confirm == "c": 54 | pyperclip.copy(arg) 55 | return "Command copied to the clipboard!", ToolMeta(stop_execution=True) 56 | elif confirm == "a": 57 | r = _select_alternative_command(arg, state) 58 | print() 59 | if r == "None": 60 | return "Command interrupted by the user", ToolMeta( 61 | stop_execution=True, skip_print=True 62 | ) 63 | if settings.sandbox_mode: 64 | pyperclip.copy(r) 65 | return "Command copied to the clipboard!", ToolMeta(stop_execution=True) 66 | arg = r 67 | result += f"The user decided to execute this alternative command `{arg}`\n\n" 68 | elif confirm == "e": 69 | inputs = { 70 | "task": state["history"][-1].content, 71 | "script_type": "shell command", 72 | "script": arg, 73 | "timestamp": state["timestamp"], 74 | } 75 | ret = explain(inputs, ask_alternative=True) 76 | if ret == "alternative": 77 | r = _select_alternative_command(arg, state) 78 | print() 79 | if r == "None": 80 | return "Command interrupted by the user", ToolMeta( 81 | stop_execution=True, skip_print=True 82 | ) 83 | arg = r 84 | result += ( 85 | f"The user decided to execute this alternative command `{arg}`\n\n" 86 | ) 87 | elif ret: 88 | return ret 89 | output = run_command(arg) 90 | result += output 91 | 92 | if len(result) > 20000: 93 | print("\n🐳 [bold red]Output too long! It will be truncated[/bold red]") 94 | result = ( 95 | result[:9000] + "\n...(OUTPUT TOO LONG TRUNCATED!)...\n" + result[-9000:] 96 | ) 97 | return result, ToolMeta( 98 | executed_scripts=[{"script": arg, "type": "shell", "result": output}] 99 | ) 100 | 101 | 102 | def _select_alternative_command(arg, state): 103 | inputs = { 104 | "timestamp": state["timestamp"], 105 | "shell": detect_shell(), 106 | "system": detect_os(), 107 | "history": tools_to_human(state["history"] + state["tool_history"]), 108 | "cmd": arg, 109 | } 110 | cmds = get_alternative_commands(inputs) 111 | r = select( 112 | ( 113 | "Pick the command to copy to the clipboard" 114 | if settings.sandbox_mode 115 | else "Pick the command to execute" 116 | ), 117 | choices=[ 118 | Choice([("fg:ansired bold", "Cancel")], "None"), 119 | Choice( 120 | [ 121 | ("fg:ansiyellow bold", arg), 122 | ("fg:gray", " # Original command"), 123 | ], 124 | arg, 125 | ), 126 | *[ 127 | Choice( 128 | [("fg:ansigreen bold", c[1]), ("fg:gray", " " + c[0])], 129 | c[1], 130 | ) 131 | for c in cmds 132 | ], 133 | ], 134 | **_select_style, 135 | ).unsafe_ask() 136 | 137 | return r 138 | -------------------------------------------------------------------------------- /shy_sh/agents/tools/shell_expert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pyperclip 4 | from typing import Annotated 5 | from tempfile import NamedTemporaryFile 6 | from rich import print 7 | from rich.live import Live 8 | from langgraph.prebuilt import InjectedState 9 | from langchain.tools import tool 10 | from shy_sh.models import State, ToolMeta 11 | from shy_sh.utils import ask_confirm, detect_shell, detect_os, parse_code 12 | from shy_sh.agents.chains.shell_expert import shexpert_chain 13 | from shy_sh.utils import run_command, tools_to_human, syntax 14 | from shy_sh.agents.chains.explain import explain 15 | 16 | 17 | @tool(response_format="content_and_artifact") 18 | def shell_expert(arg: str, state: Annotated[State, InjectedState]): 19 | """to delegate the task to a shell expert that can write and execute long and complex shell scripts, use only if you cant resolve the task with a simple shell command, just explain what do you want to achieve in a short sentence in the arg without including any shell code""" 20 | print(f"💻 [bold dark_green]Generating shell script...[/bold dark_green]\n") 21 | shell = detect_shell() 22 | system = detect_os() 23 | inputs = { 24 | "input": arg, 25 | "system": system, 26 | "shell": shell, 27 | "timestamp": state["timestamp"], 28 | "history": tools_to_human(state["history"] + state["tool_history"]), 29 | } 30 | code = "" 31 | with Live(vertical_overflow="visible") as live: 32 | for chunk in shexpert_chain.stream(inputs): 33 | code += chunk # type: ignore 34 | live.update(syntax(code)) 35 | 36 | code = parse_code(code) 37 | live.update(syntax(code.strip())) 38 | 39 | confirm = "y" 40 | if state["ask_before_execute"]: 41 | confirm = ask_confirm() 42 | print() 43 | if confirm == "n": 44 | return "Script interrupted by the user", ToolMeta( 45 | stop_execution=True, skip_print=True 46 | ) 47 | elif confirm == "c": 48 | pyperclip.copy(code) 49 | return "Script copied to the clipboard!", ToolMeta(stop_execution=True) 50 | elif confirm == "e": 51 | inputs = { 52 | "task": arg, 53 | "script_type": "shell script", 54 | "script": code, 55 | "timestamp": state["timestamp"], 56 | } 57 | ret = explain(inputs) 58 | if ret: 59 | return ret 60 | 61 | ext = ".sh" 62 | if shell == "cmd": 63 | ext = ".bat" 64 | elif shell == "powershell": 65 | ext = ".ps1" 66 | 67 | if sys.version_info >= (3, 12): 68 | with NamedTemporaryFile("w+", suffix=ext, delete_on_close=False) as file: 69 | file.write(code) 70 | file.close() 71 | os.chmod(file.name, 0o755) 72 | result = run_command(file.name) 73 | 74 | if len(result) > 20000: 75 | print("\n🐳 [bold red]Output too long! It will be truncated[/bold red]") 76 | result = ( 77 | result[:9000] 78 | + "\n...(OUTPUT TOO LONG TRUNCATED!)...\n" 79 | + result[-9000:] 80 | ) 81 | 82 | else: 83 | with NamedTemporaryFile("w+", suffix=ext, delete=False) as file: 84 | file.write(code) 85 | file.close() 86 | os.chmod(file.name, 0o755) 87 | result = run_command(file.name) 88 | 89 | if len(result) > 20000: 90 | print("\n🐳 [bold red]Output too long! It will be truncated[/bold red]") 91 | result = ( 92 | result[:9000] 93 | + "\n...(OUTPUT TOO LONG TRUNCATED!)...\n" 94 | + result[-9000:] 95 | ) 96 | os.unlink(file.name) 97 | print() 98 | ret = f"Script executed:\n{code}\n\nOutput:\n{result}" 99 | if len(ret) > 20000: 100 | ret = f"Output:\n{result}" 101 | return ret, ToolMeta( 102 | executed_scripts=[{"script": code, "type": "shell_script", "result": result}] 103 | ) 104 | -------------------------------------------------------------------------------- /shy_sh/agents/tools/shell_history.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | from rich import print 3 | from langgraph.prebuilt import InjectedState 4 | from langchain.tools import tool 5 | from shy_sh.models import State, ToolMeta 6 | from shy_sh.utils import get_shell_history 7 | 8 | 9 | @tool(response_format="content_and_artifact") 10 | def shell_history(arg: str, state: Annotated[State, InjectedState]): 11 | """to get the history of the last commands executed by the user in the shell, useful to understand what the user has already tried, use it if the user asks to check a previous command launched before the chat started""" 12 | print(f"📜 [bold yellow]Let me check...[/bold yellow]\n") 13 | history = get_shell_history() 14 | return f"These are the last commands executed by the user:\n{history}", ToolMeta() 15 | -------------------------------------------------------------------------------- /shy_sh/main.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from typing import Optional, Annotated 3 | from importlib.metadata import version 4 | from shy_sh.agents.shy_agent.agent import ShyAgent 5 | from shy_sh.manager.history import truncate_chats 6 | from shy_sh.settings import settings, configure_yaml 7 | from shy_sh.agents.chains.explain import explain as do_explain 8 | from shy_sh.utils import load_history 9 | from rich import print 10 | from time import strftime 11 | 12 | 13 | def exec( 14 | prompt: Annotated[Optional[list[str]], typer.Argument(allow_dash=False)] = None, 15 | oneshot: Annotated[ 16 | Optional[bool], 17 | typer.Option( 18 | "-o", 19 | help="One shot mode", 20 | ), 21 | ] = False, 22 | no_ask: Annotated[ 23 | Optional[bool], 24 | typer.Option( 25 | "-x", 26 | help="Do not ask for confirmation before executing scripts", 27 | ), 28 | ] = False, 29 | explain: Annotated[ 30 | Optional[bool], 31 | typer.Option( 32 | "-e", 33 | help="Explain the given shell command", 34 | ), 35 | ] = False, 36 | audio: Annotated[ 37 | Optional[bool], 38 | typer.Option( 39 | "-a", 40 | help="Interactive mode with audio input", 41 | ), 42 | ] = False, 43 | configure: Annotated[ 44 | Optional[bool], typer.Option("--configure", help="Configure LLM") 45 | ] = False, 46 | display_version: Annotated[ 47 | Optional[bool], typer.Option("--version", help="Show version") 48 | ] = False, 49 | ): 50 | if display_version: 51 | print(f"Version: {version(__package__ or 'shy-sh')}") 52 | return 53 | if configure: 54 | configure_yaml() 55 | return 56 | task = " ".join(prompt or []) 57 | print(f"[bold italic dark_orange]{settings.llm.provider} - {settings.llm.name}[/]") 58 | if explain: 59 | if not task: 60 | print("🚨 [bold red]No command provided[/]") 61 | do_explain( 62 | { 63 | "task": "explain this shell command", 64 | "script_type": "shell command", 65 | "script": task, 66 | "script_type": "shell command", 67 | "timestamp": strftime("%Y-%m-%d %H:%M:%S"), 68 | }, 69 | ask_execute=False, 70 | ) 71 | return 72 | interactive = not oneshot 73 | if task: 74 | print(f"\n✨: {task}\n") 75 | try: 76 | ShyAgent( 77 | interactive=interactive, 78 | ask_before_execute=not no_ask, 79 | audio=bool(audio), 80 | ).start(task) 81 | except Exception as e: 82 | print(f"🚨 [bold red]{e}[/bold red]") 83 | 84 | 85 | def main(): 86 | try: 87 | truncate_chats(1000) 88 | import readline 89 | 90 | readline.set_history_length(100) 91 | readline.set_completer_delims(" \t\n") 92 | readline.parse_and_bind("tab: complete") 93 | readline.parse_and_bind("set show-all-if-unmodified on") 94 | except Exception: 95 | pass 96 | load_history() 97 | typer.run(exec) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /shy_sh/manager/chat_manager.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import delete, select, func 2 | from sqlalchemy.orm import Session 3 | from shy_sh.manager.sql_models import Chat, ExecutedScript, ScriptType 4 | 5 | 6 | class ChatManager: 7 | def __init__(self, session: Session): 8 | self.session = session 9 | 10 | def get_chat(self, chat_id): 11 | return self.session.get(Chat, chat_id) 12 | 13 | def get_all_chats(self): 14 | return ( 15 | self.session.execute(select(Chat).order_by(Chat.created_at)).scalars().all() 16 | ) 17 | 18 | def get_recent_scripts( 19 | self, 20 | script_type: ScriptType | None = None, 21 | limit: int = 20, 22 | ): 23 | q = ( 24 | select( 25 | ExecutedScript.script, 26 | func.count(ExecutedScript.script), 27 | func.max(ExecutedScript.type), 28 | func.max(ExecutedScript.created_at), 29 | ) 30 | .group_by(ExecutedScript.script) 31 | .order_by(ExecutedScript.created_at.desc()) 32 | .limit(limit) 33 | ) 34 | if script_type: 35 | q = q.where(ExecutedScript.type == script_type) 36 | return self.session.execute(q).all() 37 | 38 | def save_chat( 39 | self, 40 | id: int | None = None, 41 | **kwargs, 42 | ): 43 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 44 | if id: 45 | chat = self.get_chat(id) 46 | if not chat: 47 | raise ValueError(f"Chat with ID {id} not found.") 48 | for key, value in kwargs.items(): 49 | setattr(chat, key, value) 50 | else: 51 | chat = Chat(**kwargs) 52 | self.session.add(chat) 53 | 54 | def truncate_chats(self, keep: int = 100): 55 | self.session.execute( 56 | delete(Chat).where( 57 | Chat.id.not_in( 58 | select(Chat.id).order_by(Chat.created_at.desc()).limit(keep) 59 | ) 60 | ) 61 | ) 62 | -------------------------------------------------------------------------------- /shy_sh/manager/history.py: -------------------------------------------------------------------------------- 1 | from shy_sh.agents.misc import has_tool_calls, parse_react_tool 2 | from shy_sh.manager.sql_models import ExecutedScript, ScriptType, session 3 | from shy_sh.manager.chat_manager import ChatManager 4 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage 5 | from rich import print 6 | from rich.markdown import Markdown 7 | 8 | from shy_sh.utils import syntax, to_local 9 | 10 | 11 | def print_chat_history(): 12 | print("\n[bold green]Chats[/bold green]") 13 | with session() as db: 14 | chat_manager = ChatManager(db) 15 | chats = chat_manager.get_all_chats() 16 | if not chats: 17 | print("No chat history found.") 18 | return 19 | 20 | for chat in chats: 21 | print(f"[bold magenta][{chat.id}][/] - [bold green]{chat.title}[/]") 22 | print( 23 | f" Model: {chat.meta.get('provider', '')} - {chat.meta.get('name', 'Unknown')}" 24 | ) 25 | print(f" Messages: {len(chat.messages)}") 26 | print(f" Created: {to_local(chat.created_at).strftime('%Y-%m-%d %H:%M')}") 27 | 28 | 29 | _MESSAGE_MAP = { 30 | "human": HumanMessage, 31 | "ai": AIMessage, 32 | "tool": ToolMessage, 33 | } 34 | 35 | 36 | def _serialize_messages(messages: list[dict]) -> list[BaseMessage]: 37 | return [ 38 | _MESSAGE_MAP[message["type"]].model_validate(message) for message in messages 39 | ] 40 | 41 | 42 | def load_chat_history(chat_id: int) -> list[BaseMessage]: 43 | with session() as db: 44 | chat_manager = ChatManager(db) 45 | chat = chat_manager.get_chat(chat_id) 46 | if not chat: 47 | return [] 48 | return _serialize_messages(chat.messages) 49 | 50 | 51 | def print_chat(chat_id: int): 52 | hist = load_chat_history(chat_id) 53 | print(Markdown("---")) 54 | print() 55 | for message in hist: 56 | if hasattr(message, "tool_calls") and len(message.tool_calls) > 0: # type: ignore 57 | print(syntax(message.tool_calls[0]["args"].get("arg"))) # type: ignore 58 | print() 59 | else: 60 | try: 61 | rtool = parse_react_tool(message) 62 | except Exception: 63 | rtool = None 64 | if rtool: 65 | print(syntax(rtool.arg)) 66 | else: 67 | avatar = "✨" if isinstance(message, HumanMessage) else "🤖" 68 | msg = str(message.content) 69 | if msg.startswith("Tool response:"): 70 | msg = msg[14:].strip() 71 | print(Markdown(f"{avatar}: {msg}")) 72 | print() 73 | print(Markdown("---")) 74 | 75 | 76 | def print_recent_commands( 77 | script_type: ScriptType | None = None, 78 | ): 79 | print("\n[bold green]Commands history[/bold green]") 80 | with session() as db: 81 | chat_manager = ChatManager(db) 82 | scripts = chat_manager.get_recent_scripts(script_type) 83 | for script, count, kind, created_at in reversed(scripts): 84 | if len(script) > 100: 85 | script = script[:100] + "..." 86 | print( 87 | f"\n[magenta][{kind.name}][/] {to_local(created_at).strftime('%Y-%m-%d %H:%M')}" 88 | ) 89 | print( 90 | syntax( 91 | script, lexer="python" if kind == ScriptType.PYTHON else "console" 92 | ) 93 | ) 94 | 95 | 96 | def save_chat_history( 97 | *, 98 | chat_id: int | None = None, 99 | title: str | None = None, 100 | messages: list[BaseMessage] | None = None, 101 | meta: dict | None = None, 102 | executed_scripts: list[dict] | None = None, 103 | ): 104 | with session() as db: 105 | chat_manager = ChatManager(db) 106 | chat_manager.save_chat( 107 | chat_id, 108 | title=title, 109 | messages=( 110 | [message.model_dump() for message in messages] if messages else None 111 | ), 112 | meta=meta, 113 | executed_scripts=( 114 | [ExecutedScript(**script) for script in executed_scripts] 115 | if executed_scripts 116 | else None 117 | ), 118 | ) 119 | 120 | 121 | def truncate_chats(keep: int = 100): 122 | with session() as db: 123 | chat_manager = ChatManager(db) 124 | chat_manager.truncate_chats(keep) 125 | -------------------------------------------------------------------------------- /shy_sh/manager/sql_models.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from datetime import datetime 3 | import os 4 | from sqlalchemy import DateTime, Text, create_engine, JSON, ForeignKey 5 | from sqlalchemy.orm import ( 6 | declarative_base, 7 | sessionmaker, 8 | mapped_column, 9 | relationship, 10 | Mapped, 11 | ) 12 | from enum import Enum 13 | 14 | Base = declarative_base() 15 | 16 | 17 | class Chat(Base): 18 | __tablename__ = "chat" 19 | 20 | id: Mapped[int] = mapped_column(primary_key=True) 21 | title: Mapped[str] 22 | messages: Mapped[list[dict]] = mapped_column(JSON) 23 | meta: Mapped[dict] = mapped_column(JSON) 24 | created_at: Mapped[datetime] = mapped_column( 25 | DateTime, default=lambda: datetime.utcnow() 26 | ) 27 | 28 | executed_scripts: Mapped[list["ExecutedScript"]] = relationship( 29 | "ExecutedScript", back_populates="chat", cascade="all, delete-orphan" 30 | ) 31 | 32 | 33 | class ScriptType(str, Enum): 34 | SHELL = "shell" 35 | SHELL_SCRIPT = "shell_script" 36 | PYTHON = "python" 37 | 38 | 39 | class ExecutedScript(Base): 40 | __tablename__ = "executed_script" 41 | 42 | id: Mapped[int] = mapped_column(primary_key=True) 43 | type: Mapped[ScriptType] 44 | script: Mapped[str] = mapped_column(Text) 45 | result: Mapped[str | None] = mapped_column(Text) 46 | created_at: Mapped[datetime] = mapped_column( 47 | DateTime, default=lambda: datetime.utcnow() 48 | ) 49 | 50 | chat_id: Mapped[int] = mapped_column(ForeignKey("chat.id")) 51 | chat: Mapped[Chat] = relationship("Chat", back_populates="executed_scripts") 52 | 53 | 54 | DATABASE_URL = os.path.expanduser("~/.config/shy/shy.db") 55 | if not os.path.exists(os.path.dirname(DATABASE_URL)): 56 | os.makedirs(os.path.dirname(DATABASE_URL), exist_ok=True) 57 | engine = create_engine(f"sqlite:///{DATABASE_URL}") 58 | SessionLocal = sessionmaker(bind=engine) 59 | Base.metadata.create_all(engine) 60 | 61 | 62 | @contextmanager 63 | def session(): 64 | session = SessionLocal() 65 | try: 66 | yield session 67 | session.commit() 68 | except Exception as e: 69 | session.rollback() 70 | raise e 71 | finally: 72 | session.close() 73 | -------------------------------------------------------------------------------- /shy_sh/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing_extensions import TypedDict, Annotated 3 | from langgraph.graph.message import add_messages 4 | from pydantic import BaseModel 5 | 6 | 7 | class ToolRequest(BaseModel): 8 | tool: str 9 | arg: str 10 | thoughts: Optional[str] = None 11 | 12 | 13 | class FinalResponse(BaseModel): 14 | response: str 15 | 16 | 17 | class ToolMeta(BaseModel): 18 | stop_execution: bool = False 19 | skip_print: bool = False 20 | executed_scripts: list[dict] = [] 21 | 22 | 23 | def append(left: list, right: list, **kwargs) -> list: 24 | """Append right to left and return the result""" 25 | left.extend(right) 26 | return left 27 | 28 | 29 | class State(TypedDict): 30 | timestamp: str 31 | lang_spec: str 32 | ask_before_execute: bool = True 33 | tools_instructions: str | None = None 34 | few_shot_examples: Annotated[list, add_messages] = [] 35 | history: Annotated[list, add_messages] = [] 36 | tool_history: Annotated[list, add_messages] = [] 37 | executed_scripts: Annotated[list, append] = [] 38 | -------------------------------------------------------------------------------- /shy_sh/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | from typing import Type, Any, Literal 5 | 6 | from pydantic_settings import ( 7 | BaseSettings, 8 | SettingsConfigDict, 9 | YamlConfigSettingsSource, 10 | ) 11 | from pydantic import BaseModel 12 | from pathlib import Path 13 | from questionary import confirm, text, select, password, Style 14 | 15 | 16 | class BaseLLMSchema(BaseModel): 17 | provider: str 18 | name: str 19 | api_key: str = "" 20 | base_url: str | None = None 21 | temperature: float = 0.0 22 | 23 | 24 | class LLMSchema(BaseLLMSchema): 25 | agent_pattern: Literal["function_call", "react"] = "react" 26 | 27 | 28 | class _Settings(BaseModel): 29 | llm: LLMSchema = LLMSchema(provider="ollama", name="llama3.2") 30 | 31 | language: str = "" 32 | sandbox_mode: bool = False 33 | 34 | 35 | class Settings(BaseSettings, _Settings): 36 | model_config = SettingsConfigDict( 37 | extra="ignore", 38 | yaml_file=[ 39 | "~/.config/shy/config.yaml", 40 | "~/.config/shy/config.yml", 41 | "./shy.yaml", 42 | "./shy.yml", 43 | ], 44 | ) 45 | 46 | @classmethod 47 | def settings_customise_sources( 48 | cls, 49 | settings_cls: Type[BaseSettings], 50 | **kwargs: Any, 51 | ): 52 | return (YamlConfigSettingsSource(settings_cls),) 53 | 54 | 55 | settings = Settings() 56 | PROVIDERS = ["ollama", "openai", "google", "anthropic", "groq", "aws"] 57 | 58 | 59 | def get_or_create_settings_path(): 60 | file_name = None 61 | for f in reversed(settings.model_config["yaml_file"]): # type: ignore 62 | f = Path(f).expanduser() 63 | if os.path.exists(f): 64 | file_name = f 65 | break 66 | if not file_name: 67 | file_name = settings.model_config["yaml_file"][0] # type: ignore 68 | file_name = Path(file_name).expanduser() 69 | os.makedirs(os.path.dirname(file_name), exist_ok=True) 70 | return file_name 71 | 72 | 73 | _text_style = { 74 | "qmark": "", 75 | "style": Style.from_dict( 76 | { 77 | "selected": "fg:darkorange noreverse", 78 | "question": "fg:ansigreen nobold", 79 | "highlighted": "fg:darkorange", 80 | "text": "fg:darkorange", 81 | "answer": "fg:darkorange nobold", 82 | "instruction": "fg:darkorange", 83 | } 84 | ), 85 | } 86 | 87 | _select_style = { 88 | "pointer": "►", 89 | "instruction": " ", 90 | **_text_style, 91 | } 92 | 93 | 94 | def _try_float(x): 95 | try: 96 | float(x) 97 | return True 98 | except ValueError: 99 | return "Please enter a valid number" 100 | 101 | 102 | def configure_yaml(): 103 | provider = select( 104 | message="Provider:", 105 | choices=PROVIDERS, 106 | default=settings.llm.provider, 107 | **_select_style, 108 | ).unsafe_ask() 109 | if provider != "ollama": 110 | api_key = password( 111 | message="API Key:", 112 | default=settings.llm.api_key, 113 | **_text_style, 114 | ).unsafe_ask() 115 | else: 116 | api_key = settings.llm.api_key 117 | base_url = settings.llm.base_url 118 | if provider == "openai": 119 | base_url = ( 120 | text( 121 | message="Base URL:", 122 | default=settings.llm.base_url or "", 123 | **_text_style, 124 | ) 125 | .unsafe_ask() 126 | .strip() 127 | ) 128 | base_url = base_url or None 129 | model = input_model(provider, api_key, settings.llm.name) 130 | agent_pattern = select( 131 | message="Agent Pattern:", 132 | choices=["function_call", "react"], 133 | default=settings.llm.agent_pattern, 134 | **_select_style, 135 | ).unsafe_ask() 136 | temperature = text( 137 | message="Temperature:", 138 | default=str(settings.llm.temperature), 139 | validate=lambda x: _try_float(x), 140 | **_text_style, 141 | ).unsafe_ask() 142 | 143 | llm = { 144 | "provider": provider, 145 | "name": model, 146 | "api_key": api_key, 147 | "base_url": base_url, 148 | "temperature": float(temperature), 149 | "agent_pattern": agent_pattern, 150 | } 151 | 152 | language = text("Language:", default=settings.language, **_text_style).unsafe_ask() 153 | sandbox_mode = confirm( 154 | "Sandbox Mode:", 155 | default=settings.sandbox_mode, 156 | **_text_style, 157 | ).unsafe_ask() 158 | 159 | file_name = get_or_create_settings_path() 160 | 161 | with open(file_name, "w") as f: 162 | f.write( 163 | yaml.dump( 164 | { 165 | "llm": llm, 166 | "language": language, 167 | "sandbox_mode": sandbox_mode, 168 | } 169 | ) 170 | ) 171 | 172 | print(f"\nConfiguration saved to {file_name}") 173 | 174 | 175 | def input_model(provider: str, api_key: str, default_model: str | None = None): 176 | try: 177 | match provider: 178 | case "ollama": 179 | from ollama import list 180 | 181 | r = list() 182 | model_list = [l.model.replace(":latest", "") for l in r["models"]] 183 | return select( 184 | message="Model:", 185 | choices=model_list, 186 | default=default_model if default_model in model_list else None, 187 | **_select_style, 188 | ).unsafe_ask() 189 | case "openai": 190 | from openai import OpenAI 191 | 192 | r = OpenAI(api_key=api_key).models.list() 193 | model_list = [l.id for l in r.data] 194 | return select( 195 | message="Model:", 196 | choices=model_list, 197 | default=default_model if default_model in model_list else None, 198 | **_select_style, 199 | ).unsafe_ask() 200 | case "google": 201 | from google.ai.generativelanguage import ModelServiceClient 202 | from google.auth.api_key import Credentials 203 | 204 | r = ModelServiceClient(credentials=Credentials(api_key)).list_models() 205 | model_list = [l.name.replace("models/", "", 1) for l in r] 206 | return select( 207 | message="Model:", 208 | choices=model_list, 209 | default=default_model if default_model in model_list else None, 210 | **_select_style, 211 | ).unsafe_ask() 212 | case "anthropic": 213 | import requests 214 | 215 | r = requests.get( 216 | "https://api.anthropic.com/v1/models", 217 | headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"}, 218 | ).json() 219 | model_list = [l["id"] for l in r["data"]] 220 | 221 | return select( 222 | message="Model:", 223 | choices=model_list, 224 | default=default_model if default_model in model_list else None, 225 | **_select_style, 226 | ).unsafe_ask() 227 | case "groq": 228 | from groq import Client 229 | 230 | r = Client(api_key=api_key).models.list() 231 | model_list = [l.id for l in r.data] 232 | return select( 233 | message="Model:", 234 | choices=model_list, 235 | default=default_model if default_model in model_list else None, 236 | **_select_style, 237 | ).unsafe_ask() 238 | case "aws": 239 | from boto3 import client 240 | 241 | region, access_key, secret_key = api_key.split(" ") 242 | r = client( 243 | "bedrock", 244 | region_name=region, 245 | aws_access_key_id=access_key, 246 | aws_secret_access_key=secret_key, 247 | ).list_foundation_models( 248 | byOutputModality="TEXT", 249 | ) 250 | model_list = [l["modelId"] for l in r["modelSummaries"]] 251 | return select( 252 | message="Model:", 253 | choices=model_list, 254 | default=default_model if default_model in model_list else None, 255 | **_select_style, 256 | ).unsafe_ask() 257 | case _: 258 | raise ValueError("Invalid provider") 259 | except Exception: 260 | return text(message="Model:", **_text_style).unsafe_ask() 261 | -------------------------------------------------------------------------------- /shy_sh/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | import re 3 | import os 4 | import platform 5 | import subprocess 6 | from typing import Literal 7 | from tiktoken import get_encoding 8 | from rich.prompt import Prompt 9 | from rich.syntax import Syntax 10 | from langchain_core.messages import HumanMessage, ToolMessage, AIMessage 11 | from shy_sh.settings import settings 12 | 13 | try: 14 | import readline 15 | except Exception: 16 | readline = None 17 | 18 | 19 | RL_HISTORY_FILE = os.path.expanduser("~/.config/shy/.history") 20 | 21 | 22 | def load_history(): 23 | if not readline: 24 | return 25 | try: 26 | readline.read_history_file(RL_HISTORY_FILE) 27 | except Exception: 28 | pass 29 | 30 | 31 | def save_history(): 32 | if not readline: 33 | return 34 | readline.write_history_file(RL_HISTORY_FILE) 35 | 36 | 37 | def clear_history(): 38 | if not readline: 39 | return 40 | readline.clear_history() 41 | 42 | 43 | def ask_confirm(explain=True, alternatives=False) -> Literal["y", "n", "c", "e", "a"]: 44 | clear_history() 45 | choices = ["n", "c", "no", "copy"] 46 | if explain: 47 | choices.extend(["e", "explain"]) 48 | if alternatives: 49 | choices.extend(["a", "alternatives"]) 50 | if not settings.sandbox_mode: 51 | choices.extend(["y", "yes"]) 52 | 53 | ret = Prompt.ask( 54 | f"""\n [dark_orange]{ 55 | 'Do you need more details?' if settings.sandbox_mode else 'Do you want to execute this command?' 56 | }[/] [bold magenta][{ 57 | '[underline]C[/]opy/[underline]n[/]o' if settings.sandbox_mode else '[underline]Y[/]es/[underline]n[/]o/[underline]c[/]opy' 58 | }{ 59 | '/[underline]e[/]xplain' if explain else '' 60 | }{ 61 | '/[underline]a[/]lternatives' if alternatives else '' 62 | }][/]""", 63 | choices=choices, 64 | default="c" if settings.sandbox_mode else "y", 65 | show_default=False, 66 | show_choices=False, 67 | case_sensitive=False, 68 | ).lower()[0] 69 | clear_history() 70 | load_history() 71 | return ret # type: ignore 72 | 73 | 74 | def syntax(text: str, lexer: str = "console"): 75 | return Syntax( 76 | text, 77 | lexer, 78 | word_wrap=True, 79 | background_color="default", 80 | ) 81 | 82 | 83 | def decode_output(process): 84 | try: 85 | response = process.stdout.decode() or process.stderr.decode() 86 | except UnicodeDecodeError: 87 | # windows 88 | import ctypes 89 | 90 | oemCP = ctypes.windll.kernel32.GetConsoleOutputCP() 91 | encoding = "cp" + str(oemCP) 92 | response = process.stdout.decode(encoding) or process.stderr.decode(encoding) 93 | return response 94 | 95 | 96 | def decode_output2(text: bytes): 97 | try: 98 | response = text.decode() 99 | except UnicodeDecodeError: 100 | # windows 101 | import ctypes 102 | 103 | oemCP = ctypes.windll.kernel32.GetConsoleOutputCP() 104 | encoding = "cp" + str(oemCP) 105 | response = text.decode(encoding) 106 | return response 107 | 108 | 109 | def run_shell(cmd: str): 110 | if cmd == "history" or cmd.startswith("history "): 111 | return get_shell_history() 112 | result = subprocess.run( 113 | cmd, 114 | stdout=subprocess.PIPE, 115 | stderr=subprocess.PIPE, 116 | shell=True, 117 | ) 118 | return decode_output(result) 119 | 120 | 121 | def run_pty(cmd: str): 122 | import pty 123 | 124 | if cmd == "history" or cmd.startswith("history "): 125 | return get_shell_history() 126 | stdout = b"" 127 | 128 | def read(fd): 129 | nonlocal stdout 130 | ret = os.read(fd, 1024) 131 | stdout += ret 132 | return ret 133 | 134 | ret_code = pty.spawn([detect_raw_shell(), "-c", cmd], read) 135 | return ret_code, decode_output2(stdout) 136 | 137 | 138 | def stream_shell(cmd: str): 139 | if cmd == "history" or cmd.startswith("history "): 140 | return get_shell_history() 141 | result = subprocess.Popen( 142 | cmd, 143 | stdout=subprocess.PIPE, 144 | stderr=subprocess.PIPE, 145 | shell=True, 146 | ) 147 | while result.poll() is None: 148 | chunk = b"" 149 | if result.stdout is not None and result.stdout.readable(): 150 | chunk += result.stdout.read(1) 151 | yield decode_output2(chunk) 152 | remaining = b"" 153 | if result.stdout is not None and result.stdout.readable(): 154 | remaining += result.stdout.read() 155 | if result.stderr is not None and result.stderr.readable(): 156 | remaining += result.stderr.read() 157 | 158 | if remaining: 159 | yield decode_output2(remaining) 160 | 161 | 162 | def run_command(cmd: str): 163 | if detect_shell() in ["powershell", "cmd"]: 164 | result = "" 165 | for chunk in stream_shell(cmd): 166 | print(chunk, end="", flush=True) 167 | result += chunk 168 | result = result or "Exit code: 0" 169 | else: 170 | ret_code, result = run_pty(cmd) 171 | result = result or f"Exit code: {ret_code}" 172 | return result 173 | 174 | 175 | def run_python(file: str): 176 | return run_command(f"python {file}") 177 | 178 | 179 | def detect_raw_shell(): 180 | return os.environ.get("SHELL") or os.environ.get("COMSPEC") or "sh" 181 | 182 | 183 | def detect_shell(): 184 | shell = detect_raw_shell() 185 | shell = shell.lower().split("/")[-1] 186 | if "powershell" in shell: 187 | return "powershell" 188 | elif "cmd" in shell: 189 | return "cmd" 190 | 191 | return shell 192 | 193 | 194 | def detect_os(): 195 | system = platform.system() or "linux" 196 | if system.lower() == "darwin": 197 | return "macos" 198 | return system 199 | 200 | 201 | def count_tokens( 202 | messages: list, encoding_name: str = "o200k_base", offset: int = 2000 203 | ) -> int: 204 | text = "\n".join(msg.content for msg in messages) 205 | encoding = get_encoding(encoding_name) 206 | return len(encoding.encode(text)) + offset 207 | 208 | 209 | def tools_to_human(messages): 210 | return [ 211 | ( 212 | HumanMessage(msg.content) 213 | if isinstance(msg, ToolMessage) 214 | else ( 215 | AIMessage(msg.content or "tool_request") 216 | if isinstance(msg, AIMessage) 217 | else msg 218 | ) 219 | ) 220 | for msg in messages 221 | ] 222 | 223 | 224 | HISTORY_FILES = { 225 | "bash": ".bash_history", 226 | "sh": ".bash_history", 227 | "zsh": ".zsh_history", 228 | "fish": ".local/share/fish/fish_history", 229 | "ksh": ".ksh_history", 230 | "tcsh": ".history", 231 | } 232 | 233 | 234 | def get_shell_history(): 235 | try: 236 | shell = detect_shell() 237 | history_file = HISTORY_FILES[shell] 238 | with open(os.path.expanduser(f"~/{history_file}"), "r") as f: 239 | history = f.read() 240 | return "\n".join( 241 | [ 242 | cmd 243 | for cmd in history.strip().split("\n")[:-1] 244 | if cmd != "shy" 245 | and not cmd.startswith("shy ") 246 | and ";shy " not in cmd 247 | and not cmd.endswith(";shy") 248 | ][-5:] 249 | ) 250 | except Exception: 251 | return "I can't get the history for this shell" 252 | 253 | 254 | def parse_code(code): 255 | code = re.sub(r"```[^\n]*\n", "", code) 256 | return code[: code.rfind("```")] 257 | 258 | 259 | def to_local(date: datetime): 260 | try: 261 | import tzlocal 262 | 263 | local_tz = tzlocal.get_localzone() 264 | return date.replace(tzinfo=timezone.utc).astimezone(local_tz) 265 | except Exception: 266 | return date 267 | 268 | 269 | SUGGESTIONS = ["/chats", "/clear", "/history", "/load ", "quit"] 270 | 271 | 272 | def command_completer(text, state): 273 | matches = [s for s in SUGGESTIONS if s.startswith(text)] 274 | return matches[state] if state < len(matches) else None 275 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typer import Typer 3 | from typer.testing import CliRunner 4 | from shy_sh.main import exec as main 5 | from tests.utils import mock_settings 6 | 7 | 8 | def pytest_addoption(parser): 9 | parser.addoption( 10 | "--eval", action="store_true", default=False, help="run slow tests" 11 | ) 12 | 13 | 14 | def pytest_configure(config): 15 | config.addinivalue_line("markers", "slow: mark test as slow to run") 16 | 17 | 18 | def pytest_collection_modifyitems(config, items): 19 | is_eval = config.getoption("--eval") 20 | skip_eval = pytest.mark.skip(reason="need --eval option to run") 21 | for item in items: 22 | if ("eval" in item.keywords and not is_eval) or ( 23 | "eval" not in item.keywords and is_eval 24 | ): 25 | item.add_marker(skip_eval) 26 | 27 | 28 | @pytest.fixture(autouse=True) 29 | def mock_app_settings(request): 30 | if "eval" in request.keywords: 31 | return 32 | mock_settings( 33 | { 34 | "language": "esperanto", 35 | "llm": { 36 | "agent_pattern": "react", 37 | "api_key": "xxx", 38 | "name": "test", 39 | "provider": "ollama", 40 | "temperature": 1.0, 41 | }, 42 | } 43 | ) 44 | 45 | 46 | @pytest.fixture(autouse=True) 47 | def mock_readline(mocker): 48 | mocker.patch("readline.set_history_length") 49 | mocker.patch("readline.read_history_file") 50 | mocker.patch("readline.write_history_file") 51 | 52 | 53 | @pytest.fixture() 54 | def exec(): 55 | runner = CliRunner() 56 | app = Typer() 57 | app.command()(main) 58 | 59 | def invoke(cmd, **kwargs): 60 | return runner.invoke(app, cmd.split(" "), **kwargs) 61 | 62 | return invoke 63 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from tests.utils import mock_llm 3 | 4 | 5 | def test_version(exec): 6 | result = exec("--version") 7 | assert result.exit_code == 0 8 | assert result.stdout.startswith("Version: ") 9 | 10 | 11 | def test_question(exec, mocker): 12 | with mock_llm(mocker): 13 | result = exec("-o how are you") 14 | assert result.exit_code == 0 15 | 16 | assert "✨: how are you" in result.stdout 17 | assert "🤖: test" in result.stdout 18 | 19 | 20 | def test_interactive_question(exec, mocker): 21 | with mock_llm(mocker, ["fine thanks"]): 22 | result = exec("how are you", input="exit\n") 23 | assert result.exit_code == 0 24 | 25 | assert "✨: how are you" in result.stdout 26 | assert "🤖: fine thanks" in result.stdout 27 | assert "👋 Bye!" in result.stdout 28 | 29 | 30 | def test_use_shell_tool(exec, mocker): 31 | confirm = mocker.patch("shy_sh.agents.tools.shell.ask_confirm", return_value="y") 32 | with mock_llm( 33 | mocker, 34 | [ 35 | '{"tool": "shell", "arg": "echo fine thanks", "thoughts": "test"}', 36 | "fine thanks", 37 | ], 38 | ): 39 | result = exec("-o how are you") 40 | assert result.exit_code == 0 41 | 42 | assert "✨: how are you" in result.stdout 43 | assert confirm.call_count == 1 44 | assert "🛠️ echo fine thanks" in result.stdout 45 | assert "🤖: fine thanks" in result.stdout 46 | 47 | 48 | def test_use_shell_tool_with_alternatives(exec, mocker): 49 | confirm = mocker.patch("shy_sh.agents.tools.shell.ask_confirm", return_value="a") 50 | select = mocker.patch( 51 | "shy_sh.agents.tools.shell._select_alternative_command", 52 | return_value="echo fine thanks", 53 | ) 54 | with mock_llm( 55 | mocker, 56 | [ 57 | '{"tool": "shell", "arg": "echo fine thanks", "thoughts": "test"}', 58 | "fine thanks", 59 | ], 60 | ): 61 | result = exec("-o how are you") 62 | assert result.exit_code == 0 63 | 64 | assert "✨: how are you" in result.stdout 65 | assert confirm.call_count == 1 66 | assert select.call_count == 1 67 | assert "🛠️ echo fine thanks" in result.stdout 68 | assert "🤖: fine thanks" in result.stdout 69 | 70 | 71 | def test_use_shell_tool_with_explain(exec, mocker): 72 | confirm = mocker.patch( 73 | "shy_sh.agents.tools.shell.ask_confirm", side_effect=["e", "y"] 74 | ) 75 | with mock_llm( 76 | mocker, 77 | [ 78 | '{"tool": "shell", "arg": "echo fine thanks", "thoughts": "test"}', 79 | "fine thanks", 80 | ], 81 | ): 82 | result = exec("-o how are you") 83 | assert result.exit_code == 0 84 | 85 | assert "✨: how are you" in result.stdout 86 | # assert confirm.call_count == 2 FIXME 87 | assert "🛠️ echo fine thanks" in result.stdout 88 | assert "🤖: fine thanks" in result.stdout 89 | 90 | 91 | def test_use_shell_tool_no_confirmation(exec, mocker): 92 | with mock_llm( 93 | mocker, 94 | [ 95 | '{"tool": "shell", "arg": "echo fine thanks", "thoughts": "test"}', 96 | "fine thanks", 97 | ], 98 | ): 99 | result = exec("-o -x how are you") 100 | assert result.exit_code == 0 101 | 102 | assert "✨: how are you" in result.stdout 103 | assert ( 104 | "Do you want to execute this command? [Yes/no/copy/explain]" 105 | not in result.stdout 106 | ) 107 | assert "🛠️ echo fine thanks" in result.stdout 108 | assert "🤖: fine thanks" in result.stdout 109 | 110 | 111 | def test_use_shell_expert_tool(exec, mocker): 112 | confirm = mocker.patch( 113 | "shy_sh.agents.tools.shell_expert.ask_confirm", return_value="y" 114 | ) 115 | with mock_llm( 116 | mocker, 117 | [ 118 | '{"tool": "shell_expert", "arg": "say how are you", "thoughts": "test"}', 119 | "```sh\necho fine thanks\n```", 120 | "fine thanks", 121 | ], 122 | ): 123 | result = exec("-o how are you") 124 | assert result.exit_code == 0 125 | 126 | assert "✨: how are you" in result.stdout 127 | assert "💻 Generating shell script..." in result.stdout 128 | assert confirm.call_count == 1 129 | assert "echo fine thanks" in result.stdout 130 | assert "🤖: fine thanks" in result.stdout 131 | 132 | 133 | def test_use_shell_expert_tool_no_confirmation(exec, mocker): 134 | confirm = mocker.patch( 135 | "shy_sh.agents.tools.shell_expert.ask_confirm", return_value="y" 136 | ) 137 | with mock_llm( 138 | mocker, 139 | [ 140 | '{"tool": "shell_expert", "arg": "say how are you", "thoughts": "test"}', 141 | "```sh\necho fine thanks\n```", 142 | "fine thanks", 143 | ], 144 | ): 145 | result = exec("-o -x how are you") 146 | assert result.exit_code == 0 147 | 148 | assert "✨: how are you" in result.stdout 149 | assert "💻 Generating shell script..." in result.stdout 150 | assert confirm.call_count == 0 151 | assert "echo fine thanks" in result.stdout 152 | assert "🤖: fine thanks" in result.stdout 153 | 154 | 155 | def test_use_python_expert_tool(exec, mocker): 156 | confirm = mocker.patch( 157 | "shy_sh.agents.tools.python_expert.ask_confirm", return_value="y" 158 | ) 159 | with mock_llm( 160 | mocker, 161 | [ 162 | '{"tool": "python_expert", "arg": "say how are you", "thoughts": "test"}', 163 | "```python\nprint('fine thanks')\n```", 164 | "fine thanks", 165 | ], 166 | ): 167 | result = exec("-o how are you") 168 | assert result.exit_code == 0 169 | 170 | assert "✨: how are you" in result.stdout 171 | assert "🐍 Generating python script..." in result.stdout 172 | assert confirm.call_count == 1 173 | assert "print('fine thanks')" in result.stdout 174 | assert "🤖: fine thanks" in result.stdout 175 | 176 | 177 | def test_use_python_expert_tool_no_confirmation(exec, mocker): 178 | confirm = mocker.patch( 179 | "shy_sh.agents.tools.python_expert.ask_confirm", return_value="y" 180 | ) 181 | with mock_llm( 182 | mocker, 183 | [ 184 | '{"tool": "python_expert", "arg": "say how are you", "thoughts": "test"}', 185 | "```python\nprint('fine thanks')\n```", 186 | "fine thanks", 187 | ], 188 | ): 189 | result = exec("-o -x how are you") 190 | assert result.exit_code == 0 191 | 192 | assert "✨: how are you" in result.stdout 193 | assert "🐍 Generating python script..." in result.stdout 194 | assert confirm.call_count == 0 195 | assert "print('fine thanks')" in result.stdout 196 | assert "🤖: fine thanks" in result.stdout 197 | 198 | 199 | def test_use_shell_history_tool(exec, mocker): 200 | with mock_llm( 201 | mocker, 202 | [ 203 | '{"tool": "shell_history", "arg": "", "thoughts": "test"}', 204 | "history readed", 205 | ], 206 | ): 207 | result = exec("-o how are you") 208 | assert result.exit_code == 0 209 | 210 | assert "✨: how are you" in result.stdout 211 | assert "📜 Let me check..." in result.stdout 212 | assert "These are the last commands executed by the user:" not in result.stdout 213 | assert "🤖: history readed" in result.stdout 214 | 215 | 216 | def test_use_tool_chain(exec, mocker): 217 | with mock_llm( 218 | mocker, 219 | [ 220 | '{"tool": "shell", "arg": "echo fine thanks", "thoughts": "test"}', 221 | '{"tool": "shell_expert", "arg": "say how are you with shell", "thoughts": "test"}', 222 | "```sh\necho fine thanks\n```", 223 | '{"tool": "python_expert", "arg": "say how are you", "thoughts": "test"}', 224 | "```python\nprint('fine thanks')\n```", 225 | "fine thanks", 226 | ], 227 | ): 228 | result = exec("-o -x how are you") 229 | assert result.exit_code == 0 230 | 231 | assert "✨: how are you" in result.stdout 232 | assert "🛠️ echo fine thanks" in result.stdout 233 | assert "💻 Generating shell script..." in result.stdout 234 | assert "🐍 Generating python script..." in result.stdout 235 | assert "🤖: fine thanks" in result.stdout 236 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import warnings 3 | from langchain_core.messages import HumanMessage, AIMessage, ToolMessage 4 | from shy_sh.agents.misc import get_graph_inputs, run_few_shot_examples 5 | from langgraph.graph import StateGraph, START, END 6 | from shy_sh.agents.shy_agent.nodes.chatbot import chatbot 7 | from shy_sh.agents.shy_agent.nodes.tools_handler import _get_tool_calls 8 | from shy_sh.agents.shy_agent.edges.tool_calls import tool_calls_edge 9 | from shy_sh.models import State 10 | 11 | 12 | def tools_handler(state: State): 13 | last_message = state["tool_history"][-1] 14 | t_calls = _get_tool_calls(last_message) 15 | return { 16 | "tool_history": [ 17 | ToolMessage(content=tc["args"]["arg"], tool_call_id=tc["id"]) 18 | for tc in t_calls 19 | ] 20 | } 21 | 22 | 23 | @pytest.fixture 24 | def graph(): 25 | graph_builder = StateGraph(State) 26 | graph_builder.add_node("chatbot", chatbot) 27 | graph_builder.add_node("tools", tools_handler) 28 | 29 | graph_builder.add_edge(START, "chatbot") 30 | graph_builder.add_conditional_edges("chatbot", tool_calls_edge) 31 | graph_builder.add_conditional_edges("tools", lambda _: END) 32 | 33 | return graph_builder.compile() 34 | 35 | 36 | @pytest.fixture 37 | def run_graph(graph): 38 | def _run(task): 39 | inputs = get_graph_inputs( 40 | history=[HumanMessage(content=task)], 41 | examples=run_few_shot_examples(), 42 | ask_before_execute=False, 43 | ) 44 | return graph.invoke(inputs)["tool_history"] 45 | 46 | return _run 47 | 48 | 49 | @pytest.mark.eval 50 | def test_list_files(run_graph): 51 | history = run_graph("list files") 52 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 53 | assert [ 54 | True 55 | for x in history 56 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 57 | ] 58 | assert tool_msg.content.startswith("ls") 59 | 60 | 61 | @pytest.mark.eval 62 | def test_convert_image(run_graph): 63 | history = run_graph("convert test.png to jpg") 64 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 65 | assert [ 66 | True 67 | for x in history 68 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 69 | ] 70 | assert ( 71 | tool_msg.content.startswith("convert") 72 | or tool_msg.content.startswith("magick") 73 | or tool_msg.content.startswith("mogrify") 74 | or tool_msg.content.startswith("sips") 75 | ) 76 | 77 | 78 | @pytest.mark.eval 79 | def test_find_files(run_graph): 80 | history = run_graph("find all python files") 81 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 82 | assert [ 83 | True 84 | for x in history 85 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 86 | ] 87 | assert tool_msg.content.startswith("find") 88 | 89 | 90 | @pytest.mark.eval 91 | def test_terraform_base(run_graph): 92 | history = run_graph( 93 | "give me the command to apply the terraform configuration using the local.tfvars file" 94 | ) 95 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 96 | assert [ 97 | True 98 | for x in history 99 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 100 | ] 101 | assert "terraform apply -var-file=local.tfvars" in tool_msg.content 102 | 103 | 104 | @pytest.mark.eval 105 | def test_terraform_advanced(run_graph): 106 | history = run_graph( 107 | "give me the terraform command to import the file base.json to the current state" 108 | ) 109 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 110 | assert [ 111 | True 112 | for x in history 113 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 114 | ] 115 | assert ( 116 | "terraform state" in tool_msg.content or "terraform import" in tool_msg.content 117 | ) 118 | if "base.json" not in tool_msg.content: 119 | warnings.warn( 120 | f"base.json not found in terraform command, probably a wrong command\n`{tool_msg.content}`" 121 | ) 122 | 123 | 124 | @pytest.mark.eval 125 | def test_git_diff(run_graph): 126 | history = run_graph("show me the git diff of the last commit") 127 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 128 | assert "git diff" in tool_msg.content 129 | assert [ 130 | True 131 | for x in history 132 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 133 | ] 134 | 135 | 136 | @pytest.mark.eval 137 | def test_git_tag(run_graph): 138 | history = run_graph("tag with 1.0.0 and push it to remote, all in one command") 139 | tool_msg = [x for x in history if isinstance(x, ToolMessage)][0] 140 | assert [ 141 | True 142 | for x in history 143 | if isinstance(x, AIMessage) and '{"tool": "shell"' in x.content 144 | ] 145 | assert "git tag" in tool_msg.content 146 | assert "1.0.0" in tool_msg.content 147 | assert "git push origin" in tool_msg.content 148 | 149 | 150 | @pytest.mark.eval 151 | def test_python_expert_call(run_graph): 152 | history = run_graph("write a python script to say hello") 153 | assert [ 154 | True 155 | for x in history 156 | if isinstance(x, AIMessage) and '{"tool": "python_expert"' in x.content 157 | ] 158 | 159 | 160 | @pytest.mark.eval 161 | def test_shell_expert_call(run_graph): 162 | history = run_graph( 163 | "write a shell script to write 5 files with a small story inside" 164 | ) 165 | assert [ 166 | True 167 | for x in history 168 | if isinstance(x, AIMessage) and '{"tool": "shell_expert"' in x.content 169 | ] 170 | 171 | 172 | @pytest.mark.eval 173 | def test_shell_history_call(run_graph): 174 | history = run_graph("show me the last 5 commands I ran") 175 | assert [ 176 | True 177 | for x in history 178 | if isinstance(x, AIMessage) and '{"tool": "shell_history"' in x.content 179 | ] 180 | -------------------------------------------------------------------------------- /tests/test_eval_experts.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from langchain_core.messages import HumanMessage, AIMessage 3 | from shy_sh.agents.chains.shell_expert import shexpert_chain 4 | from shy_sh.agents.chains.python_expert import pyexpert_chain 5 | from shy_sh.agents.chains.alternative_commands import get_alternative_commands 6 | from shy_sh.utils import parse_code 7 | 8 | 9 | @pytest.mark.eval 10 | def test_shell_expert_base(): 11 | code = shexpert_chain.invoke( 12 | { 13 | "input": "list files", 14 | "timestamp": "2022-01-01", 15 | "system": "linux", 16 | "shell": "bash", 17 | "history": [], 18 | } 19 | ) 20 | code = parse_code(code) 21 | assert code.startswith("#!/") 22 | assert "set -e" in code 23 | assert "ls" in code 24 | 25 | 26 | @pytest.mark.eval 27 | def test_shell_expert_with_history(): 28 | code = shexpert_chain.invoke( 29 | { 30 | "input": "create a script that works", 31 | "timestamp": "2022-01-01", 32 | "system": "linux", 33 | "shell": "bash", 34 | "history": [ 35 | HumanMessage(content="list files"), 36 | AIMessage(content="""{"tool": "shell", "args": {"arg": "dir"}}"""), 37 | HumanMessage(content="I'm on linux"), 38 | ], 39 | } 40 | ) 41 | code = parse_code(code) 42 | assert code.startswith("#!/") 43 | assert "ls" in code 44 | 45 | 46 | @pytest.mark.eval 47 | def test_python_expert_base(): 48 | code = pyexpert_chain.invoke( 49 | { 50 | "input": "write a script that prints the current date and time plus the word 'hello'", 51 | "timestamp": "2022-01-01", 52 | "history": [], 53 | } 54 | ) 55 | code = parse_code(code) 56 | assert "print(" in code 57 | assert "datetime" in code 58 | assert "now()" in code 59 | 60 | 61 | @pytest.mark.eval 62 | def test_alternatve_commands(): 63 | result = get_alternative_commands( 64 | { 65 | "cmd": "ls *.py", 66 | "timestamp": "2022-01-01", 67 | "system": "linux", 68 | "shell": "bash", 69 | "history": [HumanMessage(content="find all python files")], 70 | } 71 | ) 72 | cmds = [r[1] for r in result] 73 | assert [True for r in result if "find" in r[1]] 74 | assert [True for r in result if "ls" in r[1]] 75 | 76 | 77 | @pytest.mark.eval 78 | def test_alternatve_commands_powershell(): 79 | result = get_alternative_commands( 80 | { 81 | "cmd": "ls *.py", 82 | "timestamp": "2022-01-01", 83 | "system": "windows", 84 | "shell": "powershell", 85 | "history": [HumanMessage(content="find all python files")], 86 | } 87 | ) 88 | 89 | assert [True for r in result if "Get-ChildItem" in r[1]] 90 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from shy_sh.settings import settings, _Settings 2 | from shy_sh.agents.llms import get_llm 3 | from langchain_community.llms.fake import FakeListLLM 4 | from langchain_core.messages import AIMessage 5 | from contextlib import contextmanager 6 | 7 | 8 | def mock_settings(config): 9 | config = _Settings.model_validate(config) 10 | for key in config.model_dump().keys(): 11 | setattr(settings, key, getattr(config, key)) 12 | 13 | 14 | @contextmanager 15 | def mock_llm(mocker, responses=["test", "test2", "test3", "test4"]): 16 | def to_ai_message(x): 17 | return AIMessage(content=x) 18 | 19 | llm = FakeListLLM(responses=responses) | to_ai_message 20 | mocker.patch("shy_sh.agents.llms._get_llm", return_value=llm) 21 | yield llm 22 | get_llm.cache_clear() 23 | --------------------------------------------------------------------------------