├── src └── zeromcp │ ├── py.typed │ ├── __init__.py │ ├── jsonrpc.py │ └── mcp.py ├── .python-version ├── .editorconfig ├── .gitignore ├── CONTRIBUTING.md ├── pyproject.toml ├── LICENSE ├── .github └── workflows │ └── ci.yml ├── examples └── mcp_example.py ├── README.md └── tests ├── server_test.py ├── mcp_test.py └── jsonrpc_test.py /src/zeromcp/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /src/zeromcp/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcp import McpRpcRegistry, McpToolError, McpServer, McpHttpRequestHandler 2 | 3 | __all__ = ["McpRpcRegistry", "McpToolError", "McpServer", "McpHttpRequestHandler"] 4 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | tab_width = 4 7 | end_of_line = lf 8 | charset = utf-8 9 | trim_trailing_whitespace = true 10 | insert_final_newline = true 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # Virtual environments 10 | .venv 11 | 12 | # Coverage 13 | htmlcov/ 14 | .coverage* 15 | coverage.xml 16 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Development notes 2 | 3 | Run the tests with coverage: 4 | 5 | ```sh 6 | uv run coverage run --data-file=.coverage.mcp tests/mcp_test.py 7 | uv run coverage run --data-file=.coverage.jsonrpc tests/jsonrpc_test.py 8 | uv run coverage run --data-file=.coverage.server tests/server_test.py 9 | ``` 10 | 11 | Combine coverage and generate report: 12 | 13 | ```sh 14 | uv run coverage combine 15 | uv run coverage report 16 | uv run coverage html 17 | ``` 18 | 19 | Generate report for just `jsonrpc_test.py: 20 | 21 | ```sh 22 | uv run coverage html --data-file=.coverage.jsonrpc 23 | ``` 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "zeromcp" 3 | version = "0.0.0" 4 | description = "Zero-dependency MCP server implementation" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [] 8 | license = "MIT" 9 | 10 | [project.urls] 11 | Homepage = "https://github.com/mrexodia/zeromcp" 12 | Repository = "https://github.com/mrexodia/zeromcp" 13 | Issues = "https://github.com/mrexodia/zeromcp/issues" 14 | 15 | [build-system] 16 | requires = ["hatchling"] 17 | build-backend = "hatchling.build" 18 | 19 | [dependency-groups] 20 | dev = [ 21 | "coverage>=7.11.3", 22 | "mcp>=1.21.2", 23 | "requests>=2.32.5", 24 | ] 25 | 26 | [tool.coverage.report] 27 | omit = ["tests/*"] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Duncan Ogilvie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | # Automatically cancel previous runs of this workflow on the same branch 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.ref }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | linux: 12 | # Skip building pull requests from the same repository 13 | if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) }} 14 | runs-on: ubuntu-latest 15 | permissions: 16 | id-token: write 17 | contents: read 18 | env: 19 | PYTHONUNBUFFERED: 1 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v5 23 | 24 | - name: Update pyproject.toml version 25 | if: ${{ startsWith(github.ref, 'refs/tags/') }} 26 | shell: bash 27 | run: | 28 | # Extract version from tag (strip 'v' prefix if present) 29 | VERSION=${GITHUB_REF#refs/tags/} 30 | VERSION=${VERSION#v} 31 | echo "Extracted version: $VERSION" 32 | 33 | # Update version in pyproject.toml (works on both GNU and BSD sed) 34 | sed -i.bak "s/^version = .*/version = \"$VERSION\"/" pyproject.toml 35 | rm pyproject.toml.bak 36 | 37 | - name: Install uv 38 | uses: astral-sh/setup-uv@v6 39 | with: 40 | version: "0.9.3" 41 | 42 | - name: Create venv 43 | run: uv sync 44 | 45 | - name: Type check 46 | run: uvx ty check 47 | 48 | - name: Ruff check 49 | run: uvx ruff check 50 | 51 | - name: Test JSON-RPC 52 | run: uv run coverage run --data-file=.coverage.jsonrpc tests/jsonrpc_test.py 53 | 54 | - name: Test MCP 55 | run: uv run coverage run --data-file=.coverage.mcp tests/mcp_test.py 56 | 57 | - name: Test Server 58 | run: uv run coverage run --data-file=.coverage.server tests/server_test.py 59 | 60 | - name: Generate coverage report 61 | run: | 62 | uv run coverage combine 63 | uv run coverage report 64 | 65 | - name: Publish 66 | if: ${{ startsWith(github.ref, 'refs/tags/') }} 67 | run: uv build && uv publish 68 | -------------------------------------------------------------------------------- /examples/mcp_example.py: -------------------------------------------------------------------------------- 1 | """Example MCP server with test tools""" 2 | 3 | import time 4 | import argparse 5 | from urllib.parse import urlparse 6 | from typing import Annotated, Optional, TypedDict, NotRequired 7 | from zeromcp import McpToolError, McpServer 8 | 9 | mcp = McpServer("example") 10 | 11 | 12 | @mcp.tool 13 | def divide( 14 | numerator: Annotated[float, "Numerator"], 15 | denominator: Annotated[float, "Denominator"], 16 | ) -> float: 17 | """Divide two numbers (no zero check - tests natural exceptions)""" 18 | return numerator / denominator 19 | 20 | 21 | class GreetingResponse(TypedDict): 22 | message: Annotated[str, "Greeting message"] 23 | name: Annotated[str, "Name that was greeted"] 24 | age: Annotated[NotRequired[int], "Age if provided"] 25 | 26 | 27 | @mcp.tool 28 | def greet( 29 | name: Annotated[str, "Name to greet"], 30 | age: Annotated[Optional[int], "Age of person"] = None, 31 | ) -> GreetingResponse: 32 | """Generate a greeting message""" 33 | if age is not None: 34 | return { 35 | "message": f"Hello, {name}! You are {age} years old.", 36 | "name": name, 37 | "age": age, 38 | } 39 | return {"message": f"Hello, {name}!", "name": name} 40 | 41 | 42 | class SystemInfo(TypedDict): 43 | platform: Annotated[str, "Operating system platform"] 44 | python_version: Annotated[str, "Python version"] 45 | machine: Annotated[str, "Machine architecture"] 46 | timestamp: Annotated[float, "Current timestamp"] 47 | 48 | 49 | @mcp.tool 50 | def get_system_info() -> SystemInfo: 51 | """Get system information""" 52 | import platform 53 | 54 | return { 55 | "platform": platform.system(), 56 | "python_version": platform.python_version(), 57 | "machine": platform.machine(), 58 | "timestamp": time.time(), 59 | } 60 | 61 | 62 | @mcp.tool 63 | def failing_tool(message: Annotated[str, "Error message to raise"]) -> str: 64 | """Tool that always fails (for testing error handling)""" 65 | raise McpToolError(message) 66 | 67 | 68 | class StructInfo(TypedDict): 69 | name: Annotated[str, "Structure name"] 70 | size: Annotated[int, "Structure size in bytes"] 71 | fields: Annotated[list[str], "List of field names"] 72 | 73 | 74 | @mcp.tool 75 | def struct_get( 76 | names: Annotated[list[str], "Array of structure names"] 77 | | Annotated[str, "Single structure name"], 78 | ) -> list[StructInfo]: 79 | """Retrieve structure information by names""" 80 | return [ 81 | StructInfo( 82 | { 83 | "name": name, 84 | "size": 128, # Dummy size 85 | "fields": ["field1", "field2", "field3"], # Dummy fields 86 | } 87 | ) 88 | for name in (names if isinstance(names, list) else [names]) 89 | ] 90 | 91 | 92 | @mcp.tool 93 | def random_dict(param: dict[str, int] | None) -> dict: 94 | """Return a random dictionary for testing serialization""" 95 | return { 96 | **(param or {}), 97 | "x": 42, 98 | "y": 7, 99 | "z": 99, 100 | } 101 | 102 | 103 | @mcp.resource("example://system_info") 104 | def system_info_resource() -> SystemInfo: 105 | """Resource providing system information""" 106 | return get_system_info() 107 | 108 | 109 | @mcp.resource("example://greeting/{name}") 110 | def greeting_resource( 111 | name: Annotated[str, "Name to greet from resource"], 112 | ) -> GreetingResponse: 113 | """Resource providing greeting message""" 114 | return greet(name) 115 | 116 | 117 | @mcp.resource("example://error") 118 | def error_resource() -> None: 119 | """Resource that always fails (for testing error handling)""" 120 | raise McpToolError("This is a resource error for testing purposes.") 121 | 122 | 123 | @mcp.prompt 124 | def code_review( 125 | code: Annotated[str, "Code to review"], 126 | language: Annotated[str, "Programming language"] = "python", 127 | ) -> str: 128 | """Review code for bugs and improvements""" 129 | return f"Please review this {language} code for bugs and improvements:\n\n```{language}\n{code}\n```" 130 | 131 | 132 | @mcp.prompt 133 | def summarize( 134 | text: Annotated[str, "Text to summarize"], 135 | max_sentences: Annotated[int, "Maximum sentences"] = 3, 136 | ) -> str: 137 | """Summarize text concisely""" 138 | return f"Summarize the following in {max_sentences} sentences or fewer:\n\n{text}" 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser(description="MCP Example Server") 143 | parser.add_argument( 144 | "--transport", 145 | help="Transport (stdio or http://host:port)", 146 | default="http://127.0.0.1:5001", 147 | ) 148 | args = parser.parse_args() 149 | if args.transport == "stdio": 150 | mcp.stdio() 151 | else: 152 | url = urlparse(args.transport) 153 | if url.hostname is None or url.port is None: 154 | raise Exception(f"Invalid transport URL: {args.transport}") 155 | 156 | print("Starting MCP Example Server...") 157 | 158 | print("\nAvailable tools:") 159 | for name in mcp.tools.methods.keys(): 160 | func = mcp.tools.methods[name] 161 | print(f" - {name}: {func.__doc__}") 162 | 163 | print("\nAvailable resources:") 164 | for name in mcp.resources.methods.keys(): 165 | func = mcp.resources.methods[name] 166 | print(f" - {name}: {func.__doc__}") 167 | 168 | print("\nAvailable prompts:") 169 | for name in mcp.prompts.methods.keys(): 170 | func = mcp.prompts.methods[name] 171 | print(f" - {name}: {func.__doc__}") 172 | print() 173 | 174 | mcp.serve(url.hostname, url.port) 175 | 176 | try: 177 | input("\nServer is running, press Enter or Ctrl+C to stop...") 178 | except (KeyboardInterrupt, EOFError): 179 | print("\n\nStopping server...") 180 | mcp.stop() 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zeromcp 2 | 3 | **Minimal MCP server implementation in pure Python.** 4 | 5 | A lightweight, handcrafted implementation of the [Model Context Protocol](https://modelcontextprotocol.io/) focused on what most users actually need: exposing tools with clean Python type annotations. 6 | 7 | ## Features 8 | 9 | - ✨ **Zero dependencies** - Pure Python, standard library only 10 | - 🎯 **Type-safe** - Native Python type annotations for everything 11 | - 🚀 **Fast** - Minimal overhead, maximum performance 12 | - 🛠️ **Handcrafted** - Written by a human[1](#ai-usage), verified against the spec 13 | - 🌐 **HTTP/SSE transport** - Streamable responses 14 | - 📡 **Stdio transport** - For legacy clients 15 | - 📦 **Tiny** - Less than 1,000 lines of code 16 | 17 | ## Installation 18 | 19 | ```bash 20 | pip install zeromcp 21 | ``` 22 | 23 | Or with uv: 24 | 25 | ```bash 26 | uv add zeromcp 27 | ``` 28 | 29 | ## Quick Start 30 | 31 | ```python 32 | from typing import Annotated 33 | from zeromcp import McpServer 34 | 35 | mcp = McpServer("my-server") 36 | 37 | @mcp.tool 38 | def greet( 39 | name: Annotated[str, "Name to greet"], 40 | age: Annotated[int | None, "Age of person"] = None 41 | ) -> str: 42 | """Generate a greeting message""" 43 | if age: 44 | return f"Hello, {name}! You are {age} years old." 45 | return f"Hello, {name}!" 46 | 47 | if __name__ == "__main__": 48 | mcp.serve("127.0.0.1", 8000) 49 | ``` 50 | 51 | Then manually test your MCP server with the [inspector](https://github.com/modelcontextprotocol/inspector): 52 | 53 | ```bash 54 | npx -y @modelcontextprotocol/inspector 55 | ``` 56 | 57 | Once things are working you can configure the `mcp.json`: 58 | 59 | ```json 60 | { 61 | "mcpServers": { 62 | "my-server": { 63 | "type": "http", 64 | "url": "http://127.0.0.1:8000/mcp" 65 | } 66 | } 67 | } 68 | ``` 69 | 70 | ## Stdio Transport 71 | 72 | For MCP clients that only support stdio transport: 73 | 74 | ```python 75 | from zeromcp import McpServer 76 | 77 | mcp = McpServer("my-server") 78 | 79 | @mcp.tool 80 | def greet(name: str) -> str: 81 | """Generate a greeting""" 82 | return f"Hello, {name}!" 83 | 84 | if __name__ == "__main__": 85 | mcp.stdio() 86 | ``` 87 | 88 | Then configure in `mcp.json` (different for every client): 89 | 90 | ```json 91 | { 92 | "mcpServers": { 93 | "my-server": { 94 | "command": "python", 95 | "args": ["path/to/server.py"] 96 | } 97 | } 98 | } 99 | ``` 100 | 101 | ## Type Annotations 102 | 103 | zeromcp uses native Python `Annotated` types for schema generation: 104 | 105 | ```python 106 | from typing import Annotated, Optional, TypedDict, NotRequired 107 | 108 | class GreetingResponse(TypedDict): 109 | message: Annotated[str, "Greeting message"] 110 | name: Annotated[str, "Name that was greeted"] 111 | age: Annotated[NotRequired[int], "Age if provided"] 112 | 113 | @mcp.tool 114 | def greet( 115 | name: Annotated[str, "Name to greet"], 116 | age: Annotated[Optional[int], "Age of person"] = None 117 | ) -> GreetingResponse: 118 | """Generate a greeting message""" 119 | if age is not None: 120 | return { 121 | "message": f"Hello, {name}! You are {age} years old.", 122 | "name": name, 123 | "age": age 124 | } 125 | return { 126 | "message": f"Hello, {name}!", 127 | "name": name 128 | } 129 | ``` 130 | 131 | ## Union Types 132 | 133 | Tools can accept multiple input types: 134 | 135 | ```python 136 | from typing import Annotated, TypedDict 137 | 138 | class StructInfo(TypedDict): 139 | name: Annotated[str, "Structure name"] 140 | size: Annotated[int, "Structure size in bytes"] 141 | fields: Annotated[list[str], "List of field names"] 142 | 143 | @mcp.tool 144 | def struct_get( 145 | names: Annotated[list[str], "Array of structure names"] 146 | | Annotated[str, "Single structure name"] 147 | ) -> list[StructInfo]: 148 | """Retrieve structure information by names""" 149 | return [ 150 | { 151 | "name": name, 152 | "size": 128, 153 | "fields": ["field1", "field2", "field3"] 154 | } 155 | for name in (names if isinstance(names, list) else [names]) 156 | ] 157 | ``` 158 | 159 | ## Error Handling 160 | 161 | ```python 162 | from zeromcp import McpToolError 163 | 164 | @mcp.tool 165 | def divide( 166 | numerator: Annotated[float, "Numerator"], 167 | denominator: Annotated[float, "Denominator"] 168 | ) -> float: 169 | """Divide two numbers""" 170 | if denominator == 0: 171 | raise McpToolError("Division by zero") 172 | return numerator / denominator 173 | ``` 174 | 175 | ## Resources 176 | 177 | Expose read-only data via URI patterns. Resources are serialized as JSON. 178 | 179 | ```python 180 | from typing import Annotated 181 | 182 | @mcp.resource("file://data.txt") 183 | def read_file() -> dict: 184 | """Get information about data.txt""" 185 | return {"name": "data.txt", "size": 1024} 186 | 187 | @mcp.resource("file://{filename}") 188 | def read_any_file( 189 | filename: Annotated[str, "Name of file to read"] 190 | ) -> dict: 191 | """Get information about any file""" 192 | return {"name": filename, "size": 2048} 193 | ``` 194 | 195 | ## Prompts 196 | 197 | Expose reusable prompt templates with typed arguments. 198 | 199 | ```python 200 | from typing import Annotated 201 | 202 | @mcp.prompt 203 | def code_review( 204 | code: Annotated[str, "Code to review"], 205 | language: Annotated[str, "Programming language"] = "python" 206 | ) -> str: 207 | """Review code for bugs and improvements""" 208 | return f"Please review this {language} code:\n\n```{language}\n{code}\n```" 209 | ``` 210 | 211 | ## CORS 212 | 213 | By default, zeromcp allows CORS requests from localhost origins (`localhost`, `127.0.0.1`, `::1`) on **any port**. This allows tools like the MCP Inspector or local AI tools to communicate with your MCP server. 214 | 215 | ```python 216 | from zeromcp import McpServer 217 | 218 | mcp = McpServer("my-server") 219 | 220 | # Default: allow localhost on any port 221 | mcp.cors_allowed_origins = mcp.cors_localhost 222 | 223 | # Allow all origins (use with caution) 224 | mcp.cors_allowed_origins = "*" 225 | 226 | # Allow specific origins 227 | mcp.cors_allowed_origins = [ 228 | "http://localhost:3000", 229 | "https://myapp.example.com", 230 | ] 231 | 232 | # Disable CORS (blocks all browser cross-origin requests) 233 | mcp.cors_allowed_origins = None 234 | 235 | # Custom logic 236 | mcp.cors_allowed_origins = lambda origin: origin.endswith(".example.com") 237 | ``` 238 | 239 | Note: CORS only affects browser-based requests. Non-browser clients like `curl` or MCP desktop apps are unaffected by this setting. 240 | 241 | ## Supported clients 242 | 243 | The following clients have been tested: 244 | 245 | - [Claude Code](https://code.claude.com/docs/en/mcp#installing-mcp-servers) 246 | - [Claude Desktop](https://modelcontextprotocol.io/docs/develop/connect-local-servers#installing-the-filesystem-server) (_stdio only_) 247 | - [Visual Studio Code](https://code.visualstudio.com/docs/copilot/customization/mcp-servers) 248 | - [Roo Code](https://docs.roocode.com/features/mcp/using-mcp-in-roo) / [Cline](https://docs.cline.bot/mcp/configuring-mcp-servers) / [Kilo Code](https://kilocode.ai/docs/features/mcp/using-mcp-in-kilo-code) 249 | - [LM Studio](https://lmstudio.ai/docs/app/mcp) 250 | - [Jan](https://www.jan.ai/docs/desktop/mcp#configure-and-use-mcps-within-jan) 251 | - [Gemini CLI](https://geminicli.com/docs/tools/mcp-server/#how-to-set-up-your-mcp-server) 252 | - [Cursor](https://cursor.com/docs/context/mcp) 253 | - [Windsurf](https://docs.windsurf.com/windsurf/cascade/mcp) 254 | - [Zed](https://zed.dev/docs/ai/mcp) (_stdio only_) 255 | - [Warp](https://docs.warp.dev/knowledge-and-collaboration/mcp#adding-an-mcp-server) 256 | 257 | _Note_: generally the `/mcp` endpoint is preferred, but not all clients support it correctly. 258 | 259 | 260 | 1README and some of the tests written by Claude 261 | -------------------------------------------------------------------------------- /tests/server_test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import sys 3 | import socket 4 | from contextlib import contextmanager 5 | from zeromcp import McpServer 6 | 7 | def find_free_port(): 8 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 9 | s.bind(("", 0)) 10 | return s.getsockname()[1] 11 | 12 | @contextmanager 13 | def run_server(name="test", **kwargs): 14 | port = find_free_port() 15 | server = McpServer(name) 16 | for k, v in kwargs.items(): 17 | setattr(server, k, v) 18 | server.serve("127.0.0.1", port, background=True) 19 | base_url = f"http://127.0.0.1:{port}" 20 | try: 21 | yield base_url, server 22 | finally: 23 | server.stop() 24 | 25 | PING_JSON = {"jsonrpc": "2.0", "method": "ping", "id": 1} 26 | 27 | def test_cors_permissive(): 28 | print("Testing CORS permissive (cors_allowed_origins='*')...") 29 | with run_server(cors_allowed_origins="*") as (base_url, _): 30 | test_origin = "http://example.com" 31 | # Test OPTIONS 32 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": test_origin}) 33 | assert resp.headers.get("Access-Control-Allow-Origin") == test_origin, "OPTIONS should have CORS header" 34 | 35 | # Test POST 36 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": test_origin}, json=PING_JSON) 37 | assert resp.headers.get("Access-Control-Allow-Origin") == test_origin, "POST should have CORS header" 38 | print("✓ PASS") 39 | 40 | def test_cors_restrictive(): 41 | print("Testing CORS restrictive (cors_allowed_origins=None)...") 42 | with run_server(cors_allowed_origins=None) as (base_url, _): 43 | # Test OPTIONS 44 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://example.com"}) 45 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header" 46 | 47 | # Test POST 48 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://example.com"}, json=PING_JSON) 49 | assert "Access-Control-Allow-Origin" not in resp.headers, "POST should NOT have CORS header" 50 | print("✓ PASS") 51 | 52 | def test_cors_local(): 53 | print("Testing CORS localhost...") 54 | with run_server() as (base_url, _): 55 | # Test OPTIONS 56 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://localhost:1234"}) 57 | assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:1234", "OPTIONS should have CORS header" 58 | 59 | # Test POST 60 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "https://127.0.0.1:4321"}, json=PING_JSON) 61 | assert resp.headers.get("Access-Control-Allow-Origin") == "https://127.0.0.1:4321", "POST should have CORS header (HTTPS)" 62 | 63 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://[::1]:4321"}, json=PING_JSON) 64 | assert resp.headers.get("Access-Control-Allow-Origin") == "http://[::1]:4321", "POST should have CORS header (IPv6)" 65 | 66 | # Test OPTIONS with wrong origin 67 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://example.com"}) 68 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header for wrong origin" 69 | 70 | def test_cors_list(): 71 | print("Testing CORS list...") 72 | allowed_origins = ["http://example.com", "https://example.org"] 73 | with run_server(cors_allowed_origins=allowed_origins) as (base_url, _): 74 | # Test allowed origins 75 | for origin in allowed_origins: 76 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": origin}) 77 | assert resp.headers.get("Access-Control-Allow-Origin") == origin, f"OPTIONS should have CORS header for {origin}" 78 | 79 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": origin}, json=PING_JSON) 80 | assert resp.headers.get("Access-Control-Allow-Origin") == origin, f"POST should have CORS header for {origin}" 81 | 82 | # Test disallowed origin 83 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://notallowed.com"}) 84 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header for disallowed origin" 85 | 86 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://notallowed.com"}, json=PING_JSON) 87 | assert "Access-Control-Allow-Origin" not in resp.headers, "POST should NOT have CORS header for disallowed origin" 88 | 89 | def test_body_limit(): 90 | print("Testing body limit...") 91 | # Set small limit (100 bytes) 92 | with run_server(post_body_limit=100) as (base_url, _): 93 | # Small request - should pass 94 | resp = requests.post(f"{base_url}/mcp", json=PING_JSON) 95 | assert resp.status_code == 200, "Small request should pass" 96 | 97 | # Large request - should fail 98 | large_payload = "x" * 200 99 | resp = requests.post(f"{base_url}/mcp", data=large_payload) 100 | assert resp.status_code == 413, "Large request should fail with 413" 101 | assert "Payload Too Large" in resp.text, "Error message should mention payload size" 102 | print("✓ PASS") 103 | 104 | def test_exception_redaction(): 105 | print("Testing exception redaction...") 106 | with run_server() as (base_url, server): 107 | server.tools.redact_exceptions = True 108 | 109 | @server.tool 110 | def fail(): 111 | raise ValueError("Secret internal info") 112 | 113 | # Call via tools/call 114 | payload = { 115 | "jsonrpc": "2.0", 116 | "method": "tools/call", 117 | "params": {"name": "fail", "arguments": {}}, 118 | "id": 1 119 | } 120 | resp = requests.post(f"{base_url}/mcp", json=payload) 121 | data = resp.json() 122 | 123 | # The outer JSON-RPC call succeeds 124 | assert "result" in data, f"Expected result, got error: {data.get('error')}" 125 | result = data["result"] 126 | 127 | # The tool execution failed 128 | assert result["isError"] is True, "Tool execution should be an error" 129 | error_text = result["content"][0]["text"] 130 | 131 | assert error_text == "Internal Error: Secret internal info", f"Should show redacted message, got: {error_text}" 132 | assert "Traceback" not in error_text, "Should NOT show traceback" 133 | print("✓ PASS") 134 | 135 | def test_exception_exposure(): 136 | print("Testing exception exposure (default)...") 137 | with run_server() as (base_url, server): 138 | server.tools.redact_exceptions = False 139 | 140 | @server.tool 141 | def fail(): 142 | raise ValueError("Secret internal info") 143 | 144 | # Call via tools/call 145 | payload = { 146 | "jsonrpc": "2.0", 147 | "method": "tools/call", 148 | "params": {"name": "fail", "arguments": {}}, 149 | "id": 1 150 | } 151 | resp = requests.post(f"{base_url}/mcp", json=payload) 152 | data = resp.json() 153 | 154 | # The outer JSON-RPC call succeeds 155 | assert "result" in data, f"Expected result, got error: {data.get('error')}" 156 | result = data["result"] 157 | 158 | # The tool execution failed 159 | assert result["isError"] is True, "Tool execution should be an error" 160 | error_text = result["content"][0]["text"] 161 | 162 | assert "Secret internal info" in error_text, "Should show exception message" 163 | assert "Traceback" in error_text, "Should show traceback" 164 | print("✓ PASS") 165 | 166 | print("✓ PASS") 167 | 168 | def test_http_errors(): 169 | print("Testing HTTP errors...") 170 | with run_server() as (base_url, _): 171 | # GET /mcp -> 405 Method Not Allowed 172 | resp = requests.get(f"{base_url}/mcp") 173 | assert resp.status_code == 405, f"GET /mcp should return 405, got {resp.status_code}" 174 | 175 | # GET /invalid -> 404 Not Found 176 | resp = requests.get(f"{base_url}/invalid") 177 | assert resp.status_code == 404, f"GET /invalid should return 404, got {resp.status_code}" 178 | 179 | # POST /invalid -> 404 Not Found 180 | resp = requests.post(f"{base_url}/invalid", json={}) 181 | assert resp.status_code == 404, f"POST /invalid should return 404, got {resp.status_code}" 182 | print("✓ PASS") 183 | 184 | def test_sse_errors(): 185 | print("Testing SSE errors...") 186 | with run_server() as (base_url, _): 187 | # POST /sse without session -> 400 Bad Request 188 | resp = requests.post(f"{base_url}/sse", json={}) 189 | assert resp.status_code == 400, f"POST /sse without session should return 400, got {resp.status_code}" 190 | assert "Missing ?session" in resp.text 191 | 192 | # POST /sse with invalid session -> 400 Bad Request 193 | resp = requests.post(f"{base_url}/sse?session=invalid-uuid", json={}) 194 | assert resp.status_code == 400, f"POST /sse with invalid session should return 400, got {resp.status_code}" 195 | assert "No active SSE connection" in resp.text 196 | print("✓ PASS") 197 | 198 | def test_mcp_tool_error(): 199 | print("Testing McpToolError...") 200 | from zeromcp import McpToolError 201 | with run_server() as (base_url, server): 202 | @server.tool 203 | def fail_custom(): 204 | raise McpToolError("Custom tool error") 205 | 206 | resp = requests.post(f"{base_url}/mcp", json={ 207 | "jsonrpc": "2.0", 208 | "method": "tools/call", 209 | "params": {"name": "fail_custom", "arguments": {}}, 210 | "id": 1 211 | }) 212 | data = resp.json() 213 | result = data["result"] 214 | assert result["isError"] is True 215 | assert "Custom tool error" in result["content"][0]["text"] 216 | print("✓ PASS") 217 | 218 | def run_all_tests(): 219 | print("="*60) 220 | print("SERVER TESTS") 221 | print("="*60) 222 | 223 | try: 224 | test_cors_permissive() 225 | test_cors_restrictive() 226 | test_cors_local() 227 | test_cors_list() 228 | test_body_limit() 229 | test_exception_redaction() 230 | test_exception_exposure() 231 | test_http_errors() 232 | test_sse_errors() 233 | test_mcp_tool_error() 234 | print("\n" + "="*60) 235 | print("ALL SERVER TESTS PASSED! ✓") 236 | print("="*60) 237 | except AssertionError as e: 238 | print(f"\n❌ FAIL: {e}") 239 | sys.exit(1) 240 | except Exception as e: 241 | print(f"\n❌ ERROR: {e}") 242 | import traceback 243 | traceback.print_exc() 244 | sys.exit(1) 245 | 246 | if __name__ == "__main__": 247 | run_all_tests() 248 | -------------------------------------------------------------------------------- /src/zeromcp/jsonrpc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import inspect 3 | import traceback 4 | from typing import Any, Callable, get_type_hints, get_origin, get_args, Union, TypedDict, TypeAlias, NotRequired, is_typeddict 5 | from types import UnionType 6 | 7 | JsonRpcId: TypeAlias = str | int | float | None 8 | JsonRpcParams: TypeAlias = dict[str, Any] | list[Any] | None 9 | 10 | class JsonRpcRequest(TypedDict): 11 | jsonrpc: str 12 | method: str 13 | params: NotRequired[JsonRpcParams] 14 | id: NotRequired[JsonRpcId] 15 | 16 | class JsonRpcError(TypedDict): 17 | code: int 18 | message: str 19 | data: NotRequired[Any] 20 | 21 | class JsonRpcResponse(TypedDict): 22 | jsonrpc: str 23 | result: NotRequired[Any] 24 | error: NotRequired[JsonRpcError] 25 | id: JsonRpcId 26 | 27 | class JsonRpcException(Exception): 28 | def __init__(self, code: int, message: str, data: Any = None): 29 | self.code = code 30 | self.message = message 31 | self.data = data 32 | 33 | class JsonRpcRegistry: 34 | def __init__(self): 35 | self.methods: dict[str, Callable] = {} 36 | self._cache: dict[Callable, tuple[inspect.Signature, dict, list[str]]] = {} 37 | self.redact_exceptions = False 38 | 39 | def method(self, func: Callable, name: str | None = None) -> Callable: 40 | self.methods[name or func.__name__] = func # type: ignore 41 | return func 42 | 43 | def dispatch(self, request: dict | str | bytes | bytearray) -> JsonRpcResponse | None: 44 | try: 45 | if not isinstance(request, dict): 46 | request = json.loads(request) 47 | if not isinstance(request, dict): 48 | return self._error(None, -32600, "Invalid request: must be a JSON object") 49 | except Exception as e: 50 | return self._error(None, -32700, "JSON parse error", str(e)) 51 | 52 | if request.get("jsonrpc") != "2.0": 53 | return self._error(None, -32600, "Invalid request: 'jsonrpc' must be '2.0'") 54 | 55 | method = request.get("method") 56 | if method is None: 57 | return self._error(None, -32600, "Invalid request: 'method' is required") 58 | if not isinstance(method, str): 59 | return self._error(None, -32600, "Invalid request: 'method' must be a string") 60 | 61 | request_id: JsonRpcId = request.get("id") 62 | is_notification = "id" not in request 63 | params: JsonRpcParams = request.get("params") 64 | try: 65 | result = self._call(method, params) 66 | if is_notification: 67 | return None 68 | return { 69 | "jsonrpc": "2.0", 70 | "result": result, 71 | "id": request_id, 72 | } 73 | except JsonRpcException as e: 74 | if is_notification: 75 | return None 76 | return self._error(request_id, e.code, e.message, e.data) 77 | except Exception as e: 78 | if is_notification: 79 | return None 80 | error = self.map_exception(e) 81 | return self._error(request_id, error["code"], error["message"], error.get("data")) 82 | 83 | def map_exception(self, e: Exception) -> JsonRpcError: 84 | if self.redact_exceptions: 85 | return { 86 | "code": -32603, 87 | "message": f"Internal Error: {str(e)}", 88 | } 89 | return { 90 | "code": -32603, 91 | "message": "\n".join(traceback.format_exception(e)).strip() + "\n\nPlease report a bug!", 92 | } 93 | 94 | def _call(self, method: str, params: Any) -> Any: 95 | if method not in self.methods: 96 | raise JsonRpcException(-32601, f"Method '{method}' not found") 97 | 98 | func = self.methods[method] 99 | 100 | # Check for cached reflection data 101 | if func not in self._cache: 102 | sig = inspect.signature(func) 103 | hints = get_type_hints(func) 104 | hints.pop("return", None) 105 | 106 | # Determine required vs optional parameters 107 | required_params = [] 108 | for param_name, param in sig.parameters.items(): 109 | if param.default is inspect.Parameter.empty: 110 | required_params.append(param_name) 111 | 112 | self._cache[func] = (sig, hints, required_params) 113 | 114 | sig, hints, required_params = self._cache[func] 115 | 116 | # Handle None params 117 | if params is None: 118 | if len(required_params) == 0: 119 | return func() 120 | else: 121 | raise JsonRpcException(-32602, "Missing required params") 122 | 123 | # Convert list params to dict by parameter names 124 | if isinstance(params, list): 125 | if len(params) < len(required_params): 126 | raise JsonRpcException( 127 | -32602, 128 | f"Invalid params: expected at least {len(required_params)} arguments, got {len(params)}" 129 | ) 130 | if len(params) > len(sig.parameters): 131 | raise JsonRpcException( 132 | -32602, 133 | f"Invalid params: expected at most {len(sig.parameters)} arguments, got {len(params)}" 134 | ) 135 | params = dict(zip(sig.parameters.keys(), params)) 136 | 137 | # Validate dict params 138 | if isinstance(params, dict): 139 | # Check all required params are present 140 | missing = set(required_params) - set(params.keys()) 141 | if missing: 142 | raise JsonRpcException( 143 | -32602, 144 | f"Invalid params: missing required parameters: {list(missing)}" 145 | ) 146 | 147 | # Check no extra params 148 | extra = set(params.keys()) - set(sig.parameters.keys()) 149 | if extra: 150 | raise JsonRpcException( 151 | -32602, 152 | f"Invalid params: unexpected parameters: {list(extra)}" 153 | ) 154 | 155 | validated_params = {} 156 | for param_name, value in params.items(): 157 | # If no type hint, pass through without validation 158 | if param_name not in hints: 159 | validated_params[param_name] = value 160 | continue 161 | 162 | # Has type hint, validate 163 | expected_type = hints[param_name] 164 | 165 | # Inline type validation 166 | origin = get_origin(expected_type) 167 | args = get_args(expected_type) 168 | 169 | # Handle None/null 170 | if value is None: 171 | if expected_type is not type(None): 172 | # Check if None is allowed in a Union 173 | if not (origin in (Union, UnionType) and type(None) in args): 174 | raise JsonRpcException(-32602, f"Invalid params: {param_name} cannot be null") 175 | validated_params[param_name] = None 176 | continue 177 | 178 | # Handle Union types (int | str, Optional[int], etc.) 179 | if origin in (Union, UnionType): 180 | type_matched = False 181 | for arg_type in args: 182 | if arg_type is type(None): 183 | continue 184 | 185 | arg_origin = get_origin(arg_type) 186 | check_type = arg_origin if arg_origin is not None else arg_type 187 | 188 | # TypedDict cannot be used with isinstance - check for dict instead 189 | if is_typeddict(arg_type): 190 | check_type = dict 191 | 192 | if isinstance(value, check_type): 193 | type_matched = True 194 | break 195 | 196 | if not type_matched: 197 | raise JsonRpcException(-32602, f"Invalid params: {param_name} union does not contain {type(value).__name__}") 198 | validated_params[param_name] = value 199 | continue 200 | 201 | # Handle generic types (list[X], dict[K,V]) 202 | if origin is not None: 203 | if not isinstance(value, origin): 204 | raise JsonRpcException( 205 | -32602, 206 | f"Invalid params: {param_name} expected {origin.__name__}, got {type(value).__name__}" 207 | ) 208 | validated_params[param_name] = value 209 | continue 210 | 211 | # Handle TypedDict (must check before basic types) 212 | if is_typeddict(expected_type): 213 | if not isinstance(value, dict): 214 | raise JsonRpcException( 215 | -32602, 216 | f"Invalid params: {param_name} expected dict, got {type(value).__name__}" 217 | ) 218 | validated_params[param_name] = value 219 | continue 220 | 221 | # Handle Any 222 | if expected_type is Any: 223 | validated_params[param_name] = value 224 | continue 225 | 226 | # Handle basic types 227 | if isinstance(expected_type, type): 228 | # Allow int -> float conversion 229 | if expected_type is float and isinstance(value, int): 230 | validated_params[param_name] = float(value) 231 | continue 232 | if not isinstance(value, expected_type): 233 | raise JsonRpcException( 234 | -32602, 235 | f"Invalid params: {param_name} expected {expected_type.__name__}, got {type(value).__name__}" 236 | ) 237 | validated_params[param_name] = value 238 | continue 239 | 240 | return func(**validated_params) 241 | 242 | else: 243 | raise JsonRpcException(-32602, "Invalid params: must be array or object") 244 | 245 | def _error(self, request_id: JsonRpcId, code: int, message: str, data: Any = None) -> JsonRpcResponse | None: 246 | error: JsonRpcError = { 247 | "code": code, 248 | "message": message, 249 | } 250 | if data is not None: 251 | error["data"] = data 252 | return { 253 | "jsonrpc": "2.0", 254 | "error": error, 255 | "id": request_id, 256 | } 257 | -------------------------------------------------------------------------------- /tests/mcp_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import socket 4 | import asyncio 5 | import subprocess 6 | 7 | from pydantic import AnyUrl 8 | from mcp import ClientSession, StdioServerParameters, McpError, types 9 | from mcp.client.stdio import stdio_client 10 | from mcp.client.sse import sse_client 11 | from mcp.client.streamable_http import streamablehttp_client 12 | 13 | example_mcp = os.path.join(os.path.dirname(__file__), "..", "examples", "mcp_example.py") 14 | assert os.path.exists(example_mcp), f"not found: {example_mcp}" 15 | 16 | async def test_example_server(prefix: str, session: ClientSession): 17 | # Initialize the connection 18 | await session.initialize() 19 | 20 | # Test ping 21 | ping_result = await session.send_ping() 22 | print(f"[{prefix}] Ping result: {ping_result}") 23 | assert isinstance(ping_result, types.EmptyResult), "ping should return EmptyResult" 24 | 25 | # List available resources 26 | resources = await session.list_resources() 27 | print(f"[{prefix}] Available resources: {[r.uri for r in resources.resources]}") 28 | assert len(resources.resources) == 2, "expected 2 static resources" 29 | assert str(resources.resources[0].uri) == "example://system_info", "expected system_info resource" 30 | 31 | # List available resource templates 32 | template_resources = await session.list_resource_templates() 33 | print(f"[{prefix}] Available resource templates: {[r.uriTemplate for r in template_resources.resourceTemplates]}") 34 | assert len(template_resources.resourceTemplates) == 1, "expected 1 resource template" 35 | assert template_resources.resourceTemplates[0].uriTemplate == "example://greeting/{name}", "expected greeting template" 36 | 37 | # List available tools 38 | tools = await session.list_tools() 39 | print(f"[{prefix}] Available tools: {[t.name for t in tools.tools]}") 40 | tool_names = {t.name for t in tools.tools} 41 | assert tool_names == {"divide", "greet", "random_dict", "get_system_info", "failing_tool", "struct_get"}, f"unexpected tools: {tool_names}" 42 | 43 | # List available prompts 44 | prompts = await session.list_prompts() 45 | print(f"[{prefix}] Available prompts: {[p.name for p in prompts.prompts]}") 46 | prompt_names = {p.name for p in prompts.prompts} 47 | assert prompt_names == {"code_review", "summarize"}, ( 48 | f"unexpected prompts: {prompt_names}" 49 | ) 50 | 51 | # Get prompt with required argument only 52 | result = await session.get_prompt("summarize", arguments={"text": "Hello world"}) 53 | assert result.messages, "expected messages" 54 | assert result.messages[0].role == "user" 55 | content = result.messages[0].content 56 | assert isinstance(content, types.TextContent), "expected TextContent" 57 | print(f"[{prefix}] Summarize prompt result: {content.text[:50]}...") 58 | assert "Hello world" in content.text, "expected text in prompt" 59 | 60 | # Get prompt with optional argument 61 | result = await session.get_prompt( 62 | "code_review", arguments={"code": "print('hi')", "language": "javascript"} 63 | ) 64 | assert result.messages, "expected messages" 65 | content = result.messages[0].content 66 | assert isinstance(content, types.TextContent), "expected TextContent" 67 | print(f"[{prefix}] Code review prompt result: {content.text[:50]}...") 68 | assert "javascript" in content.text, "expected language in prompt" 69 | assert "print('hi')" in content.text, "expected code in prompt" 70 | 71 | # Read a resource - assert content 72 | resource_content = await session.read_resource(AnyUrl("example://system_info")) 73 | content_block = resource_content.contents[0] 74 | assert isinstance(content_block, types.TextResourceContents), "expected TextResourceContents" 75 | print(f"[{prefix}] Resource content: {content_block.text}") 76 | assert "platform" in content_block.text, "expected platform in system_info" 77 | assert "python_version" in content_block.text, "expected python_version in system_info" 78 | 79 | # Read template resource - assert content 80 | template_content = await session.read_resource(AnyUrl("example://greeting/Pythonista")) 81 | template_block = template_content.contents[0] 82 | assert isinstance(template_block, types.TextResourceContents), "expected TextResourceContents" 83 | print(f"[{prefix}] Template resource content: {template_block.text}") 84 | assert "Pythonista" in template_block.text, "expected name in greeting" 85 | assert "message" in template_block.text, "expected message field in greeting" 86 | 87 | # Call divide tool 88 | result = await session.call_tool("divide", arguments={"numerator": 42, "denominator": 2}) 89 | assert not result.isError, "divide should succeed" 90 | result_unstructured = result.content[0] 91 | assert isinstance(result_unstructured, types.TextContent), "expected TextContent" 92 | print(f"[{prefix}] Divide result: {result_unstructured.text}") 93 | assert "21" in result_unstructured.text, "42/2 should be 21" 94 | 95 | # Call greet tool without age 96 | result = await session.call_tool("greet", arguments={"name": "Alice"}) 97 | assert not result.isError, "greet should succeed" 98 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 99 | content = result.content[0].text 100 | print(f"[{prefix}] Greet result: {content}") 101 | assert "Alice" in content, "expected name in greeting" 102 | assert "message" in content, "expected message field" 103 | assert "age" not in content or content.count("age") == 1, "age should not have value" 104 | 105 | # Call greet tool with age 106 | result = await session.call_tool("greet", arguments={"name": "Bob", "age": 30}) 107 | assert not result.isError, "greet with age should succeed" 108 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 109 | content = result.content[0].text 110 | print(f"[{prefix}] Greet with age result: {content}") 111 | assert "Bob" in content and "30" in content, "expected name and age" 112 | 113 | # Call get_system_info tool 114 | result = await session.call_tool("get_system_info", arguments={}) 115 | assert not result.isError, "get_system_info should succeed" 116 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 117 | content = result.content[0].text 118 | print(f"[{prefix}] System info result: {content}") 119 | assert "platform" in content, "expected platform" 120 | assert "python_version" in content, "expected python_version" 121 | assert "timestamp" in content, "expected timestamp" 122 | 123 | # Call struct_get with list 124 | result = await session.call_tool("struct_get", arguments={"names": ["foo", "bar"]}) 125 | assert not result.isError, "struct_get with list should succeed" 126 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 127 | content = result.content[0].text 128 | print(f"[{prefix}] Struct_get (list) result: {content}") 129 | assert "foo" in content and "bar" in content, "expected both struct names" 130 | 131 | # Call struct_get with string 132 | result = await session.call_tool("struct_get", arguments={"names": "baz"}) 133 | assert not result.isError, "struct_get with string should succeed" 134 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 135 | content = result.content[0].text 136 | print(f"[{prefix}] Struct_get (string) result: {content}") 137 | assert "baz" in content, "expected struct name" 138 | 139 | # Call failing tool 140 | result = await session.call_tool("failing_tool", arguments={"message": "This is a test error"}) 141 | print(f"[{prefix}] Failing tool result: {result}") 142 | assert result.isError, "expected tool call to fail" 143 | assert isinstance(result.content[0], types.TextContent), "expected text content in tool call result" 144 | assert "test error" in result.content[0].text, "expected error message in tool call result" 145 | 146 | # Call random_dict tool 147 | result = await session.call_tool("random_dict", arguments={"param": {"x": 112, "other": "yes"}}) 148 | assert not result.isError, "random_dict should succeed" 149 | assert isinstance(result.content[0], types.TextContent), "expected TextContent" 150 | content = result.content[0].text 151 | print(f"[{prefix}] Random dict result: {content}") 152 | 153 | # Call random_dict tool with null 154 | result = await session.call_tool("random_dict", arguments={"param": None}) 155 | assert not result.isError, "random_dict with null should succeed" 156 | 157 | async def test_edge_cases(prefix: str, session: ClientSession): 158 | """Test edge cases and error conditions""" 159 | await session.initialize() 160 | 161 | # Test non-existent tool 162 | result = await session.call_tool("nonexistent_tool", arguments={}) 163 | assert result.isError, "should error on non-existent tool" 164 | print(f"[{prefix}] Non-existent tool error: {result.content[0] if result.content else 'no content'}") 165 | 166 | # Test missing required parameter 167 | result = await session.call_tool("divide", arguments={"numerator": 42}) 168 | assert result.isError, "should error on missing denominator" 169 | print(f"[{prefix}] Missing param error: {result.content[0] if result.content else 'no content'}") 170 | 171 | # Test division by zero (natural exception) 172 | result = await session.call_tool("divide", arguments={"numerator": 1, "denominator": 0}) 173 | assert result.isError, "division by zero should error" 174 | print(f"[{prefix}] Division by zero error: {result.content[0] if result.content else 'no content'}") 175 | 176 | # Test invalid resource URI 177 | try: 178 | await session.read_resource(AnyUrl("example://invalid_resource")) 179 | assert False, "should have raised on invalid resource" 180 | except McpError as e: 181 | assert "Resource not found" in e.error.message, "expected invalid resource error" 182 | 183 | # Test resource template with missing substitution 184 | try: 185 | await session.read_resource(AnyUrl("example://greeting/")) 186 | assert False, "should have raised on missing template parameter" 187 | except McpError as e: 188 | assert "Resource not found" in e.error.message, "expected missing substitution error" 189 | 190 | # Test resource that raises an error 191 | try: 192 | await session.read_resource(AnyUrl("example://error")) 193 | assert False, "should have raised on error resource" 194 | except McpError as e: 195 | assert "This is a resource error for testing purposes." in e.error.message, ( 196 | "expected resource error message" 197 | ) 198 | 199 | # Test non-existent prompt 200 | try: 201 | await session.get_prompt("nonexistent_prompt", arguments={}) 202 | assert False, "should have raised on non-existent prompt" 203 | except McpError as e: 204 | assert "Method 'nonexistent_prompt' not found" in e.error.message, ( 205 | "expected method not found error" 206 | ) 207 | 208 | print(f"[{prefix}] Edge case tests passed!") 209 | 210 | 211 | def coverage_wrap(name: str, args: list[str]) -> list[str]: 212 | if os.environ.get("COVERAGE_RUN"): 213 | args = ["-m", "coverage", "run", f"--data-file=.coverage.{name}"] + args 214 | return args 215 | 216 | async def test_stdio(): 217 | print("[stdio] Testing...") 218 | server_params = StdioServerParameters( 219 | command=sys.executable, 220 | args=coverage_wrap("stdio", [example_mcp, "--transport", "stdio"]), 221 | ) 222 | async with stdio_client(server_params) as (read, write): 223 | async with ClientSession(read, write) as session: 224 | await test_example_server("stdio", session) 225 | await test_edge_cases("stdio", session) 226 | 227 | async def test_sse(address: str): 228 | print("[sse] Testing...") 229 | async with sse_client(f"{address}/sse") as (read, write): 230 | async with ClientSession(read, write) as session: 231 | await test_example_server("sse", session) 232 | await test_edge_cases("sse", session) 233 | 234 | async def test_streamablehttp(address: str): 235 | print("[streamable] Testing...") 236 | async with streamablehttp_client(f"{address}/mcp") as (read, write, session_callback): 237 | async with ClientSession(read, write) as session: 238 | await test_example_server("streamable", session) 239 | await test_edge_cases("streamable", session) 240 | 241 | def find_available_port(): 242 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 243 | sock.bind(("", 0)) 244 | port = sock.getsockname()[1] 245 | sock.close() 246 | return port 247 | 248 | async def test_serve(): 249 | print("[serve] Testing...") 250 | 251 | # Start example MCP server as subprocess 252 | address = f"http://127.0.0.1:{find_available_port()}" 253 | process = subprocess.Popen( 254 | [sys.executable] + coverage_wrap("serve", [example_mcp, "--transport", address]), 255 | stdin=subprocess.PIPE, 256 | text=True, 257 | encoding="utf-8", 258 | bufsize=1, 259 | ) 260 | try: 261 | await asyncio.sleep(0.5) # Wait for server to start 262 | await test_sse(address) 263 | await test_streamablehttp(address) 264 | finally: 265 | print("[serve] Terminating example MCP server") 266 | process.stdin.close() # type: ignore 267 | process.wait() 268 | pass 269 | 270 | async def main(): 271 | await test_serve() 272 | await test_stdio() 273 | 274 | if __name__ == "__main__": 275 | import os 276 | asyncio.run(main()) 277 | -------------------------------------------------------------------------------- /tests/jsonrpc_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comprehensive JSON-RPC 2.0 test suite for MCP implementation 3 | """ 4 | import json 5 | import sys 6 | import traceback 7 | import re 8 | from typing import Optional, Any, TypedDict 9 | 10 | from zeromcp.jsonrpc import JsonRpcRegistry 11 | 12 | # Create registry and register test methods 13 | jsonrpc = JsonRpcRegistry() 14 | 15 | class Point(TypedDict): 16 | x: int 17 | y: int 18 | 19 | @jsonrpc.method 20 | def subtract(minuend: int, subtrahend: int) -> int: 21 | return minuend - subtrahend 22 | 23 | @jsonrpc.method 24 | def update(a: int, b: int, c: int, d: int, e: int) -> str: 25 | return "updated" 26 | 27 | @jsonrpc.method 28 | def foobar() -> str: 29 | return "bar" 30 | 31 | @jsonrpc.method 32 | def get_data() -> list: 33 | return ["hello", 5] 34 | 35 | @jsonrpc.method 36 | def greet(name: str, greeting: str = "Hello") -> str: 37 | return f"{greeting}, {name}!" 38 | 39 | @jsonrpc.method 40 | def process_optional(value: Optional[int]) -> str: 41 | return f"Got: {value}" 42 | 43 | @jsonrpc.method 44 | def union_test(id: int | str | None | Point) -> str: 45 | return f"ID: {id or ''}" 46 | 47 | @jsonrpc.method 48 | def list_test(items: list[str]) -> int: 49 | return len(items) 50 | 51 | @jsonrpc.method 52 | def exception(): 53 | raise Exception("Python exception") 54 | 55 | @jsonrpc.method 56 | def point_pretty(p: Point) -> str: 57 | return f"Point(x={p['x']}, y={p['y']})" 58 | 59 | @jsonrpc.method 60 | def round_float(value: float) -> int: 61 | return round(value) 62 | 63 | @jsonrpc.method 64 | def python_repr(value: Any) -> str: 65 | return repr(value) 66 | 67 | @jsonrpc.method 68 | def unknown(x, y): 69 | return x + y 70 | 71 | def matches_response(actual: dict | None, expected: dict | None) -> bool: 72 | """Check if actual response matches expected, with regex support for error messages.""" 73 | if actual is None and expected is None: 74 | return True 75 | if actual is None or expected is None: 76 | return False 77 | 78 | # Check top-level keys 79 | if set(actual.keys()) != set(expected.keys()): 80 | return False 81 | 82 | for key in expected.keys(): 83 | actual_val = actual[key] 84 | expected_val = expected[key] 85 | 86 | # Handle error object specially for regex matching 87 | if key == "error" and isinstance(expected_val, dict) and isinstance(actual_val, dict): 88 | # Check code exactly 89 | if actual_val.get("code") != expected_val.get("code"): 90 | return False 91 | 92 | # Check message with regex support 93 | expected_msg = expected_val.get("message", "") 94 | actual_msg = actual_val.get("message", "") 95 | if isinstance(expected_msg, str) and expected_msg.startswith("regex:"): 96 | pattern = expected_msg[6:] # Remove "regex:" prefix 97 | if not re.search(pattern, actual_msg): 98 | return False 99 | else: 100 | if actual_msg != expected_msg: 101 | return False 102 | 103 | # For data field, support regex patterns 104 | if "data" in expected_val: 105 | expected_data = expected_val["data"] 106 | actual_data = actual_val.get("data", "") 107 | 108 | # If expected_data starts with "regex:", treat it as a regex pattern 109 | if isinstance(expected_data, str) and expected_data.startswith("regex:"): 110 | pattern = expected_data[6:] # Remove "regex:" prefix 111 | if not re.search(pattern, str(actual_data)): 112 | return False 113 | else: 114 | # Exact match 115 | if actual_data != expected_data: 116 | return False 117 | elif "data" in actual_val: 118 | # Actual has data but expected doesn't - that's ok 119 | pass 120 | else: 121 | # Exact match for other fields 122 | if actual_val != expected_val: 123 | return False 124 | 125 | return True 126 | 127 | def test_rpc(request: Any, expected_response: dict | None = None, description: str = ""): 128 | """Helper to test RPC calls""" 129 | print(f"\n{'='*60}") 130 | print(f"Test: {description}") 131 | print(f"--> {request}") 132 | 133 | try: 134 | result = jsonrpc.dispatch(request) 135 | except Exception: 136 | print("\n❌ UNEXPECTED EXCEPTION:") 137 | traceback.print_exc() 138 | sys.exit(1) 139 | 140 | if result is None: 141 | print("<-- (no response - notification)") 142 | if expected_response is not None: 143 | print("\n❌ FAIL: Expected response but got None") 144 | print("Expected: {json.dumps(expected_response, indent=2)}") 145 | sys.exit(1) 146 | else: 147 | result_json = json.dumps(result, indent=2) 148 | print(f"<-- {result_json}") 149 | 150 | if expected_response is not None: 151 | if not matches_response(result, expected_response): # type: ignore 152 | print("\n❌ FAIL: Response mismatch") 153 | print(f"Expected: {json.dumps(expected_response, indent=2)}") 154 | print(f"Got: {result_json}") 155 | sys.exit(1) 156 | 157 | print("✓ PASS") 158 | return result 159 | 160 | 161 | def run_all_tests(): 162 | print("="*60) 163 | print("JSON-RPC 2.0 COMPLIANCE TESTS") 164 | print("="*60) 165 | 166 | # ======================================== 167 | # SPEC EXAMPLES 168 | # ======================================== 169 | 170 | # Positional parameters 171 | test_rpc( 172 | {"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1}, 173 | {"jsonrpc": "2.0", "result": 19, "id": 1}, 174 | "Positional params - subtract(42, 23)" 175 | ) 176 | 177 | test_rpc( 178 | '{"jsonrpc": "2.0", "method": "subtract", "params": [23, 42], "id": 2}', 179 | {"jsonrpc": "2.0", "result": -19, "id": 2}, 180 | "Positional params - subtract(23, 42)" 181 | ) 182 | 183 | # Named parameters 184 | test_rpc( 185 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"subtrahend": 23, "minuend": 42}, "id": 3}', 186 | {"jsonrpc": "2.0", "result": 19, "id": 3}, 187 | "Named params - order independent (1)" 188 | ) 189 | 190 | test_rpc( 191 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23}, "id": 4}', 192 | {"jsonrpc": "2.0", "result": 19, "id": 4}, 193 | "Named params - order independent (2)" 194 | ) 195 | 196 | # Notifications (no response) 197 | test_rpc( 198 | '{"jsonrpc": "2.0", "method": "update", "params": [1,2,3,4,5]}', 199 | None, 200 | "Notification - update (no id)" 201 | ) 202 | 203 | test_rpc( 204 | '{"jsonrpc": "2.0", "method": "foobar"}', 205 | None, 206 | "Notification - foobar (no params, no id)" 207 | ) 208 | 209 | # Non-existent method 210 | test_rpc( 211 | '{"jsonrpc": "2.0", "method": "does_not_exist", "id": "1"}', 212 | {"jsonrpc": "2.0", "error": {"code": -32601, "message": "regex:Method.*not found"}, "id": "1"}, 213 | "Non-existent method error" 214 | ) 215 | 216 | # Invalid JSON - use regex to match error since different parsers give different messages 217 | test_rpc( 218 | '{"jsonrpc": "2.0", "method": "foobar, "params": "bar", "baz]', 219 | {"jsonrpc": "2.0", "error": {"code": -32700, "message": "JSON parse error", "data": "regex:Expecting"}, "id": None}, 220 | "Parse error - invalid JSON" 221 | ) 222 | 223 | test_rpc( 224 | 1234, 225 | {"jsonrpc": "2.0", "error": {"code": -32700, "message": "JSON parse error", "data": "regex:object must be"}, "id": None}, 226 | "Parse error - invalid JSON" 227 | ) 228 | 229 | # Invalid Request object - method is not a string 230 | test_rpc( 231 | '{"jsonrpc": "2.0", "method": 1, "params": "bar"}', 232 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'method' must be a string"}, "id": None}, 233 | "Invalid Request - method is number" 234 | ) 235 | 236 | # Missing jsonrpc version 237 | test_rpc( 238 | '{"method": "subtract", "params": [1, 2], "id": 1}', 239 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'jsonrpc' must be '2.0'"}, "id": None}, 240 | "Invalid Request - missing jsonrpc field" 241 | ) 242 | 243 | # Wrong jsonrpc version 244 | test_rpc( 245 | '{"jsonrpc": "1.0", "method": "subtract", "params": [1, 2], "id": 1}', 246 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'jsonrpc' must be '2.0'"}, "id": None}, 247 | "Invalid Request - wrong jsonrpc version" 248 | ) 249 | 250 | # Missing method 251 | test_rpc( 252 | '{"jsonrpc": "2.0", "params": [1, 2], "id": 1}', 253 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'method' is required"}, "id": None}, 254 | "Invalid Request - missing method" 255 | ) 256 | 257 | # Empty array (not valid single request) 258 | test_rpc( 259 | '[]', 260 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None}, 261 | "Invalid Request - empty array" 262 | ) 263 | 264 | # Non-object request 265 | test_rpc( 266 | '"not an object"', 267 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None}, 268 | "Invalid Request - string instead of object" 269 | ) 270 | 271 | test_rpc( 272 | '123', 273 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None}, 274 | "Invalid Request - number instead of object" 275 | ) 276 | 277 | # Request with id: null (valid request, not a notification) 278 | test_rpc( 279 | '{"jsonrpc": "2.0", "method": "foobar", "id": null}', 280 | {"jsonrpc": "2.0", "result": "bar", "id": None}, 281 | "Valid request with id: null" 282 | ) 283 | 284 | # ======================================== 285 | # PARAMETER VALIDATION TESTS 286 | # ======================================== 287 | 288 | # Wrong number of positional params - too few 289 | test_rpc( 290 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42], "id": 1}', 291 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: expected at least 2 arguments, got 1"}, "id": 1}, 292 | "Invalid params - too few positional arguments" 293 | ) 294 | 295 | # Wrong number of positional params - too many 296 | test_rpc( 297 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, 23, 10], "id": 1}', 298 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: expected at most 2 arguments, got 3"}, "id": 1}, 299 | "Invalid params - too many positional arguments" 300 | ) 301 | 302 | # Missing required named param 303 | test_rpc( 304 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42}, "id": 1}', 305 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: missing required parameters: ['subtrahend']"}, "id": 1}, 306 | "Invalid params - missing required parameter" 307 | ) 308 | 309 | # Extra named param 310 | test_rpc( 311 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23, "extra": 1}, "id": 1}', 312 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: unexpected parameters: ['extra']"}, "id": 1}, 313 | "Invalid params - unexpected parameter" 314 | ) 315 | 316 | # Wrong type - string instead of int 317 | test_rpc( 318 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, "not a number"], "id": 1}', 319 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend expected int, got str"}, "id": 1}, 320 | "Invalid params - wrong type (str instead of int)" 321 | ) 322 | 323 | # Wrong type - list instead of int 324 | test_rpc( 325 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": [1, 2]}, "id": 1}', 326 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend expected int, got list"}, "id": 1}, 327 | "Invalid params - wrong type (list instead of int)" 328 | ) 329 | 330 | # Null for non-optional param 331 | test_rpc( 332 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, null], "id": 1}', 333 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend cannot be null"}, "id": 1}, 334 | "Invalid params - null for non-optional parameter" 335 | ) 336 | 337 | # Params is invalid type (not array or object) 338 | test_rpc( 339 | '{"jsonrpc": "2.0", "method": "subtract", "params": "invalid", "id": 1}', 340 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: must be array or object"}, "id": 1}, 341 | "Invalid params - string instead of array/object" 342 | ) 343 | 344 | test_rpc( 345 | '{"jsonrpc": "2.0", "method": "subtract", "params": 123, "id": 1}', 346 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: must be array or object"}, "id": 1}, 347 | "Invalid params - number instead of array/object" 348 | ) 349 | 350 | test_rpc( 351 | '{"jsonrpc": "2.0", "method": "subtract", "params": null, "id": 1}', 352 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Missing required params"}, "id": 1}, 353 | "Invalid params - null for required params" 354 | ) 355 | 356 | # ======================================== 357 | # DEFAULT PARAMETERS TESTS 358 | # ======================================== 359 | 360 | # Function with default param - omit optional param (positional) 361 | test_rpc( 362 | '{"jsonrpc": "2.0", "method": "greet", "params": ["Alice"], "id": 1}', 363 | {"jsonrpc": "2.0", "result": "Hello, Alice!", "id": 1}, 364 | "Default param - omit optional (positional)" 365 | ) 366 | 367 | # Function with default param - provide optional param (positional) 368 | test_rpc( 369 | '{"jsonrpc": "2.0", "method": "greet", "params": ["Alice", "Hi"], "id": 1}', 370 | {"jsonrpc": "2.0", "result": "Hi, Alice!", "id": 1}, 371 | "Default param - provide optional (positional)" 372 | ) 373 | 374 | # Function with default param - omit optional param (named) 375 | test_rpc( 376 | '{"jsonrpc": "2.0", "method": "greet", "params": {"name": "Bob"}, "id": 1}', 377 | {"jsonrpc": "2.0", "result": "Hello, Bob!", "id": 1}, 378 | "Default param - omit optional (named)" 379 | ) 380 | 381 | # Function with default param - provide optional param (named) 382 | test_rpc( 383 | '{"jsonrpc": "2.0", "method": "greet", "params": {"name": "Bob", "greeting": "Hey"}, "id": 1}', 384 | {"jsonrpc": "2.0", "result": "Hey, Bob!", "id": 1}, 385 | "Default param - provide optional (named)" 386 | ) 387 | 388 | # Function with no params - omit params field 389 | test_rpc( 390 | '{"jsonrpc": "2.0", "method": "get_data", "id": 1}', 391 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1}, 392 | "No params function - params field omitted" 393 | ) 394 | 395 | # Function with no params - empty array 396 | test_rpc( 397 | '{"jsonrpc": "2.0", "method": "get_data", "params": [], "id": 1}', 398 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1}, 399 | "No params function - empty array" 400 | ) 401 | 402 | # Function with no params - empty object 403 | test_rpc( 404 | '{"jsonrpc": "2.0", "method": "get_data", "params": {}, "id": 1}', 405 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1}, 406 | "No params function - empty object" 407 | ) 408 | 409 | # ======================================== 410 | # UNION TYPE TESTS 411 | # ======================================== 412 | 413 | # Union type - int 414 | test_rpc( 415 | '{"jsonrpc": "2.0", "method": "union_test", "params": [123], "id": 1}', 416 | {"jsonrpc": "2.0", "result": "ID: 123", "id": 1}, 417 | "Union type (int | str) - int value" 418 | ) 419 | 420 | # Union type - str 421 | test_rpc( 422 | '{"jsonrpc": "2.0", "method": "union_test", "params": ["abc"], "id": 1}', 423 | {"jsonrpc": "2.0", "result": "ID: abc", "id": 1}, 424 | "Union type (int | str) - str value" 425 | ) 426 | 427 | # Union type - invalid type 428 | test_rpc( 429 | '{"jsonrpc": "2.0", "method": "union_test", "params": [[1, 2, 3]], "id": 1}', 430 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: id union does not contain list"}, "id": 1}, 431 | "Union type (int | str) - invalid list" 432 | ) 433 | 434 | # Union type - null 435 | test_rpc( 436 | '{"jsonrpc": "2.0", "method": "union_test", "params": [null], "id": 1}', 437 | {"jsonrpc": "2.0", "result": "ID: ", "id": 1}, 438 | "Union type (int | str | None) - null value" 439 | ) 440 | 441 | # ======================================== 442 | # OPTIONAL TYPE TESTS 443 | # ======================================== 444 | 445 | # Optional - provide value 446 | test_rpc( 447 | '{"jsonrpc": "2.0", "method": "process_optional", "params": [42], "id": 1}', 448 | {"jsonrpc": "2.0", "result": "Got: 42", "id": 1}, 449 | "Optional type - provide value" 450 | ) 451 | 452 | # Optional - provide null 453 | test_rpc( 454 | '{"jsonrpc": "2.0", "method": "process_optional", "params": [null], "id": 1}', 455 | {"jsonrpc": "2.0", "result": "Got: None", "id": 1}, 456 | "Optional type - provide null" 457 | ) 458 | 459 | # ======================================== 460 | # GENERIC TYPE TESTS 461 | # ======================================== 462 | 463 | # list[T] - valid list 464 | test_rpc( 465 | '{"jsonrpc": "2.0", "method": "list_test", "params": [["a", "b", "c"]], "id": 1}', 466 | {"jsonrpc": "2.0", "result": 3, "id": 1}, 467 | "Generic type list[str] - valid list (no inner validation)" 468 | ) 469 | 470 | # list[T] - wrong outer type 471 | test_rpc( 472 | '{"jsonrpc": "2.0", "method": "list_test", "params": ["not a list"], "id": 1}', 473 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: items expected list, got str"}, "id": 1}, 474 | "Generic type list[str] - wrong outer type" 475 | ) 476 | 477 | # Point TypedDict - valid dict 478 | test_rpc( 479 | '{"jsonrpc": "2.0", "method": "point_pretty", "params": [{"x": 10, "y": 20}], "id": 1}', 480 | {"jsonrpc": "2.0", "result": "Point(x=10, y=20)", "id": 1}, 481 | "TypedDict Point - valid dict" 482 | ) 483 | 484 | # Point TypedDict - wrong outer type 485 | test_rpc( 486 | '{"jsonrpc": "2.0", "method": "point_pretty", "params": ["not a dict"], "id": 1}', 487 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: p expected dict, got str"}, "id": 1}, 488 | "TypedDict Point - wrong outer type" 489 | ) 490 | 491 | # Convert from int to float 492 | test_rpc( 493 | '{"jsonrpc": "2.0", "method": "round_float", "params": [3], "id": 1}', 494 | {"jsonrpc": "2.0", "result": 3, "id": 1}, 495 | "Convert int to float for float parameter" 496 | ) 497 | 498 | # Any type - various inputs 499 | test_rpc( 500 | '{"jsonrpc": "2.0", "method": "python_repr", "params": [42], "id": 1}', 501 | {"jsonrpc": "2.0", "result": "42", "id": 1}, 502 | "Any type - int value" 503 | ) 504 | 505 | test_rpc( 506 | '{"jsonrpc": "2.0", "method": "python_repr", "params": ["hello"], "id": 1}', 507 | {"jsonrpc": "2.0", "result": "'hello'", "id": 1}, 508 | "Any type - str value" 509 | ) 510 | 511 | # Unspecified types (unknown) - should accept anything 512 | test_rpc( 513 | '{"jsonrpc": "2.0", "method": "unknown", "params": [10, 20], "id": 1}', 514 | {"jsonrpc": "2.0", "result": 30, "id": 1}, 515 | "Unknown parameter types - accept any" 516 | ) 517 | 518 | # ======================================== 519 | # NOTIFICATION ERROR HANDLING 520 | # ======================================== 521 | 522 | # Notification with error (should return None, no response) 523 | test_rpc( 524 | '{"jsonrpc": "2.0", "method": "does_not_exist"}', 525 | None, 526 | "Notification - error does not produce response" 527 | ) 528 | 529 | # Notification with invalid params (should return None, no response) 530 | test_rpc( 531 | '{"jsonrpc": "2.0", "method": "subtract", "params": [1]}', 532 | None, 533 | "Notification - invalid params does not produce response" 534 | ) 535 | 536 | test_rpc( 537 | '{"jsonrpc": "2.0", "method": "exception", "id": 1}', 538 | {"jsonrpc": "2.0", "error": {"code": -32603, "message": "regex:Python exception"}, "id": 1}, 539 | "Method that raises python exception" 540 | ) 541 | 542 | test_rpc( 543 | '{"jsonrpc": "2.0", "method": "exception"}', 544 | None, 545 | "Notification - method that raises python exception" 546 | ) 547 | 548 | print("\n" + "="*60) 549 | print("ALL TESTS PASSED! ✓") 550 | print("="*60) 551 | 552 | if __name__ == "__main__": 553 | run_all_tests() 554 | -------------------------------------------------------------------------------- /src/zeromcp/mcp.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import time 4 | import uuid 5 | import json 6 | import inspect 7 | import threading 8 | import traceback 9 | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer, HTTPServer 10 | from typing import Any, Callable, Union, Annotated, BinaryIO, NotRequired, get_origin, get_args, get_type_hints, is_typeddict 11 | from types import UnionType 12 | from urllib.parse import urlparse, parse_qs 13 | from io import BufferedIOBase 14 | 15 | from .jsonrpc import JsonRpcRegistry, JsonRpcError, JsonRpcException 16 | 17 | class McpToolError(Exception): 18 | def __init__(self, message: str): 19 | super().__init__(message) 20 | 21 | class McpRpcRegistry(JsonRpcRegistry): 22 | """JSON-RPC registry with custom error handling for MCP tools""" 23 | def map_exception(self, e: Exception) -> JsonRpcError: 24 | if isinstance(e, McpToolError): 25 | return { 26 | "code": -32000, 27 | "message": e.args[0] or "MCP Tool Error", 28 | } 29 | return super().map_exception(e) 30 | 31 | class _McpSseConnection: 32 | """Manages a single SSE client connection""" 33 | def __init__(self, wfile): 34 | self.wfile: BufferedIOBase = wfile 35 | self.session_id = str(uuid.uuid4()) 36 | self.alive = True 37 | 38 | def send_event(self, event_type: str, data): 39 | """Send an SSE event to the client 40 | 41 | Args: 42 | event_type: Type of event (e.g., "endpoint", "message", "ping") 43 | data: Event data - can be string (sent as-is) or dict (JSON-encoded) 44 | """ 45 | if not self.alive: 46 | return False 47 | 48 | try: 49 | # SSE format: "event: type\ndata: content\n\n" 50 | if isinstance(data, str): 51 | data_str = f"data: {data}\n\n" 52 | else: 53 | data_str = f"data: {json.dumps(data)}\n\n" 54 | message = f"event: {event_type}\n{data_str}".encode("utf-8") 55 | self.wfile.write(message) 56 | self.wfile.flush() # Ensure data is sent immediately 57 | return True 58 | except (BrokenPipeError, OSError): 59 | self.alive = False 60 | return False 61 | 62 | class McpHttpRequestHandler(BaseHTTPRequestHandler): 63 | server_version = "zeromcp/1.3.0" 64 | error_message_format = "%(code)d - %(message)s" 65 | error_content_type = "text/plain" 66 | 67 | def __init__(self, request, client_address, server): 68 | self.mcp_server: "McpServer" = getattr(server, "mcp_server") 69 | super().__init__(request, client_address, server) 70 | 71 | def log_message(self, format, *args): 72 | """Override to suppress default logging or customize""" 73 | pass 74 | 75 | def send_cors_headers(self, *, preflight = False): 76 | origin = self.headers.get("Origin", "") 77 | if not origin: 78 | return 79 | def is_allowed(): 80 | allowed = self.mcp_server.cors_allowed_origins 81 | if allowed is None: 82 | return False 83 | if callable(allowed): 84 | return allowed(origin) 85 | if isinstance(allowed, str): 86 | allowed = [allowed] 87 | assert isinstance(allowed, list) 88 | return "*" in allowed or origin in allowed 89 | if not is_allowed(): 90 | return 91 | self.send_header("Access-Control-Allow-Origin", origin) 92 | if preflight: 93 | self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") 94 | self.send_header("Access-Control-Allow-Headers", "Content-Type, Accept, X-Requested-With, Mcp-Session-Id, Mcp-Protocol-Version") 95 | if self.headers.get("Access-Control-Request-Private-Network") == "true": 96 | self.send_header("Access-Control-Allow-Private-Network", "true") 97 | 98 | def send_error(self, code, message=None, explain=None): 99 | self.send_response(code) 100 | self.send_header("Content-Type", "text/plain") 101 | self.send_cors_headers() 102 | self.end_headers() 103 | self.wfile.write(f"{message}\n".encode("utf-8")) 104 | 105 | def handle(self): 106 | """Override to add error handling for connection errors""" 107 | try: 108 | super().handle() 109 | except (ConnectionAbortedError, ConnectionResetError, BrokenPipeError): 110 | # Client disconnected - normal, suppress traceback 111 | pass 112 | 113 | def do_GET(self): 114 | match urlparse(self.path).path: 115 | case "/sse": 116 | self._handle_sse_get() 117 | case "/mcp": 118 | self.send_error(405, "Method Not Allowed") 119 | case _: 120 | self.send_error(404, "Not Found") 121 | 122 | def do_POST(self): 123 | # Read request body 124 | content_length = int(self.headers.get("Content-Length", 0)) 125 | 126 | if content_length > self.mcp_server.post_body_limit: 127 | self.send_error(413, f"Payload Too Large: exceeds {self.mcp_server.post_body_limit} bytes") 128 | return 129 | 130 | body = self.rfile.read(content_length) if content_length > 0 else b"" 131 | 132 | match urlparse(self.path).path: 133 | case "/sse": 134 | self._handle_sse_post(body) 135 | case "/mcp": 136 | self._handle_mcp_post(body) 137 | case _: 138 | self.send_error(404, "Not Found") 139 | 140 | def do_OPTIONS(self): 141 | """Handle CORS preflight requests""" 142 | self.send_response(200) 143 | self.send_cors_headers(preflight=True) 144 | self.end_headers() 145 | 146 | def _handle_sse_get(self): 147 | # Create SSE connection wrapper 148 | conn = _McpSseConnection(self.wfile) 149 | self.mcp_server._sse_connections[conn.session_id] = conn 150 | 151 | try: 152 | # Send SSE headers 153 | self.send_response(200) 154 | self.send_header("Content-Type", "text/event-stream") 155 | self.send_header("Cache-Control", "no-cache") 156 | self.send_header("Connection", "keep-alive") 157 | self.send_cors_headers() 158 | self.end_headers() 159 | 160 | # Send endpoint event with session ID for routing 161 | conn.send_event("endpoint", f"/sse?session={conn.session_id}") 162 | 163 | # Keep connection alive with periodic pings 164 | last_ping = time.time() 165 | while conn.alive and self.mcp_server._running: 166 | now = time.time() 167 | if now - last_ping > 30: # Ping every 30 seconds 168 | if not conn.send_event("ping", {}): 169 | break 170 | last_ping = now 171 | time.sleep(1) 172 | 173 | finally: 174 | conn.alive = False 175 | if conn.session_id in self.mcp_server._sse_connections: 176 | del self.mcp_server._sse_connections[conn.session_id] 177 | 178 | def _handle_sse_post(self, body: bytes): 179 | query_params = parse_qs(urlparse(self.path).query) 180 | session_id = query_params.get("session", [None])[0] 181 | if session_id is None: 182 | self.send_error(400, "Missing ?session for SSE POST") 183 | return 184 | 185 | # Dispatch to MCP registry 186 | setattr(self.mcp_server._protocol_version, "data", "2024-11-05") 187 | response = self.mcp_server.registry.dispatch(body) 188 | 189 | # Send SSE response if necessary 190 | if response is not None: 191 | sse_conn = self.mcp_server._sse_connections.get(session_id) 192 | if sse_conn is None or not sse_conn.alive: 193 | # No SSE connection found 194 | self.send_error(400, f"No active SSE connection found for session {session_id}") 195 | return 196 | 197 | # Send response via SSE event stream 198 | sse_conn.send_event("message", response) 199 | 200 | # Return 202 Accepted to acknowledge POST 201 | self.send_response(202) 202 | self.send_header("Content-Type", "application/json") 203 | self.send_header("Content-Length", str(len(body))) 204 | self.send_cors_headers() 205 | self.end_headers() 206 | self.wfile.write(body) 207 | 208 | def _handle_mcp_post(self, body: bytes): 209 | # Dispatch to MCP registry 210 | setattr(self.mcp_server._protocol_version, "data", "2025-06-18") 211 | response = self.mcp_server.registry.dispatch(body) 212 | 213 | def send_response(status: int, body: bytes): 214 | self.send_response(status) 215 | self.send_header("Content-Type", "application/json") 216 | self.send_header("Content-Length", str(len(body))) 217 | self.send_cors_headers() 218 | self.end_headers() 219 | self.wfile.write(body) 220 | 221 | # Check if notification (returns None) 222 | if response is None: 223 | send_response(202, b"Accepted") 224 | else: 225 | send_response(200, json.dumps(response).encode("utf-8")) 226 | 227 | class McpServer: 228 | def __init__(self, name: str, version = "1.0.0"): 229 | self.name = name 230 | self.version = version 231 | self.post_body_limit = 10 * 1024 * 1024 232 | self.cors_allowed_origins: Callable[[str], bool] | list[str] | str | None = self.cors_localhost 233 | self.tools = McpRpcRegistry() 234 | self.resources = McpRpcRegistry() 235 | self.prompts = McpRpcRegistry() 236 | 237 | self._http_server: HTTPServer | None = None 238 | self._server_thread: threading.Thread | None = None 239 | self._running = False 240 | self._sse_connections: dict[str, _McpSseConnection] = {} 241 | self._protocol_version = threading.local() 242 | 243 | # Register MCP protocol methods with correct names 244 | self.registry = JsonRpcRegistry() 245 | self.registry.methods["ping"] = self._mcp_ping 246 | self.registry.methods["initialize"] = self._mcp_initialize 247 | self.registry.methods["tools/list"] = self._mcp_tools_list 248 | self.registry.methods["tools/call"] = self._mcp_tools_call 249 | self.registry.methods["resources/list"] = self._mcp_resources_list 250 | self.registry.methods["resources/templates/list"] = self._mcp_resource_templates_list 251 | self.registry.methods["resources/read"] = self._mcp_resources_read 252 | self.registry.methods["prompts/list"] = self._mcp_prompts_list 253 | self.registry.methods["prompts/get"] = self._mcp_prompts_get 254 | 255 | def tool(self, func: Callable) -> Callable: 256 | return self.tools.method(func) 257 | 258 | def prompt(self, func: Callable) -> Callable: 259 | return self.prompts.method(func) 260 | 261 | def resource(self, uri: str) -> Callable[[Callable], Callable]: 262 | def decorator(func: Callable) -> Callable: 263 | setattr(func, "__resource_uri__", uri) 264 | return self.resources.method(func) 265 | return decorator 266 | 267 | def serve(self, host: str, port: int, *, background = True, request_handler = McpHttpRequestHandler): 268 | if self._running: 269 | print("[MCP] Server is already running") 270 | return 271 | 272 | # Create server with deferred binding 273 | assert issubclass(request_handler, McpHttpRequestHandler) 274 | self._http_server = (ThreadingHTTPServer if background else HTTPServer)( 275 | (host, port), request_handler, bind_and_activate=False 276 | ) 277 | self._http_server.allow_reuse_address = False 278 | 279 | # Set the MCPServer instance on the handler class 280 | setattr(self._http_server, "mcp_server", self) 281 | 282 | try: 283 | # Bind and activate in main thread - errors propagate synchronously 284 | self._http_server.server_bind() 285 | self._http_server.server_activate() 286 | except OSError: 287 | # Cleanup on binding failure 288 | self._http_server.server_close() 289 | self._http_server = None 290 | raise 291 | 292 | # Only start thread after successful bind 293 | self._running = True 294 | 295 | print("[MCP] Server started:") 296 | print(f" Streamable HTTP: http://{host}:{port}/mcp") 297 | print(f" SSE: http://{host}:{port}/sse") 298 | 299 | def serve_forever(): 300 | try: 301 | self._http_server.serve_forever() # type: ignore 302 | except Exception as e: 303 | print(f"[MCP] Server error: {e}") 304 | traceback.print_exc() 305 | finally: 306 | self._running = False 307 | 308 | if background: 309 | self._server_thread = threading.Thread(target=serve_forever, daemon=True) 310 | self._server_thread.start() 311 | else: 312 | serve_forever() 313 | 314 | def stop(self): 315 | if not self._running: 316 | return 317 | 318 | self._running = False 319 | 320 | # Close all SSE connections 321 | for conn in self._sse_connections.values(): 322 | conn.alive = False 323 | self._sse_connections.clear() 324 | 325 | # Shutdown the HTTP server 326 | if self._http_server: 327 | # shutdown() must be called from a different thread 328 | # than the one running serve_forever() 329 | self._http_server.shutdown() 330 | self._http_server.server_close() 331 | self._http_server = None 332 | 333 | if self._server_thread: 334 | self._server_thread.join() 335 | self._server_thread = None 336 | 337 | print("[MCP] Server stopped") 338 | 339 | def stdio(self, stdin: BinaryIO | None = None, stdout: BinaryIO | None = None): 340 | stdin = stdin or sys.stdin.buffer 341 | stdout = stdout or sys.stdout.buffer 342 | while True: 343 | try: 344 | request = stdin.readline() 345 | if not request: # EOF 346 | break 347 | 348 | # Strip whitespace (trailing newline) before parsing 349 | request = request.strip() 350 | if not request: 351 | continue 352 | 353 | response = self.registry.dispatch(request) 354 | if response is not None: 355 | stdout.write(json.dumps(response).encode("utf-8") + b"\n") 356 | stdout.flush() 357 | except (BrokenPipeError, KeyboardInterrupt): # Client disconnected 358 | break 359 | 360 | def cors_localhost(self, origin: str) -> bool: 361 | """Allow CORS requests from localhost on ANY port.""" 362 | return urlparse(origin).hostname in ("localhost", "127.0.0.1", "::1") 363 | 364 | def _mcp_ping(self, _meta: dict | None = None) -> dict: 365 | """MCP ping method""" 366 | return {} 367 | 368 | def _mcp_initialize(self, protocolVersion: str, capabilities: dict, clientInfo: dict, _meta: dict | None = None) -> dict: 369 | """MCP initialize method""" 370 | return { 371 | "protocolVersion": getattr(self._protocol_version, "data", protocolVersion), 372 | "capabilities": { 373 | "tools": {}, 374 | "resources": { 375 | "subscribe": False, 376 | "listChanged": False, 377 | }, 378 | "prompts": {}, 379 | }, 380 | "serverInfo": { 381 | "name": self.name, 382 | "version": self.version, 383 | }, 384 | } 385 | 386 | def _mcp_tools_list(self, _meta: dict | None = None) -> dict: 387 | """MCP tools/list method""" 388 | return { 389 | "tools": [ 390 | self._generate_tool_schema(func_name, func) 391 | for func_name, func in self.tools.methods.items() 392 | ], 393 | } 394 | 395 | def _mcp_tools_call(self, name: str, arguments: dict | None = None, _meta: dict | None = None) -> dict: 396 | """MCP tools/call method""" 397 | # Wrap tool call in JSON-RPC request 398 | tool_response = self.tools.dispatch({ 399 | "jsonrpc": "2.0", 400 | "method": name, 401 | "params": arguments, 402 | "id": None, 403 | }) 404 | assert tool_response is not None, "Only notification requests return None" 405 | 406 | # Check for error response 407 | if "error" in tool_response: 408 | error = tool_response["error"] 409 | return { 410 | "content": [{"type": "text", "text": error["message"] or "Unknown error"}], 411 | "isError": True, 412 | } 413 | 414 | result = tool_response.get("result") 415 | return { 416 | "content": [{"type": "text", "text": json.dumps(result, indent=2)}], 417 | "structuredContent": result if isinstance(result, dict) else {"result": result}, 418 | "isError": False, 419 | } 420 | 421 | def _enumerate_resources(self): 422 | for name, func in self.resources.methods.items(): 423 | uri: str = getattr(func, "__resource_uri__") 424 | description = (func.__doc__ or f"Read {uri}").strip() 425 | yield uri, name, description 426 | 427 | def _mcp_resources_list(self, _meta: dict | None = None) -> dict: 428 | """MCP resources/list method - returns static resources only (no URI parameters)""" 429 | return { 430 | "resources": [ 431 | { 432 | "uri": uri, 433 | "name": name, 434 | "description": description, 435 | "mimeType": "application/json", 436 | } 437 | for uri, name, description in self._enumerate_resources() 438 | if "{" not in uri 439 | ] 440 | } 441 | 442 | def _mcp_resource_templates_list(self, _meta: dict | None = None) -> dict: 443 | """MCP resources/templates/list method - returns parameterized resource templates""" 444 | return { 445 | "resourceTemplates": [ 446 | { 447 | "uriTemplate": uri, 448 | "name": name, 449 | "description": description, 450 | "mimeType": "application/json", 451 | } 452 | for uri, name, description in self._enumerate_resources() 453 | if "{" in uri 454 | ] 455 | } 456 | 457 | def _mcp_resources_read(self, uri: str, _meta: dict | None = None) -> dict: 458 | """MCP resources/read method""" 459 | 460 | # Try to match URI against all registered resource patterns 461 | for pattern, name, _ in self._enumerate_resources(): 462 | # Convert pattern to regex, replacing {param} with named capture groups 463 | regex_pattern = re.sub(r"\{(\w+)\}", r"(?P<\1>[^/]+)", pattern) 464 | regex_pattern = f"^{regex_pattern}$" 465 | 466 | match = re.match(regex_pattern, uri) 467 | if match: 468 | # Found matching resource - call it via JSON-RPC 469 | params = list(match.groupdict().values()) 470 | 471 | resource_response = self.resources.dispatch({ 472 | "jsonrpc": "2.0", 473 | "method": name, 474 | "params": params, 475 | "id": None, 476 | }) 477 | assert resource_response is not None, "Only notification requests return None" 478 | 479 | if "error" in resource_response: 480 | error = resource_response["error"] 481 | raise JsonRpcException(error["code"], error["message"], error.get("data")) 482 | 483 | return { 484 | "contents": [{ 485 | "uri": uri, 486 | "mimeType": "application/json", 487 | "text": json.dumps(resource_response.get("result"), indent=2), 488 | }] 489 | } 490 | 491 | raise JsonRpcException(-32002, "Resource not found", {"uri": uri}) 492 | 493 | def _mcp_prompts_list(self, _meta: dict | None = None) -> dict: 494 | """MCP prompts/list method""" 495 | return { 496 | "prompts": [ 497 | self._generate_prompt_schema(func_name, func) 498 | for func_name, func in self.prompts.methods.items() 499 | ], 500 | } 501 | 502 | def _mcp_prompts_get( 503 | self, name: str, arguments: dict | None = None, _meta: dict | None = None 504 | ) -> dict: 505 | """MCP prompts/get method""" 506 | # Dispatch to prompts registry 507 | prompt_response = self.prompts.dispatch( 508 | { 509 | "jsonrpc": "2.0", 510 | "method": name, 511 | "params": arguments, 512 | "id": None, 513 | } 514 | ) 515 | assert prompt_response is not None, "Only notification requests return None" 516 | 517 | # Check for error response 518 | if "error" in prompt_response: 519 | error = prompt_response["error"] 520 | raise JsonRpcException(error["code"], error["message"], error.get("data")) 521 | 522 | result = prompt_response.get("result") 523 | 524 | # Pass through list of messages directly 525 | if isinstance(result, list): 526 | return {"messages": result} 527 | 528 | # Convert non-string results to JSON 529 | if not isinstance(result, str): 530 | result = json.dumps(result, indent=2) 531 | return { 532 | "messages": [ 533 | { 534 | "role": "user", 535 | "content": {"type": "text", "text": result}, 536 | }, 537 | ], 538 | } 539 | 540 | def _generate_prompt_schema(self, func_name: str, func: Callable) -> dict: 541 | """Generate MCP prompt schema from a function""" 542 | hints = get_type_hints(func, include_extras=True) 543 | hints.pop("return", None) 544 | sig = inspect.signature(func) 545 | 546 | # Build arguments list (PromptArgument format) 547 | arguments = [] 548 | for param_name, param_type in hints.items(): 549 | arg: dict[str, Any] = {"name": param_name} 550 | 551 | # Extract description from Annotated 552 | origin = get_origin(param_type) 553 | if origin is Annotated: 554 | args = get_args(param_type) 555 | arg["description"] = str(args[-1]) 556 | 557 | # Check if required (no default value) 558 | param = sig.parameters.get(param_name) 559 | if not param or param.default is inspect.Parameter.empty: 560 | arg["required"] = True 561 | 562 | arguments.append(arg) 563 | 564 | schema: dict[str, Any] = { 565 | "name": func_name, 566 | "description": (func.__doc__ or f"Prompt {func_name}").strip(), 567 | } 568 | 569 | if arguments: 570 | schema["arguments"] = arguments 571 | 572 | return schema 573 | 574 | def _type_to_json_schema(self, py_type: Any) -> dict: 575 | """Convert Python type hint to JSON schema object""" 576 | origin = get_origin(py_type) 577 | # Annotated[T, "description"] 578 | if origin is Annotated: 579 | args = get_args(py_type) 580 | return { 581 | **self._type_to_json_schema(args[0]), 582 | "description": str(args[-1]), 583 | } 584 | 585 | # NotRequired[T] 586 | if origin is NotRequired: 587 | return self._type_to_json_schema(get_args(py_type)[0]) 588 | 589 | # Union[Ts..], Optional[T] and T1 | T2 590 | if origin in (Union, UnionType): 591 | return {"anyOf": [self._type_to_json_schema(t) for t in get_args(py_type)]} 592 | 593 | # list[T] 594 | if origin is list: 595 | return { 596 | "type": "array", 597 | "items": self._type_to_json_schema(get_args(py_type)[0]), 598 | } 599 | 600 | # dict[str, T] 601 | if origin is dict: 602 | return { 603 | "type": "object", 604 | "additionalProperties": self._type_to_json_schema(get_args(py_type)[1]), 605 | } 606 | 607 | # TypedDict 608 | if is_typeddict(py_type): 609 | return self._typed_dict_to_schema(py_type) 610 | 611 | # Primitives 612 | return { 613 | "type": { 614 | int: "integer", 615 | float: "number", 616 | str: "string", 617 | bool: "boolean", 618 | list: "array", 619 | dict: "object", 620 | type(None): "null", 621 | }.get(py_type, "object"), 622 | } 623 | 624 | def _typed_dict_to_schema(self, typed_dict_class) -> dict: 625 | """Convert TypedDict to JSON schema""" 626 | hints = get_type_hints(typed_dict_class, include_extras=True) 627 | required_keys = getattr(typed_dict_class, "__required_keys__", set(hints.keys())) 628 | 629 | return { 630 | "type": "object", 631 | "properties": { 632 | field_name: self._type_to_json_schema(field_type) 633 | for field_name, field_type in hints.items() 634 | }, 635 | "required": [key for key in hints.keys() if key in required_keys], 636 | "additionalProperties": False, 637 | } 638 | 639 | def _generate_tool_schema(self, func_name: str, func: Callable) -> dict: 640 | """Generate MCP tool schema from a function""" 641 | hints = get_type_hints(func, include_extras=True) 642 | return_type = hints.pop("return", None) 643 | sig = inspect.signature(func) 644 | 645 | # Build parameter schema 646 | properties = {} 647 | required = [] 648 | 649 | for param_name, param_type in hints.items(): 650 | properties[param_name] = self._type_to_json_schema(param_type) 651 | 652 | # Add to required if no default value 653 | param = sig.parameters.get(param_name) 654 | if not param or param.default is inspect.Parameter.empty: 655 | required.append(param_name) 656 | 657 | schema: dict[str, Any] = { 658 | "name": func_name, 659 | "description": (func.__doc__ or f"Call {func_name}").strip(), 660 | "inputSchema": { 661 | "type": "object", 662 | "properties": properties, 663 | "required": required, 664 | }, 665 | } 666 | 667 | # Add outputSchema if return type exists and is not None 668 | if return_type and return_type is not type(None): 669 | return_schema = self._type_to_json_schema(return_type) 670 | 671 | # Wrap non-object returns in a "result" property 672 | if return_schema.get("type") != "object": 673 | return_schema = { 674 | "type": "object", 675 | "properties": {"result": return_schema}, 676 | "required": ["result"], 677 | } 678 | 679 | schema["outputSchema"] = return_schema 680 | 681 | return schema 682 | --------------------------------------------------------------------------------