├── .DS_Store ├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── client │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── file_system │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── lib.rs │ │ ├── main.rs │ │ └── server.rs ├── knowledge_graph_memory │ ├── Cargo.toml │ └── src │ │ ├── main.rs │ │ └── types.rs └── pingpong │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── client.rs │ ├── lib.rs │ ├── main.rs │ └── server.rs ├── mcp_test.sh └── src ├── client.rs ├── lib.rs ├── protocol.rs ├── registry.rs ├── server.rs ├── sse ├── http_server.rs ├── middleware.rs └── mod.rs ├── transport ├── http_transport.rs ├── inmemory_transport.rs ├── mod.rs ├── sse_transport.rs ├── stdio_transport.rs └── ws_transport.rs └── types.rs /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v3g42/async-mcp/bb77f9656d64993199dcd55b3bf208a67fd3735d/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | pull_request: 5 | branches: [ "main" ] 6 | 7 | env: 8 | CARGO_TERM_COLOR: always 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Build 18 | run: cargo build --verbose 19 | - name: Run tests 20 | run: cargo test --verbose 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # RustRover 17 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 18 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 19 | # and can be added to the global gitignore or merged into this file. For a more nuclear 20 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 21 | #.idea/ 22 | 23 | # Added by cargo 24 | 25 | /target 26 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | ".", 4 | "examples/client", 5 | "examples/file_system", 6 | "examples/knowledge_graph_memory", 7 | "examples/pingpong", 8 | ] 9 | default-members = ["examples/file_system", "examples/pingpong"] 10 | # Your existing package configuration stays here 11 | [package] 12 | name = "async-mcp" 13 | version = "0.1.2" 14 | edition = "2021" 15 | description = "Async Implementation of Model Context Protocol (MCP)" 16 | repository = "https://github.com/v3g42/async-mcp" 17 | license = "Apache-2.0" 18 | authors = ["https://github.com/v3g42"] 19 | documentation = "https://github.com/v3g42/async-mcp#readme" 20 | homepage = "https://github.com/v3g42/async-mcp" 21 | keywords = ["async", "mcp", "protocol", "Anthropic"] 22 | categories = ["asynchronous", "network-programming"] 23 | readme = "README.md" 24 | [dependencies] 25 | tokio = { version = "1.0", features = ["full"] } 26 | serde = { version = "1.0", features = ["derive"] } 27 | serde_json = "1.0" 28 | anyhow = "1.0" 29 | async-trait = "0.1" 30 | url = { version = "2.5", features = ["serde"] } 31 | tracing = "0.1" 32 | reqwest = { version = "0.12.12", features = ["stream", "json"] } 33 | actix-web = "4" 34 | tokio-stream = "0.1" 35 | futures = "0.3" 36 | jsonwebtoken = "8.1" 37 | uuid = { version = "1.0", features = ["v4"] } 38 | actix-ws = "0.2.5" 39 | tokio-tungstenite = { version = "0.21", features = ["native-tls"] } 40 | 41 | [dev-dependencies] 42 | tracing-subscriber = "0.3" 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async MCP 2 | A minimalistic async Rust implementation of the Model Context Protocol (MCP). This library extends the synchronous implementation from [mcp-sdk](https://github.com/AntigmaLabs/mcp-sdk) to support async operations and implements additional transports. Due to significant code changes, it is released as a separate crate. 3 | 4 | [![Crates.io](https://img.shields.io/crates/v/async-mcp)](https://crates.io/crates/async-mcp) 5 | 6 | > **Note**: This project is still early in development. 7 | 8 | ## Installation 9 | 10 | Add this to your `Cargo.toml`: 11 | 12 | ```toml 13 | [dependencies] 14 | async-mcp = "0.1.2" 15 | ``` 16 | 17 | ## Overview 18 | This is an implementation of the [Model Context Protocol](https://github.com/modelcontextprotocol) defined by Anthropic. 19 | 20 | ## Features 21 | 22 | ### Supported Transports 23 | - Server-Sent Events (SSE) 24 | - Standard IO (Stdio) 25 | - In-Memory Channel 26 | - Websockets 27 | 28 | ## Usage Examples 29 | 30 | ### Server Implementation 31 | 32 | #### Using Stdio Transport 33 | ```rust 34 | let server = Server::builder(StdioTransport) 35 | .capabilities(ServerCapabilities { 36 | tools: Some(json!({})), 37 | ..Default::default() 38 | }) 39 | .request_handler("tools/list", list_tools) 40 | .request_handler("tools/call", call_tool) 41 | .request_handler("resources/list", |_req: ListRequest| { 42 | Ok(ResourcesListResponse { 43 | resources: vec![], 44 | next_cursor: None, 45 | meta: None, 46 | }) 47 | }) 48 | .build(); 49 | ``` 50 | 51 | #### Run Http Server supporting both SSE and WS 52 | ```rust 53 | run_http_server(3004, None, |transport| async move { 54 | let server = build_server(transport); 55 | Ok(server) 56 | }) 57 | .await?; 58 | ``` 59 | 60 | Local Endpoints 61 | ``` 62 | WebSocket endpoint: ws://127.0.0.1:3004/ws 63 | SSE endpoint: http://127.0.0.1:3004/sse 64 | ``` 65 | 66 | ### Client Implementation 67 | 68 | #### Setting up Transport 69 | ```rust 70 | // Stdio Transport 71 | let transport = ClientStdioTransport::new("", &[])?; 72 | 73 | // In-Memory Transport 74 | let transport = ClientInMemoryTransport::new(|t| tokio::spawn(inmemory_server(t))); 75 | 76 | // SSE Transport 77 | let transport = ClientSseTransportBuilder::new(server_url).build(); 78 | 79 | // WS Transport 80 | let transport = async_mcp::transport::ClientWsTransportBuilder::new("ws://localhost:3004/ws".to_string()).build(); 81 | ``` 82 | 83 | #### Making Requests 84 | ```rust 85 | // Initialize transport 86 | transport.open().await?; 87 | 88 | // Create and start client 89 | let client = async_mcp::client::ClientBuilder::new(transport.clone()).build(); 90 | let client_clone = client.clone(); 91 | let _client_handle = tokio::spawn(async move { client_clone.start().await }); 92 | 93 | // Make a request 94 | client 95 | .request( 96 | "tools/call", 97 | Some(json!({"name": "ping", "arguments": {}})), 98 | RequestOptions::default().timeout(Duration::from_secs(5)), 99 | ) 100 | .await? 101 | ``` 102 | 103 | ## Complete Examples 104 | For full working examples, check out: 105 | - [Ping Pong Example](./examples/pingpong/) 106 | - [File System Example](examples/file_system/README.md) 107 | - [Knowledge Graph Memory Example](examples/knowledge_graph_memory/README.md) 108 | 109 | ## Related SDKs 110 | 111 | ### Official 112 | - [TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) 113 | - [Python SDK](https://github.com/modelcontextprotocol/python-sdk) 114 | 115 | ### Community 116 | - [Go SDK](https://github.com/mark3labs/mcp-go) 117 | 118 | For the complete feature set, please refer to the [MCP specification](https://spec.modelcontextprotocol.io/). 119 | 120 | ## Implementation Status 121 | 122 | ### Core Protocol Features 123 | - [x] Basic Message Types 124 | - [ ] Error and Signal Handling 125 | - [x] Transport Layer 126 | - [x] Stdio 127 | - [x] In-Memory Channel 128 | - [x] SSE 129 | - [x] Websockets 130 | 131 | ### Server Features 132 | - [x] Tools Support 133 | - [ ] Prompts 134 | - [ ] Resources 135 | - [x] Pagination 136 | - [x] Completion 137 | 138 | ### Client Features 139 | Compatible with Claude Desktop: 140 | - [x] Stdio Support 141 | - [x] In-Memory Channel 142 | - [x] SSE Support 143 | 144 | ### Monitoring 145 | - [ ] Logging 146 | - [ ] Metrics 147 | 148 | ### Utilities 149 | - [ ] Ping 150 | - [ ] Cancellation 151 | - [ ] Progress Tracking 152 | -------------------------------------------------------------------------------- /examples/client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "client" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | async-mcp = { path = "../.." } # This points to your main project 8 | tokio = { version = "1.0", features = ["full"] } 9 | serde = { version = "1.0", features = ["derive"] } 10 | serde_json = "1.0" 11 | anyhow = "1.0" 12 | tracing-subscriber = "0.3" 13 | tracing = "0.1" 14 | home = "0.5.9" 15 | -------------------------------------------------------------------------------- /examples/client/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use anyhow::Result; 4 | use async_mcp::{ 5 | client::ClientBuilder, 6 | protocol::RequestOptions, 7 | transport::{ClientStdioTransport, Transport}, 8 | }; 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<()> { 12 | #[cfg(unix)] 13 | { 14 | // Create transport connected to cat command which will stay alive 15 | let transport = ClientStdioTransport::new("cat", &[], None)?; 16 | 17 | // Open transport 18 | transport.open().await?; 19 | 20 | let client = ClientBuilder::new(transport).build(); 21 | let client_clone = client.clone(); 22 | tokio::spawn(async move { client_clone.start().await }); 23 | let response = client 24 | .request( 25 | "echo", 26 | None, 27 | RequestOptions::default().timeout(Duration::from_secs(1)), 28 | ) 29 | .await?; 30 | println!("{:?}", response); 31 | } 32 | #[cfg(windows)] 33 | { 34 | println!("Windows is not supported yet"); 35 | } 36 | Ok(()) 37 | } 38 | -------------------------------------------------------------------------------- /examples/file_system/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "file_system" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | async-mcp = { path = "../.." } 8 | tokio = { version = "1.0", features = ["full"] } 9 | serde = { version = "1.0", features = ["derive"] } 10 | serde_json = "1.0" 11 | anyhow = "1.0" 12 | tracing-subscriber = "0.3" 13 | tracing = "0.1" 14 | home = "0.5.9" 15 | clap = { version = "4.4", features = ["derive"] } 16 | -------------------------------------------------------------------------------- /examples/file_system/README.md: -------------------------------------------------------------------------------- 1 | # Simple Example of A Read-Only File System 2 | 3 | This example demonstrates a simple read-only file system. It allows you to list the contents of a directory and read the contents of a file. 4 | 5 | similar to the [Typescript Example](https://github.com/modelcontextprotocol/servers/tree/main/src/filesystem) example, but with a read-only file system. 6 | 7 | ### Tools 8 | 9 | - **read_file** 10 | - Read complete contents of a file 11 | - Input: `path` (string) 12 | - Reads complete file contents with UTF-8 encoding 13 | 14 | - **list_directory** 15 | - List directory contents with [FILE] or [DIR] prefixes 16 | - Input: `path` (string) 17 | 18 | - **search_files** 19 | - Recursively search for files/directories 20 | - Inputs: 21 | - `path` (string): Starting directory 22 | - `pattern` (string): Search pattern 23 | - Case-insensitive matching 24 | - Returns full paths to matches 25 | 26 | - **get_file_info** 27 | - Get detailed file/directory metadata 28 | - Input: `path` (string) 29 | - Returns: metadata of the file or directory 30 | 31 | ## How to Build and Run Example Locally 32 | 33 | ### Prerequisites 34 | - macOS (will handle Windows in the future) 35 | - The latest version of Claude Desktop installed 36 | - Install [Rust](https://www.rust-lang.org/tools/install) 37 | 38 | ### Build and Install Binary 39 | ```bash 40 | cd async-mcp/examples/file_system 41 | cargo install --path . 42 | ``` 43 | This will build the binary and install it to your local cargo bin directory. Later you will need to configure Claude Desktop to use this binary. 44 | ### Configure Claude Desktop 45 | 46 | If you are using macOS, open the `claude_desktop_config.json` file in a text editor: 47 | ```bash 48 | code ~/Library/Application\ Support/Claude/claude_desktop_config.json 49 | ``` 50 | 51 | Modify the `claude_desktop_config.json` file to include the following configuration: 52 | (replace YOUR_USERNAME with your actual username): 53 | ```json 54 | { 55 | "mcpServers": { 56 | "mcp_example_file_system": { 57 | "command": "/Users/YOUR_USERNAME/.cargo/bin/file_system" 58 | } 59 | } 60 | } 61 | ``` 62 | Save the file, and restart Claude Desktop. 63 | ## What will it look like 64 | Screenshot 2024-11-30 at 12 44 19 PM 65 | 66 | ## Test locally 67 | ``` 68 | cat << 'EOF' | cargo run --bin file_system 69 | {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "list_directory", "arguments": {"path": "."}}, "id": 1} 70 | EOF 71 | ``` -------------------------------------------------------------------------------- /examples/file_system/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod server; 2 | -------------------------------------------------------------------------------- /examples/file_system/src/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_mcp::{run_http_server, transport::ServerStdioTransport}; 3 | use clap::{Parser, ValueEnum}; 4 | use file_system::server::build_server; 5 | 6 | #[derive(Parser)] 7 | #[command(author, version, about, long_about = None)] 8 | struct Cli { 9 | /// Transport type to use 10 | #[arg(value_enum, default_value_t = TransportType::Stdio)] 11 | transport: TransportType, 12 | } 13 | 14 | #[derive(Copy, Clone, PartialEq, Eq, ValueEnum)] 15 | enum TransportType { 16 | Stdio, 17 | Sse, 18 | } 19 | 20 | #[tokio::main] 21 | async fn main() -> Result<()> { 22 | tracing_subscriber::fmt() 23 | .with_max_level(tracing::Level::DEBUG) 24 | // needs to be stderr due to stdio transport 25 | .with_writer(std::io::stderr) 26 | .init(); 27 | 28 | let cli = Cli::parse(); 29 | 30 | match cli.transport { 31 | TransportType::Stdio => { 32 | let server = build_server(ServerStdioTransport); 33 | server 34 | .listen() 35 | .await 36 | .map_err(|e| anyhow::anyhow!("Server error: {}", e))?; 37 | } 38 | TransportType::Sse => { 39 | run_http_server(3004, None, |transport, _, _| async move { 40 | let server = build_server(transport); 41 | Ok(server) 42 | }) 43 | .await?; 44 | } 45 | }; 46 | Ok(()) 47 | } 48 | -------------------------------------------------------------------------------- /examples/file_system/src/server.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::path::{Path, PathBuf}; 3 | 4 | use anyhow::Result; 5 | use async_mcp::server::Server; 6 | use async_mcp::transport::Transport; 7 | use async_mcp::types::{ 8 | CallToolRequest, CallToolResponse, ListRequest, ResourcesListResponse, ServerCapabilities, 9 | ToolResponseContent, ToolsListResponse, 10 | }; 11 | use serde_json::json; 12 | 13 | pub fn build_server(t: T) -> Server { 14 | Server::builder(t) 15 | .capabilities(ServerCapabilities { 16 | tools: Some(json!({})), 17 | ..Default::default() 18 | }) 19 | .request_handler("tools/list", |req: ListRequest| { 20 | Box::pin(async move { list_tools(req) }) 21 | }) 22 | .request_handler("tools/call", |req: CallToolRequest| { 23 | Box::pin(async move { call_tool(req) }) 24 | }) 25 | .request_handler("resources/list", |_req: ListRequest| { 26 | Box::pin(async move { 27 | Ok(ResourcesListResponse { 28 | resources: vec![], 29 | next_cursor: None, 30 | meta: None, 31 | }) 32 | }) 33 | }) 34 | .build() 35 | } 36 | 37 | fn call_tool(req: CallToolRequest) -> Result { 38 | let name = req.name.as_str(); 39 | let args = req.arguments.unwrap_or_default(); 40 | let result = match name { 41 | "read_file" => { 42 | let path = get_path(&args)?; 43 | let content = std::fs::read_to_string(path)?; 44 | ToolResponseContent::Text { text: content } 45 | } 46 | "list_directory" => { 47 | let path = get_path(&args)?; 48 | let entries = std::fs::read_dir(path)?; 49 | let mut text = String::new(); 50 | for entry in entries { 51 | let entry = entry?; 52 | let prefix = if entry.file_type()?.is_dir() { 53 | "[DIR]" 54 | } else { 55 | "[FILE]" 56 | }; 57 | text.push_str(&format!( 58 | "{prefix} {}\n", 59 | entry.file_name().to_string_lossy() 60 | )); 61 | } 62 | ToolResponseContent::Text { text } 63 | } 64 | "search_files" => { 65 | let path = get_path(&args)?; 66 | let pattern = args["pattern"].as_str().unwrap(); 67 | let mut matches = Vec::new(); 68 | search_directory(&path, pattern, &mut matches)?; 69 | ToolResponseContent::Text { 70 | text: matches.join("\n"), 71 | } 72 | } 73 | "get_file_info" => { 74 | let path = get_path(&args)?; 75 | let metadata = std::fs::metadata(path)?; 76 | ToolResponseContent::Text { 77 | text: format!("{:?}", metadata), 78 | } 79 | } 80 | "list_allowed_directories" => ToolResponseContent::Text { 81 | text: "[]".to_string(), 82 | }, 83 | _ => return Err(anyhow::anyhow!("Unknown tool: {}", req.name)), 84 | }; 85 | Ok(CallToolResponse { 86 | content: vec![result], 87 | is_error: None, 88 | meta: None, 89 | }) 90 | } 91 | 92 | fn search_directory(dir: &Path, pattern: &str, matches: &mut Vec) -> Result<()> { 93 | for entry in std::fs::read_dir(dir)? { 94 | let entry = entry?; 95 | let path = entry.path(); 96 | let name = path 97 | .file_name() 98 | .unwrap_or_default() 99 | .to_string_lossy() 100 | .to_lowercase(); 101 | 102 | // Check if the current file/directory matches the pattern 103 | if name.contains(&pattern.to_lowercase()) { 104 | matches.push(path.to_string_lossy().to_string()); 105 | } 106 | 107 | // Recursively search subdirectories 108 | if path.is_dir() { 109 | search_directory(&path, pattern, matches)?; 110 | } 111 | } 112 | Ok(()) 113 | } 114 | 115 | fn get_path(args: &HashMap) -> Result { 116 | tracing::debug!("Args: {args:?}"); 117 | let path = args["path"] 118 | .as_str() 119 | .ok_or(anyhow::anyhow!("Missing path"))?; 120 | 121 | if path.starts_with('~') { 122 | let home = home::home_dir().ok_or(anyhow::anyhow!("Could not determine home directory"))?; 123 | // Strip the ~ and join with home path 124 | let path = home.join(path.strip_prefix("~/").unwrap_or_default()); 125 | Ok(path) 126 | } else { 127 | Ok(PathBuf::from(path)) 128 | } 129 | } 130 | 131 | fn list_tools(_req: ListRequest) -> Result { 132 | let response = json!({ 133 | "tools": [ 134 | { 135 | "name": "read_file", 136 | "description": 137 | "Read the complete contents of a file from the file system. \ 138 | Handles various text encodings and provides detailed error messages \ 139 | if the file cannot be read. Use this tool when you need to examine \ 140 | the contents of a single file. Only works within allowed directories.", 141 | "inputSchema": { 142 | "type": "object", 143 | "properties": { 144 | "path": { 145 | "type": "string" 146 | } 147 | }, 148 | "required": ["path"] 149 | }, 150 | }, 151 | { 152 | "name": "list_directory", 153 | "description": 154 | "Get a detailed listing of all files and directories in a specified path. \ 155 | Results clearly distinguish between files and directories with [FILE] and [DIR] \ 156 | prefixes. This tool is essential for understanding directory structure and \ 157 | finding specific files within a directory. Only works within allowed directories.", 158 | "inputSchema": { 159 | "type": "object", 160 | "properties": { 161 | "path": { 162 | "type": "string" 163 | } 164 | }, 165 | "required": ["path"] 166 | }, 167 | }, 168 | { 169 | "name": "search_files", 170 | "description": 171 | "Recursively search for files and directories matching a pattern. \ 172 | Searches through all subdirectories from the starting path. The search \ 173 | is case-insensitive and matches partial names. Returns full paths to all \ 174 | matching items. Great for finding files when you don't know their exact location. \ 175 | Only searches within allowed directories.", 176 | "inputSchema": { 177 | "type": "object", 178 | "properties": { 179 | "path": { 180 | "type": "string" 181 | }, 182 | "pattern": { 183 | "type": "string" 184 | } 185 | }, 186 | "required": ["path", "pattern"] 187 | }, 188 | }, 189 | { 190 | "name": "get_file_info", 191 | "description": 192 | "Retrieve detailed metadata about a file or directory. Returns comprehensive \ 193 | information including size, creation time, last modified time, permissions, \ 194 | and type. This tool is perfect for understanding file characteristics \ 195 | without reading the actual content. Only works within allowed directories.", 196 | "inputSchema": { 197 | "type": "object", 198 | "properties": { 199 | "path": { 200 | "type": "string" 201 | } 202 | }, 203 | "required": ["path"] 204 | }, 205 | } 206 | ], 207 | }); 208 | Ok(serde_json::from_value(response)?) 209 | } 210 | -------------------------------------------------------------------------------- /examples/knowledge_graph_memory/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "knowledge_graph_memory" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | async-mcp = { path = "../.." } # This points to your main project 8 | tokio = { version = "1.0", features = ["full"] } 9 | serde = { version = "1.0", features = ["derive"] } 10 | serde_json = "1.0" 11 | anyhow = "1.0" 12 | tracing-subscriber = "0.3" 13 | tracing = "0.1" 14 | -------------------------------------------------------------------------------- /examples/knowledge_graph_memory/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, Mutex}; 2 | 3 | use async_mcp::{ 4 | server::{Server, ServerBuilder}, 5 | transport::ServerStdioTransport, 6 | types::{CallToolRequest, CallToolResponse, ServerCapabilities, Tool, ToolResponseContent}, 7 | }; 8 | use serde_json::json; 9 | use types::{AddObservationParams, DeleteObservationParams, Entity, KnowledgeGraph, Relation}; 10 | 11 | use anyhow::Result; 12 | mod types; 13 | 14 | #[tokio::main] 15 | async fn main() -> Result<()> { 16 | tracing_subscriber::fmt() 17 | .with_max_level(tracing::Level::DEBUG) 18 | // needs to be stderr due to stdio transport 19 | .with_writer(std::io::stderr) 20 | .init(); 21 | 22 | let mut server = Server::builder(ServerStdioTransport).capabilities(ServerCapabilities { 23 | tools: Some(json!({})), 24 | ..Default::default() 25 | }); 26 | register_tools(&mut server)?; 27 | 28 | let server = server.build(); 29 | server 30 | .listen() 31 | .await 32 | .map_err(|e| anyhow::anyhow!("Server error: {}", e))?; 33 | Ok(()) 34 | } 35 | 36 | fn register_tools(server: &mut ServerBuilder) -> Result<()> { 37 | let memory_file_path = "kb_memory.json"; 38 | let kg = KnowledgeGraph::load_from_file(memory_file_path)?; 39 | let kg = Arc::new(Mutex::new(kg)); 40 | 41 | let description = Tool { 42 | name: "create_entities".to_string(), 43 | description: Some("Create multiple new entities".to_string()), 44 | input_schema: json!({ 45 | "type":"object", 46 | "properties":{ 47 | "entities":{ 48 | "type":"array", 49 | "items":{ 50 | "type":"object", 51 | "properties":{ 52 | "name":{"type":"string"}, 53 | "entityType":{"type":"string"}, 54 | "observations":{ 55 | "type":"array", "items":{"type":"string"} 56 | } 57 | }, 58 | "required":["name","entityType","observations"] 59 | } 60 | } 61 | }, 62 | "required":["entities"] 63 | }), 64 | output_schema: None, 65 | }; 66 | 67 | let kg_clone = kg.clone(); 68 | server.register_tool(description, move |req: CallToolRequest| { 69 | let kg_clone = kg_clone.clone(); 70 | Box::pin(async move { 71 | let args = req.arguments.unwrap_or_default(); 72 | let entities = args 73 | .get("entities") 74 | .ok_or(anyhow::anyhow!("missing arguments `entities`"))?; 75 | let entities: Vec = serde_json::from_value(entities.clone())?; 76 | let created = kg_clone.lock().unwrap().create_entities(entities)?; 77 | kg_clone.lock().unwrap().save_to_file(memory_file_path)?; 78 | Ok(CallToolResponse { 79 | content: vec![ToolResponseContent::Text { 80 | text: json!(created).to_string(), 81 | }], 82 | is_error: None, 83 | meta: None, 84 | }) 85 | }) 86 | }); 87 | 88 | let description = Tool { 89 | name: "create_relations".to_string(), 90 | description: Some("Create multiple new relations".to_string()), 91 | input_schema: json!({ 92 | "type":"object", 93 | "properties":{ 94 | "relations":{ 95 | "type":"array", 96 | "items":{ 97 | "type":"object", 98 | "properties":{ 99 | "from":{"type":"string"}, 100 | "to":{"type":"string"}, 101 | "relationType":{"type":"string"} 102 | }, 103 | "required":["from","to","relationType"] 104 | } 105 | } 106 | }, 107 | "required":["relations"] 108 | }), 109 | output_schema: None, 110 | }; 111 | let kg_clone = kg.clone(); 112 | server.register_tool(description, move |req: CallToolRequest| { 113 | let kg_clone = kg_clone.clone(); 114 | Box::pin(async move { 115 | let args = req.arguments.unwrap_or_default(); 116 | let relations = args 117 | .get("relations") 118 | .ok_or(anyhow::anyhow!("missing arguments `relations`"))?; 119 | let relations: Vec = serde_json::from_value(relations.clone())?; 120 | let created = kg_clone.lock().unwrap().create_relations(relations)?; 121 | kg_clone.lock().unwrap().save_to_file(memory_file_path)?; 122 | Ok(CallToolResponse { 123 | content: vec![ToolResponseContent::Text { 124 | text: json!(created).to_string(), 125 | }], 126 | is_error: None, 127 | meta: None, 128 | }) 129 | }) 130 | }); 131 | 132 | let description = Tool { 133 | name: "add_observations".to_string(), 134 | description: Some("Add new observations to existing entities".to_string()), 135 | input_schema: json!({ 136 | "type": "object", 137 | "properties": { 138 | "observations": { 139 | "type": "array", 140 | "items": { 141 | "type": "object", 142 | "properties": { 143 | "entityName": {"type": "string"}, 144 | "contents": { 145 | "type": "array", 146 | "items": {"type": "string"} 147 | } 148 | }, 149 | "required": ["entityName", "contents"] 150 | } 151 | } 152 | }, 153 | "required": ["observations"] 154 | }), 155 | output_schema: None, 156 | }; 157 | let kg_clone = kg.clone(); 158 | server.register_tool(description, move |req: CallToolRequest| { 159 | let kg_clone = kg_clone.clone(); 160 | Box::pin(async move { 161 | let args = req.arguments.unwrap_or_default(); 162 | let observations = args 163 | .get("observations") 164 | .ok_or(anyhow::anyhow!("missing arguments `observations`"))?; 165 | let observations: Vec = 166 | serde_json::from_value(observations.clone())?; 167 | let results = kg_clone.lock().unwrap().add_observations(observations)?; 168 | kg_clone.lock().unwrap().save_to_file(memory_file_path)?; 169 | Ok(CallToolResponse { 170 | content: vec![ToolResponseContent::Text { 171 | text: json!(results).to_string(), 172 | }], 173 | is_error: None, 174 | meta: None, 175 | }) 176 | }) 177 | }); 178 | 179 | let description = Tool { 180 | name: "delete_entities".to_string(), 181 | description: Some("Delete multiple entities and their relations".to_string()), 182 | input_schema: json!({ 183 | "type": "object", 184 | "properties": { 185 | "entityNames": { 186 | "type": "array", 187 | "items": {"type": "string"} 188 | } 189 | }, 190 | "required": ["entityNames"] 191 | }), 192 | output_schema: None, 193 | }; 194 | let kg_clone = kg.clone(); 195 | server.register_tool(description, move |req: CallToolRequest| { 196 | let kg_clone = kg_clone.clone(); 197 | Box::pin(async move { 198 | let args = req.arguments.unwrap_or_default(); 199 | let entity_names = args 200 | .get("entityNames") 201 | .ok_or(anyhow::anyhow!("missing arguments `entityNames`"))?; 202 | let entity_names: Vec = serde_json::from_value(entity_names.clone())?; 203 | let mut kg_guard = kg_clone.lock().unwrap(); 204 | kg_guard.delete_entities(entity_names)?; 205 | kg_guard.save_to_file(memory_file_path)?; 206 | Ok(CallToolResponse { 207 | content: vec![ToolResponseContent::Text { 208 | text: "Entities deleted successfully".to_string(), 209 | }], 210 | is_error: None, 211 | meta: None, 212 | }) 213 | }) 214 | }); 215 | 216 | let description = Tool { 217 | name: "delete_observations".to_string(), 218 | description: Some("Delete specific observations from entities".to_string()), 219 | input_schema: json!({ 220 | "type": "object", 221 | "properties": { 222 | "deletions": { 223 | "type": "array", 224 | "items": { 225 | "type": "object", 226 | "properties": { 227 | "entityName": {"type": "string"}, 228 | "observations": { 229 | "type": "array", 230 | "items": {"type": "string"} 231 | } 232 | }, 233 | "required": ["entityName", "observations"] 234 | } 235 | } 236 | }, 237 | "required": ["deletions"] 238 | }), 239 | output_schema: None, 240 | }; 241 | let kg_clone = kg.clone(); 242 | server.register_tool(description, move |req: CallToolRequest| { 243 | let kg_clone = kg_clone.clone(); 244 | Box::pin(async move { 245 | let args = req.arguments.unwrap_or_default(); 246 | let deletions = args 247 | .get("deletions") 248 | .ok_or(anyhow::anyhow!("missing arguments `deletions`"))?; 249 | let deletions: Vec = 250 | serde_json::from_value(deletions.clone())?; 251 | let mut kg_guard = kg_clone.lock().unwrap(); 252 | kg_guard.delete_observations(deletions)?; 253 | kg_guard.save_to_file(memory_file_path)?; 254 | Ok(CallToolResponse { 255 | content: vec![ToolResponseContent::Text { 256 | text: "Observations deleted successfully".to_string(), 257 | }], 258 | is_error: None, 259 | meta: None, 260 | }) 261 | }) 262 | }); 263 | 264 | let description = Tool { 265 | name: "delete_relations".to_string(), 266 | description: Some("Delete multiple relations from the graph".to_string()), 267 | input_schema: json!({ 268 | "type": "object", 269 | "properties": { 270 | "relations": { 271 | "type": "array", 272 | "items": { 273 | "type": "object", 274 | "properties": { 275 | "from": {"type": "string"}, 276 | "to": {"type": "string"}, 277 | "relationType": {"type": "string"} 278 | }, 279 | "required": ["from", "to", "relationType"] 280 | } 281 | } 282 | }, 283 | "required": ["relations"] 284 | }), 285 | output_schema: None, 286 | }; 287 | let kg_clone = kg.clone(); 288 | server.register_tool(description, move |req: CallToolRequest| { 289 | let kg_clone = kg_clone.clone(); 290 | Box::pin(async move { 291 | let args = req.arguments.unwrap_or_default(); 292 | let relations = args 293 | .get("relations") 294 | .ok_or(anyhow::anyhow!("missing arguments `relations`"))?; 295 | let relations: Vec = serde_json::from_value(relations.clone())?; 296 | let mut kg_guard = kg_clone.lock().unwrap(); 297 | kg_guard.delete_relations(relations)?; 298 | kg_guard.save_to_file(memory_file_path)?; 299 | Ok(CallToolResponse { 300 | content: vec![ToolResponseContent::Text { 301 | text: "Relations deleted successfully".to_string(), 302 | }], 303 | is_error: None, 304 | meta: None, 305 | }) 306 | }) 307 | }); 308 | 309 | let description = Tool { 310 | name: "read_graph".to_string(), 311 | description: Some("Read the entire knowledge graph".to_string()), 312 | input_schema: json!({ 313 | "type": "object", 314 | "properties": {} 315 | }), 316 | output_schema: None, 317 | }; 318 | let kg_clone = kg.clone(); 319 | server.register_tool(description, move |_req: CallToolRequest| { 320 | let kg_clone = kg_clone.clone(); 321 | Box::pin(async move { 322 | Ok(CallToolResponse { 323 | content: vec![ToolResponseContent::Text { 324 | text: json!(*kg_clone.lock().unwrap()).to_string(), 325 | }], 326 | is_error: None, 327 | meta: None, 328 | }) 329 | }) 330 | }); 331 | 332 | let description = Tool { 333 | name: "search_nodes".to_string(), 334 | description: Some("Search for nodes in the knowledge graph".to_string()), 335 | input_schema: json!({ 336 | "type": "object", 337 | "properties": { 338 | "query": {"type": "string"} 339 | }, 340 | "required": ["query"] 341 | }), 342 | output_schema: None, 343 | }; 344 | let kg_clone = kg.clone(); 345 | server.register_tool(description, move |req: CallToolRequest| { 346 | let kg_clone = kg_clone.clone(); 347 | Box::pin(async move { 348 | let args = req.arguments.unwrap_or_default(); 349 | let query = args 350 | .get("query") 351 | .ok_or(anyhow::anyhow!("missing argument `query`"))? 352 | .as_str() 353 | .ok_or(anyhow::anyhow!("query must be a string"))?; 354 | let results = kg_clone.lock().unwrap().search_nodes(query)?; 355 | Ok(CallToolResponse { 356 | content: vec![ToolResponseContent::Text { 357 | text: json!(results).to_string(), 358 | }], 359 | is_error: None, 360 | meta: None, 361 | }) 362 | }) 363 | }); 364 | 365 | let description = Tool { 366 | name: "open_nodes".to_string(), 367 | description: Some("Open specific nodes by their names".to_string()), 368 | input_schema: json!({ 369 | "type": "object", 370 | "properties": { 371 | "names": { 372 | "type": "array", 373 | "items": {"type": "string"} 374 | } 375 | }, 376 | "required": ["names"] 377 | }), 378 | output_schema: None, 379 | }; 380 | let kg_clone = kg.clone(); 381 | server.register_tool(description, move |req: CallToolRequest| { 382 | let kg_clone = kg_clone.clone(); 383 | Box::pin(async move { 384 | let args = req.arguments.unwrap_or_default(); 385 | let names = args 386 | .get("names") 387 | .ok_or(anyhow::anyhow!("missing arguments `names`"))?; 388 | let names: Vec = serde_json::from_value(names.clone())?; 389 | let results = kg_clone.lock().unwrap().open_nodes(names)?; 390 | Ok(CallToolResponse { 391 | content: vec![ToolResponseContent::Text { 392 | text: json!(results).to_string(), 393 | }], 394 | is_error: None, 395 | meta: None, 396 | }) 397 | }) 398 | }); 399 | 400 | Ok(()) 401 | } 402 | -------------------------------------------------------------------------------- /examples/knowledge_graph_memory/src/types.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use serde::{Deserialize, Serialize}; 3 | use std::{ 4 | fs::File, 5 | io::{BufRead, BufReader, Write}, 6 | path::Path, 7 | }; 8 | 9 | // ----------------------------------------------------------------------------- 10 | // Data Structures 11 | // ----------------------------------------------------------------------------- 12 | 13 | #[derive(Debug, Serialize, Deserialize, Clone)] 14 | pub struct Entity { 15 | pub name: String, 16 | #[serde(rename = "entityType")] 17 | pub entity_type: String, 18 | pub observations: Vec, 19 | } 20 | 21 | #[derive(Debug, Serialize, Deserialize, Clone)] 22 | pub struct Relation { 23 | pub from: String, 24 | pub to: String, 25 | #[serde(rename = "relationType")] 26 | pub relation_type: String, 27 | } 28 | 29 | #[derive(Debug, Serialize, Deserialize, Clone)] 30 | pub struct KnowledgeGraph { 31 | pub entities: Vec, 32 | pub relations: Vec, 33 | } 34 | 35 | impl KnowledgeGraph { 36 | pub fn load_from_file(memory_file_path: &str) -> Result { 37 | if !Path::new(memory_file_path).exists() { 38 | return Ok(Self { 39 | entities: vec![], 40 | relations: vec![], 41 | }); 42 | } 43 | 44 | let file = File::open(memory_file_path)?; 45 | let reader = BufReader::new(file); 46 | let mut kg = KnowledgeGraph { 47 | entities: vec![], 48 | relations: vec![], 49 | }; 50 | 51 | for line_res in reader.lines() { 52 | let line = line_res?; 53 | if line.trim().is_empty() { 54 | continue; 55 | } 56 | let json_val: serde_json::Value = serde_json::from_str(&line)?; 57 | if let Some(t) = json_val.get("type").and_then(|v| v.as_str()) { 58 | match t { 59 | "entity" => { 60 | let entity: Entity = serde_json::from_value(json_val)?; 61 | kg.entities.push(entity); 62 | } 63 | "relation" => { 64 | let relation: Relation = serde_json::from_value(json_val)?; 65 | kg.relations.push(relation); 66 | } 67 | _ => {} 68 | } 69 | } 70 | } 71 | 72 | Ok(kg) 73 | } 74 | 75 | pub fn save_to_file(&self, memory_file_path: &str) -> Result<()> { 76 | let mut file = File::create(memory_file_path)?; 77 | for entity in &self.entities { 78 | let mut map = serde_json::to_value(entity)?; 79 | if let Some(obj) = map.as_object_mut() { 80 | obj.insert( 81 | "type".to_string(), 82 | serde_json::Value::String("entity".into()), 83 | ); 84 | } 85 | let line = serde_json::to_string(&map)?; 86 | writeln!(file, "{}", line)?; 87 | } 88 | for relation in &self.relations { 89 | let mut map = serde_json::to_value(relation)?; 90 | if let Some(obj) = map.as_object_mut() { 91 | obj.insert( 92 | "type".to_string(), 93 | serde_json::Value::String("relation".into()), 94 | ); 95 | } 96 | let line = serde_json::to_string(&map)?; 97 | writeln!(file, "{}", line)?; 98 | } 99 | Ok(()) 100 | } 101 | 102 | pub fn create_entities(&mut self, entities: Vec) -> Result> { 103 | let mut newly_added = Vec::new(); 104 | for e in entities { 105 | if !self.entities.iter().any(|x| x.name == e.name) { 106 | self.entities.push(e.clone()); 107 | newly_added.push(e); 108 | } 109 | } 110 | Ok(newly_added) 111 | } 112 | 113 | pub fn create_relations(&mut self, relations: Vec) -> Result> { 114 | let mut newly_added = Vec::new(); 115 | for r in relations { 116 | if !self.relations.iter().any(|rel| { 117 | rel.from == r.from && rel.to == r.to && rel.relation_type == r.relation_type 118 | }) { 119 | self.relations.push(r.clone()); 120 | newly_added.push(r); 121 | } 122 | } 123 | Ok(newly_added) 124 | } 125 | 126 | pub fn add_observations( 127 | &mut self, 128 | observations: Vec, 129 | ) -> Result> { 130 | let mut results = Vec::new(); 131 | 132 | for obs in observations { 133 | let entity = self.entities.iter_mut().find(|e| e.name == obs.entity_name); 134 | 135 | if let Some(e) = entity { 136 | let mut added_contents = Vec::new(); 137 | for content in obs.contents { 138 | if !e.observations.contains(&content) { 139 | e.observations.push(content.clone()); 140 | added_contents.push(content); 141 | } 142 | } 143 | results.push(AddedObservationResult { 144 | entity_name: obs.entity_name, 145 | added_observations: added_contents, 146 | }); 147 | } else { 148 | anyhow::bail!("Entity with name {} not found", obs.entity_name); 149 | } 150 | } 151 | 152 | Ok(results) 153 | } 154 | 155 | pub fn delete_entities(&mut self, entity_names: Vec) -> Result<()> { 156 | self.entities.retain(|e| !entity_names.contains(&e.name)); 157 | self.relations 158 | .retain(|r| !entity_names.contains(&r.from) && !entity_names.contains(&r.to)); 159 | Ok(()) 160 | } 161 | 162 | pub fn delete_observations(&mut self, deletions: Vec) -> Result<()> { 163 | for d in deletions { 164 | if let Some(ent) = self.entities.iter_mut().find(|e| e.name == d.entity_name) { 165 | ent.observations.retain(|obs| !d.observations.contains(obs)); 166 | } 167 | } 168 | Ok(()) 169 | } 170 | 171 | pub fn delete_relations(&mut self, relations: Vec) -> Result<()> { 172 | self.relations.retain(|r| { 173 | !relations.iter().any(|del| { 174 | del.from == r.from && del.to == r.to && del.relation_type == r.relation_type 175 | }) 176 | }); 177 | Ok(()) 178 | } 179 | 180 | pub fn search_nodes(&self, query: &str) -> Result { 181 | let q_lower = query.to_lowercase(); 182 | 183 | let filtered_entities: Vec = self 184 | .entities 185 | .iter() 186 | .filter(|e| { 187 | e.name.to_lowercase().contains(&q_lower) 188 | || e.entity_type.to_lowercase().contains(&q_lower) 189 | || e.observations 190 | .iter() 191 | .any(|obs| obs.to_lowercase().contains(&q_lower)) 192 | }) 193 | .cloned() 194 | .collect(); 195 | 196 | let filtered_entity_names: Vec = 197 | filtered_entities.iter().map(|e| e.name.clone()).collect(); 198 | 199 | let filtered_relations: Vec = self 200 | .relations 201 | .iter() 202 | .filter(|r| { 203 | filtered_entity_names.contains(&r.from) && filtered_entity_names.contains(&r.to) 204 | }) 205 | .cloned() 206 | .collect(); 207 | 208 | Ok(KnowledgeGraph { 209 | entities: filtered_entities, 210 | relations: filtered_relations, 211 | }) 212 | } 213 | 214 | pub fn open_nodes(&self, names: Vec) -> Result { 215 | let filtered_entities: Vec = self 216 | .entities 217 | .iter() 218 | .filter(|e| names.contains(&e.name)) 219 | .cloned() 220 | .collect(); 221 | 222 | let filtered_entity_names: Vec = 223 | filtered_entities.iter().map(|e| e.name.clone()).collect(); 224 | 225 | let filtered_relations: Vec = self 226 | .relations 227 | .iter() 228 | .filter(|r| { 229 | filtered_entity_names.contains(&r.from) && filtered_entity_names.contains(&r.to) 230 | }) 231 | .cloned() 232 | .collect(); 233 | 234 | Ok(KnowledgeGraph { 235 | entities: filtered_entities, 236 | relations: filtered_relations, 237 | }) 238 | } 239 | } 240 | 241 | #[derive(Debug, Deserialize)] 242 | pub struct AddObservationParams { 243 | #[serde(rename = "entityName")] 244 | pub entity_name: String, 245 | pub contents: Vec, 246 | } 247 | 248 | #[derive(Debug, Serialize)] 249 | pub struct AddedObservationResult { 250 | #[serde(rename = "entityName")] 251 | pub entity_name: String, 252 | #[serde(rename = "addedObservations")] 253 | pub added_observations: Vec, 254 | } 255 | 256 | // For delete_observations 257 | #[derive(Debug, Deserialize)] 258 | pub struct DeleteObservationParams { 259 | #[serde(rename = "entityName")] 260 | pub entity_name: String, 261 | pub observations: Vec, 262 | } 263 | -------------------------------------------------------------------------------- /examples/pingpong/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pingpong" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | async-mcp = { path = "../.." } 8 | tokio = { version = "1.0", features = ["full"] } 9 | serde = { version = "1.0", features = ["derive"] } 10 | serde_json = "1.0" 11 | anyhow = "1.0" 12 | tracing-subscriber = "0.3" 13 | tracing = "0.1" 14 | home = "0.5.9" 15 | clap = { version = "4.4", features = ["derive"] } 16 | 17 | [[bin]] 18 | name = "pingpong" 19 | path = "./src/main.rs" 20 | 21 | [[bin]] 22 | name = "pingpong_client" 23 | path = "./src/client.rs" 24 | -------------------------------------------------------------------------------- /examples/pingpong/README.md: -------------------------------------------------------------------------------- 1 | # Ping/Pong exmaple with various transports 2 | 3 | This example demonstrates a simple ping/pong 4 | 5 | ### Tools 6 | 7 | - **ping** 8 | - Responds with pong 9 | 10 | 11 | ## Test locally 12 | ``` 13 | cat << 'EOF' | cargo run --bin pingpong 14 | {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "ping"}, "id": 1} 15 | EOF 16 | ``` -------------------------------------------------------------------------------- /examples/pingpong/src/client.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use anyhow::Result; 4 | use async_mcp::{ 5 | protocol::RequestOptions, 6 | transport::{ 7 | ClientInMemoryTransport, ClientSseTransportBuilder, ClientStdioTransport, Transport, 8 | }, 9 | }; 10 | use clap::{Parser, ValueEnum}; 11 | use pingpong::inmemory_server; 12 | use serde_json::json; 13 | use tracing::info; 14 | #[derive(Parser)] 15 | #[command(author, version, about, long_about = None)] 16 | struct Cli { 17 | /// Transport type to use 18 | #[arg(value_enum, default_value_t = TransportType::Stdio)] 19 | transport: TransportType, 20 | } 21 | 22 | #[derive(Copy, Clone, PartialEq, Eq, ValueEnum)] 23 | enum TransportType { 24 | Stdio, 25 | InMemory, 26 | Sse, 27 | Ws, 28 | } 29 | 30 | #[tokio::main] 31 | async fn main() -> Result<()> { 32 | tracing_subscriber::fmt() 33 | .with_max_level(tracing::Level::DEBUG) 34 | .with_writer(std::io::stderr) 35 | .init(); 36 | 37 | let cli = Cli::parse(); 38 | 39 | let response = match cli.transport { 40 | TransportType::Stdio => { 41 | // Build the server first 42 | // cargo build --bin pingpong_server 43 | let transport = ClientStdioTransport::new("./target/debug/pingpong", &[], None)?; 44 | transport.open().await?; 45 | // Create and start client 46 | let client = async_mcp::client::ClientBuilder::new(transport.clone()).build(); 47 | let client_clone = client.clone(); 48 | let _client_handle = tokio::spawn(async move { client_clone.start().await }); 49 | 50 | tokio::time::sleep(Duration::from_millis(100)).await; 51 | // Make a request 52 | client 53 | .request( 54 | "tools/call", 55 | Some(json!({"name": "ping", "arguments": {}})), 56 | RequestOptions::default().timeout(Duration::from_secs(5)), 57 | ) 58 | .await? 59 | } 60 | TransportType::Sse => { 61 | let transport = 62 | ClientSseTransportBuilder::new("http://localhost:3004".to_string()).build(); 63 | transport.open().await?; 64 | // Create and start client 65 | let client = async_mcp::client::ClientBuilder::new(transport.clone()).build(); 66 | let client_clone = client.clone(); 67 | let _client_handle = tokio::spawn(async move { client_clone.start().await }); 68 | 69 | // Make a request 70 | client 71 | .request( 72 | "tools/call", 73 | Some(json!({"name": "ping", "arguments": {}})), 74 | RequestOptions::default().timeout(Duration::from_secs(5)), 75 | ) 76 | .await? 77 | } 78 | TransportType::InMemory => { 79 | let client_transport = 80 | ClientInMemoryTransport::new(|t| tokio::spawn(inmemory_server(t))); 81 | client_transport.open().await?; 82 | let client = async_mcp::client::ClientBuilder::new(client_transport.clone()).build(); 83 | let client_clone = client.clone(); 84 | let _client_handle = tokio::spawn(async move { client_clone.start().await }); 85 | 86 | // Make a request 87 | client 88 | .request( 89 | "tools/call", 90 | Some(json!({"name": "ping", "arguments": {}})), 91 | RequestOptions::default().timeout(Duration::from_secs(5)), 92 | ) 93 | .await? 94 | } 95 | TransportType::Ws => { 96 | let transport = async_mcp::transport::ClientWsTransportBuilder::new( 97 | "ws://localhost:3004/ws".to_string(), 98 | ) 99 | .build(); 100 | transport.open().await?; 101 | // Create and start client 102 | let client = async_mcp::client::ClientBuilder::new(transport.clone()).build(); 103 | let client_clone = client.clone(); 104 | let _client_handle = tokio::spawn(async move { client_clone.start().await }); 105 | 106 | // Make a request 107 | client 108 | .request( 109 | "tools/call", 110 | Some(json!({"name": "ping", "arguments": {}})), 111 | RequestOptions::default().timeout(Duration::from_secs(5)), 112 | ) 113 | .await? 114 | } 115 | }; 116 | info!("response: {response}"); 117 | Ok(()) 118 | } 119 | -------------------------------------------------------------------------------- /examples/pingpong/src/lib.rs: -------------------------------------------------------------------------------- 1 | use async_mcp::transport::ServerInMemoryTransport; 2 | use server::build_server; 3 | 4 | pub mod server; 5 | pub async fn inmemory_server(transport: ServerInMemoryTransport) { 6 | let server = build_server(transport.clone()); 7 | server.listen().await.unwrap(); 8 | } 9 | -------------------------------------------------------------------------------- /examples/pingpong/src/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_mcp::{run_http_server, transport::ServerStdioTransport}; 3 | use clap::{Parser, ValueEnum}; 4 | use pingpong::server::build_server; 5 | 6 | #[derive(Parser)] 7 | #[command(author, version, about, long_about = None)] 8 | struct Cli { 9 | /// Transport type to use 10 | #[arg(value_enum, default_value_t = TransportType::Http)] 11 | transport: TransportType, 12 | } 13 | 14 | #[derive(Copy, Clone, PartialEq, Eq, ValueEnum)] 15 | enum TransportType { 16 | Stdio, 17 | Http, 18 | } 19 | 20 | #[tokio::main] 21 | async fn main() -> Result<()> { 22 | tracing_subscriber::fmt() 23 | .with_max_level(tracing::Level::DEBUG) 24 | // needs to be stderr due to stdio transport 25 | .with_writer(std::io::stderr) 26 | .init(); 27 | 28 | let cli = Cli::parse(); 29 | 30 | match cli.transport { 31 | TransportType::Stdio => { 32 | let server = build_server(ServerStdioTransport); 33 | server 34 | .listen() 35 | .await 36 | .map_err(|e| anyhow::anyhow!("Server error: {}", e))?; 37 | } 38 | TransportType::Http => { 39 | run_http_server(3004, None, |transport, _, _| async move { 40 | let server = build_server(transport); 41 | Ok(server) 42 | }) 43 | .await?; 44 | } 45 | }; 46 | Ok(()) 47 | } 48 | -------------------------------------------------------------------------------- /examples/pingpong/src/server.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_mcp::server::Server; 3 | use async_mcp::transport::Transport; 4 | use async_mcp::types::{ 5 | CallToolRequest, CallToolResponse, ListRequest, ResourcesListResponse, ServerCapabilities, 6 | ToolResponseContent, ToolsListResponse, 7 | }; 8 | use serde_json::json; 9 | 10 | pub fn build_server(t: T) -> Server { 11 | Server::builder(t) 12 | .capabilities(ServerCapabilities { 13 | tools: Some(json!({})), 14 | ..Default::default() 15 | }) 16 | .request_handler("tools/list", |req: ListRequest| { 17 | Box::pin(async move { list_tools(req) }) 18 | }) 19 | .request_handler("tools/call", |req: CallToolRequest| { 20 | Box::pin(async move { call_tool(req) }) 21 | }) 22 | .request_handler("resources/list", |_req: ListRequest| { 23 | Box::pin(async move { 24 | Ok(ResourcesListResponse { 25 | resources: vec![], 26 | next_cursor: None, 27 | meta: None, 28 | }) 29 | }) 30 | }) 31 | .build() 32 | } 33 | 34 | fn list_tools(_req: ListRequest) -> Result { 35 | let response = json!({ 36 | "tools": [ 37 | { 38 | "name": "ping", 39 | "description": "Send a ping to get a pong response", 40 | "inputSchema": { 41 | "type": "object", 42 | "properties": {}, 43 | "required": [] 44 | }, 45 | }, 46 | ]}); 47 | Ok(serde_json::from_value(response)?) 48 | } 49 | 50 | fn call_tool(req: CallToolRequest) -> Result { 51 | let name = req.name.as_str(); 52 | let result = match name { 53 | "ping" => ToolResponseContent::Text { 54 | text: "pong".to_string(), 55 | }, 56 | _ => return Err(anyhow::anyhow!("Unknown tool: {}", req.name)), 57 | }; 58 | Ok(CallToolResponse { 59 | content: vec![result], 60 | is_error: None, 61 | meta: None, 62 | }) 63 | } 64 | -------------------------------------------------------------------------------- /mcp_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | tee /tmp/mcp.fifo | ./target/debug/file_system | tee /tmp/mcp.fifo -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | protocol::{Protocol, ProtocolBuilder, RequestOptions}, 3 | transport::Transport, 4 | types::{ 5 | ClientCapabilities, Implementation, InitializeRequest, InitializeResponse, 6 | RootCapabilities, LATEST_PROTOCOL_VERSION, 7 | }, 8 | }; 9 | 10 | use anyhow::Result; 11 | use tracing::debug; 12 | 13 | #[derive(Clone)] 14 | pub struct Client { 15 | protocol: Protocol, 16 | } 17 | 18 | impl Client { 19 | pub fn builder(transport: T) -> ClientBuilder { 20 | ClientBuilder::new(transport) 21 | } 22 | 23 | pub async fn initialize(&self, client_info: Implementation) -> Result { 24 | let request = InitializeRequest { 25 | protocol_version: LATEST_PROTOCOL_VERSION.to_string(), 26 | capabilities: ClientCapabilities { 27 | experimental: Some(serde_json::json!({})), 28 | sampling: Some(serde_json::json!({})), 29 | roots: Some(RootCapabilities { 30 | list_changed: Some(false), 31 | }), 32 | }, 33 | client_info, 34 | }; 35 | let response = self 36 | .request( 37 | "initialize", 38 | Some(serde_json::to_value(request)?), 39 | RequestOptions::default(), 40 | ) 41 | .await?; 42 | let response: InitializeResponse = serde_json::from_value(response) 43 | .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?; 44 | 45 | if response.protocol_version != LATEST_PROTOCOL_VERSION { 46 | return Err(anyhow::anyhow!( 47 | "Unsupported protocol version: {}", 48 | response.protocol_version 49 | )); 50 | } 51 | 52 | debug!( 53 | "Initialized with protocol version: {}", 54 | response.protocol_version 55 | ); 56 | self.protocol 57 | .notify("notifications/initialized", None) 58 | .await?; 59 | Ok(response) 60 | } 61 | 62 | pub async fn request( 63 | &self, 64 | method: &str, 65 | params: Option, 66 | options: RequestOptions, 67 | ) -> Result { 68 | let response = self.protocol.request(method, params, options).await?; 69 | response 70 | .result 71 | .ok_or_else(|| anyhow::anyhow!("Request failed: {:?}", response.error)) 72 | } 73 | 74 | pub async fn start(&self) -> Result<()> { 75 | self.protocol.listen().await 76 | } 77 | } 78 | 79 | pub struct ClientBuilder { 80 | protocol: ProtocolBuilder, 81 | } 82 | 83 | impl ClientBuilder { 84 | pub fn new(transport: T) -> Self { 85 | Self { 86 | protocol: ProtocolBuilder::new(transport), 87 | } 88 | } 89 | 90 | pub fn build(self) -> Client { 91 | Client { 92 | protocol: self.protocol.build(), 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod client; 2 | pub mod protocol; 3 | pub mod registry; 4 | pub mod server; 5 | pub mod sse; 6 | pub use sse::http_server::run_http_server; 7 | pub mod transport; 8 | pub mod types; 9 | -------------------------------------------------------------------------------- /src/protocol.rs: -------------------------------------------------------------------------------- 1 | use super::transport::{ 2 | JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Transport, 3 | }; 4 | use super::types::ErrorCode; 5 | use anyhow::anyhow; 6 | use anyhow::Result; 7 | use async_trait::async_trait; 8 | use serde::de::DeserializeOwned; 9 | use serde::Serialize; 10 | use std::pin::Pin; 11 | use std::sync::atomic::Ordering; 12 | use std::time::Duration; 13 | use std::{ 14 | collections::HashMap, 15 | sync::{atomic::AtomicU64, Arc}, 16 | }; 17 | use tokio::sync::oneshot; 18 | use tokio::sync::Mutex; 19 | use tokio::time::timeout; 20 | use tracing::debug; 21 | 22 | #[derive(Clone)] 23 | pub struct Protocol { 24 | transport: Arc, 25 | 26 | request_id: Arc, 27 | pending_requests: Arc>>>, 28 | request_handlers: Arc>>>, 29 | notification_handlers: Arc>>>, 30 | } 31 | 32 | impl Protocol { 33 | pub fn builder(transport: T) -> ProtocolBuilder { 34 | ProtocolBuilder::new(transport) 35 | } 36 | 37 | pub async fn notify(&self, method: &str, params: Option) -> Result<()> { 38 | let notification = JsonRpcNotification { 39 | method: method.to_string(), 40 | params, 41 | ..Default::default() 42 | }; 43 | let msg = JsonRpcMessage::Notification(notification); 44 | self.transport.send(&msg).await?; 45 | Ok(()) 46 | } 47 | 48 | pub async fn request( 49 | &self, 50 | method: &str, 51 | params: Option, 52 | options: RequestOptions, 53 | ) -> Result { 54 | let id = self.request_id.fetch_add(1, Ordering::SeqCst); 55 | 56 | // Create a oneshot channel for this request 57 | let (tx, rx) = oneshot::channel(); 58 | 59 | // Store the sender 60 | { 61 | let mut pending = self.pending_requests.lock().await; 62 | pending.insert(id, tx); 63 | } 64 | 65 | // Send the request 66 | let msg = JsonRpcMessage::Request(JsonRpcRequest { 67 | id, 68 | method: method.to_string(), 69 | params, 70 | ..Default::default() 71 | }); 72 | self.transport.send(&msg).await?; 73 | 74 | // Wait for response with timeout 75 | match timeout(options.timeout, rx) 76 | .await 77 | .map_err(|_| anyhow!("Request timed out"))? 78 | { 79 | Ok(response) => Ok(response), 80 | Err(_) => { 81 | // Clean up the pending request if receiver was dropped 82 | let mut pending = self.pending_requests.lock().await; 83 | pending.remove(&id); 84 | Err(anyhow!("Request cancelled")) 85 | } 86 | } 87 | } 88 | 89 | pub async fn listen(&self) -> Result<()> { 90 | debug!("Listening for requests"); 91 | loop { 92 | let message = self.transport.receive().await; 93 | 94 | let message = match message { 95 | Ok(msg) => msg, 96 | Err(e) => { 97 | tracing::error!("Failed to parse message: {:?}", e); 98 | continue; 99 | } 100 | }; 101 | 102 | // Exit loop when transport signals shutdown with None 103 | if message.is_none() { 104 | break; 105 | } 106 | 107 | match message.unwrap() { 108 | JsonRpcMessage::Request(request) => self.handle_request(request).await?, 109 | JsonRpcMessage::Response(response) => { 110 | let id = response.id; 111 | let mut pending = self.pending_requests.lock().await; 112 | if let Some(tx) = pending.remove(&id) { 113 | let _ = tx.send(response); 114 | } 115 | } 116 | JsonRpcMessage::Notification(notification) => { 117 | let handlers = self.notification_handlers.lock().await; 118 | if let Some(handler) = handlers.get(¬ification.method) { 119 | handler.handle(notification).await?; 120 | } 121 | } 122 | } 123 | } 124 | Ok(()) 125 | } 126 | 127 | async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> { 128 | let handlers = self.request_handlers.lock().await; 129 | if let Some(handler) = handlers.get(&request.method) { 130 | match handler.handle(request.clone()).await { 131 | Ok(response) => { 132 | let msg = JsonRpcMessage::Response(response); 133 | self.transport.send(&msg).await?; 134 | } 135 | Err(e) => { 136 | let error_response = JsonRpcResponse { 137 | id: request.id, 138 | result: None, 139 | error: Some(JsonRpcError { 140 | code: ErrorCode::InternalError as i32, 141 | message: e.to_string(), 142 | data: None, 143 | }), 144 | ..Default::default() 145 | }; 146 | let msg = JsonRpcMessage::Response(error_response); 147 | self.transport.send(&msg).await?; 148 | } 149 | } 150 | } else { 151 | self.transport 152 | .send(&JsonRpcMessage::Response(JsonRpcResponse { 153 | id: request.id, 154 | error: Some(JsonRpcError { 155 | code: ErrorCode::MethodNotFound as i32, 156 | message: format!("Method not found: {}", request.method), 157 | data: None, 158 | }), 159 | ..Default::default() 160 | })) 161 | .await?; 162 | } 163 | Ok(()) 164 | } 165 | } 166 | 167 | /// The default request timeout, in milliseconds 168 | pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000; 169 | pub struct RequestOptions { 170 | timeout: Duration, 171 | } 172 | 173 | impl RequestOptions { 174 | pub fn timeout(self, timeout: Duration) -> Self { 175 | Self { timeout } 176 | } 177 | } 178 | 179 | impl Default for RequestOptions { 180 | fn default() -> Self { 181 | Self { 182 | timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC), 183 | } 184 | } 185 | } 186 | 187 | pub struct ProtocolBuilder { 188 | transport: T, 189 | request_handlers: HashMap>, 190 | notification_handlers: HashMap>, 191 | } 192 | impl ProtocolBuilder { 193 | pub fn new(transport: T) -> Self { 194 | Self { 195 | transport, 196 | request_handlers: HashMap::new(), 197 | notification_handlers: HashMap::new(), 198 | } 199 | } 200 | /// Register a typed request handler 201 | pub fn request_handler( 202 | mut self, 203 | method: &str, 204 | handler: impl Fn(Req) -> Pin> + Send>> 205 | + Send 206 | + Sync 207 | + 'static, 208 | ) -> Self 209 | where 210 | Req: DeserializeOwned + Send + Sync + 'static, 211 | Resp: Serialize + Send + Sync + 'static, 212 | { 213 | let handler = TypedRequestHandler { 214 | handler: Box::new(handler), 215 | _phantom: std::marker::PhantomData, 216 | }; 217 | 218 | self.request_handlers 219 | .insert(method.to_string(), Box::new(handler)); 220 | self 221 | } 222 | 223 | pub fn has_request_handler(&self, method: &str) -> bool { 224 | self.request_handlers.contains_key(method) 225 | } 226 | 227 | pub fn notification_handler( 228 | mut self, 229 | method: &str, 230 | handler: impl Fn(N) -> Pin> + Send>> 231 | + Send 232 | + Sync 233 | + 'static, 234 | ) -> Self 235 | where 236 | N: DeserializeOwned + Send + Sync + 'static, 237 | { 238 | self.notification_handlers.insert( 239 | method.to_string(), 240 | Box::new(TypedNotificationHandler { 241 | handler: Box::new(handler), 242 | _phantom: std::marker::PhantomData, 243 | }), 244 | ); 245 | self 246 | } 247 | 248 | pub fn build(self) -> Protocol { 249 | Protocol { 250 | transport: Arc::new(self.transport), 251 | request_handlers: Arc::new(Mutex::new(self.request_handlers)), 252 | notification_handlers: Arc::new(Mutex::new(self.notification_handlers)), 253 | request_id: Arc::new(AtomicU64::new(0)), 254 | pending_requests: Arc::new(Mutex::new(HashMap::new())), 255 | } 256 | } 257 | } 258 | 259 | // Update the handler traits to be async 260 | #[async_trait] 261 | trait RequestHandler: Send + Sync { 262 | async fn handle(&self, request: JsonRpcRequest) -> Result; 263 | } 264 | 265 | #[async_trait] 266 | trait NotificationHandler: Send + Sync { 267 | async fn handle(&self, notification: JsonRpcNotification) -> Result<()>; 268 | } 269 | 270 | // Update the TypedRequestHandler to use async handlers 271 | struct TypedRequestHandler 272 | where 273 | Req: DeserializeOwned + Send + Sync + 'static, 274 | Resp: Serialize + Send + Sync + 'static, 275 | { 276 | handler: Box< 277 | dyn Fn(Req) -> std::pin::Pin> + Send>> 278 | + Send 279 | + Sync, 280 | >, 281 | _phantom: std::marker::PhantomData<(Req, Resp)>, 282 | } 283 | 284 | #[async_trait] 285 | impl RequestHandler for TypedRequestHandler 286 | where 287 | Req: DeserializeOwned + Send + Sync + 'static, 288 | Resp: Serialize + Send + Sync + 'static, 289 | { 290 | async fn handle(&self, request: JsonRpcRequest) -> Result { 291 | let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null() 292 | { 293 | serde_json::from_value(serde_json::Value::Null)? 294 | } else { 295 | serde_json::from_value(request.params.unwrap())? 296 | }; 297 | let result = (self.handler)(params).await?; 298 | Ok(JsonRpcResponse { 299 | id: request.id, 300 | result: Some(serde_json::to_value(result)?), 301 | error: None, 302 | ..Default::default() 303 | }) 304 | } 305 | } 306 | 307 | struct TypedNotificationHandler 308 | where 309 | N: DeserializeOwned + Send + Sync + 'static, 310 | { 311 | handler: Box< 312 | dyn Fn(N) -> std::pin::Pin> + Send>> 313 | + Send 314 | + Sync, 315 | >, 316 | _phantom: std::marker::PhantomData, 317 | } 318 | 319 | #[async_trait] 320 | impl NotificationHandler for TypedNotificationHandler 321 | where 322 | N: DeserializeOwned + Send + Sync + 'static, 323 | { 324 | async fn handle(&self, notification: JsonRpcNotification) -> Result<()> { 325 | let params: N = 326 | if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() { 327 | serde_json::from_value(serde_json::Value::Null)? 328 | } else { 329 | match ¬ification.params { 330 | Some(params) => { 331 | let res = serde_json::from_value(params.clone()); 332 | match res { 333 | Ok(r) => r, 334 | Err(e) => { 335 | tracing::warn!( 336 | "Failed to parse notification params: {:?}. Params: {:?}", 337 | e, 338 | notification.params 339 | ); 340 | serde_json::from_value(serde_json::Value::Null)? 341 | } 342 | } 343 | } 344 | None => serde_json::from_value(serde_json::Value::Null)?, 345 | } 346 | }; 347 | (self.handler)(params).await 348 | } 349 | } 350 | -------------------------------------------------------------------------------- /src/registry.rs: -------------------------------------------------------------------------------- 1 | use crate::types::{CallToolRequest, CallToolResponse, Tool}; 2 | use anyhow::Result; 3 | use std::collections::HashMap; 4 | use std::future::Future; 5 | use std::pin::Pin; 6 | 7 | pub struct Tools { 8 | tool_handlers: HashMap, 9 | } 10 | 11 | impl Tools { 12 | pub(crate) fn new(map: HashMap) -> Self { 13 | Self { tool_handlers: map } 14 | } 15 | 16 | pub fn get_tool(&self, name: &str) -> Option { 17 | self.tool_handlers 18 | .get(name) 19 | .map(|tool_handler| tool_handler.tool.clone()) 20 | } 21 | 22 | pub async fn call_tool(&self, req: CallToolRequest) -> Result { 23 | let handler = self 24 | .tool_handlers 25 | .get(&req.name) 26 | .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", req.name))?; 27 | 28 | (handler.f)(req).await 29 | } 30 | 31 | pub fn list_tools(&self) -> Vec { 32 | self.tool_handlers 33 | .values() 34 | .map(|tool_handler| tool_handler.tool.clone()) 35 | .collect() 36 | } 37 | } 38 | 39 | pub(crate) struct ToolHandler { 40 | pub tool: Tool, 41 | pub f: Box< 42 | dyn Fn(CallToolRequest) -> Pin> + Send>> 43 | + Send 44 | + Sync, 45 | >, 46 | } 47 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | sync::{Arc, RwLock}, 4 | }; 5 | 6 | use crate::{ 7 | registry::{ToolHandler, Tools}, 8 | types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse}, 9 | }; 10 | 11 | use super::{ 12 | protocol::{Protocol, ProtocolBuilder}, 13 | transport::Transport, 14 | types::{ 15 | ClientCapabilities, Implementation, InitializeRequest, InitializeResponse, 16 | ServerCapabilities, LATEST_PROTOCOL_VERSION, 17 | }, 18 | }; 19 | use anyhow::Result; 20 | use serde::{de::DeserializeOwned, Serialize}; 21 | use std::future::Future; 22 | use std::pin::Pin; 23 | 24 | #[derive(Clone)] 25 | pub struct ServerState { 26 | client_capabilities: Option, 27 | client_info: Option, 28 | initialized: bool, 29 | } 30 | 31 | #[derive(Clone)] 32 | pub struct Server { 33 | protocol: Protocol, 34 | state: Arc>, 35 | } 36 | 37 | pub struct ServerBuilder { 38 | protocol: ProtocolBuilder, 39 | server_info: Implementation, 40 | capabilities: ServerCapabilities, 41 | tools: HashMap, 42 | } 43 | 44 | impl ServerBuilder { 45 | pub fn name>(mut self, name: S) -> Self { 46 | self.server_info.name = name.into(); 47 | self 48 | } 49 | 50 | pub fn version>(mut self, version: S) -> Self { 51 | self.server_info.version = version.into(); 52 | self 53 | } 54 | 55 | pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self { 56 | self.capabilities = capabilities; 57 | self 58 | } 59 | 60 | /// Register a typed request handler 61 | /// for higher-level api use add tool 62 | pub fn request_handler( 63 | mut self, 64 | method: &str, 65 | handler: impl Fn(Req) -> Pin> + Send>> 66 | + Send 67 | + Sync 68 | + 'static, 69 | ) -> Self 70 | where 71 | Req: DeserializeOwned + Send + Sync + 'static, 72 | Resp: Serialize + Send + Sync + 'static, 73 | { 74 | self.protocol = self.protocol.request_handler(method, handler); 75 | self 76 | } 77 | 78 | pub fn notification_handler( 79 | mut self, 80 | method: &str, 81 | handler: impl Fn(N) -> Pin> + Send>> 82 | + Send 83 | + Sync 84 | + 'static, 85 | ) -> Self 86 | where 87 | N: DeserializeOwned + Send + Sync + 'static, 88 | { 89 | self.protocol = self.protocol.notification_handler(method, handler); 90 | self 91 | } 92 | 93 | pub fn register_tool( 94 | &mut self, 95 | tool: Tool, 96 | f: impl Fn(CallToolRequest) -> Pin> + Send>> 97 | + Send 98 | + Sync 99 | + 'static, 100 | ) { 101 | self.tools.insert( 102 | tool.name.clone(), 103 | ToolHandler { 104 | tool, 105 | f: Box::new(f), 106 | }, 107 | ); 108 | } 109 | 110 | pub fn build(self) -> Server { 111 | Server::new(self) 112 | } 113 | } 114 | 115 | impl Server { 116 | pub fn builder(transport: T) -> ServerBuilder { 117 | ServerBuilder { 118 | protocol: Protocol::builder(transport), 119 | server_info: Implementation { 120 | name: env!("CARGO_PKG_NAME").to_string(), 121 | version: env!("CARGO_PKG_VERSION").to_string(), 122 | }, 123 | capabilities: Default::default(), 124 | tools: HashMap::new(), 125 | } 126 | } 127 | 128 | fn new(builder: ServerBuilder) -> Self { 129 | let state = Arc::new(RwLock::new(ServerState { 130 | client_capabilities: None, 131 | client_info: None, 132 | initialized: false, 133 | })); 134 | 135 | // Initialize protocol with handlers 136 | let mut protocol = builder 137 | .protocol 138 | .request_handler( 139 | "initialize", 140 | Self::handle_init(state.clone(), builder.server_info, builder.capabilities), 141 | ) 142 | .notification_handler( 143 | "notifications/initialized", 144 | Self::handle_initialized(state.clone()), 145 | ); 146 | 147 | // Add tools handlers if not already present 148 | if !protocol.has_request_handler("tools/list") { 149 | let tools = Arc::new(Tools::new(builder.tools)); 150 | let tools_clone = tools.clone(); 151 | let tools_list = tools.clone(); 152 | let tools_call = tools_clone.clone(); 153 | 154 | protocol = protocol 155 | .request_handler("tools/list", move |_req: ListRequest| { 156 | let tools = tools_list.clone(); 157 | Box::pin(async move { 158 | Ok(ToolsListResponse { 159 | tools: tools.list_tools(), 160 | next_cursor: None, 161 | meta: None, 162 | }) 163 | }) 164 | }) 165 | .request_handler("tools/call", move |req: CallToolRequest| { 166 | let tools = tools_call.clone(); 167 | Box::pin(async move { tools.call_tool(req).await }) 168 | }); 169 | } 170 | 171 | Server { 172 | protocol: protocol.build(), 173 | state, 174 | } 175 | } 176 | 177 | // Helper function for initialize handler 178 | fn handle_init( 179 | state: Arc>, 180 | server_info: Implementation, 181 | capabilities: ServerCapabilities, 182 | ) -> impl Fn( 183 | InitializeRequest, 184 | ) 185 | -> Pin> + Send>> { 186 | move |req| { 187 | let state = state.clone(); 188 | let server_info = server_info.clone(); 189 | let capabilities = capabilities.clone(); 190 | 191 | Box::pin(async move { 192 | let mut state = state 193 | .write() 194 | .map_err(|_| anyhow::anyhow!("Lock poisoned"))?; 195 | state.client_capabilities = Some(req.capabilities); 196 | state.client_info = Some(req.client_info); 197 | 198 | Ok(InitializeResponse { 199 | protocol_version: LATEST_PROTOCOL_VERSION.to_string(), 200 | capabilities, 201 | server_info, 202 | }) 203 | }) 204 | } 205 | } 206 | 207 | // Helper function for initialized handler 208 | fn handle_initialized( 209 | state: Arc>, 210 | ) -> impl Fn(()) -> Pin> + Send>> { 211 | move |_| { 212 | let state = state.clone(); 213 | Box::pin(async move { 214 | let mut state = state 215 | .write() 216 | .map_err(|_| anyhow::anyhow!("Lock poisoned"))?; 217 | state.initialized = true; 218 | Ok(()) 219 | }) 220 | } 221 | } 222 | 223 | pub fn get_client_capabilities(&self) -> Option { 224 | self.state.read().ok()?.client_capabilities.clone() 225 | } 226 | 227 | pub fn get_client_info(&self) -> Option { 228 | self.state.read().ok()?.client_info.clone() 229 | } 230 | 231 | pub fn is_initialized(&self) -> bool { 232 | self.state 233 | .read() 234 | .ok() 235 | .map(|state| state.initialized) 236 | .unwrap_or(false) 237 | } 238 | 239 | pub async fn listen(&self) -> Result<()> { 240 | self.protocol.listen().await 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /src/sse/http_server.rs: -------------------------------------------------------------------------------- 1 | use actix_web::middleware::Logger; 2 | use actix_web::web::Payload; 3 | use actix_web::web::Query; 4 | use actix_web::HttpMessage; 5 | use actix_web::{web, App, HttpResponse, HttpServer}; 6 | use anyhow::Result; 7 | use futures::StreamExt; 8 | use uuid::Uuid; 9 | 10 | use crate::server::Server; 11 | use crate::sse::middleware::{AuthConfig, JwtAuth}; 12 | use crate::transport::ServerHttpTransport; 13 | use crate::transport::{handle_ws_connection, Message, ServerSseTransport, ServerWsTransport}; 14 | use serde::{Deserialize, Serialize}; 15 | use std::collections::HashMap; 16 | use std::fmt::Debug; 17 | use std::sync::{Arc, Mutex}; 18 | use tokio::sync::broadcast; 19 | use tracing::{debug, error, info}; 20 | 21 | /// Server-side SSE transport that handles HTTP POST requests for incoming messages 22 | /// and sends responses via SSE 23 | #[derive(Debug, Serialize, Deserialize)] 24 | pub struct Claims { 25 | pub exp: usize, 26 | pub iat: usize, 27 | } 28 | 29 | #[derive(Clone)] 30 | pub struct Endpoint(pub String); 31 | 32 | #[derive(Deserialize)] 33 | pub struct MessageQuery { 34 | #[serde(rename = "sessionId")] 35 | session_id: Option, 36 | } 37 | 38 | #[derive(Clone)] 39 | pub struct SessionState { 40 | sessions: Arc>>, 41 | build_server: Arc< 42 | dyn Fn( 43 | ServerHttpTransport, 44 | Option, 45 | String, 46 | ) 47 | -> futures::future::BoxFuture<'static, Result>> 48 | + Send 49 | + Sync, 50 | >, 51 | endpoint: String, 52 | } 53 | 54 | impl SessionState { 55 | /// Create a new SessionState instance with configurable parameters 56 | pub fn new( 57 | endpoint: String, 58 | build_server: Arc< 59 | dyn Fn( 60 | ServerHttpTransport, 61 | Option, 62 | String, 63 | ) 64 | -> futures::future::BoxFuture<'static, Result>> 65 | + Send 66 | + Sync, 67 | >, 68 | sessions: Arc>>, 69 | ) -> Self { 70 | Self { 71 | sessions, 72 | build_server, 73 | endpoint, 74 | } 75 | } 76 | } 77 | 78 | /// Run a server instance with the specified transport 79 | pub async fn run_http_server( 80 | port: u16, 81 | jwt_secret: Option, 82 | build_server: F, 83 | ) -> Result<()> 84 | where 85 | F: Fn(ServerHttpTransport, Option, String) -> Fut + Send + Sync + 'static, 86 | Fut: futures::Future>> + Send + 'static, 87 | { 88 | info!("Starting server on http://0.0.0.0:{}", port); 89 | info!("WebSocket endpoint: ws://0.0.0.0:{}/ws", port); 90 | info!("SSE endpoint: http://0.0.0.0:{}/sse", port); 91 | 92 | let sessions = Arc::new(Mutex::new(HashMap::new())); 93 | 94 | // Box the future when creating the Arc 95 | let build_server = Arc::new(move |t, o, session_id| { 96 | Box::pin(build_server(t, o, session_id)) as futures::future::BoxFuture<_> 97 | }); 98 | 99 | let auth_config = jwt_secret.map(|jwt_secret| AuthConfig { jwt_secret }); 100 | let http_server = http_server(port, sessions, auth_config, build_server); 101 | 102 | http_server.await?; 103 | Ok(()) 104 | } 105 | 106 | pub async fn http_server( 107 | port: u16, 108 | sessions: Arc>>, 109 | auth_config: Option, 110 | build_server: Arc< 111 | dyn Fn( 112 | ServerHttpTransport, 113 | Option, 114 | String, 115 | ) 116 | -> futures::future::BoxFuture<'static, Result>> 117 | + Send 118 | + Sync, 119 | >, 120 | ) -> std::result::Result<(), std::io::Error> { 121 | let session_state = SessionState { 122 | sessions, 123 | build_server, 124 | endpoint: format!("http://0.0.0.0:{}", port), 125 | }; 126 | 127 | let server = HttpServer::new(move || { 128 | let session_state = session_state.clone(); 129 | App::new() 130 | .wrap(Logger::default()) 131 | .wrap(JwtAuth::new(auth_config.clone())) 132 | .app_data(web::Data::new(session_state)) 133 | .route("/sse", web::get().to(sse_handler)) 134 | .route("/message", web::post().to(message_handler)) 135 | .route("/ws", web::get().to(ws_handler)) 136 | }) 137 | .bind(("0.0.0.0", port))? 138 | .run(); 139 | 140 | server.await 141 | } 142 | 143 | pub async fn sse_handler( 144 | req: actix_web::HttpRequest, 145 | session_state: web::Data, 146 | ) -> HttpResponse { 147 | let endpoint = req.extensions().get::().cloned(); 148 | let session_metadata = req.extensions().get::().cloned(); 149 | let client_ip = req 150 | .peer_addr() 151 | .map(|addr| addr.ip().to_string()) 152 | .unwrap_or_else(|| "unknown".to_string()); 153 | 154 | debug!("New SSE connection request from {}", client_ip); 155 | 156 | // Create new session 157 | let session_id = Uuid::new_v4().to_string(); 158 | 159 | // Create channel for SSE messages 160 | let (sse_tx, sse_rx) = broadcast::channel(100); 161 | 162 | // Create new transport for this session 163 | let transport = ServerHttpTransport::Sse(ServerSseTransport::new(sse_tx.clone())); 164 | 165 | // Store transport in sessions map 166 | session_state 167 | .sessions 168 | .lock() 169 | .unwrap() 170 | .insert(session_id.clone(), transport.clone()); 171 | 172 | debug!( 173 | "SSE connection established for {} with session_id {}", 174 | client_ip, session_id 175 | ); 176 | let endpoint = endpoint.map_or(session_state.endpoint.clone(), |e| e.0); 177 | // Create initial endpoint info event 178 | let endpoint_info = 179 | format!("event: endpoint\ndata: {endpoint}/message?sessionId={session_id}\n\n",); 180 | 181 | let stream = futures::stream::once(async move { 182 | Ok::<_, std::convert::Infallible>(web::Bytes::from(endpoint_info)) 183 | }) 184 | .chain(futures::stream::unfold(sse_rx, move |mut rx| { 185 | let client_ip = client_ip.clone(); 186 | async move { 187 | match rx.recv().await { 188 | Ok(msg) => { 189 | // Show first and last 500 characters for debugging 190 | let json = serde_json::to_string(&msg).unwrap(); 191 | if json.len() > 1000 { 192 | let first = &json[..500]; 193 | let last = &json[json.len() - 500..]; 194 | debug!("Sending SSE message to {}: {}...{}", client_ip, first, last); 195 | } else { 196 | debug!("Sending SSE message to {}: {}", client_ip, json); 197 | } 198 | let sse_data = format!("data: {}\n\n", json); 199 | Some(( 200 | Ok::<_, std::convert::Infallible>(web::Bytes::from(sse_data)), 201 | rx, 202 | )) 203 | } 204 | _ => None, 205 | } 206 | } 207 | })); 208 | 209 | // Create and start server instance for this session 210 | let transport_clone = transport.clone(); 211 | let build_server = session_state.build_server.clone(); 212 | let session_metadata = session_metadata.clone(); 213 | let ses_id = session_id.clone(); 214 | tokio::spawn(async move { 215 | match build_server(transport_clone, session_metadata, ses_id.clone()).await { 216 | Ok(server) => { 217 | if let Err(e) = server.listen().await { 218 | error!("Server error: {:?}", e); 219 | } 220 | } 221 | Err(e) => { 222 | error!("Failed to build server: {:?}", e); 223 | } 224 | } 225 | }); 226 | 227 | HttpResponse::Ok() 228 | .append_header(("X-Session-Id", session_id)) 229 | .content_type("text/event-stream") 230 | .streaming(stream) 231 | } 232 | 233 | pub async fn message_handler( 234 | query: Query, 235 | message: web::Json, 236 | session_state: web::Data, 237 | ) -> HttpResponse { 238 | if let Some(session_id) = &query.session_id { 239 | let sessions = session_state.sessions.lock().unwrap(); 240 | if let Some(transport) = sessions.get(session_id) { 241 | match transport { 242 | ServerHttpTransport::Sse(sse) => match sse.send_message(message.into_inner()).await 243 | { 244 | Ok(_) => { 245 | debug!("Successfully sent message to session {}", session_id); 246 | HttpResponse::Accepted().finish() 247 | } 248 | Err(e) => { 249 | error!("Failed to send message to session {}: {:?}", session_id, e); 250 | HttpResponse::InternalServerError().finish() 251 | } 252 | }, 253 | ServerHttpTransport::Ws(_) => HttpResponse::BadRequest() 254 | .body("Cannot send message to WebSocket connection through HTTP endpoint"), 255 | } 256 | } else { 257 | HttpResponse::NotFound().body(format!("Session {} not found", session_id)) 258 | } 259 | } else { 260 | HttpResponse::BadRequest().body("Session ID not specified") 261 | } 262 | } 263 | 264 | pub async fn ws_handler( 265 | req: actix_web::HttpRequest, 266 | body: Payload, 267 | session_state: web::Data, 268 | ) -> Result { 269 | let session_metadata = req.extensions().get::().cloned(); 270 | 271 | let (response, session, msg_stream) = actix_ws::handle(&req, body)?; 272 | 273 | let client_ip = req 274 | .peer_addr() 275 | .map(|addr| addr.ip().to_string()) 276 | .unwrap_or_else(|| "unknown".to_string()); 277 | 278 | info!("New WebSocket connection from {}", client_ip); 279 | 280 | // Create channels for message passing 281 | let (tx, rx) = broadcast::channel(100); 282 | let transport = 283 | ServerHttpTransport::Ws(ServerWsTransport::new(session.clone(), rx.resubscribe())); 284 | 285 | // Store transport in sessions map 286 | let session_id = Uuid::new_v4().to_string(); 287 | session_state 288 | .sessions 289 | .lock() 290 | .unwrap() 291 | .insert(session_id.clone(), transport.clone()); 292 | 293 | // Start WebSocket handling in the background 294 | actix_web::rt::spawn(async move { 295 | let _ = handle_ws_connection(session, msg_stream, tx.clone(), rx.resubscribe()).await; 296 | }); 297 | 298 | // Spawn server instance 299 | let build_server = session_state.build_server.clone(); 300 | let session_metadata = session_metadata.clone(); 301 | actix_web::rt::spawn(async move { 302 | if let Ok(server) = build_server(transport, session_metadata, session_id.clone()).await { 303 | let _ = server.listen().await; 304 | } 305 | }); 306 | 307 | Ok(response) 308 | } 309 | -------------------------------------------------------------------------------- /src/sse/middleware.rs: -------------------------------------------------------------------------------- 1 | use actix_web::{ 2 | body::EitherBody, 3 | dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, 4 | Error, HttpResponse, 5 | }; 6 | use futures::future::LocalBoxFuture; 7 | use jsonwebtoken::{decode, DecodingKey, Validation}; 8 | use serde::{Deserialize, Serialize}; 9 | use std::future::{ready, Ready}; 10 | 11 | #[derive(Debug, Serialize, Deserialize)] 12 | pub struct Claims { 13 | pub exp: usize, 14 | pub iat: usize, 15 | } 16 | 17 | #[derive(Clone)] 18 | pub struct AuthConfig { 19 | pub jwt_secret: String, 20 | } 21 | 22 | pub struct JwtAuth(Option); 23 | 24 | impl JwtAuth { 25 | pub fn new(config: Option) -> Self { 26 | JwtAuth(config) 27 | } 28 | } 29 | 30 | impl Transform for JwtAuth 31 | where 32 | S: Service, Error = Error>, 33 | S::Future: 'static, 34 | B: 'static, 35 | { 36 | type Response = ServiceResponse>; 37 | type Error = Error; 38 | type InitError = (); 39 | type Transform = JwtAuthMiddleware; 40 | type Future = Ready>; 41 | 42 | fn new_transform(&self, service: S) -> Self::Future { 43 | ready(Ok(JwtAuthMiddleware { 44 | service, 45 | auth_config: self.0.clone(), 46 | })) 47 | } 48 | } 49 | 50 | pub struct JwtAuthMiddleware { 51 | service: S, 52 | auth_config: Option, 53 | } 54 | 55 | impl Service for JwtAuthMiddleware 56 | where 57 | S: Service, Error = Error>, 58 | S::Future: 'static, 59 | B: 'static, 60 | { 61 | type Response = ServiceResponse>; 62 | type Error = Error; 63 | type Future = LocalBoxFuture<'static, Result>; 64 | 65 | forward_ready!(service); 66 | 67 | fn call(&self, req: ServiceRequest) -> Self::Future { 68 | if let Some(config) = &self.auth_config { 69 | let auth_header = req 70 | .headers() 71 | .get("Authorization") 72 | .and_then(|h| h.to_str().ok()); 73 | 74 | match auth_header { 75 | Some(auth) if auth.starts_with("Bearer ") => { 76 | let token = &auth[7..]; 77 | match decode::( 78 | token, 79 | &DecodingKey::from_secret(config.jwt_secret.as_bytes()), 80 | &Validation::default(), 81 | ) { 82 | Ok(_) => { 83 | let fut = self.service.call(req); 84 | Box::pin( 85 | async move { fut.await.map(ServiceResponse::map_into_left_body) }, 86 | ) 87 | } 88 | Err(_) => { 89 | let (req, _) = req.into_parts(); 90 | Box::pin(async move { 91 | Ok( 92 | ServiceResponse::new( 93 | req, 94 | HttpResponse::Unauthorized().finish(), 95 | ) 96 | .map_into_right_body(), 97 | ) 98 | }) 99 | } 100 | } 101 | } 102 | _ => { 103 | let (req, _) = req.into_parts(); 104 | Box::pin(async move { 105 | Ok( 106 | ServiceResponse::new(req, HttpResponse::Unauthorized().finish()) 107 | .map_into_right_body(), 108 | ) 109 | }) 110 | } 111 | } 112 | } else { 113 | let fut = self.service.call(req); 114 | Box::pin(async move { fut.await.map(ServiceResponse::map_into_left_body) }) 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/sse/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod http_server; 2 | pub mod middleware; 3 | -------------------------------------------------------------------------------- /src/transport/http_transport.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | ClientSseTransport, ClientWsTransport, Message, ServerSseTransport, ServerWsTransport, 3 | Transport, 4 | }; 5 | use anyhow::Result; 6 | pub enum ServerHttpTransport { 7 | Sse(ServerSseTransport), 8 | Ws(ServerWsTransport), 9 | } 10 | pub enum ClientHttpTransport { 11 | Sse(ClientSseTransport), 12 | Ws(ClientWsTransport), 13 | } 14 | 15 | impl Clone for ServerHttpTransport { 16 | fn clone(&self) -> Self { 17 | match self { 18 | ServerHttpTransport::Sse(sse) => ServerHttpTransport::Sse(sse.clone()), 19 | ServerHttpTransport::Ws(ws) => ServerHttpTransport::Ws(ws.clone()), 20 | } 21 | } 22 | } 23 | 24 | #[async_trait::async_trait] 25 | impl Transport for ServerHttpTransport { 26 | async fn send(&self, message: &Message) -> Result<()> { 27 | match self { 28 | ServerHttpTransport::Sse(sse) => sse.send(message).await, 29 | ServerHttpTransport::Ws(ws) => ws.send(message).await, 30 | } 31 | } 32 | 33 | async fn receive(&self) -> Result> { 34 | match self { 35 | ServerHttpTransport::Sse(sse) => sse.receive().await, 36 | ServerHttpTransport::Ws(ws) => ws.receive().await, 37 | } 38 | } 39 | 40 | async fn open(&self) -> Result<()> { 41 | match self { 42 | ServerHttpTransport::Sse(sse) => sse.open().await, 43 | ServerHttpTransport::Ws(ws) => ws.open().await, 44 | } 45 | } 46 | 47 | async fn close(&self) -> Result<()> { 48 | match self { 49 | ServerHttpTransport::Sse(sse) => sse.close().await, 50 | ServerHttpTransport::Ws(ws) => ws.close().await, 51 | } 52 | } 53 | } 54 | 55 | impl Clone for ClientHttpTransport { 56 | fn clone(&self) -> Self { 57 | match self { 58 | ClientHttpTransport::Sse(sse) => ClientHttpTransport::Sse(sse.clone()), 59 | ClientHttpTransport::Ws(ws) => ClientHttpTransport::Ws(ws.clone()), 60 | } 61 | } 62 | } 63 | 64 | #[async_trait::async_trait] 65 | impl Transport for ClientHttpTransport { 66 | async fn send(&self, message: &Message) -> Result<()> { 67 | match self { 68 | ClientHttpTransport::Sse(sse) => sse.send(message).await, 69 | ClientHttpTransport::Ws(ws) => ws.send(message).await, 70 | } 71 | } 72 | 73 | async fn receive(&self) -> Result> { 74 | match self { 75 | ClientHttpTransport::Sse(sse) => sse.receive().await, 76 | ClientHttpTransport::Ws(ws) => ws.receive().await, 77 | } 78 | } 79 | 80 | async fn open(&self) -> Result<()> { 81 | match self { 82 | ClientHttpTransport::Sse(sse) => sse.open().await, 83 | ClientHttpTransport::Ws(ws) => ws.open().await, 84 | } 85 | } 86 | 87 | async fn close(&self) -> Result<()> { 88 | match self { 89 | ClientHttpTransport::Sse(sse) => sse.close().await, 90 | ClientHttpTransport::Ws(ws) => ws.close().await, 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/transport/inmemory_transport.rs: -------------------------------------------------------------------------------- 1 | use super::{Message, Transport}; 2 | use anyhow::Result; 3 | use async_trait::async_trait; 4 | use std::sync::Arc; 5 | use tokio::sync::mpsc::{self, Receiver, Sender}; 6 | use tokio::sync::Mutex; 7 | use tokio::task::JoinHandle; 8 | use tracing::debug; 9 | 10 | /// Server-side transport that receives messages from a channel 11 | #[derive(Clone)] 12 | pub struct ServerInMemoryTransport { 13 | rx: Arc>>>, 14 | tx: Sender, 15 | } 16 | 17 | impl Default for ServerInMemoryTransport { 18 | fn default() -> Self { 19 | let (tx, rx) = mpsc::channel(100); // Default buffer size of 100 20 | Self { 21 | rx: Arc::new(Mutex::new(Some(rx))), 22 | tx, 23 | } 24 | } 25 | } 26 | 27 | #[async_trait] 28 | impl Transport for ServerInMemoryTransport { 29 | async fn receive(&self) -> Result> { 30 | let mut rx_guard = self.rx.lock().await; 31 | let rx = rx_guard 32 | .as_mut() 33 | .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?; 34 | 35 | match rx.recv().await { 36 | Some(message) => { 37 | debug!("Server received: {:?}", message); 38 | Ok(Some(message)) 39 | } 40 | None => { 41 | debug!("Client channel closed"); 42 | Ok(None) 43 | } 44 | } 45 | } 46 | 47 | async fn send(&self, message: &Message) -> Result<()> { 48 | debug!("Server sending: {:?}", message); 49 | self.tx 50 | .send(message.clone()) 51 | .await 52 | .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?; 53 | Ok(()) 54 | } 55 | 56 | async fn open(&self) -> Result<()> { 57 | Ok(()) 58 | } 59 | 60 | async fn close(&self) -> Result<()> { 61 | *self.rx.lock().await = None; 62 | Ok(()) 63 | } 64 | } 65 | 66 | /// Client-side transport that communicates with a spawned server task 67 | #[derive(Clone)] 68 | pub struct ClientInMemoryTransport { 69 | tx: Arc>>>, 70 | rx: Arc>>>, 71 | server_handle: Arc>>>, 72 | server_factory: Arc JoinHandle<()> + Send + Sync>, 73 | } 74 | 75 | impl ClientInMemoryTransport { 76 | pub fn new(server_factory: F) -> Self 77 | where 78 | F: Fn(ServerInMemoryTransport) -> JoinHandle<()> + Send + Sync + 'static, 79 | { 80 | Self { 81 | tx: Arc::new(Mutex::new(None)), 82 | rx: Arc::new(Mutex::new(None)), 83 | server_handle: Arc::new(Mutex::new(None)), 84 | server_factory: Arc::new(server_factory), 85 | } 86 | } 87 | } 88 | 89 | #[async_trait] 90 | impl Transport for ClientInMemoryTransport { 91 | async fn receive(&self) -> Result> { 92 | let mut rx_guard = self.rx.lock().await; 93 | let rx = rx_guard 94 | .as_mut() 95 | .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?; 96 | 97 | match rx.recv().await { 98 | Some(message) => { 99 | debug!("Client received: {:?}", message); 100 | Ok(Some(message)) 101 | } 102 | None => { 103 | debug!("Server channel closed"); 104 | Ok(None) 105 | } 106 | } 107 | } 108 | 109 | async fn send(&self, message: &Message) -> Result<()> { 110 | let tx_guard = self.tx.lock().await; 111 | let tx = tx_guard 112 | .as_ref() 113 | .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?; 114 | 115 | debug!("Client sending: {:?}", message); 116 | tx.send(message.clone()) 117 | .await 118 | .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?; 119 | Ok(()) 120 | } 121 | 122 | async fn open(&self) -> Result<()> { 123 | let (client_tx, server_rx) = mpsc::channel(100); 124 | let (server_tx, client_rx) = mpsc::channel(100); 125 | 126 | let server_transport = ServerInMemoryTransport { 127 | rx: Arc::new(Mutex::new(Some(server_rx))), 128 | tx: server_tx, 129 | }; 130 | 131 | let server_handle = (self.server_factory)(server_transport); 132 | 133 | *self.rx.lock().await = Some(client_rx); 134 | *self.tx.lock().await = Some(client_tx); 135 | *self.server_handle.lock().await = Some(server_handle); 136 | 137 | Ok(()) 138 | } 139 | 140 | async fn close(&self) -> Result<()> { 141 | *self.tx.lock().await = None; 142 | *self.rx.lock().await = None; 143 | 144 | if let Some(handle) = self.server_handle.lock().await.take() { 145 | handle.await?; 146 | } 147 | 148 | Ok(()) 149 | } 150 | } 151 | 152 | #[cfg(test)] 153 | mod tests { 154 | use super::*; 155 | use crate::transport::{JsonRpcMessage, JsonRpcRequest, JsonRpcVersion}; 156 | use std::time::Duration; 157 | 158 | async fn echo_server(transport: ServerInMemoryTransport) { 159 | while let Ok(Some(message)) = transport.receive().await { 160 | if transport.send(&message).await.is_err() { 161 | break; 162 | } 163 | } 164 | } 165 | 166 | #[tokio::test] 167 | async fn test_async_transport() -> Result<()> { 168 | let transport = ClientInMemoryTransport::new(|t| tokio::spawn(echo_server(t))); 169 | 170 | // Create a test message 171 | let test_message = JsonRpcMessage::Request(JsonRpcRequest { 172 | id: 1, 173 | method: "test".to_string(), 174 | params: Some(serde_json::json!({"hello": "world"})), 175 | jsonrpc: JsonRpcVersion::default(), 176 | }); 177 | 178 | // Open transport 179 | transport.open().await?; 180 | 181 | // Send message 182 | transport.send(&test_message).await?; 183 | 184 | // Receive echoed message 185 | let response = transport.receive().await?; 186 | 187 | // Verify the response matches 188 | assert_eq!(Some(test_message), response); 189 | 190 | // Clean up 191 | transport.close().await?; 192 | 193 | Ok(()) 194 | } 195 | 196 | #[tokio::test] 197 | async fn test_graceful_shutdown() -> Result<()> { 198 | let transport = ClientInMemoryTransport::new(|t| { 199 | tokio::spawn(async move { 200 | tokio::time::sleep(Duration::from_secs(5)).await; 201 | drop(t); 202 | }) 203 | }); 204 | 205 | transport.open().await?; 206 | 207 | // Spawn a task that will read from the transport 208 | let transport_clone = transport.clone(); 209 | let read_handle = tokio::spawn(async move { 210 | let result = transport_clone.receive().await; 211 | debug!("Receive returned: {:?}", result); 212 | result 213 | }); 214 | 215 | // Wait a bit to ensure the server is running 216 | tokio::time::sleep(Duration::from_millis(100)).await; 217 | 218 | // Initiate graceful shutdown 219 | let start = std::time::Instant::now(); 220 | transport.close().await?; 221 | let shutdown_duration = start.elapsed(); 222 | 223 | // Verify shutdown completed quickly 224 | assert!(shutdown_duration < Duration::from_secs(5)); 225 | 226 | // Verify receive operation was cancelled 227 | let read_result = read_handle.await?; 228 | assert!(read_result.is_ok()); 229 | assert_eq!(read_result.unwrap(), None); 230 | 231 | Ok(()) 232 | } 233 | 234 | #[tokio::test] 235 | async fn test_multiple_messages() -> Result<()> { 236 | let transport = ClientInMemoryTransport::new(|t| tokio::spawn(echo_server(t))); 237 | transport.open().await?; 238 | 239 | let messages: Vec<_> = (0..5) 240 | .map(|i| { 241 | JsonRpcMessage::Request(JsonRpcRequest { 242 | id: i, 243 | method: format!("test_{}", i), 244 | params: Some(serde_json::json!({"index": i})), 245 | jsonrpc: JsonRpcVersion::default(), 246 | }) 247 | }) 248 | .collect(); 249 | 250 | // Send all messages 251 | for msg in &messages { 252 | transport.send(msg).await?; 253 | } 254 | 255 | // Receive and verify all messages 256 | for expected in &messages { 257 | let received = transport.receive().await?; 258 | assert_eq!(Some(expected.clone()), received); 259 | } 260 | 261 | transport.close().await?; 262 | Ok(()) 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /src/transport/mod.rs: -------------------------------------------------------------------------------- 1 | //! Transport layer for the MCP protocol 2 | //! handles the serialization and deserialization of message 3 | //! handles send and receive of messages 4 | //! defines transport layer types 5 | use anyhow::Result; 6 | use async_trait::async_trait; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | mod stdio_transport; 10 | pub use stdio_transport::*; 11 | mod inmemory_transport; 12 | pub use inmemory_transport::*; 13 | mod sse_transport; 14 | pub use sse_transport::*; 15 | mod ws_transport; 16 | pub use ws_transport::*; 17 | mod http_transport; 18 | pub use http_transport::*; 19 | /// only JsonRpcMessage is supported for now 20 | /// https://spec.modelcontextprotocol.io/specification/basic/messages/ 21 | pub type Message = JsonRpcMessage; 22 | 23 | #[async_trait] 24 | pub trait Transport: Send + Sync + 'static { 25 | /// Send a message to the transport 26 | async fn send(&self, message: &Message) -> Result<()>; 27 | 28 | /// Receive a message from the transport 29 | /// this is blocking call 30 | async fn receive(&self) -> Result>; 31 | 32 | /// open the transport 33 | async fn open(&self) -> Result<()>; 34 | 35 | /// Close the transport 36 | async fn close(&self) -> Result<()>; 37 | } 38 | 39 | /// Request ID type 40 | pub type RequestId = u64; 41 | /// JSON RPC version type 42 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 43 | #[serde(transparent)] 44 | pub struct JsonRpcVersion(String); 45 | 46 | impl Default for JsonRpcVersion { 47 | fn default() -> Self { 48 | JsonRpcVersion("2.0".to_owned()) 49 | } 50 | } 51 | 52 | impl JsonRpcVersion { 53 | pub fn as_str(&self) -> &str { 54 | &self.0 55 | } 56 | } 57 | 58 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] 59 | #[serde(deny_unknown_fields)] 60 | #[serde(untagged)] 61 | pub enum JsonRpcMessage { 62 | Response(JsonRpcResponse), 63 | Request(JsonRpcRequest), 64 | Notification(JsonRpcNotification), 65 | } 66 | 67 | // json rpc types 68 | #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] 69 | #[serde(deny_unknown_fields)] 70 | pub struct JsonRpcRequest { 71 | pub id: RequestId, 72 | pub method: String, 73 | #[serde(skip_serializing_if = "Option::is_none")] 74 | pub params: Option, 75 | pub jsonrpc: JsonRpcVersion, 76 | } 77 | 78 | #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] 79 | #[serde(rename_all = "camelCase")] 80 | #[serde(deny_unknown_fields)] 81 | #[serde(default)] 82 | pub struct JsonRpcNotification { 83 | pub method: String, 84 | #[serde(skip_serializing_if = "Option::is_none")] 85 | pub params: Option, 86 | pub jsonrpc: JsonRpcVersion, 87 | } 88 | 89 | #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] 90 | #[serde(deny_unknown_fields)] 91 | #[serde(rename_all = "camelCase")] 92 | #[serde(default)] 93 | pub struct JsonRpcResponse { 94 | /// The request ID this response corresponds to 95 | pub id: RequestId, 96 | /// The result of the request, if successful 97 | #[serde(skip_serializing_if = "Option::is_none")] 98 | pub result: Option, 99 | /// The error, if the request failed 100 | #[serde(skip_serializing_if = "Option::is_none")] 101 | pub error: Option, 102 | /// The JSON-RPC version 103 | pub jsonrpc: JsonRpcVersion, 104 | } 105 | 106 | #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] 107 | #[serde(rename_all = "camelCase")] 108 | #[serde(default)] 109 | pub struct JsonRpcError { 110 | /// Error code 111 | pub code: i32, 112 | /// Error message 113 | pub message: String, 114 | /// Optional additional error data 115 | #[serde(skip_serializing_if = "Option::is_none")] 116 | pub data: Option, 117 | } 118 | 119 | #[cfg(test)] 120 | mod tests { 121 | use super::*; 122 | #[test] 123 | fn test_deserialize_initialize_request() { 124 | let json = r#"{"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"claude-ai","version":"0.1.0"}},"jsonrpc":"2.0","id":0}"#; 125 | 126 | let message: Message = serde_json::from_str(json).unwrap(); 127 | match message { 128 | JsonRpcMessage::Request(req) => { 129 | assert_eq!(req.jsonrpc.as_str(), "2.0"); 130 | assert_eq!(req.id, 0); 131 | assert_eq!(req.method, "initialize"); 132 | 133 | // Verify params exist and are an object 134 | let params = req.params.expect("params should exist"); 135 | assert!(params.is_object()); 136 | 137 | let params_obj = params.as_object().unwrap(); 138 | assert_eq!(params_obj["protocolVersion"], "2024-11-05"); 139 | 140 | let client_info = params_obj["clientInfo"].as_object().unwrap(); 141 | assert_eq!(client_info["name"], "claude-ai"); 142 | assert_eq!(client_info["version"], "0.1.0"); 143 | } 144 | _ => panic!("Expected Request variant"), 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/transport/sse_transport.rs: -------------------------------------------------------------------------------- 1 | use crate::sse::middleware::{AuthConfig, Claims}; 2 | 3 | use super::{Message, Transport}; 4 | 5 | use actix_web::web::Bytes; 6 | use anyhow::Result; 7 | use async_trait::async_trait; 8 | use futures::StreamExt; 9 | use jsonwebtoken::{encode, EncodingKey, Header}; 10 | 11 | use std::collections::HashMap; 12 | use std::sync::Arc; 13 | use std::time::{SystemTime, UNIX_EPOCH}; 14 | use tokio::sync::{broadcast, mpsc, Mutex}; 15 | use tracing::debug; 16 | 17 | #[derive(Clone)] 18 | pub struct ServerSseTransport { 19 | // For receiving messages from HTTP POST requests 20 | message_rx: Arc>>, 21 | message_tx: mpsc::Sender, 22 | // For sending messages to SSE clients 23 | sse_tx: broadcast::Sender, 24 | } 25 | 26 | impl ServerSseTransport { 27 | pub fn new(sse_tx: broadcast::Sender) -> Self { 28 | let (message_tx, message_rx) = mpsc::channel(100); 29 | Self { 30 | message_rx: Arc::new(Mutex::new(message_rx)), 31 | message_tx, 32 | sse_tx, 33 | } 34 | } 35 | 36 | pub async fn send_message(&self, message: Message) -> Result<()> { 37 | self.message_tx.send(message).await?; 38 | Ok(()) 39 | } 40 | 41 | // Helper function to chunk message into SSE format 42 | fn format_sse_message(message: &Message) -> Result { 43 | const CHUNK_SIZE: usize = 16 * 1024; // 16KB chunks 44 | let json = serde_json::to_string(message)?; 45 | let mut result = String::new(); 46 | 47 | // Add event type 48 | result.push_str("event: message\n"); 49 | 50 | // If small enough, send as single chunk 51 | if json.len() <= CHUNK_SIZE { 52 | result.push_str(&format!("data: {}\n\n", json)); 53 | return Ok(result); 54 | } 55 | 56 | // For larger messages, split at proper boundaries (commas or spaces) 57 | let mut start = 0; 58 | while start < json.len() { 59 | let mut end = (start + CHUNK_SIZE).min(json.len()); 60 | 61 | // If we're not at the end, find a good split point 62 | if end < json.len() { 63 | // Look back for a comma or space to split at 64 | while end > start && !json[end..].starts_with([',', ' ']) { 65 | end -= 1; 66 | } 67 | // If we couldn't find a good split point, just use the max size 68 | if end == start { 69 | end = (start + CHUNK_SIZE).min(json.len()); 70 | } 71 | } 72 | 73 | result.push_str(&format!("data: {}\n", &json[start..end])); 74 | start = end; 75 | } 76 | 77 | result.push('\n'); 78 | Ok(result) 79 | } 80 | } 81 | 82 | #[async_trait] 83 | impl Transport for ServerSseTransport { 84 | async fn receive(&self) -> Result> { 85 | let mut rx = self.message_rx.lock().await; 86 | match rx.recv().await { 87 | Some(message) => { 88 | debug!("Received message from POST request: {:?}", message); 89 | Ok(Some(message)) 90 | } 91 | None => Ok(None), 92 | } 93 | } 94 | 95 | async fn send(&self, message: &Message) -> Result<()> { 96 | let formatted = Self::format_sse_message(message)?; 97 | // Show first and last 500 characters for debugging 98 | if formatted.len() > 1000 { 99 | let first = &formatted[..500]; 100 | let last = &formatted[formatted.len() - 500..]; 101 | debug!("Sending chunked SSE message: {}...{}", first, last); 102 | } else { 103 | debug!("Sending chunked SSE message: {}", formatted); 104 | } 105 | 106 | self.sse_tx.send(message.clone())?; 107 | Ok(()) 108 | } 109 | 110 | async fn open(&self) -> Result<()> { 111 | Ok(()) 112 | } 113 | 114 | async fn close(&self) -> Result<()> { 115 | Ok(()) 116 | } 117 | } 118 | 119 | #[derive(Debug)] 120 | pub enum SseEvent { 121 | Message(Message), 122 | SessionId(String), 123 | } 124 | 125 | /// Client-side SSE transport that sends messages via HTTP POST 126 | /// and receives responses via SSE 127 | #[derive(Clone)] 128 | pub struct ClientSseTransport { 129 | tx: mpsc::Sender, 130 | rx: Arc>>, 131 | server_url: String, 132 | client: reqwest::Client, 133 | auth_config: Option, 134 | session_id: Arc>>, 135 | headers: HashMap, 136 | buffer: Arc>, // Add buffer for partial messages 137 | } 138 | 139 | impl ClientSseTransport { 140 | pub fn builder(url: String) -> ClientSseTransportBuilder { 141 | ClientSseTransportBuilder::new(url) 142 | } 143 | 144 | fn generate_token(&self) -> Result { 145 | let auth_config = self 146 | .auth_config 147 | .as_ref() 148 | .ok_or_else(|| anyhow::anyhow!("Auth config not set"))?; 149 | 150 | let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as usize; 151 | let claims = Claims { 152 | iat: now, 153 | exp: now + 3600, // Token expires in 1 hour 154 | }; 155 | 156 | encode( 157 | &Header::default(), 158 | &claims, 159 | &EncodingKey::from_secret(auth_config.jwt_secret.as_bytes()), 160 | ) 161 | .map_err(Into::into) 162 | } 163 | 164 | async fn add_auth_header( 165 | &self, 166 | request: reqwest::RequestBuilder, 167 | ) -> Result { 168 | if self.auth_config.is_some() { 169 | let token = self.generate_token()?; 170 | Ok(request.header("Authorization", format!("Bearer {}", token))) 171 | } else { 172 | Ok(request) 173 | } 174 | } 175 | 176 | fn parse_sse_message(event: &str) -> Option { 177 | let mut event_type = None; 178 | let mut current_data = String::new(); 179 | 180 | // Process each line 181 | for line in event.lines() { 182 | let line = line.trim(); 183 | if line.is_empty() { 184 | continue; 185 | } 186 | 187 | if line.starts_with("event:") { 188 | event_type = Some(line.trim_start_matches("event:").trim().to_string()); 189 | } else if line.starts_with("data:") { 190 | // Strip the "data:" prefix and any leading/trailing whitespace 191 | let data = line["data:".len()..].trim(); 192 | // For chunked messages, we just concatenate the data 193 | current_data.push_str(data); 194 | } 195 | } 196 | 197 | // If we have data, try to parse it 198 | if !current_data.is_empty() { 199 | let result = match (event_type.as_ref(), Some(¤t_data)) { 200 | (Some(endpoint), Some(url)) if endpoint == "endpoint" => Some(SseEvent::SessionId( 201 | url.split("sessionId=") 202 | .nth(1) 203 | .unwrap_or_default() 204 | .to_string(), 205 | )), 206 | (None, Some(data)) | (Some(_), Some(data)) => { 207 | match serde_json::from_str::(data) { 208 | Ok(msg) => Some(SseEvent::Message(msg)), 209 | Err(e) => { 210 | debug!( 211 | "Failed to parse SSE message: {}. Content preview: {}", 212 | e, 213 | if data.len() > 100 { 214 | format!("{}... (truncated)", &data[..100]) 215 | } else { 216 | data.to_string() 217 | } 218 | ); 219 | None 220 | } 221 | } 222 | } 223 | _ => None, 224 | }; 225 | 226 | if result.is_none() { 227 | debug!( 228 | "Unrecognized SSE event format - event_type: {:?}, data length: {}", 229 | event_type, 230 | current_data.len() 231 | ); 232 | } 233 | 234 | result 235 | } else { 236 | None 237 | } 238 | } 239 | 240 | async fn handle_sse_chunk( 241 | chunk: Bytes, 242 | tx: &mpsc::Sender, 243 | session_id: &Arc>>, 244 | buffer: &Arc>, 245 | ) -> Result<()> { 246 | let chunk_str = String::from_utf8(chunk.to_vec())?; 247 | let mut buffer = buffer.lock().await; 248 | 249 | // Append new chunk to buffer 250 | buffer.push_str(&chunk_str); 251 | 252 | // Process complete messages 253 | while let Some(pos) = buffer.find("\n\n") { 254 | let complete_event = buffer[..pos + 2].to_string(); 255 | buffer.replace_range(..pos + 2, ""); 256 | 257 | if let Some(sse_event) = Self::parse_sse_message(&complete_event) { 258 | match sse_event { 259 | SseEvent::Message(message) => { 260 | debug!("Received SSE message: {:?}", message); 261 | tx.send(message).await?; 262 | } 263 | SseEvent::SessionId(id) => { 264 | debug!("Received session ID: {}", id); 265 | *session_id.lock().await = Some(id); 266 | } 267 | } 268 | } 269 | } 270 | 271 | Ok(()) 272 | } 273 | } 274 | 275 | #[derive(Default)] 276 | pub struct ClientSseTransportBuilder { 277 | server_url: String, 278 | auth_config: Option, 279 | headers: HashMap, 280 | } 281 | 282 | impl ClientSseTransportBuilder { 283 | pub fn new(server_url: String) -> Self { 284 | Self { 285 | server_url, 286 | auth_config: None, 287 | headers: HashMap::new(), 288 | } 289 | } 290 | 291 | pub fn with_auth(mut self, jwt_secret: String) -> Self { 292 | self.auth_config = Some(AuthConfig { jwt_secret }); 293 | self 294 | } 295 | 296 | pub fn with_header(mut self, key: impl Into, value: impl Into) -> Self { 297 | self.headers.insert(key.into(), value.into()); 298 | self 299 | } 300 | 301 | pub fn build(self) -> ClientSseTransport { 302 | let (tx, rx) = mpsc::channel(100); 303 | ClientSseTransport { 304 | tx, 305 | rx: Arc::new(Mutex::new(rx)), 306 | server_url: self.server_url, 307 | client: reqwest::Client::new(), 308 | auth_config: self.auth_config, 309 | session_id: Arc::new(Mutex::new(None)), 310 | headers: self.headers, 311 | buffer: Arc::new(Mutex::new(String::new())), // Initialize buffer 312 | } 313 | } 314 | } 315 | 316 | #[async_trait] 317 | impl Transport for ClientSseTransport { 318 | async fn receive(&self) -> Result> { 319 | let mut rx = self.rx.lock().await; 320 | match rx.recv().await { 321 | Some(message) => { 322 | debug!("Received SSE message: {:?}", message); 323 | Ok(Some(message)) 324 | } 325 | None => Ok(None), 326 | } 327 | } 328 | 329 | async fn send(&self, message: &Message) -> Result<()> { 330 | let session_id = self 331 | .session_id 332 | .lock() 333 | .await 334 | .as_ref() 335 | .ok_or_else(|| anyhow::anyhow!("No session ID available"))? 336 | .clone(); 337 | 338 | let request = self 339 | .client 340 | .post(format!( 341 | "{}/message?sessionId={}", 342 | self.server_url, session_id 343 | )) 344 | .json(message); 345 | 346 | let request = self.add_auth_header(request).await?; 347 | let response = request.send().await?; 348 | 349 | if !response.status().is_success() { 350 | let status = response.status(); 351 | let text = response.text().await?; 352 | return Err(anyhow::anyhow!( 353 | "Failed to send message, status: {status}, body: {text}", 354 | )); 355 | } 356 | 357 | Ok(()) 358 | } 359 | 360 | async fn open(&self) -> Result<()> { 361 | let tx = self.tx.clone(); 362 | let server_url = self.server_url.clone(); 363 | let auth_config = self.auth_config.clone(); 364 | let session_id = self.session_id.clone(); 365 | let headers = self.headers.clone(); 366 | let buffer = self.buffer.clone(); 367 | 368 | let handle = tokio::spawn(async move { 369 | let mut request = reqwest::Client::new().get(format!("{}/sse", server_url)); 370 | 371 | // Add custom headers 372 | for (key, value) in &headers { 373 | request = request.header(key, value); 374 | } 375 | 376 | // Add auth header if configured 377 | if let Some(auth_config) = auth_config { 378 | let claims = Claims { 379 | iat: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as usize, 380 | exp: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as usize + 3600, 381 | }; 382 | 383 | let token = encode( 384 | &Header::default(), 385 | &claims, 386 | &EncodingKey::from_secret(auth_config.jwt_secret.as_bytes()), 387 | )?; 388 | 389 | request = request.header("Authorization", format!("Bearer {}", token)); 390 | } 391 | 392 | let mut event_stream = request.send().await?.bytes_stream(); 393 | 394 | // Handle first message to get session ID 395 | if let Some(first_chunk) = event_stream.next().await { 396 | match first_chunk { 397 | Ok(bytes) => Self::handle_sse_chunk(bytes, &tx, &session_id, &buffer).await?, 398 | Err(e) => { 399 | return Err(anyhow::anyhow!("Failed to get initial SSE message: {}", e)) 400 | } 401 | } 402 | } else { 403 | return Err(anyhow::anyhow!( 404 | "SSE connection closed before receiving initial message" 405 | )); 406 | } 407 | 408 | // Handle remaining messages 409 | while let Some(chunk) = event_stream.next().await { 410 | if let Ok(bytes) = chunk { 411 | if let Err(e) = Self::handle_sse_chunk(bytes, &tx, &session_id, &buffer).await { 412 | debug!("Error handling SSE message: {:?}", e); 413 | } 414 | } 415 | } 416 | 417 | Ok::<_, anyhow::Error>(()) 418 | }); 419 | 420 | // Wait for the session ID to be set 421 | let mut attempts = 0; 422 | while attempts < 10 { 423 | if self.session_id.lock().await.is_some() { 424 | return Ok(()); 425 | } 426 | tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; 427 | attempts += 1; 428 | } 429 | 430 | handle.abort(); 431 | Err(anyhow::anyhow!("Timeout waiting for initial SSE message")) 432 | } 433 | 434 | async fn close(&self) -> Result<()> { 435 | Ok(()) 436 | } 437 | } 438 | 439 | #[cfg(test)] 440 | mod tests { 441 | use super::*; 442 | 443 | #[test] 444 | fn test_parse_large_sse_message() { 445 | // This is the problematic message format we're seeing 446 | let large_json = r#"{"id":0,"result":{"tools":[{"description":"A powerful web search tool that provides comprehensive, real-time results using Tavily's AI search engine. Returns relevant web content with customizable parameters for result count, content type, and domain filtering. Ideal for gathering current information, news, and detailed web content analysis.","inputSchema":{"properties":{"days":{"default":3,"description":"The number of days back from the current date to include in the search results. This specifies the time frame of data to be retrieved. Please note that this feature is only available when using the 'news' search topic","type":"number"}}},"name":"tavily-search"}]},"jsonrpc":"2.0"}"#; 447 | 448 | // Format it as an SSE message with multiple data chunks 449 | let mut sse_message = String::new(); 450 | sse_message.push_str("event: message\n"); 451 | 452 | // Split the JSON into smaller chunks (simulating what the server does) 453 | let chunk_size = 100; 454 | for chunk in large_json.as_bytes().chunks(chunk_size) { 455 | if let Ok(chunk_str) = std::str::from_utf8(chunk) { 456 | sse_message.push_str(&format!("data: {}\n", chunk_str)); 457 | } 458 | } 459 | sse_message.push('\n'); 460 | 461 | // Try to parse it 462 | let result = ClientSseTransport::parse_sse_message(&sse_message); 463 | assert!(result.is_some(), "Failed to parse SSE message"); 464 | 465 | if let Some(SseEvent::Message(msg)) = result { 466 | // Verify the parsed message matches the original 467 | let parsed_json = serde_json::to_string(&msg).unwrap(); 468 | assert_eq!(parsed_json, large_json); 469 | } else { 470 | panic!("Expected Message event"); 471 | } 472 | } 473 | 474 | #[test] 475 | fn test_parse_real_sse_message() { 476 | // The actual message that's failing, but properly formatted 477 | let sse_message = concat!( 478 | "data: {\"id\":0,\"result\":{\"tools\":[{\"description\":\"A powerful web search tool that provides comprehensive, real-time results using Tavily's AI search engine. Returns relevant web content with customizable parameters for result count, content type, and domain filtering. Ideal for gathering current information, news, and detailed web content analysis.\",\"inputSchema\":{\"properties\":{\"days\":{\"default\":3,\"description\":\"The number of days back from the current date to include in the search results. This specifies the time frame of data to be retrieved. Please note that this feature is only available when using the 'news' search topic\",\"type\":\"number\"},\"exclude_domains\":{\"default\":[],\"description\":\"List of domains to specifically exclude, if the user asks to exclude a domain set this to the domain of the site\",\"items\":{\"type\":\"string\"},\"type\":\"array\"},\"include_domains\":{\"default\":[],\"description\":\"A list of domains to specifically include in the search results, if the user asks to search on specific sites set this to the domain of the site\",\"items\":{\"type\":\"string\"},\"type\":\"array\"},\"include_image_descriptions\":{\"default\":false,\"description\":\"Include a list of query-related images and their descriptions in the response\",\"type\":\"boolean\"},\"include_images\":{\"default\":false,\"description\":\"Include a list of query-related images in the response\",\"type\":\"boolean\"},\"include_raw_content\":{\"default\":false,\"description\":\"Include the cleaned and parsed HTML content of each search result\",\"type\":\"boolean\"},\"max_results\":{\"default\":10,\"description\":\"The maximum number of search results to return\",\"maximum\":20,\"minimum\":5,\"type\":\"number\"},\"query\":{\"description\":\"Search query\",\"type\":\"string\"},\"search_depth\":{\"default\":\"basic\",\"description\":\"The depth of the search. It can be 'basic' or 'advanced'\",\"enum\":[\"basic\",\"advanced\"],\"type\":\"string\"},\"time_range\":{\"description\":\"The time range back from the current date to include in the search results. This feature is available for both 'general' and 'news' search topics\",\"enum\":[\"day\",\"week\",\"month\",\"year\",\"d\",\"w\",\"m\",\"y\"],\"type\":\"string\"},\"topic\":{\"default\":\"general\",\"description\":\"The category of the search. This will determine which of our agents will be used for the search\",\"enum\":[\"general\",\"news\"],\"type\":\"string\"}},\"required\":[\"query\"],\"type\":\"object\"},\"name\":\"tavily-search\"},{\"description\":\"A powerful web content extraction tool that retrieves and processes raw content from specified URLs, ideal for data collection, content analysis, and research tasks.\",\"inputSchema\":{\"properties\":{\"extract_depth\":{\"default\":\"basic\",\"description\":\"Depth of extraction - 'basic' or 'advanced', if usrls are linkedin use 'advanced' or if explicitly told to use advanced\",\"enum\":[\"basic\",\"advanced\"],\"type\":\"string\"},\"include_images\":{\"default\":false,\"description\":\"Include a list of images extracted from the urls in the response\",\"type\":\"boolean\"},\"urls\":{\"description\":\"List of URLs to extract content from\",\"items\":{\"type\":\"string\"},\"type\":\"array\"}},\"required\":[\"urls\"],\"type\":\"object\"},\"name\":\"tavily-extract\"},{\"description\":\"Read the complete contents of a file from the file system. Handles various text encodings and provides detailed error messages if the file cannot be read. Use this tool when you need to examine the contents of a single file. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"path\":{\"type\":\"string\"}},\"required\":[\"path\"],\"type\":\"object\"},\"name\":\"read_file\"},{\"description\":\"Read the contents of multiple files simultaneously. This is more efficient than reading files one by one when you need to analyze or compare multiple files. Each file's content is returned with its path as a reference. Failed reads for individual files won't stop the entire operation. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"paths\":{\"items\":{\"type\":\"string\"},\"type\":\"array\"}},\"required\":[\"paths\"],\"type\":\"object\"},\"name\":\"read_multiple_files\"},{\"description\":\"Create a new file or completely overwrite an existing file with new content. Use with caution as it will overwrite existing files without warning. Handles text content with proper encoding. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"content\":{\"type\":\"string\"},\"path\":{\"type\":\"string\"}},\"required\":[\"path\",\"content\"],\"type\":\"object\"},\"name\":\"write_file\"},{\"description\":\"Make line-based edits to a text file. Each edit replaces exact line sequences with new content. Returns a git-style diff showing the changes made. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"dryRun\":{\"default\":false,\"description\":\"Preview changes using git-style diff format\",\"type\":\"boolean\"},\"edits\":{\"items\":{\"additionalProperties\":false,\"properties\":{\"newText\":{\"description\":\"Text to replace with\",\"type\":\"string\"},\"oldText\":{\"description\":\"Text to search for - must match exactly\",\"type\":\"string\"}},\"required\":[\"oldText\",\"newText\"],\"type\":\"object\"},\"type\":\"array\"},\"path\":{\"type\":\"string\"}},\"required\":[\"path\",\"edits\"],\"type\":\"object\"},\"name\":\"edit_file\"},{\"description\":\"Create a new directory or ensure a directory exists. Can create multiple nested directories in one operation. If the directory already exists, this operation will succeed silently. Perfect for setting up directory structures for projects or ensuring required paths exist. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"path\":{\"type\":\"string\"}},\"required\":[\"path\"],\"type\":\"object\"},\"name\":\"create_directory\"},{\"description\":\"Get a detailed listing of all files and directories in a specified path. Results clearly distinguish between files and directories with [FILE] and [DIR] prefixes. This tool is essential for understanding directory structure and finding specific files within a directory. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"path\":{\"type\":\"string\"}},\"required\":[\"path\"],\"type\":\"object\"},\"name\":\"list_directory\"},{\"description\":\"Get a recursive tree view of files and directories as a JSON structure. Each entry includes 'name', 'type' (file/directory), and 'children' for directories. Files have no children array, while directories always have a children array (which may be empty). The output is formatted with 2-space indentation for readability. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"path\":{\"type\":\"string\"}},\"required\":[\"path\"],\"type\":\"object\"},\"name\":\"directory_tree\"},{\"description\":\"Move or rename files and directories. Can move files between directories and rename them in a single operation. If the destination exists, the operation will fail. Works across different directories and can be used for simple renaming within the same directory. Both source and destination must be within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"destination\":{\"type\":\"string\"},\"source\":{\"type\":\"string\"}},\"required\":[\"source\",\"destination\"],\"type\":\"object\"},\"name\":\"move_file\"},{\"description\":\"Recursively search for files and directories matching a pattern. Searches through all subdirectories from the starting path. The search is case-insensitive and matches partial names. Returns full paths to all matching items. Great for finding files when you don't know their exact location. Only searches within allowed directories.\",\"inputSchema\":{\"$schema\":\"http://json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"excludePatterns\":{\"default\":[],\"items\":{\"type\":\"string\"},\"type\":\"array\"},\"path\":{\"type\":\"string\"},\"pattern\":{\"type\":\"string\"}},\"requ", 479 | "data: ired\":[\"path\",\"pattern\"],\"type\":\"object\"},\"name\":\"search_files\"},{\"description\":\"Retrieve detailed metadata about a file or directory. Returns comprehensive information including size, creation time, last modified time, permissions, and type. This tool is perfect for understanding file characteristics without reading the actual content. Only works within allowed directories.\",\"inputSchema\":{\"$schema\":\"http: //json-schema.org/draft-07/schema#\",\"additionalProperties\":false,\"properties\":{\"path\":{\"type\":\"string\"}},\"required\":[\"path\"],\"type\":\"object\"},\"name\":\"get_file_info\"},{\"description\":\"Returns the list of directories that this server is allowed to access. Use this to understand which directories are available before trying to access files.\",\"inputSchema\":{\"properties\":{},\"required\":[],\"type\":\"object\"},\"name\":\"list_allowed_directories\"}]},\"jsonrpc\":\"2.0\"}" 480 | ); 481 | 482 | let result = ClientSseTransport::parse_sse_message(sse_message); 483 | assert!(result.is_some(), "Failed to parse real SSE message"); 484 | 485 | // Verify we can parse the message into valid JSON 486 | if let Some(SseEvent::Message(msg)) = result { 487 | let json = serde_json::to_string(&msg).unwrap(); 488 | assert!(json.contains("\"description\":\"A powerful web search tool")); 489 | } else { 490 | panic!("Expected Message event"); 491 | } 492 | } 493 | } 494 | -------------------------------------------------------------------------------- /src/transport/stdio_transport.rs: -------------------------------------------------------------------------------- 1 | use super::{Message, Transport}; 2 | use anyhow::Result; 3 | use async_trait::async_trait; 4 | use std::collections::HashMap; 5 | use std::io::{self, BufRead, Write}; 6 | use std::process::Stdio; 7 | use std::sync::Arc; 8 | use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; 9 | use tokio::process::Child; 10 | use tokio::sync::Mutex; 11 | use tracing::debug; 12 | 13 | /// Stdio transport for server with json serialization 14 | /// TODO: support for other binary serialzation formats 15 | #[derive(Default, Clone)] 16 | pub struct ServerStdioTransport; 17 | #[async_trait] 18 | impl Transport for ServerStdioTransport { 19 | async fn receive(&self) -> Result> { 20 | let stdin = io::stdin(); 21 | let mut reader = stdin.lock(); 22 | let mut line = String::new(); 23 | reader.read_line(&mut line)?; 24 | if line.is_empty() { 25 | return Ok(None); 26 | } 27 | 28 | debug!("Received: {line}"); 29 | let message: Message = serde_json::from_str(&line)?; 30 | Ok(Some(message)) 31 | } 32 | 33 | async fn send(&self, message: &Message) -> Result<()> { 34 | let stdout = io::stdout(); 35 | let mut writer = stdout.lock(); 36 | let serialized = serde_json::to_string(message)?; 37 | debug!("Sending: {serialized}"); 38 | writer.write_all(serialized.as_bytes())?; 39 | writer.write_all(b"\n")?; 40 | writer.flush()?; 41 | Ok(()) 42 | } 43 | 44 | async fn open(&self) -> Result<()> { 45 | Ok(()) 46 | } 47 | 48 | async fn close(&self) -> Result<()> { 49 | Ok(()) 50 | } 51 | } 52 | 53 | /// ClientStdioTransport launches a child process and communicates with it via stdio 54 | #[derive(Clone)] 55 | pub struct ClientStdioTransport { 56 | stdin: Arc>>>, 57 | stdout: Arc>>>, 58 | child: Arc>>, 59 | program: String, 60 | args: Vec, 61 | env: Option>, 62 | } 63 | 64 | impl ClientStdioTransport { 65 | pub fn new(program: &str, args: &[&str], env: Option>) -> Result { 66 | Ok(ClientStdioTransport { 67 | stdin: Arc::new(Mutex::new(None)), 68 | stdout: Arc::new(Mutex::new(None)), 69 | child: Arc::new(Mutex::new(None)), 70 | program: program.to_string(), 71 | args: args.iter().map(|&s| s.to_string()).collect(), 72 | env, 73 | }) 74 | } 75 | } 76 | #[async_trait] 77 | impl Transport for ClientStdioTransport { 78 | async fn receive(&self) -> Result> { 79 | debug!("ClientStdioTransport: Starting to receive message"); 80 | let mut stdout = self.stdout.lock().await; 81 | let stdout = stdout 82 | .as_mut() 83 | .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?; 84 | 85 | let mut line = String::new(); 86 | debug!("ClientStdioTransport: Reading line from process"); 87 | let bytes_read = stdout.read_line(&mut line).await?; 88 | debug!("ClientStdioTransport: Read {} bytes", bytes_read); 89 | 90 | if bytes_read == 0 { 91 | debug!("ClientStdioTransport: Received EOF from process"); 92 | return Ok(None); 93 | } 94 | 95 | let row = if line.len() > 1000 { 96 | let start = &line[..100]; 97 | let end = &line[line.len() - 100..]; 98 | format!("{}...{}", start, end) 99 | } else { 100 | line.clone() 101 | }; 102 | 103 | debug!("ClientStdioTransport: Received from process: {}", row); 104 | let message: Message = serde_json::from_str(&line).map_err(|e| { 105 | tracing::error!("Failed to parse message: {}", e); 106 | e 107 | })?; 108 | debug!("ClientStdioTransport: Successfully parsed message"); 109 | Ok(Some(message)) 110 | } 111 | 112 | async fn send(&self, message: &Message) -> Result<()> { 113 | debug!("ClientStdioTransport: Starting to send message"); 114 | let mut stdin = self.stdin.lock().await; 115 | let stdin = stdin 116 | .as_mut() 117 | .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?; 118 | 119 | let serialized = serde_json::to_string(message)?; 120 | debug!("ClientStdioTransport: Sending to process: {serialized}"); 121 | stdin.write_all(serialized.as_bytes()).await?; 122 | stdin.write_all(b"\n").await?; 123 | stdin.flush().await?; 124 | debug!("ClientStdioTransport: Successfully sent and flushed message"); 125 | Ok(()) 126 | } 127 | 128 | async fn open(&self) -> Result<()> { 129 | debug!("ClientStdioTransport: Opening transport"); 130 | let mut command = tokio::process::Command::new(&self.program); 131 | 132 | // Set up the command with args and stdio 133 | command 134 | .args(&self.args) 135 | .stdin(Stdio::piped()) 136 | .stdout(Stdio::piped()); 137 | 138 | // Add environment variables 139 | if let Some(env) = &self.env { 140 | for (key, value) in env { 141 | command.env(key, value); 142 | } 143 | } 144 | 145 | let mut child = command.spawn()?; 146 | 147 | debug!("ClientStdioTransport: Child process spawned"); 148 | let stdin = child 149 | .stdin 150 | .take() 151 | .ok_or_else(|| anyhow::anyhow!("Child process stdin not available"))?; 152 | let stdout = child 153 | .stdout 154 | .take() 155 | .ok_or_else(|| anyhow::anyhow!("Child process stdout not available"))?; 156 | 157 | *self.stdin.lock().await = Some(BufWriter::new(stdin)); 158 | *self.stdout.lock().await = Some(BufReader::new(stdout)); 159 | *self.child.lock().await = Some(child); 160 | 161 | Ok(()) 162 | } 163 | 164 | async fn close(&self) -> Result<()> { 165 | const GRACEFUL_TIMEOUT_MS: u64 = 1000; 166 | const SIGTERM_TIMEOUT_MS: u64 = 500; 167 | debug!("Starting graceful shutdown"); 168 | { 169 | let mut stdin_guard = self.stdin.lock().await; 170 | if let Some(stdin) = stdin_guard.as_mut() { 171 | debug!("Flushing stdin"); 172 | stdin.flush().await?; 173 | } 174 | *stdin_guard = None; 175 | } 176 | 177 | let mut child_guard = self.child.lock().await; 178 | let Some(child) = child_guard.as_mut() else { 179 | debug!("No child process to close"); 180 | return Ok(()); 181 | }; 182 | 183 | debug!("Attempting graceful shutdown"); 184 | match child.try_wait()? { 185 | Some(status) => { 186 | debug!("Process already exited with status: {}", status); 187 | *child_guard = None; 188 | return Ok(()); 189 | } 190 | None => { 191 | debug!("Waiting for process to exit gracefully"); 192 | tokio::time::sleep(tokio::time::Duration::from_millis(GRACEFUL_TIMEOUT_MS)).await; 193 | } 194 | } 195 | 196 | if child.try_wait()?.is_none() { 197 | debug!("Process still running, sending SIGTERM"); 198 | child.kill().await?; 199 | tokio::time::sleep(tokio::time::Duration::from_millis(SIGTERM_TIMEOUT_MS)).await; 200 | } 201 | 202 | if child.try_wait()?.is_none() { 203 | debug!("Process not responding to SIGTERM, forcing kill"); 204 | child.kill().await?; 205 | } 206 | 207 | match child.wait().await { 208 | Ok(status) => debug!("Process exited with status: {}", status), 209 | Err(e) => debug!("Error waiting for process exit: {}", e), 210 | } 211 | 212 | *child_guard = None; 213 | debug!("Shutdown complete"); 214 | Ok(()) 215 | } 216 | } 217 | 218 | #[cfg(test)] 219 | mod tests { 220 | use crate::transport::{JsonRpcMessage, JsonRpcRequest, JsonRpcVersion}; 221 | 222 | use super::*; 223 | use std::time::Duration; 224 | #[tokio::test] 225 | #[cfg(unix)] 226 | async fn test_stdio_transport() -> Result<()> { 227 | // Create transport connected to cat command which will stay alive 228 | let transport = ClientStdioTransport::new("cat", &[], None)?; 229 | 230 | // Create a test message 231 | let test_message = JsonRpcMessage::Request(JsonRpcRequest { 232 | id: 1, 233 | method: "test".to_string(), 234 | params: Some(serde_json::json!({"hello": "world"})), 235 | jsonrpc: JsonRpcVersion::default(), 236 | }); 237 | 238 | // Open transport 239 | transport.open().await?; 240 | 241 | // Send message 242 | transport.send(&test_message).await?; 243 | 244 | // Receive echoed message 245 | let response = transport.receive().await?; 246 | 247 | // Verify the response matches 248 | assert_eq!(Some(test_message), response); 249 | 250 | // Clean up 251 | transport.close().await?; 252 | 253 | Ok(()) 254 | } 255 | 256 | #[tokio::test] 257 | #[cfg(unix)] 258 | async fn test_graceful_shutdown() -> Result<()> { 259 | // Create transport with a sleep command that runs for 5 seconds 260 | let transport = ClientStdioTransport::new("sleep", &["5"], None)?; 261 | transport.open().await?; 262 | 263 | // Spawn a task that will read from the transport 264 | let transport_clone = transport.clone(); 265 | let read_handle = tokio::spawn(async move { 266 | let result = transport_clone.receive().await; 267 | debug!("Receive returned: {:?}", result); 268 | result 269 | }); 270 | 271 | // Wait a bit to ensure the process is running 272 | tokio::time::sleep(Duration::from_millis(100)).await; 273 | 274 | // Initiate graceful shutdown 275 | let start = std::time::Instant::now(); 276 | transport.close().await?; 277 | let shutdown_duration = start.elapsed(); 278 | 279 | // Verify that: 280 | // 1. The read operation was cancelled (returned None) 281 | // 2. The shutdown completed in less than 5 seconds (didn't wait for sleep) 282 | // 3. The process was properly terminated 283 | let read_result = read_handle.await?; 284 | assert!(read_result.is_ok()); 285 | assert_eq!(read_result.unwrap(), None); 286 | assert!(shutdown_duration < Duration::from_secs(5)); 287 | 288 | // Verify process is no longer running 289 | let child_guard = transport.child.lock().await; 290 | assert!(child_guard.is_none()); 291 | 292 | Ok(()) 293 | } 294 | 295 | #[tokio::test] 296 | #[cfg(unix)] 297 | async fn test_shutdown_with_pending_io() -> Result<()> { 298 | // Use 'read' command which will wait for input without echoing 299 | let transport = ClientStdioTransport::new("read", &[], None)?; 300 | transport.open().await?; 301 | 302 | // Start a receive operation that will be pending 303 | let transport_clone = transport.clone(); 304 | let read_handle = tokio::spawn(async move { transport_clone.receive().await }); 305 | 306 | // Give some time for read operation to start 307 | tokio::time::sleep(Duration::from_millis(100)).await; 308 | 309 | // Send a message (will be pending since 'read' won't echo) 310 | let test_message = JsonRpcMessage::Request(JsonRpcRequest { 311 | id: 1, 312 | method: "test".to_string(), 313 | params: Some(serde_json::json!({"hello": "world"})), 314 | jsonrpc: JsonRpcVersion::default(), 315 | }); 316 | transport.send(&test_message).await?; 317 | 318 | // Initiate shutdown 319 | transport.close().await?; 320 | 321 | // Verify the read operation was cancelled cleanly 322 | let read_result = read_handle.await?; 323 | assert!(read_result.is_ok()); 324 | assert_eq!(read_result.unwrap(), None); 325 | 326 | Ok(()) 327 | } 328 | } 329 | -------------------------------------------------------------------------------- /src/transport/ws_transport.rs: -------------------------------------------------------------------------------- 1 | use super::{Message, Transport}; 2 | use actix_ws::{Message as WsMessage, Session}; 3 | use anyhow::Result; 4 | use async_trait::async_trait; 5 | use futures::{SinkExt, StreamExt}; 6 | use reqwest::header::{HeaderName, HeaderValue}; 7 | use std::sync::Arc; 8 | use std::{collections::HashMap, str::FromStr}; 9 | use tokio::sync::{broadcast, Mutex}; 10 | use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message as TungsteniteMessage}; 11 | use tracing::{debug, info}; 12 | 13 | #[derive(Clone)] 14 | pub struct ServerWsTransport { 15 | session: Arc>>, 16 | rx: Arc>>>, 17 | } 18 | 19 | impl ServerWsTransport { 20 | pub fn new(session: Session, rx: broadcast::Receiver) -> Self { 21 | Self { 22 | session: Arc::new(Mutex::new(Some(session))), 23 | rx: Arc::new(Mutex::new(Some(rx))), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Clone)] 29 | pub struct ClientWsTransport { 30 | ws_tx: Arc>>>, 31 | ws_rx: Arc>>>, 32 | url: String, 33 | headers: HashMap, 34 | ws_write: Arc< 35 | Mutex< 36 | Option< 37 | futures::stream::SplitSink< 38 | tokio_tungstenite::WebSocketStream< 39 | tokio_tungstenite::MaybeTlsStream, 40 | >, 41 | TungsteniteMessage, 42 | >, 43 | >, 44 | >, 45 | >, 46 | } 47 | 48 | impl ClientWsTransport { 49 | pub fn builder(url: String) -> ClientWsTransportBuilder { 50 | ClientWsTransportBuilder::new(url) 51 | } 52 | } 53 | 54 | #[derive(Default)] 55 | pub struct ClientWsTransportBuilder { 56 | url: String, 57 | headers: HashMap, 58 | } 59 | 60 | impl ClientWsTransportBuilder { 61 | pub fn new(url: String) -> Self { 62 | Self { 63 | url, 64 | headers: HashMap::new(), 65 | } 66 | } 67 | 68 | pub fn with_header(mut self, key: impl Into, value: impl Into) -> Self { 69 | self.headers.insert(key.into(), value.into()); 70 | self 71 | } 72 | 73 | pub fn build(self) -> ClientWsTransport { 74 | let (tx, rx) = broadcast::channel(100); 75 | ClientWsTransport { 76 | ws_tx: Arc::new(Mutex::new(Some(tx))), 77 | ws_rx: Arc::new(Mutex::new(Some(rx))), 78 | url: self.url, 79 | headers: self.headers, 80 | ws_write: Arc::new(Mutex::new(None)), 81 | } 82 | } 83 | } 84 | 85 | #[async_trait] 86 | impl Transport for ServerWsTransport { 87 | async fn receive(&self) -> Result> { 88 | if let Some(rx) = self.rx.lock().await.as_mut() { 89 | match rx.recv().await { 90 | Ok(msg) => { 91 | debug!("Server received message: {:?}", msg); 92 | Ok(Some(msg)) 93 | } 94 | Err(e) => { 95 | debug!("Server receive error: {}", e); 96 | Ok(None) 97 | } 98 | } 99 | } else { 100 | debug!("Server receive called but receiver is None"); 101 | Ok(None) 102 | } 103 | } 104 | 105 | async fn send(&self, message: &Message) -> Result<()> { 106 | let text = serde_json::to_string(message)?; 107 | if let Some(session) = self.session.lock().await.as_mut() { 108 | debug!("Server sending message: {}", text); 109 | session.text(text).await?; 110 | } else { 111 | debug!("Server send called but session is None"); 112 | } 113 | Ok(()) 114 | } 115 | 116 | async fn open(&self) -> Result<()> { 117 | Ok(()) 118 | } 119 | 120 | async fn close(&self) -> Result<()> { 121 | info!("Server WebSocket connection closing"); 122 | if let Some(session) = self.session.lock().await.take() { 123 | session.close(None).await?; 124 | } 125 | Ok(()) 126 | } 127 | } 128 | 129 | #[async_trait] 130 | impl Transport for ClientWsTransport { 131 | async fn receive(&self) -> Result> { 132 | if let Some(rx) = self.ws_rx.lock().await.as_mut() { 133 | match rx.recv().await { 134 | Ok(msg) => { 135 | debug!("Client received message: {:?}", msg); 136 | Ok(Some(msg)) 137 | } 138 | Err(e) => { 139 | debug!("Client receive error: {}", e); 140 | Ok(None) 141 | } 142 | } 143 | } else { 144 | debug!("Client receive called but receiver is None"); 145 | Ok(None) 146 | } 147 | } 148 | 149 | async fn send(&self, message: &Message) -> Result<()> { 150 | let text = serde_json::to_string(message)?; 151 | if let Some(write) = self.ws_write.lock().await.as_mut() { 152 | debug!("Client sending message: {}", text); 153 | write.send(TungsteniteMessage::Text(text)).await?; 154 | } else { 155 | debug!("Client send called but writer is None"); 156 | } 157 | Ok(()) 158 | } 159 | 160 | async fn open(&self) -> Result<()> { 161 | info!("Opening WebSocket connection to {}", self.url); 162 | 163 | let mut request = self.url.clone().into_client_request().unwrap(); 164 | // MCP servers seem to be expecting this as protocol 165 | request.headers_mut().insert( 166 | "Sec-WebSocket-Protocol", 167 | HeaderValue::from_str("mcp").unwrap(), 168 | ); 169 | for (k, v) in &self.headers { 170 | request.headers_mut().insert( 171 | HeaderName::from_str(k).unwrap(), 172 | HeaderValue::from_str(v).unwrap(), 173 | ); 174 | } 175 | let (ws_stream, response) = tokio_tungstenite::connect_async(request).await?; 176 | 177 | info!( 178 | "WebSocket connection established. Response status: {}", 179 | response.status() 180 | ); 181 | debug!("WebSocket response headers: {:?}", response.headers()); 182 | 183 | let (write, read) = ws_stream.split(); 184 | *self.ws_write.lock().await = Some(write); 185 | 186 | // Get channels for WebSocket communication 187 | let ws_tx = self 188 | .ws_tx 189 | .lock() 190 | .await 191 | .as_ref() 192 | .expect("sender should exist") 193 | .clone(); 194 | 195 | // Handle receiving messages from WebSocket 196 | tokio::spawn(async move { 197 | let mut read = read; 198 | while let Some(result) = read.next().await { 199 | match result { 200 | Ok(msg) => { 201 | if let TungsteniteMessage::Text(text) = msg { 202 | match serde_json::from_str::(&text) { 203 | Ok(message) => { 204 | debug!("Received WebSocket message: {:?}", message); 205 | // Send to the broadcast channel for the transport to receive 206 | let _ = ws_tx.send(message); 207 | } 208 | Err(e) => debug!("Failed to parse WebSocket message: {}", e), 209 | } 210 | } 211 | } 212 | Err(e) => { 213 | info!("WebSocket read error: {}", e); 214 | break; 215 | } 216 | } 217 | } 218 | info!("WebSocket read loop terminated"); 219 | }); 220 | 221 | Ok(()) 222 | } 223 | 224 | async fn close(&self) -> Result<()> { 225 | info!("Closing WebSocket connection"); 226 | self.ws_tx.lock().await.take(); 227 | self.ws_rx.lock().await.take(); 228 | Ok(()) 229 | } 230 | } 231 | 232 | pub async fn handle_ws_connection( 233 | mut session: Session, 234 | mut stream: actix_ws::MessageStream, 235 | tx: broadcast::Sender, 236 | mut rx: broadcast::Receiver, 237 | ) -> Result<()> { 238 | info!("New WebSocket connection established"); 239 | 240 | loop { 241 | tokio::select! { 242 | Some(Ok(msg)) = stream.next() => { 243 | if let WsMessage::Text(text) = msg { 244 | match serde_json::from_str::(&text) { 245 | Ok(message) => { 246 | debug!("Handler received message: {:?}", message); 247 | tx.send(message)?; 248 | } 249 | Err(e) => debug!("Failed to parse message in handler: {}", e), 250 | } 251 | } 252 | } 253 | Ok(message) = rx.recv() => { 254 | debug!("Handler sending message: {:?}", message); 255 | let text = serde_json::to_string(&message)?; 256 | session.text(text).await?; 257 | } 258 | else => { 259 | info!("WebSocket connection terminated"); 260 | break 261 | } 262 | } 263 | } 264 | Ok(()) 265 | } 266 | -------------------------------------------------------------------------------- /src/types.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use url::Url; 5 | 6 | pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; 7 | 8 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 9 | #[serde(rename_all = "camelCase")] 10 | #[serde(default)] 11 | pub struct Implementation { 12 | pub name: String, 13 | pub version: String, 14 | } 15 | 16 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 17 | #[serde(rename_all = "camelCase")] 18 | #[serde(default)] 19 | pub struct InitializeRequest { 20 | pub protocol_version: String, 21 | pub capabilities: ClientCapabilities, 22 | pub client_info: Implementation, 23 | } 24 | 25 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 26 | #[serde(rename_all = "camelCase")] 27 | #[serde(default)] 28 | pub struct InitializeResponse { 29 | pub protocol_version: String, 30 | pub capabilities: ServerCapabilities, 31 | pub server_info: Implementation, 32 | } 33 | 34 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 35 | #[serde(rename_all = "camelCase")] 36 | #[serde(default)] 37 | pub struct ServerCapabilities { 38 | #[serde(skip_serializing_if = "Option::is_none")] 39 | pub tools: Option, 40 | #[serde(skip_serializing_if = "Option::is_none")] 41 | pub experimental: Option, 42 | #[serde(skip_serializing_if = "Option::is_none")] 43 | pub logging: Option, 44 | #[serde(skip_serializing_if = "Option::is_none")] 45 | pub prompts: Option, 46 | #[serde(skip_serializing_if = "Option::is_none")] 47 | pub resources: Option, 48 | } 49 | 50 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 51 | #[serde(rename_all = "camelCase")] 52 | #[serde(default)] 53 | pub struct PromptCapabilities { 54 | pub list_changed: Option, 55 | } 56 | 57 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 58 | #[serde(rename_all = "camelCase")] 59 | #[serde(default)] 60 | pub struct ResourceCapabilities { 61 | pub subscribe: Option, 62 | pub list_changed: Option, 63 | } 64 | 65 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 66 | #[serde(rename_all = "camelCase")] 67 | #[serde(default)] 68 | pub struct ClientCapabilities { 69 | pub experimental: Option, 70 | pub sampling: Option, 71 | pub roots: Option, 72 | } 73 | 74 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 75 | #[serde(rename_all = "camelCase")] 76 | #[serde(default)] 77 | pub struct RootCapabilities { 78 | pub list_changed: Option, 79 | } 80 | 81 | #[derive(Debug, Clone, Serialize, Deserialize)] 82 | #[serde(rename_all = "camelCase")] 83 | pub struct Tool { 84 | pub name: String, 85 | #[serde(skip_serializing_if = "Option::is_none")] 86 | pub description: Option, 87 | pub input_schema: serde_json::Value, 88 | #[serde(skip_serializing_if = "Option::is_none")] 89 | pub output_schema: Option, 90 | } 91 | #[derive(Debug, Clone, Serialize, Deserialize)] 92 | #[serde(rename_all = "camelCase")] 93 | pub struct CallToolRequest { 94 | pub name: String, 95 | #[serde(skip_serializing_if = "Option::is_none")] 96 | pub arguments: Option>, 97 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 98 | pub meta: Option, 99 | } 100 | 101 | #[derive(Debug, Clone, Serialize, Deserialize)] 102 | #[serde(rename_all = "camelCase")] 103 | pub struct CallToolResponse { 104 | pub content: Vec, 105 | #[serde(skip_serializing_if = "Option::is_none")] 106 | pub is_error: Option, 107 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 108 | pub meta: Option, 109 | } 110 | 111 | #[derive(Debug, Clone, Serialize, Deserialize)] 112 | #[serde(tag = "type")] 113 | pub enum ToolResponseContent { 114 | #[serde(rename = "text")] 115 | Text { text: String }, 116 | #[serde(rename = "image")] 117 | Image { data: String, mime_type: String }, 118 | #[serde(rename = "resource")] 119 | Resource { resource: ResourceContents }, 120 | } 121 | 122 | #[derive(Debug, Clone, Serialize, Deserialize)] 123 | #[serde(rename_all = "camelCase")] 124 | pub struct ResourceContents { 125 | pub uri: Url, 126 | #[serde(skip_serializing_if = "Option::is_none")] 127 | pub mime_type: Option, 128 | } 129 | 130 | #[derive(Debug, Clone, Serialize, Deserialize)] 131 | #[serde(rename_all = "camelCase")] 132 | pub struct ReadResourceRequest { 133 | pub uri: Url, 134 | } 135 | 136 | #[derive(Debug, Clone, Serialize, Deserialize)] 137 | #[serde(rename_all = "camelCase")] 138 | pub struct ListRequest { 139 | #[serde(skip_serializing_if = "Option::is_none")] 140 | pub cursor: Option, 141 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 142 | pub meta: Option, 143 | } 144 | 145 | #[derive(Debug, Clone, Serialize, Deserialize)] 146 | #[serde(rename_all = "camelCase")] 147 | pub struct ToolsListResponse { 148 | pub tools: Vec, 149 | #[serde(skip_serializing_if = "Option::is_none")] 150 | pub next_cursor: Option, 151 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 152 | pub meta: Option, 153 | } 154 | #[derive(Debug, Clone, Deserialize, Serialize)] 155 | #[serde(rename_all = "camelCase")] 156 | pub struct PromptsListResponse { 157 | pub prompts: Vec, 158 | #[serde(skip_serializing_if = "Option::is_none")] 159 | pub next_cursor: Option, 160 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 161 | pub meta: Option>, 162 | } 163 | 164 | #[derive(Debug, Clone, Deserialize, Serialize)] 165 | #[serde(rename_all = "camelCase")] 166 | pub struct Prompt { 167 | pub name: String, 168 | #[serde(skip_serializing_if = "Option::is_none")] 169 | pub description: Option, 170 | #[serde(skip_serializing_if = "Option::is_none")] 171 | pub arguments: Option>, 172 | } 173 | 174 | #[derive(Debug, Clone, Deserialize, Serialize)] 175 | #[serde(rename_all = "camelCase")] 176 | pub struct PromptArgument { 177 | pub name: String, 178 | #[serde(skip_serializing_if = "Option::is_none")] 179 | pub description: Option, 180 | #[serde(skip_serializing_if = "Option::is_none")] 181 | pub required: Option, 182 | } 183 | 184 | #[derive(Debug, Clone, Deserialize, Serialize)] 185 | #[serde(rename_all = "camelCase")] 186 | pub struct ResourcesListResponse { 187 | pub resources: Vec, 188 | #[serde(skip_serializing_if = "Option::is_none")] 189 | pub next_cursor: Option, 190 | #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] 191 | pub meta: Option>, 192 | } 193 | 194 | #[derive(Debug, Clone, Serialize, Deserialize)] 195 | #[serde(rename_all = "camelCase")] 196 | pub struct Resource { 197 | pub uri: Url, 198 | pub name: String, 199 | #[serde(skip_serializing_if = "Option::is_none")] 200 | pub description: Option, 201 | #[serde(skip_serializing_if = "Option::is_none")] 202 | pub mime_type: Option, 203 | } 204 | 205 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 206 | pub enum ErrorCode { 207 | // SDK error codes 208 | ConnectionClosed = -1, 209 | RequestTimeout = -2, 210 | 211 | // Standard JSON-RPC error codes 212 | ParseError = -32700, 213 | InvalidRequest = -32600, 214 | MethodNotFound = -32601, 215 | InvalidParams = -32602, 216 | InternalError = -32603, 217 | } 218 | 219 | #[cfg(test)] 220 | mod tests { 221 | use super::*; 222 | 223 | #[test] 224 | fn test_server_capabilities() { 225 | let capabilities = ServerCapabilities::default(); 226 | let json = serde_json::to_string(&capabilities).unwrap(); 227 | assert_eq!(json, "{}"); 228 | } 229 | } 230 | --------------------------------------------------------------------------------