├── .github └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── LICENSE ├── README.md ├── client ├── client.go ├── http.go ├── inprocess.go ├── inprocess_test.go ├── interface.go ├── sse.go ├── sse_test.go ├── stdio.go ├── stdio_test.go └── transport │ ├── inprocess.go │ ├── interface.go │ ├── sse.go │ ├── sse_test.go │ ├── stdio.go │ ├── stdio_test.go │ ├── streamable_http.go │ └── streamable_http_test.go ├── examples ├── custom_context │ └── main.go ├── everything │ └── main.go └── filesystem_stdio_client │ └── main.go ├── go.mod ├── go.sum ├── mcp ├── prompts.go ├── resources.go ├── tools.go ├── tools_test.go ├── types.go └── utils.go ├── server ├── hooks.go ├── internal │ └── gen │ │ ├── README.md │ │ ├── data.go │ │ ├── hooks.go.tmpl │ │ ├── main.go │ │ └── request_handler.go.tmpl ├── request_handler.go ├── resource_test.go ├── server.go ├── server_race_test.go ├── server_test.go ├── sse.go ├── sse_test.go ├── stdio.go └── stdio_test.go └── testdata └── mockstdio_server.go /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: go 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-go@v5 13 | with: 14 | go-version-file: 'go.mod' 15 | - run: go test ./... -race 16 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: "Create Release on Tag Push" 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Checkout Code 12 | uses: actions/checkout@v3 13 | 14 | - name: Create GitHub Release 15 | uses: actions/create-release@v1 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | with: 19 | tag_name: ${{ github.ref }} 20 | release_name: Release ${{ github.ref }} 21 | draft: false 22 | prerelease: false 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | .env 3 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anthropic, PBC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "sync" 9 | "sync/atomic" 10 | 11 | "github.com/mark3labs/mcp-go/client/transport" 12 | "github.com/mark3labs/mcp-go/mcp" 13 | ) 14 | 15 | // Client implements the MCP client. 16 | type Client struct { 17 | transport transport.Interface 18 | 19 | initialized bool 20 | notifications []func(mcp.JSONRPCNotification) 21 | notifyMu sync.RWMutex 22 | requestID atomic.Int64 23 | clientCapabilities mcp.ClientCapabilities 24 | serverCapabilities mcp.ServerCapabilities 25 | } 26 | 27 | type ClientOption func(*Client) 28 | 29 | // WithClientCapabilities sets the client capabilities for the client. 30 | func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { 31 | return func(c *Client) { 32 | c.clientCapabilities = capabilities 33 | } 34 | } 35 | 36 | // NewClient creates a new MCP client with the given transport. 37 | // Usage: 38 | // 39 | // stdio := transport.NewStdio("./mcp_server", nil, "xxx") 40 | // client, err := NewClient(stdio) 41 | // if err != nil { 42 | // log.Fatalf("Failed to create client: %v", err) 43 | // } 44 | func NewClient(transport transport.Interface, options ...ClientOption) *Client { 45 | client := &Client{ 46 | transport: transport, 47 | } 48 | 49 | for _, opt := range options { 50 | opt(client) 51 | } 52 | 53 | return client 54 | } 55 | 56 | // Start initiates the connection to the server. 57 | // Must be called before using the client. 58 | func (c *Client) Start(ctx context.Context) error { 59 | if c.transport == nil { 60 | return fmt.Errorf("transport is nil") 61 | } 62 | err := c.transport.Start(ctx) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { 68 | c.notifyMu.RLock() 69 | defer c.notifyMu.RUnlock() 70 | for _, handler := range c.notifications { 71 | handler(notification) 72 | } 73 | }) 74 | return nil 75 | } 76 | 77 | // Close shuts down the client and closes the transport. 78 | func (c *Client) Close() error { 79 | return c.transport.Close() 80 | } 81 | 82 | // OnNotification registers a handler function to be called when notifications are received. 83 | // Multiple handlers can be registered and will be called in the order they were added. 84 | func (c *Client) OnNotification( 85 | handler func(notification mcp.JSONRPCNotification), 86 | ) { 87 | c.notifyMu.Lock() 88 | defer c.notifyMu.Unlock() 89 | c.notifications = append(c.notifications, handler) 90 | } 91 | 92 | // sendRequest sends a JSON-RPC request to the server and waits for a response. 93 | // Returns the raw JSON response message or an error if the request fails. 94 | func (c *Client) sendRequest( 95 | ctx context.Context, 96 | method string, 97 | params interface{}, 98 | ) (*json.RawMessage, error) { 99 | if !c.initialized && method != "initialize" { 100 | return nil, fmt.Errorf("client not initialized") 101 | } 102 | 103 | id := c.requestID.Add(1) 104 | 105 | request := transport.JSONRPCRequest{ 106 | JSONRPC: mcp.JSONRPC_VERSION, 107 | ID: id, 108 | Method: method, 109 | Params: params, 110 | } 111 | 112 | response, err := c.transport.SendRequest(ctx, request) 113 | if err != nil { 114 | return nil, fmt.Errorf("transport error: %w", err) 115 | } 116 | 117 | if response.Error != nil { 118 | return nil, errors.New(response.Error.Message) 119 | } 120 | 121 | return &response.Result, nil 122 | } 123 | 124 | // Initialize negotiates with the server. 125 | // Must be called after Start, and before any request methods. 126 | func (c *Client) Initialize( 127 | ctx context.Context, 128 | request mcp.InitializeRequest, 129 | ) (*mcp.InitializeResult, error) { 130 | // Ensure we send a params object with all required fields 131 | params := struct { 132 | ProtocolVersion string `json:"protocolVersion"` 133 | ClientInfo mcp.Implementation `json:"clientInfo"` 134 | Capabilities mcp.ClientCapabilities `json:"capabilities"` 135 | }{ 136 | ProtocolVersion: request.Params.ProtocolVersion, 137 | ClientInfo: request.Params.ClientInfo, 138 | Capabilities: request.Params.Capabilities, // Will be empty struct if not set 139 | } 140 | 141 | response, err := c.sendRequest(ctx, "initialize", params) 142 | if err != nil { 143 | return nil, err 144 | } 145 | 146 | var result mcp.InitializeResult 147 | if err := json.Unmarshal(*response, &result); err != nil { 148 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 149 | } 150 | 151 | // Store serverCapabilities 152 | c.serverCapabilities = result.Capabilities 153 | 154 | // Send initialized notification 155 | notification := mcp.JSONRPCNotification{ 156 | JSONRPC: mcp.JSONRPC_VERSION, 157 | Notification: mcp.Notification{ 158 | Method: "notifications/initialized", 159 | }, 160 | } 161 | 162 | err = c.transport.SendNotification(ctx, notification) 163 | if err != nil { 164 | return nil, fmt.Errorf( 165 | "failed to send initialized notification: %w", 166 | err, 167 | ) 168 | } 169 | 170 | c.initialized = true 171 | return &result, nil 172 | } 173 | 174 | func (c *Client) Ping(ctx context.Context) error { 175 | _, err := c.sendRequest(ctx, "ping", nil) 176 | return err 177 | } 178 | 179 | // ListResourcesByPage manually list resources by page. 180 | func (c *Client) ListResourcesByPage( 181 | ctx context.Context, 182 | request mcp.ListResourcesRequest, 183 | ) (*mcp.ListResourcesResult, error) { 184 | result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list") 185 | if err != nil { 186 | return nil, err 187 | } 188 | return result, nil 189 | } 190 | 191 | func (c *Client) ListResources( 192 | ctx context.Context, 193 | request mcp.ListResourcesRequest, 194 | ) (*mcp.ListResourcesResult, error) { 195 | result, err := c.ListResourcesByPage(ctx, request) 196 | if err != nil { 197 | return nil, err 198 | } 199 | for result.NextCursor != "" { 200 | select { 201 | case <-ctx.Done(): 202 | return nil, ctx.Err() 203 | default: 204 | request.Params.Cursor = result.NextCursor 205 | newPageRes, err := c.ListResourcesByPage(ctx, request) 206 | if err != nil { 207 | return nil, err 208 | } 209 | result.Resources = append(result.Resources, newPageRes.Resources...) 210 | result.NextCursor = newPageRes.NextCursor 211 | } 212 | } 213 | return result, nil 214 | } 215 | 216 | func (c *Client) ListResourceTemplatesByPage( 217 | ctx context.Context, 218 | request mcp.ListResourceTemplatesRequest, 219 | ) (*mcp.ListResourceTemplatesResult, error) { 220 | result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list") 221 | if err != nil { 222 | return nil, err 223 | } 224 | return result, nil 225 | } 226 | 227 | func (c *Client) ListResourceTemplates( 228 | ctx context.Context, 229 | request mcp.ListResourceTemplatesRequest, 230 | ) (*mcp.ListResourceTemplatesResult, error) { 231 | result, err := c.ListResourceTemplatesByPage(ctx, request) 232 | if err != nil { 233 | return nil, err 234 | } 235 | for result.NextCursor != "" { 236 | select { 237 | case <-ctx.Done(): 238 | return nil, ctx.Err() 239 | default: 240 | request.Params.Cursor = result.NextCursor 241 | newPageRes, err := c.ListResourceTemplatesByPage(ctx, request) 242 | if err != nil { 243 | return nil, err 244 | } 245 | result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...) 246 | result.NextCursor = newPageRes.NextCursor 247 | } 248 | } 249 | return result, nil 250 | } 251 | 252 | func (c *Client) ReadResource( 253 | ctx context.Context, 254 | request mcp.ReadResourceRequest, 255 | ) (*mcp.ReadResourceResult, error) { 256 | response, err := c.sendRequest(ctx, "resources/read", request.Params) 257 | if err != nil { 258 | return nil, err 259 | } 260 | 261 | return mcp.ParseReadResourceResult(response) 262 | } 263 | 264 | func (c *Client) Subscribe( 265 | ctx context.Context, 266 | request mcp.SubscribeRequest, 267 | ) error { 268 | _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) 269 | return err 270 | } 271 | 272 | func (c *Client) Unsubscribe( 273 | ctx context.Context, 274 | request mcp.UnsubscribeRequest, 275 | ) error { 276 | _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) 277 | return err 278 | } 279 | 280 | func (c *Client) ListPromptsByPage( 281 | ctx context.Context, 282 | request mcp.ListPromptsRequest, 283 | ) (*mcp.ListPromptsResult, error) { 284 | result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list") 285 | if err != nil { 286 | return nil, err 287 | } 288 | return result, nil 289 | } 290 | 291 | func (c *Client) ListPrompts( 292 | ctx context.Context, 293 | request mcp.ListPromptsRequest, 294 | ) (*mcp.ListPromptsResult, error) { 295 | result, err := c.ListPromptsByPage(ctx, request) 296 | if err != nil { 297 | return nil, err 298 | } 299 | for result.NextCursor != "" { 300 | select { 301 | case <-ctx.Done(): 302 | return nil, ctx.Err() 303 | default: 304 | request.Params.Cursor = result.NextCursor 305 | newPageRes, err := c.ListPromptsByPage(ctx, request) 306 | if err != nil { 307 | return nil, err 308 | } 309 | result.Prompts = append(result.Prompts, newPageRes.Prompts...) 310 | result.NextCursor = newPageRes.NextCursor 311 | } 312 | } 313 | return result, nil 314 | } 315 | 316 | func (c *Client) GetPrompt( 317 | ctx context.Context, 318 | request mcp.GetPromptRequest, 319 | ) (*mcp.GetPromptResult, error) { 320 | response, err := c.sendRequest(ctx, "prompts/get", request.Params) 321 | if err != nil { 322 | return nil, err 323 | } 324 | 325 | return mcp.ParseGetPromptResult(response) 326 | } 327 | 328 | func (c *Client) ListToolsByPage( 329 | ctx context.Context, 330 | request mcp.ListToolsRequest, 331 | ) (*mcp.ListToolsResult, error) { 332 | result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list") 333 | if err != nil { 334 | return nil, err 335 | } 336 | return result, nil 337 | } 338 | 339 | func (c *Client) ListTools( 340 | ctx context.Context, 341 | request mcp.ListToolsRequest, 342 | ) (*mcp.ListToolsResult, error) { 343 | result, err := c.ListToolsByPage(ctx, request) 344 | if err != nil { 345 | return nil, err 346 | } 347 | for result.NextCursor != "" { 348 | select { 349 | case <-ctx.Done(): 350 | return nil, ctx.Err() 351 | default: 352 | request.Params.Cursor = result.NextCursor 353 | newPageRes, err := c.ListToolsByPage(ctx, request) 354 | if err != nil { 355 | return nil, err 356 | } 357 | result.Tools = append(result.Tools, newPageRes.Tools...) 358 | result.NextCursor = newPageRes.NextCursor 359 | } 360 | } 361 | return result, nil 362 | } 363 | 364 | func (c *Client) CallTool( 365 | ctx context.Context, 366 | request mcp.CallToolRequest, 367 | ) (*mcp.CallToolResult, error) { 368 | response, err := c.sendRequest(ctx, "tools/call", request.Params) 369 | if err != nil { 370 | return nil, err 371 | } 372 | 373 | return mcp.ParseCallToolResult(response) 374 | } 375 | 376 | func (c *Client) SetLevel( 377 | ctx context.Context, 378 | request mcp.SetLevelRequest, 379 | ) error { 380 | _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) 381 | return err 382 | } 383 | 384 | func (c *Client) Complete( 385 | ctx context.Context, 386 | request mcp.CompleteRequest, 387 | ) (*mcp.CompleteResult, error) { 388 | response, err := c.sendRequest(ctx, "completion/complete", request.Params) 389 | if err != nil { 390 | return nil, err 391 | } 392 | 393 | var result mcp.CompleteResult 394 | if err := json.Unmarshal(*response, &result); err != nil { 395 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 396 | } 397 | 398 | return &result, nil 399 | } 400 | 401 | func listByPage[T any]( 402 | ctx context.Context, 403 | client *Client, 404 | request mcp.PaginatedRequest, 405 | method string, 406 | ) (*T, error) { 407 | response, err := client.sendRequest(ctx, method, request.Params) 408 | if err != nil { 409 | return nil, err 410 | } 411 | var result T 412 | if err := json.Unmarshal(*response, &result); err != nil { 413 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 414 | } 415 | return &result, nil 416 | } 417 | 418 | // Helper methods 419 | 420 | // GetTransport gives access to the underlying transport layer. 421 | // Cast it to the specific transport type and obtain the other helper methods. 422 | func (c *Client) GetTransport() transport.Interface { 423 | return c.transport 424 | } 425 | 426 | // GetServerCapabilities returns the server capabilities. 427 | func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { 428 | return c.serverCapabilities 429 | } 430 | 431 | // GetClientCapabilities returns the client capabilities. 432 | func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { 433 | return c.clientCapabilities 434 | } 435 | -------------------------------------------------------------------------------- /client/http.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/mark3labs/mcp-go/client/transport" 7 | ) 8 | 9 | // NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client 10 | // with the given base URL. Returns an error if the URL is invalid. 11 | func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) { 12 | trans, err := transport.NewStreamableHTTP(baseURL, options...) 13 | if err != nil { 14 | return nil, fmt.Errorf("failed to create SSE transport: %w", err) 15 | } 16 | return NewClient(trans), nil 17 | } 18 | -------------------------------------------------------------------------------- /client/inprocess.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/mark3labs/mcp-go/client/transport" 5 | "github.com/mark3labs/mcp-go/server" 6 | ) 7 | 8 | // NewInProcessClient connect directly to a mcp server object in the same process 9 | func NewInProcessClient(server *server.MCPServer) (*Client, error) { 10 | inProcessTransport := transport.NewInProcessTransport(server) 11 | return NewClient(inProcessTransport), nil 12 | } 13 | -------------------------------------------------------------------------------- /client/inprocess_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | "github.com/mark3labs/mcp-go/server" 9 | ) 10 | 11 | func TestInProcessMCPClient(t *testing.T) { 12 | mcpServer := server.NewMCPServer( 13 | "test-server", 14 | "1.0.0", 15 | server.WithResourceCapabilities(true, true), 16 | server.WithPromptCapabilities(true), 17 | server.WithToolCapabilities(true), 18 | ) 19 | 20 | // Add a test tool 21 | mcpServer.AddTool(mcp.NewTool( 22 | "test-tool", 23 | mcp.WithDescription("Test tool"), 24 | mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), 25 | mcp.WithToolAnnotation(mcp.ToolAnnotation{ 26 | Title: "Test Tool Annotation Title", 27 | ReadOnlyHint: true, 28 | DestructiveHint: false, 29 | IdempotentHint: true, 30 | OpenWorldHint: false, 31 | }), 32 | ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 33 | return &mcp.CallToolResult{ 34 | Content: []mcp.Content{ 35 | mcp.TextContent{ 36 | Type: "text", 37 | Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), 38 | }, 39 | }, 40 | }, nil 41 | }) 42 | 43 | mcpServer.AddResource( 44 | mcp.Resource{ 45 | URI: "resource://testresource", 46 | Name: "My Resource", 47 | }, 48 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 49 | return []mcp.ResourceContents{ 50 | mcp.TextResourceContents{ 51 | URI: "resource://testresource", 52 | MIMEType: "text/plain", 53 | Text: "test content", 54 | }, 55 | }, nil 56 | }, 57 | ) 58 | 59 | mcpServer.AddPrompt( 60 | mcp.Prompt{ 61 | Name: "test-prompt", 62 | Description: "A test prompt", 63 | Arguments: []mcp.PromptArgument{ 64 | { 65 | Name: "arg1", 66 | Description: "First argument", 67 | }, 68 | }, 69 | }, 70 | func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 71 | return &mcp.GetPromptResult{ 72 | Messages: []mcp.PromptMessage{ 73 | { 74 | Role: mcp.RoleAssistant, 75 | Content: mcp.TextContent{ 76 | Type: "text", 77 | Text: "Test prompt with arg1: " + request.Params.Arguments["arg1"], 78 | }, 79 | }, 80 | }, 81 | }, nil 82 | }, 83 | ) 84 | 85 | t.Run("Can initialize and make requests", func(t *testing.T) { 86 | client, err := NewInProcessClient(mcpServer) 87 | if err != nil { 88 | t.Fatalf("Failed to create client: %v", err) 89 | } 90 | defer client.Close() 91 | 92 | // Start the client 93 | if err := client.Start(context.Background()); err != nil { 94 | t.Fatalf("Failed to start client: %v", err) 95 | } 96 | 97 | // Initialize 98 | initRequest := mcp.InitializeRequest{} 99 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 100 | initRequest.Params.ClientInfo = mcp.Implementation{ 101 | Name: "test-client", 102 | Version: "1.0.0", 103 | } 104 | 105 | result, err := client.Initialize(context.Background(), initRequest) 106 | if err != nil { 107 | t.Fatalf("Failed to initialize: %v", err) 108 | } 109 | 110 | if result.ServerInfo.Name != "test-server" { 111 | t.Errorf( 112 | "Expected server name 'test-server', got '%s'", 113 | result.ServerInfo.Name, 114 | ) 115 | } 116 | 117 | // Test Ping 118 | if err := client.Ping(context.Background()); err != nil { 119 | t.Errorf("Ping failed: %v", err) 120 | } 121 | 122 | // Test ListTools 123 | toolsRequest := mcp.ListToolsRequest{} 124 | toolListResult, err := client.ListTools(context.Background(), toolsRequest) 125 | if err != nil { 126 | t.Errorf("ListTools failed: %v", err) 127 | } 128 | if toolListResult == nil || len((*toolListResult).Tools) == 0 { 129 | t.Errorf("Expected one tool") 130 | } 131 | testToolAnnotations := (*toolListResult).Tools[0].Annotations 132 | if testToolAnnotations.Title != "Test Tool Annotation Title" || 133 | testToolAnnotations.ReadOnlyHint != true || 134 | testToolAnnotations.DestructiveHint != false || 135 | testToolAnnotations.IdempotentHint != true || 136 | testToolAnnotations.OpenWorldHint != false { 137 | t.Errorf("The annotations of the tools are invalid") 138 | } 139 | }) 140 | 141 | t.Run("Handles errors properly", func(t *testing.T) { 142 | client, err := NewInProcessClient(mcpServer) 143 | if err != nil { 144 | t.Fatalf("Failed to create client: %v", err) 145 | } 146 | defer client.Close() 147 | 148 | if err := client.Start(context.Background()); err != nil { 149 | t.Fatalf("Failed to start client: %v", err) 150 | } 151 | 152 | // Try to make a request without initializing 153 | toolsRequest := mcp.ListToolsRequest{} 154 | _, err = client.ListTools(context.Background(), toolsRequest) 155 | if err == nil { 156 | t.Error("Expected error when making request before initialization") 157 | } 158 | }) 159 | 160 | t.Run("CallTool", func(t *testing.T) { 161 | client, err := NewInProcessClient(mcpServer) 162 | if err != nil { 163 | t.Fatalf("Failed to create client: %v", err) 164 | } 165 | defer client.Close() 166 | 167 | if err := client.Start(context.Background()); err != nil { 168 | t.Fatalf("Failed to start client: %v", err) 169 | } 170 | 171 | // Initialize 172 | initRequest := mcp.InitializeRequest{} 173 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 174 | initRequest.Params.ClientInfo = mcp.Implementation{ 175 | Name: "test-client", 176 | Version: "1.0.0", 177 | } 178 | 179 | _, err = client.Initialize(context.Background(), initRequest) 180 | if err != nil { 181 | t.Fatalf("Failed to initialize: %v", err) 182 | } 183 | 184 | request := mcp.CallToolRequest{} 185 | request.Params.Name = "test-tool" 186 | request.Params.Arguments = map[string]interface{}{ 187 | "parameter-1": "value1", 188 | } 189 | 190 | result, err := client.CallTool(context.Background(), request) 191 | if err != nil { 192 | t.Fatalf("CallTool failed: %v", err) 193 | } 194 | 195 | if len(result.Content) != 1 { 196 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 197 | } 198 | }) 199 | 200 | t.Run("Ping", func(t *testing.T) { 201 | client, err := NewInProcessClient(mcpServer) 202 | if err != nil { 203 | t.Fatalf("Failed to create client: %v", err) 204 | } 205 | defer client.Close() 206 | 207 | if err := client.Start(context.Background()); err != nil { 208 | t.Fatalf("Failed to start client: %v", err) 209 | } 210 | 211 | // Initialize 212 | initRequest := mcp.InitializeRequest{} 213 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 214 | initRequest.Params.ClientInfo = mcp.Implementation{ 215 | Name: "test-client", 216 | Version: "1.0.0", 217 | } 218 | 219 | _, err = client.Initialize(context.Background(), initRequest) 220 | if err != nil { 221 | t.Fatalf("Failed to initialize: %v", err) 222 | } 223 | 224 | err = client.Ping(context.Background()) 225 | if err != nil { 226 | t.Errorf("Ping failed: %v", err) 227 | } 228 | }) 229 | 230 | t.Run("ListResources", func(t *testing.T) { 231 | client, err := NewInProcessClient(mcpServer) 232 | if err != nil { 233 | t.Fatalf("Failed to create client: %v", err) 234 | } 235 | defer client.Close() 236 | 237 | if err := client.Start(context.Background()); err != nil { 238 | t.Fatalf("Failed to start client: %v", err) 239 | } 240 | 241 | // Initialize 242 | initRequest := mcp.InitializeRequest{} 243 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 244 | initRequest.Params.ClientInfo = mcp.Implementation{ 245 | Name: "test-client", 246 | Version: "1.0.0", 247 | } 248 | 249 | _, err = client.Initialize(context.Background(), initRequest) 250 | if err != nil { 251 | t.Fatalf("Failed to initialize: %v", err) 252 | } 253 | 254 | request := mcp.ListResourcesRequest{} 255 | result, err := client.ListResources(context.Background(), request) 256 | if err != nil { 257 | t.Errorf("ListResources failed: %v", err) 258 | } 259 | 260 | if len(result.Resources) != 1 { 261 | t.Errorf("Expected 1 resource, got %d", len(result.Resources)) 262 | } 263 | }) 264 | 265 | t.Run("ReadResource", func(t *testing.T) { 266 | client, err := NewInProcessClient(mcpServer) 267 | if err != nil { 268 | t.Fatalf("Failed to create client: %v", err) 269 | } 270 | defer client.Close() 271 | 272 | if err := client.Start(context.Background()); err != nil { 273 | t.Fatalf("Failed to start client: %v", err) 274 | } 275 | 276 | // Initialize 277 | initRequest := mcp.InitializeRequest{} 278 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 279 | initRequest.Params.ClientInfo = mcp.Implementation{ 280 | Name: "test-client", 281 | Version: "1.0.0", 282 | } 283 | 284 | _, err = client.Initialize(context.Background(), initRequest) 285 | if err != nil { 286 | t.Fatalf("Failed to initialize: %v", err) 287 | } 288 | 289 | request := mcp.ReadResourceRequest{} 290 | request.Params.URI = "resource://testresource" 291 | 292 | result, err := client.ReadResource(context.Background(), request) 293 | if err != nil { 294 | t.Errorf("ReadResource failed: %v", err) 295 | } 296 | 297 | if len(result.Contents) != 1 { 298 | t.Errorf("Expected 1 content item, got %d", len(result.Contents)) 299 | } 300 | }) 301 | 302 | t.Run("ListPrompts", func(t *testing.T) { 303 | client, err := NewInProcessClient(mcpServer) 304 | if err != nil { 305 | t.Fatalf("Failed to create client: %v", err) 306 | } 307 | defer client.Close() 308 | 309 | if err := client.Start(context.Background()); err != nil { 310 | t.Fatalf("Failed to start client: %v", err) 311 | } 312 | 313 | // Initialize 314 | initRequest := mcp.InitializeRequest{} 315 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 316 | initRequest.Params.ClientInfo = mcp.Implementation{ 317 | Name: "test-client", 318 | Version: "1.0.0", 319 | } 320 | 321 | _, err = client.Initialize(context.Background(), initRequest) 322 | if err != nil { 323 | t.Fatalf("Failed to initialize: %v", err) 324 | } 325 | request := mcp.ListPromptsRequest{} 326 | result, err := client.ListPrompts(context.Background(), request) 327 | if err != nil { 328 | t.Errorf("ListPrompts failed: %v", err) 329 | } 330 | 331 | if len(result.Prompts) != 1 { 332 | t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) 333 | } 334 | }) 335 | 336 | t.Run("GetPrompt", func(t *testing.T) { 337 | client, err := NewInProcessClient(mcpServer) 338 | if err != nil { 339 | t.Fatalf("Failed to create client: %v", err) 340 | } 341 | defer client.Close() 342 | 343 | if err := client.Start(context.Background()); err != nil { 344 | t.Fatalf("Failed to start client: %v", err) 345 | } 346 | 347 | // Initialize 348 | initRequest := mcp.InitializeRequest{} 349 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 350 | initRequest.Params.ClientInfo = mcp.Implementation{ 351 | Name: "test-client", 352 | Version: "1.0.0", 353 | } 354 | 355 | _, err = client.Initialize(context.Background(), initRequest) 356 | if err != nil { 357 | t.Fatalf("Failed to initialize: %v", err) 358 | } 359 | 360 | request := mcp.GetPromptRequest{} 361 | request.Params.Name = "test-prompt" 362 | 363 | result, err := client.GetPrompt(context.Background(), request) 364 | if err != nil { 365 | t.Errorf("GetPrompt failed: %v", err) 366 | } 367 | 368 | if len(result.Messages) != 1 { 369 | t.Errorf("Expected 1 message, got %d", len(result.Messages)) 370 | } 371 | }) 372 | 373 | t.Run("ListTools", func(t *testing.T) { 374 | client, err := NewInProcessClient(mcpServer) 375 | if err != nil { 376 | t.Fatalf("Failed to create client: %v", err) 377 | } 378 | defer client.Close() 379 | 380 | if err := client.Start(context.Background()); err != nil { 381 | t.Fatalf("Failed to start client: %v", err) 382 | } 383 | 384 | // Initialize 385 | initRequest := mcp.InitializeRequest{} 386 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 387 | initRequest.Params.ClientInfo = mcp.Implementation{ 388 | Name: "test-client", 389 | Version: "1.0.0", 390 | } 391 | 392 | _, err = client.Initialize(context.Background(), initRequest) 393 | if err != nil { 394 | t.Fatalf("Failed to initialize: %v", err) 395 | } 396 | 397 | request := mcp.ListToolsRequest{} 398 | result, err := client.ListTools(context.Background(), request) 399 | if err != nil { 400 | t.Errorf("ListTools failed: %v", err) 401 | } 402 | 403 | if len(result.Tools) != 1 { 404 | t.Errorf("Expected 1 tool, got %d", len(result.Tools)) 405 | } 406 | }) 407 | } 408 | -------------------------------------------------------------------------------- /client/interface.go: -------------------------------------------------------------------------------- 1 | // Package client provides MCP (Model Control Protocol) client implementations. 2 | package client 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | ) 9 | 10 | // MCPClient represents an MCP client interface 11 | type MCPClient interface { 12 | // Initialize sends the initial connection request to the server 13 | Initialize( 14 | ctx context.Context, 15 | request mcp.InitializeRequest, 16 | ) (*mcp.InitializeResult, error) 17 | 18 | // Ping checks if the server is alive 19 | Ping(ctx context.Context) error 20 | 21 | // ListResourcesByPage manually list resources by page. 22 | ListResourcesByPage( 23 | ctx context.Context, 24 | request mcp.ListResourcesRequest, 25 | ) (*mcp.ListResourcesResult, error) 26 | 27 | // ListResources requests a list of available resources from the server 28 | ListResources( 29 | ctx context.Context, 30 | request mcp.ListResourcesRequest, 31 | ) (*mcp.ListResourcesResult, error) 32 | 33 | // ListResourceTemplatesByPage manually list resource templates by page. 34 | ListResourceTemplatesByPage( 35 | ctx context.Context, 36 | request mcp.ListResourceTemplatesRequest, 37 | ) (*mcp.ListResourceTemplatesResult, 38 | error) 39 | 40 | // ListResourceTemplates requests a list of available resource templates from the server 41 | ListResourceTemplates( 42 | ctx context.Context, 43 | request mcp.ListResourceTemplatesRequest, 44 | ) (*mcp.ListResourceTemplatesResult, 45 | error) 46 | 47 | // ReadResource reads a specific resource from the server 48 | ReadResource( 49 | ctx context.Context, 50 | request mcp.ReadResourceRequest, 51 | ) (*mcp.ReadResourceResult, error) 52 | 53 | // Subscribe requests notifications for changes to a specific resource 54 | Subscribe(ctx context.Context, request mcp.SubscribeRequest) error 55 | 56 | // Unsubscribe cancels notifications for a specific resource 57 | Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error 58 | 59 | // ListPromptsByPage manually list prompts by page. 60 | ListPromptsByPage( 61 | ctx context.Context, 62 | request mcp.ListPromptsRequest, 63 | ) (*mcp.ListPromptsResult, error) 64 | 65 | // ListPrompts requests a list of available prompts from the server 66 | ListPrompts( 67 | ctx context.Context, 68 | request mcp.ListPromptsRequest, 69 | ) (*mcp.ListPromptsResult, error) 70 | 71 | // GetPrompt retrieves a specific prompt from the server 72 | GetPrompt( 73 | ctx context.Context, 74 | request mcp.GetPromptRequest, 75 | ) (*mcp.GetPromptResult, error) 76 | 77 | // ListToolsByPage manually list tools by page. 78 | ListToolsByPage( 79 | ctx context.Context, 80 | request mcp.ListToolsRequest, 81 | ) (*mcp.ListToolsResult, error) 82 | 83 | // ListTools requests a list of available tools from the server 84 | ListTools( 85 | ctx context.Context, 86 | request mcp.ListToolsRequest, 87 | ) (*mcp.ListToolsResult, error) 88 | 89 | // CallTool invokes a specific tool on the server 90 | CallTool( 91 | ctx context.Context, 92 | request mcp.CallToolRequest, 93 | ) (*mcp.CallToolResult, error) 94 | 95 | // SetLevel sets the logging level for the server 96 | SetLevel(ctx context.Context, request mcp.SetLevelRequest) error 97 | 98 | // Complete requests completion options for a given argument 99 | Complete( 100 | ctx context.Context, 101 | request mcp.CompleteRequest, 102 | ) (*mcp.CompleteResult, error) 103 | 104 | // Close client connection and cleanup resources 105 | Close() error 106 | 107 | // OnNotification registers a handler for notifications 108 | OnNotification(handler func(notification mcp.JSONRPCNotification)) 109 | } 110 | -------------------------------------------------------------------------------- /client/sse.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "github.com/mark3labs/mcp-go/client/transport" 6 | "net/url" 7 | ) 8 | 9 | func WithHeaders(headers map[string]string) transport.ClientOption { 10 | return transport.WithHeaders(headers) 11 | } 12 | 13 | // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. 14 | // Returns an error if the URL is invalid. 15 | func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { 16 | 17 | sseTransport, err := transport.NewSSE(baseURL, options...) 18 | if err != nil { 19 | return nil, fmt.Errorf("failed to create SSE transport: %w", err) 20 | } 21 | 22 | return NewClient(sseTransport), nil 23 | } 24 | 25 | // GetEndpoint returns the current endpoint URL for the SSE connection. 26 | // 27 | // Note: This method only works with SSE transport, or it will panic. 28 | func GetEndpoint(c *Client) *url.URL { 29 | t := c.GetTransport() 30 | sse := t.(*transport.SSE) 31 | return sse.GetEndpoint() 32 | } 33 | -------------------------------------------------------------------------------- /client/sse_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "github.com/mark3labs/mcp-go/client/transport" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mark3labs/mcp-go/mcp" 10 | "github.com/mark3labs/mcp-go/server" 11 | ) 12 | 13 | func TestSSEMCPClient(t *testing.T) { 14 | // Create MCP server with capabilities 15 | mcpServer := server.NewMCPServer( 16 | "test-server", 17 | "1.0.0", 18 | server.WithResourceCapabilities(true, true), 19 | server.WithPromptCapabilities(true), 20 | server.WithToolCapabilities(true), 21 | ) 22 | 23 | // Add a test tool 24 | mcpServer.AddTool(mcp.NewTool( 25 | "test-tool", 26 | mcp.WithDescription("Test tool"), 27 | mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), 28 | mcp.WithToolAnnotation(mcp.ToolAnnotation{ 29 | Title: "Test Tool Annotation Title", 30 | ReadOnlyHint: true, 31 | DestructiveHint: false, 32 | IdempotentHint: true, 33 | OpenWorldHint: false, 34 | }), 35 | ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 36 | return &mcp.CallToolResult{ 37 | Content: []mcp.Content{ 38 | mcp.TextContent{ 39 | Type: "text", 40 | Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), 41 | }, 42 | }, 43 | }, nil 44 | }) 45 | 46 | // Initialize 47 | testServer := server.NewTestServer(mcpServer) 48 | defer testServer.Close() 49 | 50 | t.Run("Can create client", func(t *testing.T) { 51 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 52 | if err != nil { 53 | t.Fatalf("Failed to create client: %v", err) 54 | } 55 | defer client.Close() 56 | 57 | sseTransport := client.GetTransport().(*transport.SSE) 58 | if sseTransport.GetBaseURL() == nil { 59 | t.Error("Base URL should not be nil") 60 | } 61 | }) 62 | 63 | t.Run("Can initialize and make requests", func(t *testing.T) { 64 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 65 | if err != nil { 66 | t.Fatalf("Failed to create client: %v", err) 67 | } 68 | defer client.Close() 69 | 70 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 71 | defer cancel() 72 | 73 | // Start the client 74 | if err := client.Start(ctx); err != nil { 75 | t.Fatalf("Failed to start client: %v", err) 76 | } 77 | 78 | // Initialize 79 | initRequest := mcp.InitializeRequest{} 80 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 81 | initRequest.Params.ClientInfo = mcp.Implementation{ 82 | Name: "test-client", 83 | Version: "1.0.0", 84 | } 85 | 86 | result, err := client.Initialize(ctx, initRequest) 87 | if err != nil { 88 | t.Fatalf("Failed to initialize: %v", err) 89 | } 90 | 91 | if result.ServerInfo.Name != "test-server" { 92 | t.Errorf( 93 | "Expected server name 'test-server', got '%s'", 94 | result.ServerInfo.Name, 95 | ) 96 | } 97 | 98 | // Test Ping 99 | if err := client.Ping(ctx); err != nil { 100 | t.Errorf("Ping failed: %v", err) 101 | } 102 | 103 | // Test ListTools 104 | toolsRequest := mcp.ListToolsRequest{} 105 | toolListResult, err := client.ListTools(ctx, toolsRequest) 106 | if err != nil { 107 | t.Errorf("ListTools failed: %v", err) 108 | } 109 | if toolListResult == nil || len((*toolListResult).Tools) == 0 { 110 | t.Errorf("Expected one tool") 111 | } 112 | testToolAnnotations := (*toolListResult).Tools[0].Annotations 113 | if testToolAnnotations.Title != "Test Tool Annotation Title" || 114 | testToolAnnotations.ReadOnlyHint != true || 115 | testToolAnnotations.DestructiveHint != false || 116 | testToolAnnotations.IdempotentHint != true || 117 | testToolAnnotations.OpenWorldHint != false { 118 | t.Errorf("The annotations of the tools are invalid") 119 | } 120 | }) 121 | 122 | // t.Run("Can handle notifications", func(t *testing.T) { 123 | // client, err := NewSSEMCPClient(testServer.URL + "/sse") 124 | // if err != nil { 125 | // t.Fatalf("Failed to create client: %v", err) 126 | // } 127 | // defer client.Close() 128 | 129 | // notificationReceived := make(chan mcp.JSONRPCNotification, 1) 130 | // client.OnNotification(func(notification mcp.JSONRPCNotification) { 131 | // notificationReceived <- notification 132 | // }) 133 | 134 | // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 135 | // defer cancel() 136 | 137 | // if err := client.Start(ctx); err != nil { 138 | // t.Fatalf("Failed to start client: %v", err) 139 | // } 140 | 141 | // // Initialize first 142 | // initRequest := mcp.InitializeRequest{} 143 | // initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 144 | // initRequest.Params.ClientInfo = mcp.Implementation{ 145 | // Name: "test-client", 146 | // Version: "1.0.0", 147 | // } 148 | 149 | // _, err = client.Initialize(ctx, initRequest) 150 | // if err != nil { 151 | // t.Fatalf("Failed to initialize: %v", err) 152 | // } 153 | 154 | // // Subscribe to a resource to test notifications 155 | // subRequest := mcp.SubscribeRequest{} 156 | // subRequest.Params.URI = "test://resource" 157 | // if err := client.Subscribe(ctx, subRequest); err != nil { 158 | // t.Fatalf("Failed to subscribe: %v", err) 159 | // } 160 | 161 | // select { 162 | // case <-notificationReceived: 163 | // // Success 164 | // case <-time.After(time.Second): 165 | // t.Error("Timeout waiting for notification") 166 | // } 167 | // }) 168 | 169 | t.Run("Handles errors properly", func(t *testing.T) { 170 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 171 | if err != nil { 172 | t.Fatalf("Failed to create client: %v", err) 173 | } 174 | defer client.Close() 175 | 176 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 177 | defer cancel() 178 | 179 | if err := client.Start(ctx); err != nil { 180 | t.Fatalf("Failed to start client: %v", err) 181 | } 182 | 183 | // Try to make a request without initializing 184 | toolsRequest := mcp.ListToolsRequest{} 185 | _, err = client.ListTools(ctx, toolsRequest) 186 | if err == nil { 187 | t.Error("Expected error when making request before initialization") 188 | } 189 | }) 190 | 191 | // t.Run("Handles context cancellation", func(t *testing.T) { 192 | // client, err := NewSSEMCPClient(testServer.URL + "/sse") 193 | // if err != nil { 194 | // t.Fatalf("Failed to create client: %v", err) 195 | // } 196 | // defer client.Close() 197 | 198 | // if err := client.Start(context.Background()); err != nil { 199 | // t.Fatalf("Failed to start client: %v", err) 200 | // } 201 | 202 | // ctx, cancel := context.WithCancel(context.Background()) 203 | // cancel() // Cancel immediately 204 | 205 | // toolsRequest := mcp.ListToolsRequest{} 206 | // _, err = client.ListTools(ctx, toolsRequest) 207 | // if err == nil { 208 | // t.Error("Expected error when context is cancelled") 209 | // } 210 | // }) 211 | 212 | t.Run("CallTool", func(t *testing.T) { 213 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 214 | if err != nil { 215 | t.Fatalf("Failed to create client: %v", err) 216 | } 217 | defer client.Close() 218 | 219 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 220 | defer cancel() 221 | 222 | if err := client.Start(ctx); err != nil { 223 | t.Fatalf("Failed to start client: %v", err) 224 | } 225 | 226 | // Initialize 227 | initRequest := mcp.InitializeRequest{} 228 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 229 | initRequest.Params.ClientInfo = mcp.Implementation{ 230 | Name: "test-client", 231 | Version: "1.0.0", 232 | } 233 | 234 | _, err = client.Initialize(ctx, initRequest) 235 | if err != nil { 236 | t.Fatalf("Failed to initialize: %v", err) 237 | } 238 | 239 | request := mcp.CallToolRequest{} 240 | request.Params.Name = "test-tool" 241 | request.Params.Arguments = map[string]interface{}{ 242 | "parameter-1": "value1", 243 | } 244 | 245 | result, err := client.CallTool(ctx, request) 246 | if err != nil { 247 | t.Fatalf("CallTool failed: %v", err) 248 | } 249 | 250 | if len(result.Content) != 1 { 251 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 252 | } 253 | }) 254 | } 255 | -------------------------------------------------------------------------------- /client/stdio.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/mark3labs/mcp-go/client/transport" 9 | ) 10 | 11 | // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. 12 | // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. 13 | // Returns an error if the subprocess cannot be started or the pipes cannot be created. 14 | // 15 | // NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. 16 | // This is for backward compatibility. 17 | func NewStdioMCPClient( 18 | command string, 19 | env []string, 20 | args ...string, 21 | ) (*Client, error) { 22 | 23 | stdioTransport := transport.NewStdio(command, env, args...) 24 | err := stdioTransport.Start(context.Background()) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to start stdio transport: %w", err) 27 | } 28 | 29 | return NewClient(stdioTransport), nil 30 | } 31 | 32 | // GetStderr returns a reader for the stderr output of the subprocess. 33 | // This can be used to capture error messages or logs from the subprocess. 34 | func GetStderr(c *Client) (io.Reader, bool) { 35 | t := c.GetTransport() 36 | 37 | stdio, ok := t.(*transport.Stdio) 38 | if !ok { 39 | return nil, false 40 | } 41 | 42 | return stdio.Stderr(), true 43 | } 44 | -------------------------------------------------------------------------------- /client/stdio_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | func compileTestServer(outputPath string) error { 19 | cmd := exec.Command( 20 | "go", 21 | "build", 22 | "-o", 23 | outputPath, 24 | "../testdata/mockstdio_server.go", 25 | ) 26 | if output, err := cmd.CombinedOutput(); err != nil { 27 | return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) 28 | } 29 | return nil 30 | } 31 | 32 | func TestStdioMCPClient(t *testing.T) { 33 | // Compile mock server 34 | mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") 35 | if err := compileTestServer(mockServerPath); err != nil { 36 | t.Fatalf("Failed to compile mock server: %v", err) 37 | } 38 | defer os.Remove(mockServerPath) 39 | 40 | client, err := NewStdioMCPClient(mockServerPath, []string{}) 41 | if err != nil { 42 | t.Fatalf("Failed to create client: %v", err) 43 | } 44 | var logRecords []map[string]any 45 | var logRecordsMu sync.RWMutex 46 | var wg sync.WaitGroup 47 | wg.Add(1) 48 | go func() { 49 | defer wg.Done() 50 | 51 | stderr, ok := GetStderr(client) 52 | if !ok { 53 | return 54 | } 55 | 56 | dec := json.NewDecoder(stderr) 57 | for { 58 | var record map[string]any 59 | if err := dec.Decode(&record); err != nil { 60 | return 61 | } 62 | logRecordsMu.Lock() 63 | logRecords = append(logRecords, record) 64 | logRecordsMu.Unlock() 65 | } 66 | }() 67 | 68 | t.Run("Initialize", func(t *testing.T) { 69 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 70 | defer cancel() 71 | 72 | request := mcp.InitializeRequest{} 73 | request.Params.ProtocolVersion = "1.0" 74 | request.Params.ClientInfo = mcp.Implementation{ 75 | Name: "test-client", 76 | Version: "1.0.0", 77 | } 78 | request.Params.Capabilities = mcp.ClientCapabilities{ 79 | Roots: &struct { 80 | ListChanged bool `json:"listChanged,omitempty"` 81 | }{ 82 | ListChanged: true, 83 | }, 84 | } 85 | 86 | result, err := client.Initialize(ctx, request) 87 | if err != nil { 88 | t.Fatalf("Initialize failed: %v", err) 89 | } 90 | 91 | if result.ServerInfo.Name != "mock-server" { 92 | t.Errorf( 93 | "Expected server name 'mock-server', got '%s'", 94 | result.ServerInfo.Name, 95 | ) 96 | } 97 | }) 98 | 99 | t.Run("Ping", func(t *testing.T) { 100 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 101 | defer cancel() 102 | 103 | err := client.Ping(ctx) 104 | if err != nil { 105 | t.Errorf("Ping failed: %v", err) 106 | } 107 | }) 108 | 109 | t.Run("ListResources", func(t *testing.T) { 110 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 111 | defer cancel() 112 | 113 | request := mcp.ListResourcesRequest{} 114 | result, err := client.ListResources(ctx, request) 115 | if err != nil { 116 | t.Errorf("ListResources failed: %v", err) 117 | } 118 | 119 | if len(result.Resources) != 1 { 120 | t.Errorf("Expected 1 resource, got %d", len(result.Resources)) 121 | } 122 | }) 123 | 124 | t.Run("ReadResource", func(t *testing.T) { 125 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 126 | defer cancel() 127 | 128 | request := mcp.ReadResourceRequest{} 129 | request.Params.URI = "test://resource" 130 | 131 | result, err := client.ReadResource(ctx, request) 132 | if err != nil { 133 | t.Errorf("ReadResource failed: %v", err) 134 | } 135 | 136 | if len(result.Contents) != 1 { 137 | t.Errorf("Expected 1 content item, got %d", len(result.Contents)) 138 | } 139 | }) 140 | 141 | t.Run("Subscribe and Unsubscribe", func(t *testing.T) { 142 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 143 | defer cancel() 144 | 145 | // Test Subscribe 146 | subRequest := mcp.SubscribeRequest{} 147 | subRequest.Params.URI = "test://resource" 148 | err := client.Subscribe(ctx, subRequest) 149 | if err != nil { 150 | t.Errorf("Subscribe failed: %v", err) 151 | } 152 | 153 | // Test Unsubscribe 154 | unsubRequest := mcp.UnsubscribeRequest{} 155 | unsubRequest.Params.URI = "test://resource" 156 | err = client.Unsubscribe(ctx, unsubRequest) 157 | if err != nil { 158 | t.Errorf("Unsubscribe failed: %v", err) 159 | } 160 | }) 161 | 162 | t.Run("ListPrompts", func(t *testing.T) { 163 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 164 | defer cancel() 165 | 166 | request := mcp.ListPromptsRequest{} 167 | result, err := client.ListPrompts(ctx, request) 168 | if err != nil { 169 | t.Errorf("ListPrompts failed: %v", err) 170 | } 171 | 172 | if len(result.Prompts) != 1 { 173 | t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) 174 | } 175 | }) 176 | 177 | t.Run("GetPrompt", func(t *testing.T) { 178 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 179 | defer cancel() 180 | 181 | request := mcp.GetPromptRequest{} 182 | request.Params.Name = "test-prompt" 183 | 184 | result, err := client.GetPrompt(ctx, request) 185 | if err != nil { 186 | t.Errorf("GetPrompt failed: %v", err) 187 | } 188 | 189 | if len(result.Messages) != 1 { 190 | t.Errorf("Expected 1 message, got %d", len(result.Messages)) 191 | } 192 | }) 193 | 194 | t.Run("ListTools", func(t *testing.T) { 195 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 196 | defer cancel() 197 | 198 | request := mcp.ListToolsRequest{} 199 | result, err := client.ListTools(ctx, request) 200 | if err != nil { 201 | t.Errorf("ListTools failed: %v", err) 202 | } 203 | 204 | if len(result.Tools) != 1 { 205 | t.Errorf("Expected 1 tool, got %d", len(result.Tools)) 206 | } 207 | }) 208 | 209 | t.Run("CallTool", func(t *testing.T) { 210 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 211 | defer cancel() 212 | 213 | request := mcp.CallToolRequest{} 214 | request.Params.Name = "test-tool" 215 | request.Params.Arguments = map[string]interface{}{ 216 | "param1": "value1", 217 | } 218 | 219 | result, err := client.CallTool(ctx, request) 220 | if err != nil { 221 | t.Errorf("CallTool failed: %v", err) 222 | } 223 | 224 | if len(result.Content) != 1 { 225 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 226 | } 227 | }) 228 | 229 | t.Run("SetLevel", func(t *testing.T) { 230 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 231 | defer cancel() 232 | 233 | request := mcp.SetLevelRequest{} 234 | request.Params.Level = mcp.LoggingLevelInfo 235 | 236 | err := client.SetLevel(ctx, request) 237 | if err != nil { 238 | t.Errorf("SetLevel failed: %v", err) 239 | } 240 | }) 241 | 242 | t.Run("Complete", func(t *testing.T) { 243 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 244 | defer cancel() 245 | 246 | request := mcp.CompleteRequest{} 247 | request.Params.Ref = mcp.PromptReference{ 248 | Type: "ref/prompt", 249 | Name: "test-prompt", 250 | } 251 | request.Params.Argument.Name = "test-arg" 252 | request.Params.Argument.Value = "test-value" 253 | 254 | result, err := client.Complete(ctx, request) 255 | if err != nil { 256 | t.Errorf("Complete failed: %v", err) 257 | } 258 | 259 | if len(result.Completion.Values) != 1 { 260 | t.Errorf( 261 | "Expected 1 completion value, got %d", 262 | len(result.Completion.Values), 263 | ) 264 | } 265 | }) 266 | 267 | client.Close() 268 | wg.Wait() 269 | 270 | t.Run("CheckLogs", func(t *testing.T) { 271 | logRecordsMu.RLock() 272 | defer logRecordsMu.RUnlock() 273 | 274 | if len(logRecords) != 1 { 275 | t.Errorf("Expected 1 log record, got %d", len(logRecords)) 276 | return 277 | } 278 | 279 | msg, ok := logRecords[0][slog.MessageKey].(string) 280 | if !ok { 281 | t.Errorf("Expected log record to have message key") 282 | } 283 | if msg != "launch successful" { 284 | t.Errorf("Expected log message 'launch successful', got '%s'", msg) 285 | } 286 | }) 287 | } 288 | -------------------------------------------------------------------------------- /client/transport/inprocess.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "sync" 8 | 9 | "github.com/mark3labs/mcp-go/mcp" 10 | "github.com/mark3labs/mcp-go/server" 11 | ) 12 | 13 | type InProcessTransport struct { 14 | server *server.MCPServer 15 | 16 | onNotification func(mcp.JSONRPCNotification) 17 | notifyMu sync.RWMutex 18 | } 19 | 20 | func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { 21 | return &InProcessTransport{ 22 | server: server, 23 | } 24 | } 25 | 26 | func (c *InProcessTransport) Start(ctx context.Context) error { 27 | return nil 28 | } 29 | 30 | func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { 31 | requestBytes, err := json.Marshal(request) 32 | if err != nil { 33 | return nil, fmt.Errorf("failed to marshal request: %w", err) 34 | } 35 | requestBytes = append(requestBytes, '\n') 36 | 37 | respMessage := c.server.HandleMessage(ctx, requestBytes) 38 | respByte, err := json.Marshal(respMessage) 39 | if err != nil { 40 | return nil, fmt.Errorf("failed to marshal response message: %w", err) 41 | } 42 | rpcResp := JSONRPCResponse{} 43 | err = json.Unmarshal(respByte, &rpcResp) 44 | if err != nil { 45 | return nil, fmt.Errorf("failed to unmarshal response message: %w", err) 46 | } 47 | 48 | return &rpcResp, nil 49 | } 50 | 51 | func (c *InProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { 52 | notificationBytes, err := json.Marshal(notification) 53 | if err != nil { 54 | return fmt.Errorf("failed to marshal notification: %w", err) 55 | } 56 | notificationBytes = append(notificationBytes, '\n') 57 | c.server.HandleMessage(ctx, notificationBytes) 58 | 59 | return nil 60 | } 61 | 62 | func (c *InProcessTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { 63 | c.notifyMu.Lock() 64 | defer c.notifyMu.Unlock() 65 | c.onNotification = handler 66 | } 67 | 68 | func (*InProcessTransport) Close() error { 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /client/transport/interface.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | ) 9 | 10 | // Interface for the transport layer. 11 | type Interface interface { 12 | // Start the connection. Start should only be called once. 13 | Start(ctx context.Context) error 14 | 15 | // SendRequest sends a json RPC request and returns the response synchronously. 16 | SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) 17 | 18 | // SendNotification sends a json RPC Notification to the server. 19 | SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error 20 | 21 | // SetNotificationHandler sets the handler for notifications. 22 | // Any notification before the handler is set will be discarded. 23 | SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) 24 | 25 | // Close the connection. 26 | Close() error 27 | } 28 | 29 | type JSONRPCRequest struct { 30 | JSONRPC string `json:"jsonrpc"` 31 | ID int64 `json:"id"` 32 | Method string `json:"method"` 33 | Params any `json:"params,omitempty"` 34 | } 35 | 36 | type JSONRPCResponse struct { 37 | JSONRPC string `json:"jsonrpc"` 38 | ID *int64 `json:"id"` 39 | Result json.RawMessage `json:"result"` 40 | Error *struct { 41 | Code int `json:"code"` 42 | Message string `json:"message"` 43 | Data json.RawMessage `json:"data"` 44 | } `json:"error"` 45 | } 46 | -------------------------------------------------------------------------------- /client/transport/sse.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/mark3labs/mcp-go/mcp" 18 | ) 19 | 20 | // SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). 21 | // It maintains a persistent HTTP connection to receive server-pushed events 22 | // while sending requests over regular HTTP POST calls. The client handles 23 | // automatic reconnection and message routing between requests and responses. 24 | type SSE struct { 25 | baseURL *url.URL 26 | endpoint *url.URL 27 | httpClient *http.Client 28 | responses map[int64]chan *JSONRPCResponse 29 | mu sync.RWMutex 30 | onNotification func(mcp.JSONRPCNotification) 31 | notifyMu sync.RWMutex 32 | endpointChan chan struct{} 33 | headers map[string]string 34 | 35 | started atomic.Bool 36 | closed atomic.Bool 37 | cancelSSEStream context.CancelFunc 38 | } 39 | 40 | type ClientOption func(*SSE) 41 | 42 | func WithHeaders(headers map[string]string) ClientOption { 43 | return func(sc *SSE) { 44 | sc.headers = headers 45 | } 46 | } 47 | 48 | // NewSSE creates a new SSE-based MCP client with the given base URL. 49 | // Returns an error if the URL is invalid. 50 | func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { 51 | parsedURL, err := url.Parse(baseURL) 52 | if err != nil { 53 | return nil, fmt.Errorf("invalid URL: %w", err) 54 | } 55 | 56 | smc := &SSE{ 57 | baseURL: parsedURL, 58 | httpClient: &http.Client{}, 59 | responses: make(map[int64]chan *JSONRPCResponse), 60 | endpointChan: make(chan struct{}), 61 | headers: make(map[string]string), 62 | } 63 | 64 | for _, opt := range options { 65 | opt(smc) 66 | } 67 | 68 | return smc, nil 69 | } 70 | 71 | // Start initiates the SSE connection to the server and waits for the endpoint information. 72 | // Returns an error if the connection fails or times out waiting for the endpoint. 73 | func (c *SSE) Start(ctx context.Context) error { 74 | 75 | if c.started.Load() { 76 | return fmt.Errorf("has already started") 77 | } 78 | 79 | ctx, cancel := context.WithCancel(ctx) 80 | c.cancelSSEStream = cancel 81 | 82 | req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) 83 | 84 | if err != nil { 85 | return fmt.Errorf("failed to create request: %w", err) 86 | } 87 | 88 | req.Header.Set("Accept", "text/event-stream") 89 | req.Header.Set("Cache-Control", "no-cache") 90 | req.Header.Set("Connection", "keep-alive") 91 | 92 | // set custom http headers 93 | for k, v := range c.headers { 94 | req.Header.Set(k, v) 95 | } 96 | 97 | resp, err := c.httpClient.Do(req) 98 | if err != nil { 99 | return fmt.Errorf("failed to connect to SSE stream: %w", err) 100 | } 101 | 102 | if resp.StatusCode != http.StatusOK { 103 | resp.Body.Close() 104 | return fmt.Errorf("unexpected status code: %d", resp.StatusCode) 105 | } 106 | 107 | go c.readSSE(resp.Body) 108 | 109 | // Wait for the endpoint to be received 110 | timeout := time.NewTimer(30 * time.Second) 111 | defer timeout.Stop() 112 | select { 113 | case <-c.endpointChan: 114 | // Endpoint received, proceed 115 | case <-ctx.Done(): 116 | return fmt.Errorf("context cancelled while waiting for endpoint") 117 | case <-timeout.C: // Add a timeout 118 | cancel() 119 | return fmt.Errorf("timeout waiting for endpoint") 120 | } 121 | 122 | c.started.Store(true) 123 | return nil 124 | } 125 | 126 | // readSSE continuously reads the SSE stream and processes events. 127 | // It runs until the connection is closed or an error occurs. 128 | func (c *SSE) readSSE(reader io.ReadCloser) { 129 | defer reader.Close() 130 | 131 | br := bufio.NewReader(reader) 132 | var event, data string 133 | 134 | for { 135 | // when close or start's ctx cancel, the reader will be closed 136 | // and the for loop will break. 137 | line, err := br.ReadString('\n') 138 | if err != nil { 139 | if err == io.EOF { 140 | // Process any pending event before exit 141 | if event != "" && data != "" { 142 | c.handleSSEEvent(event, data) 143 | } 144 | break 145 | } 146 | if !c.closed.Load() { 147 | fmt.Printf("SSE stream error: %v\n", err) 148 | } 149 | return 150 | } 151 | 152 | // Remove only newline markers 153 | line = strings.TrimRight(line, "\r\n") 154 | if line == "" { 155 | // Empty line means end of event 156 | if event != "" && data != "" { 157 | c.handleSSEEvent(event, data) 158 | event = "" 159 | data = "" 160 | } 161 | continue 162 | } 163 | 164 | if strings.HasPrefix(line, "event:") { 165 | event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) 166 | } else if strings.HasPrefix(line, "data:") { 167 | data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) 168 | } 169 | } 170 | } 171 | 172 | // handleSSEEvent processes SSE events based on their type. 173 | // Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. 174 | func (c *SSE) handleSSEEvent(event, data string) { 175 | switch event { 176 | case "endpoint": 177 | endpoint, err := c.baseURL.Parse(data) 178 | if err != nil { 179 | fmt.Printf("Error parsing endpoint URL: %v\n", err) 180 | return 181 | } 182 | if endpoint.Host != c.baseURL.Host { 183 | fmt.Printf("Endpoint origin does not match connection origin\n") 184 | return 185 | } 186 | c.endpoint = endpoint 187 | close(c.endpointChan) 188 | 189 | case "message": 190 | var baseMessage JSONRPCResponse 191 | if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { 192 | fmt.Printf("Error unmarshaling message: %v\n", err) 193 | return 194 | } 195 | 196 | // Handle notification 197 | if baseMessage.ID == nil { 198 | var notification mcp.JSONRPCNotification 199 | if err := json.Unmarshal([]byte(data), ¬ification); err != nil { 200 | return 201 | } 202 | c.notifyMu.RLock() 203 | if c.onNotification != nil { 204 | c.onNotification(notification) 205 | } 206 | c.notifyMu.RUnlock() 207 | return 208 | } 209 | 210 | c.mu.RLock() 211 | ch, ok := c.responses[*baseMessage.ID] 212 | c.mu.RUnlock() 213 | 214 | if ok { 215 | ch <- &baseMessage 216 | c.mu.Lock() 217 | delete(c.responses, *baseMessage.ID) 218 | c.mu.Unlock() 219 | } 220 | } 221 | } 222 | 223 | func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { 224 | c.notifyMu.Lock() 225 | defer c.notifyMu.Unlock() 226 | c.onNotification = handler 227 | } 228 | 229 | // sendRequest sends a JSON-RPC request to the server and waits for a response. 230 | // Returns the raw JSON response message or an error if the request fails. 231 | func (c *SSE) SendRequest( 232 | ctx context.Context, 233 | request JSONRPCRequest, 234 | ) (*JSONRPCResponse, error) { 235 | 236 | if !c.started.Load() { 237 | return nil, fmt.Errorf("transport not started yet") 238 | } 239 | if c.closed.Load() { 240 | return nil, fmt.Errorf("transport has been closed") 241 | } 242 | if c.endpoint == nil { 243 | return nil, fmt.Errorf("endpoint not received") 244 | } 245 | 246 | requestBytes, err := json.Marshal(request) 247 | if err != nil { 248 | return nil, fmt.Errorf("failed to marshal request: %w", err) 249 | } 250 | 251 | responseChan := make(chan *JSONRPCResponse, 1) 252 | c.mu.Lock() 253 | c.responses[request.ID] = responseChan 254 | c.mu.Unlock() 255 | 256 | req, err := http.NewRequestWithContext( 257 | ctx, 258 | "POST", 259 | c.endpoint.String(), 260 | bytes.NewReader(requestBytes), 261 | ) 262 | if err != nil { 263 | return nil, fmt.Errorf("failed to create request: %w", err) 264 | } 265 | 266 | req.Header.Set("Content-Type", "application/json") 267 | // set custom http headers 268 | for k, v := range c.headers { 269 | req.Header.Set(k, v) 270 | } 271 | 272 | resp, err := c.httpClient.Do(req) 273 | if err != nil { 274 | return nil, fmt.Errorf("failed to send request: %w", err) 275 | } 276 | defer resp.Body.Close() 277 | 278 | if resp.StatusCode != http.StatusOK && 279 | resp.StatusCode != http.StatusAccepted { 280 | body, _ := io.ReadAll(resp.Body) 281 | return nil, fmt.Errorf( 282 | "request failed with status %d: %s", 283 | resp.StatusCode, 284 | body, 285 | ) 286 | } 287 | 288 | select { 289 | case <-ctx.Done(): 290 | c.mu.Lock() 291 | delete(c.responses, request.ID) 292 | c.mu.Unlock() 293 | return nil, ctx.Err() 294 | case response := <-responseChan: 295 | return response, nil 296 | } 297 | } 298 | 299 | // Close shuts down the SSE client connection and cleans up any pending responses. 300 | // Returns an error if the shutdown process fails. 301 | func (c *SSE) Close() error { 302 | if !c.closed.CompareAndSwap(false, true) { 303 | return nil // Already closed 304 | } 305 | 306 | if c.cancelSSEStream != nil { 307 | // It could stop the sse stream body, to quit the readSSE loop immediately 308 | // Also, it could quit start() immediately if not receiving the endpoint 309 | c.cancelSSEStream() 310 | } 311 | 312 | // Clean up any pending responses 313 | c.mu.Lock() 314 | for _, ch := range c.responses { 315 | close(ch) 316 | } 317 | c.responses = make(map[int64]chan *JSONRPCResponse) 318 | c.mu.Unlock() 319 | 320 | return nil 321 | } 322 | 323 | // SendNotification sends a JSON-RPC notification to the server without expecting a response. 324 | func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { 325 | if c.endpoint == nil { 326 | return fmt.Errorf("endpoint not received") 327 | } 328 | 329 | notificationBytes, err := json.Marshal(notification) 330 | if err != nil { 331 | return fmt.Errorf("failed to marshal notification: %w", err) 332 | } 333 | 334 | req, err := http.NewRequestWithContext( 335 | ctx, 336 | "POST", 337 | c.endpoint.String(), 338 | bytes.NewReader(notificationBytes), 339 | ) 340 | if err != nil { 341 | return fmt.Errorf("failed to create notification request: %w", err) 342 | } 343 | 344 | req.Header.Set("Content-Type", "application/json") 345 | // Set custom HTTP headers 346 | for k, v := range c.headers { 347 | req.Header.Set(k, v) 348 | } 349 | 350 | resp, err := c.httpClient.Do(req) 351 | if err != nil { 352 | return fmt.Errorf("failed to send notification: %w", err) 353 | } 354 | defer resp.Body.Close() 355 | 356 | if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { 357 | body, _ := io.ReadAll(resp.Body) 358 | return fmt.Errorf( 359 | "notification failed with status %d: %s", 360 | resp.StatusCode, 361 | body, 362 | ) 363 | } 364 | 365 | return nil 366 | } 367 | 368 | // GetEndpoint returns the current endpoint URL for the SSE connection. 369 | func (c *SSE) GetEndpoint() *url.URL { 370 | return c.endpoint 371 | } 372 | 373 | // GetBaseURL returns the base URL set in the SSE constructor. 374 | func (c *SSE) GetBaseURL() *url.URL { 375 | return c.baseURL 376 | } 377 | -------------------------------------------------------------------------------- /client/transport/sse_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "fmt" 12 | "net/http" 13 | "net/http/httptest" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | // startMockSSEEchoServer starts a test HTTP server that implements 19 | // a minimal SSE-based echo server for testing purposes. 20 | // It returns the server URL and a function to close the server. 21 | func startMockSSEEchoServer() (string, func()) { 22 | // Create handler for SSE endpoint 23 | var sseWriter http.ResponseWriter 24 | var flush func() 25 | var mu sync.Mutex 26 | sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | // Setup SSE headers 28 | defer func() { 29 | mu.Lock() // for passing race test 30 | sseWriter = nil 31 | flush = nil 32 | mu.Unlock() 33 | fmt.Printf("SSEHandler ends: %v\n", r.Context().Err()) 34 | }() 35 | 36 | w.Header().Set("Content-Type", "text/event-stream") 37 | flusher, ok := w.(http.Flusher) 38 | if !ok { 39 | http.Error(w, "Streaming unsupported", http.StatusInternalServerError) 40 | return 41 | } 42 | 43 | mu.Lock() 44 | sseWriter = w 45 | flush = flusher.Flush 46 | mu.Unlock() 47 | 48 | // Send initial endpoint event with message endpoint URL 49 | mu.Lock() 50 | fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") 51 | flusher.Flush() 52 | mu.Unlock() 53 | 54 | // Keep connection open 55 | <-r.Context().Done() 56 | }) 57 | 58 | // Create handler for message endpoint 59 | messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 | // Handle only POST requests 61 | if r.Method != http.MethodPost { 62 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 63 | return 64 | } 65 | 66 | // Parse incoming JSON-RPC request 67 | var request map[string]interface{} 68 | decoder := json.NewDecoder(r.Body) 69 | if err := decoder.Decode(&request); err != nil { 70 | http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) 71 | return 72 | } 73 | 74 | // Echo back the request as the response result 75 | response := map[string]interface{}{ 76 | "jsonrpc": "2.0", 77 | "id": request["id"], 78 | "result": request, 79 | } 80 | 81 | method := request["method"] 82 | switch method { 83 | case "debug/echo": 84 | response["result"] = request 85 | case "debug/echo_notification": 86 | response["result"] = request 87 | // send notification to client 88 | responseBytes, _ := json.Marshal(map[string]any{ 89 | "jsonrpc": "2.0", 90 | "method": "debug/test", 91 | "params": request, 92 | }) 93 | mu.Lock() 94 | fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", responseBytes) 95 | flush() 96 | mu.Unlock() 97 | case "debug/echo_error_string": 98 | data, _ := json.Marshal(request) 99 | response["error"] = map[string]interface{}{ 100 | "code": -1, 101 | "message": string(data), 102 | } 103 | } 104 | 105 | // Set response headers 106 | w.Header().Set("Content-Type", "application/json") 107 | w.WriteHeader(http.StatusAccepted) 108 | 109 | go func() { 110 | data, _ := json.Marshal(response) 111 | mu.Lock() 112 | defer mu.Unlock() 113 | if sseWriter != nil && flush != nil { 114 | fmt.Fprintf(sseWriter, "event: message\ndata: %s\n\n", data) 115 | flush() 116 | } 117 | }() 118 | 119 | }) 120 | 121 | // Create a router to handle different endpoints 122 | mux := http.NewServeMux() 123 | mux.Handle("/", sseHandler) 124 | mux.Handle("/message", messageHandler) 125 | 126 | // Start test server 127 | testServer := httptest.NewServer(mux) 128 | 129 | return testServer.URL, testServer.Close 130 | } 131 | 132 | func TestSSE(t *testing.T) { 133 | // Compile mock server 134 | url, closeF := startMockSSEEchoServer() 135 | defer closeF() 136 | 137 | trans, err := NewSSE(url) 138 | if err != nil { 139 | t.Fatal(err) 140 | } 141 | 142 | // Start the transport 143 | ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) 144 | defer cancel() 145 | 146 | err = trans.Start(ctx) 147 | if err != nil { 148 | t.Fatalf("Failed to start transport: %v", err) 149 | } 150 | defer trans.Close() 151 | 152 | t.Run("SendRequest", func(t *testing.T) { 153 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 154 | defer cancel() 155 | 156 | params := map[string]interface{}{ 157 | "string": "hello world", 158 | "array": []interface{}{1, 2, 3}, 159 | } 160 | 161 | request := JSONRPCRequest{ 162 | JSONRPC: "2.0", 163 | ID: 1, 164 | Method: "debug/echo", 165 | Params: params, 166 | } 167 | 168 | // Send the request 169 | response, err := trans.SendRequest(ctx, request) 170 | if err != nil { 171 | t.Fatalf("SendRequest failed: %v", err) 172 | } 173 | 174 | // Parse the result to verify echo 175 | var result struct { 176 | JSONRPC string `json:"jsonrpc"` 177 | ID int64 `json:"id"` 178 | Method string `json:"method"` 179 | Params map[string]interface{} `json:"params"` 180 | } 181 | 182 | if err := json.Unmarshal(response.Result, &result); err != nil { 183 | t.Fatalf("Failed to unmarshal result: %v", err) 184 | } 185 | 186 | // Verify response data matches what was sent 187 | if result.JSONRPC != "2.0" { 188 | t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) 189 | } 190 | if result.ID != 1 { 191 | t.Errorf("Expected ID 1, got %d", result.ID) 192 | } 193 | if result.Method != "debug/echo" { 194 | t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) 195 | } 196 | 197 | if str, ok := result.Params["string"].(string); !ok || str != "hello world" { 198 | t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) 199 | } 200 | 201 | if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { 202 | t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) 203 | } 204 | }) 205 | 206 | t.Run("SendRequestWithTimeout", func(t *testing.T) { 207 | // Create a context that's already canceled 208 | ctx, cancel := context.WithCancel(context.Background()) 209 | cancel() // Cancel the context immediately 210 | 211 | // Prepare a request 212 | request := JSONRPCRequest{ 213 | JSONRPC: "2.0", 214 | ID: 3, 215 | Method: "debug/echo", 216 | } 217 | 218 | // The request should fail because the context is canceled 219 | _, err := trans.SendRequest(ctx, request) 220 | if err == nil { 221 | t.Errorf("Expected context canceled error, got nil") 222 | } else if !errors.Is(err, context.Canceled) { 223 | t.Errorf("Expected context.Canceled error, got: %v", err) 224 | } 225 | }) 226 | 227 | t.Run("SendNotification & NotificationHandler", func(t *testing.T) { 228 | 229 | var wg sync.WaitGroup 230 | notificationChan := make(chan mcp.JSONRPCNotification, 1) 231 | 232 | // Set notification handler 233 | trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { 234 | notificationChan <- notification 235 | }) 236 | 237 | // Send a notification 238 | // This would trigger a notification from the server 239 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 240 | defer cancel() 241 | 242 | notification := mcp.JSONRPCNotification{ 243 | JSONRPC: "2.0", 244 | Notification: mcp.Notification{ 245 | Method: "debug/echo_notification", 246 | Params: mcp.NotificationParams{ 247 | AdditionalFields: map[string]interface{}{"test": "value"}, 248 | }, 249 | }, 250 | } 251 | err := trans.SendNotification(ctx, notification) 252 | if err != nil { 253 | t.Fatalf("SendNotification failed: %v", err) 254 | } 255 | 256 | wg.Add(1) 257 | go func() { 258 | defer wg.Done() 259 | select { 260 | case nt := <-notificationChan: 261 | // We received a notification 262 | responseJson, _ := json.Marshal(nt.Params.AdditionalFields) 263 | requestJson, _ := json.Marshal(notification) 264 | if string(responseJson) != string(requestJson) { 265 | t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) 266 | } 267 | 268 | case <-time.After(1 * time.Second): 269 | t.Errorf("Expected notification, got none") 270 | } 271 | }() 272 | 273 | wg.Wait() 274 | }) 275 | 276 | t.Run("MultipleRequests", func(t *testing.T) { 277 | var wg sync.WaitGroup 278 | const numRequests = 5 279 | 280 | // Send multiple requests concurrently 281 | mu := sync.Mutex{} 282 | responses := make([]*JSONRPCResponse, numRequests) 283 | errors := make([]error, numRequests) 284 | 285 | for i := 0; i < numRequests; i++ { 286 | wg.Add(1) 287 | go func(idx int) { 288 | defer wg.Done() 289 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 290 | defer cancel() 291 | 292 | // Each request has a unique ID and payload 293 | request := JSONRPCRequest{ 294 | JSONRPC: "2.0", 295 | ID: int64(100 + idx), 296 | Method: "debug/echo", 297 | Params: map[string]interface{}{ 298 | "requestIndex": idx, 299 | "timestamp": time.Now().UnixNano(), 300 | }, 301 | } 302 | 303 | resp, err := trans.SendRequest(ctx, request) 304 | mu.Lock() 305 | responses[idx] = resp 306 | errors[idx] = err 307 | mu.Unlock() 308 | }(i) 309 | } 310 | 311 | wg.Wait() 312 | 313 | // Check results 314 | for i := 0; i < numRequests; i++ { 315 | if errors[i] != nil { 316 | t.Errorf("Request %d failed: %v", i, errors[i]) 317 | continue 318 | } 319 | 320 | if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { 321 | t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) 322 | continue 323 | } 324 | 325 | // Parse the result to verify echo 326 | var result struct { 327 | JSONRPC string `json:"jsonrpc"` 328 | ID int64 `json:"id"` 329 | Method string `json:"method"` 330 | Params map[string]interface{} `json:"params"` 331 | } 332 | 333 | if err := json.Unmarshal(responses[i].Result, &result); err != nil { 334 | t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) 335 | continue 336 | } 337 | 338 | // Verify data matches what was sent 339 | if result.ID != int64(100+i) { 340 | t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) 341 | } 342 | 343 | if result.Method != "debug/echo" { 344 | t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) 345 | } 346 | 347 | // Verify the requestIndex parameter 348 | if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { 349 | t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) 350 | } 351 | } 352 | }) 353 | 354 | t.Run("ResponseError", func(t *testing.T) { 355 | 356 | // Prepare a request 357 | request := JSONRPCRequest{ 358 | JSONRPC: "2.0", 359 | ID: 100, 360 | Method: "debug/echo_error_string", 361 | } 362 | 363 | // The request should fail because the context is canceled 364 | reps, err := trans.SendRequest(ctx, request) 365 | if err != nil { 366 | t.Errorf("SendRequest failed: %v", err) 367 | } 368 | 369 | if reps.Error == nil { 370 | t.Errorf("Expected error, got nil") 371 | } 372 | 373 | var responseError JSONRPCRequest 374 | if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { 375 | t.Errorf("Failed to unmarshal result: %v", err) 376 | } 377 | 378 | if responseError.Method != "debug/echo_error_string" { 379 | t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) 380 | } 381 | if responseError.ID != 100 { 382 | t.Errorf("Expected ID 100, got %d", responseError.ID) 383 | } 384 | if responseError.JSONRPC != "2.0" { 385 | t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) 386 | } 387 | }) 388 | 389 | } 390 | 391 | func TestSSEErrors(t *testing.T) { 392 | t.Run("InvalidURL", func(t *testing.T) { 393 | // Create a new SSE transport with an invalid URL 394 | _, err := NewSSE("://invalid-url") 395 | if err == nil { 396 | t.Errorf("Expected error when creating with invalid URL, got nil") 397 | } 398 | }) 399 | 400 | t.Run("NonExistentURL", func(t *testing.T) { 401 | // Create a new SSE transport with a non-existent URL 402 | sse, err := NewSSE("http://localhost:1") 403 | if err != nil { 404 | t.Fatalf("Failed to create SSE transport: %v", err) 405 | } 406 | 407 | // Start should fail 408 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 409 | defer cancel() 410 | 411 | err = sse.Start(ctx) 412 | if err == nil { 413 | t.Errorf("Expected error when starting with non-existent URL, got nil") 414 | sse.Close() 415 | } 416 | }) 417 | 418 | t.Run("RequestBeforeStart", func(t *testing.T) { 419 | url, closeF := startMockSSEEchoServer() 420 | defer closeF() 421 | 422 | // Create a new SSE instance without calling Start method 423 | sse, err := NewSSE(url) 424 | if err != nil { 425 | t.Fatalf("Failed to create SSE transport: %v", err) 426 | } 427 | 428 | // Prepare a request 429 | request := JSONRPCRequest{ 430 | JSONRPC: "2.0", 431 | ID: 99, 432 | Method: "ping", 433 | } 434 | 435 | ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 436 | defer cancel() 437 | 438 | _, err = sse.SendRequest(ctx, request) 439 | if err == nil { 440 | t.Errorf("Expected SendRequest to fail before Start(), but it didn't") 441 | } 442 | }) 443 | 444 | t.Run("RequestAfterClose", func(t *testing.T) { 445 | // Start a mock server 446 | url, closeF := startMockSSEEchoServer() 447 | defer closeF() 448 | 449 | // Create a new SSE transport 450 | sse, err := NewSSE(url) 451 | if err != nil { 452 | t.Fatalf("Failed to create SSE transport: %v", err) 453 | } 454 | 455 | // Start the transport 456 | ctx := context.Background() 457 | if err := sse.Start(ctx); err != nil { 458 | t.Fatalf("Failed to start SSE transport: %v", err) 459 | } 460 | 461 | // Close the transport 462 | sse.Close() 463 | 464 | // Wait a bit to ensure connection has closed 465 | time.Sleep(100 * time.Millisecond) 466 | 467 | // Try to send a request after close 468 | request := JSONRPCRequest{ 469 | JSONRPC: "2.0", 470 | ID: 1, 471 | Method: "ping", 472 | } 473 | 474 | _, err = sse.SendRequest(ctx, request) 475 | if err == nil { 476 | t.Errorf("Expected error when sending request after close, got nil") 477 | } 478 | }) 479 | 480 | } 481 | -------------------------------------------------------------------------------- /client/transport/stdio.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "os" 10 | "os/exec" 11 | "sync" 12 | 13 | "github.com/mark3labs/mcp-go/mcp" 14 | ) 15 | 16 | // Stdio implements the transport layer of the MCP protocol using stdio communication. 17 | // It launches a subprocess and communicates with it via standard input/output streams 18 | // using JSON-RPC messages. The client handles message routing between requests and 19 | // responses, and supports asynchronous notifications. 20 | type Stdio struct { 21 | command string 22 | args []string 23 | env []string 24 | 25 | cmd *exec.Cmd 26 | stdin io.WriteCloser 27 | stdout *bufio.Reader 28 | stderr io.ReadCloser 29 | responses map[int64]chan *JSONRPCResponse 30 | mu sync.RWMutex 31 | done chan struct{} 32 | onNotification func(mcp.JSONRPCNotification) 33 | notifyMu sync.RWMutex 34 | } 35 | 36 | // NewStdio creates a new stdio transport to communicate with a subprocess. 37 | // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. 38 | // Returns an error if the subprocess cannot be started or the pipes cannot be created. 39 | func NewStdio( 40 | command string, 41 | env []string, 42 | args ...string, 43 | ) *Stdio { 44 | 45 | client := &Stdio{ 46 | command: command, 47 | args: args, 48 | env: env, 49 | 50 | responses: make(map[int64]chan *JSONRPCResponse), 51 | done: make(chan struct{}), 52 | } 53 | 54 | return client 55 | } 56 | 57 | func (c *Stdio) Start(ctx context.Context) error { 58 | cmd := exec.CommandContext(ctx, c.command, c.args...) 59 | 60 | mergedEnv := os.Environ() 61 | mergedEnv = append(mergedEnv, c.env...) 62 | 63 | cmd.Env = mergedEnv 64 | 65 | stdin, err := cmd.StdinPipe() 66 | if err != nil { 67 | return fmt.Errorf("failed to create stdin pipe: %w", err) 68 | } 69 | 70 | stdout, err := cmd.StdoutPipe() 71 | if err != nil { 72 | return fmt.Errorf("failed to create stdout pipe: %w", err) 73 | } 74 | 75 | stderr, err := cmd.StderrPipe() 76 | if err != nil { 77 | return fmt.Errorf("failed to create stderr pipe: %w", err) 78 | } 79 | 80 | c.cmd = cmd 81 | c.stdin = stdin 82 | c.stderr = stderr 83 | c.stdout = bufio.NewReader(stdout) 84 | 85 | if err := cmd.Start(); err != nil { 86 | return fmt.Errorf("failed to start command: %w", err) 87 | } 88 | 89 | // Start reading responses in a goroutine and wait for it to be ready 90 | ready := make(chan struct{}) 91 | go func() { 92 | close(ready) 93 | c.readResponses() 94 | }() 95 | <-ready 96 | 97 | return nil 98 | } 99 | 100 | // Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. 101 | // Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. 102 | func (c *Stdio) Close() error { 103 | close(c.done) 104 | if err := c.stdin.Close(); err != nil { 105 | return fmt.Errorf("failed to close stdin: %w", err) 106 | } 107 | if err := c.stderr.Close(); err != nil { 108 | return fmt.Errorf("failed to close stderr: %w", err) 109 | } 110 | return c.cmd.Wait() 111 | } 112 | 113 | // OnNotification registers a handler function to be called when notifications are received. 114 | // Multiple handlers can be registered and will be called in the order they were added. 115 | func (c *Stdio) SetNotificationHandler( 116 | handler func(notification mcp.JSONRPCNotification), 117 | ) { 118 | c.notifyMu.Lock() 119 | defer c.notifyMu.Unlock() 120 | c.onNotification = handler 121 | } 122 | 123 | // readResponses continuously reads and processes responses from the server's stdout. 124 | // It handles both responses to requests and notifications, routing them appropriately. 125 | // Runs until the done channel is closed or an error occurs reading from stdout. 126 | func (c *Stdio) readResponses() { 127 | for { 128 | select { 129 | case <-c.done: 130 | return 131 | default: 132 | line, err := c.stdout.ReadString('\n') 133 | if err != nil { 134 | if err != io.EOF { 135 | fmt.Printf("Error reading response: %v\n", err) 136 | } 137 | return 138 | } 139 | 140 | var baseMessage JSONRPCResponse 141 | if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { 142 | continue 143 | } 144 | 145 | // Handle notification 146 | if baseMessage.ID == nil { 147 | var notification mcp.JSONRPCNotification 148 | if err := json.Unmarshal([]byte(line), ¬ification); err != nil { 149 | continue 150 | } 151 | c.notifyMu.RLock() 152 | if c.onNotification != nil { 153 | c.onNotification(notification) 154 | } 155 | c.notifyMu.RUnlock() 156 | continue 157 | } 158 | 159 | c.mu.RLock() 160 | ch, ok := c.responses[*baseMessage.ID] 161 | c.mu.RUnlock() 162 | 163 | if ok { 164 | ch <- &baseMessage 165 | c.mu.Lock() 166 | delete(c.responses, *baseMessage.ID) 167 | c.mu.Unlock() 168 | } 169 | } 170 | } 171 | } 172 | 173 | // sendRequest sends a JSON-RPC request to the server and waits for a response. 174 | // It creates a unique request ID, sends the request over stdin, and waits for 175 | // the corresponding response or context cancellation. 176 | // Returns the raw JSON response message or an error if the request fails. 177 | func (c *Stdio) SendRequest( 178 | ctx context.Context, 179 | request JSONRPCRequest, 180 | ) (*JSONRPCResponse, error) { 181 | if c.stdin == nil { 182 | return nil, fmt.Errorf("stdio client not started") 183 | } 184 | 185 | // Create the complete request structure 186 | responseChan := make(chan *JSONRPCResponse, 1) 187 | c.mu.Lock() 188 | c.responses[request.ID] = responseChan 189 | c.mu.Unlock() 190 | 191 | requestBytes, err := json.Marshal(request) 192 | if err != nil { 193 | return nil, fmt.Errorf("failed to marshal request: %w", err) 194 | } 195 | requestBytes = append(requestBytes, '\n') 196 | 197 | if _, err := c.stdin.Write(requestBytes); err != nil { 198 | return nil, fmt.Errorf("failed to write request: %w", err) 199 | } 200 | 201 | select { 202 | case <-ctx.Done(): 203 | c.mu.Lock() 204 | delete(c.responses, request.ID) 205 | c.mu.Unlock() 206 | return nil, ctx.Err() 207 | case response := <-responseChan: 208 | return response, nil 209 | } 210 | } 211 | 212 | // SendNotification sends a json RPC Notification to the server. 213 | func (c *Stdio) SendNotification( 214 | ctx context.Context, 215 | notification mcp.JSONRPCNotification, 216 | ) error { 217 | if c.stdin == nil { 218 | return fmt.Errorf("stdio client not started") 219 | } 220 | 221 | notificationBytes, err := json.Marshal(notification) 222 | if err != nil { 223 | return fmt.Errorf("failed to marshal notification: %w", err) 224 | } 225 | notificationBytes = append(notificationBytes, '\n') 226 | 227 | if _, err := c.stdin.Write(notificationBytes); err != nil { 228 | return fmt.Errorf("failed to write notification: %w", err) 229 | } 230 | 231 | return nil 232 | } 233 | 234 | // Stderr returns a reader for the stderr output of the subprocess. 235 | // This can be used to capture error messages or logs from the subprocess. 236 | func (c *Stdio) Stderr() io.Reader { 237 | return c.stderr 238 | } 239 | -------------------------------------------------------------------------------- /client/transport/stdio_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "os" 8 | "os/exec" 9 | "path/filepath" 10 | "runtime" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | func compileTestServer(outputPath string) error { 19 | cmd := exec.Command( 20 | "go", 21 | "build", 22 | "-o", 23 | outputPath, 24 | "../../testdata/mockstdio_server.go", 25 | ) 26 | if output, err := cmd.CombinedOutput(); err != nil { 27 | return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) 28 | } 29 | return nil 30 | } 31 | 32 | func TestStdio(t *testing.T) { 33 | // Compile mock server 34 | mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") 35 | // Add .exe suffix on Windows 36 | if runtime.GOOS == "windows" { 37 | mockServerPath += ".exe" 38 | } 39 | if err := compileTestServer(mockServerPath); err != nil { 40 | t.Fatalf("Failed to compile mock server: %v", err) 41 | } 42 | defer os.Remove(mockServerPath) 43 | 44 | // Create a new Stdio transport 45 | stdio := NewStdio(mockServerPath, nil) 46 | 47 | // Start the transport 48 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 49 | defer cancel() 50 | 51 | err := stdio.Start(ctx) 52 | if err != nil { 53 | t.Fatalf("Failed to start Stdio transport: %v", err) 54 | } 55 | defer stdio.Close() 56 | 57 | t.Run("SendRequest", func(t *testing.T) { 58 | ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) 59 | defer cancel() 60 | 61 | params := map[string]interface{}{ 62 | "string": "hello world", 63 | "array": []interface{}{1, 2, 3}, 64 | } 65 | 66 | request := JSONRPCRequest{ 67 | JSONRPC: "2.0", 68 | ID: 1, 69 | Method: "debug/echo", 70 | Params: params, 71 | } 72 | 73 | // Send the request 74 | response, err := stdio.SendRequest(ctx, request) 75 | if err != nil { 76 | t.Fatalf("SendRequest failed: %v", err) 77 | } 78 | 79 | // Parse the result to verify echo 80 | var result struct { 81 | JSONRPC string `json:"jsonrpc"` 82 | ID int64 `json:"id"` 83 | Method string `json:"method"` 84 | Params map[string]interface{} `json:"params"` 85 | } 86 | 87 | if err := json.Unmarshal(response.Result, &result); err != nil { 88 | t.Fatalf("Failed to unmarshal result: %v", err) 89 | } 90 | 91 | // Verify response data matches what was sent 92 | if result.JSONRPC != "2.0" { 93 | t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) 94 | } 95 | if result.ID != 1 { 96 | t.Errorf("Expected ID 1, got %d", result.ID) 97 | } 98 | if result.Method != "debug/echo" { 99 | t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) 100 | } 101 | 102 | if str, ok := result.Params["string"].(string); !ok || str != "hello world" { 103 | t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) 104 | } 105 | 106 | if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { 107 | t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) 108 | } 109 | }) 110 | 111 | t.Run("SendRequestWithTimeout", func(t *testing.T) { 112 | // Create a context that's already canceled 113 | ctx, cancel := context.WithCancel(context.Background()) 114 | cancel() // Cancel the context immediately 115 | 116 | // Prepare a request 117 | request := JSONRPCRequest{ 118 | JSONRPC: "2.0", 119 | ID: 3, 120 | Method: "debug/echo", 121 | } 122 | 123 | // The request should fail because the context is canceled 124 | _, err := stdio.SendRequest(ctx, request) 125 | if err == nil { 126 | t.Errorf("Expected context canceled error, got nil") 127 | } else if err != context.Canceled { 128 | t.Errorf("Expected context.Canceled error, got: %v", err) 129 | } 130 | }) 131 | 132 | t.Run("SendNotification & NotificationHandler", func(t *testing.T) { 133 | 134 | var wg sync.WaitGroup 135 | notificationChan := make(chan mcp.JSONRPCNotification, 1) 136 | 137 | // Set notification handler 138 | stdio.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { 139 | notificationChan <- notification 140 | }) 141 | 142 | // Send a notification 143 | // This would trigger a notification from the server 144 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 145 | defer cancel() 146 | 147 | notification := mcp.JSONRPCNotification{ 148 | JSONRPC: "2.0", 149 | Notification: mcp.Notification{ 150 | Method: "debug/echo_notification", 151 | Params: mcp.NotificationParams{ 152 | AdditionalFields: map[string]interface{}{"test": "value"}, 153 | }, 154 | }, 155 | } 156 | err := stdio.SendNotification(ctx, notification) 157 | if err != nil { 158 | t.Fatalf("SendNotification failed: %v", err) 159 | } 160 | 161 | wg.Add(1) 162 | go func() { 163 | defer wg.Done() 164 | select { 165 | case nt := <-notificationChan: 166 | // We received a notification 167 | responseJson, _ := json.Marshal(nt.Params.AdditionalFields) 168 | requestJson, _ := json.Marshal(notification) 169 | if string(responseJson) != string(requestJson) { 170 | t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) 171 | } 172 | 173 | case <-time.After(1 * time.Second): 174 | t.Errorf("Expected notification, got none") 175 | } 176 | }() 177 | 178 | wg.Wait() 179 | }) 180 | 181 | t.Run("MultipleRequests", func(t *testing.T) { 182 | var wg sync.WaitGroup 183 | const numRequests = 5 184 | 185 | // Send multiple requests concurrently 186 | responses := make([]*JSONRPCResponse, numRequests) 187 | errors := make([]error, numRequests) 188 | mu := sync.Mutex{} 189 | for i := 0; i < numRequests; i++ { 190 | wg.Add(1) 191 | go func(idx int) { 192 | defer wg.Done() 193 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 194 | defer cancel() 195 | 196 | // Each request has a unique ID and payload 197 | request := JSONRPCRequest{ 198 | JSONRPC: "2.0", 199 | ID: int64(100 + idx), 200 | Method: "debug/echo", 201 | Params: map[string]interface{}{ 202 | "requestIndex": idx, 203 | "timestamp": time.Now().UnixNano(), 204 | }, 205 | } 206 | 207 | resp, err := stdio.SendRequest(ctx, request) 208 | mu.Lock() 209 | responses[idx] = resp 210 | errors[idx] = err 211 | mu.Unlock() 212 | }(i) 213 | } 214 | 215 | wg.Wait() 216 | 217 | // Check results 218 | for i := 0; i < numRequests; i++ { 219 | if errors[i] != nil { 220 | t.Errorf("Request %d failed: %v", i, errors[i]) 221 | continue 222 | } 223 | 224 | if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { 225 | t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) 226 | continue 227 | } 228 | 229 | // Parse the result to verify echo 230 | var result struct { 231 | JSONRPC string `json:"jsonrpc"` 232 | ID int64 `json:"id"` 233 | Method string `json:"method"` 234 | Params map[string]interface{} `json:"params"` 235 | } 236 | 237 | if err := json.Unmarshal(responses[i].Result, &result); err != nil { 238 | t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) 239 | continue 240 | } 241 | 242 | // Verify data matches what was sent 243 | if result.ID != int64(100+i) { 244 | t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) 245 | } 246 | 247 | if result.Method != "debug/echo" { 248 | t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) 249 | } 250 | 251 | // Verify the requestIndex parameter 252 | if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { 253 | t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) 254 | } 255 | } 256 | }) 257 | 258 | t.Run("ResponseError", func(t *testing.T) { 259 | 260 | // Prepare a request 261 | request := JSONRPCRequest{ 262 | JSONRPC: "2.0", 263 | ID: 100, 264 | Method: "debug/echo_error_string", 265 | } 266 | 267 | // The request should fail because the context is canceled 268 | reps, err := stdio.SendRequest(ctx, request) 269 | if err != nil { 270 | t.Errorf("SendRequest failed: %v", err) 271 | } 272 | 273 | if reps.Error == nil { 274 | t.Errorf("Expected error, got nil") 275 | } 276 | 277 | var responseError JSONRPCRequest 278 | if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { 279 | t.Errorf("Failed to unmarshal result: %v", err) 280 | } 281 | 282 | if responseError.Method != "debug/echo_error_string" { 283 | t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) 284 | } 285 | if responseError.ID != 100 { 286 | t.Errorf("Expected ID 100, got %d", responseError.ID) 287 | } 288 | if responseError.JSONRPC != "2.0" { 289 | t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) 290 | } 291 | }) 292 | 293 | } 294 | 295 | func TestStdioErrors(t *testing.T) { 296 | t.Run("InvalidCommand", func(t *testing.T) { 297 | // Create a new Stdio transport with a non-existent command 298 | stdio := NewStdio("non_existent_command", nil) 299 | 300 | // Start should fail 301 | ctx := context.Background() 302 | err := stdio.Start(ctx) 303 | if err == nil { 304 | t.Errorf("Expected error when starting with invalid command, got nil") 305 | stdio.Close() 306 | } 307 | }) 308 | 309 | t.Run("RequestBeforeStart", func(t *testing.T) { 310 | mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") 311 | // Add .exe suffix on Windows 312 | if runtime.GOOS == "windows" { 313 | mockServerPath += ".exe" 314 | } 315 | if err := compileTestServer(mockServerPath); err != nil { 316 | t.Fatalf("Failed to compile mock server: %v", err) 317 | } 318 | defer os.Remove(mockServerPath) 319 | 320 | uninitiatedStdio := NewStdio(mockServerPath, nil) 321 | 322 | // Prepare a request 323 | request := JSONRPCRequest{ 324 | JSONRPC: "2.0", 325 | ID: 99, 326 | Method: "ping", 327 | } 328 | 329 | ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 330 | defer cancel() 331 | _, err := uninitiatedStdio.SendRequest(ctx, request) 332 | if err == nil { 333 | t.Errorf("Expected SendRequest to panic before Start(), but it didn't") 334 | } else if err.Error() != "stdio client not started" { 335 | t.Errorf("Expected error 'stdio client not started', got: %v", err) 336 | } 337 | }) 338 | 339 | t.Run("RequestAfterClose", func(t *testing.T) { 340 | // Compile mock server 341 | mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") 342 | // Add .exe suffix on Windows 343 | if runtime.GOOS == "windows" { 344 | mockServerPath += ".exe" 345 | } 346 | if err := compileTestServer(mockServerPath); err != nil { 347 | t.Fatalf("Failed to compile mock server: %v", err) 348 | } 349 | defer os.Remove(mockServerPath) 350 | 351 | // Create a new Stdio transport 352 | stdio := NewStdio(mockServerPath, nil) 353 | 354 | // Start the transport 355 | ctx := context.Background() 356 | if err := stdio.Start(ctx); err != nil { 357 | t.Fatalf("Failed to start Stdio transport: %v", err) 358 | } 359 | 360 | // Close the transport - ignore errors like "broken pipe" since the process might exit already 361 | stdio.Close() 362 | 363 | // Wait a bit to ensure process has exited 364 | time.Sleep(100 * time.Millisecond) 365 | 366 | // Try to send a request after close 367 | request := JSONRPCRequest{ 368 | JSONRPC: "2.0", 369 | ID: 1, 370 | Method: "ping", 371 | } 372 | 373 | _, err := stdio.SendRequest(ctx, request) 374 | if err == nil { 375 | t.Errorf("Expected error when sending request after close, got nil") 376 | } 377 | }) 378 | 379 | } 380 | -------------------------------------------------------------------------------- /client/transport/streamable_http.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/mark3labs/mcp-go/mcp" 18 | ) 19 | 20 | type StreamableHTTPCOption func(*StreamableHTTP) 21 | 22 | func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption { 23 | return func(sc *StreamableHTTP) { 24 | sc.headers = headers 25 | } 26 | } 27 | 28 | // WithHTTPTimeout sets the timeout for a HTTP request and stream. 29 | func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { 30 | return func(sc *StreamableHTTP) { 31 | sc.httpClient.Timeout = timeout 32 | } 33 | } 34 | 35 | // StreamableHTTP implements Streamable HTTP transport. 36 | // 37 | // It transmits JSON-RPC messages over individual HTTP requests. One message per request. 38 | // The HTTP response body can either be a single JSON-RPC response, 39 | // or an upgraded SSE stream that concludes with a JSON-RPC response for the same request. 40 | // 41 | // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports 42 | // 43 | // The current implementation does not support the following features: 44 | // - batching 45 | // - continuously listening for server notifications when no request is in flight 46 | // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) 47 | // - resuming stream 48 | // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) 49 | // - server -> client request 50 | type StreamableHTTP struct { 51 | baseURL *url.URL 52 | httpClient *http.Client 53 | headers map[string]string 54 | 55 | sessionID atomic.Value // string 56 | 57 | notificationHandler func(mcp.JSONRPCNotification) 58 | notifyMu sync.RWMutex 59 | 60 | closed chan struct{} 61 | } 62 | 63 | // NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. 64 | // Returns an error if the URL is invalid. 65 | func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { 66 | parsedURL, err := url.Parse(baseURL) 67 | if err != nil { 68 | return nil, fmt.Errorf("invalid URL: %w", err) 69 | } 70 | 71 | smc := &StreamableHTTP{ 72 | baseURL: parsedURL, 73 | httpClient: &http.Client{}, 74 | headers: make(map[string]string), 75 | closed: make(chan struct{}), 76 | } 77 | smc.sessionID.Store("") // set initial value to simplify later usage 78 | 79 | for _, opt := range options { 80 | opt(smc) 81 | } 82 | 83 | return smc, nil 84 | } 85 | 86 | // Start initiates the HTTP connection to the server. 87 | func (c *StreamableHTTP) Start(ctx context.Context) error { 88 | // For Streamable HTTP, we don't need to establish a persistent connection 89 | return nil 90 | } 91 | 92 | // Close closes the all the HTTP connections to the server. 93 | func (c *StreamableHTTP) Close() error { 94 | select { 95 | case <-c.closed: 96 | return nil 97 | default: 98 | } 99 | // Cancel all in-flight requests 100 | close(c.closed) 101 | 102 | sessionId := c.sessionID.Load().(string) 103 | if sessionId != "" { 104 | c.sessionID.Store("") 105 | 106 | // notify server session closed 107 | go func() { 108 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 109 | defer cancel() 110 | req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) 111 | if err != nil { 112 | fmt.Printf("failed to create close request\n: %v", err) 113 | return 114 | } 115 | req.Header.Set(headerKeySessionID, sessionId) 116 | res, err := c.httpClient.Do(req) 117 | if err != nil { 118 | fmt.Printf("failed to send close request\n: %v", err) 119 | return 120 | } 121 | res.Body.Close() 122 | }() 123 | } 124 | 125 | return nil 126 | } 127 | 128 | const ( 129 | initializeMethod = "initialize" 130 | headerKeySessionID = "Mcp-Session-Id" 131 | ) 132 | 133 | // sendRequest sends a JSON-RPC request to the server and waits for a response. 134 | // Returns the raw JSON response message or an error if the request fails. 135 | func (c *StreamableHTTP) SendRequest( 136 | ctx context.Context, 137 | request JSONRPCRequest, 138 | ) (*JSONRPCResponse, error) { 139 | 140 | // Create a combined context that could be canceled when the client is closed 141 | newCtx, cancel := context.WithCancel(ctx) 142 | defer cancel() 143 | go func() { 144 | select { 145 | case <-c.closed: 146 | cancel() 147 | case <-newCtx.Done(): 148 | // The original context was canceled, no need to do anything 149 | } 150 | }() 151 | ctx = newCtx 152 | 153 | // Marshal request 154 | requestBody, err := json.Marshal(request) 155 | if err != nil { 156 | return nil, fmt.Errorf("failed to marshal request: %w", err) 157 | } 158 | 159 | // Create HTTP request 160 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) 161 | if err != nil { 162 | return nil, fmt.Errorf("failed to create request: %w", err) 163 | } 164 | 165 | // Set headers 166 | req.Header.Set("Content-Type", "application/json") 167 | req.Header.Set("Accept", "application/json, text/event-stream") 168 | sessionID := c.sessionID.Load() 169 | if sessionID != "" { 170 | req.Header.Set(headerKeySessionID, sessionID.(string)) 171 | } 172 | for k, v := range c.headers { 173 | req.Header.Set(k, v) 174 | } 175 | 176 | // Send request 177 | resp, err := c.httpClient.Do(req) 178 | if err != nil { 179 | return nil, fmt.Errorf("failed to send request: %w", err) 180 | } 181 | defer resp.Body.Close() 182 | 183 | // Check if we got an error response 184 | if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { 185 | // handle session closed 186 | if resp.StatusCode == http.StatusNotFound { 187 | c.sessionID.CompareAndSwap(sessionID, "") 188 | return nil, fmt.Errorf("session terminated (404). need to re-initialize") 189 | } 190 | 191 | // handle error response 192 | var errResponse JSONRPCResponse 193 | body, _ := io.ReadAll(resp.Body) 194 | if err := json.Unmarshal(body, &errResponse); err == nil { 195 | return &errResponse, nil 196 | } 197 | return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) 198 | } 199 | 200 | if request.Method == initializeMethod { 201 | // saved the received session ID in the response 202 | // empty session ID is allowed 203 | if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { 204 | c.sessionID.Store(sessionID) 205 | } 206 | } 207 | 208 | // Handle different response types 209 | switch resp.Header.Get("Content-Type") { 210 | case "application/json": 211 | // Single response 212 | var response JSONRPCResponse 213 | if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { 214 | return nil, fmt.Errorf("failed to decode response: %w", err) 215 | } 216 | 217 | // should not be a notification 218 | if response.ID == nil { 219 | return nil, fmt.Errorf("response should contain RPC id: %v", response) 220 | } 221 | 222 | return &response, nil 223 | 224 | case "text/event-stream": 225 | // Server is using SSE for streaming responses 226 | return c.handleSSEResponse(ctx, resp.Body) 227 | 228 | default: 229 | return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) 230 | } 231 | } 232 | 233 | // handleSSEResponse processes an SSE stream for a specific request. 234 | // It returns the final result for the request once received, or an error. 235 | func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { 236 | 237 | // Create a channel for this specific request 238 | responseChan := make(chan *JSONRPCResponse, 1) 239 | 240 | ctx, cancel := context.WithCancel(ctx) 241 | defer cancel() 242 | 243 | // Start a goroutine to process the SSE stream 244 | go func() { 245 | // only close responseChan after readingSSE() 246 | defer close(responseChan) 247 | 248 | c.readSSE(ctx, reader, func(event, data string) { 249 | 250 | // (unsupported: batching) 251 | 252 | var message JSONRPCResponse 253 | if err := json.Unmarshal([]byte(data), &message); err != nil { 254 | fmt.Printf("failed to unmarshal message: %v\n", err) 255 | return 256 | } 257 | 258 | // Handle notification 259 | if message.ID == nil { 260 | var notification mcp.JSONRPCNotification 261 | if err := json.Unmarshal([]byte(data), ¬ification); err != nil { 262 | fmt.Printf("failed to unmarshal notification: %v\n", err) 263 | return 264 | } 265 | c.notifyMu.RLock() 266 | if c.notificationHandler != nil { 267 | c.notificationHandler(notification) 268 | } 269 | c.notifyMu.RUnlock() 270 | return 271 | } 272 | 273 | responseChan <- &message 274 | }) 275 | }() 276 | 277 | // Wait for the response or context cancellation 278 | select { 279 | case response := <-responseChan: 280 | if response == nil { 281 | return nil, fmt.Errorf("unexpected nil response") 282 | } 283 | return response, nil 284 | case <-ctx.Done(): 285 | return nil, ctx.Err() 286 | } 287 | } 288 | 289 | // readSSE reads the SSE stream(reader) and calls the handler for each event and data pair. 290 | // It will end when the reader is closed (or the context is done). 291 | func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) { 292 | defer reader.Close() 293 | 294 | br := bufio.NewReader(reader) 295 | var event, data string 296 | 297 | for { 298 | select { 299 | case <-ctx.Done(): 300 | return 301 | default: 302 | line, err := br.ReadString('\n') 303 | if err != nil { 304 | if err == io.EOF { 305 | // Process any pending event before exit 306 | if event != "" && data != "" { 307 | handler(event, data) 308 | } 309 | return 310 | } 311 | select { 312 | case <-ctx.Done(): 313 | return 314 | default: 315 | fmt.Printf("SSE stream error: %v\n", err) 316 | return 317 | } 318 | } 319 | 320 | // Remove only newline markers 321 | line = strings.TrimRight(line, "\r\n") 322 | if line == "" { 323 | // Empty line means end of event 324 | if event != "" && data != "" { 325 | handler(event, data) 326 | event = "" 327 | data = "" 328 | } 329 | continue 330 | } 331 | 332 | if strings.HasPrefix(line, "event:") { 333 | event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) 334 | } else if strings.HasPrefix(line, "data:") { 335 | data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) 336 | } 337 | } 338 | } 339 | } 340 | 341 | func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { 342 | 343 | // Marshal request 344 | requestBody, err := json.Marshal(notification) 345 | if err != nil { 346 | return fmt.Errorf("failed to marshal notification: %w", err) 347 | } 348 | 349 | // Create HTTP request 350 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) 351 | if err != nil { 352 | return fmt.Errorf("failed to create request: %w", err) 353 | } 354 | 355 | // Set headers 356 | req.Header.Set("Content-Type", "application/json") 357 | if sessionID := c.sessionID.Load(); sessionID != "" { 358 | req.Header.Set(headerKeySessionID, sessionID.(string)) 359 | } 360 | for k, v := range c.headers { 361 | req.Header.Set(k, v) 362 | } 363 | 364 | // Send request 365 | resp, err := c.httpClient.Do(req) 366 | if err != nil { 367 | return fmt.Errorf("failed to send request: %w", err) 368 | } 369 | defer resp.Body.Close() 370 | 371 | if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { 372 | body, _ := io.ReadAll(resp.Body) 373 | return fmt.Errorf( 374 | "notification failed with status %d: %s", 375 | resp.StatusCode, 376 | body, 377 | ) 378 | } 379 | 380 | return nil 381 | } 382 | 383 | func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { 384 | c.notifyMu.Lock() 385 | defer c.notifyMu.Unlock() 386 | c.notificationHandler = handler 387 | } 388 | 389 | func (c *StreamableHTTP) GetSessionId() string { 390 | return c.sessionID.Load().(string) 391 | } 392 | -------------------------------------------------------------------------------- /client/transport/streamable_http_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "net/http/httptest" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/mark3labs/mcp-go/mcp" 15 | ) 16 | 17 | // startMockStreamableHTTPServer starts a test HTTP server that implements 18 | // a minimal Streamable HTTP server for testing purposes. 19 | // It returns the server URL and a function to close the server. 20 | func startMockStreamableHTTPServer() (string, func()) { 21 | var sessionID string 22 | var mu sync.Mutex 23 | 24 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 | // Handle only POST requests 26 | if r.Method != http.MethodPost { 27 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 28 | return 29 | } 30 | 31 | // Parse incoming JSON-RPC request 32 | var request map[string]any 33 | decoder := json.NewDecoder(r.Body) 34 | if err := decoder.Decode(&request); err != nil { 35 | http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) 36 | return 37 | } 38 | 39 | method := request["method"] 40 | switch method { 41 | case "initialize": 42 | // Generate a new session ID 43 | mu.Lock() 44 | sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) 45 | mu.Unlock() 46 | w.Header().Set("Mcp-Session-Id", sessionID) 47 | w.Header().Set("Content-Type", "application/json") 48 | w.WriteHeader(http.StatusAccepted) 49 | json.NewEncoder(w).Encode(map[string]interface{}{ 50 | "jsonrpc": "2.0", 51 | "id": request["id"], 52 | "result": "initialized", 53 | }) 54 | 55 | case "debug/echo": 56 | // Check session ID 57 | if r.Header.Get("Mcp-Session-Id") != sessionID { 58 | http.Error(w, "Invalid session ID", http.StatusNotFound) 59 | return 60 | } 61 | 62 | // Echo back the request as the response result 63 | w.Header().Set("Content-Type", "application/json") 64 | w.WriteHeader(http.StatusOK) 65 | json.NewEncoder(w).Encode(map[string]interface{}{ 66 | "jsonrpc": "2.0", 67 | "id": request["id"], 68 | "result": request, 69 | }) 70 | 71 | case "debug/echo_notification": 72 | // Check session ID 73 | if r.Header.Get("Mcp-Session-Id") != sessionID { 74 | http.Error(w, "Invalid session ID", http.StatusNotFound) 75 | return 76 | } 77 | 78 | // Send response and notification 79 | w.Header().Set("Content-Type", "text/event-stream") 80 | w.WriteHeader(http.StatusOK) 81 | notification := map[string]any{ 82 | "jsonrpc": "2.0", 83 | "method": "debug/test", 84 | "params": request, 85 | } 86 | notificationData, _ := json.Marshal(notification) 87 | fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) 88 | response := map[string]any{ 89 | "jsonrpc": "2.0", 90 | "id": request["id"], 91 | "result": request, 92 | } 93 | responseData, _ := json.Marshal(response) 94 | fmt.Fprintf(w, "event: message\ndata: %s\n\n", responseData) 95 | 96 | case "debug/echo_error_string": 97 | // Check session ID 98 | if r.Header.Get("Mcp-Session-Id") != sessionID { 99 | http.Error(w, "Invalid session ID", http.StatusNotFound) 100 | return 101 | } 102 | 103 | // Return an error response 104 | w.Header().Set("Content-Type", "application/json") 105 | w.WriteHeader(http.StatusOK) 106 | data, _ := json.Marshal(request) 107 | json.NewEncoder(w).Encode(map[string]interface{}{ 108 | "jsonrpc": "2.0", 109 | "id": request["id"], 110 | "error": map[string]interface{}{ 111 | "code": -1, 112 | "message": string(data), 113 | }, 114 | }) 115 | } 116 | }) 117 | 118 | // Start test server 119 | testServer := httptest.NewServer(handler) 120 | return testServer.URL, testServer.Close 121 | } 122 | 123 | func TestStreamableHTTP(t *testing.T) { 124 | // Start mock server 125 | url, closeF := startMockStreamableHTTPServer() 126 | defer closeF() 127 | 128 | // Create transport 129 | trans, err := NewStreamableHTTP(url) 130 | if err != nil { 131 | t.Fatal(err) 132 | } 133 | defer trans.Close() 134 | 135 | // Initialize the transport first 136 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 137 | defer cancel() 138 | 139 | initRequest := JSONRPCRequest{ 140 | JSONRPC: "2.0", 141 | ID: 1, 142 | Method: "initialize", 143 | } 144 | 145 | _, err = trans.SendRequest(ctx, initRequest) 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | // Now run the tests 151 | t.Run("SendRequest", func(t *testing.T) { 152 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 153 | defer cancel() 154 | 155 | params := map[string]interface{}{ 156 | "string": "hello world", 157 | "array": []interface{}{1, 2, 3}, 158 | } 159 | 160 | request := JSONRPCRequest{ 161 | JSONRPC: "2.0", 162 | ID: 1, 163 | Method: "debug/echo", 164 | Params: params, 165 | } 166 | 167 | // Send the request 168 | response, err := trans.SendRequest(ctx, request) 169 | if err != nil { 170 | t.Fatalf("SendRequest failed: %v", err) 171 | } 172 | 173 | // Parse the result to verify echo 174 | var result struct { 175 | JSONRPC string `json:"jsonrpc"` 176 | ID int64 `json:"id"` 177 | Method string `json:"method"` 178 | Params map[string]interface{} `json:"params"` 179 | } 180 | 181 | if err := json.Unmarshal(response.Result, &result); err != nil { 182 | t.Fatalf("Failed to unmarshal result: %v", err) 183 | } 184 | 185 | // Verify response data matches what was sent 186 | if result.JSONRPC != "2.0" { 187 | t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) 188 | } 189 | if result.ID != 1 { 190 | t.Errorf("Expected ID 1, got %d", result.ID) 191 | } 192 | if result.Method != "debug/echo" { 193 | t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) 194 | } 195 | 196 | if str, ok := result.Params["string"].(string); !ok || str != "hello world" { 197 | t.Errorf("Expected string 'hello world', got %v", result.Params["string"]) 198 | } 199 | 200 | if arr, ok := result.Params["array"].([]interface{}); !ok || len(arr) != 3 { 201 | t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) 202 | } 203 | }) 204 | 205 | t.Run("SendRequestWithTimeout", func(t *testing.T) { 206 | // Create a context that's already canceled 207 | ctx, cancel := context.WithCancel(context.Background()) 208 | cancel() // Cancel the context immediately 209 | 210 | // Prepare a request 211 | request := JSONRPCRequest{ 212 | JSONRPC: "2.0", 213 | ID: 3, 214 | Method: "debug/echo", 215 | } 216 | 217 | // The request should fail because the context is canceled 218 | _, err := trans.SendRequest(ctx, request) 219 | if err == nil { 220 | t.Errorf("Expected context canceled error, got nil") 221 | } else if !errors.Is(err, context.Canceled) { 222 | t.Errorf("Expected context.Canceled error, got: %v", err) 223 | } 224 | }) 225 | 226 | t.Run("SendNotification & NotificationHandler", func(t *testing.T) { 227 | var wg sync.WaitGroup 228 | notificationChan := make(chan mcp.JSONRPCNotification, 1) 229 | 230 | // Set notification handler 231 | trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { 232 | notificationChan <- notification 233 | }) 234 | 235 | // Send a request that triggers a notification 236 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 237 | defer cancel() 238 | 239 | request := JSONRPCRequest{ 240 | JSONRPC: "2.0", 241 | ID: 1, 242 | Method: "debug/echo_notification", 243 | } 244 | 245 | _, err := trans.SendRequest(ctx, request) 246 | if err != nil { 247 | t.Fatalf("SendRequest failed: %v", err) 248 | } 249 | 250 | wg.Add(1) 251 | go func() { 252 | defer wg.Done() 253 | select { 254 | case notification := <-notificationChan: 255 | // We received a notification 256 | got := notification.Params.AdditionalFields 257 | if got == nil { 258 | t.Errorf("Notification handler did not send the expected notification: got nil") 259 | } 260 | if int64(got["id"].(float64)) != request.ID || 261 | got["jsonrpc"] != request.JSONRPC || 262 | got["method"] != request.Method { 263 | 264 | responseJson, _ := json.Marshal(got) 265 | requestJson, _ := json.Marshal(request) 266 | t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) 267 | } 268 | 269 | case <-time.After(1 * time.Second): 270 | t.Errorf("Expected notification, got none") 271 | } 272 | }() 273 | 274 | wg.Wait() 275 | }) 276 | 277 | t.Run("MultipleRequests", func(t *testing.T) { 278 | var wg sync.WaitGroup 279 | const numRequests = 5 280 | 281 | // Send multiple requests concurrently 282 | mu := sync.Mutex{} 283 | responses := make([]*JSONRPCResponse, numRequests) 284 | errors := make([]error, numRequests) 285 | 286 | for i := 0; i < numRequests; i++ { 287 | wg.Add(1) 288 | go func(idx int) { 289 | defer wg.Done() 290 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 291 | defer cancel() 292 | 293 | // Each request has a unique ID and payload 294 | request := JSONRPCRequest{ 295 | JSONRPC: "2.0", 296 | ID: int64(100 + idx), 297 | Method: "debug/echo", 298 | Params: map[string]interface{}{ 299 | "requestIndex": idx, 300 | "timestamp": time.Now().UnixNano(), 301 | }, 302 | } 303 | 304 | resp, err := trans.SendRequest(ctx, request) 305 | mu.Lock() 306 | responses[idx] = resp 307 | errors[idx] = err 308 | mu.Unlock() 309 | }(i) 310 | } 311 | 312 | wg.Wait() 313 | 314 | // Check results 315 | for i := 0; i < numRequests; i++ { 316 | if errors[i] != nil { 317 | t.Errorf("Request %d failed: %v", i, errors[i]) 318 | continue 319 | } 320 | 321 | if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { 322 | t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) 323 | continue 324 | } 325 | 326 | // Parse the result to verify echo 327 | var result struct { 328 | JSONRPC string `json:"jsonrpc"` 329 | ID int64 `json:"id"` 330 | Method string `json:"method"` 331 | Params map[string]interface{} `json:"params"` 332 | } 333 | 334 | if err := json.Unmarshal(responses[i].Result, &result); err != nil { 335 | t.Errorf("Request %d: Failed to unmarshal result: %v", i, err) 336 | continue 337 | } 338 | 339 | // Verify data matches what was sent 340 | if result.ID != int64(100+i) { 341 | t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) 342 | } 343 | 344 | if result.Method != "debug/echo" { 345 | t.Errorf("Request %d: Expected method 'debug/echo', got '%s'", i, result.Method) 346 | } 347 | 348 | // Verify the requestIndex parameter 349 | if idx, ok := result.Params["requestIndex"].(float64); !ok || int(idx) != i { 350 | t.Errorf("Request %d: Expected requestIndex %d, got %v", i, i, result.Params["requestIndex"]) 351 | } 352 | } 353 | }) 354 | 355 | t.Run("ResponseError", func(t *testing.T) { 356 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 357 | defer cancel() 358 | 359 | // Prepare a request 360 | request := JSONRPCRequest{ 361 | JSONRPC: "2.0", 362 | ID: 100, 363 | Method: "debug/echo_error_string", 364 | } 365 | 366 | reps, err := trans.SendRequest(ctx, request) 367 | if err != nil { 368 | t.Errorf("SendRequest failed: %v", err) 369 | } 370 | 371 | if reps.Error == nil { 372 | t.Errorf("Expected error, got nil") 373 | } 374 | 375 | var responseError JSONRPCRequest 376 | if err := json.Unmarshal([]byte(reps.Error.Message), &responseError); err != nil { 377 | t.Errorf("Failed to unmarshal result: %v", err) 378 | return 379 | } 380 | 381 | if responseError.Method != "debug/echo_error_string" { 382 | t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) 383 | } 384 | if responseError.ID != 100 { 385 | t.Errorf("Expected ID 100, got %d", responseError.ID) 386 | } 387 | if responseError.JSONRPC != "2.0" { 388 | t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) 389 | } 390 | }) 391 | } 392 | 393 | func TestStreamableHTTPErrors(t *testing.T) { 394 | t.Run("InvalidURL", func(t *testing.T) { 395 | // Create a new StreamableHTTP transport with an invalid URL 396 | _, err := NewStreamableHTTP("://invalid-url") 397 | if err == nil { 398 | t.Errorf("Expected error when creating with invalid URL, got nil") 399 | } 400 | }) 401 | 402 | t.Run("NonExistentURL", func(t *testing.T) { 403 | // Create a new StreamableHTTP transport with a non-existent URL 404 | trans, err := NewStreamableHTTP("http://localhost:1") 405 | if err != nil { 406 | t.Fatalf("Failed to create StreamableHTTP transport: %v", err) 407 | } 408 | 409 | // Send request should fail 410 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 411 | defer cancel() 412 | 413 | request := JSONRPCRequest{ 414 | JSONRPC: "2.0", 415 | ID: 1, 416 | Method: "initialize", 417 | } 418 | 419 | _, err = trans.SendRequest(ctx, request) 420 | if err == nil { 421 | t.Errorf("Expected error when sending request to non-existent URL, got nil") 422 | } 423 | }) 424 | 425 | } 426 | -------------------------------------------------------------------------------- /examples/custom_context/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "os" 12 | 13 | "github.com/mark3labs/mcp-go/mcp" 14 | "github.com/mark3labs/mcp-go/server" 15 | ) 16 | 17 | // authKey is a custom context key for storing the auth token. 18 | type authKey struct{} 19 | 20 | // withAuthKey adds an auth key to the context. 21 | func withAuthKey(ctx context.Context, auth string) context.Context { 22 | return context.WithValue(ctx, authKey{}, auth) 23 | } 24 | 25 | // authFromRequest extracts the auth token from the request headers. 26 | func authFromRequest(ctx context.Context, r *http.Request) context.Context { 27 | return withAuthKey(ctx, r.Header.Get("Authorization")) 28 | } 29 | 30 | // authFromEnv extracts the auth token from the environment 31 | func authFromEnv(ctx context.Context) context.Context { 32 | return withAuthKey(ctx, os.Getenv("API_KEY")) 33 | } 34 | 35 | // tokenFromContext extracts the auth token from the context. 36 | // This can be used by tools to extract the token regardless of the 37 | // transport being used by the server. 38 | func tokenFromContext(ctx context.Context) (string, error) { 39 | auth, ok := ctx.Value(authKey{}).(string) 40 | if !ok { 41 | return "", fmt.Errorf("missing auth") 42 | } 43 | return auth, nil 44 | } 45 | 46 | type response struct { 47 | Args map[string]interface{} `json:"args"` 48 | Headers map[string]string `json:"headers"` 49 | } 50 | 51 | // makeRequest makes a request to httpbin.org including the auth token in the request 52 | // headers and the message in the query string. 53 | func makeRequest(ctx context.Context, message, token string) (*response, error) { 54 | req, err := http.NewRequestWithContext(ctx, "GET", "https://httpbin.org/anything", nil) 55 | if err != nil { 56 | return nil, err 57 | } 58 | req.Header.Set("Authorization", token) 59 | query := req.URL.Query() 60 | query.Add("message", message) 61 | req.URL.RawQuery = query.Encode() 62 | resp, err := http.DefaultClient.Do(req) 63 | if err != nil { 64 | return nil, err 65 | } 66 | defer resp.Body.Close() 67 | body, err := io.ReadAll(resp.Body) 68 | if err != nil { 69 | return nil, err 70 | } 71 | var r *response 72 | if err := json.Unmarshal(body, &r); err != nil { 73 | return nil, err 74 | } 75 | return r, nil 76 | } 77 | 78 | // handleMakeAuthenticatedRequestTool is a tool that makes an authenticated request 79 | // using the token from the context. 80 | func handleMakeAuthenticatedRequestTool( 81 | ctx context.Context, 82 | request mcp.CallToolRequest, 83 | ) (*mcp.CallToolResult, error) { 84 | message, ok := request.Params.Arguments["message"].(string) 85 | if !ok { 86 | return nil, fmt.Errorf("missing message") 87 | } 88 | token, err := tokenFromContext(ctx) 89 | if err != nil { 90 | return nil, fmt.Errorf("missing token: %v", err) 91 | } 92 | // Now our tool can make a request with the token, irrespective of where it came from. 93 | resp, err := makeRequest(ctx, message, token) 94 | if err != nil { 95 | return nil, err 96 | } 97 | return mcp.NewToolResultText(fmt.Sprintf("%+v", resp)), nil 98 | } 99 | 100 | type MCPServer struct { 101 | server *server.MCPServer 102 | } 103 | 104 | func NewMCPServer() *MCPServer { 105 | mcpServer := server.NewMCPServer( 106 | "example-server", 107 | "1.0.0", 108 | server.WithResourceCapabilities(true, true), 109 | server.WithPromptCapabilities(true), 110 | server.WithToolCapabilities(true), 111 | ) 112 | mcpServer.AddTool(mcp.NewTool("make_authenticated_request", 113 | mcp.WithDescription("Makes an authenticated request"), 114 | mcp.WithString("message", 115 | mcp.Description("Message to echo"), 116 | mcp.Required(), 117 | ), 118 | ), handleMakeAuthenticatedRequestTool) 119 | 120 | return &MCPServer{ 121 | server: mcpServer, 122 | } 123 | } 124 | 125 | func (s *MCPServer) ServeSSE(addr string) *server.SSEServer { 126 | return server.NewSSEServer(s.server, 127 | server.WithBaseURL(fmt.Sprintf("http://%s", addr)), 128 | server.WithSSEContextFunc(authFromRequest), 129 | ) 130 | } 131 | 132 | func (s *MCPServer) ServeStdio() error { 133 | return server.ServeStdio(s.server, server.WithStdioContextFunc(authFromEnv)) 134 | } 135 | 136 | func main() { 137 | var transport string 138 | flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)") 139 | flag.StringVar( 140 | &transport, 141 | "transport", 142 | "stdio", 143 | "Transport type (stdio or sse)", 144 | ) 145 | flag.Parse() 146 | 147 | s := NewMCPServer() 148 | 149 | switch transport { 150 | case "stdio": 151 | if err := s.ServeStdio(); err != nil { 152 | log.Fatalf("Server error: %v", err) 153 | } 154 | case "sse": 155 | sseServer := s.ServeSSE("localhost:8080") 156 | log.Printf("SSE server listening on :8080") 157 | if err := sseServer.Start(":8080"); err != nil { 158 | log.Fatalf("Server error: %v", err) 159 | } 160 | default: 161 | log.Fatalf( 162 | "Invalid transport type: %s. Must be 'stdio' or 'sse'", 163 | transport, 164 | ) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /examples/filesystem_stdio_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "time" 9 | 10 | "github.com/mark3labs/mcp-go/client" 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | func main() { 15 | c, err := client.NewStdioMCPClient( 16 | "npx", 17 | []string{}, // Empty ENV 18 | "-y", 19 | "@modelcontextprotocol/server-filesystem", 20 | "/tmp", 21 | ) 22 | if err != nil { 23 | log.Fatalf("Failed to create client: %v", err) 24 | } 25 | defer c.Close() 26 | 27 | // Create context with timeout 28 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 29 | defer cancel() 30 | 31 | // Initialize the client 32 | fmt.Println("Initializing client...") 33 | initRequest := mcp.InitializeRequest{} 34 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 35 | initRequest.Params.ClientInfo = mcp.Implementation{ 36 | Name: "example-client", 37 | Version: "1.0.0", 38 | } 39 | 40 | initResult, err := c.Initialize(ctx, initRequest) 41 | if err != nil { 42 | log.Fatalf("Failed to initialize: %v", err) 43 | } 44 | fmt.Printf( 45 | "Initialized with server: %s %s\n\n", 46 | initResult.ServerInfo.Name, 47 | initResult.ServerInfo.Version, 48 | ) 49 | 50 | // List Tools 51 | fmt.Println("Listing available tools...") 52 | toolsRequest := mcp.ListToolsRequest{} 53 | tools, err := c.ListTools(ctx, toolsRequest) 54 | if err != nil { 55 | log.Fatalf("Failed to list tools: %v", err) 56 | } 57 | for _, tool := range tools.Tools { 58 | fmt.Printf("- %s: %s\n", tool.Name, tool.Description) 59 | } 60 | fmt.Println() 61 | 62 | // List allowed directories 63 | fmt.Println("Listing allowed directories...") 64 | listDirRequest := mcp.CallToolRequest{ 65 | Request: mcp.Request{ 66 | Method: "tools/call", 67 | }, 68 | } 69 | listDirRequest.Params.Name = "list_allowed_directories" 70 | 71 | result, err := c.CallTool(ctx, listDirRequest) 72 | if err != nil { 73 | log.Fatalf("Failed to list allowed directories: %v", err) 74 | } 75 | printToolResult(result) 76 | fmt.Println() 77 | 78 | // List /tmp 79 | fmt.Println("Listing /tmp directory...") 80 | listTmpRequest := mcp.CallToolRequest{} 81 | listTmpRequest.Params.Name = "list_directory" 82 | listTmpRequest.Params.Arguments = map[string]interface{}{ 83 | "path": "/tmp", 84 | } 85 | 86 | result, err = c.CallTool(ctx, listTmpRequest) 87 | if err != nil { 88 | log.Fatalf("Failed to list directory: %v", err) 89 | } 90 | printToolResult(result) 91 | fmt.Println() 92 | 93 | // Create mcp directory 94 | fmt.Println("Creating /tmp/mcp directory...") 95 | createDirRequest := mcp.CallToolRequest{} 96 | createDirRequest.Params.Name = "create_directory" 97 | createDirRequest.Params.Arguments = map[string]interface{}{ 98 | "path": "/tmp/mcp", 99 | } 100 | 101 | result, err = c.CallTool(ctx, createDirRequest) 102 | if err != nil { 103 | log.Fatalf("Failed to create directory: %v", err) 104 | } 105 | printToolResult(result) 106 | fmt.Println() 107 | 108 | // Create hello.txt 109 | fmt.Println("Creating /tmp/mcp/hello.txt...") 110 | writeFileRequest := mcp.CallToolRequest{} 111 | writeFileRequest.Params.Name = "write_file" 112 | writeFileRequest.Params.Arguments = map[string]interface{}{ 113 | "path": "/tmp/mcp/hello.txt", 114 | "content": "Hello World", 115 | } 116 | 117 | result, err = c.CallTool(ctx, writeFileRequest) 118 | if err != nil { 119 | log.Fatalf("Failed to create file: %v", err) 120 | } 121 | printToolResult(result) 122 | fmt.Println() 123 | 124 | // Verify file contents 125 | fmt.Println("Reading /tmp/mcp/hello.txt...") 126 | readFileRequest := mcp.CallToolRequest{} 127 | readFileRequest.Params.Name = "read_file" 128 | readFileRequest.Params.Arguments = map[string]interface{}{ 129 | "path": "/tmp/mcp/hello.txt", 130 | } 131 | 132 | result, err = c.CallTool(ctx, readFileRequest) 133 | if err != nil { 134 | log.Fatalf("Failed to read file: %v", err) 135 | } 136 | printToolResult(result) 137 | 138 | // Get file info 139 | fmt.Println("Getting info for /tmp/mcp/hello.txt...") 140 | fileInfoRequest := mcp.CallToolRequest{} 141 | fileInfoRequest.Params.Name = "get_file_info" 142 | fileInfoRequest.Params.Arguments = map[string]interface{}{ 143 | "path": "/tmp/mcp/hello.txt", 144 | } 145 | 146 | result, err = c.CallTool(ctx, fileInfoRequest) 147 | if err != nil { 148 | log.Fatalf("Failed to get file info: %v", err) 149 | } 150 | printToolResult(result) 151 | } 152 | 153 | // Helper function to print tool results 154 | func printToolResult(result *mcp.CallToolResult) { 155 | for _, content := range result.Content { 156 | if textContent, ok := content.(mcp.TextContent); ok { 157 | fmt.Println(textContent.Text) 158 | } else { 159 | jsonBytes, _ := json.MarshalIndent(content, "", " ") 160 | fmt.Println(string(jsonBytes)) 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mark3labs/mcp-go 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/google/uuid v1.6.0 7 | github.com/spf13/cast v1.7.1 8 | github.com/stretchr/testify v1.9.0 9 | github.com/yosida95/uritemplate/v3 v3.0.2 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | gopkg.in/yaml.v3 v3.0.1 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 4 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 5 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 6 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 7 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 8 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 9 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 10 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 11 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 12 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 16 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 17 | github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= 18 | github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 19 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 20 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 21 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= 22 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= 23 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 26 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 27 | -------------------------------------------------------------------------------- /mcp/prompts.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | /* Prompts */ 4 | 5 | // ListPromptsRequest is sent from the client to request a list of prompts and 6 | // prompt templates the server has. 7 | type ListPromptsRequest struct { 8 | PaginatedRequest 9 | } 10 | 11 | // ListPromptsResult is the server's response to a prompts/list request from 12 | // the client. 13 | type ListPromptsResult struct { 14 | PaginatedResult 15 | Prompts []Prompt `json:"prompts"` 16 | } 17 | 18 | // GetPromptRequest is used by the client to get a prompt provided by the 19 | // server. 20 | type GetPromptRequest struct { 21 | Request 22 | Params struct { 23 | // The name of the prompt or prompt template. 24 | Name string `json:"name"` 25 | // Arguments to use for templating the prompt. 26 | Arguments map[string]string `json:"arguments,omitempty"` 27 | } `json:"params"` 28 | } 29 | 30 | // GetPromptResult is the server's response to a prompts/get request from the 31 | // client. 32 | type GetPromptResult struct { 33 | Result 34 | // An optional description for the prompt. 35 | Description string `json:"description,omitempty"` 36 | Messages []PromptMessage `json:"messages"` 37 | } 38 | 39 | // Prompt represents a prompt or prompt template that the server offers. 40 | // If Arguments is non-nil and non-empty, this indicates the prompt is a template 41 | // that requires argument values to be provided when calling prompts/get. 42 | // If Arguments is nil or empty, this is a static prompt that takes no arguments. 43 | type Prompt struct { 44 | // The name of the prompt or prompt template. 45 | Name string `json:"name"` 46 | // An optional description of what this prompt provides 47 | Description string `json:"description,omitempty"` 48 | // A list of arguments to use for templating the prompt. 49 | // The presence of arguments indicates this is a template prompt. 50 | Arguments []PromptArgument `json:"arguments,omitempty"` 51 | } 52 | 53 | // PromptArgument describes an argument that a prompt template can accept. 54 | // When a prompt includes arguments, clients must provide values for all 55 | // required arguments when making a prompts/get request. 56 | type PromptArgument struct { 57 | // The name of the argument. 58 | Name string `json:"name"` 59 | // A human-readable description of the argument. 60 | Description string `json:"description,omitempty"` 61 | // Whether this argument must be provided. 62 | // If true, clients must include this argument when calling prompts/get. 63 | Required bool `json:"required,omitempty"` 64 | } 65 | 66 | // Role represents the sender or recipient of messages and data in a 67 | // conversation. 68 | type Role string 69 | 70 | const ( 71 | RoleUser Role = "user" 72 | RoleAssistant Role = "assistant" 73 | ) 74 | 75 | // PromptMessage describes a message returned as part of a prompt. 76 | // 77 | // This is similar to `SamplingMessage`, but also supports the embedding of 78 | // resources from the MCP server. 79 | type PromptMessage struct { 80 | Role Role `json:"role"` 81 | Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource 82 | } 83 | 84 | // PromptListChangedNotification is an optional notification from the server 85 | // to the client, informing it that the list of prompts it offers has changed. This 86 | // may be issued by servers without any previous subscription from the client. 87 | type PromptListChangedNotification struct { 88 | Notification 89 | } 90 | 91 | // PromptOption is a function that configures a Prompt. 92 | // It provides a flexible way to set various properties of a Prompt using the functional options pattern. 93 | type PromptOption func(*Prompt) 94 | 95 | // ArgumentOption is a function that configures a PromptArgument. 96 | // It allows for flexible configuration of prompt arguments using the functional options pattern. 97 | type ArgumentOption func(*PromptArgument) 98 | 99 | // 100 | // Core Prompt Functions 101 | // 102 | 103 | // NewPrompt creates a new Prompt with the given name and options. 104 | // The prompt will be configured based on the provided options. 105 | // Options are applied in order, allowing for flexible prompt configuration. 106 | func NewPrompt(name string, opts ...PromptOption) Prompt { 107 | prompt := Prompt{ 108 | Name: name, 109 | } 110 | 111 | for _, opt := range opts { 112 | opt(&prompt) 113 | } 114 | 115 | return prompt 116 | } 117 | 118 | // WithPromptDescription adds a description to the Prompt. 119 | // The description should provide a clear, human-readable explanation of what the prompt does. 120 | func WithPromptDescription(description string) PromptOption { 121 | return func(p *Prompt) { 122 | p.Description = description 123 | } 124 | } 125 | 126 | // WithArgument adds an argument to the prompt's argument list. 127 | // The argument will be configured based on the provided options. 128 | func WithArgument(name string, opts ...ArgumentOption) PromptOption { 129 | return func(p *Prompt) { 130 | arg := PromptArgument{ 131 | Name: name, 132 | } 133 | 134 | for _, opt := range opts { 135 | opt(&arg) 136 | } 137 | 138 | if p.Arguments == nil { 139 | p.Arguments = make([]PromptArgument, 0) 140 | } 141 | p.Arguments = append(p.Arguments, arg) 142 | } 143 | } 144 | 145 | // 146 | // Argument Options 147 | // 148 | 149 | // ArgumentDescription adds a description to a prompt argument. 150 | // The description should explain the purpose and expected values of the argument. 151 | func ArgumentDescription(desc string) ArgumentOption { 152 | return func(arg *PromptArgument) { 153 | arg.Description = desc 154 | } 155 | } 156 | 157 | // RequiredArgument marks an argument as required in the prompt. 158 | // Required arguments must be provided when getting the prompt. 159 | func RequiredArgument() ArgumentOption { 160 | return func(arg *PromptArgument) { 161 | arg.Required = true 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /mcp/resources.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import "github.com/yosida95/uritemplate/v3" 4 | 5 | // ResourceOption is a function that configures a Resource. 6 | // It provides a flexible way to set various properties of a Resource using the functional options pattern. 7 | type ResourceOption func(*Resource) 8 | 9 | // NewResource creates a new Resource with the given URI, name and options. 10 | // The resource will be configured based on the provided options. 11 | // Options are applied in order, allowing for flexible resource configuration. 12 | func NewResource(uri string, name string, opts ...ResourceOption) Resource { 13 | resource := Resource{ 14 | URI: uri, 15 | Name: name, 16 | } 17 | 18 | for _, opt := range opts { 19 | opt(&resource) 20 | } 21 | 22 | return resource 23 | } 24 | 25 | // WithResourceDescription adds a description to the Resource. 26 | // The description should provide a clear, human-readable explanation of what the resource represents. 27 | func WithResourceDescription(description string) ResourceOption { 28 | return func(r *Resource) { 29 | r.Description = description 30 | } 31 | } 32 | 33 | // WithMIMEType sets the MIME type for the Resource. 34 | // This should indicate the format of the resource's contents. 35 | func WithMIMEType(mimeType string) ResourceOption { 36 | return func(r *Resource) { 37 | r.MIMEType = mimeType 38 | } 39 | } 40 | 41 | // WithAnnotations adds annotations to the Resource. 42 | // Annotations can provide additional metadata about the resource's intended use. 43 | func WithAnnotations(audience []Role, priority float64) ResourceOption { 44 | return func(r *Resource) { 45 | if r.Annotations == nil { 46 | r.Annotations = &struct { 47 | Audience []Role `json:"audience,omitempty"` 48 | Priority float64 `json:"priority,omitempty"` 49 | }{} 50 | } 51 | r.Annotations.Audience = audience 52 | r.Annotations.Priority = priority 53 | } 54 | } 55 | 56 | // ResourceTemplateOption is a function that configures a ResourceTemplate. 57 | // It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. 58 | type ResourceTemplateOption func(*ResourceTemplate) 59 | 60 | // NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. 61 | // The template will be configured based on the provided options. 62 | // Options are applied in order, allowing for flexible template configuration. 63 | func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { 64 | template := ResourceTemplate{ 65 | URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, 66 | Name: name, 67 | } 68 | 69 | for _, opt := range opts { 70 | opt(&template) 71 | } 72 | 73 | return template 74 | } 75 | 76 | // WithTemplateDescription adds a description to the ResourceTemplate. 77 | // The description should provide a clear, human-readable explanation of what resources this template represents. 78 | func WithTemplateDescription(description string) ResourceTemplateOption { 79 | return func(t *ResourceTemplate) { 80 | t.Description = description 81 | } 82 | } 83 | 84 | // WithTemplateMIMEType sets the MIME type for the ResourceTemplate. 85 | // This should only be set if all resources matching this template will have the same type. 86 | func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { 87 | return func(t *ResourceTemplate) { 88 | t.MIMEType = mimeType 89 | } 90 | } 91 | 92 | // WithTemplateAnnotations adds annotations to the ResourceTemplate. 93 | // Annotations can provide additional metadata about the template's intended use. 94 | func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { 95 | return func(t *ResourceTemplate) { 96 | if t.Annotations == nil { 97 | t.Annotations = &struct { 98 | Audience []Role `json:"audience,omitempty"` 99 | Priority float64 `json:"priority,omitempty"` 100 | }{} 101 | } 102 | t.Annotations.Audience = audience 103 | t.Annotations.Priority = priority 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /mcp/tools_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestToolWithBothSchemasError verifies that there will be feedback if the 12 | // developer mixes raw schema with a schema provided via DSL. 13 | func TestToolWithBothSchemasError(t *testing.T) { 14 | // Create a tool with both schemas set 15 | tool := NewTool("dual-schema-tool", 16 | WithDescription("A tool with both schemas set"), 17 | WithString("input", Description("Test input")), 18 | ) 19 | 20 | _, err := json.Marshal(tool) 21 | assert.Nil(t, err) 22 | 23 | // Set the RawInputSchema as well - this should conflict with the InputSchema 24 | // Note: InputSchema.Type is explicitly set to "object" in NewTool 25 | tool.RawInputSchema = json.RawMessage(`{"type":"string"}`) 26 | 27 | // Attempt to marshal to JSON 28 | _, err = json.Marshal(tool) 29 | 30 | // Should return an error 31 | assert.ErrorIs(t, err, errToolSchemaConflict) 32 | } 33 | 34 | func TestToolWithRawSchema(t *testing.T) { 35 | // Create a complex raw schema 36 | rawSchema := json.RawMessage(`{ 37 | "type": "object", 38 | "properties": { 39 | "query": {"type": "string", "description": "Search query"}, 40 | "limit": {"type": "integer", "minimum": 1, "maximum": 50} 41 | }, 42 | "required": ["query"] 43 | }`) 44 | 45 | // Create a tool with raw schema 46 | tool := NewToolWithRawSchema("search-tool", "Search API", rawSchema) 47 | 48 | // Marshal to JSON 49 | data, err := json.Marshal(tool) 50 | assert.NoError(t, err) 51 | 52 | // Unmarshal to verify the structure 53 | var result map[string]interface{} 54 | err = json.Unmarshal(data, &result) 55 | assert.NoError(t, err) 56 | 57 | // Verify tool properties 58 | assert.Equal(t, "search-tool", result["name"]) 59 | assert.Equal(t, "Search API", result["description"]) 60 | 61 | // Verify schema was properly included 62 | schema, ok := result["inputSchema"].(map[string]interface{}) 63 | assert.True(t, ok) 64 | assert.Equal(t, "object", schema["type"]) 65 | 66 | properties, ok := schema["properties"].(map[string]interface{}) 67 | assert.True(t, ok) 68 | 69 | query, ok := properties["query"].(map[string]interface{}) 70 | assert.True(t, ok) 71 | assert.Equal(t, "string", query["type"]) 72 | 73 | required, ok := schema["required"].([]interface{}) 74 | assert.True(t, ok) 75 | assert.Contains(t, required, "query") 76 | } 77 | 78 | func TestUnmarshalToolWithRawSchema(t *testing.T) { 79 | // Create a complex raw schema 80 | rawSchema := json.RawMessage(`{ 81 | "type": "object", 82 | "properties": { 83 | "query": {"type": "string", "description": "Search query"}, 84 | "limit": {"type": "integer", "minimum": 1, "maximum": 50} 85 | }, 86 | "required": ["query"] 87 | }`) 88 | 89 | // Create a tool with raw schema 90 | tool := NewToolWithRawSchema("search-tool", "Search API", rawSchema) 91 | 92 | // Marshal to JSON 93 | data, err := json.Marshal(tool) 94 | assert.NoError(t, err) 95 | 96 | // Unmarshal to verify the structure 97 | var toolUnmarshalled Tool 98 | err = json.Unmarshal(data, &toolUnmarshalled) 99 | assert.NoError(t, err) 100 | 101 | // Verify tool properties 102 | assert.Equal(t, tool.Name, toolUnmarshalled.Name) 103 | assert.Equal(t, tool.Description, toolUnmarshalled.Description) 104 | 105 | // Verify schema was properly included 106 | assert.Equal(t, "object", toolUnmarshalled.InputSchema.Type) 107 | assert.Contains(t, toolUnmarshalled.InputSchema.Properties, "query") 108 | assert.Subset(t, toolUnmarshalled.InputSchema.Properties["query"], map[string]interface{}{ 109 | "type": "string", 110 | "description": "Search query", 111 | }) 112 | assert.Contains(t, toolUnmarshalled.InputSchema.Properties, "limit") 113 | assert.Subset(t, toolUnmarshalled.InputSchema.Properties["limit"], map[string]interface{}{ 114 | "type": "integer", 115 | "minimum": 1.0, 116 | "maximum": 50.0, 117 | }) 118 | assert.Subset(t, toolUnmarshalled.InputSchema.Required, []string{"query"}) 119 | } 120 | 121 | func TestUnmarshalToolWithoutRawSchema(t *testing.T) { 122 | // Create a tool with both schemas set 123 | tool := NewTool("dual-schema-tool", 124 | WithDescription("A tool with both schemas set"), 125 | WithString("input", Description("Test input")), 126 | ) 127 | 128 | data, err := json.Marshal(tool) 129 | assert.Nil(t, err) 130 | 131 | // Unmarshal to verify the structure 132 | var toolUnmarshalled Tool 133 | err = json.Unmarshal(data, &toolUnmarshalled) 134 | assert.NoError(t, err) 135 | 136 | // Verify tool properties 137 | assert.Equal(t, tool.Name, toolUnmarshalled.Name) 138 | assert.Equal(t, tool.Description, toolUnmarshalled.Description) 139 | assert.Subset(t, toolUnmarshalled.InputSchema.Properties["input"], map[string]interface{}{ 140 | "type": "string", 141 | "description": "Test input", 142 | }) 143 | assert.Empty(t, toolUnmarshalled.InputSchema.Required) 144 | assert.Empty(t, toolUnmarshalled.RawInputSchema) 145 | } 146 | 147 | func TestToolWithObjectAndArray(t *testing.T) { 148 | // Create a tool with both object and array properties 149 | tool := NewTool("reading-list", 150 | WithDescription("A tool for managing reading lists"), 151 | WithObject("preferences", 152 | Description("User preferences for the reading list"), 153 | Properties(map[string]interface{}{ 154 | "theme": map[string]interface{}{ 155 | "type": "string", 156 | "description": "UI theme preference", 157 | "enum": []string{"light", "dark"}, 158 | }, 159 | "maxItems": map[string]interface{}{ 160 | "type": "number", 161 | "description": "Maximum number of items in the list", 162 | "minimum": 1, 163 | "maximum": 100, 164 | }, 165 | })), 166 | WithArray("books", 167 | Description("List of books to read"), 168 | Required(), 169 | Items(map[string]interface{}{ 170 | "type": "object", 171 | "properties": map[string]interface{}{ 172 | "title": map[string]interface{}{ 173 | "type": "string", 174 | "description": "Book title", 175 | "required": true, 176 | }, 177 | "author": map[string]interface{}{ 178 | "type": "string", 179 | "description": "Book author", 180 | }, 181 | "year": map[string]interface{}{ 182 | "type": "number", 183 | "description": "Publication year", 184 | "minimum": 1000, 185 | }, 186 | }, 187 | }))) 188 | 189 | // Marshal to JSON 190 | data, err := json.Marshal(tool) 191 | assert.NoError(t, err) 192 | 193 | // Unmarshal to verify the structure 194 | var result map[string]interface{} 195 | err = json.Unmarshal(data, &result) 196 | assert.NoError(t, err) 197 | 198 | // Verify tool properties 199 | assert.Equal(t, "reading-list", result["name"]) 200 | assert.Equal(t, "A tool for managing reading lists", result["description"]) 201 | 202 | // Verify schema was properly included 203 | schema, ok := result["inputSchema"].(map[string]interface{}) 204 | assert.True(t, ok) 205 | assert.Equal(t, "object", schema["type"]) 206 | 207 | // Verify properties 208 | properties, ok := schema["properties"].(map[string]interface{}) 209 | assert.True(t, ok) 210 | 211 | // Verify preferences object 212 | preferences, ok := properties["preferences"].(map[string]interface{}) 213 | assert.True(t, ok) 214 | assert.Equal(t, "object", preferences["type"]) 215 | assert.Equal(t, "User preferences for the reading list", preferences["description"]) 216 | 217 | prefProps, ok := preferences["properties"].(map[string]interface{}) 218 | assert.True(t, ok) 219 | assert.Contains(t, prefProps, "theme") 220 | assert.Contains(t, prefProps, "maxItems") 221 | 222 | // Verify books array 223 | books, ok := properties["books"].(map[string]interface{}) 224 | assert.True(t, ok) 225 | assert.Equal(t, "array", books["type"]) 226 | assert.Equal(t, "List of books to read", books["description"]) 227 | 228 | // Verify array items schema 229 | items, ok := books["items"].(map[string]interface{}) 230 | assert.True(t, ok) 231 | assert.Equal(t, "object", items["type"]) 232 | 233 | itemProps, ok := items["properties"].(map[string]interface{}) 234 | assert.True(t, ok) 235 | assert.Contains(t, itemProps, "title") 236 | assert.Contains(t, itemProps, "author") 237 | assert.Contains(t, itemProps, "year") 238 | 239 | // Verify required fields 240 | required, ok := schema["required"].([]interface{}) 241 | assert.True(t, ok) 242 | assert.Contains(t, required, "books") 243 | } 244 | 245 | func TestParseToolCallToolRequest(t *testing.T) { 246 | request := CallToolRequest{} 247 | request.Params.Name = "test-tool" 248 | request.Params.Arguments = map[string]interface{}{ 249 | "bool_value": "true", 250 | "int64_value": "123456789", 251 | "int32_value": "123456789", 252 | "int16_value": "123456789", 253 | "int8_value": "123456789", 254 | "int_value": "123456789", 255 | "uint_value": "123456789", 256 | "uint64_value": "123456789", 257 | "uint32_value": "123456789", 258 | "uint16_value": "123456789", 259 | "uint8_value": "123456789", 260 | "float32_value": "3.14", 261 | "float64_value": "3.1415926", 262 | "string_value": "hello", 263 | } 264 | param1 := ParseBoolean(request, "bool_value", false) 265 | assert.Equal(t, fmt.Sprintf("%T", param1), "bool") 266 | 267 | param2 := ParseInt64(request, "int64_value", 1) 268 | assert.Equal(t, fmt.Sprintf("%T", param2), "int64") 269 | 270 | param3 := ParseInt32(request, "int32_value", 1) 271 | assert.Equal(t, fmt.Sprintf("%T", param3), "int32") 272 | 273 | param4 := ParseInt16(request, "int16_value", 1) 274 | assert.Equal(t, fmt.Sprintf("%T", param4), "int16") 275 | 276 | param5 := ParseInt8(request, "int8_value", 1) 277 | assert.Equal(t, fmt.Sprintf("%T", param5), "int8") 278 | 279 | param6 := ParseInt(request, "int_value", 1) 280 | assert.Equal(t, fmt.Sprintf("%T", param6), "int") 281 | 282 | param7 := ParseUInt(request, "uint_value", 1) 283 | assert.Equal(t, fmt.Sprintf("%T", param7), "uint") 284 | 285 | param8 := ParseUInt64(request, "uint64_value", 1) 286 | assert.Equal(t, fmt.Sprintf("%T", param8), "uint64") 287 | 288 | param9 := ParseUInt32(request, "uint32_value", 1) 289 | assert.Equal(t, fmt.Sprintf("%T", param9), "uint32") 290 | 291 | param10 := ParseUInt16(request, "uint16_value", 1) 292 | assert.Equal(t, fmt.Sprintf("%T", param10), "uint16") 293 | 294 | param11 := ParseUInt8(request, "uint8_value", 1) 295 | assert.Equal(t, fmt.Sprintf("%T", param11), "uint8") 296 | 297 | param12 := ParseFloat32(request, "float32_value", 1.0) 298 | assert.Equal(t, fmt.Sprintf("%T", param12), "float32") 299 | 300 | param13 := ParseFloat64(request, "float64_value", 1.0) 301 | assert.Equal(t, fmt.Sprintf("%T", param13), "float64") 302 | 303 | param14 := ParseString(request, "string_value", "") 304 | assert.Equal(t, fmt.Sprintf("%T", param14), "string") 305 | 306 | param15 := ParseInt64(request, "string_value", 1) 307 | assert.Equal(t, fmt.Sprintf("%T", param15), "int64") 308 | t.Logf("param15 type: %T,value:%v", param15, param15) 309 | 310 | } 311 | -------------------------------------------------------------------------------- /server/internal/gen/README.md: -------------------------------------------------------------------------------- 1 | # Readme for Codegen 2 | 3 | This internal module contains code generation for producing a few repetitive 4 | constructs, namely: 5 | 6 | - The switch statement that handles the request dispatch 7 | - The hook function types and the methods on the Hook struct 8 | 9 | To invoke the code generation: 10 | 11 | ``` 12 | go generate ./... 13 | ``` 14 | 15 | ## Development 16 | 17 | - `request_handler.go.tmpl` generates `server/request_handler.go`, and 18 | - `hooks.go.tmpl` generates `server/hooks.go` 19 | 20 | Inside of `data.go` there is a struct with the inputs to both templates. 21 | 22 | Note that the driver in `main.go` generates code and also pipes it through 23 | `goimports` for formatting and imports cleanup. 24 | 25 | -------------------------------------------------------------------------------- /server/internal/gen/data.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | type MCPRequestType struct { 4 | MethodName string 5 | ParamType string 6 | ResultType string 7 | HookName string 8 | Group string 9 | GroupName string 10 | GroupHookName string 11 | UnmarshalError string 12 | HandlerFunc string 13 | } 14 | 15 | var MCPRequestTypes = []MCPRequestType{ 16 | { 17 | MethodName: "MethodInitialize", 18 | ParamType: "InitializeRequest", 19 | ResultType: "InitializeResult", 20 | HookName: "Initialize", 21 | UnmarshalError: "invalid initialize request", 22 | HandlerFunc: "handleInitialize", 23 | }, { 24 | MethodName: "MethodPing", 25 | ParamType: "PingRequest", 26 | ResultType: "EmptyResult", 27 | HookName: "Ping", 28 | UnmarshalError: "invalid ping request", 29 | HandlerFunc: "handlePing", 30 | }, { 31 | MethodName: "MethodResourcesList", 32 | ParamType: "ListResourcesRequest", 33 | ResultType: "ListResourcesResult", 34 | Group: "resources", 35 | GroupName: "Resources", 36 | GroupHookName: "Resource", 37 | HookName: "ListResources", 38 | UnmarshalError: "invalid list resources request", 39 | HandlerFunc: "handleListResources", 40 | }, { 41 | MethodName: "MethodResourcesTemplatesList", 42 | ParamType: "ListResourceTemplatesRequest", 43 | ResultType: "ListResourceTemplatesResult", 44 | Group: "resources", 45 | GroupName: "Resources", 46 | GroupHookName: "Resource", 47 | HookName: "ListResourceTemplates", 48 | UnmarshalError: "invalid list resource templates request", 49 | HandlerFunc: "handleListResourceTemplates", 50 | }, { 51 | MethodName: "MethodResourcesRead", 52 | ParamType: "ReadResourceRequest", 53 | ResultType: "ReadResourceResult", 54 | Group: "resources", 55 | GroupName: "Resources", 56 | GroupHookName: "Resource", 57 | HookName: "ReadResource", 58 | UnmarshalError: "invalid read resource request", 59 | HandlerFunc: "handleReadResource", 60 | }, { 61 | MethodName: "MethodPromptsList", 62 | ParamType: "ListPromptsRequest", 63 | ResultType: "ListPromptsResult", 64 | Group: "prompts", 65 | GroupName: "Prompts", 66 | GroupHookName: "Prompt", 67 | HookName: "ListPrompts", 68 | UnmarshalError: "invalid list prompts request", 69 | HandlerFunc: "handleListPrompts", 70 | }, { 71 | MethodName: "MethodPromptsGet", 72 | ParamType: "GetPromptRequest", 73 | ResultType: "GetPromptResult", 74 | Group: "prompts", 75 | GroupName: "Prompts", 76 | GroupHookName: "Prompt", 77 | HookName: "GetPrompt", 78 | UnmarshalError: "invalid get prompt request", 79 | HandlerFunc: "handleGetPrompt", 80 | }, { 81 | MethodName: "MethodToolsList", 82 | ParamType: "ListToolsRequest", 83 | ResultType: "ListToolsResult", 84 | Group: "tools", 85 | GroupName: "Tools", 86 | GroupHookName: "Tool", 87 | HookName: "ListTools", 88 | UnmarshalError: "invalid list tools request", 89 | HandlerFunc: "handleListTools", 90 | }, { 91 | MethodName: "MethodToolsCall", 92 | ParamType: "CallToolRequest", 93 | ResultType: "CallToolResult", 94 | Group: "tools", 95 | GroupName: "Tools", 96 | GroupHookName: "Tool", 97 | HookName: "CallTool", 98 | UnmarshalError: "invalid call tool request", 99 | HandlerFunc: "handleToolCall", 100 | }, 101 | } 102 | -------------------------------------------------------------------------------- /server/internal/gen/hooks.go.tmpl: -------------------------------------------------------------------------------- 1 | // Code generated by `go generate`. DO NOT EDIT. 2 | // source: server/internal/gen/hooks.go.tmpl 3 | package server 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. 15 | type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) 16 | 17 | // OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. 18 | type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) 19 | 20 | // BeforeAnyHookFunc is a function that is called after the request is 21 | // parsed but before the method is called. 22 | type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) 23 | 24 | // OnSuccessHookFunc is a hook that will be called after the request 25 | // successfully generates a result, but before the result is sent to the client. 26 | type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) 27 | 28 | // OnErrorHookFunc is a hook that will be called when an error occurs, 29 | // either during the request parsing or the method execution. 30 | // 31 | // Example usage: 32 | // ``` 33 | // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 34 | // // Check for specific error types using errors.Is 35 | // if errors.Is(err, ErrUnsupported) { 36 | // // Handle capability not supported errors 37 | // log.Printf("Capability not supported: %v", err) 38 | // } 39 | // 40 | // // Use errors.As to get specific error types 41 | // var parseErr = &UnparsableMessageError{} 42 | // if errors.As(err, &parseErr) { 43 | // // Access specific methods/fields of the error type 44 | // log.Printf("Failed to parse message for method %s: %v", 45 | // parseErr.GetMethod(), parseErr.Unwrap()) 46 | // // Access the raw message that failed to parse 47 | // rawMsg := parseErr.GetMessage() 48 | // } 49 | // 50 | // // Check for specific resource/prompt/tool errors 51 | // switch { 52 | // case errors.Is(err, ErrResourceNotFound): 53 | // log.Printf("Resource not found: %v", err) 54 | // case errors.Is(err, ErrPromptNotFound): 55 | // log.Printf("Prompt not found: %v", err) 56 | // case errors.Is(err, ErrToolNotFound): 57 | // log.Printf("Tool not found: %v", err) 58 | // } 59 | // }) 60 | type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) 61 | 62 | {{range .}} 63 | type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}) 64 | type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) 65 | {{end}} 66 | 67 | type Hooks struct { 68 | OnRegisterSession []OnRegisterSessionHookFunc 69 | OnUnregisterSession []OnUnregisterSessionHookFunc 70 | OnBeforeAny []BeforeAnyHookFunc 71 | OnSuccess []OnSuccessHookFunc 72 | OnError []OnErrorHookFunc 73 | {{- range .}} 74 | OnBefore{{.HookName}} []OnBefore{{.HookName}}Func 75 | OnAfter{{.HookName}} []OnAfter{{.HookName}}Func 76 | {{- end}} 77 | } 78 | 79 | func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { 80 | c.OnBeforeAny = append(c.OnBeforeAny, hook) 81 | } 82 | 83 | func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { 84 | c.OnSuccess = append(c.OnSuccess, hook) 85 | } 86 | 87 | // AddOnError registers a hook function that will be called when an error occurs. 88 | // The error parameter contains the actual error object, which can be interrogated 89 | // using Go's error handling patterns like errors.Is and errors.As. 90 | // 91 | // Example: 92 | // ``` 93 | // // Create a channel to receive errors for testing 94 | // errChan := make(chan error, 1) 95 | // 96 | // // Register hook to capture and inspect errors 97 | // hooks := &Hooks{} 98 | // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 99 | // // For capability-related errors 100 | // if errors.Is(err, ErrUnsupported) { 101 | // // Handle capability not supported 102 | // errChan <- err 103 | // return 104 | // } 105 | // 106 | // // For parsing errors 107 | // var parseErr = &UnparsableMessageError{} 108 | // if errors.As(err, &parseErr) { 109 | // // Handle unparsable message errors 110 | // fmt.Printf("Failed to parse %s request: %v\n", 111 | // parseErr.GetMethod(), parseErr.Unwrap()) 112 | // errChan <- parseErr 113 | // return 114 | // } 115 | // 116 | // // For resource/prompt/tool not found errors 117 | // if errors.Is(err, ErrResourceNotFound) || 118 | // errors.Is(err, ErrPromptNotFound) || 119 | // errors.Is(err, ErrToolNotFound) { 120 | // // Handle not found errors 121 | // errChan <- err 122 | // return 123 | // } 124 | // 125 | // // For other errors 126 | // errChan <- err 127 | // }) 128 | // 129 | // server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) 130 | // ``` 131 | func (c *Hooks) AddOnError(hook OnErrorHookFunc) { 132 | c.OnError = append(c.OnError, hook) 133 | } 134 | 135 | func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { 136 | if c == nil { 137 | return 138 | } 139 | for _, hook := range c.OnBeforeAny { 140 | hook(ctx, id, method, message) 141 | } 142 | } 143 | 144 | func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { 145 | if c == nil { 146 | return 147 | } 148 | for _, hook := range c.OnSuccess { 149 | hook(ctx, id, method, message, result) 150 | } 151 | } 152 | 153 | // onError calls all registered error hooks with the error object. 154 | // The err parameter contains the actual error that occurred, which implements 155 | // the standard error interface and may be a wrapped error or custom error type. 156 | // 157 | // This allows consumer code to use Go's error handling patterns: 158 | // - errors.Is(err, ErrUnsupported) to check for specific sentinel errors 159 | // - errors.As(err, &customErr) to extract custom error types 160 | // 161 | // Common error types include: 162 | // - ErrUnsupported: When a capability is not enabled 163 | // - UnparsableMessageError: When request parsing fails 164 | // - ErrResourceNotFound: When a resource is not found 165 | // - ErrPromptNotFound: When a prompt is not found 166 | // - ErrToolNotFound: When a tool is not found 167 | func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 168 | if c == nil { 169 | return 170 | } 171 | for _, hook := range c.OnError { 172 | hook(ctx, id, method, message, err) 173 | } 174 | } 175 | 176 | func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { 177 | c.OnRegisterSession = append(c.OnRegisterSession, hook) 178 | } 179 | 180 | func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { 181 | if c == nil { 182 | return 183 | } 184 | for _, hook := range c.OnRegisterSession { 185 | hook(ctx, session) 186 | } 187 | } 188 | 189 | func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { 190 | c.OnUnregisterSession = append(c.OnUnregisterSession, hook) 191 | } 192 | 193 | func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { 194 | if c == nil { 195 | return 196 | } 197 | for _, hook := range c.OnUnregisterSession { 198 | hook(ctx, session) 199 | } 200 | } 201 | 202 | {{- range .}} 203 | func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) { 204 | c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook) 205 | } 206 | 207 | func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) { 208 | c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook) 209 | } 210 | 211 | func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) { 212 | c.beforeAny(ctx, id, mcp.{{.MethodName}}, message) 213 | if c == nil { 214 | return 215 | } 216 | for _, hook := range c.OnBefore{{.HookName}} { 217 | hook(ctx, id, message) 218 | } 219 | } 220 | 221 | func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) { 222 | c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result) 223 | if c == nil { 224 | return 225 | } 226 | for _, hook := range c.OnAfter{{.HookName}} { 227 | hook(ctx, id, message, result) 228 | } 229 | } 230 | {{- end -}} 231 | -------------------------------------------------------------------------------- /server/internal/gen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "io" 7 | "log" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "strings" 12 | "text/template" 13 | ) 14 | 15 | //go:generate go run . ../.. 16 | 17 | //go:embed hooks.go.tmpl 18 | var hooksTemplate string 19 | 20 | //go:embed request_handler.go.tmpl 21 | var requestHandlerTemplate string 22 | 23 | func RenderTemplateToFile(templateContent, destPath, fileName string, data any) error { 24 | // Create temp file for initial output 25 | tempFile, err := os.CreateTemp("", "hooks-*.go") 26 | if err != nil { 27 | return err 28 | } 29 | tempFilePath := tempFile.Name() 30 | defer os.Remove(tempFilePath) // Clean up temp file when done 31 | defer tempFile.Close() 32 | 33 | // Parse and execute template to temp file 34 | tmpl, err := template.New(fileName).Funcs(template.FuncMap{ 35 | "toLower": strings.ToLower, 36 | }).Parse(templateContent) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | if err := tmpl.Execute(tempFile, data); err != nil { 42 | return err 43 | } 44 | 45 | // Run goimports on the temp file 46 | cmd := exec.Command("go", "run", "golang.org/x/tools/cmd/goimports@latest", "-w", tempFilePath) 47 | if output, err := cmd.CombinedOutput(); err != nil { 48 | return fmt.Errorf("goimports failed: %v\n%s", err, output) 49 | } 50 | 51 | // Read the processed content 52 | processedContent, err := os.ReadFile(tempFilePath) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | // Write the processed content to the destination 58 | var destWriter io.Writer 59 | if destPath == "-" { 60 | destWriter = os.Stdout 61 | } else { 62 | destFile, err := os.Create(filepath.Join(destPath, fileName)) 63 | if err != nil { 64 | return err 65 | } 66 | defer destFile.Close() 67 | destWriter = destFile 68 | } 69 | 70 | _, err = destWriter.Write(processedContent) 71 | return err 72 | } 73 | 74 | func main() { 75 | if len(os.Args) < 2 { 76 | log.Fatal("usage: gen ") 77 | } 78 | destPath := os.Args[1] 79 | 80 | if err := RenderTemplateToFile(hooksTemplate, destPath, "hooks.go", MCPRequestTypes); err != nil { 81 | log.Fatal(err) 82 | } 83 | 84 | if err := RenderTemplateToFile(requestHandlerTemplate, destPath, "request_handler.go", MCPRequestTypes); err != nil { 85 | log.Fatal(err) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /server/internal/gen/request_handler.go.tmpl: -------------------------------------------------------------------------------- 1 | // Code generated by `go generate`. DO NOT EDIT. 2 | // source: server/internal/gen/request_handler.go.tmpl 3 | package server 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response 15 | func (s *MCPServer) HandleMessage( 16 | ctx context.Context, 17 | message json.RawMessage, 18 | ) mcp.JSONRPCMessage { 19 | // Add server to context 20 | ctx = context.WithValue(ctx, serverKey{}, s) 21 | var err *requestError 22 | 23 | var baseMessage struct { 24 | JSONRPC string `json:"jsonrpc"` 25 | Method mcp.MCPMethod `json:"method"` 26 | ID any `json:"id,omitempty"` 27 | Result any `json:"result,omitempty"` 28 | } 29 | 30 | if err := json.Unmarshal(message, &baseMessage); err != nil { 31 | return createErrorResponse( 32 | nil, 33 | mcp.PARSE_ERROR, 34 | "Failed to parse message", 35 | ) 36 | } 37 | 38 | // Check for valid JSONRPC version 39 | if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { 40 | return createErrorResponse( 41 | baseMessage.ID, 42 | mcp.INVALID_REQUEST, 43 | "Invalid JSON-RPC version", 44 | ) 45 | } 46 | 47 | if baseMessage.ID == nil { 48 | var notification mcp.JSONRPCNotification 49 | if err := json.Unmarshal(message, ¬ification); err != nil { 50 | return createErrorResponse( 51 | nil, 52 | mcp.PARSE_ERROR, 53 | "Failed to parse notification", 54 | ) 55 | } 56 | s.handleNotification(ctx, notification) 57 | return nil // Return nil for notifications 58 | } 59 | 60 | if baseMessage.Result != nil { 61 | // this is a response to a request sent by the server (e.g. from a ping 62 | // sent due to WithKeepAlive option) 63 | return nil 64 | } 65 | 66 | switch baseMessage.Method { 67 | {{- range .}} 68 | case mcp.{{.MethodName}}: 69 | var request mcp.{{.ParamType}} 70 | var result *mcp.{{.ResultType}} 71 | {{ if .Group }}if s.capabilities.{{.Group}} == nil { 72 | err = &requestError{ 73 | id: baseMessage.ID, 74 | code: mcp.METHOD_NOT_FOUND, 75 | err: fmt.Errorf("{{toLower .GroupName}} %w", ErrUnsupported), 76 | } 77 | } else{{ end }} if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 78 | err = &requestError{ 79 | id: baseMessage.ID, 80 | code: mcp.INVALID_REQUEST, 81 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 82 | } 83 | } else { 84 | s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) 85 | result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request) 86 | } 87 | if err != nil { 88 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 89 | return err.ToJSONRPCError() 90 | } 91 | s.hooks.after{{.HookName}}(ctx, baseMessage.ID, &request, result) 92 | return createResponse(baseMessage.ID, *result) 93 | {{- end }} 94 | default: 95 | return createErrorResponse( 96 | baseMessage.ID, 97 | mcp.METHOD_NOT_FOUND, 98 | fmt.Sprintf("Method %s not found", baseMessage.Method), 99 | ) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /server/request_handler.go: -------------------------------------------------------------------------------- 1 | // Code generated by `go generate`. DO NOT EDIT. 2 | // source: server/internal/gen/request_handler.go.tmpl 3 | package server 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | 10 | "github.com/mark3labs/mcp-go/mcp" 11 | ) 12 | 13 | // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response 14 | func (s *MCPServer) HandleMessage( 15 | ctx context.Context, 16 | message json.RawMessage, 17 | ) mcp.JSONRPCMessage { 18 | // Add server to context 19 | ctx = context.WithValue(ctx, serverKey{}, s) 20 | var err *requestError 21 | 22 | var baseMessage struct { 23 | JSONRPC string `json:"jsonrpc"` 24 | Method mcp.MCPMethod `json:"method"` 25 | ID any `json:"id,omitempty"` 26 | Result any `json:"result,omitempty"` 27 | } 28 | 29 | if err := json.Unmarshal(message, &baseMessage); err != nil { 30 | return createErrorResponse( 31 | nil, 32 | mcp.PARSE_ERROR, 33 | "Failed to parse message", 34 | ) 35 | } 36 | 37 | // Check for valid JSONRPC version 38 | if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { 39 | return createErrorResponse( 40 | baseMessage.ID, 41 | mcp.INVALID_REQUEST, 42 | "Invalid JSON-RPC version", 43 | ) 44 | } 45 | 46 | if baseMessage.ID == nil { 47 | var notification mcp.JSONRPCNotification 48 | if err := json.Unmarshal(message, ¬ification); err != nil { 49 | return createErrorResponse( 50 | nil, 51 | mcp.PARSE_ERROR, 52 | "Failed to parse notification", 53 | ) 54 | } 55 | s.handleNotification(ctx, notification) 56 | return nil // Return nil for notifications 57 | } 58 | 59 | if baseMessage.Result != nil { 60 | // this is a response to a request sent by the server (e.g. from a ping 61 | // sent due to WithKeepAlive option) 62 | return nil 63 | } 64 | 65 | switch baseMessage.Method { 66 | case mcp.MethodInitialize: 67 | var request mcp.InitializeRequest 68 | var result *mcp.InitializeResult 69 | if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 70 | err = &requestError{ 71 | id: baseMessage.ID, 72 | code: mcp.INVALID_REQUEST, 73 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 74 | } 75 | } else { 76 | s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) 77 | result, err = s.handleInitialize(ctx, baseMessage.ID, request) 78 | } 79 | if err != nil { 80 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 81 | return err.ToJSONRPCError() 82 | } 83 | s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) 84 | return createResponse(baseMessage.ID, *result) 85 | case mcp.MethodPing: 86 | var request mcp.PingRequest 87 | var result *mcp.EmptyResult 88 | if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 89 | err = &requestError{ 90 | id: baseMessage.ID, 91 | code: mcp.INVALID_REQUEST, 92 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 93 | } 94 | } else { 95 | s.hooks.beforePing(ctx, baseMessage.ID, &request) 96 | result, err = s.handlePing(ctx, baseMessage.ID, request) 97 | } 98 | if err != nil { 99 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 100 | return err.ToJSONRPCError() 101 | } 102 | s.hooks.afterPing(ctx, baseMessage.ID, &request, result) 103 | return createResponse(baseMessage.ID, *result) 104 | case mcp.MethodResourcesList: 105 | var request mcp.ListResourcesRequest 106 | var result *mcp.ListResourcesResult 107 | if s.capabilities.resources == nil { 108 | err = &requestError{ 109 | id: baseMessage.ID, 110 | code: mcp.METHOD_NOT_FOUND, 111 | err: fmt.Errorf("resources %w", ErrUnsupported), 112 | } 113 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 114 | err = &requestError{ 115 | id: baseMessage.ID, 116 | code: mcp.INVALID_REQUEST, 117 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 118 | } 119 | } else { 120 | s.hooks.beforeListResources(ctx, baseMessage.ID, &request) 121 | result, err = s.handleListResources(ctx, baseMessage.ID, request) 122 | } 123 | if err != nil { 124 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 125 | return err.ToJSONRPCError() 126 | } 127 | s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) 128 | return createResponse(baseMessage.ID, *result) 129 | case mcp.MethodResourcesTemplatesList: 130 | var request mcp.ListResourceTemplatesRequest 131 | var result *mcp.ListResourceTemplatesResult 132 | if s.capabilities.resources == nil { 133 | err = &requestError{ 134 | id: baseMessage.ID, 135 | code: mcp.METHOD_NOT_FOUND, 136 | err: fmt.Errorf("resources %w", ErrUnsupported), 137 | } 138 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 139 | err = &requestError{ 140 | id: baseMessage.ID, 141 | code: mcp.INVALID_REQUEST, 142 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 143 | } 144 | } else { 145 | s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) 146 | result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) 147 | } 148 | if err != nil { 149 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 150 | return err.ToJSONRPCError() 151 | } 152 | s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) 153 | return createResponse(baseMessage.ID, *result) 154 | case mcp.MethodResourcesRead: 155 | var request mcp.ReadResourceRequest 156 | var result *mcp.ReadResourceResult 157 | if s.capabilities.resources == nil { 158 | err = &requestError{ 159 | id: baseMessage.ID, 160 | code: mcp.METHOD_NOT_FOUND, 161 | err: fmt.Errorf("resources %w", ErrUnsupported), 162 | } 163 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 164 | err = &requestError{ 165 | id: baseMessage.ID, 166 | code: mcp.INVALID_REQUEST, 167 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 168 | } 169 | } else { 170 | s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) 171 | result, err = s.handleReadResource(ctx, baseMessage.ID, request) 172 | } 173 | if err != nil { 174 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 175 | return err.ToJSONRPCError() 176 | } 177 | s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) 178 | return createResponse(baseMessage.ID, *result) 179 | case mcp.MethodPromptsList: 180 | var request mcp.ListPromptsRequest 181 | var result *mcp.ListPromptsResult 182 | if s.capabilities.prompts == nil { 183 | err = &requestError{ 184 | id: baseMessage.ID, 185 | code: mcp.METHOD_NOT_FOUND, 186 | err: fmt.Errorf("prompts %w", ErrUnsupported), 187 | } 188 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 189 | err = &requestError{ 190 | id: baseMessage.ID, 191 | code: mcp.INVALID_REQUEST, 192 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 193 | } 194 | } else { 195 | s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) 196 | result, err = s.handleListPrompts(ctx, baseMessage.ID, request) 197 | } 198 | if err != nil { 199 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 200 | return err.ToJSONRPCError() 201 | } 202 | s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) 203 | return createResponse(baseMessage.ID, *result) 204 | case mcp.MethodPromptsGet: 205 | var request mcp.GetPromptRequest 206 | var result *mcp.GetPromptResult 207 | if s.capabilities.prompts == nil { 208 | err = &requestError{ 209 | id: baseMessage.ID, 210 | code: mcp.METHOD_NOT_FOUND, 211 | err: fmt.Errorf("prompts %w", ErrUnsupported), 212 | } 213 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 214 | err = &requestError{ 215 | id: baseMessage.ID, 216 | code: mcp.INVALID_REQUEST, 217 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 218 | } 219 | } else { 220 | s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) 221 | result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) 222 | } 223 | if err != nil { 224 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 225 | return err.ToJSONRPCError() 226 | } 227 | s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) 228 | return createResponse(baseMessage.ID, *result) 229 | case mcp.MethodToolsList: 230 | var request mcp.ListToolsRequest 231 | var result *mcp.ListToolsResult 232 | if s.capabilities.tools == nil { 233 | err = &requestError{ 234 | id: baseMessage.ID, 235 | code: mcp.METHOD_NOT_FOUND, 236 | err: fmt.Errorf("tools %w", ErrUnsupported), 237 | } 238 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 239 | err = &requestError{ 240 | id: baseMessage.ID, 241 | code: mcp.INVALID_REQUEST, 242 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 243 | } 244 | } else { 245 | s.hooks.beforeListTools(ctx, baseMessage.ID, &request) 246 | result, err = s.handleListTools(ctx, baseMessage.ID, request) 247 | } 248 | if err != nil { 249 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 250 | return err.ToJSONRPCError() 251 | } 252 | s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) 253 | return createResponse(baseMessage.ID, *result) 254 | case mcp.MethodToolsCall: 255 | var request mcp.CallToolRequest 256 | var result *mcp.CallToolResult 257 | if s.capabilities.tools == nil { 258 | err = &requestError{ 259 | id: baseMessage.ID, 260 | code: mcp.METHOD_NOT_FOUND, 261 | err: fmt.Errorf("tools %w", ErrUnsupported), 262 | } 263 | } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 264 | err = &requestError{ 265 | id: baseMessage.ID, 266 | code: mcp.INVALID_REQUEST, 267 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 268 | } 269 | } else { 270 | s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) 271 | result, err = s.handleToolCall(ctx, baseMessage.ID, request) 272 | } 273 | if err != nil { 274 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 275 | return err.ToJSONRPCError() 276 | } 277 | s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) 278 | return createResponse(baseMessage.ID, *result) 279 | default: 280 | return createErrorResponse( 281 | baseMessage.ID, 282 | mcp.METHOD_NOT_FOUND, 283 | fmt.Sprintf("Method %s not found", baseMessage.Method), 284 | ) 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /server/resource_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mark3labs/mcp-go/mcp" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMCPServer_RemoveResource(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) 17 | expectedNotifications int 18 | validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) 19 | }{ 20 | { 21 | name: "RemoveResource removes the resource from the server", 22 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 23 | // Add a test resource 24 | server.AddResource( 25 | mcp.NewResource( 26 | "test://resource1", 27 | "Resource 1", 28 | mcp.WithResourceDescription("Test resource 1"), 29 | mcp.WithMIMEType("text/plain"), 30 | ), 31 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 32 | return []mcp.ResourceContents{ 33 | mcp.TextResourceContents{ 34 | URI: "test://resource1", 35 | MIMEType: "text/plain", 36 | Text: "test content 1", 37 | }, 38 | }, nil 39 | }, 40 | ) 41 | 42 | // Add a second resource 43 | server.AddResource( 44 | mcp.NewResource( 45 | "test://resource2", 46 | "Resource 2", 47 | mcp.WithResourceDescription("Test resource 2"), 48 | mcp.WithMIMEType("text/plain"), 49 | ), 50 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 51 | return []mcp.ResourceContents{ 52 | mcp.TextResourceContents{ 53 | URI: "test://resource2", 54 | MIMEType: "text/plain", 55 | Text: "test content 2", 56 | }, 57 | }, nil 58 | }, 59 | ) 60 | 61 | // First, verify we have two resources 62 | response := server.HandleMessage(context.Background(), []byte(`{ 63 | "jsonrpc": "2.0", 64 | "id": 1, 65 | "method": "resources/list" 66 | }`)) 67 | resp, ok := response.(mcp.JSONRPCResponse) 68 | assert.True(t, ok) 69 | result, ok := resp.Result.(mcp.ListResourcesResult) 70 | assert.True(t, ok) 71 | assert.Len(t, result.Resources, 2) 72 | 73 | // Now register session to receive notifications 74 | err := server.RegisterSession(context.TODO(), &fakeSession{ 75 | sessionID: "test", 76 | notificationChannel: notificationChannel, 77 | initialized: true, 78 | }) 79 | require.NoError(t, err) 80 | 81 | // Now remove one resource 82 | server.RemoveResource("test://resource1") 83 | }, 84 | expectedNotifications: 1, 85 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 86 | // Check that we received a list_changed notification 87 | assert.Equal(t, "resources/list_changed", notifications[0].Method) 88 | 89 | // Verify we now have only one resource 90 | resp, ok := resourcesList.(mcp.JSONRPCResponse) 91 | assert.True(t, ok, "Expected JSONRPCResponse, got %T", resourcesList) 92 | 93 | result, ok := resp.Result.(mcp.ListResourcesResult) 94 | assert.True(t, ok, "Expected ListResourcesResult, got %T", resp.Result) 95 | 96 | assert.Len(t, result.Resources, 1) 97 | assert.Equal(t, "Resource 2", result.Resources[0].Name) 98 | }, 99 | }, 100 | { 101 | name: "RemoveResource with non-existent resource does nothing", 102 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 103 | // Add a test resource 104 | server.AddResource( 105 | mcp.NewResource( 106 | "test://resource1", 107 | "Resource 1", 108 | mcp.WithResourceDescription("Test resource 1"), 109 | mcp.WithMIMEType("text/plain"), 110 | ), 111 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 112 | return []mcp.ResourceContents{ 113 | mcp.TextResourceContents{ 114 | URI: "test://resource1", 115 | MIMEType: "text/plain", 116 | Text: "test content 1", 117 | }, 118 | }, nil 119 | }, 120 | ) 121 | 122 | // Register session to receive notifications 123 | err := server.RegisterSession(context.TODO(), &fakeSession{ 124 | sessionID: "test", 125 | notificationChannel: notificationChannel, 126 | initialized: true, 127 | }) 128 | require.NoError(t, err) 129 | 130 | // Remove a non-existent resource 131 | server.RemoveResource("test://nonexistent") 132 | }, 133 | expectedNotifications: 1, // Still sends a notification 134 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 135 | // Check that we received a list_changed notification 136 | assert.Equal(t, "resources/list_changed", notifications[0].Method) 137 | 138 | // The original resource should still be there 139 | resp, ok := resourcesList.(mcp.JSONRPCResponse) 140 | assert.True(t, ok) 141 | 142 | result, ok := resp.Result.(mcp.ListResourcesResult) 143 | assert.True(t, ok) 144 | 145 | assert.Len(t, result.Resources, 1) 146 | assert.Equal(t, "Resource 1", result.Resources[0].Name) 147 | }, 148 | }, 149 | { 150 | name: "RemoveResource with no listChanged capability doesn't send notification", 151 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 152 | // Create a new server without listChanged capability 153 | noListChangedServer := NewMCPServer( 154 | "test-server", 155 | "1.0.0", 156 | WithResourceCapabilities(true, false), // Subscribe but not listChanged 157 | ) 158 | 159 | // Add a resource 160 | noListChangedServer.AddResource( 161 | mcp.NewResource( 162 | "test://resource1", 163 | "Resource 1", 164 | mcp.WithResourceDescription("Test resource 1"), 165 | mcp.WithMIMEType("text/plain"), 166 | ), 167 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 168 | return []mcp.ResourceContents{ 169 | mcp.TextResourceContents{ 170 | URI: "test://resource1", 171 | MIMEType: "text/plain", 172 | Text: "test content 1", 173 | }, 174 | }, nil 175 | }, 176 | ) 177 | 178 | // Register session to receive notifications 179 | err := noListChangedServer.RegisterSession(context.TODO(), &fakeSession{ 180 | sessionID: "test", 181 | notificationChannel: notificationChannel, 182 | initialized: true, 183 | }) 184 | require.NoError(t, err) 185 | 186 | // Remove the resource 187 | noListChangedServer.RemoveResource("test://resource1") 188 | 189 | // The test can now proceed without waiting for notifications 190 | // since we don't expect any 191 | }, 192 | expectedNotifications: 0, // No notifications expected 193 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 194 | // Nothing to do here, we're just verifying that no notifications were sent 195 | assert.Empty(t, notifications) 196 | }, 197 | }, 198 | } 199 | 200 | for _, tt := range tests { 201 | t.Run(tt.name, func(t *testing.T) { 202 | ctx := context.Background() 203 | server := NewMCPServer( 204 | "test-server", 205 | "1.0.0", 206 | WithResourceCapabilities(true, true), 207 | ) 208 | 209 | // Initialize the server 210 | _ = server.HandleMessage(ctx, []byte(`{ 211 | "jsonrpc": "2.0", 212 | "id": 1, 213 | "method": "initialize" 214 | }`)) 215 | 216 | notificationChannel := make(chan mcp.JSONRPCNotification, 100) 217 | notifications := make([]mcp.JSONRPCNotification, 0) 218 | 219 | tt.action(t, server, notificationChannel) 220 | 221 | // Collect notifications with a timeout 222 | if tt.expectedNotifications > 0 { 223 | for i := 0; i < tt.expectedNotifications; i++ { 224 | select { 225 | case notification := <-notificationChannel: 226 | notifications = append(notifications, notification) 227 | case <-time.After(1 * time.Second): 228 | t.Fatalf("Expected %d notifications but only received %d", tt.expectedNotifications, len(notifications)) 229 | } 230 | } 231 | } else { 232 | // If no notifications expected, wait a brief period to ensure none are sent 233 | select { 234 | case notification := <-notificationChannel: 235 | notifications = append(notifications, notification) 236 | case <-time.After(100 * time.Millisecond): 237 | // This is the expected path - no notifications 238 | } 239 | } 240 | 241 | // Get final resources list 242 | listMessage := `{ 243 | "jsonrpc": "2.0", 244 | "id": 1, 245 | "method": "resources/list" 246 | }` 247 | resourcesList := server.HandleMessage(ctx, []byte(listMessage)) 248 | 249 | // Validate the results 250 | tt.validate(t, notifications, resourcesList) 251 | }) 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /server/server_race_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/mark3labs/mcp-go/mcp" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // TestRaceConditions attempts to trigger race conditions by performing 16 | // concurrent operations on different resources of the MCPServer. 17 | func TestRaceConditions(t *testing.T) { 18 | // Create a server with all capabilities 19 | srv := NewMCPServer("test-server", "1.0.0", 20 | WithResourceCapabilities(true, true), 21 | WithPromptCapabilities(true), 22 | WithToolCapabilities(true), 23 | WithLogging(), 24 | WithRecovery(), 25 | ) 26 | 27 | // Create a context 28 | ctx := context.Background() 29 | 30 | // Create a sync.WaitGroup to coordinate test goroutines 31 | var wg sync.WaitGroup 32 | 33 | // Define test duration 34 | testDuration := 300 * time.Millisecond 35 | 36 | // Start goroutines to perform concurrent operations 37 | runConcurrentOperation(&wg, testDuration, "add-prompts", func() { 38 | name := fmt.Sprintf("prompt-%d", time.Now().UnixNano()) 39 | srv.AddPrompt(mcp.Prompt{ 40 | Name: name, 41 | Description: "Test prompt", 42 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 43 | return &mcp.GetPromptResult{}, nil 44 | }) 45 | }) 46 | 47 | runConcurrentOperation(&wg, testDuration, "add-tools", func() { 48 | name := fmt.Sprintf("tool-%d", time.Now().UnixNano()) 49 | srv.AddTool(mcp.Tool{ 50 | Name: name, 51 | Description: "Test tool", 52 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 53 | return &mcp.CallToolResult{}, nil 54 | }) 55 | }) 56 | 57 | runConcurrentOperation(&wg, testDuration, "delete-tools", func() { 58 | name := fmt.Sprintf("delete-tool-%d", time.Now().UnixNano()) 59 | // Add and immediately delete 60 | srv.AddTool(mcp.Tool{ 61 | Name: name, 62 | Description: "Temporary tool", 63 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 64 | return &mcp.CallToolResult{}, nil 65 | }) 66 | srv.DeleteTools(name) 67 | }) 68 | 69 | runConcurrentOperation(&wg, testDuration, "add-middleware", func() { 70 | middleware := func(next ToolHandlerFunc) ToolHandlerFunc { 71 | return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 72 | return next(ctx, req) 73 | } 74 | } 75 | WithToolHandlerMiddleware(middleware)(srv) 76 | }) 77 | 78 | runConcurrentOperation(&wg, testDuration, "list-tools", func() { 79 | result, reqErr := srv.handleListTools(ctx, "123", mcp.ListToolsRequest{}) 80 | require.Nil(t, reqErr, "List tools operation should not return an error") 81 | require.NotNil(t, result, "List tools result should not be nil") 82 | }) 83 | 84 | runConcurrentOperation(&wg, testDuration, "list-prompts", func() { 85 | result, reqErr := srv.handleListPrompts(ctx, "123", mcp.ListPromptsRequest{}) 86 | require.Nil(t, reqErr, "List prompts operation should not return an error") 87 | require.NotNil(t, result, "List prompts result should not be nil") 88 | }) 89 | 90 | // Add a persistent tool for testing tool calls 91 | srv.AddTool(mcp.Tool{ 92 | Name: "persistent-tool", 93 | Description: "Test tool that always exists", 94 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 95 | return &mcp.CallToolResult{}, nil 96 | }) 97 | 98 | runConcurrentOperation(&wg, testDuration, "call-tools", func() { 99 | req := mcp.CallToolRequest{} 100 | req.Params.Name = "persistent-tool" 101 | req.Params.Arguments = map[string]interface{}{"param": "test"} 102 | result, reqErr := srv.handleToolCall(ctx, "123", req) 103 | require.Nil(t, reqErr, "Tool call operation should not return an error") 104 | require.NotNil(t, result, "Tool call result should not be nil") 105 | }) 106 | 107 | runConcurrentOperation(&wg, testDuration, "add-resources", func() { 108 | uri := fmt.Sprintf("resource-%d", time.Now().UnixNano()) 109 | srv.AddResource(mcp.Resource{ 110 | URI: uri, 111 | Name: uri, 112 | Description: "Test resource", 113 | }, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 114 | return []mcp.ResourceContents{ 115 | mcp.TextResourceContents{ 116 | URI: uri, 117 | Text: "Test content", 118 | }, 119 | }, nil 120 | }) 121 | }) 122 | 123 | // Wait for all operations to complete 124 | wg.Wait() 125 | t.Log("No race conditions detected") 126 | } 127 | 128 | // Helper function to run an operation concurrently for a specified duration 129 | func runConcurrentOperation(wg *sync.WaitGroup, duration time.Duration, name string, operation func()) { 130 | wg.Add(1) 131 | go func() { 132 | defer wg.Done() 133 | 134 | done := time.After(duration) 135 | for { 136 | select { 137 | case <-done: 138 | return 139 | default: 140 | operation() 141 | } 142 | } 143 | }() 144 | } 145 | 146 | // TestConcurrentPromptAdd specifically tests for the deadlock scenario where adding a prompt 147 | // from a goroutine can cause a deadlock 148 | func TestConcurrentPromptAdd(t *testing.T) { 149 | srv := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) 150 | ctx := context.Background() 151 | 152 | // Add a prompt with a handler that adds another prompt in a goroutine 153 | srv.AddPrompt(mcp.Prompt{ 154 | Name: "initial-prompt", 155 | Description: "Initial prompt", 156 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 157 | go func() { 158 | srv.AddPrompt(mcp.Prompt{ 159 | Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), 160 | Description: "Added from handler", 161 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 162 | return &mcp.GetPromptResult{}, nil 163 | }) 164 | }() 165 | return &mcp.GetPromptResult{}, nil 166 | }) 167 | 168 | // Create request and channel to track completion 169 | req := mcp.GetPromptRequest{} 170 | req.Params.Name = "initial-prompt" 171 | done := make(chan struct{}) 172 | 173 | // Try to get the prompt - this would deadlock with a single mutex 174 | go func() { 175 | result, reqErr := srv.handleGetPrompt(ctx, "123", req) 176 | require.Nil(t, reqErr, "Get prompt operation should not return an error") 177 | require.NotNil(t, result, "Get prompt result should not be nil") 178 | close(done) 179 | }() 180 | 181 | // Assert the operation completes without deadlock 182 | assert.Eventually(t, func() bool { 183 | select { 184 | case <-done: 185 | return true 186 | default: 187 | return false 188 | } 189 | }, 1*time.Second, 10*time.Millisecond, "Deadlock detected: operation did not complete in time") 190 | } 191 | -------------------------------------------------------------------------------- /server/stdio.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "sync/atomic" 13 | "syscall" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | // StdioContextFunc is a function that takes an existing context and returns 19 | // a potentially modified context. 20 | // This can be used to inject context values from environment variables, 21 | // for example. 22 | type StdioContextFunc func(ctx context.Context) context.Context 23 | 24 | // StdioServer wraps a MCPServer and handles stdio communication. 25 | // It provides a simple way to create command-line MCP servers that 26 | // communicate via standard input/output streams using JSON-RPC messages. 27 | type StdioServer struct { 28 | server *MCPServer 29 | errLogger *log.Logger 30 | contextFunc StdioContextFunc 31 | } 32 | 33 | // StdioOption defines a function type for configuring StdioServer 34 | type StdioOption func(*StdioServer) 35 | 36 | // WithErrorLogger sets the error logger for the server 37 | func WithErrorLogger(logger *log.Logger) StdioOption { 38 | return func(s *StdioServer) { 39 | s.errLogger = logger 40 | } 41 | } 42 | 43 | // WithContextFunc sets a function that will be called to customise the context 44 | // to the server. Note that the stdio server uses the same context for all requests, 45 | // so this function will only be called once per server instance. 46 | func WithStdioContextFunc(fn StdioContextFunc) StdioOption { 47 | return func(s *StdioServer) { 48 | s.contextFunc = fn 49 | } 50 | } 51 | 52 | // stdioSession is a static client session, since stdio has only one client. 53 | type stdioSession struct { 54 | notifications chan mcp.JSONRPCNotification 55 | initialized atomic.Bool 56 | } 57 | 58 | func (s *stdioSession) SessionID() string { 59 | return "stdio" 60 | } 61 | 62 | func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { 63 | return s.notifications 64 | } 65 | 66 | func (s *stdioSession) Initialize() { 67 | s.initialized.Store(true) 68 | } 69 | 70 | func (s *stdioSession) Initialized() bool { 71 | return s.initialized.Load() 72 | } 73 | 74 | var _ ClientSession = (*stdioSession)(nil) 75 | 76 | var stdioSessionInstance = stdioSession{ 77 | notifications: make(chan mcp.JSONRPCNotification, 100), 78 | } 79 | 80 | // NewStdioServer creates a new stdio server wrapper around an MCPServer. 81 | // It initializes the server with a default error logger that discards all output. 82 | func NewStdioServer(server *MCPServer) *StdioServer { 83 | return &StdioServer{ 84 | server: server, 85 | errLogger: log.New( 86 | os.Stderr, 87 | "", 88 | log.LstdFlags, 89 | ), // Default to discarding logs 90 | } 91 | } 92 | 93 | // SetErrorLogger configures where error messages from the StdioServer are logged. 94 | // The provided logger will receive all error messages generated during server operation. 95 | func (s *StdioServer) SetErrorLogger(logger *log.Logger) { 96 | s.errLogger = logger 97 | } 98 | 99 | // SetContextFunc sets a function that will be called to customise the context 100 | // to the server. Note that the stdio server uses the same context for all requests, 101 | // so this function will only be called once per server instance. 102 | func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { 103 | s.contextFunc = fn 104 | } 105 | 106 | // handleNotifications continuously processes notifications from the session's notification channel 107 | // and writes them to the provided output. It runs until the context is cancelled. 108 | // Any errors encountered while writing notifications are logged but do not stop the handler. 109 | func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { 110 | for { 111 | select { 112 | case notification := <-stdioSessionInstance.notifications: 113 | if err := s.writeResponse(notification, stdout); err != nil { 114 | s.errLogger.Printf("Error writing notification: %v", err) 115 | } 116 | case <-ctx.Done(): 117 | return 118 | } 119 | } 120 | } 121 | 122 | // processInputStream continuously reads and processes messages from the input stream. 123 | // It handles EOF gracefully as a normal termination condition. 124 | // The function returns when either: 125 | // - The context is cancelled (returns context.Err()) 126 | // - EOF is encountered (returns nil) 127 | // - An error occurs while reading or processing messages (returns the error) 128 | func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { 129 | for { 130 | if err := ctx.Err(); err != nil { 131 | return err 132 | } 133 | 134 | line, err := s.readNextLine(ctx, reader) 135 | if err != nil { 136 | if err == io.EOF { 137 | return nil 138 | } 139 | s.errLogger.Printf("Error reading input: %v", err) 140 | return err 141 | } 142 | 143 | if err := s.processMessage(ctx, line, stdout); err != nil { 144 | if err == io.EOF { 145 | return nil 146 | } 147 | s.errLogger.Printf("Error handling message: %v", err) 148 | return err 149 | } 150 | } 151 | } 152 | 153 | // readNextLine reads a single line from the input reader in a context-aware manner. 154 | // It uses channels to make the read operation cancellable via context. 155 | // Returns the read line and any error encountered. If the context is cancelled, 156 | // returns an empty string and the context's error. EOF is returned when the input 157 | // stream is closed. 158 | func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { 159 | readChan := make(chan string, 1) 160 | errChan := make(chan error, 1) 161 | done := make(chan struct{}) 162 | defer close(done) 163 | 164 | go func() { 165 | select { 166 | case <-done: 167 | return 168 | default: 169 | line, err := reader.ReadString('\n') 170 | if err != nil { 171 | select { 172 | case errChan <- err: 173 | case <-done: 174 | 175 | } 176 | return 177 | } 178 | select { 179 | case readChan <- line: 180 | case <-done: 181 | } 182 | } 183 | }() 184 | 185 | select { 186 | case <-ctx.Done(): 187 | return "", ctx.Err() 188 | case err := <-errChan: 189 | return "", err 190 | case line := <-readChan: 191 | return line, nil 192 | } 193 | } 194 | 195 | // Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. 196 | // It runs until the context is cancelled or an error occurs. 197 | // Returns an error if there are issues with reading input or writing output. 198 | func (s *StdioServer) Listen( 199 | ctx context.Context, 200 | stdin io.Reader, 201 | stdout io.Writer, 202 | ) error { 203 | // Set a static client context since stdio only has one client 204 | if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { 205 | return fmt.Errorf("register session: %w", err) 206 | } 207 | defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) 208 | ctx = s.server.WithContext(ctx, &stdioSessionInstance) 209 | 210 | // Add in any custom context. 211 | if s.contextFunc != nil { 212 | ctx = s.contextFunc(ctx) 213 | } 214 | 215 | reader := bufio.NewReader(stdin) 216 | 217 | // Start notification handler 218 | go s.handleNotifications(ctx, stdout) 219 | return s.processInputStream(ctx, reader, stdout) 220 | } 221 | 222 | // processMessage handles a single JSON-RPC message and writes the response. 223 | // It parses the message, processes it through the wrapped MCPServer, and writes any response. 224 | // Returns an error if there are issues with message processing or response writing. 225 | func (s *StdioServer) processMessage( 226 | ctx context.Context, 227 | line string, 228 | writer io.Writer, 229 | ) error { 230 | // Parse the message as raw JSON 231 | var rawMessage json.RawMessage 232 | if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { 233 | response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") 234 | return s.writeResponse(response, writer) 235 | } 236 | 237 | // Handle the message using the wrapped server 238 | response := s.server.HandleMessage(ctx, rawMessage) 239 | 240 | // Only write response if there is one (not for notifications) 241 | if response != nil { 242 | if err := s.writeResponse(response, writer); err != nil { 243 | return fmt.Errorf("failed to write response: %w", err) 244 | } 245 | } 246 | 247 | return nil 248 | } 249 | 250 | // writeResponse marshals and writes a JSON-RPC response message followed by a newline. 251 | // Returns an error if marshaling or writing fails. 252 | func (s *StdioServer) writeResponse( 253 | response mcp.JSONRPCMessage, 254 | writer io.Writer, 255 | ) error { 256 | responseBytes, err := json.Marshal(response) 257 | if err != nil { 258 | return err 259 | } 260 | 261 | // Write response followed by newline 262 | if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { 263 | return err 264 | } 265 | 266 | return nil 267 | } 268 | 269 | // ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. 270 | // It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. 271 | // Returns an error if the server encounters any issues during operation. 272 | func ServeStdio(server *MCPServer, opts ...StdioOption) error { 273 | s := NewStdioServer(server) 274 | s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags)) 275 | 276 | for _, opt := range opts { 277 | opt(s) 278 | } 279 | 280 | ctx, cancel := context.WithCancel(context.Background()) 281 | defer cancel() 282 | 283 | // Set up signal handling 284 | sigChan := make(chan os.Signal, 1) 285 | signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) 286 | 287 | go func() { 288 | <-sigChan 289 | cancel() 290 | }() 291 | 292 | return s.Listen(ctx, os.Stdin, os.Stdout) 293 | } 294 | -------------------------------------------------------------------------------- /server/stdio_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "log" 9 | "os" 10 | "testing" 11 | 12 | "github.com/mark3labs/mcp-go/mcp" 13 | ) 14 | 15 | func TestStdioServer(t *testing.T) { 16 | t.Run("Can instantiate", func(t *testing.T) { 17 | mcpServer := NewMCPServer("test", "1.0.0") 18 | stdioServer := NewStdioServer(mcpServer) 19 | 20 | if stdioServer.server == nil { 21 | t.Error("MCPServer should not be nil") 22 | } 23 | if stdioServer.errLogger == nil { 24 | t.Error("errLogger should not be nil") 25 | } 26 | }) 27 | 28 | t.Run("Can send and receive messages", func(t *testing.T) { 29 | // Create pipes for stdin and stdout 30 | stdinReader, stdinWriter := io.Pipe() 31 | stdoutReader, stdoutWriter := io.Pipe() 32 | 33 | // Create server 34 | mcpServer := NewMCPServer("test", "1.0.0", 35 | WithResourceCapabilities(true, true), 36 | ) 37 | stdioServer := NewStdioServer(mcpServer) 38 | stdioServer.SetErrorLogger(log.New(io.Discard, "", 0)) 39 | 40 | // Create context with cancel 41 | ctx, cancel := context.WithCancel(context.Background()) 42 | defer cancel() 43 | 44 | // Create error channel to catch server errors 45 | serverErrCh := make(chan error, 1) 46 | 47 | // Start server in goroutine 48 | go func() { 49 | err := stdioServer.Listen(ctx, stdinReader, stdoutWriter) 50 | if err != nil && err != io.EOF && err != context.Canceled { 51 | serverErrCh <- err 52 | } 53 | close(serverErrCh) 54 | }() 55 | 56 | // Create test message 57 | initRequest := map[string]interface{}{ 58 | "jsonrpc": "2.0", 59 | "id": 1, 60 | "method": "initialize", 61 | "params": map[string]interface{}{ 62 | "protocolVersion": "2024-11-05", 63 | "clientInfo": map[string]interface{}{ 64 | "name": "test-client", 65 | "version": "1.0.0", 66 | }, 67 | }, 68 | } 69 | 70 | // Send request 71 | requestBytes, err := json.Marshal(initRequest) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | // Read response 81 | scanner := bufio.NewScanner(stdoutReader) 82 | if !scanner.Scan() { 83 | t.Fatal("failed to read response") 84 | } 85 | responseBytes := scanner.Bytes() 86 | 87 | var response map[string]interface{} 88 | if err := json.Unmarshal(responseBytes, &response); err != nil { 89 | t.Fatalf("failed to unmarshal response: %v", err) 90 | } 91 | 92 | // Verify response structure 93 | if response["jsonrpc"] != "2.0" { 94 | t.Errorf("expected jsonrpc version 2.0, got %v", response["jsonrpc"]) 95 | } 96 | if response["id"].(float64) != 1 { 97 | t.Errorf("expected id 1, got %v", response["id"]) 98 | } 99 | if response["error"] != nil { 100 | t.Errorf("unexpected error in response: %v", response["error"]) 101 | } 102 | if response["result"] == nil { 103 | t.Error("expected result in response") 104 | } 105 | 106 | // Clean up 107 | cancel() 108 | stdinWriter.Close() 109 | stdoutWriter.Close() 110 | 111 | // Check for server errors 112 | if err := <-serverErrCh; err != nil { 113 | t.Errorf("unexpected server error: %v", err) 114 | } 115 | }) 116 | 117 | t.Run("Can use a custom context function", func(t *testing.T) { 118 | // Use a custom context key to store a test value. 119 | type testContextKey struct{} 120 | testValFromContext := func(ctx context.Context) string { 121 | val := ctx.Value(testContextKey{}) 122 | if val == nil { 123 | return "" 124 | } 125 | return val.(string) 126 | } 127 | // Create a context function that sets a test value from the environment. 128 | // In real life this could be used to send configuration in a similar way, 129 | // or from a config file. 130 | const testEnvVar = "TEST_ENV_VAR" 131 | setTestValFromEnv := func(ctx context.Context) context.Context { 132 | return context.WithValue(ctx, testContextKey{}, os.Getenv(testEnvVar)) 133 | } 134 | t.Setenv(testEnvVar, "test_value") 135 | 136 | // Create pipes for stdin and stdout 137 | stdinReader, stdinWriter := io.Pipe() 138 | stdoutReader, stdoutWriter := io.Pipe() 139 | 140 | // Create server 141 | mcpServer := NewMCPServer("test", "1.0.0") 142 | // Add a tool which uses the context function. 143 | mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 144 | // Note this is agnostic to the transport type i.e. doesn't know about request headers. 145 | testVal := testValFromContext(ctx) 146 | return mcp.NewToolResultText(testVal), nil 147 | }) 148 | stdioServer := NewStdioServer(mcpServer) 149 | stdioServer.SetErrorLogger(log.New(io.Discard, "", 0)) 150 | stdioServer.SetContextFunc(setTestValFromEnv) 151 | 152 | // Create context with cancel 153 | ctx, cancel := context.WithCancel(context.Background()) 154 | defer cancel() 155 | 156 | // Create error channel to catch server errors 157 | serverErrCh := make(chan error, 1) 158 | 159 | // Start server in goroutine 160 | go func() { 161 | err := stdioServer.Listen(ctx, stdinReader, stdoutWriter) 162 | if err != nil && err != io.EOF && err != context.Canceled { 163 | serverErrCh <- err 164 | } 165 | close(serverErrCh) 166 | }() 167 | 168 | // Create test message 169 | initRequest := map[string]interface{}{ 170 | "jsonrpc": "2.0", 171 | "id": 1, 172 | "method": "initialize", 173 | "params": map[string]interface{}{ 174 | "protocolVersion": "2024-11-05", 175 | "clientInfo": map[string]interface{}{ 176 | "name": "test-client", 177 | "version": "1.0.0", 178 | }, 179 | }, 180 | } 181 | 182 | // Send request 183 | requestBytes, err := json.Marshal(initRequest) 184 | if err != nil { 185 | t.Fatal(err) 186 | } 187 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | 192 | // Read response 193 | scanner := bufio.NewScanner(stdoutReader) 194 | if !scanner.Scan() { 195 | t.Fatal("failed to read response") 196 | } 197 | responseBytes := scanner.Bytes() 198 | 199 | var response map[string]interface{} 200 | if err := json.Unmarshal(responseBytes, &response); err != nil { 201 | t.Fatalf("failed to unmarshal response: %v", err) 202 | } 203 | 204 | // Verify response structure 205 | if response["jsonrpc"] != "2.0" { 206 | t.Errorf("expected jsonrpc version 2.0, got %v", response["jsonrpc"]) 207 | } 208 | if response["id"].(float64) != 1 { 209 | t.Errorf("expected id 1, got %v", response["id"]) 210 | } 211 | if response["error"] != nil { 212 | t.Errorf("unexpected error in response: %v", response["error"]) 213 | } 214 | if response["result"] == nil { 215 | t.Error("expected result in response") 216 | } 217 | 218 | // Call the tool. 219 | toolRequest := map[string]interface{}{ 220 | "jsonrpc": "2.0", 221 | "id": 2, 222 | "method": "tools/call", 223 | "params": map[string]interface{}{ 224 | "name": "test_tool", 225 | }, 226 | } 227 | requestBytes, err = json.Marshal(toolRequest) 228 | if err != nil { 229 | t.Fatalf("Failed to marshal tool request: %v", err) 230 | } 231 | 232 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 233 | if err != nil { 234 | t.Fatal(err) 235 | } 236 | 237 | if !scanner.Scan() { 238 | t.Fatal("failed to read response") 239 | } 240 | responseBytes = scanner.Bytes() 241 | 242 | response = map[string]interface{}{} 243 | if err := json.Unmarshal(responseBytes, &response); err != nil { 244 | t.Fatalf("failed to unmarshal response: %v", err) 245 | } 246 | 247 | if response["jsonrpc"] != "2.0" { 248 | t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) 249 | } 250 | if response["id"].(float64) != 2 { 251 | t.Errorf("Expected id 2, got %v", response["id"]) 252 | } 253 | if response["result"].(map[string]interface{})["content"].([]interface{})[0].(map[string]interface{})["text"] != "test_value" { 254 | t.Errorf("Expected result 'test_value', got %v", response["result"]) 255 | } 256 | if response["error"] != nil { 257 | t.Errorf("Expected no error, got %v", response["error"]) 258 | } 259 | 260 | // Clean up 261 | cancel() 262 | stdinWriter.Close() 263 | stdoutWriter.Close() 264 | 265 | // Check for server errors 266 | if err := <-serverErrCh; err != nil { 267 | t.Errorf("unexpected server error: %v", err) 268 | } 269 | }) 270 | } 271 | -------------------------------------------------------------------------------- /testdata/mockstdio_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | ) 10 | 11 | type JSONRPCRequest struct { 12 | JSONRPC string `json:"jsonrpc"` 13 | ID *int64 `json:"id,omitempty"` 14 | Method string `json:"method"` 15 | Params json.RawMessage `json:"params"` 16 | } 17 | 18 | type JSONRPCResponse struct { 19 | JSONRPC string `json:"jsonrpc"` 20 | ID *int64 `json:"id,omitempty"` 21 | Result interface{} `json:"result,omitempty"` 22 | Error *struct { 23 | Code int `json:"code"` 24 | Message string `json:"message"` 25 | } `json:"error,omitempty"` 26 | } 27 | 28 | func main() { 29 | logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) 30 | logger.Info("launch successful") 31 | scanner := bufio.NewScanner(os.Stdin) 32 | for scanner.Scan() { 33 | var request JSONRPCRequest 34 | if err := json.Unmarshal(scanner.Bytes(), &request); err != nil { 35 | continue 36 | } 37 | 38 | response := handleRequest(request) 39 | responseBytes, _ := json.Marshal(response) 40 | fmt.Fprintf(os.Stdout, "%s\n", responseBytes) 41 | } 42 | } 43 | 44 | func handleRequest(request JSONRPCRequest) JSONRPCResponse { 45 | response := JSONRPCResponse{ 46 | JSONRPC: "2.0", 47 | ID: request.ID, 48 | } 49 | 50 | switch request.Method { 51 | case "initialize": 52 | response.Result = map[string]interface{}{ 53 | "protocolVersion": "1.0", 54 | "serverInfo": map[string]interface{}{ 55 | "name": "mock-server", 56 | "version": "1.0.0", 57 | }, 58 | "capabilities": map[string]interface{}{ 59 | "prompts": map[string]interface{}{ 60 | "listChanged": true, 61 | }, 62 | "resources": map[string]interface{}{ 63 | "listChanged": true, 64 | "subscribe": true, 65 | }, 66 | "tools": map[string]interface{}{ 67 | "listChanged": true, 68 | }, 69 | }, 70 | } 71 | case "ping": 72 | response.Result = struct{}{} 73 | case "resources/list": 74 | response.Result = map[string]interface{}{ 75 | "resources": []map[string]interface{}{ 76 | { 77 | "name": "test-resource", 78 | "uri": "test://resource", 79 | }, 80 | }, 81 | } 82 | case "resources/read": 83 | response.Result = map[string]interface{}{ 84 | "contents": []map[string]interface{}{ 85 | { 86 | "text": "test content", 87 | "uri": "test://resource", 88 | }, 89 | }, 90 | } 91 | case "resources/subscribe", "resources/unsubscribe": 92 | response.Result = struct{}{} 93 | case "prompts/list": 94 | response.Result = map[string]interface{}{ 95 | "prompts": []map[string]interface{}{ 96 | { 97 | "name": "test-prompt", 98 | }, 99 | }, 100 | } 101 | case "prompts/get": 102 | response.Result = map[string]interface{}{ 103 | "messages": []map[string]interface{}{ 104 | { 105 | "role": "assistant", 106 | "content": map[string]interface{}{ 107 | "type": "text", 108 | "text": "test message", 109 | }, 110 | }, 111 | }, 112 | } 113 | case "tools/list": 114 | response.Result = map[string]interface{}{ 115 | "tools": []map[string]interface{}{ 116 | { 117 | "name": "test-tool", 118 | "inputSchema": map[string]interface{}{ 119 | "type": "object", 120 | }, 121 | }, 122 | }, 123 | } 124 | case "tools/call": 125 | response.Result = map[string]interface{}{ 126 | "content": []map[string]interface{}{ 127 | { 128 | "type": "text", 129 | "text": "tool result", 130 | }, 131 | }, 132 | } 133 | case "logging/setLevel": 134 | response.Result = struct{}{} 135 | case "completion/complete": 136 | response.Result = map[string]interface{}{ 137 | "completion": map[string]interface{}{ 138 | "values": []string{"test completion"}, 139 | }, 140 | } 141 | 142 | // Debug methods for testing transport. 143 | case "debug/echo": 144 | response.Result = request 145 | case "debug/echo_notification": 146 | response.Result = request 147 | 148 | // send notification to client 149 | responseBytes, _ := json.Marshal(map[string]any{ 150 | "jsonrpc": "2.0", 151 | "method": "debug/test", 152 | "params": request, 153 | }) 154 | fmt.Fprintf(os.Stdout, "%s\n", responseBytes) 155 | 156 | case "debug/echo_error_string": 157 | all, _ := json.Marshal(request) 158 | response.Error = &struct { 159 | Code int `json:"code"` 160 | Message string `json:"message"` 161 | }{ 162 | Code: -32601, 163 | Message: string(all), 164 | } 165 | default: 166 | response.Error = &struct { 167 | Code int `json:"code"` 168 | Message string `json:"message"` 169 | }{ 170 | Code: -32601, 171 | Message: "Method not found", 172 | } 173 | } 174 | 175 | return response 176 | } 177 | --------------------------------------------------------------------------------