├── src
└── zeromcp
│ ├── py.typed
│ ├── __init__.py
│ ├── jsonrpc.py
│ └── mcp.py
├── .python-version
├── .editorconfig
├── .gitignore
├── CONTRIBUTING.md
├── pyproject.toml
├── LICENSE
├── .github
└── workflows
│ └── ci.yml
├── examples
└── mcp_example.py
├── README.md
└── tests
├── server_test.py
├── mcp_test.py
└── jsonrpc_test.py
/src/zeromcp/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/src/zeromcp/__init__.py:
--------------------------------------------------------------------------------
1 | from .mcp import McpRpcRegistry, McpToolError, McpServer, McpHttpRequestHandler
2 |
3 | __all__ = ["McpRpcRegistry", "McpToolError", "McpServer", "McpHttpRequestHandler"]
4 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_style = space
5 | indent_size = 4
6 | tab_width = 4
7 | end_of_line = lf
8 | charset = utf-8
9 | trim_trailing_whitespace = true
10 | insert_final_newline = true
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 |
9 | # Virtual environments
10 | .venv
11 |
12 | # Coverage
13 | htmlcov/
14 | .coverage*
15 | coverage.xml
16 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Development notes
2 |
3 | Run the tests with coverage:
4 |
5 | ```sh
6 | uv run coverage run --data-file=.coverage.mcp tests/mcp_test.py
7 | uv run coverage run --data-file=.coverage.jsonrpc tests/jsonrpc_test.py
8 | uv run coverage run --data-file=.coverage.server tests/server_test.py
9 | ```
10 |
11 | Combine coverage and generate report:
12 |
13 | ```sh
14 | uv run coverage combine
15 | uv run coverage report
16 | uv run coverage html
17 | ```
18 |
19 | Generate report for just `jsonrpc_test.py:
20 |
21 | ```sh
22 | uv run coverage html --data-file=.coverage.jsonrpc
23 | ```
24 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "zeromcp"
3 | version = "0.0.0"
4 | description = "Zero-dependency MCP server implementation"
5 | readme = "README.md"
6 | requires-python = ">=3.11"
7 | dependencies = []
8 | license = "MIT"
9 |
10 | [project.urls]
11 | Homepage = "https://github.com/mrexodia/zeromcp"
12 | Repository = "https://github.com/mrexodia/zeromcp"
13 | Issues = "https://github.com/mrexodia/zeromcp/issues"
14 |
15 | [build-system]
16 | requires = ["hatchling"]
17 | build-backend = "hatchling.build"
18 |
19 | [dependency-groups]
20 | dev = [
21 | "coverage>=7.11.3",
22 | "mcp>=1.21.2",
23 | "requests>=2.32.5",
24 | ]
25 |
26 | [tool.coverage.report]
27 | omit = ["tests/*"]
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Duncan Ogilvie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on: [push, pull_request]
4 |
5 | # Automatically cancel previous runs of this workflow on the same branch
6 | concurrency:
7 | group: ${{ github.workflow }}-${{ github.ref }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | linux:
12 | # Skip building pull requests from the same repository
13 | if: ${{ github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository) }}
14 | runs-on: ubuntu-latest
15 | permissions:
16 | id-token: write
17 | contents: read
18 | env:
19 | PYTHONUNBUFFERED: 1
20 | steps:
21 | - name: Checkout
22 | uses: actions/checkout@v5
23 |
24 | - name: Update pyproject.toml version
25 | if: ${{ startsWith(github.ref, 'refs/tags/') }}
26 | shell: bash
27 | run: |
28 | # Extract version from tag (strip 'v' prefix if present)
29 | VERSION=${GITHUB_REF#refs/tags/}
30 | VERSION=${VERSION#v}
31 | echo "Extracted version: $VERSION"
32 |
33 | # Update version in pyproject.toml (works on both GNU and BSD sed)
34 | sed -i.bak "s/^version = .*/version = \"$VERSION\"/" pyproject.toml
35 | rm pyproject.toml.bak
36 |
37 | - name: Install uv
38 | uses: astral-sh/setup-uv@v6
39 | with:
40 | version: "0.9.3"
41 |
42 | - name: Create venv
43 | run: uv sync
44 |
45 | - name: Type check
46 | run: uvx ty check
47 |
48 | - name: Ruff check
49 | run: uvx ruff check
50 |
51 | - name: Test JSON-RPC
52 | run: uv run coverage run --data-file=.coverage.jsonrpc tests/jsonrpc_test.py
53 |
54 | - name: Test MCP
55 | run: uv run coverage run --data-file=.coverage.mcp tests/mcp_test.py
56 |
57 | - name: Test Server
58 | run: uv run coverage run --data-file=.coverage.server tests/server_test.py
59 |
60 | - name: Generate coverage report
61 | run: |
62 | uv run coverage combine
63 | uv run coverage report
64 |
65 | - name: Publish
66 | if: ${{ startsWith(github.ref, 'refs/tags/') }}
67 | run: uv build && uv publish
68 |
--------------------------------------------------------------------------------
/examples/mcp_example.py:
--------------------------------------------------------------------------------
1 | """Example MCP server with test tools"""
2 |
3 | import time
4 | import argparse
5 | from urllib.parse import urlparse
6 | from typing import Annotated, Optional, TypedDict, NotRequired
7 | from zeromcp import McpToolError, McpServer
8 |
9 | mcp = McpServer("example")
10 |
11 |
12 | @mcp.tool
13 | def divide(
14 | numerator: Annotated[float, "Numerator"],
15 | denominator: Annotated[float, "Denominator"],
16 | ) -> float:
17 | """Divide two numbers (no zero check - tests natural exceptions)"""
18 | return numerator / denominator
19 |
20 |
21 | class GreetingResponse(TypedDict):
22 | message: Annotated[str, "Greeting message"]
23 | name: Annotated[str, "Name that was greeted"]
24 | age: Annotated[NotRequired[int], "Age if provided"]
25 |
26 |
27 | @mcp.tool
28 | def greet(
29 | name: Annotated[str, "Name to greet"],
30 | age: Annotated[Optional[int], "Age of person"] = None,
31 | ) -> GreetingResponse:
32 | """Generate a greeting message"""
33 | if age is not None:
34 | return {
35 | "message": f"Hello, {name}! You are {age} years old.",
36 | "name": name,
37 | "age": age,
38 | }
39 | return {"message": f"Hello, {name}!", "name": name}
40 |
41 |
42 | class SystemInfo(TypedDict):
43 | platform: Annotated[str, "Operating system platform"]
44 | python_version: Annotated[str, "Python version"]
45 | machine: Annotated[str, "Machine architecture"]
46 | timestamp: Annotated[float, "Current timestamp"]
47 |
48 |
49 | @mcp.tool
50 | def get_system_info() -> SystemInfo:
51 | """Get system information"""
52 | import platform
53 |
54 | return {
55 | "platform": platform.system(),
56 | "python_version": platform.python_version(),
57 | "machine": platform.machine(),
58 | "timestamp": time.time(),
59 | }
60 |
61 |
62 | @mcp.tool
63 | def failing_tool(message: Annotated[str, "Error message to raise"]) -> str:
64 | """Tool that always fails (for testing error handling)"""
65 | raise McpToolError(message)
66 |
67 |
68 | class StructInfo(TypedDict):
69 | name: Annotated[str, "Structure name"]
70 | size: Annotated[int, "Structure size in bytes"]
71 | fields: Annotated[list[str], "List of field names"]
72 |
73 |
74 | @mcp.tool
75 | def struct_get(
76 | names: Annotated[list[str], "Array of structure names"]
77 | | Annotated[str, "Single structure name"],
78 | ) -> list[StructInfo]:
79 | """Retrieve structure information by names"""
80 | return [
81 | StructInfo(
82 | {
83 | "name": name,
84 | "size": 128, # Dummy size
85 | "fields": ["field1", "field2", "field3"], # Dummy fields
86 | }
87 | )
88 | for name in (names if isinstance(names, list) else [names])
89 | ]
90 |
91 |
92 | @mcp.tool
93 | def random_dict(param: dict[str, int] | None) -> dict:
94 | """Return a random dictionary for testing serialization"""
95 | return {
96 | **(param or {}),
97 | "x": 42,
98 | "y": 7,
99 | "z": 99,
100 | }
101 |
102 |
103 | @mcp.resource("example://system_info")
104 | def system_info_resource() -> SystemInfo:
105 | """Resource providing system information"""
106 | return get_system_info()
107 |
108 |
109 | @mcp.resource("example://greeting/{name}")
110 | def greeting_resource(
111 | name: Annotated[str, "Name to greet from resource"],
112 | ) -> GreetingResponse:
113 | """Resource providing greeting message"""
114 | return greet(name)
115 |
116 |
117 | @mcp.resource("example://error")
118 | def error_resource() -> None:
119 | """Resource that always fails (for testing error handling)"""
120 | raise McpToolError("This is a resource error for testing purposes.")
121 |
122 |
123 | @mcp.prompt
124 | def code_review(
125 | code: Annotated[str, "Code to review"],
126 | language: Annotated[str, "Programming language"] = "python",
127 | ) -> str:
128 | """Review code for bugs and improvements"""
129 | return f"Please review this {language} code for bugs and improvements:\n\n```{language}\n{code}\n```"
130 |
131 |
132 | @mcp.prompt
133 | def summarize(
134 | text: Annotated[str, "Text to summarize"],
135 | max_sentences: Annotated[int, "Maximum sentences"] = 3,
136 | ) -> str:
137 | """Summarize text concisely"""
138 | return f"Summarize the following in {max_sentences} sentences or fewer:\n\n{text}"
139 |
140 |
141 | if __name__ == "__main__":
142 | parser = argparse.ArgumentParser(description="MCP Example Server")
143 | parser.add_argument(
144 | "--transport",
145 | help="Transport (stdio or http://host:port)",
146 | default="http://127.0.0.1:5001",
147 | )
148 | args = parser.parse_args()
149 | if args.transport == "stdio":
150 | mcp.stdio()
151 | else:
152 | url = urlparse(args.transport)
153 | if url.hostname is None or url.port is None:
154 | raise Exception(f"Invalid transport URL: {args.transport}")
155 |
156 | print("Starting MCP Example Server...")
157 |
158 | print("\nAvailable tools:")
159 | for name in mcp.tools.methods.keys():
160 | func = mcp.tools.methods[name]
161 | print(f" - {name}: {func.__doc__}")
162 |
163 | print("\nAvailable resources:")
164 | for name in mcp.resources.methods.keys():
165 | func = mcp.resources.methods[name]
166 | print(f" - {name}: {func.__doc__}")
167 |
168 | print("\nAvailable prompts:")
169 | for name in mcp.prompts.methods.keys():
170 | func = mcp.prompts.methods[name]
171 | print(f" - {name}: {func.__doc__}")
172 | print()
173 |
174 | mcp.serve(url.hostname, url.port)
175 |
176 | try:
177 | input("\nServer is running, press Enter or Ctrl+C to stop...")
178 | except (KeyboardInterrupt, EOFError):
179 | print("\n\nStopping server...")
180 | mcp.stop()
181 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # zeromcp
2 |
3 | **Minimal MCP server implementation in pure Python.**
4 |
5 | A lightweight, handcrafted implementation of the [Model Context Protocol](https://modelcontextprotocol.io/) focused on what most users actually need: exposing tools with clean Python type annotations.
6 |
7 | ## Features
8 |
9 | - ✨ **Zero dependencies** - Pure Python, standard library only
10 | - 🎯 **Type-safe** - Native Python type annotations for everything
11 | - 🚀 **Fast** - Minimal overhead, maximum performance
12 | - 🛠️ **Handcrafted** - Written by a human[1](#ai-usage), verified against the spec
13 | - 🌐 **HTTP/SSE transport** - Streamable responses
14 | - 📡 **Stdio transport** - For legacy clients
15 | - 📦 **Tiny** - Less than 1,000 lines of code
16 |
17 | ## Installation
18 |
19 | ```bash
20 | pip install zeromcp
21 | ```
22 |
23 | Or with uv:
24 |
25 | ```bash
26 | uv add zeromcp
27 | ```
28 |
29 | ## Quick Start
30 |
31 | ```python
32 | from typing import Annotated
33 | from zeromcp import McpServer
34 |
35 | mcp = McpServer("my-server")
36 |
37 | @mcp.tool
38 | def greet(
39 | name: Annotated[str, "Name to greet"],
40 | age: Annotated[int | None, "Age of person"] = None
41 | ) -> str:
42 | """Generate a greeting message"""
43 | if age:
44 | return f"Hello, {name}! You are {age} years old."
45 | return f"Hello, {name}!"
46 |
47 | if __name__ == "__main__":
48 | mcp.serve("127.0.0.1", 8000)
49 | ```
50 |
51 | Then manually test your MCP server with the [inspector](https://github.com/modelcontextprotocol/inspector):
52 |
53 | ```bash
54 | npx -y @modelcontextprotocol/inspector
55 | ```
56 |
57 | Once things are working you can configure the `mcp.json`:
58 |
59 | ```json
60 | {
61 | "mcpServers": {
62 | "my-server": {
63 | "type": "http",
64 | "url": "http://127.0.0.1:8000/mcp"
65 | }
66 | }
67 | }
68 | ```
69 |
70 | ## Stdio Transport
71 |
72 | For MCP clients that only support stdio transport:
73 |
74 | ```python
75 | from zeromcp import McpServer
76 |
77 | mcp = McpServer("my-server")
78 |
79 | @mcp.tool
80 | def greet(name: str) -> str:
81 | """Generate a greeting"""
82 | return f"Hello, {name}!"
83 |
84 | if __name__ == "__main__":
85 | mcp.stdio()
86 | ```
87 |
88 | Then configure in `mcp.json` (different for every client):
89 |
90 | ```json
91 | {
92 | "mcpServers": {
93 | "my-server": {
94 | "command": "python",
95 | "args": ["path/to/server.py"]
96 | }
97 | }
98 | }
99 | ```
100 |
101 | ## Type Annotations
102 |
103 | zeromcp uses native Python `Annotated` types for schema generation:
104 |
105 | ```python
106 | from typing import Annotated, Optional, TypedDict, NotRequired
107 |
108 | class GreetingResponse(TypedDict):
109 | message: Annotated[str, "Greeting message"]
110 | name: Annotated[str, "Name that was greeted"]
111 | age: Annotated[NotRequired[int], "Age if provided"]
112 |
113 | @mcp.tool
114 | def greet(
115 | name: Annotated[str, "Name to greet"],
116 | age: Annotated[Optional[int], "Age of person"] = None
117 | ) -> GreetingResponse:
118 | """Generate a greeting message"""
119 | if age is not None:
120 | return {
121 | "message": f"Hello, {name}! You are {age} years old.",
122 | "name": name,
123 | "age": age
124 | }
125 | return {
126 | "message": f"Hello, {name}!",
127 | "name": name
128 | }
129 | ```
130 |
131 | ## Union Types
132 |
133 | Tools can accept multiple input types:
134 |
135 | ```python
136 | from typing import Annotated, TypedDict
137 |
138 | class StructInfo(TypedDict):
139 | name: Annotated[str, "Structure name"]
140 | size: Annotated[int, "Structure size in bytes"]
141 | fields: Annotated[list[str], "List of field names"]
142 |
143 | @mcp.tool
144 | def struct_get(
145 | names: Annotated[list[str], "Array of structure names"]
146 | | Annotated[str, "Single structure name"]
147 | ) -> list[StructInfo]:
148 | """Retrieve structure information by names"""
149 | return [
150 | {
151 | "name": name,
152 | "size": 128,
153 | "fields": ["field1", "field2", "field3"]
154 | }
155 | for name in (names if isinstance(names, list) else [names])
156 | ]
157 | ```
158 |
159 | ## Error Handling
160 |
161 | ```python
162 | from zeromcp import McpToolError
163 |
164 | @mcp.tool
165 | def divide(
166 | numerator: Annotated[float, "Numerator"],
167 | denominator: Annotated[float, "Denominator"]
168 | ) -> float:
169 | """Divide two numbers"""
170 | if denominator == 0:
171 | raise McpToolError("Division by zero")
172 | return numerator / denominator
173 | ```
174 |
175 | ## Resources
176 |
177 | Expose read-only data via URI patterns. Resources are serialized as JSON.
178 |
179 | ```python
180 | from typing import Annotated
181 |
182 | @mcp.resource("file://data.txt")
183 | def read_file() -> dict:
184 | """Get information about data.txt"""
185 | return {"name": "data.txt", "size": 1024}
186 |
187 | @mcp.resource("file://{filename}")
188 | def read_any_file(
189 | filename: Annotated[str, "Name of file to read"]
190 | ) -> dict:
191 | """Get information about any file"""
192 | return {"name": filename, "size": 2048}
193 | ```
194 |
195 | ## Prompts
196 |
197 | Expose reusable prompt templates with typed arguments.
198 |
199 | ```python
200 | from typing import Annotated
201 |
202 | @mcp.prompt
203 | def code_review(
204 | code: Annotated[str, "Code to review"],
205 | language: Annotated[str, "Programming language"] = "python"
206 | ) -> str:
207 | """Review code for bugs and improvements"""
208 | return f"Please review this {language} code:\n\n```{language}\n{code}\n```"
209 | ```
210 |
211 | ## CORS
212 |
213 | By default, zeromcp allows CORS requests from localhost origins (`localhost`, `127.0.0.1`, `::1`) on **any port**. This allows tools like the MCP Inspector or local AI tools to communicate with your MCP server.
214 |
215 | ```python
216 | from zeromcp import McpServer
217 |
218 | mcp = McpServer("my-server")
219 |
220 | # Default: allow localhost on any port
221 | mcp.cors_allowed_origins = mcp.cors_localhost
222 |
223 | # Allow all origins (use with caution)
224 | mcp.cors_allowed_origins = "*"
225 |
226 | # Allow specific origins
227 | mcp.cors_allowed_origins = [
228 | "http://localhost:3000",
229 | "https://myapp.example.com",
230 | ]
231 |
232 | # Disable CORS (blocks all browser cross-origin requests)
233 | mcp.cors_allowed_origins = None
234 |
235 | # Custom logic
236 | mcp.cors_allowed_origins = lambda origin: origin.endswith(".example.com")
237 | ```
238 |
239 | Note: CORS only affects browser-based requests. Non-browser clients like `curl` or MCP desktop apps are unaffected by this setting.
240 |
241 | ## Supported clients
242 |
243 | The following clients have been tested:
244 |
245 | - [Claude Code](https://code.claude.com/docs/en/mcp#installing-mcp-servers)
246 | - [Claude Desktop](https://modelcontextprotocol.io/docs/develop/connect-local-servers#installing-the-filesystem-server) (_stdio only_)
247 | - [Visual Studio Code](https://code.visualstudio.com/docs/copilot/customization/mcp-servers)
248 | - [Roo Code](https://docs.roocode.com/features/mcp/using-mcp-in-roo) / [Cline](https://docs.cline.bot/mcp/configuring-mcp-servers) / [Kilo Code](https://kilocode.ai/docs/features/mcp/using-mcp-in-kilo-code)
249 | - [LM Studio](https://lmstudio.ai/docs/app/mcp)
250 | - [Jan](https://www.jan.ai/docs/desktop/mcp#configure-and-use-mcps-within-jan)
251 | - [Gemini CLI](https://geminicli.com/docs/tools/mcp-server/#how-to-set-up-your-mcp-server)
252 | - [Cursor](https://cursor.com/docs/context/mcp)
253 | - [Windsurf](https://docs.windsurf.com/windsurf/cascade/mcp)
254 | - [Zed](https://zed.dev/docs/ai/mcp) (_stdio only_)
255 | - [Warp](https://docs.warp.dev/knowledge-and-collaboration/mcp#adding-an-mcp-server)
256 |
257 | _Note_: generally the `/mcp` endpoint is preferred, but not all clients support it correctly.
258 |
259 |
260 | 1README and some of the tests written by Claude
261 |
--------------------------------------------------------------------------------
/tests/server_test.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import sys
3 | import socket
4 | from contextlib import contextmanager
5 | from zeromcp import McpServer
6 |
7 | def find_free_port():
8 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
9 | s.bind(("", 0))
10 | return s.getsockname()[1]
11 |
12 | @contextmanager
13 | def run_server(name="test", **kwargs):
14 | port = find_free_port()
15 | server = McpServer(name)
16 | for k, v in kwargs.items():
17 | setattr(server, k, v)
18 | server.serve("127.0.0.1", port, background=True)
19 | base_url = f"http://127.0.0.1:{port}"
20 | try:
21 | yield base_url, server
22 | finally:
23 | server.stop()
24 |
25 | PING_JSON = {"jsonrpc": "2.0", "method": "ping", "id": 1}
26 |
27 | def test_cors_permissive():
28 | print("Testing CORS permissive (cors_allowed_origins='*')...")
29 | with run_server(cors_allowed_origins="*") as (base_url, _):
30 | test_origin = "http://example.com"
31 | # Test OPTIONS
32 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": test_origin})
33 | assert resp.headers.get("Access-Control-Allow-Origin") == test_origin, "OPTIONS should have CORS header"
34 |
35 | # Test POST
36 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": test_origin}, json=PING_JSON)
37 | assert resp.headers.get("Access-Control-Allow-Origin") == test_origin, "POST should have CORS header"
38 | print("✓ PASS")
39 |
40 | def test_cors_restrictive():
41 | print("Testing CORS restrictive (cors_allowed_origins=None)...")
42 | with run_server(cors_allowed_origins=None) as (base_url, _):
43 | # Test OPTIONS
44 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://example.com"})
45 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header"
46 |
47 | # Test POST
48 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://example.com"}, json=PING_JSON)
49 | assert "Access-Control-Allow-Origin" not in resp.headers, "POST should NOT have CORS header"
50 | print("✓ PASS")
51 |
52 | def test_cors_local():
53 | print("Testing CORS localhost...")
54 | with run_server() as (base_url, _):
55 | # Test OPTIONS
56 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://localhost:1234"})
57 | assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:1234", "OPTIONS should have CORS header"
58 |
59 | # Test POST
60 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "https://127.0.0.1:4321"}, json=PING_JSON)
61 | assert resp.headers.get("Access-Control-Allow-Origin") == "https://127.0.0.1:4321", "POST should have CORS header (HTTPS)"
62 |
63 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://[::1]:4321"}, json=PING_JSON)
64 | assert resp.headers.get("Access-Control-Allow-Origin") == "http://[::1]:4321", "POST should have CORS header (IPv6)"
65 |
66 | # Test OPTIONS with wrong origin
67 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://example.com"})
68 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header for wrong origin"
69 |
70 | def test_cors_list():
71 | print("Testing CORS list...")
72 | allowed_origins = ["http://example.com", "https://example.org"]
73 | with run_server(cors_allowed_origins=allowed_origins) as (base_url, _):
74 | # Test allowed origins
75 | for origin in allowed_origins:
76 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": origin})
77 | assert resp.headers.get("Access-Control-Allow-Origin") == origin, f"OPTIONS should have CORS header for {origin}"
78 |
79 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": origin}, json=PING_JSON)
80 | assert resp.headers.get("Access-Control-Allow-Origin") == origin, f"POST should have CORS header for {origin}"
81 |
82 | # Test disallowed origin
83 | resp = requests.options(f"{base_url}/mcp", headers={"Origin": "http://notallowed.com"})
84 | assert "Access-Control-Allow-Origin" not in resp.headers, "OPTIONS should NOT have CORS header for disallowed origin"
85 |
86 | resp = requests.post(f"{base_url}/mcp", headers={"Origin": "http://notallowed.com"}, json=PING_JSON)
87 | assert "Access-Control-Allow-Origin" not in resp.headers, "POST should NOT have CORS header for disallowed origin"
88 |
89 | def test_body_limit():
90 | print("Testing body limit...")
91 | # Set small limit (100 bytes)
92 | with run_server(post_body_limit=100) as (base_url, _):
93 | # Small request - should pass
94 | resp = requests.post(f"{base_url}/mcp", json=PING_JSON)
95 | assert resp.status_code == 200, "Small request should pass"
96 |
97 | # Large request - should fail
98 | large_payload = "x" * 200
99 | resp = requests.post(f"{base_url}/mcp", data=large_payload)
100 | assert resp.status_code == 413, "Large request should fail with 413"
101 | assert "Payload Too Large" in resp.text, "Error message should mention payload size"
102 | print("✓ PASS")
103 |
104 | def test_exception_redaction():
105 | print("Testing exception redaction...")
106 | with run_server() as (base_url, server):
107 | server.tools.redact_exceptions = True
108 |
109 | @server.tool
110 | def fail():
111 | raise ValueError("Secret internal info")
112 |
113 | # Call via tools/call
114 | payload = {
115 | "jsonrpc": "2.0",
116 | "method": "tools/call",
117 | "params": {"name": "fail", "arguments": {}},
118 | "id": 1
119 | }
120 | resp = requests.post(f"{base_url}/mcp", json=payload)
121 | data = resp.json()
122 |
123 | # The outer JSON-RPC call succeeds
124 | assert "result" in data, f"Expected result, got error: {data.get('error')}"
125 | result = data["result"]
126 |
127 | # The tool execution failed
128 | assert result["isError"] is True, "Tool execution should be an error"
129 | error_text = result["content"][0]["text"]
130 |
131 | assert error_text == "Internal Error: Secret internal info", f"Should show redacted message, got: {error_text}"
132 | assert "Traceback" not in error_text, "Should NOT show traceback"
133 | print("✓ PASS")
134 |
135 | def test_exception_exposure():
136 | print("Testing exception exposure (default)...")
137 | with run_server() as (base_url, server):
138 | server.tools.redact_exceptions = False
139 |
140 | @server.tool
141 | def fail():
142 | raise ValueError("Secret internal info")
143 |
144 | # Call via tools/call
145 | payload = {
146 | "jsonrpc": "2.0",
147 | "method": "tools/call",
148 | "params": {"name": "fail", "arguments": {}},
149 | "id": 1
150 | }
151 | resp = requests.post(f"{base_url}/mcp", json=payload)
152 | data = resp.json()
153 |
154 | # The outer JSON-RPC call succeeds
155 | assert "result" in data, f"Expected result, got error: {data.get('error')}"
156 | result = data["result"]
157 |
158 | # The tool execution failed
159 | assert result["isError"] is True, "Tool execution should be an error"
160 | error_text = result["content"][0]["text"]
161 |
162 | assert "Secret internal info" in error_text, "Should show exception message"
163 | assert "Traceback" in error_text, "Should show traceback"
164 | print("✓ PASS")
165 |
166 | print("✓ PASS")
167 |
168 | def test_http_errors():
169 | print("Testing HTTP errors...")
170 | with run_server() as (base_url, _):
171 | # GET /mcp -> 405 Method Not Allowed
172 | resp = requests.get(f"{base_url}/mcp")
173 | assert resp.status_code == 405, f"GET /mcp should return 405, got {resp.status_code}"
174 |
175 | # GET /invalid -> 404 Not Found
176 | resp = requests.get(f"{base_url}/invalid")
177 | assert resp.status_code == 404, f"GET /invalid should return 404, got {resp.status_code}"
178 |
179 | # POST /invalid -> 404 Not Found
180 | resp = requests.post(f"{base_url}/invalid", json={})
181 | assert resp.status_code == 404, f"POST /invalid should return 404, got {resp.status_code}"
182 | print("✓ PASS")
183 |
184 | def test_sse_errors():
185 | print("Testing SSE errors...")
186 | with run_server() as (base_url, _):
187 | # POST /sse without session -> 400 Bad Request
188 | resp = requests.post(f"{base_url}/sse", json={})
189 | assert resp.status_code == 400, f"POST /sse without session should return 400, got {resp.status_code}"
190 | assert "Missing ?session" in resp.text
191 |
192 | # POST /sse with invalid session -> 400 Bad Request
193 | resp = requests.post(f"{base_url}/sse?session=invalid-uuid", json={})
194 | assert resp.status_code == 400, f"POST /sse with invalid session should return 400, got {resp.status_code}"
195 | assert "No active SSE connection" in resp.text
196 | print("✓ PASS")
197 |
198 | def test_mcp_tool_error():
199 | print("Testing McpToolError...")
200 | from zeromcp import McpToolError
201 | with run_server() as (base_url, server):
202 | @server.tool
203 | def fail_custom():
204 | raise McpToolError("Custom tool error")
205 |
206 | resp = requests.post(f"{base_url}/mcp", json={
207 | "jsonrpc": "2.0",
208 | "method": "tools/call",
209 | "params": {"name": "fail_custom", "arguments": {}},
210 | "id": 1
211 | })
212 | data = resp.json()
213 | result = data["result"]
214 | assert result["isError"] is True
215 | assert "Custom tool error" in result["content"][0]["text"]
216 | print("✓ PASS")
217 |
218 | def run_all_tests():
219 | print("="*60)
220 | print("SERVER TESTS")
221 | print("="*60)
222 |
223 | try:
224 | test_cors_permissive()
225 | test_cors_restrictive()
226 | test_cors_local()
227 | test_cors_list()
228 | test_body_limit()
229 | test_exception_redaction()
230 | test_exception_exposure()
231 | test_http_errors()
232 | test_sse_errors()
233 | test_mcp_tool_error()
234 | print("\n" + "="*60)
235 | print("ALL SERVER TESTS PASSED! ✓")
236 | print("="*60)
237 | except AssertionError as e:
238 | print(f"\n❌ FAIL: {e}")
239 | sys.exit(1)
240 | except Exception as e:
241 | print(f"\n❌ ERROR: {e}")
242 | import traceback
243 | traceback.print_exc()
244 | sys.exit(1)
245 |
246 | if __name__ == "__main__":
247 | run_all_tests()
248 |
--------------------------------------------------------------------------------
/src/zeromcp/jsonrpc.py:
--------------------------------------------------------------------------------
1 | import json
2 | import inspect
3 | import traceback
4 | from typing import Any, Callable, get_type_hints, get_origin, get_args, Union, TypedDict, TypeAlias, NotRequired, is_typeddict
5 | from types import UnionType
6 |
7 | JsonRpcId: TypeAlias = str | int | float | None
8 | JsonRpcParams: TypeAlias = dict[str, Any] | list[Any] | None
9 |
10 | class JsonRpcRequest(TypedDict):
11 | jsonrpc: str
12 | method: str
13 | params: NotRequired[JsonRpcParams]
14 | id: NotRequired[JsonRpcId]
15 |
16 | class JsonRpcError(TypedDict):
17 | code: int
18 | message: str
19 | data: NotRequired[Any]
20 |
21 | class JsonRpcResponse(TypedDict):
22 | jsonrpc: str
23 | result: NotRequired[Any]
24 | error: NotRequired[JsonRpcError]
25 | id: JsonRpcId
26 |
27 | class JsonRpcException(Exception):
28 | def __init__(self, code: int, message: str, data: Any = None):
29 | self.code = code
30 | self.message = message
31 | self.data = data
32 |
33 | class JsonRpcRegistry:
34 | def __init__(self):
35 | self.methods: dict[str, Callable] = {}
36 | self._cache: dict[Callable, tuple[inspect.Signature, dict, list[str]]] = {}
37 | self.redact_exceptions = False
38 |
39 | def method(self, func: Callable, name: str | None = None) -> Callable:
40 | self.methods[name or func.__name__] = func # type: ignore
41 | return func
42 |
43 | def dispatch(self, request: dict | str | bytes | bytearray) -> JsonRpcResponse | None:
44 | try:
45 | if not isinstance(request, dict):
46 | request = json.loads(request)
47 | if not isinstance(request, dict):
48 | return self._error(None, -32600, "Invalid request: must be a JSON object")
49 | except Exception as e:
50 | return self._error(None, -32700, "JSON parse error", str(e))
51 |
52 | if request.get("jsonrpc") != "2.0":
53 | return self._error(None, -32600, "Invalid request: 'jsonrpc' must be '2.0'")
54 |
55 | method = request.get("method")
56 | if method is None:
57 | return self._error(None, -32600, "Invalid request: 'method' is required")
58 | if not isinstance(method, str):
59 | return self._error(None, -32600, "Invalid request: 'method' must be a string")
60 |
61 | request_id: JsonRpcId = request.get("id")
62 | is_notification = "id" not in request
63 | params: JsonRpcParams = request.get("params")
64 | try:
65 | result = self._call(method, params)
66 | if is_notification:
67 | return None
68 | return {
69 | "jsonrpc": "2.0",
70 | "result": result,
71 | "id": request_id,
72 | }
73 | except JsonRpcException as e:
74 | if is_notification:
75 | return None
76 | return self._error(request_id, e.code, e.message, e.data)
77 | except Exception as e:
78 | if is_notification:
79 | return None
80 | error = self.map_exception(e)
81 | return self._error(request_id, error["code"], error["message"], error.get("data"))
82 |
83 | def map_exception(self, e: Exception) -> JsonRpcError:
84 | if self.redact_exceptions:
85 | return {
86 | "code": -32603,
87 | "message": f"Internal Error: {str(e)}",
88 | }
89 | return {
90 | "code": -32603,
91 | "message": "\n".join(traceback.format_exception(e)).strip() + "\n\nPlease report a bug!",
92 | }
93 |
94 | def _call(self, method: str, params: Any) -> Any:
95 | if method not in self.methods:
96 | raise JsonRpcException(-32601, f"Method '{method}' not found")
97 |
98 | func = self.methods[method]
99 |
100 | # Check for cached reflection data
101 | if func not in self._cache:
102 | sig = inspect.signature(func)
103 | hints = get_type_hints(func)
104 | hints.pop("return", None)
105 |
106 | # Determine required vs optional parameters
107 | required_params = []
108 | for param_name, param in sig.parameters.items():
109 | if param.default is inspect.Parameter.empty:
110 | required_params.append(param_name)
111 |
112 | self._cache[func] = (sig, hints, required_params)
113 |
114 | sig, hints, required_params = self._cache[func]
115 |
116 | # Handle None params
117 | if params is None:
118 | if len(required_params) == 0:
119 | return func()
120 | else:
121 | raise JsonRpcException(-32602, "Missing required params")
122 |
123 | # Convert list params to dict by parameter names
124 | if isinstance(params, list):
125 | if len(params) < len(required_params):
126 | raise JsonRpcException(
127 | -32602,
128 | f"Invalid params: expected at least {len(required_params)} arguments, got {len(params)}"
129 | )
130 | if len(params) > len(sig.parameters):
131 | raise JsonRpcException(
132 | -32602,
133 | f"Invalid params: expected at most {len(sig.parameters)} arguments, got {len(params)}"
134 | )
135 | params = dict(zip(sig.parameters.keys(), params))
136 |
137 | # Validate dict params
138 | if isinstance(params, dict):
139 | # Check all required params are present
140 | missing = set(required_params) - set(params.keys())
141 | if missing:
142 | raise JsonRpcException(
143 | -32602,
144 | f"Invalid params: missing required parameters: {list(missing)}"
145 | )
146 |
147 | # Check no extra params
148 | extra = set(params.keys()) - set(sig.parameters.keys())
149 | if extra:
150 | raise JsonRpcException(
151 | -32602,
152 | f"Invalid params: unexpected parameters: {list(extra)}"
153 | )
154 |
155 | validated_params = {}
156 | for param_name, value in params.items():
157 | # If no type hint, pass through without validation
158 | if param_name not in hints:
159 | validated_params[param_name] = value
160 | continue
161 |
162 | # Has type hint, validate
163 | expected_type = hints[param_name]
164 |
165 | # Inline type validation
166 | origin = get_origin(expected_type)
167 | args = get_args(expected_type)
168 |
169 | # Handle None/null
170 | if value is None:
171 | if expected_type is not type(None):
172 | # Check if None is allowed in a Union
173 | if not (origin in (Union, UnionType) and type(None) in args):
174 | raise JsonRpcException(-32602, f"Invalid params: {param_name} cannot be null")
175 | validated_params[param_name] = None
176 | continue
177 |
178 | # Handle Union types (int | str, Optional[int], etc.)
179 | if origin in (Union, UnionType):
180 | type_matched = False
181 | for arg_type in args:
182 | if arg_type is type(None):
183 | continue
184 |
185 | arg_origin = get_origin(arg_type)
186 | check_type = arg_origin if arg_origin is not None else arg_type
187 |
188 | # TypedDict cannot be used with isinstance - check for dict instead
189 | if is_typeddict(arg_type):
190 | check_type = dict
191 |
192 | if isinstance(value, check_type):
193 | type_matched = True
194 | break
195 |
196 | if not type_matched:
197 | raise JsonRpcException(-32602, f"Invalid params: {param_name} union does not contain {type(value).__name__}")
198 | validated_params[param_name] = value
199 | continue
200 |
201 | # Handle generic types (list[X], dict[K,V])
202 | if origin is not None:
203 | if not isinstance(value, origin):
204 | raise JsonRpcException(
205 | -32602,
206 | f"Invalid params: {param_name} expected {origin.__name__}, got {type(value).__name__}"
207 | )
208 | validated_params[param_name] = value
209 | continue
210 |
211 | # Handle TypedDict (must check before basic types)
212 | if is_typeddict(expected_type):
213 | if not isinstance(value, dict):
214 | raise JsonRpcException(
215 | -32602,
216 | f"Invalid params: {param_name} expected dict, got {type(value).__name__}"
217 | )
218 | validated_params[param_name] = value
219 | continue
220 |
221 | # Handle Any
222 | if expected_type is Any:
223 | validated_params[param_name] = value
224 | continue
225 |
226 | # Handle basic types
227 | if isinstance(expected_type, type):
228 | # Allow int -> float conversion
229 | if expected_type is float and isinstance(value, int):
230 | validated_params[param_name] = float(value)
231 | continue
232 | if not isinstance(value, expected_type):
233 | raise JsonRpcException(
234 | -32602,
235 | f"Invalid params: {param_name} expected {expected_type.__name__}, got {type(value).__name__}"
236 | )
237 | validated_params[param_name] = value
238 | continue
239 |
240 | return func(**validated_params)
241 |
242 | else:
243 | raise JsonRpcException(-32602, "Invalid params: must be array or object")
244 |
245 | def _error(self, request_id: JsonRpcId, code: int, message: str, data: Any = None) -> JsonRpcResponse | None:
246 | error: JsonRpcError = {
247 | "code": code,
248 | "message": message,
249 | }
250 | if data is not None:
251 | error["data"] = data
252 | return {
253 | "jsonrpc": "2.0",
254 | "error": error,
255 | "id": request_id,
256 | }
257 |
--------------------------------------------------------------------------------
/tests/mcp_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import socket
4 | import asyncio
5 | import subprocess
6 |
7 | from pydantic import AnyUrl
8 | from mcp import ClientSession, StdioServerParameters, McpError, types
9 | from mcp.client.stdio import stdio_client
10 | from mcp.client.sse import sse_client
11 | from mcp.client.streamable_http import streamablehttp_client
12 |
13 | example_mcp = os.path.join(os.path.dirname(__file__), "..", "examples", "mcp_example.py")
14 | assert os.path.exists(example_mcp), f"not found: {example_mcp}"
15 |
16 | async def test_example_server(prefix: str, session: ClientSession):
17 | # Initialize the connection
18 | await session.initialize()
19 |
20 | # Test ping
21 | ping_result = await session.send_ping()
22 | print(f"[{prefix}] Ping result: {ping_result}")
23 | assert isinstance(ping_result, types.EmptyResult), "ping should return EmptyResult"
24 |
25 | # List available resources
26 | resources = await session.list_resources()
27 | print(f"[{prefix}] Available resources: {[r.uri for r in resources.resources]}")
28 | assert len(resources.resources) == 2, "expected 2 static resources"
29 | assert str(resources.resources[0].uri) == "example://system_info", "expected system_info resource"
30 |
31 | # List available resource templates
32 | template_resources = await session.list_resource_templates()
33 | print(f"[{prefix}] Available resource templates: {[r.uriTemplate for r in template_resources.resourceTemplates]}")
34 | assert len(template_resources.resourceTemplates) == 1, "expected 1 resource template"
35 | assert template_resources.resourceTemplates[0].uriTemplate == "example://greeting/{name}", "expected greeting template"
36 |
37 | # List available tools
38 | tools = await session.list_tools()
39 | print(f"[{prefix}] Available tools: {[t.name for t in tools.tools]}")
40 | tool_names = {t.name for t in tools.tools}
41 | assert tool_names == {"divide", "greet", "random_dict", "get_system_info", "failing_tool", "struct_get"}, f"unexpected tools: {tool_names}"
42 |
43 | # List available prompts
44 | prompts = await session.list_prompts()
45 | print(f"[{prefix}] Available prompts: {[p.name for p in prompts.prompts]}")
46 | prompt_names = {p.name for p in prompts.prompts}
47 | assert prompt_names == {"code_review", "summarize"}, (
48 | f"unexpected prompts: {prompt_names}"
49 | )
50 |
51 | # Get prompt with required argument only
52 | result = await session.get_prompt("summarize", arguments={"text": "Hello world"})
53 | assert result.messages, "expected messages"
54 | assert result.messages[0].role == "user"
55 | content = result.messages[0].content
56 | assert isinstance(content, types.TextContent), "expected TextContent"
57 | print(f"[{prefix}] Summarize prompt result: {content.text[:50]}...")
58 | assert "Hello world" in content.text, "expected text in prompt"
59 |
60 | # Get prompt with optional argument
61 | result = await session.get_prompt(
62 | "code_review", arguments={"code": "print('hi')", "language": "javascript"}
63 | )
64 | assert result.messages, "expected messages"
65 | content = result.messages[0].content
66 | assert isinstance(content, types.TextContent), "expected TextContent"
67 | print(f"[{prefix}] Code review prompt result: {content.text[:50]}...")
68 | assert "javascript" in content.text, "expected language in prompt"
69 | assert "print('hi')" in content.text, "expected code in prompt"
70 |
71 | # Read a resource - assert content
72 | resource_content = await session.read_resource(AnyUrl("example://system_info"))
73 | content_block = resource_content.contents[0]
74 | assert isinstance(content_block, types.TextResourceContents), "expected TextResourceContents"
75 | print(f"[{prefix}] Resource content: {content_block.text}")
76 | assert "platform" in content_block.text, "expected platform in system_info"
77 | assert "python_version" in content_block.text, "expected python_version in system_info"
78 |
79 | # Read template resource - assert content
80 | template_content = await session.read_resource(AnyUrl("example://greeting/Pythonista"))
81 | template_block = template_content.contents[0]
82 | assert isinstance(template_block, types.TextResourceContents), "expected TextResourceContents"
83 | print(f"[{prefix}] Template resource content: {template_block.text}")
84 | assert "Pythonista" in template_block.text, "expected name in greeting"
85 | assert "message" in template_block.text, "expected message field in greeting"
86 |
87 | # Call divide tool
88 | result = await session.call_tool("divide", arguments={"numerator": 42, "denominator": 2})
89 | assert not result.isError, "divide should succeed"
90 | result_unstructured = result.content[0]
91 | assert isinstance(result_unstructured, types.TextContent), "expected TextContent"
92 | print(f"[{prefix}] Divide result: {result_unstructured.text}")
93 | assert "21" in result_unstructured.text, "42/2 should be 21"
94 |
95 | # Call greet tool without age
96 | result = await session.call_tool("greet", arguments={"name": "Alice"})
97 | assert not result.isError, "greet should succeed"
98 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
99 | content = result.content[0].text
100 | print(f"[{prefix}] Greet result: {content}")
101 | assert "Alice" in content, "expected name in greeting"
102 | assert "message" in content, "expected message field"
103 | assert "age" not in content or content.count("age") == 1, "age should not have value"
104 |
105 | # Call greet tool with age
106 | result = await session.call_tool("greet", arguments={"name": "Bob", "age": 30})
107 | assert not result.isError, "greet with age should succeed"
108 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
109 | content = result.content[0].text
110 | print(f"[{prefix}] Greet with age result: {content}")
111 | assert "Bob" in content and "30" in content, "expected name and age"
112 |
113 | # Call get_system_info tool
114 | result = await session.call_tool("get_system_info", arguments={})
115 | assert not result.isError, "get_system_info should succeed"
116 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
117 | content = result.content[0].text
118 | print(f"[{prefix}] System info result: {content}")
119 | assert "platform" in content, "expected platform"
120 | assert "python_version" in content, "expected python_version"
121 | assert "timestamp" in content, "expected timestamp"
122 |
123 | # Call struct_get with list
124 | result = await session.call_tool("struct_get", arguments={"names": ["foo", "bar"]})
125 | assert not result.isError, "struct_get with list should succeed"
126 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
127 | content = result.content[0].text
128 | print(f"[{prefix}] Struct_get (list) result: {content}")
129 | assert "foo" in content and "bar" in content, "expected both struct names"
130 |
131 | # Call struct_get with string
132 | result = await session.call_tool("struct_get", arguments={"names": "baz"})
133 | assert not result.isError, "struct_get with string should succeed"
134 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
135 | content = result.content[0].text
136 | print(f"[{prefix}] Struct_get (string) result: {content}")
137 | assert "baz" in content, "expected struct name"
138 |
139 | # Call failing tool
140 | result = await session.call_tool("failing_tool", arguments={"message": "This is a test error"})
141 | print(f"[{prefix}] Failing tool result: {result}")
142 | assert result.isError, "expected tool call to fail"
143 | assert isinstance(result.content[0], types.TextContent), "expected text content in tool call result"
144 | assert "test error" in result.content[0].text, "expected error message in tool call result"
145 |
146 | # Call random_dict tool
147 | result = await session.call_tool("random_dict", arguments={"param": {"x": 112, "other": "yes"}})
148 | assert not result.isError, "random_dict should succeed"
149 | assert isinstance(result.content[0], types.TextContent), "expected TextContent"
150 | content = result.content[0].text
151 | print(f"[{prefix}] Random dict result: {content}")
152 |
153 | # Call random_dict tool with null
154 | result = await session.call_tool("random_dict", arguments={"param": None})
155 | assert not result.isError, "random_dict with null should succeed"
156 |
157 | async def test_edge_cases(prefix: str, session: ClientSession):
158 | """Test edge cases and error conditions"""
159 | await session.initialize()
160 |
161 | # Test non-existent tool
162 | result = await session.call_tool("nonexistent_tool", arguments={})
163 | assert result.isError, "should error on non-existent tool"
164 | print(f"[{prefix}] Non-existent tool error: {result.content[0] if result.content else 'no content'}")
165 |
166 | # Test missing required parameter
167 | result = await session.call_tool("divide", arguments={"numerator": 42})
168 | assert result.isError, "should error on missing denominator"
169 | print(f"[{prefix}] Missing param error: {result.content[0] if result.content else 'no content'}")
170 |
171 | # Test division by zero (natural exception)
172 | result = await session.call_tool("divide", arguments={"numerator": 1, "denominator": 0})
173 | assert result.isError, "division by zero should error"
174 | print(f"[{prefix}] Division by zero error: {result.content[0] if result.content else 'no content'}")
175 |
176 | # Test invalid resource URI
177 | try:
178 | await session.read_resource(AnyUrl("example://invalid_resource"))
179 | assert False, "should have raised on invalid resource"
180 | except McpError as e:
181 | assert "Resource not found" in e.error.message, "expected invalid resource error"
182 |
183 | # Test resource template with missing substitution
184 | try:
185 | await session.read_resource(AnyUrl("example://greeting/"))
186 | assert False, "should have raised on missing template parameter"
187 | except McpError as e:
188 | assert "Resource not found" in e.error.message, "expected missing substitution error"
189 |
190 | # Test resource that raises an error
191 | try:
192 | await session.read_resource(AnyUrl("example://error"))
193 | assert False, "should have raised on error resource"
194 | except McpError as e:
195 | assert "This is a resource error for testing purposes." in e.error.message, (
196 | "expected resource error message"
197 | )
198 |
199 | # Test non-existent prompt
200 | try:
201 | await session.get_prompt("nonexistent_prompt", arguments={})
202 | assert False, "should have raised on non-existent prompt"
203 | except McpError as e:
204 | assert "Method 'nonexistent_prompt' not found" in e.error.message, (
205 | "expected method not found error"
206 | )
207 |
208 | print(f"[{prefix}] Edge case tests passed!")
209 |
210 |
211 | def coverage_wrap(name: str, args: list[str]) -> list[str]:
212 | if os.environ.get("COVERAGE_RUN"):
213 | args = ["-m", "coverage", "run", f"--data-file=.coverage.{name}"] + args
214 | return args
215 |
216 | async def test_stdio():
217 | print("[stdio] Testing...")
218 | server_params = StdioServerParameters(
219 | command=sys.executable,
220 | args=coverage_wrap("stdio", [example_mcp, "--transport", "stdio"]),
221 | )
222 | async with stdio_client(server_params) as (read, write):
223 | async with ClientSession(read, write) as session:
224 | await test_example_server("stdio", session)
225 | await test_edge_cases("stdio", session)
226 |
227 | async def test_sse(address: str):
228 | print("[sse] Testing...")
229 | async with sse_client(f"{address}/sse") as (read, write):
230 | async with ClientSession(read, write) as session:
231 | await test_example_server("sse", session)
232 | await test_edge_cases("sse", session)
233 |
234 | async def test_streamablehttp(address: str):
235 | print("[streamable] Testing...")
236 | async with streamablehttp_client(f"{address}/mcp") as (read, write, session_callback):
237 | async with ClientSession(read, write) as session:
238 | await test_example_server("streamable", session)
239 | await test_edge_cases("streamable", session)
240 |
241 | def find_available_port():
242 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
243 | sock.bind(("", 0))
244 | port = sock.getsockname()[1]
245 | sock.close()
246 | return port
247 |
248 | async def test_serve():
249 | print("[serve] Testing...")
250 |
251 | # Start example MCP server as subprocess
252 | address = f"http://127.0.0.1:{find_available_port()}"
253 | process = subprocess.Popen(
254 | [sys.executable] + coverage_wrap("serve", [example_mcp, "--transport", address]),
255 | stdin=subprocess.PIPE,
256 | text=True,
257 | encoding="utf-8",
258 | bufsize=1,
259 | )
260 | try:
261 | await asyncio.sleep(0.5) # Wait for server to start
262 | await test_sse(address)
263 | await test_streamablehttp(address)
264 | finally:
265 | print("[serve] Terminating example MCP server")
266 | process.stdin.close() # type: ignore
267 | process.wait()
268 | pass
269 |
270 | async def main():
271 | await test_serve()
272 | await test_stdio()
273 |
274 | if __name__ == "__main__":
275 | import os
276 | asyncio.run(main())
277 |
--------------------------------------------------------------------------------
/tests/jsonrpc_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Comprehensive JSON-RPC 2.0 test suite for MCP implementation
3 | """
4 | import json
5 | import sys
6 | import traceback
7 | import re
8 | from typing import Optional, Any, TypedDict
9 |
10 | from zeromcp.jsonrpc import JsonRpcRegistry
11 |
12 | # Create registry and register test methods
13 | jsonrpc = JsonRpcRegistry()
14 |
15 | class Point(TypedDict):
16 | x: int
17 | y: int
18 |
19 | @jsonrpc.method
20 | def subtract(minuend: int, subtrahend: int) -> int:
21 | return minuend - subtrahend
22 |
23 | @jsonrpc.method
24 | def update(a: int, b: int, c: int, d: int, e: int) -> str:
25 | return "updated"
26 |
27 | @jsonrpc.method
28 | def foobar() -> str:
29 | return "bar"
30 |
31 | @jsonrpc.method
32 | def get_data() -> list:
33 | return ["hello", 5]
34 |
35 | @jsonrpc.method
36 | def greet(name: str, greeting: str = "Hello") -> str:
37 | return f"{greeting}, {name}!"
38 |
39 | @jsonrpc.method
40 | def process_optional(value: Optional[int]) -> str:
41 | return f"Got: {value}"
42 |
43 | @jsonrpc.method
44 | def union_test(id: int | str | None | Point) -> str:
45 | return f"ID: {id or ''}"
46 |
47 | @jsonrpc.method
48 | def list_test(items: list[str]) -> int:
49 | return len(items)
50 |
51 | @jsonrpc.method
52 | def exception():
53 | raise Exception("Python exception")
54 |
55 | @jsonrpc.method
56 | def point_pretty(p: Point) -> str:
57 | return f"Point(x={p['x']}, y={p['y']})"
58 |
59 | @jsonrpc.method
60 | def round_float(value: float) -> int:
61 | return round(value)
62 |
63 | @jsonrpc.method
64 | def python_repr(value: Any) -> str:
65 | return repr(value)
66 |
67 | @jsonrpc.method
68 | def unknown(x, y):
69 | return x + y
70 |
71 | def matches_response(actual: dict | None, expected: dict | None) -> bool:
72 | """Check if actual response matches expected, with regex support for error messages."""
73 | if actual is None and expected is None:
74 | return True
75 | if actual is None or expected is None:
76 | return False
77 |
78 | # Check top-level keys
79 | if set(actual.keys()) != set(expected.keys()):
80 | return False
81 |
82 | for key in expected.keys():
83 | actual_val = actual[key]
84 | expected_val = expected[key]
85 |
86 | # Handle error object specially for regex matching
87 | if key == "error" and isinstance(expected_val, dict) and isinstance(actual_val, dict):
88 | # Check code exactly
89 | if actual_val.get("code") != expected_val.get("code"):
90 | return False
91 |
92 | # Check message with regex support
93 | expected_msg = expected_val.get("message", "")
94 | actual_msg = actual_val.get("message", "")
95 | if isinstance(expected_msg, str) and expected_msg.startswith("regex:"):
96 | pattern = expected_msg[6:] # Remove "regex:" prefix
97 | if not re.search(pattern, actual_msg):
98 | return False
99 | else:
100 | if actual_msg != expected_msg:
101 | return False
102 |
103 | # For data field, support regex patterns
104 | if "data" in expected_val:
105 | expected_data = expected_val["data"]
106 | actual_data = actual_val.get("data", "")
107 |
108 | # If expected_data starts with "regex:", treat it as a regex pattern
109 | if isinstance(expected_data, str) and expected_data.startswith("regex:"):
110 | pattern = expected_data[6:] # Remove "regex:" prefix
111 | if not re.search(pattern, str(actual_data)):
112 | return False
113 | else:
114 | # Exact match
115 | if actual_data != expected_data:
116 | return False
117 | elif "data" in actual_val:
118 | # Actual has data but expected doesn't - that's ok
119 | pass
120 | else:
121 | # Exact match for other fields
122 | if actual_val != expected_val:
123 | return False
124 |
125 | return True
126 |
127 | def test_rpc(request: Any, expected_response: dict | None = None, description: str = ""):
128 | """Helper to test RPC calls"""
129 | print(f"\n{'='*60}")
130 | print(f"Test: {description}")
131 | print(f"--> {request}")
132 |
133 | try:
134 | result = jsonrpc.dispatch(request)
135 | except Exception:
136 | print("\n❌ UNEXPECTED EXCEPTION:")
137 | traceback.print_exc()
138 | sys.exit(1)
139 |
140 | if result is None:
141 | print("<-- (no response - notification)")
142 | if expected_response is not None:
143 | print("\n❌ FAIL: Expected response but got None")
144 | print("Expected: {json.dumps(expected_response, indent=2)}")
145 | sys.exit(1)
146 | else:
147 | result_json = json.dumps(result, indent=2)
148 | print(f"<-- {result_json}")
149 |
150 | if expected_response is not None:
151 | if not matches_response(result, expected_response): # type: ignore
152 | print("\n❌ FAIL: Response mismatch")
153 | print(f"Expected: {json.dumps(expected_response, indent=2)}")
154 | print(f"Got: {result_json}")
155 | sys.exit(1)
156 |
157 | print("✓ PASS")
158 | return result
159 |
160 |
161 | def run_all_tests():
162 | print("="*60)
163 | print("JSON-RPC 2.0 COMPLIANCE TESTS")
164 | print("="*60)
165 |
166 | # ========================================
167 | # SPEC EXAMPLES
168 | # ========================================
169 |
170 | # Positional parameters
171 | test_rpc(
172 | {"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1},
173 | {"jsonrpc": "2.0", "result": 19, "id": 1},
174 | "Positional params - subtract(42, 23)"
175 | )
176 |
177 | test_rpc(
178 | '{"jsonrpc": "2.0", "method": "subtract", "params": [23, 42], "id": 2}',
179 | {"jsonrpc": "2.0", "result": -19, "id": 2},
180 | "Positional params - subtract(23, 42)"
181 | )
182 |
183 | # Named parameters
184 | test_rpc(
185 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"subtrahend": 23, "minuend": 42}, "id": 3}',
186 | {"jsonrpc": "2.0", "result": 19, "id": 3},
187 | "Named params - order independent (1)"
188 | )
189 |
190 | test_rpc(
191 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23}, "id": 4}',
192 | {"jsonrpc": "2.0", "result": 19, "id": 4},
193 | "Named params - order independent (2)"
194 | )
195 |
196 | # Notifications (no response)
197 | test_rpc(
198 | '{"jsonrpc": "2.0", "method": "update", "params": [1,2,3,4,5]}',
199 | None,
200 | "Notification - update (no id)"
201 | )
202 |
203 | test_rpc(
204 | '{"jsonrpc": "2.0", "method": "foobar"}',
205 | None,
206 | "Notification - foobar (no params, no id)"
207 | )
208 |
209 | # Non-existent method
210 | test_rpc(
211 | '{"jsonrpc": "2.0", "method": "does_not_exist", "id": "1"}',
212 | {"jsonrpc": "2.0", "error": {"code": -32601, "message": "regex:Method.*not found"}, "id": "1"},
213 | "Non-existent method error"
214 | )
215 |
216 | # Invalid JSON - use regex to match error since different parsers give different messages
217 | test_rpc(
218 | '{"jsonrpc": "2.0", "method": "foobar, "params": "bar", "baz]',
219 | {"jsonrpc": "2.0", "error": {"code": -32700, "message": "JSON parse error", "data": "regex:Expecting"}, "id": None},
220 | "Parse error - invalid JSON"
221 | )
222 |
223 | test_rpc(
224 | 1234,
225 | {"jsonrpc": "2.0", "error": {"code": -32700, "message": "JSON parse error", "data": "regex:object must be"}, "id": None},
226 | "Parse error - invalid JSON"
227 | )
228 |
229 | # Invalid Request object - method is not a string
230 | test_rpc(
231 | '{"jsonrpc": "2.0", "method": 1, "params": "bar"}',
232 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'method' must be a string"}, "id": None},
233 | "Invalid Request - method is number"
234 | )
235 |
236 | # Missing jsonrpc version
237 | test_rpc(
238 | '{"method": "subtract", "params": [1, 2], "id": 1}',
239 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'jsonrpc' must be '2.0'"}, "id": None},
240 | "Invalid Request - missing jsonrpc field"
241 | )
242 |
243 | # Wrong jsonrpc version
244 | test_rpc(
245 | '{"jsonrpc": "1.0", "method": "subtract", "params": [1, 2], "id": 1}',
246 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'jsonrpc' must be '2.0'"}, "id": None},
247 | "Invalid Request - wrong jsonrpc version"
248 | )
249 |
250 | # Missing method
251 | test_rpc(
252 | '{"jsonrpc": "2.0", "params": [1, 2], "id": 1}',
253 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: 'method' is required"}, "id": None},
254 | "Invalid Request - missing method"
255 | )
256 |
257 | # Empty array (not valid single request)
258 | test_rpc(
259 | '[]',
260 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None},
261 | "Invalid Request - empty array"
262 | )
263 |
264 | # Non-object request
265 | test_rpc(
266 | '"not an object"',
267 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None},
268 | "Invalid Request - string instead of object"
269 | )
270 |
271 | test_rpc(
272 | '123',
273 | {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid request: must be a JSON object"}, "id": None},
274 | "Invalid Request - number instead of object"
275 | )
276 |
277 | # Request with id: null (valid request, not a notification)
278 | test_rpc(
279 | '{"jsonrpc": "2.0", "method": "foobar", "id": null}',
280 | {"jsonrpc": "2.0", "result": "bar", "id": None},
281 | "Valid request with id: null"
282 | )
283 |
284 | # ========================================
285 | # PARAMETER VALIDATION TESTS
286 | # ========================================
287 |
288 | # Wrong number of positional params - too few
289 | test_rpc(
290 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42], "id": 1}',
291 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: expected at least 2 arguments, got 1"}, "id": 1},
292 | "Invalid params - too few positional arguments"
293 | )
294 |
295 | # Wrong number of positional params - too many
296 | test_rpc(
297 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, 23, 10], "id": 1}',
298 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: expected at most 2 arguments, got 3"}, "id": 1},
299 | "Invalid params - too many positional arguments"
300 | )
301 |
302 | # Missing required named param
303 | test_rpc(
304 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42}, "id": 1}',
305 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: missing required parameters: ['subtrahend']"}, "id": 1},
306 | "Invalid params - missing required parameter"
307 | )
308 |
309 | # Extra named param
310 | test_rpc(
311 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23, "extra": 1}, "id": 1}',
312 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: unexpected parameters: ['extra']"}, "id": 1},
313 | "Invalid params - unexpected parameter"
314 | )
315 |
316 | # Wrong type - string instead of int
317 | test_rpc(
318 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, "not a number"], "id": 1}',
319 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend expected int, got str"}, "id": 1},
320 | "Invalid params - wrong type (str instead of int)"
321 | )
322 |
323 | # Wrong type - list instead of int
324 | test_rpc(
325 | '{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": [1, 2]}, "id": 1}',
326 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend expected int, got list"}, "id": 1},
327 | "Invalid params - wrong type (list instead of int)"
328 | )
329 |
330 | # Null for non-optional param
331 | test_rpc(
332 | '{"jsonrpc": "2.0", "method": "subtract", "params": [42, null], "id": 1}',
333 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: subtrahend cannot be null"}, "id": 1},
334 | "Invalid params - null for non-optional parameter"
335 | )
336 |
337 | # Params is invalid type (not array or object)
338 | test_rpc(
339 | '{"jsonrpc": "2.0", "method": "subtract", "params": "invalid", "id": 1}',
340 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: must be array or object"}, "id": 1},
341 | "Invalid params - string instead of array/object"
342 | )
343 |
344 | test_rpc(
345 | '{"jsonrpc": "2.0", "method": "subtract", "params": 123, "id": 1}',
346 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: must be array or object"}, "id": 1},
347 | "Invalid params - number instead of array/object"
348 | )
349 |
350 | test_rpc(
351 | '{"jsonrpc": "2.0", "method": "subtract", "params": null, "id": 1}',
352 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Missing required params"}, "id": 1},
353 | "Invalid params - null for required params"
354 | )
355 |
356 | # ========================================
357 | # DEFAULT PARAMETERS TESTS
358 | # ========================================
359 |
360 | # Function with default param - omit optional param (positional)
361 | test_rpc(
362 | '{"jsonrpc": "2.0", "method": "greet", "params": ["Alice"], "id": 1}',
363 | {"jsonrpc": "2.0", "result": "Hello, Alice!", "id": 1},
364 | "Default param - omit optional (positional)"
365 | )
366 |
367 | # Function with default param - provide optional param (positional)
368 | test_rpc(
369 | '{"jsonrpc": "2.0", "method": "greet", "params": ["Alice", "Hi"], "id": 1}',
370 | {"jsonrpc": "2.0", "result": "Hi, Alice!", "id": 1},
371 | "Default param - provide optional (positional)"
372 | )
373 |
374 | # Function with default param - omit optional param (named)
375 | test_rpc(
376 | '{"jsonrpc": "2.0", "method": "greet", "params": {"name": "Bob"}, "id": 1}',
377 | {"jsonrpc": "2.0", "result": "Hello, Bob!", "id": 1},
378 | "Default param - omit optional (named)"
379 | )
380 |
381 | # Function with default param - provide optional param (named)
382 | test_rpc(
383 | '{"jsonrpc": "2.0", "method": "greet", "params": {"name": "Bob", "greeting": "Hey"}, "id": 1}',
384 | {"jsonrpc": "2.0", "result": "Hey, Bob!", "id": 1},
385 | "Default param - provide optional (named)"
386 | )
387 |
388 | # Function with no params - omit params field
389 | test_rpc(
390 | '{"jsonrpc": "2.0", "method": "get_data", "id": 1}',
391 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1},
392 | "No params function - params field omitted"
393 | )
394 |
395 | # Function with no params - empty array
396 | test_rpc(
397 | '{"jsonrpc": "2.0", "method": "get_data", "params": [], "id": 1}',
398 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1},
399 | "No params function - empty array"
400 | )
401 |
402 | # Function with no params - empty object
403 | test_rpc(
404 | '{"jsonrpc": "2.0", "method": "get_data", "params": {}, "id": 1}',
405 | {"jsonrpc": "2.0", "result": ["hello", 5], "id": 1},
406 | "No params function - empty object"
407 | )
408 |
409 | # ========================================
410 | # UNION TYPE TESTS
411 | # ========================================
412 |
413 | # Union type - int
414 | test_rpc(
415 | '{"jsonrpc": "2.0", "method": "union_test", "params": [123], "id": 1}',
416 | {"jsonrpc": "2.0", "result": "ID: 123", "id": 1},
417 | "Union type (int | str) - int value"
418 | )
419 |
420 | # Union type - str
421 | test_rpc(
422 | '{"jsonrpc": "2.0", "method": "union_test", "params": ["abc"], "id": 1}',
423 | {"jsonrpc": "2.0", "result": "ID: abc", "id": 1},
424 | "Union type (int | str) - str value"
425 | )
426 |
427 | # Union type - invalid type
428 | test_rpc(
429 | '{"jsonrpc": "2.0", "method": "union_test", "params": [[1, 2, 3]], "id": 1}',
430 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: id union does not contain list"}, "id": 1},
431 | "Union type (int | str) - invalid list"
432 | )
433 |
434 | # Union type - null
435 | test_rpc(
436 | '{"jsonrpc": "2.0", "method": "union_test", "params": [null], "id": 1}',
437 | {"jsonrpc": "2.0", "result": "ID: ", "id": 1},
438 | "Union type (int | str | None) - null value"
439 | )
440 |
441 | # ========================================
442 | # OPTIONAL TYPE TESTS
443 | # ========================================
444 |
445 | # Optional - provide value
446 | test_rpc(
447 | '{"jsonrpc": "2.0", "method": "process_optional", "params": [42], "id": 1}',
448 | {"jsonrpc": "2.0", "result": "Got: 42", "id": 1},
449 | "Optional type - provide value"
450 | )
451 |
452 | # Optional - provide null
453 | test_rpc(
454 | '{"jsonrpc": "2.0", "method": "process_optional", "params": [null], "id": 1}',
455 | {"jsonrpc": "2.0", "result": "Got: None", "id": 1},
456 | "Optional type - provide null"
457 | )
458 |
459 | # ========================================
460 | # GENERIC TYPE TESTS
461 | # ========================================
462 |
463 | # list[T] - valid list
464 | test_rpc(
465 | '{"jsonrpc": "2.0", "method": "list_test", "params": [["a", "b", "c"]], "id": 1}',
466 | {"jsonrpc": "2.0", "result": 3, "id": 1},
467 | "Generic type list[str] - valid list (no inner validation)"
468 | )
469 |
470 | # list[T] - wrong outer type
471 | test_rpc(
472 | '{"jsonrpc": "2.0", "method": "list_test", "params": ["not a list"], "id": 1}',
473 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: items expected list, got str"}, "id": 1},
474 | "Generic type list[str] - wrong outer type"
475 | )
476 |
477 | # Point TypedDict - valid dict
478 | test_rpc(
479 | '{"jsonrpc": "2.0", "method": "point_pretty", "params": [{"x": 10, "y": 20}], "id": 1}',
480 | {"jsonrpc": "2.0", "result": "Point(x=10, y=20)", "id": 1},
481 | "TypedDict Point - valid dict"
482 | )
483 |
484 | # Point TypedDict - wrong outer type
485 | test_rpc(
486 | '{"jsonrpc": "2.0", "method": "point_pretty", "params": ["not a dict"], "id": 1}',
487 | {"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid params: p expected dict, got str"}, "id": 1},
488 | "TypedDict Point - wrong outer type"
489 | )
490 |
491 | # Convert from int to float
492 | test_rpc(
493 | '{"jsonrpc": "2.0", "method": "round_float", "params": [3], "id": 1}',
494 | {"jsonrpc": "2.0", "result": 3, "id": 1},
495 | "Convert int to float for float parameter"
496 | )
497 |
498 | # Any type - various inputs
499 | test_rpc(
500 | '{"jsonrpc": "2.0", "method": "python_repr", "params": [42], "id": 1}',
501 | {"jsonrpc": "2.0", "result": "42", "id": 1},
502 | "Any type - int value"
503 | )
504 |
505 | test_rpc(
506 | '{"jsonrpc": "2.0", "method": "python_repr", "params": ["hello"], "id": 1}',
507 | {"jsonrpc": "2.0", "result": "'hello'", "id": 1},
508 | "Any type - str value"
509 | )
510 |
511 | # Unspecified types (unknown) - should accept anything
512 | test_rpc(
513 | '{"jsonrpc": "2.0", "method": "unknown", "params": [10, 20], "id": 1}',
514 | {"jsonrpc": "2.0", "result": 30, "id": 1},
515 | "Unknown parameter types - accept any"
516 | )
517 |
518 | # ========================================
519 | # NOTIFICATION ERROR HANDLING
520 | # ========================================
521 |
522 | # Notification with error (should return None, no response)
523 | test_rpc(
524 | '{"jsonrpc": "2.0", "method": "does_not_exist"}',
525 | None,
526 | "Notification - error does not produce response"
527 | )
528 |
529 | # Notification with invalid params (should return None, no response)
530 | test_rpc(
531 | '{"jsonrpc": "2.0", "method": "subtract", "params": [1]}',
532 | None,
533 | "Notification - invalid params does not produce response"
534 | )
535 |
536 | test_rpc(
537 | '{"jsonrpc": "2.0", "method": "exception", "id": 1}',
538 | {"jsonrpc": "2.0", "error": {"code": -32603, "message": "regex:Python exception"}, "id": 1},
539 | "Method that raises python exception"
540 | )
541 |
542 | test_rpc(
543 | '{"jsonrpc": "2.0", "method": "exception"}',
544 | None,
545 | "Notification - method that raises python exception"
546 | )
547 |
548 | print("\n" + "="*60)
549 | print("ALL TESTS PASSED! ✓")
550 | print("="*60)
551 |
552 | if __name__ == "__main__":
553 | run_all_tests()
554 |
--------------------------------------------------------------------------------
/src/zeromcp/mcp.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import time
4 | import uuid
5 | import json
6 | import inspect
7 | import threading
8 | import traceback
9 | from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer, HTTPServer
10 | from typing import Any, Callable, Union, Annotated, BinaryIO, NotRequired, get_origin, get_args, get_type_hints, is_typeddict
11 | from types import UnionType
12 | from urllib.parse import urlparse, parse_qs
13 | from io import BufferedIOBase
14 |
15 | from .jsonrpc import JsonRpcRegistry, JsonRpcError, JsonRpcException
16 |
17 | class McpToolError(Exception):
18 | def __init__(self, message: str):
19 | super().__init__(message)
20 |
21 | class McpRpcRegistry(JsonRpcRegistry):
22 | """JSON-RPC registry with custom error handling for MCP tools"""
23 | def map_exception(self, e: Exception) -> JsonRpcError:
24 | if isinstance(e, McpToolError):
25 | return {
26 | "code": -32000,
27 | "message": e.args[0] or "MCP Tool Error",
28 | }
29 | return super().map_exception(e)
30 |
31 | class _McpSseConnection:
32 | """Manages a single SSE client connection"""
33 | def __init__(self, wfile):
34 | self.wfile: BufferedIOBase = wfile
35 | self.session_id = str(uuid.uuid4())
36 | self.alive = True
37 |
38 | def send_event(self, event_type: str, data):
39 | """Send an SSE event to the client
40 |
41 | Args:
42 | event_type: Type of event (e.g., "endpoint", "message", "ping")
43 | data: Event data - can be string (sent as-is) or dict (JSON-encoded)
44 | """
45 | if not self.alive:
46 | return False
47 |
48 | try:
49 | # SSE format: "event: type\ndata: content\n\n"
50 | if isinstance(data, str):
51 | data_str = f"data: {data}\n\n"
52 | else:
53 | data_str = f"data: {json.dumps(data)}\n\n"
54 | message = f"event: {event_type}\n{data_str}".encode("utf-8")
55 | self.wfile.write(message)
56 | self.wfile.flush() # Ensure data is sent immediately
57 | return True
58 | except (BrokenPipeError, OSError):
59 | self.alive = False
60 | return False
61 |
62 | class McpHttpRequestHandler(BaseHTTPRequestHandler):
63 | server_version = "zeromcp/1.3.0"
64 | error_message_format = "%(code)d - %(message)s"
65 | error_content_type = "text/plain"
66 |
67 | def __init__(self, request, client_address, server):
68 | self.mcp_server: "McpServer" = getattr(server, "mcp_server")
69 | super().__init__(request, client_address, server)
70 |
71 | def log_message(self, format, *args):
72 | """Override to suppress default logging or customize"""
73 | pass
74 |
75 | def send_cors_headers(self, *, preflight = False):
76 | origin = self.headers.get("Origin", "")
77 | if not origin:
78 | return
79 | def is_allowed():
80 | allowed = self.mcp_server.cors_allowed_origins
81 | if allowed is None:
82 | return False
83 | if callable(allowed):
84 | return allowed(origin)
85 | if isinstance(allowed, str):
86 | allowed = [allowed]
87 | assert isinstance(allowed, list)
88 | return "*" in allowed or origin in allowed
89 | if not is_allowed():
90 | return
91 | self.send_header("Access-Control-Allow-Origin", origin)
92 | if preflight:
93 | self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
94 | self.send_header("Access-Control-Allow-Headers", "Content-Type, Accept, X-Requested-With, Mcp-Session-Id, Mcp-Protocol-Version")
95 | if self.headers.get("Access-Control-Request-Private-Network") == "true":
96 | self.send_header("Access-Control-Allow-Private-Network", "true")
97 |
98 | def send_error(self, code, message=None, explain=None):
99 | self.send_response(code)
100 | self.send_header("Content-Type", "text/plain")
101 | self.send_cors_headers()
102 | self.end_headers()
103 | self.wfile.write(f"{message}\n".encode("utf-8"))
104 |
105 | def handle(self):
106 | """Override to add error handling for connection errors"""
107 | try:
108 | super().handle()
109 | except (ConnectionAbortedError, ConnectionResetError, BrokenPipeError):
110 | # Client disconnected - normal, suppress traceback
111 | pass
112 |
113 | def do_GET(self):
114 | match urlparse(self.path).path:
115 | case "/sse":
116 | self._handle_sse_get()
117 | case "/mcp":
118 | self.send_error(405, "Method Not Allowed")
119 | case _:
120 | self.send_error(404, "Not Found")
121 |
122 | def do_POST(self):
123 | # Read request body
124 | content_length = int(self.headers.get("Content-Length", 0))
125 |
126 | if content_length > self.mcp_server.post_body_limit:
127 | self.send_error(413, f"Payload Too Large: exceeds {self.mcp_server.post_body_limit} bytes")
128 | return
129 |
130 | body = self.rfile.read(content_length) if content_length > 0 else b""
131 |
132 | match urlparse(self.path).path:
133 | case "/sse":
134 | self._handle_sse_post(body)
135 | case "/mcp":
136 | self._handle_mcp_post(body)
137 | case _:
138 | self.send_error(404, "Not Found")
139 |
140 | def do_OPTIONS(self):
141 | """Handle CORS preflight requests"""
142 | self.send_response(200)
143 | self.send_cors_headers(preflight=True)
144 | self.end_headers()
145 |
146 | def _handle_sse_get(self):
147 | # Create SSE connection wrapper
148 | conn = _McpSseConnection(self.wfile)
149 | self.mcp_server._sse_connections[conn.session_id] = conn
150 |
151 | try:
152 | # Send SSE headers
153 | self.send_response(200)
154 | self.send_header("Content-Type", "text/event-stream")
155 | self.send_header("Cache-Control", "no-cache")
156 | self.send_header("Connection", "keep-alive")
157 | self.send_cors_headers()
158 | self.end_headers()
159 |
160 | # Send endpoint event with session ID for routing
161 | conn.send_event("endpoint", f"/sse?session={conn.session_id}")
162 |
163 | # Keep connection alive with periodic pings
164 | last_ping = time.time()
165 | while conn.alive and self.mcp_server._running:
166 | now = time.time()
167 | if now - last_ping > 30: # Ping every 30 seconds
168 | if not conn.send_event("ping", {}):
169 | break
170 | last_ping = now
171 | time.sleep(1)
172 |
173 | finally:
174 | conn.alive = False
175 | if conn.session_id in self.mcp_server._sse_connections:
176 | del self.mcp_server._sse_connections[conn.session_id]
177 |
178 | def _handle_sse_post(self, body: bytes):
179 | query_params = parse_qs(urlparse(self.path).query)
180 | session_id = query_params.get("session", [None])[0]
181 | if session_id is None:
182 | self.send_error(400, "Missing ?session for SSE POST")
183 | return
184 |
185 | # Dispatch to MCP registry
186 | setattr(self.mcp_server._protocol_version, "data", "2024-11-05")
187 | response = self.mcp_server.registry.dispatch(body)
188 |
189 | # Send SSE response if necessary
190 | if response is not None:
191 | sse_conn = self.mcp_server._sse_connections.get(session_id)
192 | if sse_conn is None or not sse_conn.alive:
193 | # No SSE connection found
194 | self.send_error(400, f"No active SSE connection found for session {session_id}")
195 | return
196 |
197 | # Send response via SSE event stream
198 | sse_conn.send_event("message", response)
199 |
200 | # Return 202 Accepted to acknowledge POST
201 | self.send_response(202)
202 | self.send_header("Content-Type", "application/json")
203 | self.send_header("Content-Length", str(len(body)))
204 | self.send_cors_headers()
205 | self.end_headers()
206 | self.wfile.write(body)
207 |
208 | def _handle_mcp_post(self, body: bytes):
209 | # Dispatch to MCP registry
210 | setattr(self.mcp_server._protocol_version, "data", "2025-06-18")
211 | response = self.mcp_server.registry.dispatch(body)
212 |
213 | def send_response(status: int, body: bytes):
214 | self.send_response(status)
215 | self.send_header("Content-Type", "application/json")
216 | self.send_header("Content-Length", str(len(body)))
217 | self.send_cors_headers()
218 | self.end_headers()
219 | self.wfile.write(body)
220 |
221 | # Check if notification (returns None)
222 | if response is None:
223 | send_response(202, b"Accepted")
224 | else:
225 | send_response(200, json.dumps(response).encode("utf-8"))
226 |
227 | class McpServer:
228 | def __init__(self, name: str, version = "1.0.0"):
229 | self.name = name
230 | self.version = version
231 | self.post_body_limit = 10 * 1024 * 1024
232 | self.cors_allowed_origins: Callable[[str], bool] | list[str] | str | None = self.cors_localhost
233 | self.tools = McpRpcRegistry()
234 | self.resources = McpRpcRegistry()
235 | self.prompts = McpRpcRegistry()
236 |
237 | self._http_server: HTTPServer | None = None
238 | self._server_thread: threading.Thread | None = None
239 | self._running = False
240 | self._sse_connections: dict[str, _McpSseConnection] = {}
241 | self._protocol_version = threading.local()
242 |
243 | # Register MCP protocol methods with correct names
244 | self.registry = JsonRpcRegistry()
245 | self.registry.methods["ping"] = self._mcp_ping
246 | self.registry.methods["initialize"] = self._mcp_initialize
247 | self.registry.methods["tools/list"] = self._mcp_tools_list
248 | self.registry.methods["tools/call"] = self._mcp_tools_call
249 | self.registry.methods["resources/list"] = self._mcp_resources_list
250 | self.registry.methods["resources/templates/list"] = self._mcp_resource_templates_list
251 | self.registry.methods["resources/read"] = self._mcp_resources_read
252 | self.registry.methods["prompts/list"] = self._mcp_prompts_list
253 | self.registry.methods["prompts/get"] = self._mcp_prompts_get
254 |
255 | def tool(self, func: Callable) -> Callable:
256 | return self.tools.method(func)
257 |
258 | def prompt(self, func: Callable) -> Callable:
259 | return self.prompts.method(func)
260 |
261 | def resource(self, uri: str) -> Callable[[Callable], Callable]:
262 | def decorator(func: Callable) -> Callable:
263 | setattr(func, "__resource_uri__", uri)
264 | return self.resources.method(func)
265 | return decorator
266 |
267 | def serve(self, host: str, port: int, *, background = True, request_handler = McpHttpRequestHandler):
268 | if self._running:
269 | print("[MCP] Server is already running")
270 | return
271 |
272 | # Create server with deferred binding
273 | assert issubclass(request_handler, McpHttpRequestHandler)
274 | self._http_server = (ThreadingHTTPServer if background else HTTPServer)(
275 | (host, port), request_handler, bind_and_activate=False
276 | )
277 | self._http_server.allow_reuse_address = False
278 |
279 | # Set the MCPServer instance on the handler class
280 | setattr(self._http_server, "mcp_server", self)
281 |
282 | try:
283 | # Bind and activate in main thread - errors propagate synchronously
284 | self._http_server.server_bind()
285 | self._http_server.server_activate()
286 | except OSError:
287 | # Cleanup on binding failure
288 | self._http_server.server_close()
289 | self._http_server = None
290 | raise
291 |
292 | # Only start thread after successful bind
293 | self._running = True
294 |
295 | print("[MCP] Server started:")
296 | print(f" Streamable HTTP: http://{host}:{port}/mcp")
297 | print(f" SSE: http://{host}:{port}/sse")
298 |
299 | def serve_forever():
300 | try:
301 | self._http_server.serve_forever() # type: ignore
302 | except Exception as e:
303 | print(f"[MCP] Server error: {e}")
304 | traceback.print_exc()
305 | finally:
306 | self._running = False
307 |
308 | if background:
309 | self._server_thread = threading.Thread(target=serve_forever, daemon=True)
310 | self._server_thread.start()
311 | else:
312 | serve_forever()
313 |
314 | def stop(self):
315 | if not self._running:
316 | return
317 |
318 | self._running = False
319 |
320 | # Close all SSE connections
321 | for conn in self._sse_connections.values():
322 | conn.alive = False
323 | self._sse_connections.clear()
324 |
325 | # Shutdown the HTTP server
326 | if self._http_server:
327 | # shutdown() must be called from a different thread
328 | # than the one running serve_forever()
329 | self._http_server.shutdown()
330 | self._http_server.server_close()
331 | self._http_server = None
332 |
333 | if self._server_thread:
334 | self._server_thread.join()
335 | self._server_thread = None
336 |
337 | print("[MCP] Server stopped")
338 |
339 | def stdio(self, stdin: BinaryIO | None = None, stdout: BinaryIO | None = None):
340 | stdin = stdin or sys.stdin.buffer
341 | stdout = stdout or sys.stdout.buffer
342 | while True:
343 | try:
344 | request = stdin.readline()
345 | if not request: # EOF
346 | break
347 |
348 | # Strip whitespace (trailing newline) before parsing
349 | request = request.strip()
350 | if not request:
351 | continue
352 |
353 | response = self.registry.dispatch(request)
354 | if response is not None:
355 | stdout.write(json.dumps(response).encode("utf-8") + b"\n")
356 | stdout.flush()
357 | except (BrokenPipeError, KeyboardInterrupt): # Client disconnected
358 | break
359 |
360 | def cors_localhost(self, origin: str) -> bool:
361 | """Allow CORS requests from localhost on ANY port."""
362 | return urlparse(origin).hostname in ("localhost", "127.0.0.1", "::1")
363 |
364 | def _mcp_ping(self, _meta: dict | None = None) -> dict:
365 | """MCP ping method"""
366 | return {}
367 |
368 | def _mcp_initialize(self, protocolVersion: str, capabilities: dict, clientInfo: dict, _meta: dict | None = None) -> dict:
369 | """MCP initialize method"""
370 | return {
371 | "protocolVersion": getattr(self._protocol_version, "data", protocolVersion),
372 | "capabilities": {
373 | "tools": {},
374 | "resources": {
375 | "subscribe": False,
376 | "listChanged": False,
377 | },
378 | "prompts": {},
379 | },
380 | "serverInfo": {
381 | "name": self.name,
382 | "version": self.version,
383 | },
384 | }
385 |
386 | def _mcp_tools_list(self, _meta: dict | None = None) -> dict:
387 | """MCP tools/list method"""
388 | return {
389 | "tools": [
390 | self._generate_tool_schema(func_name, func)
391 | for func_name, func in self.tools.methods.items()
392 | ],
393 | }
394 |
395 | def _mcp_tools_call(self, name: str, arguments: dict | None = None, _meta: dict | None = None) -> dict:
396 | """MCP tools/call method"""
397 | # Wrap tool call in JSON-RPC request
398 | tool_response = self.tools.dispatch({
399 | "jsonrpc": "2.0",
400 | "method": name,
401 | "params": arguments,
402 | "id": None,
403 | })
404 | assert tool_response is not None, "Only notification requests return None"
405 |
406 | # Check for error response
407 | if "error" in tool_response:
408 | error = tool_response["error"]
409 | return {
410 | "content": [{"type": "text", "text": error["message"] or "Unknown error"}],
411 | "isError": True,
412 | }
413 |
414 | result = tool_response.get("result")
415 | return {
416 | "content": [{"type": "text", "text": json.dumps(result, indent=2)}],
417 | "structuredContent": result if isinstance(result, dict) else {"result": result},
418 | "isError": False,
419 | }
420 |
421 | def _enumerate_resources(self):
422 | for name, func in self.resources.methods.items():
423 | uri: str = getattr(func, "__resource_uri__")
424 | description = (func.__doc__ or f"Read {uri}").strip()
425 | yield uri, name, description
426 |
427 | def _mcp_resources_list(self, _meta: dict | None = None) -> dict:
428 | """MCP resources/list method - returns static resources only (no URI parameters)"""
429 | return {
430 | "resources": [
431 | {
432 | "uri": uri,
433 | "name": name,
434 | "description": description,
435 | "mimeType": "application/json",
436 | }
437 | for uri, name, description in self._enumerate_resources()
438 | if "{" not in uri
439 | ]
440 | }
441 |
442 | def _mcp_resource_templates_list(self, _meta: dict | None = None) -> dict:
443 | """MCP resources/templates/list method - returns parameterized resource templates"""
444 | return {
445 | "resourceTemplates": [
446 | {
447 | "uriTemplate": uri,
448 | "name": name,
449 | "description": description,
450 | "mimeType": "application/json",
451 | }
452 | for uri, name, description in self._enumerate_resources()
453 | if "{" in uri
454 | ]
455 | }
456 |
457 | def _mcp_resources_read(self, uri: str, _meta: dict | None = None) -> dict:
458 | """MCP resources/read method"""
459 |
460 | # Try to match URI against all registered resource patterns
461 | for pattern, name, _ in self._enumerate_resources():
462 | # Convert pattern to regex, replacing {param} with named capture groups
463 | regex_pattern = re.sub(r"\{(\w+)\}", r"(?P<\1>[^/]+)", pattern)
464 | regex_pattern = f"^{regex_pattern}$"
465 |
466 | match = re.match(regex_pattern, uri)
467 | if match:
468 | # Found matching resource - call it via JSON-RPC
469 | params = list(match.groupdict().values())
470 |
471 | resource_response = self.resources.dispatch({
472 | "jsonrpc": "2.0",
473 | "method": name,
474 | "params": params,
475 | "id": None,
476 | })
477 | assert resource_response is not None, "Only notification requests return None"
478 |
479 | if "error" in resource_response:
480 | error = resource_response["error"]
481 | raise JsonRpcException(error["code"], error["message"], error.get("data"))
482 |
483 | return {
484 | "contents": [{
485 | "uri": uri,
486 | "mimeType": "application/json",
487 | "text": json.dumps(resource_response.get("result"), indent=2),
488 | }]
489 | }
490 |
491 | raise JsonRpcException(-32002, "Resource not found", {"uri": uri})
492 |
493 | def _mcp_prompts_list(self, _meta: dict | None = None) -> dict:
494 | """MCP prompts/list method"""
495 | return {
496 | "prompts": [
497 | self._generate_prompt_schema(func_name, func)
498 | for func_name, func in self.prompts.methods.items()
499 | ],
500 | }
501 |
502 | def _mcp_prompts_get(
503 | self, name: str, arguments: dict | None = None, _meta: dict | None = None
504 | ) -> dict:
505 | """MCP prompts/get method"""
506 | # Dispatch to prompts registry
507 | prompt_response = self.prompts.dispatch(
508 | {
509 | "jsonrpc": "2.0",
510 | "method": name,
511 | "params": arguments,
512 | "id": None,
513 | }
514 | )
515 | assert prompt_response is not None, "Only notification requests return None"
516 |
517 | # Check for error response
518 | if "error" in prompt_response:
519 | error = prompt_response["error"]
520 | raise JsonRpcException(error["code"], error["message"], error.get("data"))
521 |
522 | result = prompt_response.get("result")
523 |
524 | # Pass through list of messages directly
525 | if isinstance(result, list):
526 | return {"messages": result}
527 |
528 | # Convert non-string results to JSON
529 | if not isinstance(result, str):
530 | result = json.dumps(result, indent=2)
531 | return {
532 | "messages": [
533 | {
534 | "role": "user",
535 | "content": {"type": "text", "text": result},
536 | },
537 | ],
538 | }
539 |
540 | def _generate_prompt_schema(self, func_name: str, func: Callable) -> dict:
541 | """Generate MCP prompt schema from a function"""
542 | hints = get_type_hints(func, include_extras=True)
543 | hints.pop("return", None)
544 | sig = inspect.signature(func)
545 |
546 | # Build arguments list (PromptArgument format)
547 | arguments = []
548 | for param_name, param_type in hints.items():
549 | arg: dict[str, Any] = {"name": param_name}
550 |
551 | # Extract description from Annotated
552 | origin = get_origin(param_type)
553 | if origin is Annotated:
554 | args = get_args(param_type)
555 | arg["description"] = str(args[-1])
556 |
557 | # Check if required (no default value)
558 | param = sig.parameters.get(param_name)
559 | if not param or param.default is inspect.Parameter.empty:
560 | arg["required"] = True
561 |
562 | arguments.append(arg)
563 |
564 | schema: dict[str, Any] = {
565 | "name": func_name,
566 | "description": (func.__doc__ or f"Prompt {func_name}").strip(),
567 | }
568 |
569 | if arguments:
570 | schema["arguments"] = arguments
571 |
572 | return schema
573 |
574 | def _type_to_json_schema(self, py_type: Any) -> dict:
575 | """Convert Python type hint to JSON schema object"""
576 | origin = get_origin(py_type)
577 | # Annotated[T, "description"]
578 | if origin is Annotated:
579 | args = get_args(py_type)
580 | return {
581 | **self._type_to_json_schema(args[0]),
582 | "description": str(args[-1]),
583 | }
584 |
585 | # NotRequired[T]
586 | if origin is NotRequired:
587 | return self._type_to_json_schema(get_args(py_type)[0])
588 |
589 | # Union[Ts..], Optional[T] and T1 | T2
590 | if origin in (Union, UnionType):
591 | return {"anyOf": [self._type_to_json_schema(t) for t in get_args(py_type)]}
592 |
593 | # list[T]
594 | if origin is list:
595 | return {
596 | "type": "array",
597 | "items": self._type_to_json_schema(get_args(py_type)[0]),
598 | }
599 |
600 | # dict[str, T]
601 | if origin is dict:
602 | return {
603 | "type": "object",
604 | "additionalProperties": self._type_to_json_schema(get_args(py_type)[1]),
605 | }
606 |
607 | # TypedDict
608 | if is_typeddict(py_type):
609 | return self._typed_dict_to_schema(py_type)
610 |
611 | # Primitives
612 | return {
613 | "type": {
614 | int: "integer",
615 | float: "number",
616 | str: "string",
617 | bool: "boolean",
618 | list: "array",
619 | dict: "object",
620 | type(None): "null",
621 | }.get(py_type, "object"),
622 | }
623 |
624 | def _typed_dict_to_schema(self, typed_dict_class) -> dict:
625 | """Convert TypedDict to JSON schema"""
626 | hints = get_type_hints(typed_dict_class, include_extras=True)
627 | required_keys = getattr(typed_dict_class, "__required_keys__", set(hints.keys()))
628 |
629 | return {
630 | "type": "object",
631 | "properties": {
632 | field_name: self._type_to_json_schema(field_type)
633 | for field_name, field_type in hints.items()
634 | },
635 | "required": [key for key in hints.keys() if key in required_keys],
636 | "additionalProperties": False,
637 | }
638 |
639 | def _generate_tool_schema(self, func_name: str, func: Callable) -> dict:
640 | """Generate MCP tool schema from a function"""
641 | hints = get_type_hints(func, include_extras=True)
642 | return_type = hints.pop("return", None)
643 | sig = inspect.signature(func)
644 |
645 | # Build parameter schema
646 | properties = {}
647 | required = []
648 |
649 | for param_name, param_type in hints.items():
650 | properties[param_name] = self._type_to_json_schema(param_type)
651 |
652 | # Add to required if no default value
653 | param = sig.parameters.get(param_name)
654 | if not param or param.default is inspect.Parameter.empty:
655 | required.append(param_name)
656 |
657 | schema: dict[str, Any] = {
658 | "name": func_name,
659 | "description": (func.__doc__ or f"Call {func_name}").strip(),
660 | "inputSchema": {
661 | "type": "object",
662 | "properties": properties,
663 | "required": required,
664 | },
665 | }
666 |
667 | # Add outputSchema if return type exists and is not None
668 | if return_type and return_type is not type(None):
669 | return_schema = self._type_to_json_schema(return_type)
670 |
671 | # Wrap non-object returns in a "result" property
672 | if return_schema.get("type") != "object":
673 | return_schema = {
674 | "type": "object",
675 | "properties": {"result": return_schema},
676 | "required": ["result"],
677 | }
678 |
679 | schema["outputSchema"] = return_schema
680 |
681 | return schema
682 |
--------------------------------------------------------------------------------