├── .fetch.config ├── .github └── workflows │ ├── ci.yml │ ├── release.yml │ └── unstable.yml ├── .gitignore ├── .ovh.config ├── .shai.config ├── AUTHORS ├── CONTRIBUTING.md ├── CONTRIBUTORS ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── MAINTAINERS ├── README.md ├── SHAI.md ├── docs └── assets │ ├── auth.gif │ ├── shai-chain.gif │ ├── shai-headless.gif │ ├── shai-hello-world.gif │ ├── shai-http.png │ ├── shai-shell.png │ ├── shai-trace.gif │ └── shai.png ├── install.sh ├── shai-cli ├── Cargo.toml └── src │ ├── fc │ ├── client.rs │ ├── history.rs │ ├── mod.rs │ ├── protocol.rs │ ├── server.rs │ └── tests.rs │ ├── headless │ ├── app.rs │ ├── mod.rs │ └── tools.rs │ ├── main.rs │ ├── shell │ ├── mod.rs │ ├── pty.rs │ ├── rc.rs │ └── terminal.rs │ └── tui │ ├── app.rs │ ├── auth │ ├── auth.rs │ ├── config_env.rs │ ├── config_list.rs │ ├── config_model.rs │ ├── config_providers.rs │ └── mod.rs │ ├── cmdnav.rs │ ├── command.rs │ ├── helper.rs │ ├── input.rs │ ├── mod.rs │ ├── perm.rs │ ├── perm_alt_screen.rs │ └── theme.rs ├── shai-core ├── Cargo.lock ├── Cargo.toml ├── examples │ └── oauth_test.rs └── src │ ├── agent │ ├── README.md │ ├── actions │ │ ├── brain.rs │ │ ├── mod.rs │ │ └── tools.rs │ ├── agent.rs │ ├── brain.rs │ ├── builder.rs │ ├── claims.rs │ ├── error.rs │ ├── events.rs │ ├── mod.rs │ ├── output │ │ ├── log.rs │ │ ├── mod.rs │ │ ├── pretty.rs │ │ └── stdout.rs │ ├── protocol.rs │ ├── states │ │ ├── README.md │ │ ├── mod.rs │ │ ├── pause.rs │ │ ├── processing.rs │ │ ├── running.rs │ │ ├── starting.rs │ │ ├── states.rs │ │ └── terminal.rs │ └── tests.rs │ ├── config │ ├── agent.rs │ ├── config.rs │ └── mod.rs │ ├── lib.rs │ ├── logging.rs │ ├── runners │ ├── clifixer │ │ ├── fix.rs │ │ ├── mod.rs │ │ └── prompt.rs │ ├── coder │ │ ├── coder.rs │ │ ├── env.rs │ │ ├── mod.rs │ │ ├── prompt.rs │ │ └── tests.rs │ ├── compacter │ │ ├── compact.rs │ │ └── mod.rs │ ├── gerund │ │ ├── gerund.rs │ │ ├── mod.rs │ │ ├── prompt.rs │ │ └── tests.rs │ ├── mod.rs │ └── searcher │ │ ├── mod.rs │ │ ├── prompt.rs │ │ ├── searcher.rs │ │ └── tests.rs │ └── tools │ ├── bash │ ├── bash.rs │ ├── mod.rs │ ├── structs.rs │ └── tests.rs │ ├── fetch │ ├── fetch.rs │ ├── mod.rs │ ├── structs.rs │ └── tests.rs │ ├── fs │ ├── edit │ │ ├── edit.rs │ │ ├── mod.rs │ │ ├── structs.rs │ │ └── tests.rs │ ├── find │ │ ├── find.rs │ │ ├── mod.rs │ │ ├── structs.rs │ │ └── tests.rs │ ├── ls │ │ ├── ls.rs │ │ ├── mod.rs │ │ ├── structs.rs │ │ └── tests.rs │ ├── mod.rs │ ├── multiedit │ │ ├── mod.rs │ │ ├── multiedit.rs │ │ ├── structs.rs │ │ └── tests.rs │ ├── operation_log.rs │ ├── read │ │ ├── mod.rs │ │ ├── read.rs │ │ ├── structs.rs │ │ └── tests.rs │ ├── tests.rs │ └── write │ │ ├── mod.rs │ │ ├── structs.rs │ │ ├── tests.rs │ │ └── write.rs │ ├── highlight.rs │ ├── mcp │ ├── mcp.rs │ ├── mcp_config.rs │ ├── mcp_http.rs │ ├── mcp_oauth.rs │ ├── mcp_sse.rs │ ├── mcp_stdio.rs │ ├── mod.rs │ └── tests.rs │ ├── mod.rs │ ├── tests_llm.rs │ ├── todo │ ├── mod.rs │ ├── structs.rs │ ├── tests.rs │ └── todo.rs │ └── types.rs ├── shai-http ├── Cargo.toml └── src │ ├── apis │ ├── mod.rs │ ├── openai │ │ ├── completion │ │ │ ├── formatter.rs │ │ │ ├── handler.rs │ │ │ └── mod.rs │ │ ├── mod.rs │ │ └── response │ │ │ ├── formatter.rs │ │ │ ├── handler.rs │ │ │ ├── mod.rs │ │ │ └── types.rs │ └── simple │ │ ├── formatter.rs │ │ ├── handler.rs │ │ ├── mod.rs │ │ └── types.rs │ ├── error.rs │ ├── http.rs │ ├── lib.rs │ ├── session │ ├── lifecycle.rs │ ├── logger.rs │ ├── manager.rs │ ├── mod.rs │ └── session.rs │ └── streaming.rs ├── shai-llm ├── Cargo.toml └── src │ ├── chat.rs │ ├── client.rs │ ├── examples │ ├── basic_query.rs │ ├── function_calling.rs │ ├── function_calling_streaming.rs │ ├── mod.rs │ ├── query_with_history.rs │ └── streaming_query.rs │ ├── lib.rs │ ├── provider.rs │ ├── providers │ ├── anthropic │ │ ├── anthropic.rs │ │ ├── api.rs │ │ ├── mod.rs │ │ └── tests.rs │ ├── mistral.rs │ ├── mod.rs │ ├── ollama.rs │ ├── openai.rs │ ├── openai_compatible.rs │ ├── openrouter │ │ ├── api.rs │ │ ├── mod.rs │ │ └── openrouter.rs │ ├── ovhcloud.rs │ └── tests.rs │ └── tool │ ├── call.rs │ ├── call_fc_auto.rs │ ├── call_fc_required.rs │ ├── call_structured_output.rs │ ├── mod.rs │ ├── test_so.rs │ └── tool.rs └── shai-macros ├── Cargo.toml └── src └── lib.rs /.fetch.config: -------------------------------------------------------------------------------- 1 | { 2 | "name": "example", 3 | "description": "Example agent with MCP fetch server for web scraping and API calls", 4 | "llm_provider": { 5 | "provider": "ovhcloud", 6 | "env_vars": { 7 | "OVH_BASE_URL": "https://gpt-oss-120b.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1" 8 | }, 9 | "model": "gpt-oss-120b", 10 | "tool_method": "FunctionCall" 11 | }, 12 | "tools": { 13 | "builtin": ["*"], 14 | "builtin_excluded": ["fetch"], 15 | "mcp": { 16 | "fetch": { 17 | "config": { 18 | "type": "stdio", 19 | "command": "uvx", 20 | "args": ["mcp-server-fetch"] 21 | }, 22 | "enabled_tools": ["*"] 23 | } 24 | } 25 | }, 26 | "temperature": 0.3 27 | } 28 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI Build 2 | 3 | on: 4 | push: 5 | branches: [ dev/*, feature/*, fix/* ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | name: Build for ${{ matrix.os }} 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | include: 16 | - os: ubuntu-latest 17 | target: x86_64-unknown-linux-musl 18 | binary_name: shai 19 | asset_name: shai-linux-x86_64 20 | # - os: windows-latest 21 | # target: x86_64-pc-windows-msvc 22 | # binary_name: shai.exe 23 | # asset_name: shai-windows-x86_64.exe 24 | - os: macos-latest 25 | target: x86_64-apple-darwin 26 | binary_name: shai 27 | asset_name: shai-macos-x86_64 28 | - os: macos-latest 29 | target: aarch64-apple-darwin 30 | binary_name: shai 31 | asset_name: shai-macos-aarch64 32 | 33 | steps: 34 | - name: Checkout code 35 | uses: actions/checkout@v4 36 | 37 | - name: Install Rust 38 | uses: dtolnay/rust-toolchain@stable 39 | with: 40 | targets: ${{ matrix.target }} 41 | 42 | - name: Install dependencies (Linux) 43 | if: matrix.os == 'ubuntu-latest' 44 | run: | 45 | sudo apt-get update 46 | sudo apt-get install -y build-essential musl-tools 47 | 48 | - name: Add musl target (Linux) 49 | if: matrix.target == 'x86_64-unknown-linux-musl' 50 | run: rustup target add x86_64-unknown-linux-musl 51 | 52 | - name: Cache cargo registry 53 | uses: actions/cache@v4 54 | with: 55 | path: ~/.cargo/registry 56 | key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} 57 | restore-keys: | 58 | ${{ runner.os }}-cargo-registry- 59 | 60 | - name: Cache cargo index 61 | uses: actions/cache@v4 62 | with: 63 | path: ~/.cargo/git 64 | key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} 65 | restore-keys: | 66 | ${{ runner.os }}-cargo-index- 67 | 68 | - name: Cache target directory 69 | uses: actions/cache@v4 70 | with: 71 | path: target 72 | key: ${{ runner.os }}-${{ matrix.target }}-target-${{ hashFiles('**/Cargo.lock') }} 73 | restore-keys: | 74 | ${{ runner.os }}-${{ matrix.target }}-target- 75 | 76 | - name: Build binary 77 | run: cargo build --release --target ${{ matrix.target }} 78 | 79 | - name: Prepare binary (Unix) 80 | if: matrix.os != 'windows-latest' 81 | run: | 82 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 83 | strip ${{ matrix.asset_name }} 84 | 85 | - name: Prepare binary (Windows) 86 | if: matrix.os == 'windows-latest' 87 | run: | 88 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 89 | 90 | - name: Upload build artifacts 91 | uses: actions/upload-artifact@v4 92 | with: 93 | name: ${{ matrix.asset_name }} 94 | path: ${{ matrix.asset_name }} 95 | retention-days: 10 96 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | build: 9 | name: Build for ${{ matrix.os }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | include: 14 | - os: ubuntu-latest 15 | target: x86_64-unknown-linux-musl 16 | binary_name: shai 17 | asset_name: shai-linux-x86_64 18 | #- os: windows-latest 19 | # target: x86_64-pc-windows-msvc 20 | # binary_name: shai.exe 21 | # asset_name: shai-windows-x86_64.exe 22 | - os: macos-latest 23 | target: x86_64-apple-darwin 24 | binary_name: shai 25 | asset_name: shai-macos-x86_64 26 | - os: macos-latest 27 | target: aarch64-apple-darwin 28 | binary_name: shai 29 | asset_name: shai-macos-aarch64 30 | 31 | steps: 32 | - name: Checkout code 33 | uses: actions/checkout@v4 34 | 35 | - name: Install Rust 36 | uses: dtolnay/rust-toolchain@stable 37 | with: 38 | targets: ${{ matrix.target }} 39 | 40 | - name: Install dependencies (Linux) 41 | if: matrix.os == 'ubuntu-latest' 42 | run: | 43 | sudo apt-get update 44 | sudo apt-get install -y build-essential musl-tools 45 | 46 | - name: Add musl target (Linux) 47 | if: matrix.target == 'x86_64-unknown-linux-musl' 48 | run: rustup target add x86_64-unknown-linux-musl 49 | 50 | - name: Cache cargo registry 51 | uses: actions/cache@v4 52 | with: 53 | path: ~/.cargo/registry 54 | key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} 55 | restore-keys: | 56 | ${{ runner.os }}-cargo-registry- 57 | 58 | - name: Cache cargo index 59 | uses: actions/cache@v4 60 | with: 61 | path: ~/.cargo/git 62 | key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} 63 | restore-keys: | 64 | ${{ runner.os }}-cargo-index- 65 | 66 | - name: Cache target directory 67 | uses: actions/cache@v4 68 | with: 69 | path: target 70 | key: ${{ runner.os }}-${{ matrix.target }}-target-${{ hashFiles('**/Cargo.lock') }} 71 | restore-keys: | 72 | ${{ runner.os }}-${{ matrix.target }}-target- 73 | 74 | - name: Build binary 75 | run: cargo build --release --target ${{ matrix.target }} 76 | 77 | - name: Prepare binary (Unix) 78 | if: matrix.os != 'windows-latest' 79 | run: | 80 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 81 | strip ${{ matrix.asset_name }} 82 | 83 | - name: Prepare binary (Windows) 84 | if: matrix.os == 'windows-latest' 85 | run: | 86 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 87 | 88 | - name: Upload binary to release 89 | uses: actions/upload-release-asset@v1 90 | env: 91 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 92 | with: 93 | upload_url: ${{ github.event.release.upload_url }} 94 | asset_path: ${{ matrix.asset_name }} 95 | asset_name: ${{ matrix.asset_name }} 96 | asset_content_type: application/octet-stream -------------------------------------------------------------------------------- /.github/workflows/unstable.yml: -------------------------------------------------------------------------------- 1 | name: Unstable Build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | build: 9 | name: Build Unstable for ${{ matrix.os }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | include: 14 | - os: ubuntu-latest 15 | target: x86_64-unknown-linux-musl 16 | binary_name: shai 17 | asset_name: shai-unstable-linux-x86_64 18 | #- os: windows-latest 19 | # target: x86_64-pc-windows-msvc 20 | # binary_name: shai.exe 21 | # asset_name: shai-unstable-windows-x86_64.exe 22 | - os: macos-latest 23 | target: x86_64-apple-darwin 24 | binary_name: shai 25 | asset_name: shai-unstable-macos-x86_64 26 | - os: macos-latest 27 | target: aarch64-apple-darwin 28 | binary_name: shai 29 | asset_name: shai-unstable-macos-aarch64 30 | 31 | steps: 32 | - name: Checkout code 33 | uses: actions/checkout@v4 34 | 35 | - name: Install Rust 36 | uses: dtolnay/rust-toolchain@stable 37 | with: 38 | targets: ${{ matrix.target }} 39 | 40 | - name: Install dependencies (Linux) 41 | if: matrix.os == 'ubuntu-latest' 42 | run: | 43 | sudo apt-get update 44 | sudo apt-get install -y build-essential musl-tools 45 | 46 | - name: Add musl target (Linux) 47 | if: matrix.target == 'x86_64-unknown-linux-musl' 48 | run: rustup target add x86_64-unknown-linux-musl 49 | 50 | - name: Cache cargo registry 51 | uses: actions/cache@v4 52 | with: 53 | path: ~/.cargo/registry 54 | key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} 55 | restore-keys: | 56 | ${{ runner.os }}-cargo-registry- 57 | 58 | - name: Cache cargo index 59 | uses: actions/cache@v4 60 | with: 61 | path: ~/.cargo/git 62 | key: ${{ runner.os }}-cargo-index-${{ hashFiles('**/Cargo.lock') }} 63 | restore-keys: | 64 | ${{ runner.os }}-cargo-index- 65 | 66 | - name: Cache target directory 67 | uses: actions/cache@v4 68 | with: 69 | path: target 70 | key: ${{ runner.os }}-${{ matrix.target }}-target-${{ hashFiles('**/Cargo.lock') }} 71 | restore-keys: | 72 | ${{ runner.os }}-${{ matrix.target }}-target- 73 | 74 | - name: Build binary 75 | run: cargo build --release --target ${{ matrix.target }} 76 | 77 | - name: Prepare binary (Unix) 78 | if: matrix.os != 'windows-latest' 79 | run: | 80 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 81 | strip ${{ matrix.asset_name }} 82 | 83 | - name: Prepare binary (Windows) 84 | if: matrix.os == 'windows-latest' 85 | run: | 86 | cp target/${{ matrix.target }}/release/${{ matrix.binary_name }} ${{ matrix.asset_name }} 87 | 88 | - name: Upload unstable artifacts 89 | uses: actions/upload-artifact@v4 90 | with: 91 | name: ${{ matrix.asset_name }} 92 | path: ${{ matrix.asset_name }} 93 | retention-days: 30 94 | 95 | create-unstable-release: 96 | needs: build 97 | runs-on: ubuntu-latest 98 | steps: 99 | - name: Checkout code 100 | uses: actions/checkout@v4 101 | 102 | - name: Get current date 103 | id: date 104 | run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT 105 | 106 | - name: Download all artifacts 107 | uses: actions/download-artifact@v4 108 | 109 | - name: Delete existing unstable release 110 | continue-on-error: true 111 | run: | 112 | gh release delete unstable --yes 113 | env: 114 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 115 | 116 | - name: Create unstable release 117 | run: | 118 | gh release create unstable \ 119 | --title "Unstable Build (${{ steps.date.outputs.date }})" \ 120 | --notes "Automated unstable build from commit ${{ github.sha }}" \ 121 | --prerelease \ 122 | shai-unstable-linux-x86_64/shai-unstable-linux-x86_64 \ 123 | shai-unstable-macos-x86_64/shai-unstable-macos-x86_64 \ 124 | shai-unstable-macos-aarch64/shai-unstable-macos-aarch64 125 | env: 126 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Rust 2 | /target/ 3 | 4 | # IDE 5 | .vscode/ 6 | .idea/ 7 | *.swp 8 | .DS_Store 9 | .env 10 | 11 | # Logs 12 | *.log 13 | /logs/ 14 | .aider* 15 | .claude/ 16 | 17 | -------------------------------------------------------------------------------- /.ovh.config: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ovh", 3 | "description": "OVH agent with OVH MCP server for cloud management and API calls", 4 | "llm_provider": { 5 | "provider": "ovhcloud", 6 | "env_vars": { 7 | "OVH_BASE_URL": "https://gpt-oss-120b.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1" 8 | }, 9 | "model": "gpt-oss-120b", 10 | "tool_method": "FunctionCall" 11 | }, 12 | "tools": { 13 | "builtin": ["*"], 14 | "mcp": { 15 | "ovh": { 16 | "config": { 17 | "type": "http", 18 | "url": "https://mcp.eu.ovhcloud.com/mcp" 19 | }, 20 | "enabled_tools": ["*"] 21 | } 22 | } 23 | }, 24 | "system_prompt": "{{CODER_BASE_PROMPT}}", 25 | "max_tokens": 4096, 26 | "temperature": 0.3 27 | } -------------------------------------------------------------------------------- /.shai.config: -------------------------------------------------------------------------------- 1 | { 2 | "providers": [ 3 | { 4 | "provider": "ovhcloud", 5 | "env_vars": { 6 | "OVH_BASE_URL": "https://gpt-oss-120b.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1" 7 | }, 8 | "model": "gpt-oss-120b", 9 | "tool_method": "FunctionCall" 10 | } 11 | ], 12 | "selected_provider": 0 13 | } -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of authors for copyright purposes. 2 | # This file is distinct from the CONTRIBUTORS files 3 | # and it lists the copyright holders only. 4 | 5 | # Names should be added to this file as one of 6 | # Organization's name 7 | # Individual's name 8 | # Individual's name 9 | # See CONTRIBUTORS for the meaning of multiple email addresses. 10 | 11 | # Please keep the list sorted. 12 | 13 | Lucien Loiseau 14 | OVH SAS 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SHAI 2 | 3 | This project accepts contributions. In order to contribute, you should 4 | pay attention to a few things: 5 | 6 | 1. your code must follow the coding style rules 7 | 2. your code must be unit-tested 8 | 3. your code must be documented 9 | 4. your work must be signed (see below) 10 | 5. you may contribute through GitHub Pull Requests 11 | 12 | # Release Process 13 | 14 | When creating a new release, make sure to update the version in the following locations: 15 | 16 | 1. **Main CLI version**: Update `version` in `shai-cli/Cargo.toml` 17 | 2. **Core crate version**: Update `version` in `shai-core/Cargo.toml` 18 | 3. **LLM crate version**: Update `version` in `shai-llm/Cargo.toml` 19 | 4. **Macros crate version**: Update `version` in `shai-macros/Cargo.toml` 20 | 21 | The version banner in the CLI logo is automatically generated from the main CLI crate version using `env!("CARGO_PKG_VERSION")`, so no manual update is needed for the display version. 22 | 23 | After updating versions: 24 | 1. Run `cargo check` to ensure everything compiles 25 | 2. Create a git tag with the new version number 26 | 3. Push the tag to trigger the release workflow 27 | 28 | # Submitting Modifications 29 | 30 | The contributions should be submitted through Github Pull Requests 31 | and follow the DCO which is defined below. 32 | 33 | # Licensing for new files 34 | 35 | SHAI is licensed under a APACHE 2.0 license. Anything 36 | contributed to SHAI must be released under this license. 37 | 38 | When introducing a new file into the project, please make sure it has a 39 | copyright header making clear under which license it's being released. 40 | 41 | # Developer Certificate of Origin (DCO) 42 | 43 | To improve tracking of contributions to this project we will use a 44 | process modeled on the modified DCO 1.1 and use a "sign-off" procedure 45 | on patches that are being emailed around or contributed in any other 46 | way. 47 | 48 | The sign-off is a simple line at the end of the explanation for the 49 | patch, which certifies that you wrote it or otherwise have the right 50 | to pass it on as an open-source patch. The rules are pretty simple: 51 | if you can certify the below: 52 | 53 | By making a contribution to this project, I certify that: 54 | 55 | (a) The contribution was created in whole or in part by me and I have 56 | the right to submit it under the open source license indicated in 57 | the file; or 58 | 59 | (b) The contribution is based upon previous work that, to the best of 60 | my knowledge, is covered under an appropriate open source License 61 | and I have the right under that license to submit that work with 62 | modifications, whether created in whole or in part by me, under 63 | the same open source license (unless I am permitted to submit 64 | under a different license), as indicated in the file; or 65 | 66 | (c) The contribution was provided directly to me by some other person 67 | who certified (a), (b) or (c) and I have not modified it. 68 | 69 | (d) The contribution is made free of any other party's intellectual 70 | property claims or rights. 71 | 72 | (e) I understand and agree that this project and the contribution are 73 | public and that a record of the contribution (including all 74 | personal information I submit with it, including my sign-off) is 75 | maintained indefinitely and may be redistributed consistent with 76 | this project or the open source license(s) involved. 77 | 78 | 79 | then you just add a line saying 80 | 81 | Signed-off-by: Random J Developer 82 | 83 | using your real name (sorry, no pseudonyms or anonymous contributions.) -------------------------------------------------------------------------------- /CONTRIBUTORS: -------------------------------------------------------------------------------- 1 | # This is the official list of people who can contribute 2 | # (and typically have contributed) code to the repository. 3 | # 4 | # Names should be added to this file only after verifying that 5 | # the individual or the individual's organization has agreed to 6 | # the appropriate CONTRIBUTING.md file. 7 | # 8 | # Names should be added to this file like so: 9 | # Individual's name 10 | # Individual's name 11 | # 12 | # Please keep the list sorted. 13 | # 14 | 15 | Lucien Loiseau 16 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = ["shai-macros", "shai-llm", "shai-core", "shai-cli", "shai-http"] 4 | 5 | [patch.crates-io] 6 | ratatui = { git = "https://github.com/Marlinski/ratatui", branch = "feature/viewport-resize-v29" } -------------------------------------------------------------------------------- /MAINTAINERS: -------------------------------------------------------------------------------- 1 | # This is the official list of the project maintainers. 2 | # This is mostly useful for contributors that want to push 3 | # significant pull requests or for project management issues. 4 | # 5 | # 6 | # Names should be added to this file like so: 7 | # Individual's name 8 | # Individual's name 9 | # 10 | # Please keep the list sorted. 11 | # 12 | 13 | Lucien Loiseau 14 | -------------------------------------------------------------------------------- /SHAI.md: -------------------------------------------------------------------------------- 1 | # Shai – Quick Guide 2 | 3 | ## How to compile 4 | ```bash 5 | # Debug build (default) 6 | cargo build 7 | 8 | # Optimized release build 9 | cargo build --release 10 | ``` 11 | - Binaries are placed in `target/debug/` or `target/release/`. 12 | 13 | 14 | ## Project layout 15 | - **`shai-cli/`** – Command‑line interface entry point. 16 | - **`shai-core/`** – Core library (agent, state machine, protocol). 17 | - **`shai-llm/`** – LLM provider wrappers. 18 | - **`docs/`** – Additional documentation and diagrams. 19 | - **`assets/`** – Images/GIFs used in the README. 20 | - **`examples/`** – Small example programs. 21 | - **`tests/`** – Integration and unit tests. 22 | - **`install.sh`** – Helper script to install the CLI (`cargo install --path .`). 23 | - Root `README.md`, `CONTRIBUTING.md`, `LICENSE`, etc. – Project meta‑information. 24 | 25 | ## About Shai 26 | *Shai is a coding agent, your pair‑programming buddy that lives in the terminal. Written in Rust with love <3.* 27 | -------------------------------------------------------------------------------- /docs/assets/auth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/auth.gif -------------------------------------------------------------------------------- /docs/assets/shai-chain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-chain.gif -------------------------------------------------------------------------------- /docs/assets/shai-headless.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-headless.gif -------------------------------------------------------------------------------- /docs/assets/shai-hello-world.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-hello-world.gif -------------------------------------------------------------------------------- /docs/assets/shai-http.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-http.png -------------------------------------------------------------------------------- /docs/assets/shai-shell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-shell.png -------------------------------------------------------------------------------- /docs/assets/shai-trace.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai-trace.gif -------------------------------------------------------------------------------- /docs/assets/shai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/docs/assets/shai.png -------------------------------------------------------------------------------- /shai-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "shai" 3 | version = "0.1.8" 4 | edition = "2021" 5 | 6 | 7 | [dependencies] 8 | shai-core = { path = "../shai-core" } 9 | shai-llm = { path = "../shai-llm" } 10 | shai-http = { path = "../shai-http" } 11 | openai_dive = "1.3.1" 12 | chrono = "0.4" 13 | clap = { version = "4.0", features = ["derive"] } 14 | tempfile = "3.20.0" 15 | serde = { version = "1.0", features = ["derive"] } 16 | rmp-serde = "1.1" 17 | tokio = { version = "1.0", features = ["full"] } 18 | crossterm = { version = "0.28", features = ["event-stream"] } 19 | futures = "0.3" 20 | ratatui = { git = "https://github.com/Marlinski/ratatui", branch = "feature/viewport-resize-v29", features = ["crossterm"] } 21 | ansi-to-tui = "7.0.0" 22 | # ratatui = { version = "0.29", features = ["crossterm"] } 23 | tui-textarea = { version = "0.7", features = ["crossterm", "ratatui"] } 24 | serde_json = "1.0" 25 | figrs = "0.3" 26 | rand = "0.9" 27 | async-trait = "0.1" 28 | console = "0.16" 29 | ringbuffer = "0.16" 30 | cli-clipboard = "0.4" 31 | textwrap = "0.16" 32 | jwalk = "0.8.1" 33 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 34 | 35 | 36 | [target.'cfg(unix)'.dependencies] 37 | libc = "0.2" 38 | 39 | [lints.rust] 40 | dead_code = "allow" 41 | unused_variables = "allow" 42 | unused_mut = "allow" 43 | unused_imports = "allow" 44 | -------------------------------------------------------------------------------- /shai-cli/src/fc/client.rs: -------------------------------------------------------------------------------- 1 | use std::os::unix::net::UnixStream; 2 | use std::path::Path; 3 | 4 | use crate::fc::history::{CommandEntry, CommandHistory, HistoryStats}; 5 | use crate::fc::protocol::{ShaiProtocol, ShaiRequest, ShaiResponse, ResponseData}; 6 | 7 | /// Client for querying the command history via Unix socket 8 | pub struct ShaiSessionClient { 9 | socket_path: String, 10 | } 11 | 12 | impl ShaiSessionClient { 13 | pub fn new(session_id: &str) -> Self { 14 | let socket_path = format!("/tmp/shai_history_{}", session_id); 15 | Self { socket_path } 16 | } 17 | 18 | pub fn get_last_commands(&self, n: usize) -> Result> { 19 | let mut stream = UnixStream::connect(&self.socket_path) 20 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 21 | 22 | let request = ShaiRequest::GetLastCmd { n }; 23 | ShaiProtocol::write_request(&mut stream, &request)?; 24 | 25 | let response = ShaiProtocol::read_response(&mut stream)?; 26 | 27 | match response { 28 | ShaiResponse::Ok { data: ResponseData::Commands(entries) } => Ok(entries.into()), 29 | ShaiResponse::Ok { .. } => Err("Unexpected response type".into()), 30 | ShaiResponse::Error { message } => Err(message.into()), 31 | } 32 | } 33 | 34 | pub fn get_all_commands(&self) -> Result> { 35 | let mut stream = UnixStream::connect(&self.socket_path) 36 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 37 | 38 | let request = ShaiRequest::GetAllCmd; 39 | ShaiProtocol::write_request(&mut stream, &request)?; 40 | 41 | let response = ShaiProtocol::read_response(&mut stream)?; 42 | 43 | match response { 44 | ShaiResponse::Ok { data: ResponseData::Commands(entries) } => Ok(entries.into()), 45 | ShaiResponse::Ok { .. } => Err("Unexpected response type".into()), 46 | ShaiResponse::Error { message } => Err(message.into()), 47 | } 48 | } 49 | 50 | pub fn clear(&self) -> Result<(), Box> { 51 | let mut stream = UnixStream::connect(&self.socket_path) 52 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 53 | 54 | let request = ShaiRequest::Clear; 55 | ShaiProtocol::write_request(&mut stream, &request)?; 56 | 57 | let response = ShaiProtocol::read_response(&mut stream)?; 58 | 59 | match response { 60 | ShaiResponse::Ok { .. } => Ok(()), 61 | ShaiResponse::Error { message } => Err(message.into()), 62 | } 63 | } 64 | 65 | pub fn get_status(&self) -> Result> { 66 | let mut stream = UnixStream::connect(&self.socket_path) 67 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 68 | 69 | let request = ShaiRequest::Status; 70 | ShaiProtocol::write_request(&mut stream, &request)?; 71 | 72 | let response = ShaiProtocol::read_response(&mut stream)?; 73 | 74 | match response { 75 | ShaiResponse::Ok { data: ResponseData::Stats(stats) } => Ok(stats), 76 | ShaiResponse::Ok { .. } => Err("Unexpected response type".into()), 77 | ShaiResponse::Error { message } => Err(message.into()), 78 | } 79 | } 80 | 81 | pub fn pre_command(&self, cmd: &str) -> Result<(), Box> { 82 | let mut stream = UnixStream::connect(&self.socket_path) 83 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 84 | 85 | let request = ShaiRequest::PreCmd { cmd: cmd.to_string() }; 86 | ShaiProtocol::write_request(&mut stream, &request)?; 87 | 88 | let response = ShaiProtocol::read_response(&mut stream)?; 89 | 90 | match response { 91 | ShaiResponse::Ok { .. } => Ok(()), 92 | ShaiResponse::Error { message } => Err(message.into()), 93 | } 94 | } 95 | 96 | pub fn post_command(&self, exit_code: i32, cmd: &str) -> Result<(), Box> { 97 | let mut stream = UnixStream::connect(&self.socket_path) 98 | .map_err(|_| "Could not connect to SHAI history session (is server running?)")?; 99 | 100 | let request = ShaiRequest::PostCmd { 101 | cmd: cmd.to_string(), 102 | exit_code 103 | }; 104 | ShaiProtocol::write_request(&mut stream, &request)?; 105 | 106 | let response = ShaiProtocol::read_response(&mut stream)?; 107 | 108 | match response { 109 | ShaiResponse::Ok { .. } => Ok(()), 110 | ShaiResponse::Error { message } => Err(message.into()), 111 | } 112 | } 113 | 114 | pub fn session_exists(&self) -> bool { 115 | Path::new(&self.socket_path).exists() 116 | } 117 | } 118 | 119 | -------------------------------------------------------------------------------- /shai-cli/src/fc/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod history; 2 | pub mod protocol; 3 | pub mod server; 4 | pub mod client; 5 | mod tests; -------------------------------------------------------------------------------- /shai-cli/src/fc/protocol.rs: -------------------------------------------------------------------------------- 1 | use std::os::unix::net::UnixStream; 2 | use std::io::{Write, Read}; 3 | use serde::{Serialize, Deserialize}; 4 | use rmp_serde::{Serializer, Deserializer}; 5 | 6 | use crate::fc::history::{CommandEntry, HistoryStats}; 7 | 8 | #[derive(Debug, Serialize, Deserialize)] 9 | pub enum ShaiRequest { 10 | // send signals 11 | PreCmd { cmd: String }, 12 | PostCmd { cmd: String, exit_code: i32}, 13 | 14 | // request data 15 | GetAllCmd, 16 | GetLastCmd { n: usize }, 17 | Clear, 18 | Status, 19 | } 20 | 21 | #[derive(Debug, Serialize, Deserialize)] 22 | pub enum ShaiResponse { 23 | Ok { data: ResponseData }, 24 | Error { message: String }, 25 | } 26 | 27 | #[derive(Debug, Serialize, Deserialize)] 28 | #[serde(untagged)] 29 | pub enum ResponseData { 30 | Commands(Vec), 31 | Stats(HistoryStats), 32 | Empty, 33 | } 34 | 35 | pub struct ShaiProtocol; 36 | 37 | impl ShaiProtocol { 38 | pub fn write_request(stream: &mut UnixStream, request: &ShaiRequest) -> Result<(), Box> { 39 | Self::write_message(stream, request) 40 | } 41 | 42 | pub fn read_request(stream: &mut UnixStream) -> Result> { 43 | Self::read_message_request(stream) 44 | } 45 | 46 | pub fn write_response(stream: &mut UnixStream, response: &ShaiResponse) -> Result<(), Box> { 47 | Self::write_message(stream, response) 48 | } 49 | 50 | pub fn read_response(stream: &mut UnixStream) -> Result> { 51 | Self::read_message_response(stream) 52 | } 53 | 54 | // Generic write method - eliminates duplication for writing 55 | fn write_message(stream: &mut UnixStream, message: &T) -> Result<(), Box> { 56 | let mut buf = Vec::new(); 57 | message.serialize(&mut Serializer::new(&mut buf))?; 58 | 59 | // Write length prefix (4 bytes) then data 60 | stream.write_all(&(buf.len() as u32).to_le_bytes())?; 61 | stream.write_all(&buf)?; 62 | stream.flush()?; 63 | 64 | Ok(()) 65 | } 66 | 67 | // Specific read method for requests 68 | fn read_message_request(stream: &mut UnixStream) -> Result> { 69 | // Read length prefix 70 | let mut len_buf = [0u8; 4]; 71 | stream.read_exact(&mut len_buf)?; 72 | let len = u32::from_le_bytes(len_buf) as usize; 73 | 74 | // Read data 75 | let mut buf = vec![0u8; len]; 76 | stream.read_exact(&mut buf)?; 77 | 78 | let mut de = Deserializer::new(&buf[..]); 79 | let request = ShaiRequest::deserialize(&mut de)?; 80 | 81 | Ok(request) 82 | } 83 | 84 | // Specific read method for responses 85 | fn read_message_response(stream: &mut UnixStream) -> Result> { 86 | // Read length prefix 87 | let mut len_buf = [0u8; 4]; 88 | stream.read_exact(&mut len_buf)?; 89 | let len = u32::from_le_bytes(len_buf) as usize; 90 | 91 | // Read data 92 | let mut buf = vec![0u8; len]; 93 | stream.read_exact(&mut buf)?; 94 | 95 | let mut de = Deserializer::new(&buf[..]); 96 | let response = ShaiResponse::deserialize(&mut de)?; 97 | 98 | Ok(response) 99 | } 100 | } -------------------------------------------------------------------------------- /shai-cli/src/headless/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod tools; 2 | pub mod app; -------------------------------------------------------------------------------- /shai-cli/src/shell/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod pty; 2 | pub mod rc; 3 | pub mod terminal; -------------------------------------------------------------------------------- /shai-cli/src/shell/terminal.rs: -------------------------------------------------------------------------------- 1 | extern crate libc; 2 | 3 | pub struct TerminalManager { 4 | original_termios: libc::termios, 5 | } 6 | 7 | impl TerminalManager { 8 | pub fn new() -> Result> { 9 | let original_termios = Self::setup_raw_mode()?; 10 | Ok(Self { original_termios }) 11 | } 12 | 13 | fn setup_raw_mode() -> Result> { 14 | let mut termios: libc::termios = unsafe { std::mem::zeroed() }; 15 | 16 | if unsafe { libc::tcgetattr(libc::STDIN_FILENO, &mut termios) } == -1 { 17 | return Err("Failed to get terminal attributes".into()); 18 | } 19 | 20 | let original = termios; 21 | 22 | // Set raw mode 23 | unsafe { libc::cfmakeraw(&mut termios) }; 24 | 25 | if unsafe { libc::tcsetattr(libc::STDIN_FILENO, libc::TCSANOW, &termios) } == -1 { 26 | return Err("Failed to set terminal attributes".into()); 27 | } 28 | 29 | Ok(original) 30 | } 31 | 32 | pub fn get_window_size() -> Result> { 33 | let mut ws: libc::winsize = unsafe { std::mem::zeroed() }; 34 | 35 | if unsafe { libc::ioctl(libc::STDOUT_FILENO, libc::TIOCGWINSZ, &mut ws) } == -1 { 36 | return Err("Failed to get window size".into()); 37 | } 38 | 39 | Ok(ws) 40 | } 41 | 42 | pub fn set_window_size(fd: i32, ws: &libc::winsize) -> Result<(), Box> { 43 | if unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, ws) } == -1 { 44 | return Err("Failed to set window size".into()); 45 | } 46 | Ok(()) 47 | } 48 | 49 | pub fn restore(&self) { 50 | unsafe { 51 | libc::tcsetattr(libc::STDIN_FILENO, libc::TCSANOW, &self.original_termios); 52 | } 53 | } 54 | } 55 | 56 | impl Drop for TerminalManager { 57 | fn drop(&mut self) { 58 | self.restore(); 59 | } 60 | } -------------------------------------------------------------------------------- /shai-cli/src/tui/auth/config_providers.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use crossterm::event::{KeyCode, KeyEvent}; 3 | use ratatui::{ 4 | layout::{Constraint, Layout, Rect}, 5 | style::{Color, Style}, 6 | symbols::border, 7 | text::{Line, Span, Text}, 8 | widgets::{Block, Borders, Padding, Paragraph}, 9 | Frame, 10 | }; 11 | use shai_core::config::config::ShaiConfig; 12 | use shai_llm::provider::ProviderInfo; 13 | 14 | use super::auth::NavAction; 15 | 16 | #[derive(Debug)] 17 | pub struct ModalProviders { 18 | config: ShaiConfig, 19 | providers: Vec, 20 | selected_provider: usize, 21 | } 22 | 23 | #[derive(Debug)] 24 | pub enum ProviderAction { 25 | None, 26 | Selected(usize), 27 | Exit, 28 | } 29 | 30 | impl ModalProviders { 31 | pub fn new(config: ShaiConfig, providers: Vec) -> Self { 32 | Self { 33 | config, 34 | providers, 35 | selected_provider: 0, 36 | } 37 | } 38 | 39 | pub fn selected_provider(&self) -> ProviderInfo { 40 | self.providers[self.selected_provider].clone() 41 | } 42 | 43 | pub fn providers(&self) -> &[ProviderInfo] { 44 | &self.providers 45 | } 46 | 47 | pub fn extract_state(self) -> (ShaiConfig, Vec, ProviderInfo) { 48 | let selected_provider = self.providers[self.selected_provider].clone(); 49 | (self.config, self.providers, selected_provider) 50 | } 51 | } 52 | 53 | impl ModalProviders { 54 | pub async fn handle_event(&mut self, key_event: KeyEvent) -> NavAction { 55 | match key_event.code { 56 | KeyCode::Up => { 57 | if self.selected_provider > 0 { 58 | self.selected_provider -= 1; 59 | } 60 | } 61 | KeyCode::Down => { 62 | if self.selected_provider < self.providers.len() - 1 { 63 | self.selected_provider += 1; 64 | } 65 | } 66 | KeyCode::Enter => { 67 | return NavAction::Next 68 | } 69 | KeyCode::Esc => { 70 | return NavAction::Back 71 | } 72 | _ => {} 73 | } 74 | NavAction::None 75 | } 76 | 77 | pub fn draw(&self, frame: &mut Frame, area: Rect) { 78 | let [list, help] = Layout::vertical(vec![ 79 | Constraint::Length(2 + self.providers.len() as u16), 80 | Constraint::Length(1) 81 | ]).areas(area); 82 | 83 | let block = Block::default() 84 | .borders(Borders::ALL) 85 | .border_set(border::ROUNDED) 86 | .padding(Padding { left: 1, right: 1, top: 0, bottom: 0 }) 87 | .title(" Select AI Provider ") 88 | .style(Style::default().fg(Color::DarkGray)); 89 | 90 | let mut lines = vec![]; 91 | for (i, provider) in self.providers.iter().enumerate() { 92 | let prefix = if i == self.selected_provider { "● " } else { "○ " }; 93 | let line = format!("{}{}", prefix, provider.name); 94 | 95 | if i == self.selected_provider { 96 | lines.push(Line::from(vec![ 97 | Span::styled(line, Style::default().fg(Color::Green)) 98 | ])); 99 | } else { 100 | lines.push(Line::from(vec![ 101 | Span::styled(line, Style::default().fg(Color::DarkGray)) 102 | ])); 103 | } 104 | } 105 | 106 | let text = Text::from(lines); 107 | let paragraph = Paragraph::new(text).block(block); 108 | frame.render_widget(paragraph, list); 109 | 110 | frame.render_widget(Line::from(vec![ 111 | Span::styled(" ↑↓ navigate • Enter select • Esc exit", Style::default().fg(Color::DarkGray)) 112 | ]), help); 113 | } 114 | 115 | pub fn height(&self) -> usize { 116 | 2 + self.providers.len() + 1 117 | } 118 | } -------------------------------------------------------------------------------- /shai-cli/src/tui/auth/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod auth; 2 | pub mod config_list; 3 | pub mod config_providers; 4 | pub mod config_env; 5 | pub mod config_model; 6 | 7 | pub use auth::AppAuth; -------------------------------------------------------------------------------- /shai-cli/src/tui/cmdnav.rs: -------------------------------------------------------------------------------- 1 | 2 | 3 | pub struct CommandNav { 4 | 5 | } -------------------------------------------------------------------------------- /shai-cli/src/tui/helper.rs: -------------------------------------------------------------------------------- 1 | use ansi_to_tui::IntoText; 2 | use ratatui::{layout::Rect, style::{Color, Style, Stylize}, symbols::border, text::{Line, Span}, widgets::{Block, Borders, Padding, Widget}, Frame}; 3 | 4 | 5 | 6 | pub struct HelpArea; 7 | 8 | impl HelpArea { 9 | fn helper_msg(&self) -> String { 10 | [ 11 | " ? to print help tap esc twice to clear input", 12 | " / for commands tap esc while agent is running to cancel", 13 | " ctrl^c to exit", 14 | "", 15 | " Available Commands:", 16 | " /exit exit from the tui", 17 | " /tc set tool call method: [auto | fc | fc2 | so]", 18 | " /tokens display token usage" 19 | ].join("\n").to_string() 20 | } 21 | } 22 | 23 | impl HelpArea { 24 | pub fn height(&self) -> u16 { 25 | 8 // content (3 general help lines + 1 blank + 1 header + 3 command lines) 26 | } 27 | 28 | pub fn draw(&self, f: &mut Frame, area: Rect) { 29 | let helper_text = self.helper_msg(); 30 | let x = helper_text.into_text().unwrap(); 31 | let x = x.style(Style::default().fg(Color::DarkGray).dim()); 32 | f.render_widget( 33 | x, 34 | area 35 | ); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /shai-cli/src/tui/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod auth; 2 | pub mod app; 3 | pub mod input; 4 | pub mod perm; 5 | pub mod perm_alt_screen; 6 | pub mod theme; 7 | pub mod command; 8 | pub mod helper; 9 | pub mod cmdnav; 10 | 11 | pub use app::App; -------------------------------------------------------------------------------- /shai-cli/src/tui/perm_alt_screen.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, stdout, Write}; 2 | use crossterm::cursor::MoveTo; 3 | use crossterm::event::{self, EnableFocusChange, Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers, MouseButton, MouseEvent, MouseEventKind}; 4 | use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType, EnterAlternateScreen, LeaveAlternateScreen}; 5 | use crossterm::{execute, ExecutableCommand}; 6 | use crossterm::event::{EnableMouseCapture, DisableMouseCapture}; 7 | use futures::StreamExt; 8 | use tokio::time::{sleep, Duration}; 9 | use ratatui::{ 10 | layout::{Constraint, Direction, Layout, Rect}, 11 | prelude::CrosstermBackend, 12 | style::{Color, Style}, 13 | text::{Line, Span, Text}, 14 | widgets::{Block, Borders, Paragraph}, 15 | Frame, Terminal 16 | }; 17 | use shai_core::agent::events::PermissionRequest; 18 | 19 | use super::perm::{PermissionWidget, PermissionModalAction}; 20 | use super::theme::ThemePalette; 21 | 22 | pub struct AlternateScreenPermissionModal<'a> { 23 | widget: PermissionWidget<'a>, 24 | } 25 | 26 | impl AlternateScreenPermissionModal<'_> { 27 | pub fn new(widget: &PermissionWidget, palette: ThemePalette) -> io::Result { 28 | Ok(Self { 29 | widget: PermissionWidget::new( 30 | widget.request_id.clone(), 31 | widget.request.clone(), 32 | widget.remaining_perms, 33 | palette 34 | ) 35 | }) 36 | } 37 | 38 | pub fn draw(&self, frame: &mut Frame, area: Rect) { 39 | self.widget.draw(frame, area); 40 | } 41 | 42 | 43 | pub async fn run(&mut self) -> io::Result { 44 | // Enter alternate screen and enable mouse capture 45 | execute!(stdout(), EnterAlternateScreen, EnableMouseCapture)?; 46 | 47 | let result = self.run_modal().await; 48 | 49 | // Always clean up - leave alternate screen and disable mouse capture 50 | let _ = execute!(stdout(), LeaveAlternateScreen, DisableMouseCapture); 51 | let _ = stdout().flush(); 52 | 53 | // Small delay to ensure terminal state is properly restored 54 | sleep(Duration::from_millis(50)).await; 55 | 56 | result 57 | } 58 | 59 | async fn run_modal(&mut self) -> io::Result { 60 | let mut terminal = Terminal::new(CrosstermBackend::new(stdout()))?; 61 | let mut reader = event::EventStream::new(); 62 | 63 | loop { 64 | terminal.draw(|frame| { 65 | let area = frame.area(); 66 | self.widget.draw(frame, area); 67 | })?; 68 | 69 | if let Some(Ok(event)) = reader.next().await { 70 | match event { 71 | Event::Key(key_event) if key_event.kind == KeyEventKind::Press => { 72 | // Handle Ctrl+C to exit 73 | if matches!(key_event.code, KeyCode::Char('c')) 74 | && key_event.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) { 75 | // Treat Ctrl+C as Escape (Deny) 76 | return Ok(PermissionModalAction::Response { 77 | request_id: "".to_string(), // We'll fix this access later 78 | choice: shai_core::agent::PermissionResponse::Deny, 79 | }); 80 | } 81 | 82 | // Pass all key events to the widget 83 | let action = self.widget.handle_key_event(key_event).await; 84 | if !matches!(action, PermissionModalAction::Nope) { 85 | return Ok(action); 86 | } 87 | } 88 | Event::Mouse(mouse_event) => { 89 | let _ = self.widget.handle_mouse_event(mouse_event).await; 90 | } 91 | Event::Resize(..) => { 92 | // Terminal was resized, redraw on next iteration 93 | } 94 | _ => {} 95 | } 96 | } 97 | } 98 | } 99 | } 100 | 101 | impl Drop for AlternateScreenPermissionModal<'_> { 102 | fn drop(&mut self) { 103 | let _ = execute!(stdout(), DisableMouseCapture, LeaveAlternateScreen); 104 | let _ = stdout().flush(); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /shai-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "shai-core" 3 | version = "0.1.8" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | # New LLM module dependencies 8 | async-trait = "0.1" 9 | reqwest = { version = "0.12", features = ["json", "stream"] } 10 | serde = { version = "1.0", features = ["derive"] } 11 | serde_json = "1.0" 12 | uuid = { version = "1.0", features = ["v4"] } 13 | tokio = { version = "1.0", features = ["full"] } 14 | tokio-util = "0.7" 15 | futures = "0.3" 16 | termimad = "0.34" 17 | tree-sitter = "0.25" 18 | tree-sitter-highlight = "0.25" 19 | 20 | [target.'cfg(unix)'.dependencies] 21 | libc = "0.2" 22 | 23 | # Tool system dependencies 24 | schemars = "1.0.1" 25 | shai-macros = { path = "../shai-macros" } 26 | shai-llm = { path = "../shai-llm" } 27 | openai_dive = "1.3.1" 28 | regex = "1.12" 29 | walkdir = "2.4" 30 | chrono = { version = "0.4", features = ["serde"] } 31 | thiserror = "2.0" 32 | tracing = "0.1" 33 | tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] } 34 | tracing-appender = "0.2" 35 | similar = "2.6" 36 | fs = "0.0.5" 37 | dirs = "6.0" 38 | rmcp = { version = "0.6.0", features = ["schemars", "auth", "client", "transport-child-process", "transport-streamable-http-client", "transport-sse-client"] } 39 | 40 | # OAuth dependencies 41 | oauth2 = "4.4" 42 | warp = "0.3" 43 | anyhow = "1.0" 44 | webbrowser = "1.0" 45 | 46 | [dev-dependencies] 47 | tempfile = "3.23.0" 48 | paste = "1.0" 49 | 50 | [lints.rust] 51 | dead_code = "allow" 52 | unused_variables = "allow" 53 | #unused_mut = "allow" 54 | #unused_imports = "allow" 55 | 56 | -------------------------------------------------------------------------------- /shai-core/examples/oauth_test.rs: -------------------------------------------------------------------------------- 1 | use shai_core::tools::mcp::mcp_oauth::signin_oauth; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | println!("🚀 Starting OAuth flow test..."); 6 | 7 | match signin_oauth("https://mcp.eu.ovhcloud.com/").await { 8 | Ok(access_token) => { 9 | println!("✅ OAuth flow completed successfully!"); 10 | println!("🎫 Access Token: {}", access_token); 11 | println!("🔑 Token length: {} characters", access_token.len()); 12 | } 13 | Err(e) => { 14 | println!("❌ OAuth flow failed: {}", e); 15 | std::process::exit(1); 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /shai-core/src/agent/README.md: -------------------------------------------------------------------------------- 1 | # Agent Architecture 2 | 3 | The Agent system is a modular architecture for autonomous task execution with LLM interaction, tool orchestration, and external control. 4 | 5 | ## Functional Diagram 6 | 7 | ``` 8 | ┌───────────────────────────────────────────────────────────────────────────────┐ 9 | │ AGENT CORE │ 10 | ├─────────────────┬───────────────────────────────────────┬─────────────────────┤ 11 | │ CONTROLLER │ CORE MODULE │ EVENTS I/O │ 12 | │ (Protocol) │ (LLM Interaction) │ (Communication) │ 13 | │ │ │ │ 14 | │ ┌─────────────┐ │ ┌─────────────┐ ┌─────────────────┐ │ ┌─────────────────┐ │ 15 | │ │ Protocol │ │ │ Thinker │ │ Tool Actions │ │ │ Event Handlers │ │ 16 | │ │ Commands │ │ │ (Brain) │ │ Orchestrator │ │ │ (Async) │ │ 17 | │ │ │ │ │ │ │ │ │ │ │ │ 18 | │ │ • Cancel │ │ │ • LLM Call │ │ • Tool Execute │ │ │ • StatusChanged │ │ 19 | │ │ • GetState │ │ │ • Decision │ │ • Result Handle │ │ │ • ToolCall* │ │ 20 | │ │ • UserInput │ │ │ • Continue/ │ │ • Cancellation │ │ │ • UserRequired │ │ 21 | │ │ • Response │ │ │ Pause │ │ • Concurrency │ │ │ • Permission* │ │ 22 | │ └─────────────┘ │ └─────────────┘ └─────────────────┘ │ └─────────────────┘ │ 23 | │ │ │ │ 24 | └─────────────────┼───────────────────────────────────────┼─────────────────────┘ 25 | │ │ 26 | ▼ ▼ 27 | ┌─────────────────┐ ┌─────────────────┐ 28 | │ STATE MACHINE │◄─────────────────►│ TOOL SYSTEM │ 29 | │ │ │ │ 30 | │ • Starting │ │ • Read Tools │ 31 | │ • Running │ │ • Write Tools │ 32 | │ • Processing │ │ • Network Tools │ 33 | │ • Paused │ │ • Permissions │ 34 | │ • Terminal │ │ • Validation │ 35 | └─────────────────┘ └─────────────────┘ 36 | ``` 37 | 38 | ## Core Modules 39 | 40 | ### Controller (Protocol) 41 | **Purpose**: External command interface and control channel 42 | - **Commands**: Cancel, GetState, SendUserInput, Permissions 43 | - **Responses**: Ack, State, Error 44 | - **Communication**: Async channel-based protocol 45 | - **Lifecycle**: Controls agent execution flow 46 | 47 | ### Brain Module (LLM Interaction) 48 | **Purpose**: AI decision-making and reasoning engine 49 | - **Thinker**: Core LLM interaction logic 50 | - **Context**: Maintains conversation trace and available tools 51 | - **Decision**: Determines next action (continue/pause/tool use) 52 | - **Flow Control**: Manages autonomous vs interactive execution 53 | 54 | ### Events I/O (Communication) 55 | **Purpose**: Asynchronous event handling and external communication 56 | - **Internal Events**: State machine communication (`BrainResult`, `ToolsCompleted`) 57 | - **External Events**: UI/Controller notifications (`StatusChanged`, `ToolCallStarted`) 58 | - **User Interaction**: Input requests and permission handling 59 | - **Event Handlers**: Pluggable async event processing 60 | 61 | ## Key Interactions 62 | 63 | 1. **Controller → Core**: Protocol commands control agent lifecycle 64 | 2. **Core → Brain**: Triggers thinking processes with current context 65 | 3. **Brain → Tools**: Executes tool calls based on LLM decisions 66 | 4. **Core → Events**: Emits events for external consumption 67 | 5. **Events → Controller**: Handles user input and permission requests 68 | 69 | ## Concurrency Model 70 | 71 | - **State Machine**: Single-threaded with async event handling 72 | - **Brain Execution**: Spawned tasks with cancellation tokens 73 | - **Tool Execution**: Concurrent tool calls with result aggregation 74 | - **Event Emission**: Non-blocking async event distribution 75 | - **Protocol**: Channel-based async command/response pattern 76 | 77 | The architecture enables autonomous agent operation while maintaining external control and observability through clean interfaces. -------------------------------------------------------------------------------- /shai-core/src/agent/actions/brain.rs: -------------------------------------------------------------------------------- 1 | use chrono::Utc; 2 | use openai_dive::v1::resources::chat::ChatMessage; 3 | use tracing::info; 4 | use tokio_util::sync::CancellationToken; 5 | use crate::agent::{AgentCore, AgentError, AgentEvent, InternalAgentEvent, InternalAgentState, ThinkerContext, ThinkerDecision, ThinkerFlowControl}; 6 | 7 | impl AgentCore { 8 | /// Launch a brain task to decide next step 9 | pub async fn spawn_next_step(&mut self) { 10 | let cancellation_token = CancellationToken::new(); 11 | let cancel_token_clone = cancellation_token.clone(); 12 | let trace = self.trace.clone(); 13 | let tx_clone = self.internal_tx.clone(); 14 | let available_tools = self.available_tools.clone(); 15 | let method = self.method.clone(); 16 | let context = ThinkerContext { 17 | trace, 18 | available_tools, 19 | method 20 | }; 21 | let brain = self.brain.clone(); 22 | 23 | //////////////////////// TOKIO SPAWN 24 | tokio::spawn(async move { 25 | tokio::select! { 26 | result = async { 27 | brain.write().await.next_step(context).await 28 | } => { 29 | let _ = tx_clone.send(InternalAgentEvent::BrainResult { 30 | result 31 | }); 32 | } 33 | _ = cancel_token_clone.cancelled() => { 34 | // Brain thinking was cancelled, no need to send result 35 | } 36 | } 37 | }); 38 | //////////////////////// TOKIO SPAWN 39 | 40 | self.set_state(InternalAgentState::Processing { 41 | task_name: "next_step".to_string(), 42 | tools_exec_at: Utc::now(), 43 | cancellation_token 44 | }).await; 45 | } 46 | 47 | 48 | /// Process a brain task result 49 | pub async fn process_next_step(&mut self, result: Result) -> Result<(), AgentError> { 50 | let ThinkerDecision{message, flow, token_usage} = self.handle_brain_error(result).await?; 51 | let ChatMessage::Assistant { content, reasoning_content, tool_calls, .. } = message.clone() else { 52 | return self.handle_brain_error::( 53 | Err(AgentError::InvalidResponse(format!("ChatMessage::Assistant expected, but got {:?} instead", message)))).await.map(|_| () 54 | ); 55 | }; 56 | 57 | // Add the message to trace 58 | info!(target: "agent::think", reasoning_content = ?reasoning_content, content = ?content); 59 | let trace = self.trace.clone(); 60 | trace.write().await.push(message.clone()); 61 | 62 | // Emit event to external consumers 63 | let _ = self.emit_event(AgentEvent::BrainResult { 64 | timestamp: Utc::now(), 65 | thought: Ok(message.clone()) 66 | }).await; 67 | 68 | // Emit token usage event if available 69 | if let Some((input_tokens, output_tokens)) = token_usage { 70 | let _ = self.emit_event(AgentEvent::TokenUsage { 71 | input_tokens, 72 | output_tokens 73 | }).await; 74 | } 75 | 76 | // run tool call if any 77 | let tool_calls_from_brain = tool_calls.unwrap_or(vec![]); 78 | if !tool_calls_from_brain.is_empty() { 79 | self.spawn_tools(tool_calls_from_brain).await; 80 | return Ok(()) 81 | } 82 | 83 | // no tool call, thus we rely on flow control 84 | match flow { 85 | ThinkerFlowControl::AgentContinue => { 86 | self.set_state(InternalAgentState::Running).await; 87 | } 88 | ThinkerFlowControl::AgentPause => { 89 | self.set_state(InternalAgentState::Paused).await; 90 | } 91 | } 92 | Ok(()) 93 | } 94 | 95 | // Helper method that emits error events before returning the error 96 | async fn handle_brain_error(&mut self, result: Result) -> Result { 97 | match result { 98 | Ok(value) => Ok(value), 99 | Err(error) => { 100 | self.set_state(InternalAgentState::Paused).await; 101 | let _ = self.emit_event(AgentEvent::BrainResult { 102 | timestamp: Utc::now(), 103 | thought: Err(error.clone()) 104 | }).await; 105 | Err(error) 106 | } 107 | } 108 | } 109 | } -------------------------------------------------------------------------------- /shai-core/src/agent/actions/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod brain; 2 | pub mod tools; 3 | -------------------------------------------------------------------------------- /shai-core/src/agent/brain.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use async_trait::async_trait; 3 | use openai_dive::v1::resources::chat::ChatMessage; 4 | use shai_llm::ToolCallMethod; 5 | use tokio::sync::RwLock; 6 | 7 | use crate::tools::types::AnyToolBox; 8 | use super::error::AgentError; 9 | 10 | 11 | /// ThinkerContext is the agent internal state 12 | pub struct ThinkerContext { 13 | pub trace: Arc>>, 14 | pub available_tools: AnyToolBox, 15 | pub method: ToolCallMethod 16 | } 17 | 18 | /// ThinkerFlowControl drives the agentic flow 19 | #[derive(Debug, Clone)] 20 | pub enum ThinkerFlowControl { 21 | AgentContinue, 22 | AgentPause 23 | } 24 | 25 | /// This structure pilot the flow of the Agent 26 | /// If tool_call are present in the chat message, the flow attribute is ignored 27 | /// If no tool_call is present in the chat message, flow will pilot wether the agent pause or continue 28 | #[derive(Debug, Clone)] 29 | pub struct ThinkerDecision { 30 | pub message: ChatMessage, 31 | pub flow: ThinkerFlowControl, 32 | pub token_usage: Option<(u32, u32)>, // (input_tokens, output_tokens) 33 | } 34 | 35 | impl ThinkerDecision { 36 | pub fn new(message: ChatMessage) -> Self { 37 | ThinkerDecision{ 38 | message, 39 | flow: ThinkerFlowControl::AgentPause, 40 | token_usage: None, 41 | } 42 | } 43 | 44 | pub fn agent_continue(message: ChatMessage) -> Self { 45 | ThinkerDecision{ 46 | message, 47 | flow: ThinkerFlowControl::AgentContinue, 48 | token_usage: None, 49 | } 50 | } 51 | 52 | pub fn agent_pause(message: ChatMessage) -> Self { 53 | ThinkerDecision{ 54 | message, 55 | flow: ThinkerFlowControl::AgentPause, 56 | token_usage: None, 57 | } 58 | } 59 | 60 | pub fn agent_continue_with_tokens(message: ChatMessage, input_tokens: u32, output_tokens: u32) -> Self { 61 | ThinkerDecision{ 62 | message, 63 | flow: ThinkerFlowControl::AgentContinue, 64 | token_usage: Some((input_tokens, output_tokens)), 65 | } 66 | } 67 | 68 | pub fn agent_pause_with_tokens(message: ChatMessage, input_tokens: u32, output_tokens: u32) -> Self { 69 | ThinkerDecision{ 70 | message, 71 | flow: ThinkerFlowControl::AgentPause, 72 | token_usage: Some((input_tokens, output_tokens)), 73 | } 74 | } 75 | 76 | pub fn unwrap(self) -> ChatMessage { 77 | self.message 78 | } 79 | } 80 | 81 | /// Core thinking interface - pure decision making 82 | #[async_trait] 83 | pub trait Brain: Send + Sync { 84 | /// This method is called at every step of the agent to decide next step 85 | /// note that if the message contains toolcall, it will always continue 86 | async fn next_step(&mut self, context: ThinkerContext) -> Result; 87 | } 88 | 89 | 90 | -------------------------------------------------------------------------------- /shai-core/src/agent/error.rs: -------------------------------------------------------------------------------- 1 | use shai_llm::provider::LlmError; 2 | use thiserror::Error; 3 | 4 | #[derive(Error, Debug, Clone)] 5 | pub enum AgentError { 6 | #[error("Agent execution error: {0}")] 7 | ExecutionError(String), 8 | #[error("LLM error: {0}")] 9 | LlmError(String), 10 | #[error("Tool error: {0}")] 11 | ToolError(String), 12 | #[error("Agent session has been closed")] 13 | SessionClosed, 14 | #[error("Invalid response: {0}")] 15 | InvalidResponse(String), 16 | #[error("User interaction timeout")] 17 | UserTimeout, 18 | #[error("Permission denied")] 19 | PermissionDenied, 20 | #[error("User input cancelled")] 21 | UserInputCancelled, 22 | #[error("Configuration error: {0}")] 23 | ConfigurationError(String), 24 | #[error("Agent execution timed out")] 25 | TimeoutError, 26 | #[error("Maximum iterations reached")] 27 | MaxIterationsReached, 28 | #[error("Invalid state: {0}")] 29 | InvalidState(String), 30 | #[error("Invalid state transition: {0}")] 31 | InvalidStateTransition(String), 32 | } 33 | 34 | #[derive(Debug)] 35 | pub enum AgentExecutionError { 36 | LlmError(LlmError), 37 | ToolError(String), 38 | TimeoutError, 39 | MaxIterationsReached, 40 | ConfigurationError(String), 41 | } 42 | 43 | impl std::fmt::Display for AgentExecutionError { 44 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 | match self { 46 | AgentExecutionError::LlmError(e) => write!(f, "LLM error: {}", e), 47 | AgentExecutionError::ToolError(e) => write!(f, "Tool error: {}", e), 48 | AgentExecutionError::TimeoutError => write!(f, "Agent execution timed out"), 49 | AgentExecutionError::MaxIterationsReached => write!(f, "Maximum iterations reached"), 50 | AgentExecutionError::ConfigurationError(e) => write!(f, "Configuration error: {}", e), 51 | } 52 | } 53 | } 54 | 55 | impl std::error::Error for AgentExecutionError {} 56 | 57 | impl From for AgentExecutionError { 58 | fn from(error: LlmError) -> Self { 59 | AgentExecutionError::LlmError(error) 60 | } 61 | } 62 | 63 | -------------------------------------------------------------------------------- /shai-core/src/agent/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod builder; 2 | pub mod claims; 3 | pub mod error; 4 | pub mod brain; 5 | pub mod agent; 6 | pub mod protocol; 7 | pub mod events; 8 | pub mod states; 9 | pub mod actions; 10 | pub mod output; 11 | 12 | #[cfg(test)] 13 | mod tests; 14 | 15 | pub use agent::{ 16 | Agent, AgentCore, 17 | TaskAgentResponse, 18 | AgentResult 19 | }; 20 | pub use states::{InternalAgentState, PublicAgentState}; 21 | 22 | pub use protocol::{AgentRequest, AgentResponse, AgentController}; 23 | 24 | pub use events::{ 25 | InternalAgentEvent, AgentEvent, 26 | ClosureHandler, AgentEventHandler, DynEventHandler, closure_handler, 27 | UserRequest, UserResponse, PermissionRequest, PermissionResponse}; 28 | pub use output::StdoutEventManager; 29 | 30 | pub use builder::AgentBuilder; 31 | pub use claims::{ClaimManager, PermissionError}; 32 | pub use error::{AgentError, AgentExecutionError}; 33 | pub use brain::{Brain, ThinkerContext, ThinkerDecision, ThinkerFlowControl}; 34 | pub use crate::logging::LoggingConfig; -------------------------------------------------------------------------------- /shai-core/src/agent/output/log.rs: -------------------------------------------------------------------------------- 1 | use std::fs::OpenOptions; 2 | use std::io::Write; 3 | use async_trait::async_trait; 4 | use chrono::Utc; 5 | use crate::agent::{AgentEvent, AgentEventHandler}; 6 | 7 | /// File logger that writes all agent events to a debug log file 8 | pub struct FileEventLogger { 9 | log_path: String, 10 | } 11 | 12 | impl FileEventLogger { 13 | pub fn new(log_path: impl Into) -> Self { 14 | Self { 15 | log_path: log_path.into(), 16 | } 17 | } 18 | 19 | pub fn default() -> Self { 20 | Self::new("agent_events.log") 21 | } 22 | 23 | fn write_event(&self, event: &AgentEvent) { 24 | let timestamp = Utc::now(); 25 | let event_str = match event { 26 | AgentEvent::StatusChanged { old_status, new_status } => { 27 | format!("StatusChanged: {:?} -> {:?}", old_status, new_status) 28 | } 29 | AgentEvent::ThinkingStart => { 30 | format!("ThinkingStart") 31 | } 32 | AgentEvent::BrainResult { timestamp: event_time, thought } => { 33 | format!("BrainResult: {:?} - {:?}", event_time, thought) 34 | } 35 | AgentEvent::ToolCallStarted { timestamp: event_time, call } => { 36 | format!("ToolCallStarted: {:?} - {}", event_time, call.tool_name) 37 | } 38 | AgentEvent::ToolCallCompleted { duration, call, result } => { 39 | format!("ToolCallCompleted: {} in {:?} - {:?}", call.tool_name, duration, result) 40 | } 41 | AgentEvent::UserInput { input } => { 42 | format!("UserInput: {}", input) 43 | } 44 | AgentEvent::UserInputRequired { request_id, request } => { 45 | format!("UserInputRequired: {} - {:?}", request_id, request) 46 | } 47 | AgentEvent::PermissionRequired { request_id, request } => { 48 | format!("PermissionRequired: {} - {}", request_id, request.operation) 49 | } 50 | AgentEvent::Error { error } => { 51 | format!("Error: {}", error) 52 | } 53 | AgentEvent::Completed { success, message } => { 54 | format!("Completed: success={} - {}", success, message) 55 | } 56 | AgentEvent::TokenUsage { input_tokens, output_tokens } => { 57 | format!("Token Usage: input={} output={} total={}", input_tokens, output_tokens, input_tokens + output_tokens) 58 | } 59 | }; 60 | 61 | let log_line = format!("[{}] {}\n", timestamp.format("%Y-%m-%d %H:%M:%S%.3f"), event_str); 62 | 63 | if let Ok(mut file) = OpenOptions::new() 64 | .create(true) 65 | .append(true) 66 | .open(&self.log_path) 67 | { 68 | let _ = file.write_all(log_line.as_bytes()); 69 | let _ = file.flush(); 70 | } 71 | } 72 | } 73 | 74 | #[async_trait] 75 | impl AgentEventHandler for FileEventLogger { 76 | async fn handle_event(&self, event: AgentEvent) { 77 | self.write_event(&event); 78 | } 79 | } 80 | 81 | impl Default for FileEventLogger { 82 | fn default() -> Self { 83 | Self::default() 84 | } 85 | } -------------------------------------------------------------------------------- /shai-core/src/agent/output/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod stdout; 2 | pub mod pretty; 3 | pub mod log; 4 | 5 | pub use stdout::StdoutEventManager; 6 | pub use pretty::PrettyFormatter; 7 | pub use log::FileEventLogger; -------------------------------------------------------------------------------- /shai-core/src/agent/output/stdout.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Write}; 2 | use async_trait::async_trait; 3 | use crate::agent::{AgentEvent, AgentEventHandler}; 4 | use super::pretty::PrettyFormatter; 5 | 6 | /// Stdout event manager that formats and prints agent activity in a user-friendly way 7 | pub struct StdoutEventManager { 8 | formatter: PrettyFormatter, 9 | } 10 | 11 | impl StdoutEventManager { 12 | pub fn new() -> Self { 13 | Self { 14 | formatter: PrettyFormatter::new(), 15 | } 16 | } 17 | } 18 | 19 | #[async_trait] 20 | impl AgentEventHandler for StdoutEventManager { 21 | async fn handle_event(&self, event: AgentEvent) { 22 | if let Some(formatted) = self.formatter.format_event(&event) { 23 | eprintln!("{}", formatted); 24 | let _ = io::stdout().flush(); 25 | } 26 | } 27 | } 28 | 29 | impl Default for StdoutEventManager { 30 | fn default() -> Self { 31 | Self::new() 32 | } 33 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/README.md: -------------------------------------------------------------------------------- 1 | # Agent State Machine 2 | 3 | The Agent operates as a state machine with distinct states for lifecycle management and task execution. 4 | 5 | ## State Diagram 6 | 7 | ``` 8 | ┌─────────────┐ 9 | │ Starting │ 10 | └──────┬──────┘ 11 | │ AgentInitialized 12 | ▼ 13 | ┌─────────────┐ 14 | ┌─────▶│ Running │◄─────────┐ 15 | │ └──────┬──────┘ │ 16 | │ │ spawn_next_step │ BrainResult (continue) 17 | │ │ spawn_tools │ ToolsCompleted 18 | │ ▼ │ 19 | │ ┌─────────────┐ │ 20 | │ │ Processing │ │ 21 | │ │ (brain/ │──────────┘ 22 | │ │ tools) │ 23 | │ └──────┬──────┘ 24 | │ │ BrainResult (pause) 25 | │ │ TaskCancelled 26 | │ │ 27 | │ ▼ 28 | user input │ ┌─────────────┐ 29 | │ │ Paused │ 30 | └──────┤ (waiting │ 31 | │ for user) │ 32 | └─────────────┘ 33 | │ completion/error 34 | ▼ 35 | ┌─────────────┐ 36 | │ Terminal │ 37 | │ (Completed, │ 38 | │ Failed, or │ 39 | │ Cancelled) │ 40 | └─────────────┘ 41 | ``` 42 | 43 | ## States 44 | 45 | - **Starting**: Initial state during agent initialization 46 | - **Running**: Active state ready to process next step 47 | - **Processing**: Executing brain thinking or tool calls 48 | - **Paused**: Waiting for user input (agent decided to pause), this is skipped in the absence of controller 49 | - **Terminal**: Final states (Completed, Failed, Cancelled) 50 | 51 | ## Key Events 52 | 53 | - `AgentInitialized`: Moves from Starting to Running/Paused 54 | - `StartThinking`: Triggers brain execution (Running → Processing) 55 | - `BrainResult`: Brain decision result (Processing → Running/Paused) 56 | - `ToolsCompleted`: Tool execution finished (Processing → Running) 57 | - `CancelTask`: Cancel current operation 58 | 59 | ## State Transitions 60 | 61 | States transition based on internal events and brain decisions. The agent automatically moves between Running and Processing states during normal operation, with Paused state used when the brain decides to yield control back to the user. 62 | -------------------------------------------------------------------------------- /shai-core/src/agent/states/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod states; 2 | pub mod pause; 3 | pub mod running; 4 | pub mod starting; 5 | pub mod processing; 6 | pub mod terminal; 7 | 8 | pub use states::{InternalAgentState, PublicAgentState}; -------------------------------------------------------------------------------- /shai-core/src/agent/states/pause.rs: -------------------------------------------------------------------------------- 1 | use crate::agent::{AgentCore, InternalAgentEvent, AgentError}; 2 | use tracing::error; 3 | 4 | impl AgentCore { 5 | pub async fn state_pause_handle_event(&mut self, event: InternalAgentEvent) -> Result<(), AgentError> { 6 | match event { 7 | InternalAgentEvent::CancelTask => { 8 | // Silently ignore 9 | Ok(()) 10 | } 11 | _ => { 12 | // Paused state: All other events are illegal until user send something 13 | // ignore all events but log error 14 | error!("event {:?} unexpected in state {:?}", event, self.state.to_public()); 15 | Ok(()) 16 | } 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/processing.rs: -------------------------------------------------------------------------------- 1 | use crate::agent::{ 2 | AgentCore, AgentError, InternalAgentEvent 3 | }; 4 | use super::InternalAgentState; 5 | 6 | impl AgentCore { 7 | pub async fn state_processing_handle_event(&mut self, event: InternalAgentEvent) -> Result<(), AgentError> { 8 | match event { 9 | InternalAgentEvent::CancelTask => { 10 | self.cancel_task().await 11 | }, 12 | InternalAgentEvent::BrainResult { result } => { 13 | self.process_next_step(result).await 14 | }, 15 | InternalAgentEvent::ToolsCompleted { any_denied } => { 16 | if any_denied { 17 | self.set_state(InternalAgentState::Paused).await; 18 | } else { 19 | self.set_state(InternalAgentState::Running).await; 20 | } 21 | Ok(()) 22 | }, 23 | _ => { 24 | Ok(()) 25 | } 26 | } 27 | } 28 | 29 | /// cancel all pending tasks 30 | async fn cancel_task(&mut self) -> Result<(), AgentError> { 31 | let InternalAgentState::Processing { cancellation_token, .. } = &self.state else { 32 | return Err(AgentError::InvalidState(format!("state Processing expected but current state is : {:?}", self.state.to_public()))); 33 | }; 34 | 35 | cancellation_token.cancel(); 36 | Ok(()) 37 | } 38 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/running.rs: -------------------------------------------------------------------------------- 1 | use crate::agent::{AgentCore, InternalAgentEvent, AgentError}; 2 | use super::InternalAgentState; 3 | use tracing::error; 4 | 5 | impl AgentCore { 6 | pub async fn state_running_handle_event(&mut self, event: InternalAgentEvent) -> Result<(), AgentError> { 7 | let InternalAgentState::Running = &self.state else { 8 | return Err(AgentError::InvalidState(format!("state Running expected but current state is : {:?}", self.state.to_public()))); 9 | }; 10 | 11 | match event { 12 | InternalAgentEvent::CancelTask => { 13 | // Silently ignore 14 | } 15 | InternalAgentEvent::ThinkingStart => { 16 | self.spawn_next_step().await; 17 | } 18 | _ => { 19 | // Running state: Most other events should be handled by main loop or are illegal 20 | // ignore all events but log error 21 | error!("event {:?} unexpected in state {:?}", event, self.state.to_public()); 22 | } 23 | } 24 | Ok(()) 25 | } 26 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/starting.rs: -------------------------------------------------------------------------------- 1 | use crate::agent::{AgentCore, AgentError, InternalAgentEvent}; 2 | use super::InternalAgentState; 3 | use openai_dive::v1::resources::chat::ChatMessage; 4 | use tracing::error; 5 | 6 | impl AgentCore { 7 | pub async fn state_starting_handle_event(&mut self, event: InternalAgentEvent) -> Result<(), AgentError> { 8 | let InternalAgentState::Starting = &self.state else { 9 | return Err(AgentError::InvalidState(format!("state Starting expected but current state is : {:?}", self.state.to_public()))); 10 | }; 11 | 12 | match event { 13 | InternalAgentEvent::AgentInitialized => { 14 | self.handle_agent_initialized().await; 15 | } 16 | _ => { 17 | // ignore all events but log error 18 | error!("event {:?} unexpected in state {:?}", event, self.state.to_public()); 19 | } 20 | } 21 | Ok(()) 22 | } 23 | 24 | /// Handle agent initialization - move from Starting to Running or Paused based on goal 25 | async fn handle_agent_initialized(&mut self) { 26 | let trace = self.trace.clone(); 27 | let guard = trace.read().await; 28 | if let Some(ChatMessage::User { .. }) = guard.last() { 29 | self.set_state(InternalAgentState::Running).await; 30 | } else { 31 | self.set_state(InternalAgentState::Paused).await; 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/states.rs: -------------------------------------------------------------------------------- 1 | use tokio_util::sync::CancellationToken; 2 | use chrono::{DateTime, Utc}; 3 | 4 | /// Internal agent status (contains channels and sync primitives) 5 | #[derive(Debug)] 6 | pub enum InternalAgentState { 7 | /// Agent is starting up 8 | Starting, 9 | /// Agent is actively running, 10 | Running, 11 | /// Executing, might be doing multiple things at once 12 | Processing { 13 | task_name: String, 14 | tools_exec_at: DateTime, 15 | cancellation_token: CancellationToken, 16 | }, 17 | /// Agent execution is paused 18 | Paused, 19 | /// Agent completed successfully 20 | Completed { success: bool }, 21 | /// Agent failed with error 22 | Failed { error: String }, 23 | } 24 | 25 | 26 | /// Public agent status (clean version without internal channels/sync primitives) 27 | #[derive(Debug, Clone)] 28 | pub enum PublicAgentState { 29 | /// Agent is starting up 30 | Starting, 31 | /// Agent is actively running 32 | Running, 33 | /// Agent is thinking 34 | Processing { 35 | task_name: String, 36 | tools_exec_at: DateTime, 37 | }, 38 | /// Agent execution is paused 39 | Paused, 40 | /// Agent completed successfully 41 | Completed { success: bool }, 42 | /// Agent was cancelled 43 | Cancelled, 44 | /// Agent failed with error 45 | Failed { error: String }, 46 | } 47 | 48 | impl InternalAgentState { 49 | /// Convert internal status to public status (removing channels and sync primitives) 50 | pub fn to_public(&self) -> PublicAgentState { 51 | match self { 52 | InternalAgentState::Starting => PublicAgentState::Starting, 53 | InternalAgentState::Running => PublicAgentState::Running, 54 | InternalAgentState::Processing { task_name, tools_exec_at, .. } => PublicAgentState::Processing { 55 | task_name: task_name.clone(), 56 | tools_exec_at: tools_exec_at.clone() 57 | }, 58 | InternalAgentState::Paused => PublicAgentState::Paused, 59 | InternalAgentState::Completed { success } => PublicAgentState::Completed { 60 | success: *success 61 | }, 62 | InternalAgentState::Failed { error } => PublicAgentState::Failed { 63 | error: error.clone() 64 | }, 65 | } 66 | } 67 | } -------------------------------------------------------------------------------- /shai-core/src/agent/states/terminal.rs: -------------------------------------------------------------------------------- 1 | use crate::agent::{AgentCore, InternalAgentEvent, AgentError}; 2 | use tracing::error; 3 | 4 | impl AgentCore { 5 | pub async fn state_terminal_handle_event(&mut self, event: InternalAgentEvent) -> Result<(), AgentError> { 6 | match event { 7 | _ => { 8 | // ignore all events but log error 9 | error!("event {:?} unexpected in state {:?}", event, self.state.to_public()); 10 | Ok(()) 11 | } 12 | } 13 | } 14 | } -------------------------------------------------------------------------------- /shai-core/src/config/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod config; 2 | pub mod agent; -------------------------------------------------------------------------------- /shai-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod tools; 2 | pub mod agent; 3 | pub mod runners; 4 | pub mod logging; 5 | pub mod config; -------------------------------------------------------------------------------- /shai-core/src/runners/clifixer/fix.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use openai_dive::v1::resources::chat::{ChatCompletionParametersBuilder, ChatCompletionResponseFormat, JsonSchemaBuilder, ChatMessage, ChatMessageContent}; 4 | use shai_llm::{client::LlmClient, provider::LlmError}; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | use super::prompt::clifix_prompt; 8 | 9 | #[derive(Debug, Clone, Serialize, Deserialize)] 10 | pub struct CliFixResponse { 11 | pub short_rational: Option, 12 | pub fixed_cli: String, 13 | } 14 | 15 | pub async fn clifix(llm: Arc, model: String, messages: Vec) -> Result { 16 | let mut messages = messages.clone(); 17 | messages.push(ChatMessage::System { 18 | content: ChatMessageContent::Text(clifix_prompt()), 19 | name: None 20 | }); 21 | 22 | 23 | 24 | let request = ChatCompletionParametersBuilder::default() 25 | .model(model.clone()) 26 | .messages(messages) 27 | .temperature(0.1) 28 | .response_format(ChatCompletionResponseFormat::JsonSchema { 29 | json_schema: JsonSchemaBuilder::default() 30 | .name("cli_fix_response") 31 | .description("Response format for CLI fix with rationale and fixed command") 32 | .schema(serde_json::json!({ 33 | "type": "object", 34 | "properties": { 35 | "short_rational": { "type": "string" }, 36 | "fixed_cli": { "type": "string" } 37 | }, 38 | "required": ["fixed_cli"], 39 | "additionalProperties": false 40 | })) 41 | .strict(true) 42 | .build() 43 | .map_err(|e| -> LlmError { e.into() })? 44 | }) 45 | .build() 46 | .map_err(|e| -> LlmError { e.into() })?; 47 | 48 | /* 49 | if let Ok(json) = serde_json::to_string_pretty(&request) { 50 | let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); 51 | let filename = format!("request_{}.json", timestamp); 52 | let _ = std::fs::write(&filename, json); 53 | } 54 | */ 55 | 56 | let response = llm.chat(request) 57 | .await?; 58 | 59 | if let ChatMessage::Assistant { content: Some(ChatMessageContent::Text(content)), .. } = response.choices[0].message.clone() { 60 | let parsed: CliFixResponse = serde_json::from_str(&content) 61 | .map_err(|e| -> LlmError { format!("Failed to parse CLI fix response: {}", e).into() })?; 62 | Ok(parsed) 63 | } else { 64 | Err("No content in response".into()) 65 | } 66 | } -------------------------------------------------------------------------------- /shai-core/src/runners/clifixer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod prompt; 2 | pub mod fix; -------------------------------------------------------------------------------- /shai-core/src/runners/clifixer/prompt.rs: -------------------------------------------------------------------------------- 1 | use crate::runners::coder::env::{get_os_version, get_platform, get_today, get_working_dir, is_git_repo, env_all_key}; 2 | 3 | 4 | static CLIFIX_GOAL: &str = r#" 5 | You are SHAI's CLI error recovery assistant. When a user's command fails, you analyze the error and provide a corrected command. 6 | 7 | ## Your Task 8 | The user executed a command that failed. Your mission: 9 | 1. **Analyze the error** - Identify why the command failed (typo, wrong flag, missing dependency, etc.) 10 | 2. **Understand intent** - Consider command history to grasp what the user was trying to accomplish 11 | 3. **Provide solution** - Suggest the correct command that will work 12 | 13 | ## Common Error Patterns to Watch For: 14 | - **Command not found**: Suggest correct spelling or installation 15 | - **Invalid flags/options**: Provide valid alternatives 16 | - **Missing dependencies**: Include installation steps if needed 17 | - **Wrong syntax**: Fix parameter order or structure 18 | - **Permission issues**: Add sudo or ownership fixes 19 | - **Path problems**: Correct file/directory references 20 | 21 | ## Response Requirements 22 | Return valid JSON with exactly these fields: 23 | ```json 24 | { 25 | "short_rational": "Brief explanation of what went wrong (optional)", 26 | "fixed_cli": "corrected command ready to copy-paste" 27 | } 28 | ``` 29 | 30 | **Guidelines:** 31 | - Keep explanations concise and constructive 32 | - Ensure `fixed_cli` works in the current environment 33 | - No quotes or backticks around the command 34 | - Focus on the most likely fix, not all possibilities 35 | - If unsure, provide the safest/most common solution 36 | 37 | ## Environment Context 38 | 39 | Working directory: {working_dir} 40 | Is directory a git repo: {is_git_repo} 41 | Platform: {platform} 42 | OS Version: {os_version} 43 | Today's date: {today} 44 | 45 | Environment variables: 46 | {env} 47 | 48 | "#; 49 | 50 | 51 | pub fn clifix_prompt() -> String { 52 | let working_dir = get_working_dir(); 53 | let os = get_os_version(); 54 | let platform = get_platform(); 55 | let today = get_today(); 56 | let git_repo = is_git_repo(); 57 | let env = env_all_key(); 58 | 59 | CLIFIX_GOAL 60 | .replace("{working_dir}", &working_dir) 61 | .replace("{is_git_repo}", &git_repo.to_string()) 62 | .replace("{platform}", &platform) 63 | .replace("{os_version}", &os) 64 | .replace("{today}", &today) 65 | .replace("{env}", &env) 66 | .to_string() 67 | } -------------------------------------------------------------------------------- /shai-core/src/runners/coder/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod coder; 2 | pub mod prompt; 3 | pub mod env; 4 | 5 | pub use coder::CoderBrain; 6 | 7 | #[cfg(test)] 8 | mod tests; -------------------------------------------------------------------------------- /shai-core/src/runners/compacter/compact.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ovh/shai/eda106e6e921da0f01c56d50cdc34371c1ba4678/shai-core/src/runners/compacter/compact.rs -------------------------------------------------------------------------------- /shai-core/src/runners/compacter/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod compact; -------------------------------------------------------------------------------- /shai-core/src/runners/gerund/gerund.rs: -------------------------------------------------------------------------------- 1 | use openai_dive::v1::resources::chat::{ChatCompletionParametersBuilder, ChatMessage, ChatMessageContent}; 2 | use shai_llm::{client::LlmClient, provider::LlmError}; 3 | 4 | use super::prompt::gerund_prompt; 5 | 6 | 7 | 8 | pub async fn gerund(llm: LlmClient, model: String, message: String) -> Result { 9 | let message = if message.is_empty() { "the user has sent an empty message".to_string()} else {message}; 10 | let mut messages = vec![ChatMessage::User { content: ChatMessageContent::Text(message.clone()), name: None }]; 11 | messages.push(ChatMessage::System { 12 | content: ChatMessageContent::Text(gerund_prompt()), 13 | name: None 14 | }); 15 | 16 | let request = ChatCompletionParametersBuilder::default() 17 | .model(model.clone()) 18 | .messages(messages) 19 | .temperature(0.1) 20 | .build() 21 | .map_err(|e| e)?; 22 | 23 | // submit it to our big brain coder 24 | let response = llm.chat(request) 25 | .await?; 26 | 27 | Ok(response.choices[0].message.clone()) 28 | } -------------------------------------------------------------------------------- /shai-core/src/runners/gerund/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod prompt; 2 | pub mod gerund; 3 | 4 | #[cfg(test)] 5 | mod tests; -------------------------------------------------------------------------------- /shai-core/src/runners/gerund/prompt.rs: -------------------------------------------------------------------------------- 1 | 2 | static GERUND_PROMPT: &str = r#" 3 | Transform the user's request into a single uplifting verb ending in -ing that captures the essence of their message. Your response must be exactly one word - no explanations, no punctuation, no extra text. Capitalize only the first letter. 4 | 5 | Guidelines for word selection: 6 | • Choose words that spark joy and convey progress 7 | • Prioritize creativity and linguistic flair - unusual or sophisticated words are encouraged 8 | • Ensure strong thematic connection to the user's intent 9 | • Craft words that would make a developer smile when seen as a status indicator 10 | 11 | Forbidden categories: 12 | • System-related anxiety triggers (Connecting, Buffering, Loading, Syncing, Waiting) 13 | • Destructive actions (Terminating, Removing, Clearing, Purging, Erasing) 14 | • Potentially inappropriate terms in professional contexts 15 | • Negative or concerning language 16 | 17 | Think of yourself as a wordsmith creating delightful micro-poetry for status displays. The goal is to make routine development tasks feel more engaging and human. 18 | "#; 19 | 20 | 21 | pub fn gerund_prompt() -> String { 22 | GERUND_PROMPT.to_string() 23 | } -------------------------------------------------------------------------------- /shai-core/src/runners/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod coder; 2 | pub mod compacter; 3 | pub mod searcher; 4 | pub mod gerund; 5 | pub mod clifixer; -------------------------------------------------------------------------------- /shai-core/src/runners/searcher/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod searcher; 2 | pub mod prompt; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use searcher::searcher; -------------------------------------------------------------------------------- /shai-core/src/runners/searcher/searcher.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use openai_dive::v1::resources::chat::{ChatCompletionParametersBuilder, ChatCompletionToolChoice, ChatMessage, ChatMessageContent}; 4 | use shai_llm::client::LlmClient; 5 | use async_trait::async_trait; 6 | 7 | use crate::agent::brain::ThinkerDecision; 8 | use crate::agent::{Agent, AgentBuilder, AgentError, Brain, ThinkerContext}; 9 | use crate::tools::{AnyTool, FetchTool, FindTool, LsTool, ReadTool, TodoReadTool, TodoWriteTool, TodoStorage}; 10 | 11 | use super::prompt::searcher_next_step; 12 | 13 | #[derive(Clone)] 14 | pub struct SearcherBrain { 15 | pub llm: Arc, 16 | pub model: String 17 | } 18 | 19 | impl SearcherBrain { 20 | pub fn new(llm: Arc, model: String) -> Self { 21 | Self { llm, model } 22 | } 23 | 24 | /// Generic method to make LLM requests with custom system prompts and tools 25 | async fn chat_with_tools( 26 | &self, 27 | messages: Vec, 28 | tools: &Vec>, 29 | tool_choice: ChatCompletionToolChoice, 30 | ) -> Result { 31 | let request = ChatCompletionParametersBuilder::default() 32 | .model(&self.model) 33 | .messages(messages) 34 | .tools(tools.iter().map(|t| t.to_openai()).collect::>()) 35 | .tool_choice(tool_choice) 36 | .temperature(0.1) 37 | .build() 38 | .map_err(|e| AgentError::LlmError(e.to_string()))?; 39 | 40 | let response = self 41 | .llm 42 | .chat(request) 43 | .await 44 | .map_err(|e| AgentError::LlmError(e.to_string()))?; 45 | 46 | Ok(response.choices[0].message.clone()) 47 | } 48 | } 49 | 50 | 51 | #[async_trait] 52 | impl Brain for SearcherBrain { 53 | async fn next_step(&mut self, context: ThinkerContext) -> Result { 54 | let mut trace = context.trace.read().await.clone(); 55 | 56 | trace.insert(0, ChatMessage::System { 57 | content: ChatMessageContent::Text(searcher_next_step()), 58 | name: None, 59 | }); 60 | let brain_decision = self.chat_with_tools( 61 | trace, 62 | &context.available_tools, 63 | ChatCompletionToolChoice::Auto, 64 | ) 65 | .await?; 66 | 67 | // stop here if there's no other tool calls 68 | if let ChatMessage::Assistant { reasoning_content, content, tool_calls, .. } = &brain_decision { 69 | if tool_calls.as_ref().map_or(true, |calls| calls.is_empty()) { 70 | return Ok(ThinkerDecision::agent_pause(brain_decision)); 71 | } 72 | } 73 | 74 | Ok(ThinkerDecision::agent_continue(brain_decision)) 75 | } 76 | } 77 | 78 | 79 | 80 | pub fn searcher(llm: Arc, model: String) -> impl Agent { 81 | // Create shared storage for todo tools 82 | let todo_storage = Arc::new(TodoStorage::new()); 83 | 84 | // Only read-only tools for the searcher 85 | let fetch = Box::new(FetchTool::new()); 86 | let find = Box::new(FindTool::new()); 87 | let ls = Box::new(LsTool::new()); 88 | let read = Box::new(ReadTool::new(Arc::new(crate::tools::FsOperationLog::new()))); 89 | let todoread = Box::new(TodoReadTool::new(todo_storage.clone())); 90 | let todowrite = Box::new(TodoWriteTool::new(todo_storage.clone())); 91 | let toolbox: Vec> = vec![fetch, find, ls, read, todoread, todowrite]; 92 | 93 | AgentBuilder::with_brain(Box::new(SearcherBrain{llm: llm.clone(), model})) 94 | .tools(toolbox) 95 | .build() 96 | } -------------------------------------------------------------------------------- /shai-core/src/tools/bash/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod bash; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::BashToolParams; 8 | pub use bash::BashTool; -------------------------------------------------------------------------------- /shai-core/src/tools/bash/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | use std::collections::HashMap; 4 | 5 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 6 | pub struct BashToolParams { 7 | /// The bash command to execute 8 | pub command: String, 9 | /// Timeout in seconds (optional, None = no timeout) 10 | pub timeout: Option, 11 | /// Working directory for command execution (optional) 12 | pub working_dir: Option, 13 | /// Environment variables to set (optional) 14 | #[serde(default)] 15 | pub env: HashMap, 16 | } 17 | -------------------------------------------------------------------------------- /shai-core/src/tools/bash/tests.rs: -------------------------------------------------------------------------------- 1 | use super::structs::BashToolParams; 2 | use super::bash::BashTool; 3 | use crate::tools::{Tool, ToolCapability}; 4 | use shai_llm::ToolDescription; 5 | use std::collections::HashMap; 6 | use serde_json::json; 7 | 8 | #[test] 9 | fn test_bash_tool_permissions() { 10 | let tool = BashTool::new(); 11 | let perms = tool.capabilities(); 12 | assert!(perms.contains(&ToolCapability::Read)); 13 | assert!(perms.contains(&ToolCapability::Write)); 14 | assert!(perms.contains(&ToolCapability::Network)); 15 | assert_eq!(perms.len(), 3); 16 | } 17 | 18 | #[tokio::test] 19 | async fn test_bash_tool_creation() { 20 | let tool = BashTool::new(); 21 | assert_eq!(&tool.name(), "bash"); 22 | assert!(!tool.description().is_empty()); 23 | } 24 | 25 | #[tokio::test] 26 | async fn test_bash_tool_execution() { 27 | let tool = BashTool::new(); 28 | let params = BashToolParams { 29 | command: "echo hello".to_string(), 30 | timeout: None, 31 | working_dir: None, 32 | env: HashMap::new(), 33 | }; 34 | 35 | let result = Tool::execute(&tool, params, None).await; 36 | assert!(result.is_success()); 37 | if let crate::tools::types::ToolResult::Success { output, metadata } = result { 38 | assert!(output.contains("hello")); 39 | let metadata = metadata.unwrap(); 40 | assert_eq!(metadata["exit_code"], json!(0)); 41 | assert_eq!(metadata["success"], json!(true)); 42 | } else { 43 | panic!("Expected success result"); 44 | } 45 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fetch/fetch.rs: -------------------------------------------------------------------------------- 1 | use super::structs::{FetchToolParams, HttpMethod}; 2 | use crate::tools::{ToolResult, tool}; 3 | use serde_json::json; 4 | use std::collections::HashMap; 5 | use reqwest; 6 | use std::time::Duration; 7 | 8 | pub struct FetchTool; 9 | 10 | impl FetchTool { 11 | pub fn new() -> Self { 12 | Self 13 | } 14 | } 15 | 16 | #[tool(name = "fetch", description = r#"Retrieves content from a URL. This tool is ideal for accessing web pages, APIs, or other online resources. 17 | 18 | **Functionality:** 19 | - Supports `GET`, `POST`, `PUT`, and `DELETE` HTTP methods. 20 | - Allows for custom headers and request bodies, making it suitable for interacting with REST APIs. 21 | - Includes a timeout to prevent indefinite hangs on unresponsive servers. 22 | 23 | **Usage Notes:** 24 | - Provide a fully-qualified URL. 25 | - For API interactions, you can set the `Content-Type` header to `application/json` and provide a JSON string as the `body`. 26 | - The tool will return the raw response body, which you can then parse or analyze. 27 | 28 | **Examples:** 29 | - **Get a web page:** `fetch(url='https://example.com')` 30 | - **Get JSON data from an API:** `fetch(url='https://api.example.com/data')` 31 | - **Post JSON data to an API:** `fetch(url='https://api.example.com/users', method='POST', headers={'Content-Type': 'application/json'}, body='{"name": "John Doe"}')` 32 | "#, capabilities = [ToolCapability::Network])] 33 | impl FetchTool { 34 | async fn execute(&self, params: FetchToolParams) -> ToolResult { 35 | let client = reqwest::Client::builder() 36 | .timeout(Duration::from_secs(params.timeout)) 37 | .build(); 38 | 39 | let client = match client { 40 | Ok(c) => c, 41 | Err(e) => return ToolResult::error(format!("Failed to create HTTP client: {}", e)) 42 | }; 43 | 44 | // Build the request 45 | let mut request_builder = match params.method { 46 | HttpMethod::Get => client.get(¶ms.url), 47 | HttpMethod::Post => client.post(¶ms.url), 48 | HttpMethod::Put => client.put(¶ms.url), 49 | HttpMethod::Delete => client.delete(¶ms.url), 50 | }; 51 | 52 | // Add headers if provided 53 | if let Some(headers) = ¶ms.headers { 54 | for (key, value) in headers { 55 | request_builder = request_builder.header(key, value); 56 | } 57 | } 58 | 59 | // Add body for POST/PUT requests 60 | if let Some(body) = ¶ms.body { 61 | request_builder = request_builder.body(body.clone()); 62 | } 63 | 64 | // Execute the request 65 | match request_builder.send().await { 66 | Ok(response) => { 67 | let status = response.status(); 68 | let headers: HashMap = response 69 | .headers() 70 | .iter() 71 | .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) 72 | .collect(); 73 | 74 | match response.text().await { 75 | Ok(body) => { 76 | let mut meta = HashMap::new(); 77 | meta.insert("url".to_string(), json!(params.url)); 78 | meta.insert("method".to_string(), json!(match params.method { 79 | HttpMethod::Get => "GET", 80 | HttpMethod::Post => "POST", 81 | HttpMethod::Put => "PUT", 82 | HttpMethod::Delete => "DELETE", 83 | })); 84 | meta.insert("status_code".to_string(), json!(status.as_u16())); 85 | meta.insert("response_headers".to_string(), json!(headers)); 86 | meta.insert("content_length".to_string(), json!(body.len())); 87 | 88 | if status.is_success() { 89 | ToolResult::Success { 90 | output: body, 91 | metadata: Some(meta), 92 | } 93 | } else { 94 | ToolResult::Error { 95 | error: format!("HTTP request failed with status: {}", status), 96 | metadata: Some(meta), 97 | } 98 | } 99 | }, 100 | Err(e) => ToolResult::error(format!("Failed to read response body: {}", e)) 101 | } 102 | }, 103 | Err(e) => ToolResult::error(format!("HTTP request failed: {}", e)) 104 | } 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /shai-core/src/tools/fetch/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod fetch; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::{FetchToolParams, HttpMethod}; 8 | pub use fetch::FetchTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fetch/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | use std::collections::HashMap; 4 | 5 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 6 | pub struct FetchToolParams { 7 | /// URL to fetch data from 8 | pub url: String, 9 | /// HTTP method to use 10 | #[serde(default = "default_method")] 11 | pub method: HttpMethod, 12 | /// HTTP headers to send (optional) 13 | #[serde(default)] 14 | pub headers: Option>, 15 | /// Request body for POST/PUT (optional) 16 | #[serde(default)] 17 | pub body: Option, 18 | /// Request timeout in seconds (optional, defaults to 30) 19 | #[serde(default = "default_timeout")] 20 | pub timeout: u64, 21 | } 22 | 23 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 24 | #[serde(rename_all = "UPPERCASE")] 25 | #[schemars(inline)] 26 | pub enum HttpMethod { 27 | Get, 28 | Post, 29 | Put, 30 | Delete, 31 | } 32 | 33 | fn default_method() -> HttpMethod { 34 | HttpMethod::Get 35 | } 36 | 37 | fn default_timeout() -> u64 { 38 | 30 39 | } 40 | -------------------------------------------------------------------------------- /shai-core/src/tools/fetch/tests.rs: -------------------------------------------------------------------------------- 1 | use super::fetch::FetchTool; 2 | use crate::tools::{Tool, ToolCapability}; 3 | use shai_llm::ToolDescription; 4 | 5 | #[test] 6 | fn test_fetch_tool_permissions() { 7 | let tool = FetchTool::new(); 8 | let perms = tool.capabilities(); 9 | assert!(perms.contains(&ToolCapability::Network)); 10 | assert_eq!(perms.len(), 1); 11 | } 12 | 13 | #[tokio::test] 14 | async fn test_fetch_tool_creation() { 15 | let tool = FetchTool::new(); 16 | assert_eq!(&tool.name(), "fetch"); 17 | assert!(!tool.description().is_empty()); 18 | } 19 | 20 | // Note: Actual network tests would require internet connectivity 21 | // In a real environment, you'd test with mock servers or local endpoints -------------------------------------------------------------------------------- /shai-core/src/tools/fs/edit/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod edit; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::EditToolParams; 8 | pub use edit::EditTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fs/edit/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 5 | pub struct EditToolParams { 6 | /// Path to the file to edit 7 | pub path: String, 8 | /// The text pattern to find and replace 9 | pub old_string: String, 10 | /// The replacement text 11 | pub new_string: String, 12 | /// Whether to replace all occurrences (default: false, replaces only first) 13 | #[serde(default)] 14 | pub replace_all: bool, 15 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/find/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod find; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::{FindToolParams, FindType, SearchResult}; 8 | pub use find::FindTool; 9 | -------------------------------------------------------------------------------- /shai-core/src/tools/fs/find/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 5 | pub struct FindToolParams { 6 | /// The pattern to search for (supports regex) 7 | pub pattern: String, 8 | /// Directory to search in (defaults to current directory) 9 | #[serde(default)] 10 | pub path: Option, 11 | /// File extensions to include (e.g., "rs,js,py") 12 | #[serde(default)] 13 | pub include_extensions: Option, 14 | /// File patterns to exclude (e.g., "target,node_modules,.git") 15 | #[serde(default)] 16 | pub exclude_patterns: Option, 17 | /// Maximum number of results to return 18 | #[serde(default = "default_max_results")] 19 | pub max_results: u32, 20 | /// Whether to use case-sensitive search 21 | #[serde(default)] 22 | pub case_sensitive: bool, 23 | /// Find type: content (search file contents) or filename (search file names) 24 | #[serde(default = "default_find_type")] 25 | pub find_type: FindType, 26 | /// Show line numbers in results 27 | #[serde(default = "default_show_line_numbers")] 28 | pub show_line_numbers: bool, 29 | /// Maximum lines of context around matches 30 | #[serde(default)] 31 | pub context_lines: Option, 32 | /// Use whole word matching 33 | #[serde(default)] 34 | pub whole_word: bool, 35 | } 36 | 37 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 38 | #[serde(rename_all = "lowercase")] 39 | #[schemars(inline)] 40 | pub enum FindType { 41 | Content, 42 | Filename, 43 | Both, 44 | } 45 | 46 | fn default_max_results() -> u32 { 100 } 47 | fn default_find_type() -> FindType { FindType::Content } 48 | fn default_show_line_numbers() -> bool { true } 49 | 50 | #[derive(Debug, Clone, Serialize, Deserialize)] 51 | pub struct SearchResult { 52 | pub file_path: String, 53 | pub line_number: Option, 54 | pub line_content: Option, 55 | pub context_before: Vec, 56 | pub context_after: Vec, 57 | pub match_type: String, // "content" or "filename" 58 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/ls/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod ls; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::{LsToolParams, FileInfo}; 8 | pub use ls::LsTool; 9 | -------------------------------------------------------------------------------- /shai-core/src/tools/fs/ls/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 5 | pub struct LsToolParams { 6 | /// Directory to list (defaults to current directory) 7 | #[serde(default = "default_directory")] 8 | pub directory: String, 9 | /// Whether to list files recursively (defaults to false) 10 | #[serde(default)] 11 | pub recursive: bool, 12 | /// Show hidden files (files starting with .) 13 | #[serde(default)] 14 | pub show_hidden: bool, 15 | /// Show detailed information (size, permissions, etc.) 16 | #[serde(default)] 17 | pub long_format: bool, 18 | /// Maximum depth for recursive listing (None = unlimited) 19 | #[serde(default)] 20 | pub max_depth: Option, 21 | /// Maximum number of files to return (defaults to 200, set to None for unlimited) 22 | #[serde(default = "default_max_files")] 23 | pub max_files: Option, 24 | } 25 | 26 | fn default_directory() -> String { 27 | ".".to_string() 28 | } 29 | 30 | fn default_max_files() -> Option { 31 | Some(200) // Reasonable limit for LLM context 32 | } 33 | 34 | #[derive(Debug, Clone)] 35 | pub struct FileInfo { 36 | pub name: String, 37 | pub path: String, 38 | pub is_dir: bool, 39 | pub size: u64, 40 | pub modified: Option, 41 | pub permissions: String, 42 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/ls/tests.rs: -------------------------------------------------------------------------------- 1 | // Tests will be added here 2 | #[cfg(test)] 3 | mod tests { 4 | #[test] 5 | fn placeholder_test() { 6 | // Placeholder test to make module compile 7 | assert!(true); 8 | } 9 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod edit; 2 | pub mod find; 3 | pub mod ls; 4 | pub mod multiedit; 5 | pub mod operation_log; 6 | pub mod read; 7 | pub mod write; 8 | 9 | #[cfg(test)] 10 | mod tests; 11 | 12 | pub use edit::EditTool; 13 | pub use find::FindTool; 14 | pub use ls::LsTool; 15 | pub use multiedit::MultiEditTool; 16 | pub use operation_log::{FsOperationLog, FsOperationType, FsOperation, FsOperationSummary}; 17 | pub use read::ReadTool; 18 | pub use write::WriteTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fs/multiedit/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod multiedit; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::{MultiEditToolParams, EditOperation}; 8 | pub use multiedit::MultiEditTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fs/multiedit/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 5 | #[schemars(inline)] 6 | pub struct EditOperation { 7 | /// The text pattern to find and replace 8 | pub old_string: String, 9 | /// The replacement text 10 | pub new_string: String, 11 | /// Whether to replace all occurrences (default: false, replaces only first) 12 | #[serde(default)] 13 | pub replace_all: bool, 14 | } 15 | 16 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 17 | pub struct MultiEditToolParams { 18 | /// Path to the file to edit 19 | pub file_path: String, 20 | /// Array of edit operations to perform sequentially 21 | pub edits: Vec, 22 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/multiedit/tests.rs: -------------------------------------------------------------------------------- 1 | use super::structs::{MultiEditToolParams, EditOperation}; 2 | use super::multiedit::MultiEditTool; 3 | use crate::tools::{Tool, ToolCapability, FsOperationLog}; 4 | use shai_llm::ToolDescription; 5 | use std::fs; 6 | use std::sync::Arc; 7 | use tempfile::tempdir; 8 | 9 | #[test] 10 | fn test_multiedit_tool_permissions() { 11 | let log = Arc::new(FsOperationLog::new()); 12 | let tool = MultiEditTool::new(log); 13 | let perms = tool.capabilities(); 14 | assert!(perms.contains(&ToolCapability::Read)); 15 | assert!(perms.contains(&ToolCapability::Write)); 16 | assert_eq!(perms.len(), 2); 17 | } 18 | 19 | #[tokio::test] 20 | async fn test_multiedit_tool_creation() { 21 | let log = Arc::new(FsOperationLog::new()); 22 | let tool = MultiEditTool::new(log); 23 | assert_eq!(&tool.name(), "multiedit"); 24 | assert!(!tool.description().is_empty()); 25 | } 26 | 27 | #[tokio::test] 28 | async fn test_multiedit_multiple_replacements() { 29 | let dir = tempdir().unwrap(); 30 | let file_path = dir.path().join("test.txt"); 31 | fs::write(&file_path, "Hello World, Hello Universe").unwrap(); 32 | 33 | let log = Arc::new(FsOperationLog::new()); 34 | // First read the file to satisfy the logging requirement 35 | log.log_operation(crate::tools::FsOperationType::Read, file_path.to_string_lossy().to_string()).await; 36 | 37 | let tool = MultiEditTool::new(log); 38 | let params = MultiEditToolParams { 39 | file_path: file_path.to_string_lossy().to_string(), 40 | edits: vec![ 41 | EditOperation { 42 | old_string: "Hello".to_string(), 43 | new_string: "Hi".to_string(), 44 | replace_all: false, 45 | }, 46 | EditOperation { 47 | old_string: "World".to_string(), 48 | new_string: "Earth".to_string(), 49 | replace_all: true, 50 | }, 51 | ], 52 | }; 53 | 54 | let result = tool.execute(params, None).await; 55 | assert!(result.is_success()); 56 | 57 | let content = fs::read_to_string(&file_path).unwrap(); 58 | assert_eq!(content, "Hi Earth, Hello Universe"); 59 | } 60 | 61 | #[tokio::test] 62 | async fn test_multiedit_preview_diff_output() { 63 | let dir = tempdir().unwrap(); 64 | let file_path = dir.path().join("test.txt"); 65 | fs::write(&file_path, "line1\nHello World\nline3\nGoodbye World").unwrap(); 66 | 67 | let log = Arc::new(FsOperationLog::new()); 68 | log.log_operation(crate::tools::FsOperationType::Read, file_path.to_string_lossy().to_string()).await; 69 | 70 | let tool = MultiEditTool::new(log); 71 | let params = MultiEditToolParams { 72 | file_path: file_path.to_string_lossy().to_string(), 73 | edits: vec![ 74 | EditOperation { 75 | old_string: "Hello".to_string(), 76 | new_string: "Hi".to_string(), 77 | replace_all: false, 78 | }, 79 | EditOperation { 80 | old_string: "Goodbye".to_string(), 81 | new_string: "Farewell".to_string(), 82 | replace_all: false, 83 | }, 84 | ], 85 | }; 86 | 87 | // Test preview - should return Some(ToolResult) with diff 88 | let preview_result = tool.execute_preview(params.clone()).await; 89 | assert!(preview_result.is_some()); 90 | 91 | let preview = preview_result.unwrap(); 92 | assert!(preview.is_success()); 93 | 94 | // Preview should contain diff output showing both changes 95 | let output = match preview { 96 | crate::tools::ToolResult::Success { output, .. } => output, 97 | _ => panic!("Expected success result") 98 | }; 99 | 100 | println!("MultiEdit diff output:\n{}", output); 101 | 102 | // Should contain diff markers for both changes 103 | assert!(output.contains("-")); // Deletion markers 104 | assert!(output.contains("+")); // Addition markers 105 | assert!(output.contains("Hello")); // First old content 106 | assert!(output.contains("Hi")); // First new content 107 | assert!(output.contains("Goodbye")); // Second old content 108 | assert!(output.contains("Farewell")); // Second new content 109 | 110 | // Should contain ANSI color codes 111 | assert!(output.contains("\x1b[48;5;88;37m")); // Red background for deletions 112 | assert!(output.contains("\x1b[48;5;28;37m")); // Green background for additions 113 | 114 | // Original file should be unchanged after preview 115 | let original_content = fs::read_to_string(&file_path).unwrap(); 116 | assert_eq!(original_content, "line1\nHello World\nline3\nGoodbye World"); 117 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/read/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod read; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::ReadToolParams; 8 | pub use read::ReadTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fs/read/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 5 | pub struct ReadToolParams { 6 | /// Path to the file to read 7 | pub path: String, 8 | /// Starting line number (optional) 9 | #[serde(default)] 10 | pub line_start: Option, 11 | /// Ending line number (optional) 12 | #[serde(default)] 13 | pub line_end: Option, 14 | /// Whether to include line numbers in the output 15 | #[serde(default)] 16 | pub show_line_numbers: bool, 17 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/write/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod write; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::WriteToolParams; 8 | pub use write::WriteTool; -------------------------------------------------------------------------------- /shai-core/src/tools/fs/write/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use schemars::JsonSchema; 3 | 4 | #[derive(Debug, Clone, Deserialize, JsonSchema)] 5 | pub struct WriteToolParams { 6 | /// Path to the file to write 7 | pub path: String, 8 | /// Content to write to the file 9 | pub content: String, 10 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/write/tests.rs: -------------------------------------------------------------------------------- 1 | use super::structs::WriteToolParams; 2 | use super::write::WriteTool; 3 | use crate::tools::{Tool, ToolCapability, FsOperationLog}; 4 | use shai_llm::ToolDescription; 5 | use std::fs; 6 | use std::sync::Arc; 7 | use tempfile::tempdir; 8 | 9 | 10 | #[test] 11 | fn test_write_tool_permissions() { 12 | let log = Arc::new(FsOperationLog::new()); 13 | let tool = WriteTool::new(log); 14 | let perms = tool.capabilities(); 15 | assert!(perms.contains(&ToolCapability::Write)); 16 | assert_eq!(perms.len(), 1); 17 | } 18 | 19 | #[tokio::test] 20 | async fn test_write_tool_creation() { 21 | let log = Arc::new(FsOperationLog::new()); 22 | let tool = WriteTool::new(log); 23 | assert_eq!(&tool.name(), "write"); 24 | assert!(!tool.description().is_empty()); 25 | } 26 | 27 | #[tokio::test] 28 | async fn test_write_new_file() { 29 | let dir = tempdir().unwrap(); 30 | let file_path = dir.path().join("new_file.txt"); 31 | 32 | let log = Arc::new(FsOperationLog::new()); 33 | let tool = WriteTool::new(log); 34 | let params = WriteToolParams { 35 | path: file_path.to_string_lossy().to_string(), 36 | content: "Hello, World!".to_string(), 37 | }; 38 | 39 | let result = tool.execute(params, None).await; 40 | assert!(result.is_success()); 41 | if let crate::tools::types::ToolResult::Success { output, .. } = result { 42 | assert!(output.contains("created")); 43 | } else { 44 | panic!("Expected success result"); 45 | } 46 | 47 | let content = fs::read_to_string(&file_path).unwrap(); 48 | assert_eq!(content, "Hello, World!"); 49 | } -------------------------------------------------------------------------------- /shai-core/src/tools/fs/write/write.rs: -------------------------------------------------------------------------------- 1 | use super::structs::WriteToolParams; 2 | use super::super::{FsOperationLog, FsOperationType}; 3 | use crate::tools::{ToolResult, tool}; 4 | //use crate::tools::highlight::highlight_content; 5 | use serde_json::json; 6 | use std::collections::HashMap; 7 | use std::fs; 8 | use std::path::Path; 9 | use std::sync::Arc; 10 | 11 | #[derive(Clone)] 12 | pub struct WriteTool { 13 | operation_log: Arc, 14 | } 15 | 16 | impl WriteTool { 17 | pub fn new(operation_log: Arc) -> Self { 18 | Self { operation_log } 19 | } 20 | 21 | fn perform_write(&self, params: &WriteToolParams) -> Result { 22 | let path = Path::new(¶ms.path); 23 | 24 | // Check if file exists before writing 25 | let file_existed = path.exists(); 26 | 27 | // Create parent directories if they don't exist 28 | if let Some(parent) = path.parent() { 29 | if !parent.exists() { 30 | fs::create_dir_all(parent).map_err(|e| e.to_string())?; 31 | } 32 | } 33 | 34 | // Write content to file (overwrites if exists) 35 | fs::write(path, ¶ms.content).map_err(|e| e.to_string())?; 36 | 37 | let action = if file_existed { "updated" } else { "created" }; 38 | 39 | Ok(format!("Successfully {} file '{}' with {} bytes", 40 | action, params.path, params.content.len())) 41 | } 42 | } 43 | 44 | #[tool(name = "write", description = r#"Creates a new file with specified content or completely overwrites an existing file. This tool should be used with caution. 45 | 46 | **Guidelines** 47 | - To overwrite an existing file, you must first have read it with the `read` tool. This is a safety measure to ensure you are aware of the content being replaced. 48 | - This tool is primarily for creating new files when explicitly instructed. For modifying existing files, the `edit` or `multiedit` tools are the correct choice. 49 | - Do not create files proactively, especially documentation. Only create files when the user's request cannot be fulfilled by modifying existing ones."#, capabilities = [ToolCapability::Write])] 50 | impl WriteTool { 51 | 52 | async fn execute_preview(&self, params: WriteToolParams) -> Option { 53 | //let highlighted_content = highlight_content(¶ms.content, ¶ms.path); 54 | 55 | let mut metadata = HashMap::new(); 56 | metadata.insert("path".to_string(), json!(params.path)); 57 | metadata.insert("content_length".to_string(), json!(params.content.len())); 58 | metadata.insert("line_count".to_string(), json!(params.content.lines().count())); 59 | metadata.insert("operation".to_string(), json!("write_preview")); 60 | 61 | Some(ToolResult::Success { 62 | output: params.content, 63 | metadata: Some(metadata), 64 | }) 65 | } 66 | 67 | async fn execute(&self, params: WriteToolParams) -> ToolResult { 68 | match self.perform_write(¶ms) { 69 | Ok(message) => { 70 | // Log the write operation 71 | self.operation_log.log_operation(FsOperationType::Write, params.path.clone()).await; 72 | 73 | let output = format!("{}\n{}", message, params.content); 74 | let mut meta = HashMap::new(); 75 | meta.insert("path".to_string(), json!(params.path)); 76 | meta.insert("content_length".to_string(), json!(params.content.len())); 77 | meta.insert("operation".to_string(), json!("write")); 78 | 79 | // Add file size information 80 | if let Ok(metadata) = std::fs::metadata(¶ms.path) { 81 | meta.insert("file_size_bytes".to_string(), json!(metadata.len())); 82 | } 83 | 84 | // Add line count information 85 | let line_count = params.content.lines().count(); 86 | meta.insert("line_count".to_string(), json!(line_count)); 87 | 88 | ToolResult::Success { 89 | output, 90 | metadata: Some(meta), 91 | } 92 | }, 93 | Err(e) => { 94 | ToolResult::error(format!("Write failed: {}", e)) 95 | } 96 | } 97 | } 98 | } 99 | 100 | -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mcp.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use shai_llm::ToolDescription; 3 | use std::sync::Arc; 4 | use tokio::sync::Mutex; 5 | 6 | use crate::tools::{ToolResult, ToolCall, AnyTool, ToolCapability}; 7 | 8 | #[derive(Debug, Clone)] 9 | pub struct McpToolDescription { 10 | pub name: String, 11 | pub description: String, 12 | pub parameters_schema: serde_json::Value, 13 | } 14 | 15 | #[async_trait] 16 | pub trait McpClient: Send + Sync { 17 | async fn connect(&mut self) -> Result<(), Box>; 18 | async fn disconnect(&mut self) -> Result<(), Box>; 19 | async fn list_tools(&self) -> Result, Box>; 20 | async fn execute_tool(&self, tool_call: ToolCall) -> Result>; 21 | } 22 | 23 | pub struct WrappedMcpTool { 24 | pub desc: McpToolDescription, 25 | pub client: Arc>>, 26 | pub mcp_name: String, 27 | } 28 | 29 | impl ToolDescription for WrappedMcpTool { 30 | fn name(&self) -> String { 31 | self.desc.name.clone() 32 | } 33 | 34 | fn description(&self) -> String { 35 | self.desc.description.clone() 36 | } 37 | 38 | fn parameters_schema(&self) -> serde_json::Value { 39 | self.desc.parameters_schema.clone() 40 | } 41 | 42 | fn group(&self) -> Option<&str> { 43 | Some(&self.mcp_name) 44 | } 45 | } 46 | 47 | #[async_trait] 48 | impl AnyTool for WrappedMcpTool { 49 | fn capabilities(&self) -> &[ToolCapability] { 50 | &[ToolCapability::Network] 51 | } 52 | 53 | async fn execute_json(&self, params: serde_json::Value, cancel_token: Option) -> ToolResult { 54 | let tool_call = ToolCall { 55 | tool_call_id: format!("mcp-{}", uuid::Uuid::new_v4()), 56 | tool_name: self.desc.name.clone(), 57 | parameters: params, 58 | }; 59 | 60 | // Lock the client for execution 61 | // right now we only do one call at a time per mcp server to avoid race condition 62 | let client = self.client.lock().await; 63 | 64 | match client.execute_tool(tool_call).await { 65 | Ok(result) => result, 66 | Err(e) => ToolResult::error(format!("MCP tool execution failed: {}", e)), 67 | } 68 | } 69 | 70 | async fn execute_preview_json(&self, _params: serde_json::Value) -> Option { 71 | None // MCP tools don't support preview mode 72 | } 73 | } 74 | 75 | /// Create AnyTool instances from an MCP client 76 | pub async fn get_mcp_tools(mut client: Box, mcp_name: &str) -> Result>, Box> { 77 | // Auto-connect if not already connected 78 | client.connect().await?; 79 | 80 | let tool_descriptions = client.list_tools().await?; 81 | let client_ref = Arc::new(Mutex::new(client)); 82 | 83 | let wrapped_tools: Vec> = tool_descriptions 84 | .into_iter() 85 | .map(|desc| { 86 | Box::new(WrappedMcpTool { 87 | desc, 88 | client: client_ref.clone(), 89 | mcp_name: mcp_name.to_string(), 90 | }) as Box 91 | }) 92 | .collect(); 93 | 94 | Ok(wrapped_tools) 95 | } 96 | 97 | -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mcp_config.rs: -------------------------------------------------------------------------------- 1 | use crate::tools::McpClient; 2 | use serde::{Serialize, Deserialize}; 3 | 4 | use super::{StdioClient, HttpClient, SseClient}; 5 | 6 | #[derive(Debug, Clone, Serialize, Deserialize)] 7 | #[serde(tag = "type")] 8 | pub enum McpConfig { 9 | #[serde(rename = "stdio")] 10 | Stdio { command: String, args: Vec }, 11 | #[serde(rename = "http")] 12 | Http { url: String, bearer_token: Option }, 13 | #[serde(rename = "sse")] 14 | Sse { url: String }, 15 | } 16 | 17 | /// Factory function to create an MCP client from configuration 18 | pub fn create_mcp_client(config: McpConfig) -> Box { 19 | match config { 20 | McpConfig::Stdio { command, args } => { 21 | Box::new(StdioClient::new(command, args)) 22 | } 23 | McpConfig::Http { url, bearer_token } => { 24 | Box::new(HttpClient::new_with_auth(url, bearer_token)) 25 | } 26 | McpConfig::Sse { url } => { 27 | Box::new(SseClient::new(url)) 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mcp_http.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use rmcp::{ 3 | model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation, InitializeRequestParam}, 4 | service::{ServiceExt, RunningService}, 5 | transport::StreamableHttpClientTransport, 6 | RoleClient, 7 | }; 8 | use std::borrow::Cow; 9 | 10 | use crate::tools::{ToolResult, ToolCall}; 11 | use super::mcp::{McpClient, McpToolDescription}; 12 | 13 | pub struct HttpClient { 14 | url: String, 15 | bearer_token: Option, 16 | service: Option>, 17 | } 18 | 19 | impl HttpClient { 20 | pub fn new(url: String) -> Self { 21 | Self::new_with_auth(url, None) 22 | } 23 | 24 | pub fn new_with_auth(url: String, bearer_token: Option) -> Self { 25 | Self { 26 | url, 27 | bearer_token, 28 | service: None, 29 | } 30 | } 31 | } 32 | 33 | #[async_trait] 34 | impl McpClient for HttpClient { 35 | async fn connect(&mut self) -> Result<(), Box> { 36 | // Only connect if not already connected 37 | if self.service.is_some() { 38 | return Ok(()); 39 | } 40 | 41 | let transport = if let Some(token) = &self.bearer_token { 42 | // Create a custom reqwest client with default bearer token 43 | let mut default_headers = reqwest::header::HeaderMap::new(); 44 | default_headers.insert( 45 | reqwest::header::AUTHORIZATION, 46 | reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token))? 47 | ); 48 | let client = reqwest::Client::builder() 49 | .default_headers(default_headers) 50 | .build()?; 51 | 52 | StreamableHttpClientTransport::with_client( 53 | client, 54 | rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig { 55 | uri: self.url.clone().into(), 56 | ..Default::default() 57 | } 58 | ) 59 | } else { 60 | StreamableHttpClientTransport::from_uri(self.url.as_str()) 61 | }; 62 | 63 | let client_info = ClientInfo { 64 | protocol_version: Default::default(), 65 | capabilities: ClientCapabilities::default(), 66 | client_info: Implementation { 67 | name: "shai-mcp-http-client".to_string(), 68 | version: "0.1.0".to_string(), 69 | }, 70 | }; 71 | let service = client_info.serve(transport).await?; 72 | 73 | // Give the server a moment to process the initialization 74 | tokio::time::sleep(std::time::Duration::from_millis(100)).await; 75 | 76 | self.service = Some(service); 77 | Ok(()) 78 | } 79 | 80 | async fn disconnect(&mut self) -> Result<(), Box> { 81 | if let Some(service) = self.service.take() { 82 | service.cancel().await?; 83 | } 84 | Ok(()) 85 | } 86 | 87 | async fn list_tools(&self) -> Result, Box> { 88 | let service = self.service.as_ref().ok_or("Not connected")?; 89 | let tools_result = service.list_tools(None).await?; 90 | 91 | let tool_descriptions = tools_result 92 | .tools 93 | .into_iter() 94 | .map(|tool| McpToolDescription { 95 | name: tool.name.to_string(), 96 | description: tool.description.unwrap_or_default().to_string(), 97 | parameters_schema: serde_json::Value::Object((*tool.input_schema).clone()), 98 | }) 99 | .collect(); 100 | 101 | Ok(tool_descriptions) 102 | } 103 | 104 | async fn execute_tool(&self, tool_call: ToolCall) -> Result> { 105 | let service = self.service.as_ref().ok_or("Not connected")?; 106 | 107 | let result = service 108 | .call_tool(CallToolRequestParam { 109 | name: Cow::Owned(tool_call.tool_name.clone()), 110 | arguments: tool_call.parameters.as_object().cloned(), 111 | }) 112 | .await?; 113 | 114 | let content = result 115 | .content 116 | .into_iter() 117 | .map(|c| match c.raw { 118 | rmcp::model::RawContent::Text(text_content) => text_content.text, 119 | rmcp::model::RawContent::Image(image_data) => format!("[Image: {} bytes]", image_data.data.len()), 120 | rmcp::model::RawContent::Resource(_) => format!("[Resource]"), 121 | rmcp::model::RawContent::Audio(audio_data) => format!("[Audio: {} bytes]", audio_data.data.len()), 122 | }) 123 | .collect::>() 124 | .join("\n"); 125 | 126 | Ok(ToolResult::success(content)) 127 | } 128 | } -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mcp_sse.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use rmcp::{ 3 | model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation, InitializeRequestParam}, 4 | service::{ServiceExt, RunningService}, 5 | transport::SseClientTransport, 6 | RoleClient, 7 | }; 8 | use std::borrow::Cow; 9 | 10 | use crate::tools::{ToolResult, ToolCall}; 11 | use super::mcp::{McpClient, McpToolDescription}; 12 | 13 | pub struct SseClient { 14 | url: String, 15 | service: Option>, 16 | } 17 | 18 | impl SseClient { 19 | pub fn new(url: String) -> Self { 20 | Self { 21 | url, 22 | service: None, 23 | } 24 | } 25 | } 26 | 27 | #[async_trait] 28 | impl McpClient for SseClient { 29 | async fn connect(&mut self) -> Result<(), Box> { 30 | // Only connect if not already connected 31 | if self.service.is_some() { 32 | return Ok(()); 33 | } 34 | 35 | let transport = SseClientTransport::start(self.url.as_str()).await?; 36 | let client_info = ClientInfo { 37 | protocol_version: Default::default(), 38 | capabilities: ClientCapabilities::default(), 39 | client_info: Implementation { 40 | name: "shai-mcp-sse-client".to_string(), 41 | version: "0.1.0".to_string(), 42 | }, 43 | }; 44 | let service = client_info.serve(transport).await?; 45 | self.service = Some(service); 46 | Ok(()) 47 | } 48 | 49 | async fn disconnect(&mut self) -> Result<(), Box> { 50 | if let Some(service) = self.service.take() { 51 | service.cancel().await?; 52 | } 53 | Ok(()) 54 | } 55 | 56 | async fn list_tools(&self) -> Result, Box> { 57 | let service = self.service.as_ref().ok_or("Not connected")?; 58 | let tools_result = service.list_tools(Default::default()).await?; 59 | 60 | let tool_descriptions = tools_result 61 | .tools 62 | .into_iter() 63 | .map(|tool| McpToolDescription { 64 | name: tool.name.to_string(), 65 | description: tool.description.unwrap_or_default().to_string(), 66 | parameters_schema: serde_json::Value::Object((*tool.input_schema).clone()), 67 | }) 68 | .collect(); 69 | 70 | Ok(tool_descriptions) 71 | } 72 | 73 | async fn execute_tool(&self, tool_call: ToolCall) -> Result> { 74 | let service = self.service.as_ref().ok_or("Not connected")?; 75 | 76 | let result = service 77 | .call_tool(CallToolRequestParam { 78 | name: Cow::Owned(tool_call.tool_name.clone()), 79 | arguments: tool_call.parameters.as_object().cloned(), 80 | }) 81 | .await?; 82 | 83 | let content = result 84 | .content 85 | .into_iter() 86 | .map(|c| match c.raw { 87 | rmcp::model::RawContent::Text(text_content) => text_content.text, 88 | rmcp::model::RawContent::Image(image_data) => format!("[Image: {} bytes]", image_data.data.len()), 89 | rmcp::model::RawContent::Resource(_) => format!("[Resource]"), 90 | rmcp::model::RawContent::Audio(audio_data) => format!("[Audio: {} bytes]", audio_data.data.len()), 91 | }) 92 | .collect::>() 93 | .join("\n"); 94 | 95 | Ok(ToolResult::success(content)) 96 | } 97 | } -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mcp_stdio.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use rmcp::{ 3 | model::CallToolRequestParam, 4 | service::{ServiceExt, RunningService}, 5 | transport::TokioChildProcess, 6 | RoleClient, 7 | }; 8 | use std::borrow::Cow; 9 | use tokio::process::Command; 10 | 11 | use crate::tools::{ToolResult, ToolCall}; 12 | use super::mcp::{McpClient, McpToolDescription}; 13 | 14 | pub struct StdioClient { 15 | command: String, 16 | args: Vec, 17 | service: Option>, 18 | } 19 | 20 | impl StdioClient { 21 | pub fn new(command: String, args: Vec) -> Self { 22 | Self { 23 | command, 24 | args, 25 | service: None, 26 | } 27 | } 28 | } 29 | 30 | #[async_trait] 31 | impl McpClient for StdioClient { 32 | async fn connect(&mut self) -> Result<(), Box> { 33 | // Only connect if not already connected 34 | if self.service.is_some() { 35 | return Ok(()); 36 | } 37 | 38 | let mut cmd = Command::new(&self.command); 39 | for arg in &self.args { 40 | cmd.arg(arg); 41 | } 42 | let transport = TokioChildProcess::new(cmd)?; 43 | let service = ().serve(transport).await?; 44 | self.service = Some(service); 45 | Ok(()) 46 | } 47 | 48 | async fn disconnect(&mut self) -> Result<(), Box> { 49 | if let Some(service) = self.service.take() { 50 | service.cancel().await?; 51 | } 52 | Ok(()) 53 | } 54 | 55 | async fn list_tools(&self) -> Result, Box> { 56 | let service = self.service.as_ref().ok_or("Not connected")?; 57 | let tools_result = service.list_tools(Default::default()).await?; 58 | 59 | let tool_descriptions = tools_result 60 | .tools 61 | .into_iter() 62 | .map(|tool| McpToolDescription { 63 | name: tool.name.to_string(), 64 | description: tool.description.unwrap_or_default().to_string(), 65 | parameters_schema: serde_json::Value::Object((*tool.input_schema).clone()), 66 | }) 67 | .collect(); 68 | 69 | Ok(tool_descriptions) 70 | } 71 | 72 | async fn execute_tool(&self, tool_call: ToolCall) -> Result> { 73 | let service = self.service.as_ref().ok_or("Not connected")?; 74 | 75 | let result = service 76 | .call_tool(CallToolRequestParam { 77 | name: Cow::Owned(tool_call.tool_name.clone()), 78 | arguments: tool_call.parameters.as_object().cloned(), 79 | }) 80 | .await?; 81 | 82 | let content = result 83 | .content 84 | .into_iter() 85 | .map(|c| match c.raw { 86 | rmcp::model::RawContent::Text(text_content) => text_content.text, 87 | rmcp::model::RawContent::Image(image_data) => format!("[Image: {} bytes]", image_data.data.len()), 88 | rmcp::model::RawContent::Resource(_) => format!("[Resource]"), 89 | rmcp::model::RawContent::Audio(audio_data) => format!("[Audio: {} bytes]", audio_data.data.len()), 90 | }) 91 | .collect::>() 92 | .join("\n"); 93 | 94 | Ok(ToolResult::success(content)) 95 | } 96 | } -------------------------------------------------------------------------------- /shai-core/src/tools/mcp/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod mcp; 2 | pub mod mcp_stdio; 3 | pub mod mcp_http; 4 | pub mod mcp_sse; 5 | pub mod mcp_config; 6 | pub mod mcp_oauth; 7 | 8 | #[cfg(test)] 9 | mod tests; 10 | 11 | pub use mcp::{McpClient, McpToolDescription, get_mcp_tools}; 12 | pub use mcp_config::{McpConfig, create_mcp_client}; 13 | pub use mcp_stdio::StdioClient; 14 | pub use mcp_http::HttpClient; 15 | pub use mcp_sse::SseClient; -------------------------------------------------------------------------------- /shai-core/src/tools/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod types; 2 | pub mod highlight; 3 | pub mod todo; 4 | pub mod fs; 5 | pub mod fetch; 6 | pub mod bash; 7 | pub mod mcp; 8 | 9 | #[cfg(test)] 10 | mod tests_llm; 11 | 12 | pub use shai_macros::tool; 13 | pub use types::{Tool, ToolCall, ToolResult, ToolError, ToolCapability, AnyTool, AnyToolBox, ToolEmptyParams}; 14 | 15 | // Re-export all tools 16 | pub use bash::BashTool; 17 | pub use fetch::FetchTool; 18 | pub use fs::{EditTool, FindTool, LsTool, MultiEditTool, ReadTool, WriteTool, FsOperationLog, FsOperationType, FsOperation, FsOperationSummary}; 19 | pub use todo::{TodoReadTool, TodoWriteTool, TodoStorage, TodoItem, TodoStatus, TodoWriteParams, TodoItemInput}; 20 | pub use mcp::{McpClient, McpToolDescription, McpConfig, create_mcp_client, get_mcp_tools, StdioClient, HttpClient, SseClient}; 21 | -------------------------------------------------------------------------------- /shai-core/src/tools/todo/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod structs; 2 | pub mod todo; 3 | 4 | #[cfg(test)] 5 | mod tests; 6 | 7 | pub use structs::{TodoStorage, TodoItem, TodoStatus}; 8 | pub use todo::{TodoReadTool, TodoWriteTool, TodoWriteParams, TodoItemInput}; -------------------------------------------------------------------------------- /shai-core/src/tools/todo/structs.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use schemars::JsonSchema; 3 | use tokio::sync::RwLock; 4 | 5 | pub struct TodoStorage { 6 | store: RwLock> 7 | } 8 | 9 | impl TodoStorage { 10 | pub fn new() -> Self { 11 | Self { 12 | store: RwLock::new(Vec::new()) 13 | } 14 | } 15 | 16 | pub async fn get_all(&self) -> Vec { 17 | self.store.read().await.clone() 18 | } 19 | 20 | pub async fn replace_all(&self, items: Vec) { 21 | *self.store.write().await = items; 22 | } 23 | } 24 | 25 | 26 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 27 | pub struct TodoItem { 28 | pub id: String, 29 | pub content: String, 30 | pub status: TodoStatus, 31 | pub created_at: String, 32 | pub updated_at: String, 33 | } 34 | 35 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 36 | #[serde(rename_all = "snake_case")] 37 | #[schemars(inline)] 38 | pub enum TodoStatus { 39 | Pending, 40 | InProgress, 41 | Completed, 42 | } 43 | 44 | impl TodoItem { 45 | pub fn format_for_display(&self) -> String { 46 | let (checkbox, color_code) = match self.status { 47 | TodoStatus::Pending => ("☐", ""), 48 | TodoStatus::InProgress => ("☐", "\x1b[1;34m"), 49 | TodoStatus::Completed => ("☑", "\x1b[32m"), 50 | }; 51 | 52 | format!("{}{} {}\x1b[0m", color_code, checkbox, self.content) 53 | } 54 | } 55 | 56 | impl TodoStorage { 57 | pub fn format_all(&self, todos: &[TodoItem]) -> String { 58 | if todos.is_empty() { 59 | "No todos found. The todo list is empty.".to_string() 60 | } else { 61 | todos.iter() 62 | .map(|todo| todo.format_for_display()) 63 | .collect::>() 64 | .join("\n") 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /shai-core/src/tools/todo/todo.rs: -------------------------------------------------------------------------------- 1 | use super::{TodoItem, TodoStatus, TodoStorage}; 2 | use crate::tools::ToolEmptyParams; 3 | use crate::tools::{ToolResult, tool}; 4 | use std::sync::Arc; 5 | use serde_json::json; 6 | use serde::{Deserialize, Serialize}; 7 | use schemars::JsonSchema; 8 | use std::collections::HashMap; 9 | use chrono::Utc; 10 | use uuid::Uuid; 11 | 12 | 13 | // Input struct for creating todos 14 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 15 | #[schemars(inline)] 16 | pub struct TodoItemInput { 17 | pub content: String, 18 | pub status: TodoStatus, 19 | } 20 | 21 | impl From for TodoItem { 22 | fn from(input: TodoItemInput) -> Self { 23 | let now = Utc::now().to_rfc3339(); 24 | Self { 25 | id: Uuid::new_v4().to_string(), 26 | content: input.content, 27 | status: input.status, 28 | created_at: now.clone(), 29 | updated_at: now, 30 | } 31 | } 32 | } 33 | 34 | // Read Tool 35 | #[derive(Clone)] 36 | pub struct TodoReadTool { 37 | storage: Arc 38 | } 39 | 40 | #[tool(name = "todo_read", description = "Fetches the current to-do list for the session. Use this proactively to stay informed about the status of ongoing tasks.")] 41 | impl TodoReadTool { 42 | pub fn new(storage: Arc) -> Self { 43 | Self { storage } 44 | } 45 | 46 | async fn execute(&self, params: ToolEmptyParams) -> ToolResult { 47 | let todos = self.storage.get_all().await; 48 | 49 | let output = self.storage.format_all(&todos); 50 | 51 | ToolResult::Success { 52 | output, 53 | metadata: Some({ 54 | let mut meta = HashMap::new(); 55 | meta.insert("todo_count".to_string(), json!(todos.len())); 56 | meta 57 | }), 58 | } 59 | } 60 | } 61 | 62 | // Write Tool Parameters 63 | #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] 64 | pub struct TodoWriteParams { 65 | /// List of todos to write (replaces entire list) 66 | pub todos: Vec, 67 | } 68 | 69 | // Write Tool 70 | #[derive(Clone)] 71 | pub struct TodoWriteTool { 72 | storage: Arc 73 | } 74 | 75 | #[tool(name = "todo_write", description = "Creates and manages a structured task list for the coding session. This is vital for organizing complex work, tracking progress, and showing a clear plan.")] 76 | impl TodoWriteTool { 77 | pub fn new(storage: Arc) -> Self { 78 | Self { storage } 79 | } 80 | 81 | async fn execute(&self, params: TodoWriteParams) -> ToolResult { 82 | // Convert input todos to full TodoItems 83 | let todo_items: Vec = params.todos.into_iter().map(|input| input.into()).collect(); 84 | 85 | // Replace entire list 86 | self.storage.replace_all(todo_items.clone()).await; 87 | 88 | let output = self.storage.format_all(&todo_items); 89 | 90 | ToolResult::Success { 91 | output, 92 | metadata: Some({ 93 | let mut meta = HashMap::new(); 94 | meta.insert("todo_count".to_string(), json!(todo_items.len())); 95 | meta 96 | }), 97 | } 98 | } 99 | } 100 | 101 | 102 | 103 | #[cfg(test)] 104 | mod tests { 105 | use super::*; 106 | use shai_llm::ToolDescription; 107 | 108 | #[test] 109 | fn test_todo_read_json_schema() { 110 | let store = TodoStorage::new(); 111 | let tool = TodoReadTool::new(Arc::new(store)); 112 | let schema = tool.parameters_schema(); 113 | println!("{}", serde_json::to_string_pretty(&schema).unwrap()); 114 | } 115 | 116 | 117 | #[test] 118 | fn test_todo_write_json_schema() { 119 | let store = TodoStorage::new(); 120 | let tool = TodoWriteTool::new(Arc::new(store)); 121 | let schema = tool.parameters_schema(); 122 | println!("{}", serde_json::to_string_pretty(&schema).unwrap()); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /shai-http/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "shai-http" 3 | version = "0.1.8" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | shai-core = { path = "../shai-core" } 8 | shai-llm = { path = "../shai-llm" } 9 | 10 | # Web server 11 | axum = { version = "0.8.6", features = ["macros"] } 12 | tokio = { version = "1.0", features = ["full"] } 13 | tower = "0.5" 14 | tower-http = { version = "0.6", features = ["cors", "trace"] } 15 | 16 | # SSE and streaming 17 | tokio-stream = { version = "0.1", features = ["sync"] } 18 | futures = "0.3" 19 | async-trait = "0.1" 20 | 21 | # Serialization 22 | serde = { version = "1.0", features = ["derive"] } 23 | serde_json = "1.0" 24 | uuid = { version = "1.0", features = ["v4"] } 25 | 26 | # Logging 27 | tracing = "0.1" 28 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 29 | 30 | # Error handling 31 | thiserror = "2.0" 32 | anyhow = "1.0" 33 | 34 | # OpenAI types 35 | openai_dive = "1.3.1" 36 | chrono = { version = "0.4", features = ["serde"] } 37 | -------------------------------------------------------------------------------- /shai-http/src/apis/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod simple; 2 | pub mod openai; -------------------------------------------------------------------------------- /shai-http/src/apis/openai/completion/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod handler; 2 | pub mod formatter; 3 | 4 | pub use handler::*; 5 | -------------------------------------------------------------------------------- /shai-http/src/apis/openai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod completion; 2 | pub mod response; 3 | 4 | pub use completion::handle_chat_completion; 5 | pub use response::{handle_response, handle_get_response, handle_cancel_response}; 6 | -------------------------------------------------------------------------------- /shai-http/src/apis/openai/response/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod handler; 2 | pub mod types; 3 | pub mod formatter; 4 | 5 | pub use handler::{handle_response, handle_get_response, handle_cancel_response}; -------------------------------------------------------------------------------- /shai-http/src/apis/simple/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod types; 2 | pub mod handler; 3 | pub mod formatter; 4 | 5 | pub use types::{MultiModalQuery, Message}; 6 | pub use handler::handle_multimodal_query_stream; 7 | pub use formatter::SimpleFormatter; -------------------------------------------------------------------------------- /shai-http/src/apis/simple/types.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | 4 | #[derive(Debug, Clone, Serialize, Deserialize)] 5 | pub struct ToolCall { 6 | pub tool: String, 7 | #[serde(default)] 8 | pub args: HashMap, 9 | #[serde(skip_serializing_if = "Option::is_none")] 10 | pub output: Option, 11 | } 12 | 13 | #[derive(Debug, Clone, Serialize, Deserialize)] 14 | pub struct ToolCallResult { 15 | #[serde(skip_serializing_if = "Option::is_none")] 16 | pub text: Option, 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub text_stream: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub image: Option, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | pub speech: Option, 23 | #[serde(skip_serializing_if = "Option::is_none")] 24 | pub other: Option, 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | pub error: Option, 27 | #[serde(skip_serializing_if = "Option::is_none")] 28 | pub extra: Option>, 29 | } 30 | 31 | #[derive(Debug, Clone, Serialize, Deserialize)] 32 | pub struct PreviousCall { 33 | pub call: ToolCall, 34 | pub result: ToolCallResult, 35 | } 36 | 37 | #[derive(Debug, Clone, Serialize, Deserialize)] 38 | pub struct UserMessage { 39 | pub message: String, 40 | #[serde(skip_serializing_if = "Option::is_none")] 41 | pub attached_files: Option>, // { filename: base64file, ... } 42 | } 43 | 44 | #[derive(Debug, Clone, Serialize, Deserialize)] 45 | pub struct AssistantMessage { 46 | pub assistant: String, 47 | } 48 | 49 | #[derive(Debug, Clone, Serialize, Deserialize)] 50 | #[serde(untagged)] 51 | pub enum Message { 52 | User(UserMessage), 53 | Assistant(AssistantMessage), 54 | PreviousCall(PreviousCall), 55 | } 56 | 57 | /// Agent capability configuration 58 | #[derive(Debug, Clone, Serialize, Deserialize)] 59 | pub struct AgentCapability { 60 | #[serde(rename = "type")] 61 | pub tool_type: String, // "capability" 62 | #[serde(skip_serializing_if = "Option::is_none")] 63 | pub thinking: Option, 64 | #[serde(skip_serializing_if = "Option::is_none")] 65 | pub internet: Option, 66 | #[serde(skip_serializing_if = "Option::is_none")] 67 | pub image: Option, 68 | #[serde(skip_serializing_if = "Option::is_none")] 69 | pub speech: Option, 70 | } 71 | 72 | /// OpenAI API tool configuration 73 | #[derive(Debug, Clone, Serialize, Deserialize)] 74 | pub struct OpenAiApi { 75 | #[serde(rename = "type")] 76 | pub tool_type: String, // "openai" 77 | pub url: String, 78 | pub description: String, 79 | pub model: String, 80 | } 81 | 82 | /// MCP tool configuration 83 | #[derive(Debug, Clone, Serialize, Deserialize)] 84 | pub struct McpTool { 85 | #[serde(rename = "type")] 86 | pub tool_type: String, // "mcp" 87 | pub url: String, 88 | } 89 | 90 | /// Discriminated union of tool types 91 | #[derive(Debug, Clone, Serialize, Deserialize)] 92 | #[serde(tag = "type")] 93 | pub enum AgentTool { 94 | #[serde(rename = "capability")] 95 | Capability { 96 | #[serde(skip_serializing_if = "Option::is_none")] 97 | thinking: Option, 98 | #[serde(skip_serializing_if = "Option::is_none")] 99 | internet: Option, 100 | #[serde(skip_serializing_if = "Option::is_none")] 101 | image: Option, 102 | #[serde(skip_serializing_if = "Option::is_none")] 103 | speech: Option, 104 | }, 105 | #[serde(rename = "openai")] 106 | OpenAi { 107 | url: String, 108 | description: String, 109 | model: String, 110 | }, 111 | #[serde(rename = "mcp")] 112 | Mcp { 113 | url: String, 114 | }, 115 | } 116 | 117 | #[derive(Debug, Clone, Serialize, Deserialize)] 118 | pub struct MultiModalQuery { 119 | pub model: String, 120 | #[serde(default)] 121 | pub stream: bool, 122 | #[serde(skip_serializing_if = "Option::is_none")] 123 | pub messages: Option>, 124 | #[serde(skip_serializing_if = "Option::is_none")] 125 | pub tools: Option>, 126 | } 127 | 128 | #[derive(Debug, Clone, Serialize, Deserialize)] 129 | pub struct MultiModalStreamingResponse { 130 | pub id: String, 131 | pub model: String, 132 | #[serde(skip_serializing_if = "Option::is_none")] 133 | pub assistant: Option, 134 | #[serde(skip_serializing_if = "Option::is_none")] 135 | pub call: Option, 136 | #[serde(skip_serializing_if = "Option::is_none")] 137 | pub result: Option, 138 | } 139 | 140 | #[derive(Debug, Clone, Serialize, Deserialize)] 141 | #[serde(untagged)] 142 | pub enum ResponseMessage { 143 | Assistant(AssistantMessage), 144 | PreviousCall(PreviousCall), 145 | } 146 | 147 | #[derive(Debug, Clone, Serialize, Deserialize)] 148 | pub struct MultiModalResponse { 149 | pub id: String, 150 | pub model: String, 151 | pub result: Vec, 152 | } -------------------------------------------------------------------------------- /shai-http/src/error.rs: -------------------------------------------------------------------------------- 1 | use axum::{ 2 | extract::{rejection::JsonRejection, FromRequest}, 3 | http::StatusCode, 4 | response::{IntoResponse, Response, Json}, 5 | }; 6 | use serde::{Deserialize, Serialize}; 7 | use tracing::error; 8 | 9 | /// Error response structure for API errors 10 | #[derive(Debug, Serialize, Deserialize)] 11 | pub struct ErrorResponse { 12 | pub error: ErrorDetail, 13 | } 14 | 15 | #[derive(Debug, Serialize, Deserialize)] 16 | pub struct ErrorDetail { 17 | pub message: String, 18 | pub r#type: String, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub code: Option, 21 | } 22 | 23 | impl ErrorResponse { 24 | pub fn new(message: String, error_type: String, code: Option) -> Self { 25 | Self { 26 | error: ErrorDetail { 27 | message, 28 | r#type: error_type, 29 | code, 30 | }, 31 | } 32 | } 33 | 34 | pub fn not_found(message: String) -> Self { 35 | Self::new(message, "not_found".to_string(), Some("model_not_found".to_string())) 36 | } 37 | 38 | pub fn invalid_request(message: String) -> Self { 39 | Self::new(message, "invalid_request".to_string(), None) 40 | } 41 | 42 | pub fn internal_error(message: String) -> Self { 43 | Self::new(message, "internal_error".to_string(), None) 44 | } 45 | } 46 | 47 | impl IntoResponse for ErrorResponse { 48 | fn into_response(self) -> Response { 49 | let status = match self.error.r#type.as_str() { 50 | "not_found" => StatusCode::NOT_FOUND, 51 | "invalid_request" => StatusCode::BAD_REQUEST, 52 | _ => StatusCode::INTERNAL_SERVER_ERROR, 53 | }; 54 | (status, Json(self)).into_response() 55 | } 56 | } 57 | 58 | /// Custom JSON extractor that returns our ErrorResponse on deserialization failures 59 | #[derive(FromRequest)] 60 | #[from_request(via(axum::Json), rejection(ErrorResponse))] 61 | pub struct ApiJson(pub T); 62 | 63 | impl From for ErrorResponse { 64 | fn from(rejection: JsonRejection) -> Self { 65 | let message = rejection.body_text(); 66 | error!("JSON deserialization error: {}", message); 67 | ErrorResponse::invalid_request(message) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /shai-http/src/http.rs: -------------------------------------------------------------------------------- 1 | use axum::{ 2 | routing::{get, post}, 3 | Router, 4 | }; 5 | use std::sync::Arc; 6 | use tower_http::cors::CorsLayer; 7 | use tracing::info; 8 | 9 | use crate::session::{SessionManager, SessionManagerConfig}; 10 | use crate::apis; 11 | 12 | /// Configuration for the HTTP server 13 | #[derive(Clone, Debug)] 14 | pub struct ServerConfig { 15 | /// Server bind address (e.g., "127.0.0.1:8080") 16 | pub address: String, 17 | /// Session manager configuration 18 | pub session_manager: SessionManagerConfig, 19 | } 20 | 21 | impl ServerConfig { 22 | /// Create a new server config with the given address and default session manager config 23 | pub fn new(address: String) -> Self { 24 | Self { 25 | address, 26 | session_manager: SessionManagerConfig::default(), 27 | } 28 | } 29 | 30 | /// Set whether sessions are ephemeral by default 31 | pub fn with_ephemeral(mut self, ephemeral: bool) -> Self { 32 | self.session_manager.ephemeral = ephemeral; 33 | self 34 | } 35 | 36 | /// Set the maximum number of concurrent sessions 37 | pub fn with_max_sessions(mut self, max_sessions: Option) -> Self { 38 | self.session_manager.max_sessions = max_sessions; 39 | self 40 | } 41 | } 42 | 43 | /// Server state holding the session manager 44 | #[derive(Clone)] 45 | pub struct ServerState { 46 | pub session_manager: Arc, 47 | } 48 | 49 | 50 | /// Start the HTTP server with SSE streaming 51 | pub async fn start_server( 52 | config: ServerConfig, 53 | ) -> Result<(), Box> { 54 | // Create session manager 55 | let session_manager = SessionManager::new(config.session_manager.clone()); 56 | 57 | println!("✓ Session manager initialized"); 58 | if let Some(max) = config.session_manager.max_sessions { 59 | println!(" Max sessions: \x1b[1m{}\x1b[0m", max); 60 | } else { 61 | println!(" Max sessions: \x1b[1munlimited\x1b[0m"); 62 | } 63 | println!(" Default mode: \x1b[1m{}\x1b[0m", if config.session_manager.ephemeral { "ephemeral" } else { "persistent" }); 64 | println!(); 65 | 66 | let state = ServerState { 67 | session_manager: Arc::new(session_manager), 68 | }; 69 | 70 | let app = Router::new() 71 | // Simple API 72 | .route("/v1/multimodal", post(apis::simple::handle_multimodal_query_stream)) 73 | .route("/v1/multimodal/{session_id}", post(apis::simple::handle_multimodal_query_stream)) 74 | // OpenAI-compatible Response API 75 | .route("/v1/responses", post(apis::openai::handle_response)) 76 | .route("/v1/responses/{response_id}", get(apis::openai::handle_get_response)) 77 | .route("/v1/responses/{response_id}/cancel", post(apis::openai::handle_cancel_response)) 78 | // OpenAI-compatible Chat Completion API 79 | .route("/v1/chat/completions", post(apis::openai::handle_chat_completion)) 80 | .layer(CorsLayer::permissive()) 81 | .with_state(state); 82 | 83 | let listener = tokio::net::TcpListener::bind(&config.address).await?; 84 | 85 | // Print server info 86 | println!("Server starting on \x1b[1mhttp://{}\x1b[0m", config.address); 87 | println!("\nAvailable endpoints:"); 88 | println!(" \x1b[1mPOST /v1/chat/completions\x1b[0m - OpenAI Chat Completions API (ephemeral)"); 89 | println!(" \x1b[1mPOST /v1/responses\x1b[0m - OpenAI Responses API (stateful/stateless)"); 90 | println!(" \x1b[1mGET /v1/responses/:id\x1b[0m - Get response by ID"); 91 | println!(" \x1b[1mPOST /v1/responses/:id/cancel\x1b[0m - Cancel a response"); 92 | println!(" \x1b[1mPOST /v1/multimodal\x1b[0m - Simple multimodal API (streaming)"); 93 | println!(" \x1b[1mPOST /v1/multimodal/:session_id\x1b[0m - Simple multimodal API (with session)"); 94 | 95 | // List available agents 96 | use shai_core::config::agent::AgentConfig; 97 | match AgentConfig::list_agents() { 98 | Ok(agents) if !agents.is_empty() => { 99 | println!("\nAvailable agents: \x1b[2m{}\x1b[0m", agents.join(", ")); 100 | } 101 | _ => {} 102 | } 103 | 104 | println!("\nPress Ctrl+C to stop\n"); 105 | 106 | info!("HTTP server listening on {}", config.address); 107 | 108 | axum::serve(listener, app).await?; 109 | Ok(()) 110 | } -------------------------------------------------------------------------------- /shai-http/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod http; 2 | pub mod apis; 3 | pub mod error; 4 | pub mod session; 5 | pub mod streaming; 6 | 7 | pub use error::{ApiJson, ErrorResponse}; 8 | pub use session::{SessionManager, SessionManagerConfig, AgentSession}; 9 | pub use streaming::{EventFormatter, event_to_sse_stream, session_to_sse_stream}; 10 | pub use http::{ServerConfig, ServerState, start_server}; -------------------------------------------------------------------------------- /shai-http/src/session/lifecycle.rs: -------------------------------------------------------------------------------- 1 | use shai_core::agent::AgentController; 2 | use tokio::sync::OwnedMutexGuard; 3 | use tracing::info; 4 | 5 | use crate::session::logger::colored_session_id; 6 | 7 | 8 | pub enum RequestLifecycle { 9 | Background { 10 | controller_guard: OwnedMutexGuard, 11 | request_id: String, 12 | session_id: String, 13 | }, 14 | Ephemeral { 15 | controller_guard: OwnedMutexGuard, 16 | request_id: String, 17 | session_id: String, 18 | }, 19 | } 20 | 21 | impl RequestLifecycle { 22 | pub fn new(ephemeral: bool, controller_guard: OwnedMutexGuard, request_id: String, session_id: String) -> Self { 23 | match ephemeral { 24 | true => Self::Ephemeral { controller_guard, request_id, session_id }, 25 | false => Self::Background { controller_guard, request_id, session_id }, 26 | } 27 | } 28 | } 29 | 30 | impl Drop for RequestLifecycle { 31 | fn drop(&mut self) { 32 | match self { 33 | Self::Background { request_id, session_id, .. } => { 34 | info!( 35 | "[{}] - {} Stream completed, releasing controller lock (background session)", 36 | request_id, 37 | colored_session_id(session_id) 38 | ); 39 | } 40 | Self::Ephemeral { controller_guard, request_id, session_id } => { 41 | info!( 42 | "[{}] - {} Stream completed, destroying agent (ephemeral session)", 43 | request_id, 44 | colored_session_id(session_id) 45 | ); 46 | 47 | // Clone before moving into async task 48 | let ctrl = controller_guard.clone(); 49 | tokio::spawn(async move { 50 | let _ = ctrl.terminate().await; 51 | }); 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /shai-http/src/session/logger.rs: -------------------------------------------------------------------------------- 1 | use std::hash::{Hash, Hasher}; 2 | use std::collections::hash_map::DefaultHasher; 3 | use shai_core::agent::AgentEvent; 4 | use tracing::{debug, error, info}; 5 | 6 | fn color_for_session(session_id: &str) -> u8 { 7 | let mut hasher = DefaultHasher::new(); 8 | session_id.hash(&mut hasher); 9 | let hash = hasher.finish(); 10 | // pick one of 216 “cube” colors from 16–231 11 | 16 + (hash % 216) as u8 12 | } 13 | 14 | pub fn colored_session_id(session_id: &str) -> String { 15 | let color = color_for_session(session_id); 16 | format!("\x1b[38;5;{}msid={}\x1b[0m", color, session_id) 17 | } 18 | 19 | pub fn log_event(event: &AgentEvent, session_id: &str) { 20 | let session_id = colored_session_id(session_id); 21 | match event { 22 | AgentEvent::ToolCallStarted { call, .. } => { 23 | debug!("{} - ToolCall: {}", session_id, call.tool_name); 24 | } 25 | AgentEvent::ToolCallCompleted { call, result, duration, .. } => { 26 | use shai_core::tools::ToolResult; 27 | match result { 28 | ToolResult::Success { .. } => { 29 | debug!("{} - ToolResult: {} ✓ ({}ms)", 30 | session_id, call.tool_name, duration.num_milliseconds()); 31 | } 32 | ToolResult::Error { error, .. } => { 33 | let error_oneline = error.lines().next().unwrap_or(error); 34 | debug!("{} - ToolResult: {} ✗ {}", 35 | session_id, call.tool_name, error_oneline); 36 | } 37 | ToolResult::Denied => { 38 | debug!("{} - ToolResult: {} ⊘ denied", 39 | session_id, call.tool_name); 40 | } 41 | } 42 | } 43 | AgentEvent::BrainResult { .. } => { 44 | debug!("{} - BrainResult", session_id); 45 | } 46 | AgentEvent::StatusChanged { old_status, new_status } => { 47 | debug!("{} - Status: {:?} ← {:?}", 48 | session_id, new_status, old_status); 49 | } 50 | AgentEvent::Error { error } => { 51 | error!("{} - Error: {}", session_id, error); 52 | } 53 | AgentEvent::Completed { success, message } => { 54 | info!("{} - Completed: success={} msg={}", 55 | session_id, success, message); 56 | } 57 | _ => {} 58 | } 59 | } -------------------------------------------------------------------------------- /shai-http/src/session/mod.rs: -------------------------------------------------------------------------------- 1 | mod lifecycle; 2 | mod session; 3 | mod manager; 4 | mod logger; 5 | 6 | pub use logger::log_event; 7 | pub use lifecycle::{RequestLifecycle}; 8 | pub use session::{AgentSession, RequestSession}; 9 | pub use manager::{SessionManager, SessionManagerConfig}; 10 | 11 | -------------------------------------------------------------------------------- /shai-http/src/session/session.rs: -------------------------------------------------------------------------------- 1 | use shai_core::agent::{AgentController, AgentError, AgentEvent}; 2 | use openai_dive::v1::resources::chat::ChatMessage; 3 | use std::sync::Arc; 4 | use tokio::sync::{broadcast::Receiver, Mutex}; 5 | use tokio::task::JoinHandle; 6 | use tracing::info; 7 | use crate::session::logger::colored_session_id; 8 | 9 | use super::RequestLifecycle; 10 | 11 | 12 | /// Represents a single HTTP request session with automatic lifecycle management 13 | pub struct RequestSession { 14 | pub controller: AgentController, 15 | pub event_rx: Receiver, 16 | pub lifecycle: RequestLifecycle 17 | } 18 | 19 | /// A single agent session - represents one running agent instance 20 | /// Can be ephemeral (destroyed after request) or persistent (kept alive) 21 | /// Each request holds a guard against the controller so that only one query is processed per session 22 | /// - In background mode (ephemeral=false), the session survives the request and the guard is simply drop 23 | /// - In ephemeral mode (ephemeral=true), the entire session stops and is deleted once the query ends or the client disconnect 24 | pub struct AgentSession { 25 | controller: Arc>, 26 | event_rx: Receiver, 27 | logging_task: JoinHandle<()>, 28 | agent_task: JoinHandle<()>, 29 | 30 | pub session_id: String, 31 | pub agent_name: String, 32 | pub ephemeral: bool, 33 | } 34 | 35 | impl AgentSession { 36 | pub fn new( 37 | session_id: String, 38 | controller: AgentController, 39 | event_rx: Receiver, 40 | agent_task: JoinHandle<()>, 41 | logging_task: JoinHandle<()>, 42 | agent_name: Option, 43 | ephemeral: bool, 44 | ) -> Self { 45 | let agent_name_display = agent_name.unwrap_or_else(|| "default".to_string()); 46 | 47 | Self { 48 | controller: Arc::new(Mutex::new(controller)), 49 | event_rx, 50 | logging_task, 51 | agent_task, 52 | session_id, 53 | agent_name: agent_name_display, 54 | ephemeral: ephemeral, 55 | } 56 | } 57 | 58 | /// Terminate a session 59 | pub async fn cancel(&self, http_request_id: &String) -> Result<(), AgentError> { 60 | let ctrl = self.controller.clone().lock_owned().await; 61 | info!("[{}] - {} cancelling session", http_request_id, colored_session_id(&self.session_id)); 62 | ctrl.terminate().await 63 | } 64 | 65 | /// Subscribe to events from this session (read-only, non-blocking) 66 | /// Used for GET /v1/responses/{response_id} to observe an ongoing session 67 | pub fn watch(&self) -> Receiver { 68 | self.event_rx.resubscribe() 69 | } 70 | 71 | /// Handle a request for this agent session 72 | /// Returns a RequestSession that manages the lifecycle 73 | pub async fn handle_request(&self, http_request_id: &String, trace: Vec) -> Result { 74 | let controller_guard = self.controller.clone().lock_owned().await; 75 | controller_guard.wait_turn(None).await?; 76 | info!("[{}] - {} handling request", http_request_id, colored_session_id(&self.session_id)); 77 | 78 | controller_guard.send_trace(trace).await?; 79 | 80 | let event_rx = self.event_rx.resubscribe(); 81 | let controller = controller_guard.clone(); 82 | let lifecycle = RequestLifecycle::new(self.ephemeral, controller_guard, http_request_id.clone(), self.session_id.clone()); 83 | 84 | Ok(RequestSession{controller, event_rx, lifecycle}) 85 | } 86 | 87 | pub fn is_ephemeral(&self) -> bool { 88 | self.ephemeral 89 | } 90 | } 91 | 92 | impl Drop for AgentSession { 93 | fn drop(&mut self) { 94 | self.agent_task.abort(); 95 | self.logging_task.abort(); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /shai-llm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "shai-llm" 3 | version = "0.1.8" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | # New LLM module dependencies 8 | async-trait = "0.1" 9 | openssl = { version = "0.10", features = ["vendored"] } 10 | native-tls = { version = "0.2", features = ["vendored"] } 11 | reqwest = { version = "0.12", features = ["json", "stream"] } 12 | futures = "0.3" 13 | serde = { version = "1.0", features = ["derive"] } 14 | serde_json = "1.0" 15 | uuid = { version = "1.0", features = ["v4"] } 16 | tokio = { version = "1.0", features = ["full"] } 17 | openai_dive = { version = "1.3.1", features = ["stream"] } 18 | async-stream = "0.3" 19 | reqwest-eventsource = "0.6" 20 | regex = "1.0" 21 | schemars = "1.0.1" 22 | shai-macros = { path = "../shai-macros" } 23 | fastrand = "2.0" 24 | chrono = { version = "0.4", features = ["serde"] } 25 | 26 | [dev-dependencies] 27 | paste = "1.0" 28 | 29 | [lints.rust] 30 | dead_code = "allow" 31 | unused_variables = "allow" 32 | unused_mut = "allow" 33 | unused_imports = "allow" 34 | 35 | [[example]] 36 | name = "basic_query" 37 | path = "src/examples/basic_query.rs" 38 | 39 | [[example]] 40 | name = "query_with_history" 41 | path = "src/examples/query_with_history.rs" 42 | 43 | [[example]] 44 | name = "streaming_query" 45 | path = "src/examples/streaming_query.rs" 46 | 47 | [[example]] 48 | name = "function_calling" 49 | path = "src/examples/function_calling.rs" 50 | 51 | [[example]] 52 | name = "function_calling_streaming" 53 | path = "src/examples/function_calling_streaming.rs" 54 | -------------------------------------------------------------------------------- /shai-llm/src/examples/basic_query.rs: -------------------------------------------------------------------------------- 1 | // Basic query example showing simple chat completion with Mistral 2 | use shai_llm::{client::LlmClient, provider::LlmError}; 3 | use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent}; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<(), LlmError> { 7 | // Initialize Mistral client from environment variable (MISTRAL_API_KEY) 8 | let client = LlmClient::from_env_mistral() 9 | .expect("MISTRAL_API_KEY environment variable not set"); 10 | 11 | // Get the default model 12 | let model = client.provider().default_model().await?; 13 | println!("Using model: {}", model); 14 | 15 | // Create a simple chat request 16 | let request = ChatCompletionParameters { 17 | model: model, 18 | messages: vec![ 19 | ChatMessage::User { 20 | content: ChatMessageContent::Text("What is the capital of France?".to_string()), 21 | name: None, 22 | } 23 | ], 24 | temperature: Some(0.7), 25 | max_tokens: Some(100), 26 | ..Default::default() 27 | }; 28 | 29 | // Send the request and get response 30 | let response = client.chat(request).await?; 31 | 32 | // Print the response 33 | if let Some(choice) = response.choices.first() { 34 | match &choice.message { 35 | ChatMessage::Assistant { content: Some(ChatMessageContent::Text(text)), .. } => { 36 | println!("Response: {}", text); 37 | } 38 | _ => println!("No text response received"), 39 | } 40 | } 41 | 42 | Ok(()) 43 | } -------------------------------------------------------------------------------- /shai-llm/src/examples/mod.rs: -------------------------------------------------------------------------------- 1 | // Examples module for shai-llm 2 | // 3 | // This module contains practical examples demonstrating various use cases: 4 | // 5 | // - basic_query.rs: Simple chat completion 6 | // - query_with_history.rs: Multi-turn conversation with context 7 | // - streaming_query.rs: Real-time streaming responses 8 | // - function_calling.rs: Tool/function calling capabilities 9 | // 10 | // To run examples: 11 | // cargo run --example basic_query 12 | // cargo run --example query_with_history 13 | // cargo run --example streaming_query 14 | // cargo run --example function_calling 15 | 16 | pub mod basic_query; 17 | pub mod query_with_history; 18 | pub mod streaming_query; 19 | pub mod function_calling; -------------------------------------------------------------------------------- /shai-llm/src/examples/query_with_history.rs: -------------------------------------------------------------------------------- 1 | // Query with conversation history example 2 | use shai_llm::{client::LlmClient, provider::LlmError}; 3 | use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent}; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<(), LlmError> { 7 | // Initialize Mistral client from environment variable (MISTRAL_API_KEY) 8 | let client = LlmClient::from_env_mistral() 9 | .expect("MISTRAL_API_KEY environment variable not set"); 10 | 11 | // Get the default model 12 | let model = client.provider().default_model().await?; 13 | println!("Using model: {}", model); 14 | 15 | // Start conversation with system message and initial user message 16 | let mut conversation = vec![ 17 | ChatMessage::System { 18 | content: ChatMessageContent::Text("You are a helpful assistant that provides concise answers.".to_string()), 19 | name: None, 20 | }, 21 | ChatMessage::User { 22 | content: ChatMessageContent::Text("What is the capital of France?".to_string()), 23 | name: None, 24 | } 25 | ]; 26 | 27 | // First request 28 | let request = ChatCompletionParameters { 29 | model: model.clone(), 30 | messages: conversation.clone(), 31 | temperature: Some(0.7), 32 | max_tokens: Some(100), 33 | ..Default::default() 34 | }; 35 | 36 | let response = client.chat(request).await?; 37 | 38 | // Add assistant's response to conversation history 39 | if let Some(choice) = response.choices.first() { 40 | println!("Q: What is the capital of France?"); 41 | match &choice.message { 42 | ChatMessage::Assistant { content: Some(ChatMessageContent::Text(text)), .. } => { 43 | println!("A: {}", text); 44 | conversation.push(choice.message.clone()); 45 | } 46 | _ => println!("No text response received"), 47 | } 48 | } 49 | 50 | // Follow-up question using conversation history 51 | conversation.push(ChatMessage::User { 52 | content: ChatMessageContent::Text("What is the population of that city?".to_string()), 53 | name: None, 54 | }); 55 | 56 | let follow_up_request = ChatCompletionParameters { 57 | model: model, 58 | messages: conversation, 59 | temperature: Some(0.7), 60 | max_tokens: Some(100), 61 | ..Default::default() 62 | }; 63 | 64 | let follow_up_response = client.chat(follow_up_request).await?; 65 | 66 | // Print follow-up response 67 | if let Some(choice) = follow_up_response.choices.first() { 68 | println!("\nQ: What is the population of that city?"); 69 | match &choice.message { 70 | ChatMessage::Assistant { content: Some(ChatMessageContent::Text(text)), .. } => { 71 | println!("A: {}", text); 72 | } 73 | _ => println!("No text response received"), 74 | } 75 | } 76 | 77 | Ok(()) 78 | } -------------------------------------------------------------------------------- /shai-llm/src/examples/streaming_query.rs: -------------------------------------------------------------------------------- 1 | // Streaming query example showing real-time response chunks 2 | use shai_llm::{client::LlmClient, provider::LlmError}; 3 | use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent}; 4 | use futures::StreamExt; 5 | use openai_dive::v1::resources::chat::DeltaChatMessage; 6 | 7 | #[tokio::main] 8 | async fn main() -> Result<(), LlmError> { 9 | // Initialize Mistral client from environment variable (MISTRAL_API_KEY) 10 | let client = LlmClient::from_env_mistral() 11 | .expect("MISTRAL_API_KEY environment variable not set"); 12 | 13 | // Get the default model 14 | let model = client.provider().default_model().await?; 15 | println!("Using model: {}", model); 16 | 17 | // Create a streaming chat request 18 | let request = ChatCompletionParameters { 19 | model: model, 20 | messages: vec![ 21 | ChatMessage::User { 22 | content: ChatMessageContent::Text("Write a short story about a robot learning to paint.".to_string()), 23 | name: None, 24 | } 25 | ], 26 | temperature: Some(0.8), 27 | max_tokens: Some(300), 28 | stream: Some(true), 29 | ..Default::default() 30 | }; 31 | 32 | // Send the streaming request 33 | let mut stream = client.chat_stream(request).await?; 34 | 35 | println!("Streaming response:"); 36 | println!("=================="); 37 | 38 | let mut full_response = String::new(); 39 | 40 | // Process each chunk as it arrives 41 | while let Some(chunk_result) = stream.next().await { 42 | match chunk_result { 43 | Ok(chunk) => { 44 | // Extract content from the first choice if available 45 | if let Some(choice) = chunk.choices.first() { 46 | match &choice.delta { 47 | DeltaChatMessage::Assistant { content: Some(ChatMessageContent::Text(text)), .. } | 48 | DeltaChatMessage::Untagged { content: Some(ChatMessageContent::Text(text)), .. } => { 49 | if !text.is_empty() { 50 | print!("{}", text); 51 | full_response.push_str(text); 52 | // Flush stdout to show text immediately 53 | use std::io::{self, Write}; 54 | io::stdout().flush().unwrap(); 55 | } 56 | } 57 | _ => { 58 | // Handle other delta types or empty content 59 | } 60 | } 61 | } 62 | } 63 | Err(e) => { 64 | eprintln!("\nError processing stream chunk: {:?}", e); 65 | break; 66 | } 67 | } 68 | } 69 | 70 | println!("\n\n=================="); 71 | println!("Full response length: {} characters", full_response.len()); 72 | 73 | Ok(()) 74 | } -------------------------------------------------------------------------------- /shai-llm/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod client; 2 | pub mod providers; 3 | pub mod provider; 4 | pub mod chat; 5 | pub mod tool; 6 | 7 | // Re-export our client 8 | pub use client::LlmClient; 9 | 10 | pub use tool::{ 11 | ToolDescription, 12 | ToolCallMethod, 13 | ToolBox, 14 | ContainsTool, 15 | StructuredOutputBuilder, 16 | AssistantResponse, 17 | IntoChatMessage, 18 | FunctionCallingAutoBuilder, 19 | FunctionCallingRequiredBuilder}; 20 | 21 | -------------------------------------------------------------------------------- /shai-llm/src/provider.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use async_trait::async_trait; 3 | use futures::Stream; 4 | use std::error::Error; 5 | use openai_dive::v1::endpoints::chat::Chat; 6 | use openai_dive::v1::resources::{ 7 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 8 | model::ListModelResponse, 9 | }; 10 | 11 | pub type LlmError = Box; 12 | pub type LlmStream = Box> + Send + Unpin>; 13 | 14 | #[derive(Debug, Clone)] 15 | pub struct EnvVar { 16 | pub name: String, 17 | pub description: String, 18 | pub required: bool, 19 | } 20 | 21 | #[derive(Debug, Clone)] 22 | pub struct ProviderInfo { 23 | pub name: &'static str, 24 | pub display_name: &'static str, 25 | pub env_vars: Vec, 26 | } 27 | 28 | impl EnvVar { 29 | pub fn required(name: &str, description: &str) -> Self { 30 | Self { 31 | name: name.to_string(), 32 | description: description.to_string(), 33 | required: true, 34 | } 35 | } 36 | 37 | pub fn optional(name: &str, description: &str) -> Self { 38 | Self { 39 | name: name.to_string(), 40 | description: description.to_string(), 41 | required: false, 42 | } 43 | } 44 | } 45 | 46 | #[async_trait] 47 | pub trait LlmProvider: Send + Sync { 48 | async fn models(&self) -> Result; 49 | 50 | async fn default_model(&self) -> Result { 51 | let models = self.models().await?; 52 | models.data 53 | .first() 54 | .map(|m| m.id.clone()) 55 | .ok_or_else(|| "no model available".into()) 56 | } 57 | 58 | async fn chat(&self, request: ChatCompletionParameters) -> Result; 59 | 60 | async fn chat_stream(&self, request: ChatCompletionParameters) -> Result; 61 | 62 | fn supports_functions(&self, model: String) -> bool; 63 | 64 | fn supports_structured_output(&self, model: String) -> bool; 65 | 66 | fn name(&self) -> &'static str; 67 | 68 | /// Returns provider information including environment variables 69 | fn info() -> ProviderInfo where Self: Sized; 70 | } 71 | 72 | impl Debug for dyn LlmProvider { 73 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 74 | let debug = format!("LlmProvider({})", self.name()); 75 | write!(f, "{}", debug) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /shai-llm/src/providers/anthropic/api.rs: -------------------------------------------------------------------------------- 1 | use serde::{Serialize, Deserialize}; 2 | 3 | // Anthropic-specific streaming event types 4 | #[derive(Debug, Clone, Serialize, Deserialize)] 5 | #[serde(tag = "type", rename_all = "snake_case")] 6 | pub enum AnthropicStreamEvent { 7 | MessageStart { 8 | message: AnthropicMessage, 9 | }, 10 | ContentBlockStart { 11 | index: u32, 12 | content_block: AnthropicContentBlock, 13 | }, 14 | ContentBlockDelta { 15 | index: u32, 16 | delta: AnthropicDelta, 17 | }, 18 | ContentBlockStop { 19 | index: u32, 20 | }, 21 | MessageDelta { 22 | delta: AnthropicMessageDelta, 23 | usage: Option, 24 | }, 25 | MessageStop, 26 | Ping, 27 | Error { 28 | #[serde(flatten)] 29 | error: serde_json::Value, 30 | }, 31 | } 32 | 33 | #[derive(Debug, Clone, Serialize, Deserialize)] 34 | pub struct AnthropicMessage { 35 | pub id: String, 36 | #[serde(rename = "type")] 37 | pub message_type: String, 38 | pub role: String, 39 | pub content: Vec, 40 | pub model: String, 41 | pub stop_reason: Option, 42 | pub stop_sequence: Option, 43 | pub usage: AnthropicUsage, 44 | } 45 | 46 | #[derive(Debug, Clone, Serialize, Deserialize)] 47 | pub struct AnthropicContentBlock { 48 | #[serde(rename = "type")] 49 | pub block_type: String, 50 | pub text: Option, 51 | } 52 | 53 | #[derive(Debug, Clone, Serialize, Deserialize)] 54 | #[serde(tag = "type", rename_all = "snake_case")] 55 | pub enum AnthropicDelta { 56 | TextDelta { text: String }, 57 | InputJsonDelta { partial_json: String }, 58 | ThinkingDelta { thinking: String }, 59 | } 60 | 61 | #[derive(Debug, Clone, Serialize, Deserialize)] 62 | pub struct AnthropicMessageDelta { 63 | pub stop_reason: Option, 64 | pub stop_sequence: Option, 65 | } 66 | 67 | #[derive(Debug, Clone, Serialize, Deserialize)] 68 | pub struct AnthropicUsage { 69 | #[serde(skip_serializing_if = "Option::is_none")] 70 | pub input_tokens: Option, 71 | pub output_tokens: u32, 72 | } 73 | 74 | pub const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com/v1"; -------------------------------------------------------------------------------- /shai-llm/src/providers/anthropic/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | pub mod anthropic; 3 | pub mod tests; 4 | 5 | pub use anthropic::AnthropicProvider; -------------------------------------------------------------------------------- /shai-llm/src/providers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod openai; 2 | pub mod openai_compatible; 3 | pub mod openrouter; 4 | pub mod ovhcloud; 5 | pub mod anthropic; 6 | pub mod ollama; 7 | pub mod mistral; 8 | // pub mod mistral_native; // TODO: Complete implementation 9 | 10 | #[cfg(test)] 11 | mod tests; 12 | -------------------------------------------------------------------------------- /shai-llm/src/providers/ollama.rs: -------------------------------------------------------------------------------- 1 | // llm/providers/ovhcloud.rs 2 | use crate::provider::{LlmProvider, LlmError, LlmStream, ProviderInfo, EnvVar}; 3 | use async_trait::async_trait; 4 | use futures::StreamExt; 5 | use openai_dive::v1::{ 6 | api::Client, 7 | resources::{ 8 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 9 | model::ListModelResponse, 10 | }, 11 | error::APIError 12 | }; 13 | 14 | const OLLAMA_BASE_URL: &str = "http://127.0.0.1:11434/v1"; 15 | 16 | pub struct OllamaProvider { 17 | client: Client, 18 | } 19 | 20 | impl OllamaProvider { 21 | pub fn new(base_url: Option) -> Self { 22 | let mut client = Client::new(String::new()); 23 | let url = base_url.unwrap_or_else(|| OLLAMA_BASE_URL.to_string()); 24 | client.set_base_url(&url); 25 | Self { client } 26 | } 27 | 28 | /// Create OVH Cloud provider from environment variables 29 | /// Returns None if required environment variables are not set 30 | pub fn from_env() -> Option { 31 | std::env::var("OLLAMA_BASE_URL").ok().map(|api_key| { 32 | let base_url = std::env::var("OLLAMA_BASE_URL").ok(); 33 | Self::new(base_url) 34 | }) 35 | } 36 | } 37 | 38 | #[async_trait] 39 | impl LlmProvider for OllamaProvider { 40 | async fn models(&self) -> Result { 41 | let response = self.client.models().list().await 42 | .map_err(|e| Box::new(e) as LlmError)?; 43 | Ok(response) 44 | } 45 | 46 | async fn default_model(&self) -> Result { 47 | let models = self.models().await?; // Get the models 48 | 49 | models.data.iter() 50 | .find(|m| m.id.to_lowercase().contains("smol")) 51 | .or_else(|| models.data.first()) 52 | .map(|m| m.id.clone()) 53 | .ok_or_else(|| "no model available".into()) 54 | } 55 | 56 | async fn chat(&self, request: ChatCompletionParameters) -> Result { 57 | let response = self.client.chat().create(request).await 58 | .map_err(|e| Box::new(e) as LlmError)?; 59 | Ok(response) 60 | } 61 | 62 | async fn chat_stream(&self, mut request: ChatCompletionParameters) -> Result { 63 | request.stream = Some(true); 64 | 65 | let stream = self.client.chat().create_stream(request).await 66 | .map_err(|e| Box::new(e) as LlmError)?; 67 | 68 | let converted_stream = stream.map(|result| { 69 | result.map_err(|e| Box::new(e) as LlmError) 70 | }); 71 | 72 | Ok(Box::new(Box::pin(converted_stream))) 73 | } 74 | 75 | fn supports_functions(&self, model: String) -> bool { 76 | true 77 | } 78 | 79 | fn supports_structured_output(&self, model: String) -> bool { 80 | true 81 | } 82 | 83 | fn name(&self) -> &'static str { 84 | "ollama" 85 | } 86 | 87 | fn info() -> ProviderInfo { 88 | ProviderInfo { 89 | name: "ollama", 90 | display_name: "Ollama", 91 | env_vars: vec![ 92 | EnvVar::optional("OLLAMA_BASE_URL", "ollama base open ai compat url"), 93 | ], 94 | } 95 | } 96 | 97 | } 98 | 99 | -------------------------------------------------------------------------------- /shai-llm/src/providers/openai.rs: -------------------------------------------------------------------------------- 1 | // llm/providers/openai.rs 2 | use crate::provider::{LlmProvider, LlmError, LlmStream, ProviderInfo, EnvVar}; 3 | use async_trait::async_trait; 4 | use futures::StreamExt; 5 | use openai_dive::v1::{ 6 | api::Client, 7 | resources::{ 8 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 9 | model::ListModelResponse, 10 | }, 11 | }; 12 | 13 | pub struct OpenAIProvider { 14 | client: Client, 15 | } 16 | 17 | impl OpenAIProvider { 18 | pub fn new(api_key: String) -> Self { 19 | let mut client = Client::new(api_key); 20 | client.set_base_url("https://api.openai.com/v1"); 21 | Self { client } 22 | } 23 | 24 | /// Create OpenAI provider from environment variables 25 | /// Returns None if required environment variables are not set 26 | pub fn from_env() -> Option { 27 | std::env::var("OPENAI_API_KEY").ok().map(|api_key| { 28 | Self::new(api_key) 29 | }) 30 | } 31 | } 32 | 33 | #[async_trait] 34 | impl LlmProvider for OpenAIProvider { 35 | async fn models(&self) -> Result { 36 | let response = self.client.models().list().await 37 | .map_err(|e| Box::new(e) as LlmError)?; 38 | Ok(response) 39 | } 40 | 41 | async fn default_model(&self) -> Result { 42 | let models = self.models().await?; // Get the models 43 | 44 | models.data.iter() 45 | .find(|m| m.id.to_lowercase().contains("gpt4")) 46 | .or_else(|| models.data.first()) 47 | .map(|m| m.id.clone()) 48 | .ok_or_else(|| "no model available".into()) 49 | } 50 | 51 | async fn chat(&self, request: ChatCompletionParameters) -> Result { 52 | let response = self.client.chat().create(request).await 53 | .map_err(|e| Box::new(e) as LlmError)?; 54 | Ok(response) 55 | } 56 | 57 | async fn chat_stream(&self, mut request: ChatCompletionParameters) -> Result { 58 | // Ensure streaming is enabled 59 | request.stream = Some(true); 60 | 61 | let stream = self.client.chat().create_stream(request).await 62 | .map_err(|e| Box::new(e) as LlmError)?; 63 | 64 | let converted_stream = stream.map(|result| { 65 | result.map_err(|e| Box::new(e) as LlmError) 66 | }); 67 | 68 | Ok(Box::new(Box::pin(converted_stream))) 69 | } 70 | 71 | fn supports_functions(&self, model: String) -> bool { 72 | true 73 | } 74 | 75 | fn supports_structured_output(&self, model: String) -> bool { 76 | true 77 | } 78 | 79 | fn name(&self) -> &'static str { 80 | "openai" 81 | } 82 | 83 | fn info() -> ProviderInfo { 84 | ProviderInfo { 85 | name: "openai", 86 | display_name: "OpenAI (GPT-4, GPT-3.5)", 87 | env_vars: vec![ 88 | EnvVar::required("OPENAI_API_KEY", "OpenAI API key"), 89 | ], 90 | } 91 | } 92 | 93 | } 94 | 95 | -------------------------------------------------------------------------------- /shai-llm/src/providers/openai_compatible.rs: -------------------------------------------------------------------------------- 1 | // llm/providers/openai_compatible.rs 2 | use crate::provider::{LlmProvider, LlmError, LlmStream, ProviderInfo, EnvVar}; 3 | use async_trait::async_trait; 4 | use futures::StreamExt; 5 | use openai_dive::v1::{ 6 | api::Client, 7 | resources::{ 8 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 9 | model::ListModelResponse, 10 | shared::Usage, 11 | }, 12 | }; 13 | use serde_json::Value; 14 | 15 | pub struct OpenAICompatibleProvider { 16 | client: Client, 17 | } 18 | 19 | impl OpenAICompatibleProvider { 20 | pub fn new(api_key: String, base_url: String) -> Self { 21 | let mut client = Client::new(api_key); 22 | client.set_base_url(&base_url); 23 | Self { client } 24 | } 25 | 26 | /// Create OpenAI Compatible provider from environment variables 27 | /// Returns None if required environment variables are not set 28 | pub fn from_env() -> Option { 29 | match (std::env::var("OPENAI_COMPATIBLE_API_KEY"), std::env::var("OPENAI_COMPATIBLE_BASE_URL")) { 30 | (Ok(api_key), Ok(base_url)) => { 31 | Some(Self::new(api_key, base_url)) 32 | } 33 | _ => None 34 | } 35 | } 36 | } 37 | 38 | #[async_trait] 39 | impl LlmProvider for OpenAICompatibleProvider { 40 | async fn models(&self) -> Result { 41 | let response = self.client.models().list().await 42 | .map_err(|e| Box::new(e) as LlmError)?; 43 | Ok(response) 44 | } 45 | 46 | async fn chat(&self, request: ChatCompletionParameters) -> Result { 47 | let mut response = self.client.chat().create(request).await 48 | .map_err(|e| Box::new(e) as LlmError)?; 49 | 50 | Ok(response) 51 | } 52 | 53 | async fn chat_stream(&self, mut request: ChatCompletionParameters) -> Result { 54 | // Ensure streaming is enabled 55 | request.stream = Some(true); 56 | 57 | let stream = self.client.chat().create_stream(request).await 58 | .map_err(|e| Box::new(e) as LlmError)?; 59 | 60 | let converted_stream = stream.map(|result| { 61 | result.map_err(|e| Box::new(e) as LlmError) 62 | }); 63 | 64 | Ok(Box::new(Box::pin(converted_stream))) 65 | } 66 | 67 | fn supports_functions(&self, model: String) -> bool { 68 | true 69 | } 70 | 71 | fn supports_structured_output(&self, model: String) -> bool { 72 | true 73 | } 74 | 75 | fn name(&self) -> &'static str { 76 | "openai_compatible" 77 | } 78 | 79 | fn info() -> ProviderInfo { 80 | ProviderInfo { 81 | name: "openai_compatible", 82 | display_name: "OpenAI Compatible API", 83 | env_vars: vec![ 84 | EnvVar::required("OPENAI_COMPATIBLE_API_KEY", "API key for OpenAI-compatible service"), 85 | EnvVar::required("OPENAI_COMPATIBLE_BASE_URL", "Base URL for OpenAI-compatible service"), 86 | ], 87 | } 88 | } 89 | 90 | } 91 | 92 | -------------------------------------------------------------------------------- /shai-llm/src/providers/openrouter/api.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Clone, Deserialize, Serialize)] 4 | pub struct OpenRouterModelsResponse { 5 | pub data: Vec, 6 | } 7 | 8 | #[derive(Debug, Clone, Deserialize, Serialize)] 9 | pub struct OpenRouterModel { 10 | pub id: String, 11 | pub name: String, 12 | pub created: i64, 13 | pub description: String, 14 | pub architecture: OpenRouterArchitecture, 15 | pub top_provider: OpenRouterTopProvider, 16 | pub pricing: OpenRouterPricing, 17 | pub context_length: i64, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub hugging_face_id: Option, 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub per_request_limits: Option, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | pub supported_parameters: Option>, 24 | } 25 | 26 | #[derive(Debug, Clone, Deserialize, Serialize)] 27 | pub struct OpenRouterArchitecture { 28 | pub input_modalities: Vec, 29 | pub output_modalities: Vec, 30 | pub tokenizer: String, 31 | } 32 | 33 | #[derive(Debug, Clone, Deserialize, Serialize)] 34 | pub struct OpenRouterTopProvider { 35 | pub is_moderated: bool, 36 | } 37 | 38 | #[derive(Debug, Clone, Deserialize, Serialize)] 39 | pub struct OpenRouterPricing { 40 | pub prompt: String, 41 | pub completion: String, 42 | #[serde(skip_serializing_if = "Option::is_none")] 43 | pub image: Option, 44 | #[serde(skip_serializing_if = "Option::is_none")] 45 | pub request: Option, 46 | #[serde(skip_serializing_if = "Option::is_none")] 47 | pub input_cache_read: Option, 48 | #[serde(skip_serializing_if = "Option::is_none")] 49 | pub input_cache_write: Option, 50 | #[serde(skip_serializing_if = "Option::is_none")] 51 | pub web_search: Option, 52 | #[serde(skip_serializing_if = "Option::is_none")] 53 | pub internal_reasoning: Option, 54 | } 55 | 56 | impl OpenRouterModel { 57 | /// Convert OpenRouter model to openai_dive Model format 58 | pub fn to_openai_model(&self) -> openai_dive::v1::resources::model::Model { 59 | openai_dive::v1::resources::model::Model { 60 | id: self.id.clone(), 61 | object: "model".to_string(), 62 | created: Some(self.created as u32), 63 | owned_by: "openrouter".to_string(), 64 | } 65 | } 66 | } 67 | 68 | impl OpenRouterModelsResponse { 69 | /// Convert OpenRouter models response to openai_dive ListModelResponse format 70 | pub fn to_openai_models_response(&self) -> openai_dive::v1::resources::model::ListModelResponse { 71 | openai_dive::v1::resources::model::ListModelResponse { 72 | object: "list".to_string(), 73 | data: self.data.iter().map(|m| m.to_openai_model()).collect(), 74 | } 75 | } 76 | } -------------------------------------------------------------------------------- /shai-llm/src/providers/openrouter/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | pub mod openrouter; 3 | 4 | pub use openrouter::OpenRouterProvider; -------------------------------------------------------------------------------- /shai-llm/src/providers/openrouter/openrouter.rs: -------------------------------------------------------------------------------- 1 | use crate::provider::{LlmProvider, LlmError, LlmStream, ProviderInfo, EnvVar}; 2 | use super::api::OpenRouterModelsResponse; 3 | use async_trait::async_trait; 4 | use futures::StreamExt; 5 | use reqwest; 6 | use openai_dive::v1::{ 7 | api::Client, 8 | resources::{ 9 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 10 | model::ListModelResponse, 11 | }, 12 | }; 13 | 14 | const OPENROUTER_API_BASE: &str = "https://openrouter.ai/api/v1"; 15 | 16 | pub struct OpenRouterProvider { 17 | client: Client, 18 | api_key: String, 19 | base_url: String, 20 | http_client: reqwest::Client, 21 | } 22 | 23 | impl OpenRouterProvider { 24 | pub fn new(api_key: String) -> Self { 25 | let mut client = Client::new(api_key.clone()); 26 | client.set_base_url(OPENROUTER_API_BASE); 27 | Self { 28 | client, 29 | api_key, 30 | base_url: OPENROUTER_API_BASE.to_string(), 31 | http_client: reqwest::Client::new(), 32 | } 33 | } 34 | 35 | /// Create OpenRouter provider from environment variables 36 | /// Returns None if required environment variables are not set 37 | pub fn from_env() -> Option { 38 | std::env::var("OPENROUTER_API_KEY").ok().map(|api_key| { 39 | Self::new(api_key) 40 | }) 41 | } 42 | 43 | /// Get OpenRouter models using their native API format 44 | pub async fn openrouter_models(&self) -> Result { 45 | let url = format!("{}/models", self.base_url); 46 | 47 | let response = self.http_client 48 | .get(&url) 49 | .header("Authorization", format!("Bearer {}", self.api_key)) 50 | .header("Content-Type", "application/json") 51 | .send() 52 | .await 53 | .map_err(|e| Box::new(e) as LlmError)?; 54 | 55 | if !response.status().is_success() { 56 | let status = response.status(); 57 | let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); 58 | return Err(Box::new(std::io::Error::new( 59 | std::io::ErrorKind::Other, 60 | format!("OpenRouter API error {}: {}", status, text) 61 | )) as LlmError); 62 | } 63 | 64 | let openrouter_response: OpenRouterModelsResponse = response 65 | .json() 66 | .await 67 | .map_err(|e| Box::new(e) as LlmError)?; 68 | 69 | Ok(openrouter_response) 70 | } 71 | } 72 | 73 | #[async_trait] 74 | impl LlmProvider for OpenRouterProvider { 75 | async fn models(&self) -> Result { 76 | let openrouter_response = self.openrouter_models().await?; 77 | Ok(openrouter_response.to_openai_models_response()) 78 | } 79 | 80 | 81 | async fn default_model(&self) -> Result { 82 | let models = self.models().await?; 83 | 84 | let keywords = ["free"]; 85 | models.data.iter() 86 | .find(|m| keywords.iter().any(|kw| m.id.to_lowercase().contains(kw))) 87 | .or_else(|| models.data.first()) 88 | .map(|m| m.id.clone()) 89 | .ok_or_else(|| "no model available".into()) 90 | } 91 | 92 | async fn chat(&self, request: ChatCompletionParameters) -> Result { 93 | let response = self.client.chat().create(request).await 94 | .map_err(|e| Box::new(e) as LlmError)?; 95 | Ok(response) 96 | } 97 | 98 | async fn chat_stream(&self, mut request: ChatCompletionParameters) -> Result { 99 | // Ensure streaming is enabled 100 | request.stream = Some(true); 101 | 102 | let stream = self.client.chat().create_stream(request).await 103 | .map_err(|e| Box::new(e) as LlmError)?; 104 | 105 | let converted_stream = stream.map(|result| { 106 | result.map_err(|e| Box::new(e) as LlmError) 107 | }); 108 | 109 | Ok(Box::new(Box::pin(converted_stream))) 110 | } 111 | 112 | fn supports_functions(&self, model: String) -> bool { 113 | true 114 | } 115 | 116 | fn supports_structured_output(&self, model: String) -> bool { 117 | true 118 | } 119 | 120 | fn name(&self) -> &'static str { 121 | "openrouter" 122 | } 123 | 124 | fn info() -> ProviderInfo { 125 | ProviderInfo { 126 | name: "openrouter", 127 | display_name: "OpenRouter (Multiple AI Providers)", 128 | env_vars: vec![ 129 | EnvVar::required("OPENROUTER_API_KEY", "OpenRouter API key"), 130 | ], 131 | } 132 | } 133 | 134 | } 135 | 136 | -------------------------------------------------------------------------------- /shai-llm/src/providers/ovhcloud.rs: -------------------------------------------------------------------------------- 1 | // llm/providers/ovhcloud.rs 2 | use crate::provider::{LlmProvider, LlmError, LlmStream, ProviderInfo, EnvVar}; 3 | use async_trait::async_trait; 4 | use futures::StreamExt; 5 | use openai_dive::v1::{ 6 | api::Client, 7 | resources::{ 8 | chat::{ChatCompletionParameters, ChatCompletionResponse, ChatCompletionChunkResponse}, 9 | model::ListModelResponse, 10 | shared::Usage, 11 | }, 12 | error::APIError 13 | }; 14 | use serde_json::Value; 15 | 16 | const OVH_API_BASE: &str = "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1"; 17 | 18 | pub struct OvhCloudProvider { 19 | client: Client, 20 | } 21 | 22 | impl OvhCloudProvider { 23 | pub fn new(api_key: String, base_url: Option) -> Self { 24 | let mut client = Client::new(api_key); 25 | let url = base_url.unwrap_or_else(|| OVH_API_BASE.to_string()); 26 | client.set_base_url(&url); 27 | Self { client } 28 | } 29 | 30 | /// Create OVH Cloud provider from environment variables 31 | /// Returns None if required environment variables are not set 32 | pub fn from_env() -> Option { 33 | std::env::var("OVH_API_KEY").ok().map(|api_key| { 34 | let base_url = std::env::var("OVH_BASE_URL").ok(); 35 | Self::new(api_key, base_url) 36 | }) 37 | } 38 | 39 | fn sanitize_request(&self, mut request: ChatCompletionParameters) -> ChatCompletionParameters { 40 | // OVH uses max_tokens instead of max_completion_tokens 41 | if request.max_completion_tokens.is_some() { 42 | request.max_tokens = request.max_completion_tokens; 43 | request.max_completion_tokens = None; 44 | } 45 | 46 | request 47 | } 48 | } 49 | 50 | #[async_trait] 51 | impl LlmProvider for OvhCloudProvider { 52 | async fn models(&self) -> Result { 53 | let response = self.client.models().list().await 54 | .map_err(|e| Box::new(e) as LlmError)?; 55 | Ok(response) 56 | } 57 | 58 | async fn default_model(&self) -> Result { 59 | let models = self.models().await?; // Get the models 60 | 61 | models.data.iter() 62 | .find(|m| m.id.to_lowercase().contains("nemo")) 63 | .or_else(|| models.data.first()) 64 | .map(|m| m.id.clone()) 65 | .ok_or_else(|| "no model available".into()) 66 | } 67 | 68 | async fn chat(&self, request: ChatCompletionParameters) -> Result { 69 | let sanitized_request = self.sanitize_request(request); 70 | let mut response = self.client.chat().create(sanitized_request).await 71 | .map_err(|e| Box::new(e) as LlmError)?; 72 | 73 | Ok(response) 74 | } 75 | 76 | async fn chat_stream(&self, mut request: ChatCompletionParameters) -> Result { 77 | request.stream = Some(true); 78 | let sanitized_request = self.sanitize_request(request); 79 | 80 | let stream = self.client.chat().create_stream(sanitized_request).await 81 | .map_err(|e| Box::new(e) as LlmError)?; 82 | 83 | let converted_stream = stream.map(|result| { 84 | result.map_err(|e| Box::new(e) as LlmError) 85 | }); 86 | 87 | Ok(Box::new(Box::pin(converted_stream))) 88 | } 89 | 90 | fn supports_functions(&self, model: String) -> bool { 91 | true 92 | } 93 | 94 | fn supports_structured_output(&self, model: String) -> bool { 95 | true 96 | } 97 | 98 | fn name(&self) -> &'static str { 99 | "ovhcloud" 100 | } 101 | 102 | fn info() -> ProviderInfo { 103 | ProviderInfo { 104 | name: "ovhcloud", 105 | display_name: "OVHcloud AI Endpoints", 106 | env_vars: vec![ 107 | EnvVar::required("OVH_API_KEY", "OVHcloud API key"), 108 | EnvVar::optional("OVH_BASE_URL", "OVHcloud base URL (defaults to standard endpoint)"), 109 | ], 110 | } 111 | } 112 | 113 | } 114 | 115 | -------------------------------------------------------------------------------- /shai-llm/src/tool/call.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use async_trait::async_trait; 3 | 4 | use openai_dive::v1::resources::chat::{ChatCompletionFunction, ChatCompletionParameters, ChatCompletionParametersBuilder, ChatCompletionResponse, ChatCompletionTool, ChatCompletionToolChoice, ChatCompletionToolType, ChatMessage}; 5 | 6 | use crate::{provider::LlmError, tool::{call_fc_auto::ToolCallFunctionCallingAuto, call_fc_required::ToolCallFunctionCallingRequired, call_structured_output::ToolCallStructuredOutput, ToolBox}, LlmClient, ToolCallMethod, ToolDescription}; 7 | 8 | 9 | #[async_trait] 10 | pub trait LlmToolCall { 11 | async fn chat_with_tools( 12 | &self, 13 | request: ChatCompletionParameters, 14 | tools: &ToolBox, 15 | method: ToolCallMethod 16 | ) -> Result; 17 | } 18 | 19 | #[async_trait] 20 | impl LlmToolCall for LlmClient { 21 | async fn chat_with_tools( 22 | &self, 23 | request: ChatCompletionParameters, 24 | tools: &ToolBox, 25 | method: ToolCallMethod 26 | ) -> Result { 27 | match method { 28 | ToolCallMethod::Auto => { 29 | self.chat_with_tools_try_all(request, tools).await 30 | } 31 | ToolCallMethod::FunctionCall => { 32 | self.chat_with_tools_fc_auto(request, tools).await 33 | } 34 | ToolCallMethod::FunctionCallRequired => { 35 | self.chat_with_tools_fc_required(request, tools).await 36 | } 37 | ToolCallMethod::StructuredOutput => { 38 | self.chat_with_tools_so(request, tools).await 39 | } 40 | ToolCallMethod::Parsing => { 41 | Err(LlmError::from("method not supported")) 42 | } 43 | } 44 | } 45 | } 46 | 47 | #[async_trait] 48 | pub trait ToolCallAuto { 49 | async fn chat_with_tools_try_all( 50 | &self, 51 | request: ChatCompletionParameters, 52 | tools: &ToolBox 53 | ) -> Result; 54 | } 55 | 56 | #[async_trait] 57 | impl ToolCallAuto for LlmClient { 58 | async fn chat_with_tools_try_all( 59 | &self, 60 | request: ChatCompletionParameters, 61 | tools: &ToolBox 62 | ) -> Result { 63 | if let Ok(result) = self.chat_with_tools_fc_auto(request.clone(), tools).await { 64 | return Ok(result); 65 | } 66 | 67 | if let Ok(result) = self.chat_with_tools_fc_required(request.clone(), tools).await { 68 | return Ok(result); 69 | } 70 | 71 | self.chat_with_tools_so(request, tools).await 72 | } 73 | } -------------------------------------------------------------------------------- /shai-llm/src/tool/call_fc_auto.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use async_trait::async_trait; 3 | 4 | use openai_dive::v1::resources::chat::{ChatCompletionFunction, ChatCompletionParameters, ChatCompletionParametersBuilder, ChatCompletionResponse, ChatCompletionTool, ChatCompletionToolChoice, ChatCompletionToolType, ChatMessage}; 5 | 6 | use crate::{provider::LlmError, tool::ToolBox, LlmClient, ToolDescription}; 7 | 8 | pub trait FunctionCallingAutoBuilder { 9 | fn with_function_calling_auto(&mut self, tools: &ToolBox) -> &mut Self; 10 | } 11 | 12 | impl FunctionCallingAutoBuilder for ChatCompletionParametersBuilder { 13 | fn with_function_calling_auto(&mut self, tools: &ToolBox) -> &mut Self { 14 | self 15 | .tools(tools.iter().map(|t| { 16 | ChatCompletionTool { 17 | r#type: ChatCompletionToolType::Function, 18 | function: ChatCompletionFunction { 19 | name: t.name().to_string(), 20 | description: Some(t.description().to_string()), 21 | parameters: t.parameters_schema(), 22 | }, 23 | } 24 | }).collect::>()) 25 | .tool_choice(ChatCompletionToolChoice::Auto) 26 | } 27 | } 28 | 29 | #[async_trait] 30 | pub trait ToolCallFunctionCallingAuto { 31 | async fn chat_with_tools_fc_auto( 32 | &self, 33 | request: ChatCompletionParameters, 34 | tools: &ToolBox 35 | ) -> Result; 36 | } 37 | 38 | #[async_trait] 39 | impl ToolCallFunctionCallingAuto for LlmClient { 40 | async fn chat_with_tools_fc_auto( 41 | &self, 42 | request: ChatCompletionParameters, 43 | tools: &ToolBox 44 | ) -> Result { 45 | let request = ChatCompletionParametersBuilder::default() 46 | .model(&request.model) 47 | .messages(request.messages.clone()) 48 | .with_function_calling_auto(&tools) 49 | .temperature(0.3) 50 | .build() 51 | .map_err(|e| LlmError::from(e.to_string()))?; 52 | 53 | let response = self 54 | .chat(request.clone()) 55 | .await 56 | .inspect_err(|e| { 57 | // Save failed request to file for debugging 58 | let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); 59 | if let Ok(json) = serde_json::to_string_pretty(&request) { 60 | let filename = format!("logs/request_{}.json", timestamp); 61 | let _ = std::path::Path::new(&filename).parent() 62 | .map(std::fs::create_dir_all).unwrap_or(Ok(())) 63 | .and_then(|_| std::fs::write(&filename, json)); 64 | } 65 | }) 66 | .map_err(|e| LlmError::from(e.to_string()))?; 67 | 68 | Ok(response) 69 | } 70 | } -------------------------------------------------------------------------------- /shai-llm/src/tool/call_fc_required.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use async_trait::async_trait; 3 | use schemars::json_schema; 4 | use serde_json::json; 5 | 6 | use openai_dive::v1::resources::chat::{ChatCompletionFunction, ChatCompletionParameters, ChatCompletionParametersBuilder, ChatCompletionResponse, ChatCompletionTool, ChatCompletionToolChoice, ChatCompletionToolType, ChatMessage, Function, ToolCall}; 7 | use crate::{provider::LlmError, tool::ToolBox, LlmClient, ToolDescription}; 8 | 9 | 10 | pub struct NoOp {} 11 | 12 | impl ToolDescription for NoOp { 13 | fn name(&self) -> String { 14 | "no_op".to_string() 15 | } 16 | 17 | fn description(&self) -> String { 18 | "this tool is a no_op and does nothing. This tool must be called if you don't want to call any tool.".to_string() 19 | } 20 | 21 | fn parameters_schema(&self) -> serde_json::Value { 22 | serde_json::json!({}) 23 | } 24 | } 25 | 26 | pub trait FunctionCallingRequiredBuilder { 27 | fn with_function_calling_required(&mut self, tools: &ToolBox) -> &mut Self; 28 | } 29 | 30 | impl FunctionCallingRequiredBuilder for ChatCompletionParametersBuilder { 31 | fn with_function_calling_required(&mut self, tools: &ToolBox) -> &mut Self { 32 | let mut tools = tools.clone(); 33 | tools.push(Arc::new(NoOp{})); 34 | 35 | self 36 | .tools(tools.iter().map(|t| { 37 | ChatCompletionTool { 38 | r#type: ChatCompletionToolType::Function, 39 | function: ChatCompletionFunction { 40 | name: t.name().to_string(), 41 | description: Some(t.description().to_string()), 42 | parameters: t.parameters_schema(), 43 | }, 44 | } 45 | }).collect::>()) 46 | .tool_choice(ChatCompletionToolChoice::Required) 47 | } 48 | } 49 | 50 | #[async_trait] 51 | pub trait ToolCallFunctionCallingRequired { 52 | async fn chat_with_tools_fc_required( 53 | &self, 54 | request: ChatCompletionParameters, 55 | tools: &ToolBox 56 | ) -> Result; 57 | } 58 | 59 | #[async_trait] 60 | impl ToolCallFunctionCallingRequired for LlmClient { 61 | async fn chat_with_tools_fc_required( 62 | &self, 63 | request: ChatCompletionParameters, 64 | tools: &ToolBox 65 | ) -> Result { 66 | let request = ChatCompletionParametersBuilder::default() 67 | .model(&request.model) 68 | .messages(request.messages.clone()) 69 | .with_function_calling_required(&tools) 70 | .temperature(0.3) 71 | .build() 72 | .map_err(|e| LlmError::from(e.to_string()))?; 73 | 74 | let mut response = self 75 | .chat(request.clone()) 76 | .await 77 | .inspect_err(|e| { 78 | // Save failed request to file for debugging 79 | let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); 80 | if let Ok(json) = serde_json::to_string_pretty(&request) { 81 | let filename = format!("logs/request_{}.json", timestamp); 82 | let _ = std::path::Path::new(&filename).parent() 83 | .map(std::fs::create_dir_all).unwrap_or(Ok(())) 84 | .and_then(|_| std::fs::write(&filename, json)); 85 | } 86 | }) 87 | .map_err(|e| LlmError::from(e.to_string()))?; 88 | 89 | let mut response = response; 90 | match &mut response.choices[0].message { 91 | ChatMessage::Assistant { tool_calls, .. } => { 92 | if let Some(calls) = tool_calls { 93 | if let [ToolCall { function: Function { name, .. }, .. }] = calls.as_slice() { 94 | if name == "no_op" { 95 | *tool_calls = None; 96 | } 97 | } 98 | } 99 | } 100 | _ => {} 101 | } 102 | 103 | Ok(response) 104 | } 105 | } -------------------------------------------------------------------------------- /shai-llm/src/tool/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod tool; 2 | pub mod call; 3 | pub mod call_fc_auto; 4 | pub mod call_fc_required; 5 | pub mod call_structured_output; 6 | 7 | #[cfg(test)] 8 | mod test_so; 9 | 10 | pub use tool::{ToolDescription, ToolCallMethod, ToolBox, ContainsTool}; 11 | pub use call::{LlmToolCall,ToolCallAuto}; 12 | pub use call_structured_output::{AssistantResponse, StructuredOutputBuilder, IntoChatMessage}; 13 | pub use call_fc_auto::FunctionCallingAutoBuilder; 14 | pub use call_fc_required::FunctionCallingRequiredBuilder; -------------------------------------------------------------------------------- /shai-llm/src/tool/tool.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 5 | pub enum ToolCallMethod { 6 | /// let the system decide what technique to use 7 | Auto, 8 | /// use function call api with tool choice set to auto 9 | FunctionCall, 10 | /// use function call api with tool choice set to required (gave special tool for "no tool") 11 | FunctionCallRequired, 12 | /// use response_format to force structured output, add tool documentation in system prompt 13 | StructuredOutput, 14 | /// instruct llm to use special tag and parse the response from content, add tool documentation in system prompt 15 | Parsing, 16 | } 17 | 18 | /// A tool must be able to describe its parameter as a json schema 19 | pub trait ToolDescription: Send + Sync { 20 | 21 | fn name(&self) -> String; 22 | 23 | fn description(&self) -> String; 24 | 25 | fn parameters_schema(&self) -> serde_json::Value; 26 | 27 | /// Return the group name for this tool (e.g., "builtin", "mcp_ovh") 28 | fn group(&self) -> Option<&str> { 29 | None 30 | } 31 | } 32 | 33 | /// A toolbox is a set of tool 34 | pub type ToolBox = Vec>; 35 | 36 | pub trait ContainsTool { 37 | fn contains_tool(&self, name: &str) -> bool; 38 | } 39 | 40 | impl ContainsTool for ToolBox { 41 | fn contains_tool(&self, name: &str) -> bool { 42 | self.iter().any(|tool| tool.name() == name) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /shai-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "shai-macros" 3 | version = "0.1.8" 4 | edition = "2021" 5 | 6 | [lib] 7 | proc-macro = true 8 | 9 | [dependencies] 10 | proc-macro2 = "1.0" 11 | quote = "1.0" 12 | syn = { version = "2.0", features = ["full"] } --------------------------------------------------------------------------------