├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── documentation-improvement.md │ └── feature-request.md ├── pull_request_template.md └── workflows │ ├── ci.yml │ ├── golangci-lint.yml │ ├── pages.yml │ └── release.yml ├── .gitignore ├── .golangci.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── client ├── client.go ├── http.go ├── http_test.go ├── inprocess.go ├── inprocess_test.go ├── interface.go ├── oauth.go ├── oauth_test.go ├── sse.go ├── sse_test.go ├── stdio.go ├── stdio_test.go └── transport │ ├── inprocess.go │ ├── interface.go │ ├── oauth.go │ ├── oauth_test.go │ ├── oauth_utils.go │ ├── oauth_utils_test.go │ ├── sse.go │ ├── sse_test.go │ ├── stdio.go │ ├── stdio_test.go │ ├── streamable_http.go │ ├── streamable_http_oauth_test.go │ └── streamable_http_test.go ├── examples ├── custom_context │ └── main.go ├── dynamic_path │ └── main.go ├── everything │ └── main.go ├── filesystem_stdio_client │ └── main.go ├── oauth_client │ ├── README.md │ └── main.go ├── simple_client │ └── main.go └── typed_tools │ └── main.go ├── go.mod ├── go.sum ├── logo.png ├── mcp ├── prompts.go ├── resources.go ├── tools.go ├── tools_test.go ├── typed_tools.go ├── typed_tools_test.go ├── types.go ├── types_test.go └── utils.go ├── mcptest ├── mcptest.go └── mcptest_test.go ├── server ├── errors.go ├── hooks.go ├── http_transport_options.go ├── internal │ └── gen │ │ ├── README.md │ │ ├── data.go │ │ ├── hooks.go.tmpl │ │ ├── main.go │ │ └── request_handler.go.tmpl ├── request_handler.go ├── resource_test.go ├── server.go ├── server_race_test.go ├── server_test.go ├── session.go ├── session_test.go ├── sse.go ├── sse_test.go ├── stdio.go ├── stdio_test.go ├── streamable_http.go └── streamable_http_test.go ├── testdata └── mockstdio_server.go ├── util └── logger.go └── www ├── .gitignore ├── README.md ├── bun.lock ├── docs ├── pages │ ├── example.mdx │ ├── getting-started.mdx │ └── index.mdx ├── public │ └── logo.png └── styles.css ├── package.json ├── tsconfig.json └── vocs.config.ts /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report an issue or unexpected behavior 4 | title: 'bug: ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | ## Description 10 | 11 | A clear and concise description of the bug, including what happened and what you expected to happen. 12 | 13 | ## Code Sample 14 | 15 | ```go 16 | // Minimum code snippet to reproduce the issue 17 | // Remove if not applicable 18 | ``` 19 | 20 | ## Logs or Error Messages 21 | 22 | ```text 23 | If applicable, include any error messages, stack traces, or logs. Remove if not applicable. 24 | ``` 25 | 26 | ## Environment 27 | 28 | - Go version (see `go.mod`): [e.g. 1.23] 29 | - mcp-go version (see `go.mod`): [e.g. 0.27.0] 30 | - Any other relevant environment details (OS, architecture, etc.) 31 | 32 | ## Additional Context 33 | 34 | Add any other context about the problem here. 35 | 36 | ## Possible Solution 37 | 38 | If you have a suggestion for fixing the issue, please describe it here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a Question 4 | url: https://github.com/mark3labs/mcp-go/discussions/categories/q-a 5 | about: Ask any question about the project. 6 | - name: Join the Community 7 | url: https://discord.gg/RqSS2NQVsY 8 | about: Join the community on Discord. 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-improvement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation improvement 3 | about: Suggest improvements to the documentation 4 | title: 'docs: ' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## Documentation Issue 10 | 11 | Describe what's unclear, incorrect, or missing in the current documentation. 12 | 13 | ## Location 14 | 15 | Provide a link or description of where this documentation issue exists or should exist (README, code comments, examples, etc.). 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest a new feature or enhancement 4 | title: 'feature: ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ## Problem Statement 10 | 11 | A clear and concise description of what the problem is. For example, "I'm always frustrated when [...]" 12 | 13 | ## Proposed Solution 14 | 15 | A clear and concise description of what you want to happen. Include any API design or implementation details you have in mind. 16 | 17 | ## MCP Spec Reference 18 | 19 | If this feature is described in the MCP specification, please provide a link to the relevant section with a brief explanation of how it relates to your request. 20 | 21 | Remove this section if not applicable. 22 | 23 | ## Example Usage 24 | 25 | ```go 26 | // If applicable, provide sample code showing how the proposed feature would be used. 27 | // Remove if not applicable 28 | ``` 29 | 30 | ## Alternatives/Workarounds Considered 31 | 32 | A clear and concise description of any alternative solutions, workarounds, or features you've considered. 33 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | Fixes # (if applicable) 5 | 6 | ## Type of Change 7 | 8 | 9 | - [ ] Bug fix (non-breaking change that fixes an issue) 10 | - [ ] New feature (non-breaking change that adds functionality) 11 | - [ ] MCP spec compatibility implementation 12 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 13 | - [ ] Documentation update 14 | - [ ] Code refactoring (no functional changes) 15 | - [ ] Performance improvement 16 | - [ ] Tests only (no functional changes) 17 | - [ ] Other (please describe): 18 | 19 | ## Checklist 20 | 21 | 22 | - [ ] My code follows the code style of this project 23 | - [ ] I have performed a self-review of my own code 24 | - [ ] I have added tests that prove my fix is effective or that my feature works 25 | - [ ] I have updated the documentation accordingly 26 | 27 | ## MCP Spec Compliance 28 | 29 | 30 | 31 | - [ ] This PR implements a feature defined in the MCP specification 32 | - [ ] Link to relevant spec section: [Link text](https://modelcontextprotocol.io/specification/path-to-section) 33 | - [ ] Implementation follows the specification exactly 34 | 35 | ## Additional Information 36 | 37 | 38 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: go 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-go@v5 15 | with: 16 | go-version-file: 'go.mod' 17 | - run: go test ./... -race 18 | 19 | verify-codegen: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions/setup-go@v5 24 | with: 25 | go-version-file: 'go.mod' 26 | - name: Run code generation 27 | run: go generate ./... 28 | - name: Check for uncommitted changes 29 | run: | 30 | if [[ -n $(git status --porcelain) ]]; then 31 | echo "Error: Generated code is not up to date. Please run 'go generate ./...' and commit the changes." 32 | git status 33 | git diff 34 | exit 1 35 | fi 36 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | golangci: 13 | name: lint 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-go@v5 18 | with: 19 | go-version: stable 20 | - name: golangci-lint 21 | uses: golangci/golangci-lint-action@v8 22 | with: 23 | version: v2.1 24 | -------------------------------------------------------------------------------- /.github/workflows/pages.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy to GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: [ main ] # or your default branch 6 | workflow_dispatch: # Allows manual triggering 7 | 8 | jobs: 9 | build-and-deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout Repository 13 | uses: actions/checkout@v3 14 | 15 | - name: Setup Bun 16 | uses: oven-sh/setup-bun@v1 17 | with: 18 | bun-version: latest # or specify a version like '1.0.0' 19 | 20 | - name: Install Dependencies 21 | working-directory: ./www 22 | run: bun install 23 | 24 | - name: Build 25 | working-directory: ./www 26 | run: bun run build 27 | 28 | - name: Deploy to GitHub Pages 29 | uses: JamesIves/github-pages-deploy-action@v4 30 | with: 31 | folder: www/docs/dist # Your build output directory 32 | branch: gh-pages # The branch the action should deploy to 33 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: "Create Release on Tag Push" 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | jobs: 7 | release: 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Checkout Code 12 | uses: actions/checkout@v3 13 | 14 | - name: Create GitHub Release 15 | uses: actions/create-release@v1 16 | env: 17 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 18 | with: 19 | tag_name: ${{ github.ref }} 20 | release_name: Release ${{ github.ref }} 21 | draft: false 22 | prerelease: false 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | .env 3 | .idea 4 | .opencode 5 | .claude 6 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | exclusions: 4 | presets: 5 | - std-error-handling 6 | 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [contact@mark3labs.com](mailto:contact@mark3labs.com). 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to the MCP Go SDK! We welcome contributions of all kinds, including bug fixes, new features, and documentation improvements. This document outlines the process for contributing to the project. 4 | 5 | ## Development Guidelines 6 | 7 | ### Prerequisites 8 | 9 | Make sure you have Go 1.23 or later installed on your machine. You can check your Go version by running: 10 | 11 | ```bash 12 | go version 13 | ``` 14 | 15 | ### Setup 16 | 17 | 1. Fork the repository 18 | 2. Clone your fork: 19 | 20 | ```bash 21 | git clone https://github.com/YOUR_USERNAME/mcp-go.git 22 | cd mcp-go 23 | ``` 24 | 3. Install the required packages: 25 | 26 | ```bash 27 | go mod tidy 28 | ``` 29 | 30 | ### Workflow 31 | 32 | 1. Create a new branch. 33 | 2. Make your changes. 34 | 3. Ensure you have added tests for any new functionality. 35 | 4. Run the tests as shown below from the root directory: 36 | 37 | ```bash 38 | go test -v './...' 39 | ``` 40 | 5. Submit a pull request to the main branch. 41 | 42 | Feel free to reach out if you have any questions or need help either by [opening an issue](https://github.com/mark3labs/mcp-go/issues) or by reaching out in the [Discord channel](https://discord.gg/RqSS2NQVsY). 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anthropic, PBC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | Thank you for helping us improve the security of the project. Your contributions are greatly appreciated. 4 | 5 | ## Reporting a Vulnerability 6 | 7 | If you discover a security vulnerability within this project, please email the maintainers at [contact@mark3labs.com](mailto:contact@mark3labs.com). 8 | -------------------------------------------------------------------------------- /client/http.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/mark3labs/mcp-go/client/transport" 7 | ) 8 | 9 | // NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client 10 | // with the given base URL. Returns an error if the URL is invalid. 11 | func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) { 12 | trans, err := transport.NewStreamableHTTP(baseURL, options...) 13 | if err != nil { 14 | return nil, fmt.Errorf("failed to create SSE transport: %w", err) 15 | } 16 | return NewClient(trans), nil 17 | } 18 | -------------------------------------------------------------------------------- /client/http_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/mark3labs/mcp-go/mcp" 7 | "github.com/mark3labs/mcp-go/server" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestHTTPClient(t *testing.T) { 13 | hooks := &server.Hooks{} 14 | hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { 15 | clientSession := server.ClientSessionFromContext(ctx) 16 | // wait until all the notifications are handled 17 | for len(clientSession.NotificationChannel()) > 0 { 18 | } 19 | time.Sleep(time.Millisecond * 50) 20 | }) 21 | 22 | // Create MCP server with capabilities 23 | mcpServer := server.NewMCPServer( 24 | "test-server", 25 | "1.0.0", 26 | server.WithToolCapabilities(true), 27 | server.WithHooks(hooks), 28 | ) 29 | 30 | mcpServer.AddTool( 31 | mcp.NewTool("notify"), 32 | func( 33 | ctx context.Context, 34 | request mcp.CallToolRequest, 35 | ) (*mcp.CallToolResult, error) { 36 | server := server.ServerFromContext(ctx) 37 | err := server.SendNotificationToClient( 38 | ctx, 39 | "notifications/progress", 40 | map[string]any{ 41 | "progress": 10, 42 | "total": 10, 43 | "progressToken": 0, 44 | }, 45 | ) 46 | if err != nil { 47 | return nil, fmt.Errorf("failed to send notification: %w", err) 48 | } 49 | 50 | return &mcp.CallToolResult{ 51 | Content: []mcp.Content{ 52 | mcp.TextContent{ 53 | Type: "text", 54 | Text: "notification sent successfully", 55 | }, 56 | }, 57 | }, nil 58 | }, 59 | ) 60 | 61 | testServer := server.NewTestStreamableHTTPServer(mcpServer) 62 | defer testServer.Close() 63 | 64 | t.Run("Can receive notification from server", func(t *testing.T) { 65 | client, err := NewStreamableHttpClient(testServer.URL) 66 | if err != nil { 67 | t.Fatalf("create client failed %v", err) 68 | return 69 | } 70 | 71 | notificationNum := 0 72 | client.OnNotification(func(notification mcp.JSONRPCNotification) { 73 | notificationNum += 1 74 | }) 75 | 76 | ctx := context.Background() 77 | 78 | if err := client.Start(ctx); err != nil { 79 | t.Fatalf("Failed to start client: %v", err) 80 | return 81 | } 82 | 83 | // Initialize 84 | initRequest := mcp.InitializeRequest{} 85 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 86 | initRequest.Params.ClientInfo = mcp.Implementation{ 87 | Name: "test-client", 88 | Version: "1.0.0", 89 | } 90 | 91 | _, err = client.Initialize(ctx, initRequest) 92 | if err != nil { 93 | t.Fatalf("Failed to initialize: %v\n", err) 94 | } 95 | 96 | request := mcp.CallToolRequest{} 97 | request.Params.Name = "notify" 98 | result, err := client.CallTool(ctx, request) 99 | if err != nil { 100 | t.Fatalf("CallTool failed: %v", err) 101 | } 102 | 103 | if len(result.Content) != 1 { 104 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 105 | } 106 | 107 | if notificationNum != 1 { 108 | t.Errorf("Expected 1 notification item, got %d", notificationNum) 109 | } 110 | }) 111 | } 112 | -------------------------------------------------------------------------------- /client/inprocess.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/mark3labs/mcp-go/client/transport" 5 | "github.com/mark3labs/mcp-go/server" 6 | ) 7 | 8 | // NewInProcessClient connect directly to a mcp server object in the same process 9 | func NewInProcessClient(server *server.MCPServer) (*Client, error) { 10 | inProcessTransport := transport.NewInProcessTransport(server) 11 | return NewClient(inProcessTransport), nil 12 | } 13 | -------------------------------------------------------------------------------- /client/interface.go: -------------------------------------------------------------------------------- 1 | // Package client provides MCP (Model Context Protocol) client implementations. 2 | package client 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | ) 9 | 10 | // MCPClient represents an MCP client interface 11 | type MCPClient interface { 12 | // Initialize sends the initial connection request to the server 13 | Initialize( 14 | ctx context.Context, 15 | request mcp.InitializeRequest, 16 | ) (*mcp.InitializeResult, error) 17 | 18 | // Ping checks if the server is alive 19 | Ping(ctx context.Context) error 20 | 21 | // ListResourcesByPage manually list resources by page. 22 | ListResourcesByPage( 23 | ctx context.Context, 24 | request mcp.ListResourcesRequest, 25 | ) (*mcp.ListResourcesResult, error) 26 | 27 | // ListResources requests a list of available resources from the server 28 | ListResources( 29 | ctx context.Context, 30 | request mcp.ListResourcesRequest, 31 | ) (*mcp.ListResourcesResult, error) 32 | 33 | // ListResourceTemplatesByPage manually list resource templates by page. 34 | ListResourceTemplatesByPage( 35 | ctx context.Context, 36 | request mcp.ListResourceTemplatesRequest, 37 | ) (*mcp.ListResourceTemplatesResult, 38 | error) 39 | 40 | // ListResourceTemplates requests a list of available resource templates from the server 41 | ListResourceTemplates( 42 | ctx context.Context, 43 | request mcp.ListResourceTemplatesRequest, 44 | ) (*mcp.ListResourceTemplatesResult, 45 | error) 46 | 47 | // ReadResource reads a specific resource from the server 48 | ReadResource( 49 | ctx context.Context, 50 | request mcp.ReadResourceRequest, 51 | ) (*mcp.ReadResourceResult, error) 52 | 53 | // Subscribe requests notifications for changes to a specific resource 54 | Subscribe(ctx context.Context, request mcp.SubscribeRequest) error 55 | 56 | // Unsubscribe cancels notifications for a specific resource 57 | Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error 58 | 59 | // ListPromptsByPage manually list prompts by page. 60 | ListPromptsByPage( 61 | ctx context.Context, 62 | request mcp.ListPromptsRequest, 63 | ) (*mcp.ListPromptsResult, error) 64 | 65 | // ListPrompts requests a list of available prompts from the server 66 | ListPrompts( 67 | ctx context.Context, 68 | request mcp.ListPromptsRequest, 69 | ) (*mcp.ListPromptsResult, error) 70 | 71 | // GetPrompt retrieves a specific prompt from the server 72 | GetPrompt( 73 | ctx context.Context, 74 | request mcp.GetPromptRequest, 75 | ) (*mcp.GetPromptResult, error) 76 | 77 | // ListToolsByPage manually list tools by page. 78 | ListToolsByPage( 79 | ctx context.Context, 80 | request mcp.ListToolsRequest, 81 | ) (*mcp.ListToolsResult, error) 82 | 83 | // ListTools requests a list of available tools from the server 84 | ListTools( 85 | ctx context.Context, 86 | request mcp.ListToolsRequest, 87 | ) (*mcp.ListToolsResult, error) 88 | 89 | // CallTool invokes a specific tool on the server 90 | CallTool( 91 | ctx context.Context, 92 | request mcp.CallToolRequest, 93 | ) (*mcp.CallToolResult, error) 94 | 95 | // SetLevel sets the logging level for the server 96 | SetLevel(ctx context.Context, request mcp.SetLevelRequest) error 97 | 98 | // Complete requests completion options for a given argument 99 | Complete( 100 | ctx context.Context, 101 | request mcp.CompleteRequest, 102 | ) (*mcp.CompleteResult, error) 103 | 104 | // Close client connection and cleanup resources 105 | Close() error 106 | 107 | // OnNotification registers a handler for notifications 108 | OnNotification(handler func(notification mcp.JSONRPCNotification)) 109 | } 110 | -------------------------------------------------------------------------------- /client/oauth.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/mark3labs/mcp-go/client/transport" 8 | ) 9 | 10 | // OAuthConfig is a convenience type that wraps transport.OAuthConfig 11 | type OAuthConfig = transport.OAuthConfig 12 | 13 | // Token is a convenience type that wraps transport.Token 14 | type Token = transport.Token 15 | 16 | // TokenStore is a convenience type that wraps transport.TokenStore 17 | type TokenStore = transport.TokenStore 18 | 19 | // MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore 20 | type MemoryTokenStore = transport.MemoryTokenStore 21 | 22 | // NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore 23 | var NewMemoryTokenStore = transport.NewMemoryTokenStore 24 | 25 | // NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support. 26 | // Returns an error if the URL is invalid. 27 | func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) { 28 | // Add OAuth option to the list of options 29 | options = append(options, transport.WithOAuth(oauthConfig)) 30 | 31 | trans, err := transport.NewStreamableHTTP(baseURL, options...) 32 | if err != nil { 33 | return nil, fmt.Errorf("failed to create HTTP transport: %w", err) 34 | } 35 | return NewClient(trans), nil 36 | } 37 | 38 | // GenerateCodeVerifier generates a code verifier for PKCE 39 | var GenerateCodeVerifier = transport.GenerateCodeVerifier 40 | 41 | // GenerateCodeChallenge generates a code challenge from a code verifier 42 | var GenerateCodeChallenge = transport.GenerateCodeChallenge 43 | 44 | // GenerateState generates a state parameter for OAuth 45 | var GenerateState = transport.GenerateState 46 | 47 | // OAuthAuthorizationRequiredError is returned when OAuth authorization is required 48 | type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError 49 | 50 | // IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError 51 | func IsOAuthAuthorizationRequiredError(err error) bool { 52 | var target *OAuthAuthorizationRequiredError 53 | return errors.As(err, &target) 54 | } 55 | 56 | // GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError 57 | func GetOAuthHandler(err error) *transport.OAuthHandler { 58 | var oauthErr *OAuthAuthorizationRequiredError 59 | if errors.As(err, &oauthErr) { 60 | return oauthErr.Handler 61 | } 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /client/oauth_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/mark3labs/mcp-go/client/transport" 13 | ) 14 | 15 | func TestNewOAuthStreamableHttpClient(t *testing.T) { 16 | // Create a test server 17 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | // Check for Authorization header 19 | authHeader := r.Header.Get("Authorization") 20 | if authHeader != "Bearer test-token" { 21 | w.WriteHeader(http.StatusUnauthorized) 22 | return 23 | } 24 | 25 | // Return a successful response 26 | w.WriteHeader(http.StatusOK) 27 | w.Header().Set("Content-Type", "application/json") 28 | if err := json.NewEncoder(w).Encode(map[string]any{ 29 | "jsonrpc": "2.0", 30 | "id": 1, 31 | "result": map[string]any{ 32 | "protocolVersion": "2024-11-05", 33 | "serverInfo": map[string]any{ 34 | "name": "test-server", 35 | "version": "1.0.0", 36 | }, 37 | "capabilities": map[string]any{}, 38 | }, 39 | }); err != nil { 40 | t.Errorf("Failed to encode JSON response: %v", err) 41 | } 42 | })) 43 | defer server.Close() 44 | 45 | // Create a token store with a valid token 46 | tokenStore := NewMemoryTokenStore() 47 | validToken := &Token{ 48 | AccessToken: "test-token", 49 | TokenType: "Bearer", 50 | RefreshToken: "refresh-token", 51 | ExpiresIn: 3600, 52 | ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour 53 | } 54 | if err := tokenStore.SaveToken(validToken); err != nil { 55 | t.Fatalf("Failed to save token: %v", err) 56 | } 57 | 58 | // Create OAuth config 59 | oauthConfig := OAuthConfig{ 60 | ClientID: "test-client", 61 | RedirectURI: "http://localhost:8085/callback", 62 | Scopes: []string{"mcp.read", "mcp.write"}, 63 | TokenStore: tokenStore, 64 | PKCEEnabled: true, 65 | } 66 | 67 | // Create client with OAuth 68 | client, err := NewOAuthStreamableHttpClient(server.URL, oauthConfig) 69 | if err != nil { 70 | t.Fatalf("Failed to create client: %v", err) 71 | } 72 | 73 | // Start the client 74 | if err := client.Start(context.Background()); err != nil { 75 | t.Fatalf("Failed to start client: %v", err) 76 | } 77 | defer client.Close() 78 | 79 | // Verify that the client was created successfully 80 | trans := client.GetTransport() 81 | streamableHTTP, ok := trans.(*transport.StreamableHTTP) 82 | if !ok { 83 | t.Fatalf("Expected transport to be *transport.StreamableHTTP, got %T", trans) 84 | } 85 | 86 | // Verify OAuth is enabled 87 | if !streamableHTTP.IsOAuthEnabled() { 88 | t.Errorf("Expected IsOAuthEnabled() to return true") 89 | } 90 | 91 | // Verify the OAuth handler is set 92 | if streamableHTTP.GetOAuthHandler() == nil { 93 | t.Errorf("Expected GetOAuthHandler() to return a handler") 94 | } 95 | } 96 | 97 | func TestIsOAuthAuthorizationRequiredError(t *testing.T) { 98 | // Create a test error 99 | err := &transport.OAuthAuthorizationRequiredError{ 100 | Handler: transport.NewOAuthHandler(transport.OAuthConfig{}), 101 | } 102 | 103 | // Verify IsOAuthAuthorizationRequiredError returns true 104 | if !IsOAuthAuthorizationRequiredError(err) { 105 | t.Errorf("Expected IsOAuthAuthorizationRequiredError to return true") 106 | } 107 | 108 | // Verify GetOAuthHandler returns the handler 109 | handler := GetOAuthHandler(err) 110 | if handler == nil { 111 | t.Errorf("Expected GetOAuthHandler to return a handler") 112 | } 113 | 114 | // Test with a different error 115 | err2 := fmt.Errorf("some other error") 116 | 117 | // Verify IsOAuthAuthorizationRequiredError returns false 118 | if IsOAuthAuthorizationRequiredError(err2) { 119 | t.Errorf("Expected IsOAuthAuthorizationRequiredError to return false") 120 | } 121 | 122 | // Verify GetOAuthHandler returns nil 123 | handler = GetOAuthHandler(err2) 124 | if handler != nil { 125 | t.Errorf("Expected GetOAuthHandler to return nil") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /client/sse.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | 8 | "github.com/mark3labs/mcp-go/client/transport" 9 | ) 10 | 11 | func WithHeaders(headers map[string]string) transport.ClientOption { 12 | return transport.WithHeaders(headers) 13 | } 14 | 15 | func WithHeaderFunc(headerFunc transport.HTTPHeaderFunc) transport.ClientOption { 16 | return transport.WithHeaderFunc(headerFunc) 17 | } 18 | 19 | func WithHTTPClient(httpClient *http.Client) transport.ClientOption { 20 | return transport.WithHTTPClient(httpClient) 21 | } 22 | 23 | // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. 24 | // Returns an error if the URL is invalid. 25 | func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { 26 | 27 | sseTransport, err := transport.NewSSE(baseURL, options...) 28 | if err != nil { 29 | return nil, fmt.Errorf("failed to create SSE transport: %w", err) 30 | } 31 | 32 | return NewClient(sseTransport), nil 33 | } 34 | 35 | // GetEndpoint returns the current endpoint URL for the SSE connection. 36 | // 37 | // Note: This method only works with SSE transport, or it will panic. 38 | func GetEndpoint(c *Client) *url.URL { 39 | t := c.GetTransport() 40 | sse := t.(*transport.SSE) 41 | return sse.GetEndpoint() 42 | } 43 | -------------------------------------------------------------------------------- /client/sse_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "testing" 7 | "time" 8 | 9 | "github.com/mark3labs/mcp-go/client/transport" 10 | 11 | "github.com/mark3labs/mcp-go/mcp" 12 | "github.com/mark3labs/mcp-go/server" 13 | ) 14 | 15 | type contextKey string 16 | 17 | const ( 18 | testHeaderKey contextKey = "X-Test-Header" 19 | testHeaderFuncKey contextKey = "X-Test-Header-Func" 20 | ) 21 | 22 | func TestSSEMCPClient(t *testing.T) { 23 | // Create MCP server with capabilities 24 | mcpServer := server.NewMCPServer( 25 | "test-server", 26 | "1.0.0", 27 | server.WithResourceCapabilities(true, true), 28 | server.WithPromptCapabilities(true), 29 | server.WithToolCapabilities(true), 30 | ) 31 | 32 | // Add a test tool 33 | mcpServer.AddTool(mcp.NewTool( 34 | "test-tool", 35 | mcp.WithDescription("Test tool"), 36 | mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), 37 | mcp.WithTitleAnnotation("Test Tool Annotation Title"), 38 | mcp.WithReadOnlyHintAnnotation(true), 39 | mcp.WithDestructiveHintAnnotation(false), 40 | mcp.WithIdempotentHintAnnotation(true), 41 | mcp.WithOpenWorldHintAnnotation(false), 42 | ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 43 | return &mcp.CallToolResult{ 44 | Content: []mcp.Content{ 45 | mcp.TextContent{ 46 | Type: "text", 47 | Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string), 48 | }, 49 | }, 50 | }, nil 51 | }) 52 | mcpServer.AddTool(mcp.NewTool( 53 | "test-tool-for-http-header", 54 | mcp.WithDescription("Test tool for http header"), 55 | ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 56 | // , X-Test-Header-Func 57 | return &mcp.CallToolResult{ 58 | Content: []mcp.Content{ 59 | mcp.TextContent{ 60 | Type: "text", 61 | Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string), 62 | }, 63 | }, 64 | }, nil 65 | }) 66 | 67 | // Initialize 68 | testServer := server.NewTestServer(mcpServer, 69 | server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { 70 | ctx = context.WithValue(ctx, testHeaderKey, r.Header.Get("X-Test-Header")) 71 | ctx = context.WithValue(ctx, testHeaderFuncKey, r.Header.Get("X-Test-Header-Func")) 72 | return ctx 73 | }), 74 | ) 75 | defer testServer.Close() 76 | 77 | t.Run("Can create client", func(t *testing.T) { 78 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 79 | if err != nil { 80 | t.Fatalf("Failed to create client: %v", err) 81 | } 82 | defer client.Close() 83 | 84 | sseTransport := client.GetTransport().(*transport.SSE) 85 | if sseTransport.GetBaseURL() == nil { 86 | t.Error("Base URL should not be nil") 87 | } 88 | }) 89 | 90 | t.Run("Can initialize and make requests", func(t *testing.T) { 91 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 92 | if err != nil { 93 | t.Fatalf("Failed to create client: %v", err) 94 | } 95 | defer client.Close() 96 | 97 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 98 | defer cancel() 99 | 100 | // Start the client 101 | if err := client.Start(ctx); err != nil { 102 | t.Fatalf("Failed to start client: %v", err) 103 | } 104 | 105 | // Initialize 106 | initRequest := mcp.InitializeRequest{} 107 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 108 | initRequest.Params.ClientInfo = mcp.Implementation{ 109 | Name: "test-client", 110 | Version: "1.0.0", 111 | } 112 | 113 | result, err := client.Initialize(ctx, initRequest) 114 | if err != nil { 115 | t.Fatalf("Failed to initialize: %v", err) 116 | } 117 | 118 | if result.ServerInfo.Name != "test-server" { 119 | t.Errorf( 120 | "Expected server name 'test-server', got '%s'", 121 | result.ServerInfo.Name, 122 | ) 123 | } 124 | 125 | // Test Ping 126 | if err := client.Ping(ctx); err != nil { 127 | t.Errorf("Ping failed: %v", err) 128 | } 129 | 130 | // Test ListTools 131 | toolsRequest := mcp.ListToolsRequest{} 132 | toolListResult, err := client.ListTools(ctx, toolsRequest) 133 | if err != nil { 134 | t.Errorf("ListTools failed: %v", err) 135 | } 136 | if toolListResult == nil || len((*toolListResult).Tools) == 0 { 137 | t.Errorf("Expected one tool") 138 | } 139 | testToolAnnotations := (*toolListResult).Tools[0].Annotations 140 | if testToolAnnotations.Title != "Test Tool Annotation Title" || 141 | *testToolAnnotations.ReadOnlyHint != true || 142 | *testToolAnnotations.DestructiveHint != false || 143 | *testToolAnnotations.IdempotentHint != true || 144 | *testToolAnnotations.OpenWorldHint != false { 145 | t.Errorf("The annotations of the tools are invalid") 146 | } 147 | }) 148 | 149 | // t.Run("Can handle notifications", func(t *testing.T) { 150 | // client, err := NewSSEMCPClient(testServer.URL + "/sse") 151 | // if err != nil { 152 | // t.Fatalf("Failed to create client: %v", err) 153 | // } 154 | // defer client.Close() 155 | 156 | // notificationReceived := make(chan mcp.JSONRPCNotification, 1) 157 | // client.OnNotification(func(notification mcp.JSONRPCNotification) { 158 | // notificationReceived <- notification 159 | // }) 160 | 161 | // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 162 | // defer cancel() 163 | 164 | // if err := client.Start(ctx); err != nil { 165 | // t.Fatalf("Failed to start client: %v", err) 166 | // } 167 | 168 | // // Initialize first 169 | // initRequest := mcp.InitializeRequest{} 170 | // initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 171 | // initRequest.Params.ClientInfo = mcp.Implementation{ 172 | // Name: "test-client", 173 | // Version: "1.0.0", 174 | // } 175 | 176 | // _, err = client.Initialize(ctx, initRequest) 177 | // if err != nil { 178 | // t.Fatalf("Failed to initialize: %v", err) 179 | // } 180 | 181 | // // Subscribe to a resource to test notifications 182 | // subRequest := mcp.SubscribeRequest{} 183 | // subRequest.Params.URI = "test://resource" 184 | // if err := client.Subscribe(ctx, subRequest); err != nil { 185 | // t.Fatalf("Failed to subscribe: %v", err) 186 | // } 187 | 188 | // select { 189 | // case <-notificationReceived: 190 | // // Success 191 | // case <-time.After(time.Second): 192 | // t.Error("Timeout waiting for notification") 193 | // } 194 | // }) 195 | 196 | t.Run("Handles errors properly", func(t *testing.T) { 197 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 198 | if err != nil { 199 | t.Fatalf("Failed to create client: %v", err) 200 | } 201 | defer client.Close() 202 | 203 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 204 | defer cancel() 205 | 206 | if err := client.Start(ctx); err != nil { 207 | t.Fatalf("Failed to start client: %v", err) 208 | } 209 | 210 | // Try to make a request without initializing 211 | toolsRequest := mcp.ListToolsRequest{} 212 | _, err = client.ListTools(ctx, toolsRequest) 213 | if err == nil { 214 | t.Error("Expected error when making request before initialization") 215 | } 216 | }) 217 | 218 | // t.Run("Handles context cancellation", func(t *testing.T) { 219 | // client, err := NewSSEMCPClient(testServer.URL + "/sse") 220 | // if err != nil { 221 | // t.Fatalf("Failed to create client: %v", err) 222 | // } 223 | // defer client.Close() 224 | 225 | // if err := client.Start(context.Background()); err != nil { 226 | // t.Fatalf("Failed to start client: %v", err) 227 | // } 228 | 229 | // ctx, cancel := context.WithCancel(context.Background()) 230 | // cancel() // Cancel immediately 231 | 232 | // toolsRequest := mcp.ListToolsRequest{} 233 | // _, err = client.ListTools(ctx, toolsRequest) 234 | // if err == nil { 235 | // t.Error("Expected error when context is cancelled") 236 | // } 237 | // }) 238 | 239 | t.Run("CallTool", func(t *testing.T) { 240 | client, err := NewSSEMCPClient(testServer.URL + "/sse") 241 | if err != nil { 242 | t.Fatalf("Failed to create client: %v", err) 243 | } 244 | defer client.Close() 245 | 246 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 247 | defer cancel() 248 | 249 | if err := client.Start(ctx); err != nil { 250 | t.Fatalf("Failed to start client: %v", err) 251 | } 252 | 253 | // Initialize 254 | initRequest := mcp.InitializeRequest{} 255 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 256 | initRequest.Params.ClientInfo = mcp.Implementation{ 257 | Name: "test-client", 258 | Version: "1.0.0", 259 | } 260 | 261 | _, err = client.Initialize(ctx, initRequest) 262 | if err != nil { 263 | t.Fatalf("Failed to initialize: %v", err) 264 | } 265 | 266 | request := mcp.CallToolRequest{} 267 | request.Params.Name = "test-tool" 268 | request.Params.Arguments = map[string]any{ 269 | "parameter-1": "value1", 270 | } 271 | 272 | result, err := client.CallTool(ctx, request) 273 | if err != nil { 274 | t.Fatalf("CallTool failed: %v", err) 275 | } 276 | 277 | if len(result.Content) != 1 { 278 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 279 | } 280 | }) 281 | 282 | t.Run("CallTool with customized header", func(t *testing.T) { 283 | client, err := NewSSEMCPClient(testServer.URL+"/sse", 284 | WithHeaders(map[string]string{ 285 | "X-Test-Header": "test-header-value", 286 | }), 287 | WithHeaderFunc(func(ctx context.Context) map[string]string { 288 | return map[string]string{ 289 | "X-Test-Header-Func": "test-header-func-value", 290 | } 291 | }), 292 | ) 293 | if err != nil { 294 | t.Fatalf("Failed to create client: %v", err) 295 | } 296 | defer client.Close() 297 | 298 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 299 | defer cancel() 300 | 301 | if err := client.Start(ctx); err != nil { 302 | t.Fatalf("Failed to start client: %v", err) 303 | } 304 | 305 | // Initialize 306 | initRequest := mcp.InitializeRequest{} 307 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 308 | initRequest.Params.ClientInfo = mcp.Implementation{ 309 | Name: "test-client", 310 | Version: "1.0.0", 311 | } 312 | 313 | _, err = client.Initialize(ctx, initRequest) 314 | if err != nil { 315 | t.Fatalf("Failed to initialize: %v", err) 316 | } 317 | 318 | request := mcp.CallToolRequest{} 319 | request.Params.Name = "test-tool-for-http-header" 320 | 321 | result, err := client.CallTool(ctx, request) 322 | if err != nil { 323 | t.Fatalf("CallTool failed: %v", err) 324 | } 325 | 326 | if len(result.Content) != 1 { 327 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 328 | } 329 | if result.Content[0].(mcp.TextContent).Text != "context from header: test-header-value, test-header-func-value" { 330 | t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value") 331 | } 332 | }) 333 | } 334 | -------------------------------------------------------------------------------- /client/stdio.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/mark3labs/mcp-go/client/transport" 9 | ) 10 | 11 | // NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. 12 | // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. 13 | // Returns an error if the subprocess cannot be started or the pipes cannot be created. 14 | // 15 | // NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. 16 | // This is for backward compatibility. 17 | func NewStdioMCPClient( 18 | command string, 19 | env []string, 20 | args ...string, 21 | ) (*Client, error) { 22 | 23 | stdioTransport := transport.NewStdio(command, env, args...) 24 | err := stdioTransport.Start(context.Background()) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to start stdio transport: %w", err) 27 | } 28 | 29 | return NewClient(stdioTransport), nil 30 | } 31 | 32 | // GetStderr returns a reader for the stderr output of the subprocess. 33 | // This can be used to capture error messages or logs from the subprocess. 34 | func GetStderr(c *Client) (io.Reader, bool) { 35 | t := c.GetTransport() 36 | 37 | stdio, ok := t.(*transport.Stdio) 38 | if !ok { 39 | return nil, false 40 | } 41 | 42 | return stdio.Stderr(), true 43 | } 44 | -------------------------------------------------------------------------------- /client/stdio_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | "os/exec" 10 | "runtime" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | func compileTestServer(outputPath string) error { 19 | cmd := exec.Command( 20 | "go", 21 | "build", 22 | "-buildmode=pie", 23 | "-o", 24 | outputPath, 25 | "../testdata/mockstdio_server.go", 26 | ) 27 | tmpCache, _ := os.MkdirTemp("", "gocache") 28 | cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache) 29 | 30 | if output, err := cmd.CombinedOutput(); err != nil { 31 | return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) 32 | } 33 | // Verify the binary was actually created 34 | if _, err := os.Stat(outputPath); os.IsNotExist(err) { 35 | return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) 36 | } 37 | return nil 38 | } 39 | 40 | func TestStdioMCPClient(t *testing.T) { 41 | // Create a temporary file for the mock server 42 | tempFile, err := os.CreateTemp("", "mockstdio_server") 43 | if err != nil { 44 | t.Fatalf("Failed to create temp file: %v", err) 45 | } 46 | tempFile.Close() 47 | mockServerPath := tempFile.Name() 48 | 49 | // Add .exe suffix on Windows 50 | if runtime.GOOS == "windows" { 51 | os.Remove(mockServerPath) // Remove the empty file first 52 | mockServerPath += ".exe" 53 | } 54 | 55 | if compileErr := compileTestServer(mockServerPath); compileErr != nil { 56 | t.Fatalf("Failed to compile mock server: %v", compileErr) 57 | } 58 | defer os.Remove(mockServerPath) 59 | 60 | client, err := NewStdioMCPClient(mockServerPath, []string{}) 61 | if err != nil { 62 | t.Fatalf("Failed to create client: %v", err) 63 | } 64 | var logRecords []map[string]any 65 | var logRecordsMu sync.RWMutex 66 | var wg sync.WaitGroup 67 | wg.Add(1) 68 | go func() { 69 | defer wg.Done() 70 | 71 | stderr, ok := GetStderr(client) 72 | if !ok { 73 | return 74 | } 75 | 76 | dec := json.NewDecoder(stderr) 77 | for { 78 | var record map[string]any 79 | if err := dec.Decode(&record); err != nil { 80 | return 81 | } 82 | logRecordsMu.Lock() 83 | logRecords = append(logRecords, record) 84 | logRecordsMu.Unlock() 85 | } 86 | }() 87 | 88 | t.Run("Initialize", func(t *testing.T) { 89 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 90 | defer cancel() 91 | 92 | request := mcp.InitializeRequest{} 93 | request.Params.ProtocolVersion = "1.0" 94 | request.Params.ClientInfo = mcp.Implementation{ 95 | Name: "test-client", 96 | Version: "1.0.0", 97 | } 98 | request.Params.Capabilities = mcp.ClientCapabilities{ 99 | Roots: &struct { 100 | ListChanged bool `json:"listChanged,omitempty"` 101 | }{ 102 | ListChanged: true, 103 | }, 104 | } 105 | 106 | result, err := client.Initialize(ctx, request) 107 | if err != nil { 108 | t.Fatalf("Initialize failed: %v", err) 109 | } 110 | 111 | if result.ServerInfo.Name != "mock-server" { 112 | t.Errorf( 113 | "Expected server name 'mock-server', got '%s'", 114 | result.ServerInfo.Name, 115 | ) 116 | } 117 | }) 118 | 119 | t.Run("Ping", func(t *testing.T) { 120 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 121 | defer cancel() 122 | 123 | err := client.Ping(ctx) 124 | if err != nil { 125 | t.Errorf("Ping failed: %v", err) 126 | } 127 | }) 128 | 129 | t.Run("ListResources", func(t *testing.T) { 130 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 131 | defer cancel() 132 | 133 | request := mcp.ListResourcesRequest{} 134 | result, err := client.ListResources(ctx, request) 135 | if err != nil { 136 | t.Errorf("ListResources failed: %v", err) 137 | } 138 | 139 | if len(result.Resources) != 1 { 140 | t.Errorf("Expected 1 resource, got %d", len(result.Resources)) 141 | } 142 | }) 143 | 144 | t.Run("ReadResource", func(t *testing.T) { 145 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 146 | defer cancel() 147 | 148 | request := mcp.ReadResourceRequest{} 149 | request.Params.URI = "test://resource" 150 | 151 | result, err := client.ReadResource(ctx, request) 152 | if err != nil { 153 | t.Errorf("ReadResource failed: %v", err) 154 | } 155 | 156 | if len(result.Contents) != 1 { 157 | t.Errorf("Expected 1 content item, got %d", len(result.Contents)) 158 | } 159 | }) 160 | 161 | t.Run("Subscribe and Unsubscribe", func(t *testing.T) { 162 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 163 | defer cancel() 164 | 165 | // Test Subscribe 166 | subRequest := mcp.SubscribeRequest{} 167 | subRequest.Params.URI = "test://resource" 168 | err := client.Subscribe(ctx, subRequest) 169 | if err != nil { 170 | t.Errorf("Subscribe failed: %v", err) 171 | } 172 | 173 | // Test Unsubscribe 174 | unsubRequest := mcp.UnsubscribeRequest{} 175 | unsubRequest.Params.URI = "test://resource" 176 | err = client.Unsubscribe(ctx, unsubRequest) 177 | if err != nil { 178 | t.Errorf("Unsubscribe failed: %v", err) 179 | } 180 | }) 181 | 182 | t.Run("ListPrompts", func(t *testing.T) { 183 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 184 | defer cancel() 185 | 186 | request := mcp.ListPromptsRequest{} 187 | result, err := client.ListPrompts(ctx, request) 188 | if err != nil { 189 | t.Errorf("ListPrompts failed: %v", err) 190 | } 191 | 192 | if len(result.Prompts) != 1 { 193 | t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) 194 | } 195 | }) 196 | 197 | t.Run("GetPrompt", func(t *testing.T) { 198 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 199 | defer cancel() 200 | 201 | request := mcp.GetPromptRequest{} 202 | request.Params.Name = "test-prompt" 203 | 204 | result, err := client.GetPrompt(ctx, request) 205 | if err != nil { 206 | t.Errorf("GetPrompt failed: %v", err) 207 | } 208 | 209 | if len(result.Messages) != 1 { 210 | t.Errorf("Expected 1 message, got %d", len(result.Messages)) 211 | } 212 | }) 213 | 214 | t.Run("ListTools", func(t *testing.T) { 215 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 216 | defer cancel() 217 | 218 | request := mcp.ListToolsRequest{} 219 | result, err := client.ListTools(ctx, request) 220 | if err != nil { 221 | t.Errorf("ListTools failed: %v", err) 222 | } 223 | 224 | if len(result.Tools) != 1 { 225 | t.Errorf("Expected 1 tool, got %d", len(result.Tools)) 226 | } 227 | }) 228 | 229 | t.Run("CallTool", func(t *testing.T) { 230 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 231 | defer cancel() 232 | 233 | request := mcp.CallToolRequest{} 234 | request.Params.Name = "test-tool" 235 | request.Params.Arguments = map[string]any{ 236 | "param1": "value1", 237 | } 238 | 239 | result, err := client.CallTool(ctx, request) 240 | if err != nil { 241 | t.Errorf("CallTool failed: %v", err) 242 | } 243 | 244 | if len(result.Content) != 1 { 245 | t.Errorf("Expected 1 content item, got %d", len(result.Content)) 246 | } 247 | }) 248 | 249 | t.Run("SetLevel", func(t *testing.T) { 250 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 251 | defer cancel() 252 | 253 | request := mcp.SetLevelRequest{} 254 | request.Params.Level = mcp.LoggingLevelInfo 255 | 256 | err := client.SetLevel(ctx, request) 257 | if err != nil { 258 | t.Errorf("SetLevel failed: %v", err) 259 | } 260 | }) 261 | 262 | t.Run("Complete", func(t *testing.T) { 263 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 264 | defer cancel() 265 | 266 | request := mcp.CompleteRequest{} 267 | request.Params.Ref = mcp.PromptReference{ 268 | Type: "ref/prompt", 269 | Name: "test-prompt", 270 | } 271 | request.Params.Argument.Name = "test-arg" 272 | request.Params.Argument.Value = "test-value" 273 | 274 | result, err := client.Complete(ctx, request) 275 | if err != nil { 276 | t.Errorf("Complete failed: %v", err) 277 | } 278 | 279 | if len(result.Completion.Values) != 1 { 280 | t.Errorf( 281 | "Expected 1 completion value, got %d", 282 | len(result.Completion.Values), 283 | ) 284 | } 285 | }) 286 | 287 | client.Close() 288 | wg.Wait() 289 | 290 | t.Run("CheckLogs", func(t *testing.T) { 291 | logRecordsMu.RLock() 292 | defer logRecordsMu.RUnlock() 293 | 294 | if len(logRecords) != 1 { 295 | t.Errorf("Expected 1 log record, got %d", len(logRecords)) 296 | return 297 | } 298 | 299 | msg, ok := logRecords[0][slog.MessageKey].(string) 300 | if !ok { 301 | t.Errorf("Expected log record to have message key") 302 | } 303 | if msg != "launch successful" { 304 | t.Errorf("Expected log message 'launch successful', got '%s'", msg) 305 | } 306 | }) 307 | } 308 | -------------------------------------------------------------------------------- /client/transport/inprocess.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "sync" 8 | 9 | "github.com/mark3labs/mcp-go/mcp" 10 | "github.com/mark3labs/mcp-go/server" 11 | ) 12 | 13 | type InProcessTransport struct { 14 | server *server.MCPServer 15 | 16 | onNotification func(mcp.JSONRPCNotification) 17 | notifyMu sync.RWMutex 18 | } 19 | 20 | func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { 21 | return &InProcessTransport{ 22 | server: server, 23 | } 24 | } 25 | 26 | func (c *InProcessTransport) Start(ctx context.Context) error { 27 | return nil 28 | } 29 | 30 | func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { 31 | requestBytes, err := json.Marshal(request) 32 | if err != nil { 33 | return nil, fmt.Errorf("failed to marshal request: %w", err) 34 | } 35 | requestBytes = append(requestBytes, '\n') 36 | 37 | respMessage := c.server.HandleMessage(ctx, requestBytes) 38 | respByte, err := json.Marshal(respMessage) 39 | if err != nil { 40 | return nil, fmt.Errorf("failed to marshal response message: %w", err) 41 | } 42 | rpcResp := JSONRPCResponse{} 43 | err = json.Unmarshal(respByte, &rpcResp) 44 | if err != nil { 45 | return nil, fmt.Errorf("failed to unmarshal response message: %w", err) 46 | } 47 | 48 | return &rpcResp, nil 49 | } 50 | 51 | func (c *InProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { 52 | notificationBytes, err := json.Marshal(notification) 53 | if err != nil { 54 | return fmt.Errorf("failed to marshal notification: %w", err) 55 | } 56 | notificationBytes = append(notificationBytes, '\n') 57 | c.server.HandleMessage(ctx, notificationBytes) 58 | 59 | return nil 60 | } 61 | 62 | func (c *InProcessTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { 63 | c.notifyMu.Lock() 64 | defer c.notifyMu.Unlock() 65 | c.onNotification = handler 66 | } 67 | 68 | func (*InProcessTransport) Close() error { 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /client/transport/interface.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | ) 9 | 10 | // HTTPHeaderFunc is a function that extracts header entries from the given context 11 | // and returns them as key-value pairs. This is typically used to add context values 12 | // as HTTP headers in outgoing requests. 13 | type HTTPHeaderFunc func(context.Context) map[string]string 14 | 15 | // Interface for the transport layer. 16 | type Interface interface { 17 | // Start the connection. Start should only be called once. 18 | Start(ctx context.Context) error 19 | 20 | // SendRequest sends a json RPC request and returns the response synchronously. 21 | SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) 22 | 23 | // SendNotification sends a json RPC Notification to the server. 24 | SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error 25 | 26 | // SetNotificationHandler sets the handler for notifications. 27 | // Any notification before the handler is set will be discarded. 28 | SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) 29 | 30 | // Close the connection. 31 | Close() error 32 | } 33 | 34 | type JSONRPCRequest struct { 35 | JSONRPC string `json:"jsonrpc"` 36 | ID mcp.RequestId `json:"id"` 37 | Method string `json:"method"` 38 | Params any `json:"params,omitempty"` 39 | } 40 | 41 | type JSONRPCResponse struct { 42 | JSONRPC string `json:"jsonrpc"` 43 | ID mcp.RequestId `json:"id"` 44 | Result json.RawMessage `json:"result"` 45 | Error *struct { 46 | Code int `json:"code"` 47 | Message string `json:"message"` 48 | Data json.RawMessage `json:"data"` 49 | } `json:"error"` 50 | } 51 | -------------------------------------------------------------------------------- /client/transport/oauth_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestToken_IsExpired(t *testing.T) { 12 | // Test cases 13 | testCases := []struct { 14 | name string 15 | token Token 16 | expected bool 17 | }{ 18 | { 19 | name: "Valid token", 20 | token: Token{ 21 | AccessToken: "valid-token", 22 | ExpiresAt: time.Now().Add(1 * time.Hour), 23 | }, 24 | expected: false, 25 | }, 26 | { 27 | name: "Expired token", 28 | token: Token{ 29 | AccessToken: "expired-token", 30 | ExpiresAt: time.Now().Add(-1 * time.Hour), 31 | }, 32 | expected: true, 33 | }, 34 | { 35 | name: "Token with no expiration", 36 | token: Token{ 37 | AccessToken: "no-expiration-token", 38 | }, 39 | expected: false, 40 | }, 41 | } 42 | 43 | // Run test cases 44 | for _, tc := range testCases { 45 | t.Run(tc.name, func(t *testing.T) { 46 | result := tc.token.IsExpired() 47 | if result != tc.expected { 48 | t.Errorf("Expected IsExpired() to return %v, got %v", tc.expected, result) 49 | } 50 | }) 51 | } 52 | } 53 | 54 | func TestMemoryTokenStore(t *testing.T) { 55 | // Create a token store 56 | store := NewMemoryTokenStore() 57 | 58 | // Test getting token from empty store 59 | _, err := store.GetToken() 60 | if err == nil { 61 | t.Errorf("Expected error when getting token from empty store") 62 | } 63 | 64 | // Create a test token 65 | token := &Token{ 66 | AccessToken: "test-token", 67 | TokenType: "Bearer", 68 | RefreshToken: "refresh-token", 69 | ExpiresIn: 3600, 70 | ExpiresAt: time.Now().Add(1 * time.Hour), 71 | } 72 | 73 | // Save the token 74 | err = store.SaveToken(token) 75 | if err != nil { 76 | t.Fatalf("Failed to save token: %v", err) 77 | } 78 | 79 | // Get the token 80 | retrievedToken, err := store.GetToken() 81 | if err != nil { 82 | t.Fatalf("Failed to get token: %v", err) 83 | } 84 | 85 | // Verify the token 86 | if retrievedToken.AccessToken != token.AccessToken { 87 | t.Errorf("Expected access token to be %s, got %s", token.AccessToken, retrievedToken.AccessToken) 88 | } 89 | if retrievedToken.TokenType != token.TokenType { 90 | t.Errorf("Expected token type to be %s, got %s", token.TokenType, retrievedToken.TokenType) 91 | } 92 | if retrievedToken.RefreshToken != token.RefreshToken { 93 | t.Errorf("Expected refresh token to be %s, got %s", token.RefreshToken, retrievedToken.RefreshToken) 94 | } 95 | } 96 | 97 | func TestValidateRedirectURI(t *testing.T) { 98 | // Test cases 99 | testCases := []struct { 100 | name string 101 | redirectURI string 102 | expectError bool 103 | }{ 104 | { 105 | name: "Valid HTTPS URI", 106 | redirectURI: "https://example.com/callback", 107 | expectError: false, 108 | }, 109 | { 110 | name: "Valid localhost URI", 111 | redirectURI: "http://localhost:8085/callback", 112 | expectError: false, 113 | }, 114 | { 115 | name: "Valid localhost URI with 127.0.0.1", 116 | redirectURI: "http://127.0.0.1:8085/callback", 117 | expectError: false, 118 | }, 119 | { 120 | name: "Invalid HTTP URI (non-localhost)", 121 | redirectURI: "http://example.com/callback", 122 | expectError: true, 123 | }, 124 | { 125 | name: "Invalid HTTP URI with 'local' in domain", 126 | redirectURI: "http://localdomain.com/callback", 127 | expectError: true, 128 | }, 129 | { 130 | name: "Empty URI", 131 | redirectURI: "", 132 | expectError: true, 133 | }, 134 | { 135 | name: "Invalid scheme", 136 | redirectURI: "ftp://example.com/callback", 137 | expectError: true, 138 | }, 139 | { 140 | name: "IPv6 localhost", 141 | redirectURI: "http://[::1]:8080/callback", 142 | expectError: false, // IPv6 localhost is valid 143 | }, 144 | } 145 | 146 | // Run test cases 147 | for _, tc := range testCases { 148 | t.Run(tc.name, func(t *testing.T) { 149 | err := ValidateRedirectURI(tc.redirectURI) 150 | if tc.expectError && err == nil { 151 | t.Errorf("Expected error for redirect URI %s, got nil", tc.redirectURI) 152 | } else if !tc.expectError && err != nil { 153 | t.Errorf("Expected no error for redirect URI %s, got %v", tc.redirectURI, err) 154 | } 155 | }) 156 | } 157 | } 158 | 159 | func TestOAuthHandler_GetAuthorizationHeader_EmptyAccessToken(t *testing.T) { 160 | // Create a token store with a token that has an empty access token 161 | tokenStore := NewMemoryTokenStore() 162 | invalidToken := &Token{ 163 | AccessToken: "", // Empty access token 164 | TokenType: "Bearer", 165 | RefreshToken: "refresh-token", 166 | ExpiresIn: 3600, 167 | ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour 168 | } 169 | if err := tokenStore.SaveToken(invalidToken); err != nil { 170 | t.Fatalf("Failed to save token: %v", err) 171 | } 172 | 173 | // Create an OAuth handler 174 | config := OAuthConfig{ 175 | ClientID: "test-client", 176 | RedirectURI: "http://localhost:8085/callback", 177 | Scopes: []string{"mcp.read", "mcp.write"}, 178 | TokenStore: tokenStore, 179 | PKCEEnabled: true, 180 | } 181 | 182 | handler := NewOAuthHandler(config) 183 | 184 | // Test getting authorization header with empty access token 185 | _, err := handler.GetAuthorizationHeader(context.Background()) 186 | if err == nil { 187 | t.Fatalf("Expected error when getting authorization header with empty access token") 188 | } 189 | 190 | // Verify the error message 191 | if !errors.Is(err, ErrOAuthAuthorizationRequired) { 192 | t.Errorf("Expected error to be ErrOAuthAuthorizationRequired, got %v", err) 193 | } 194 | } 195 | 196 | func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { 197 | // Create an OAuth handler with an empty AuthServerMetadataURL 198 | config := OAuthConfig{ 199 | ClientID: "test-client", 200 | RedirectURI: "http://localhost:8085/callback", 201 | Scopes: []string{"mcp.read"}, 202 | TokenStore: NewMemoryTokenStore(), 203 | AuthServerMetadataURL: "", // Empty URL 204 | PKCEEnabled: true, 205 | } 206 | 207 | handler := NewOAuthHandler(config) 208 | 209 | // Test getting server metadata with empty URL 210 | _, err := handler.GetServerMetadata(context.Background()) 211 | if err == nil { 212 | t.Fatalf("Expected error when getting server metadata with empty URL") 213 | } 214 | 215 | // Verify the error message contains something about a connection error 216 | // since we're now trying to connect to the well-known endpoint 217 | if !strings.Contains(err.Error(), "connection refused") && 218 | !strings.Contains(err.Error(), "failed to send protected resource request") { 219 | t.Errorf("Expected error message to contain connection error, got %s", err.Error()) 220 | } 221 | } 222 | 223 | func TestOAuthError(t *testing.T) { 224 | testCases := []struct { 225 | name string 226 | errorCode string 227 | description string 228 | uri string 229 | expected string 230 | }{ 231 | { 232 | name: "Error with description", 233 | errorCode: "invalid_request", 234 | description: "The request is missing a required parameter", 235 | uri: "https://example.com/errors/invalid_request", 236 | expected: "OAuth error: invalid_request - The request is missing a required parameter", 237 | }, 238 | { 239 | name: "Error without description", 240 | errorCode: "unauthorized_client", 241 | description: "", 242 | uri: "", 243 | expected: "OAuth error: unauthorized_client", 244 | }, 245 | } 246 | 247 | for _, tc := range testCases { 248 | t.Run(tc.name, func(t *testing.T) { 249 | oauthErr := OAuthError{ 250 | ErrorCode: tc.errorCode, 251 | ErrorDescription: tc.description, 252 | ErrorURI: tc.uri, 253 | } 254 | 255 | if oauthErr.Error() != tc.expected { 256 | t.Errorf("Expected error message %q, got %q", tc.expected, oauthErr.Error()) 257 | } 258 | }) 259 | } 260 | } 261 | 262 | func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) { 263 | // Create an OAuth handler 264 | config := OAuthConfig{ 265 | ClientID: "test-client", 266 | RedirectURI: "http://localhost:8085/callback", 267 | Scopes: []string{"mcp.read", "mcp.write"}, 268 | TokenStore: NewMemoryTokenStore(), 269 | AuthServerMetadataURL: "http://example.com/.well-known/oauth-authorization-server", 270 | PKCEEnabled: true, 271 | } 272 | 273 | handler := NewOAuthHandler(config) 274 | 275 | // Mock the server metadata to avoid nil pointer dereference 276 | handler.serverMetadata = &AuthServerMetadata{ 277 | Issuer: "http://example.com", 278 | AuthorizationEndpoint: "http://example.com/authorize", 279 | TokenEndpoint: "http://example.com/token", 280 | } 281 | 282 | // Set the expected state 283 | expectedState := "test-state-123" 284 | handler.expectedState = expectedState 285 | 286 | // Test with non-matching state - this should fail immediately with ErrInvalidState 287 | // before trying to connect to any server 288 | err := handler.ProcessAuthorizationResponse(context.Background(), "test-code", "wrong-state", "test-code-verifier") 289 | if !errors.Is(err, ErrInvalidState) { 290 | t.Errorf("Expected ErrInvalidState, got %v", err) 291 | } 292 | 293 | // Test with empty expected state 294 | handler.expectedState = "" 295 | err = handler.ProcessAuthorizationResponse(context.Background(), "test-code", expectedState, "test-code-verifier") 296 | if err == nil { 297 | t.Errorf("Expected error with empty expected state, got nil") 298 | } 299 | if errors.Is(err, ErrInvalidState) { 300 | t.Errorf("Got ErrInvalidState when expected a different error for empty expected state") 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /client/transport/oauth_utils.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "fmt" 8 | "net/url" 9 | ) 10 | 11 | // GenerateRandomString generates a random string of the specified length 12 | func GenerateRandomString(length int) (string, error) { 13 | bytes := make([]byte, length) 14 | if _, err := rand.Read(bytes); err != nil { 15 | return "", err 16 | } 17 | return base64.RawURLEncoding.EncodeToString(bytes)[:length], nil 18 | } 19 | 20 | // GenerateCodeVerifier generates a code verifier for PKCE 21 | func GenerateCodeVerifier() (string, error) { 22 | // According to RFC 7636, the code verifier should be between 43 and 128 characters 23 | return GenerateRandomString(64) 24 | } 25 | 26 | // GenerateCodeChallenge generates a code challenge from a code verifier 27 | func GenerateCodeChallenge(codeVerifier string) string { 28 | // SHA256 hash the code verifier 29 | hash := sha256.Sum256([]byte(codeVerifier)) 30 | // Base64url encode the hash 31 | return base64.RawURLEncoding.EncodeToString(hash[:]) 32 | } 33 | 34 | // GenerateState generates a state parameter for OAuth 35 | func GenerateState() (string, error) { 36 | return GenerateRandomString(32) 37 | } 38 | 39 | // ValidateRedirectURI validates that a redirect URI is secure 40 | func ValidateRedirectURI(redirectURI string) error { 41 | // According to the spec, redirect URIs must be either localhost URLs or HTTPS URLs 42 | if redirectURI == "" { 43 | return fmt.Errorf("redirect URI cannot be empty") 44 | } 45 | 46 | // Parse the URL 47 | parsedURL, err := url.Parse(redirectURI) 48 | if err != nil { 49 | return fmt.Errorf("invalid redirect URI: %w", err) 50 | } 51 | 52 | // Check if it's a localhost URL 53 | if parsedURL.Scheme == "http" { 54 | hostname := parsedURL.Hostname() 55 | // Check for various forms of localhost 56 | if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "[::1]" { 57 | return nil 58 | } 59 | return fmt.Errorf("HTTP redirect URI must use localhost or 127.0.0.1") 60 | } 61 | 62 | // Check if it's an HTTPS URL 63 | if parsedURL.Scheme == "https" { 64 | return nil 65 | } 66 | 67 | return fmt.Errorf("redirect URI must use either HTTP with localhost or HTTPS") 68 | } 69 | -------------------------------------------------------------------------------- /client/transport/oauth_utils_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestGenerateRandomString(t *testing.T) { 9 | // Test generating strings of different lengths 10 | lengths := []int{10, 32, 64, 128} 11 | for _, length := range lengths { 12 | t.Run(fmt.Sprintf("Length_%d", length), func(t *testing.T) { 13 | str, err := GenerateRandomString(length) 14 | if err != nil { 15 | t.Fatalf("Failed to generate random string: %v", err) 16 | } 17 | if len(str) != length { 18 | t.Errorf("Expected string of length %d, got %d", length, len(str)) 19 | } 20 | 21 | // Generate another string to ensure they're different 22 | str2, err := GenerateRandomString(length) 23 | if err != nil { 24 | t.Fatalf("Failed to generate second random string: %v", err) 25 | } 26 | if str == str2 { 27 | t.Errorf("Generated identical random strings: %s", str) 28 | } 29 | }) 30 | } 31 | } 32 | 33 | func TestGenerateCodeVerifierAndChallenge(t *testing.T) { 34 | // Generate a code verifier 35 | verifier, err := GenerateCodeVerifier() 36 | if err != nil { 37 | t.Fatalf("Failed to generate code verifier: %v", err) 38 | } 39 | 40 | // Verify the length (should be 64 characters) 41 | if len(verifier) != 64 { 42 | t.Errorf("Expected code verifier of length 64, got %d", len(verifier)) 43 | } 44 | 45 | // Generate a code challenge 46 | challenge := GenerateCodeChallenge(verifier) 47 | 48 | // Verify the challenge is not empty 49 | if challenge == "" { 50 | t.Errorf("Generated empty code challenge") 51 | } 52 | 53 | // Generate another verifier and challenge to ensure they're different 54 | verifier2, _ := GenerateCodeVerifier() 55 | challenge2 := GenerateCodeChallenge(verifier2) 56 | 57 | if verifier == verifier2 { 58 | t.Errorf("Generated identical code verifiers: %s", verifier) 59 | } 60 | if challenge == challenge2 { 61 | t.Errorf("Generated identical code challenges: %s", challenge) 62 | } 63 | 64 | // Verify the same verifier always produces the same challenge 65 | challenge3 := GenerateCodeChallenge(verifier) 66 | if challenge != challenge3 { 67 | t.Errorf("Same verifier produced different challenges: %s and %s", challenge, challenge3) 68 | } 69 | } 70 | 71 | func TestGenerateState(t *testing.T) { 72 | // Generate a state parameter 73 | state, err := GenerateState() 74 | if err != nil { 75 | t.Fatalf("Failed to generate state: %v", err) 76 | } 77 | 78 | // Verify the length (should be 32 characters) 79 | if len(state) != 32 { 80 | t.Errorf("Expected state of length 32, got %d", len(state)) 81 | } 82 | 83 | // Generate another state to ensure they're different 84 | state2, _ := GenerateState() 85 | if state == state2 { 86 | t.Errorf("Generated identical states: %s", state) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /client/transport/stdio.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "os" 10 | "os/exec" 11 | "sync" 12 | 13 | "github.com/mark3labs/mcp-go/mcp" 14 | ) 15 | 16 | // Stdio implements the transport layer of the MCP protocol using stdio communication. 17 | // It launches a subprocess and communicates with it via standard input/output streams 18 | // using JSON-RPC messages. The client handles message routing between requests and 19 | // responses, and supports asynchronous notifications. 20 | type Stdio struct { 21 | command string 22 | args []string 23 | env []string 24 | 25 | cmd *exec.Cmd 26 | stdin io.WriteCloser 27 | stdout *bufio.Reader 28 | stderr io.ReadCloser 29 | responses map[string]chan *JSONRPCResponse 30 | mu sync.RWMutex 31 | done chan struct{} 32 | onNotification func(mcp.JSONRPCNotification) 33 | notifyMu sync.RWMutex 34 | } 35 | 36 | // NewIO returns a new stdio-based transport using existing input, output, and 37 | // logging streams instead of spawning a subprocess. 38 | // This is useful for testing and simulating client behavior. 39 | func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { 40 | return &Stdio{ 41 | stdin: output, 42 | stdout: bufio.NewReader(input), 43 | stderr: logging, 44 | 45 | responses: make(map[string]chan *JSONRPCResponse), 46 | done: make(chan struct{}), 47 | } 48 | } 49 | 50 | // NewStdio creates a new stdio transport to communicate with a subprocess. 51 | // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. 52 | // Returns an error if the subprocess cannot be started or the pipes cannot be created. 53 | func NewStdio( 54 | command string, 55 | env []string, 56 | args ...string, 57 | ) *Stdio { 58 | 59 | client := &Stdio{ 60 | command: command, 61 | args: args, 62 | env: env, 63 | 64 | responses: make(map[string]chan *JSONRPCResponse), 65 | done: make(chan struct{}), 66 | } 67 | 68 | return client 69 | } 70 | 71 | func (c *Stdio) Start(ctx context.Context) error { 72 | if err := c.spawnCommand(ctx); err != nil { 73 | return err 74 | } 75 | 76 | ready := make(chan struct{}) 77 | go func() { 78 | close(ready) 79 | c.readResponses() 80 | }() 81 | <-ready 82 | 83 | return nil 84 | } 85 | 86 | // spawnCommand spawns a new process running c.command. 87 | func (c *Stdio) spawnCommand(ctx context.Context) error { 88 | if c.command == "" { 89 | return nil 90 | } 91 | 92 | cmd := exec.CommandContext(ctx, c.command, c.args...) 93 | 94 | mergedEnv := os.Environ() 95 | mergedEnv = append(mergedEnv, c.env...) 96 | 97 | cmd.Env = mergedEnv 98 | 99 | stdin, err := cmd.StdinPipe() 100 | if err != nil { 101 | return fmt.Errorf("failed to create stdin pipe: %w", err) 102 | } 103 | 104 | stdout, err := cmd.StdoutPipe() 105 | if err != nil { 106 | return fmt.Errorf("failed to create stdout pipe: %w", err) 107 | } 108 | 109 | stderr, err := cmd.StderrPipe() 110 | if err != nil { 111 | return fmt.Errorf("failed to create stderr pipe: %w", err) 112 | } 113 | 114 | c.cmd = cmd 115 | c.stdin = stdin 116 | c.stderr = stderr 117 | c.stdout = bufio.NewReader(stdout) 118 | 119 | if err := cmd.Start(); err != nil { 120 | return fmt.Errorf("failed to start command: %w", err) 121 | } 122 | 123 | return nil 124 | } 125 | 126 | // Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. 127 | // Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. 128 | func (c *Stdio) Close() error { 129 | select { 130 | case <-c.done: 131 | return nil 132 | default: 133 | } 134 | // cancel all in-flight request 135 | close(c.done) 136 | 137 | if err := c.stdin.Close(); err != nil { 138 | return fmt.Errorf("failed to close stdin: %w", err) 139 | } 140 | if err := c.stderr.Close(); err != nil { 141 | return fmt.Errorf("failed to close stderr: %w", err) 142 | } 143 | 144 | if c.cmd != nil { 145 | return c.cmd.Wait() 146 | } 147 | 148 | return nil 149 | } 150 | 151 | // SetNotificationHandler sets the handler function to be called when a notification is received. 152 | // Only one handler can be set at a time; setting a new one replaces the previous handler. 153 | func (c *Stdio) SetNotificationHandler( 154 | handler func(notification mcp.JSONRPCNotification), 155 | ) { 156 | c.notifyMu.Lock() 157 | defer c.notifyMu.Unlock() 158 | c.onNotification = handler 159 | } 160 | 161 | // readResponses continuously reads and processes responses from the server's stdout. 162 | // It handles both responses to requests and notifications, routing them appropriately. 163 | // Runs until the done channel is closed or an error occurs reading from stdout. 164 | func (c *Stdio) readResponses() { 165 | for { 166 | select { 167 | case <-c.done: 168 | return 169 | default: 170 | line, err := c.stdout.ReadString('\n') 171 | if err != nil { 172 | if err != io.EOF { 173 | fmt.Printf("Error reading response: %v\n", err) 174 | } 175 | return 176 | } 177 | 178 | var baseMessage JSONRPCResponse 179 | if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { 180 | continue 181 | } 182 | 183 | // Handle notification 184 | if baseMessage.ID.IsNil() { 185 | var notification mcp.JSONRPCNotification 186 | if err := json.Unmarshal([]byte(line), ¬ification); err != nil { 187 | continue 188 | } 189 | c.notifyMu.RLock() 190 | if c.onNotification != nil { 191 | c.onNotification(notification) 192 | } 193 | c.notifyMu.RUnlock() 194 | continue 195 | } 196 | 197 | // Create string key for map lookup 198 | idKey := baseMessage.ID.String() 199 | 200 | c.mu.RLock() 201 | ch, exists := c.responses[idKey] 202 | c.mu.RUnlock() 203 | 204 | if exists { 205 | ch <- &baseMessage 206 | c.mu.Lock() 207 | delete(c.responses, idKey) 208 | c.mu.Unlock() 209 | } 210 | } 211 | } 212 | } 213 | 214 | // SendRequest sends a JSON-RPC request to the server and waits for a response. 215 | // It creates a unique request ID, sends the request over stdin, and waits for 216 | // the corresponding response or context cancellation. 217 | // Returns the raw JSON response message or an error if the request fails. 218 | func (c *Stdio) SendRequest( 219 | ctx context.Context, 220 | request JSONRPCRequest, 221 | ) (*JSONRPCResponse, error) { 222 | if c.stdin == nil { 223 | return nil, fmt.Errorf("stdio client not started") 224 | } 225 | 226 | // Marshal request 227 | requestBytes, err := json.Marshal(request) 228 | if err != nil { 229 | return nil, fmt.Errorf("failed to marshal request: %w", err) 230 | } 231 | requestBytes = append(requestBytes, '\n') 232 | 233 | // Create string key for map lookup 234 | idKey := request.ID.String() 235 | 236 | // Register response channel 237 | responseChan := make(chan *JSONRPCResponse, 1) 238 | c.mu.Lock() 239 | c.responses[idKey] = responseChan 240 | c.mu.Unlock() 241 | deleteResponseChan := func() { 242 | c.mu.Lock() 243 | delete(c.responses, idKey) 244 | c.mu.Unlock() 245 | } 246 | 247 | // Send request 248 | if _, err := c.stdin.Write(requestBytes); err != nil { 249 | deleteResponseChan() 250 | return nil, fmt.Errorf("failed to write request: %w", err) 251 | } 252 | 253 | select { 254 | case <-ctx.Done(): 255 | deleteResponseChan() 256 | return nil, ctx.Err() 257 | case response := <-responseChan: 258 | return response, nil 259 | } 260 | } 261 | 262 | // SendNotification sends a json RPC Notification to the server. 263 | func (c *Stdio) SendNotification( 264 | ctx context.Context, 265 | notification mcp.JSONRPCNotification, 266 | ) error { 267 | if c.stdin == nil { 268 | return fmt.Errorf("stdio client not started") 269 | } 270 | 271 | notificationBytes, err := json.Marshal(notification) 272 | if err != nil { 273 | return fmt.Errorf("failed to marshal notification: %w", err) 274 | } 275 | notificationBytes = append(notificationBytes, '\n') 276 | 277 | if _, err := c.stdin.Write(notificationBytes); err != nil { 278 | return fmt.Errorf("failed to write notification: %w", err) 279 | } 280 | 281 | return nil 282 | } 283 | 284 | // Stderr returns a reader for the stderr output of the subprocess. 285 | // This can be used to capture error messages or logs from the subprocess. 286 | func (c *Stdio) Stderr() io.Reader { 287 | return c.stderr 288 | } 289 | -------------------------------------------------------------------------------- /client/transport/streamable_http_oauth_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/mark3labs/mcp-go/mcp" 13 | ) 14 | 15 | func TestStreamableHTTP_WithOAuth(t *testing.T) { 16 | // Track request count to simulate 401 on first request, then success 17 | requestCount := 0 18 | authHeaderReceived := "" 19 | 20 | // Create a test server that requires OAuth 21 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 | // Capture the Authorization header 23 | authHeaderReceived = r.Header.Get("Authorization") 24 | 25 | // Check for Authorization header 26 | if requestCount == 0 { 27 | // First request - simulate 401 to test error handling 28 | requestCount++ 29 | w.WriteHeader(http.StatusUnauthorized) 30 | return 31 | } 32 | 33 | // Subsequent requests - verify the Authorization header 34 | if authHeaderReceived != "Bearer test-token" { 35 | t.Errorf("Expected Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) 36 | w.WriteHeader(http.StatusUnauthorized) 37 | return 38 | } 39 | 40 | // Return a successful response 41 | w.Header().Set("Content-Type", "application/json") 42 | w.WriteHeader(http.StatusOK) 43 | if err := json.NewEncoder(w).Encode(map[string]any{ 44 | "jsonrpc": "2.0", 45 | "id": 1, 46 | "result": "success", 47 | }); err != nil { 48 | t.Errorf("Failed to encode JSON response: %v", err) 49 | } 50 | })) 51 | defer server.Close() 52 | 53 | // Create a token store with a valid token 54 | tokenStore := NewMemoryTokenStore() 55 | validToken := &Token{ 56 | AccessToken: "test-token", 57 | TokenType: "Bearer", 58 | RefreshToken: "refresh-token", 59 | ExpiresIn: 3600, 60 | ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour 61 | } 62 | if err := tokenStore.SaveToken(validToken); err != nil { 63 | t.Fatalf("Failed to save token: %v", err) 64 | } 65 | 66 | // Create OAuth config 67 | oauthConfig := OAuthConfig{ 68 | ClientID: "test-client", 69 | RedirectURI: "http://localhost:8085/callback", 70 | Scopes: []string{"mcp.read", "mcp.write"}, 71 | TokenStore: tokenStore, 72 | PKCEEnabled: true, 73 | } 74 | 75 | // Create StreamableHTTP with OAuth 76 | transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) 77 | if err != nil { 78 | t.Fatalf("Failed to create StreamableHTTP: %v", err) 79 | } 80 | 81 | // Verify that OAuth is enabled 82 | if !transport.IsOAuthEnabled() { 83 | t.Errorf("Expected IsOAuthEnabled() to return true") 84 | } 85 | 86 | // Verify the OAuth handler is set 87 | if transport.GetOAuthHandler() == nil { 88 | t.Errorf("Expected GetOAuthHandler() to return a handler") 89 | } 90 | 91 | // First request should fail with OAuthAuthorizationRequiredError 92 | _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ 93 | JSONRPC: "2.0", 94 | ID: mcp.NewRequestId(1), 95 | Method: "test", 96 | }) 97 | 98 | // Verify the error is an OAuthAuthorizationRequiredError 99 | if err == nil { 100 | t.Fatalf("Expected error on first request, got nil") 101 | } 102 | 103 | var oauthErr *OAuthAuthorizationRequiredError 104 | if !errors.As(err, &oauthErr) { 105 | t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) 106 | } 107 | 108 | // Verify the error has the handler 109 | if oauthErr.Handler == nil { 110 | t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") 111 | } 112 | 113 | // Verify the server received the first request 114 | if requestCount != 1 { 115 | t.Errorf("Expected server to receive 1 request, got %d", requestCount) 116 | } 117 | 118 | // Second request should succeed 119 | response, err := transport.SendRequest(context.Background(), JSONRPCRequest{ 120 | JSONRPC: "2.0", 121 | ID: mcp.NewRequestId(2), 122 | Method: "test", 123 | }) 124 | 125 | if err != nil { 126 | t.Fatalf("Failed to send second request: %v", err) 127 | } 128 | 129 | // Verify the response 130 | var resultStr string 131 | if err := json.Unmarshal(response.Result, &resultStr); err != nil { 132 | t.Fatalf("Failed to unmarshal result: %v", err) 133 | } 134 | 135 | if resultStr != "success" { 136 | t.Errorf("Expected result to be 'success', got %v", resultStr) 137 | } 138 | 139 | // Verify the server received the Authorization header 140 | if authHeaderReceived != "Bearer test-token" { 141 | t.Errorf("Expected server to receive Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) 142 | } 143 | } 144 | 145 | func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { 146 | // Create a test server that requires OAuth 147 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 148 | // Always return unauthorized 149 | w.WriteHeader(http.StatusUnauthorized) 150 | })) 151 | defer server.Close() 152 | 153 | // Create an empty token store 154 | tokenStore := NewMemoryTokenStore() 155 | 156 | // Create OAuth config 157 | oauthConfig := OAuthConfig{ 158 | ClientID: "test-client", 159 | RedirectURI: "http://localhost:8085/callback", 160 | Scopes: []string{"mcp.read", "mcp.write"}, 161 | TokenStore: tokenStore, 162 | PKCEEnabled: true, 163 | } 164 | 165 | // Create StreamableHTTP with OAuth 166 | transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) 167 | if err != nil { 168 | t.Fatalf("Failed to create StreamableHTTP: %v", err) 169 | } 170 | 171 | // Send a request 172 | _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ 173 | JSONRPC: "2.0", 174 | ID: mcp.NewRequestId(1), 175 | Method: "test", 176 | }) 177 | 178 | // Verify the error is an OAuthAuthorizationRequiredError 179 | if err == nil { 180 | t.Fatalf("Expected error, got nil") 181 | } 182 | 183 | var oauthErr *OAuthAuthorizationRequiredError 184 | if !errors.As(err, &oauthErr) { 185 | t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) 186 | } 187 | 188 | // Verify the error has the handler 189 | if oauthErr.Handler == nil { 190 | t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") 191 | } 192 | } 193 | 194 | func TestStreamableHTTP_IsOAuthEnabled(t *testing.T) { 195 | // Create StreamableHTTP without OAuth 196 | transport1, err := NewStreamableHTTP("http://example.com") 197 | if err != nil { 198 | t.Fatalf("Failed to create StreamableHTTP: %v", err) 199 | } 200 | 201 | // Verify OAuth is not enabled 202 | if transport1.IsOAuthEnabled() { 203 | t.Errorf("Expected IsOAuthEnabled() to return false") 204 | } 205 | 206 | // Create StreamableHTTP with OAuth 207 | transport2, err := NewStreamableHTTP("http://example.com", WithOAuth(OAuthConfig{ 208 | ClientID: "test-client", 209 | })) 210 | if err != nil { 211 | t.Fatalf("Failed to create StreamableHTTP: %v", err) 212 | } 213 | 214 | // Verify OAuth is enabled 215 | if !transport2.IsOAuthEnabled() { 216 | t.Errorf("Expected IsOAuthEnabled() to return true") 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /examples/custom_context/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "os" 12 | 13 | "github.com/mark3labs/mcp-go/mcp" 14 | "github.com/mark3labs/mcp-go/server" 15 | ) 16 | 17 | // authKey is a custom context key for storing the auth token. 18 | type authKey struct{} 19 | 20 | // withAuthKey adds an auth key to the context. 21 | func withAuthKey(ctx context.Context, auth string) context.Context { 22 | return context.WithValue(ctx, authKey{}, auth) 23 | } 24 | 25 | // authFromRequest extracts the auth token from the request headers. 26 | func authFromRequest(ctx context.Context, r *http.Request) context.Context { 27 | return withAuthKey(ctx, r.Header.Get("Authorization")) 28 | } 29 | 30 | // authFromEnv extracts the auth token from the environment 31 | func authFromEnv(ctx context.Context) context.Context { 32 | return withAuthKey(ctx, os.Getenv("API_KEY")) 33 | } 34 | 35 | // tokenFromContext extracts the auth token from the context. 36 | // This can be used by tools to extract the token regardless of the 37 | // transport being used by the server. 38 | func tokenFromContext(ctx context.Context) (string, error) { 39 | auth, ok := ctx.Value(authKey{}).(string) 40 | if !ok { 41 | return "", fmt.Errorf("missing auth") 42 | } 43 | return auth, nil 44 | } 45 | 46 | type response struct { 47 | Args map[string]any `json:"args"` 48 | Headers map[string]string `json:"headers"` 49 | } 50 | 51 | // makeRequest makes a request to httpbin.org including the auth token in the request 52 | // headers and the message in the query string. 53 | func makeRequest(ctx context.Context, message, token string) (*response, error) { 54 | req, err := http.NewRequestWithContext(ctx, "GET", "https://httpbin.org/anything", nil) 55 | if err != nil { 56 | return nil, err 57 | } 58 | req.Header.Set("Authorization", token) 59 | query := req.URL.Query() 60 | query.Add("message", message) 61 | req.URL.RawQuery = query.Encode() 62 | resp, err := http.DefaultClient.Do(req) 63 | if err != nil { 64 | return nil, err 65 | } 66 | defer resp.Body.Close() 67 | body, err := io.ReadAll(resp.Body) 68 | if err != nil { 69 | return nil, err 70 | } 71 | var r *response 72 | if err := json.Unmarshal(body, &r); err != nil { 73 | return nil, err 74 | } 75 | return r, nil 76 | } 77 | 78 | // handleMakeAuthenticatedRequestTool is a tool that makes an authenticated request 79 | // using the token from the context. 80 | func handleMakeAuthenticatedRequestTool( 81 | ctx context.Context, 82 | request mcp.CallToolRequest, 83 | ) (*mcp.CallToolResult, error) { 84 | message, ok := request.GetArguments()["message"].(string) 85 | if !ok { 86 | return nil, fmt.Errorf("missing message") 87 | } 88 | token, err := tokenFromContext(ctx) 89 | if err != nil { 90 | return nil, fmt.Errorf("missing token: %v", err) 91 | } 92 | // Now our tool can make a request with the token, irrespective of where it came from. 93 | resp, err := makeRequest(ctx, message, token) 94 | if err != nil { 95 | return nil, err 96 | } 97 | return mcp.NewToolResultText(fmt.Sprintf("%+v", resp)), nil 98 | } 99 | 100 | type MCPServer struct { 101 | server *server.MCPServer 102 | } 103 | 104 | func NewMCPServer() *MCPServer { 105 | mcpServer := server.NewMCPServer( 106 | "example-server", 107 | "1.0.0", 108 | server.WithResourceCapabilities(true, true), 109 | server.WithPromptCapabilities(true), 110 | server.WithToolCapabilities(true), 111 | ) 112 | mcpServer.AddTool(mcp.NewTool("make_authenticated_request", 113 | mcp.WithDescription("Makes an authenticated request"), 114 | mcp.WithString("message", 115 | mcp.Description("Message to echo"), 116 | mcp.Required(), 117 | ), 118 | ), handleMakeAuthenticatedRequestTool) 119 | 120 | return &MCPServer{ 121 | server: mcpServer, 122 | } 123 | } 124 | 125 | func (s *MCPServer) ServeHTTP() *server.StreamableHTTPServer { 126 | return server.NewStreamableHTTPServer(s.server, 127 | server.WithHTTPContextFunc(authFromRequest), 128 | ) 129 | } 130 | 131 | func (s *MCPServer) ServeStdio() error { 132 | return server.ServeStdio(s.server, server.WithStdioContextFunc(authFromEnv)) 133 | } 134 | 135 | func main() { 136 | var transport string 137 | flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or http)") 138 | flag.StringVar( 139 | &transport, 140 | "transport", 141 | "stdio", 142 | "Transport type (stdio or http)", 143 | ) 144 | flag.Parse() 145 | 146 | s := NewMCPServer() 147 | 148 | switch transport { 149 | case "stdio": 150 | if err := s.ServeStdio(); err != nil { 151 | log.Fatalf("Server error: %v", err) 152 | } 153 | case "http": 154 | httpServer := s.ServeHTTP() 155 | log.Printf("HTTP server listening on :8080") 156 | if err := httpServer.Start(":8080"); err != nil { 157 | log.Fatalf("Server error: %v", err) 158 | } 159 | default: 160 | log.Fatalf( 161 | "Invalid transport type: %s. Must be 'stdio' or 'http'", 162 | transport, 163 | ) 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /examples/dynamic_path/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | 10 | "github.com/mark3labs/mcp-go/mcp" 11 | "github.com/mark3labs/mcp-go/server" 12 | ) 13 | 14 | func main() { 15 | var addr string 16 | flag.StringVar(&addr, "addr", ":8080", "address to listen on") 17 | flag.Parse() 18 | 19 | mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") 20 | 21 | // Add a trivial tool for demonstration 22 | mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 23 | return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil 24 | }) 25 | 26 | // Use a dynamic base path based on a path parameter (Go 1.22+) 27 | sseServer := server.NewSSEServer( 28 | mcpServer, 29 | server.WithDynamicBasePath(func(r *http.Request, sessionID string) string { 30 | tenant := r.PathValue("tenant") 31 | return "/api/" + tenant 32 | }), 33 | server.WithBaseURL(fmt.Sprintf("http://localhost%s", addr)), 34 | server.WithUseFullURLForMessageEndpoint(true), 35 | ) 36 | 37 | mux := http.NewServeMux() 38 | mux.Handle("/api/{tenant}/sse", sseServer.SSEHandler()) 39 | mux.Handle("/api/{tenant}/message", sseServer.MessageHandler()) 40 | 41 | log.Printf("Dynamic SSE server listening on %s", addr) 42 | if err := http.ListenAndServe(addr, mux); err != nil { 43 | log.Fatalf("Server error: %v", err) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /examples/filesystem_stdio_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "time" 9 | 10 | "github.com/mark3labs/mcp-go/client" 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | func main() { 15 | c, err := client.NewStdioMCPClient( 16 | "npx", 17 | []string{}, // Empty ENV 18 | "-y", 19 | "@modelcontextprotocol/server-filesystem", 20 | "/tmp", 21 | ) 22 | if err != nil { 23 | log.Fatalf("Failed to create client: %v", err) 24 | } 25 | defer c.Close() 26 | 27 | // Create context with timeout 28 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 29 | defer cancel() 30 | 31 | // Initialize the client 32 | fmt.Println("Initializing client...") 33 | initRequest := mcp.InitializeRequest{} 34 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 35 | initRequest.Params.ClientInfo = mcp.Implementation{ 36 | Name: "example-client", 37 | Version: "1.0.0", 38 | } 39 | 40 | initResult, err := c.Initialize(ctx, initRequest) 41 | if err != nil { 42 | log.Fatalf("Failed to initialize: %v", err) 43 | } 44 | fmt.Printf( 45 | "Initialized with server: %s %s\n\n", 46 | initResult.ServerInfo.Name, 47 | initResult.ServerInfo.Version, 48 | ) 49 | 50 | // List Tools 51 | fmt.Println("Listing available tools...") 52 | toolsRequest := mcp.ListToolsRequest{} 53 | tools, err := c.ListTools(ctx, toolsRequest) 54 | if err != nil { 55 | log.Fatalf("Failed to list tools: %v", err) 56 | } 57 | for _, tool := range tools.Tools { 58 | fmt.Printf("- %s: %s\n", tool.Name, tool.Description) 59 | } 60 | fmt.Println() 61 | 62 | // List allowed directories 63 | fmt.Println("Listing allowed directories...") 64 | listDirRequest := mcp.CallToolRequest{ 65 | Request: mcp.Request{ 66 | Method: "tools/call", 67 | }, 68 | } 69 | listDirRequest.Params.Name = "list_allowed_directories" 70 | 71 | result, err := c.CallTool(ctx, listDirRequest) 72 | if err != nil { 73 | log.Fatalf("Failed to list allowed directories: %v", err) 74 | } 75 | printToolResult(result) 76 | fmt.Println() 77 | 78 | // List /tmp 79 | fmt.Println("Listing /tmp directory...") 80 | listTmpRequest := mcp.CallToolRequest{} 81 | listTmpRequest.Params.Name = "list_directory" 82 | listTmpRequest.Params.Arguments = map[string]any{ 83 | "path": "/tmp", 84 | } 85 | 86 | result, err = c.CallTool(ctx, listTmpRequest) 87 | if err != nil { 88 | log.Fatalf("Failed to list directory: %v", err) 89 | } 90 | printToolResult(result) 91 | fmt.Println() 92 | 93 | // Create mcp directory 94 | fmt.Println("Creating /tmp/mcp directory...") 95 | createDirRequest := mcp.CallToolRequest{} 96 | createDirRequest.Params.Name = "create_directory" 97 | createDirRequest.Params.Arguments = map[string]any{ 98 | "path": "/tmp/mcp", 99 | } 100 | 101 | result, err = c.CallTool(ctx, createDirRequest) 102 | if err != nil { 103 | log.Fatalf("Failed to create directory: %v", err) 104 | } 105 | printToolResult(result) 106 | fmt.Println() 107 | 108 | // Create hello.txt 109 | fmt.Println("Creating /tmp/mcp/hello.txt...") 110 | writeFileRequest := mcp.CallToolRequest{} 111 | writeFileRequest.Params.Name = "write_file" 112 | writeFileRequest.Params.Arguments = map[string]any{ 113 | "path": "/tmp/mcp/hello.txt", 114 | "content": "Hello World", 115 | } 116 | 117 | result, err = c.CallTool(ctx, writeFileRequest) 118 | if err != nil { 119 | log.Fatalf("Failed to create file: %v", err) 120 | } 121 | printToolResult(result) 122 | fmt.Println() 123 | 124 | // Verify file contents 125 | fmt.Println("Reading /tmp/mcp/hello.txt...") 126 | readFileRequest := mcp.CallToolRequest{} 127 | readFileRequest.Params.Name = "read_file" 128 | readFileRequest.Params.Arguments = map[string]any{ 129 | "path": "/tmp/mcp/hello.txt", 130 | } 131 | 132 | result, err = c.CallTool(ctx, readFileRequest) 133 | if err != nil { 134 | log.Fatalf("Failed to read file: %v", err) 135 | } 136 | printToolResult(result) 137 | 138 | // Get file info 139 | fmt.Println("Getting info for /tmp/mcp/hello.txt...") 140 | fileInfoRequest := mcp.CallToolRequest{} 141 | fileInfoRequest.Params.Name = "get_file_info" 142 | fileInfoRequest.Params.Arguments = map[string]any{ 143 | "path": "/tmp/mcp/hello.txt", 144 | } 145 | 146 | result, err = c.CallTool(ctx, fileInfoRequest) 147 | if err != nil { 148 | log.Fatalf("Failed to get file info: %v", err) 149 | } 150 | printToolResult(result) 151 | } 152 | 153 | // Helper function to print tool results 154 | func printToolResult(result *mcp.CallToolResult) { 155 | for _, content := range result.Content { 156 | if textContent, ok := content.(mcp.TextContent); ok { 157 | fmt.Println(textContent.Text) 158 | } else { 159 | jsonBytes, _ := json.MarshalIndent(content, "", " ") 160 | fmt.Println(string(jsonBytes)) 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /examples/oauth_client/README.md: -------------------------------------------------------------------------------- 1 | # OAuth Client Example 2 | 3 | This example demonstrates how to use the OAuth capabilities of the MCP Go client to authenticate with an MCP server that requires OAuth authentication. 4 | 5 | ## Features 6 | 7 | - OAuth 2.1 authentication with PKCE support 8 | - Dynamic client registration 9 | - Authorization code flow 10 | - Token refresh 11 | - Local callback server for handling OAuth redirects 12 | 13 | ## Usage 14 | 15 | ```bash 16 | # Set environment variables (optional) 17 | export MCP_CLIENT_ID=your_client_id 18 | export MCP_CLIENT_SECRET=your_client_secret 19 | 20 | # Run the example 21 | go run main.go 22 | ``` 23 | 24 | ## How it Works 25 | 26 | 1. The client attempts to initialize a connection to the MCP server 27 | 2. If the server requires OAuth authentication, it will return a 401 Unauthorized response 28 | 3. The client detects this and starts the OAuth flow: 29 | - Generates PKCE code verifier and challenge 30 | - Generates a state parameter for security 31 | - Opens a browser to the authorization URL 32 | - Starts a local server to handle the callback 33 | 4. The user authorizes the application in their browser 34 | 5. The authorization server redirects back to the local callback server 35 | 6. The client exchanges the authorization code for an access token 36 | 7. The client retries the initialization with the access token 37 | 8. The client can now make authenticated requests to the MCP server 38 | 39 | ## Configuration 40 | 41 | Edit the following constants in `main.go` to match your environment: 42 | 43 | ```go 44 | const ( 45 | // Replace with your MCP server URL 46 | serverURL = "https://api.example.com/v1/mcp" 47 | // Use a localhost redirect URI for this example 48 | redirectURI = "http://localhost:8085/oauth/callback" 49 | ) 50 | ``` 51 | 52 | ## OAuth Scopes 53 | 54 | The example requests the following scopes: 55 | 56 | - `mcp.read` - Read access to MCP resources 57 | - `mcp.write` - Write access to MCP resources 58 | 59 | You can modify the scopes in the `oauthConfig` to match the requirements of your MCP server. -------------------------------------------------------------------------------- /examples/oauth_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "os" 9 | "os/exec" 10 | "runtime" 11 | 12 | "github.com/mark3labs/mcp-go/client" 13 | "github.com/mark3labs/mcp-go/mcp" 14 | ) 15 | 16 | const ( 17 | // Replace with your MCP server URL 18 | serverURL = "https://api.example.com/v1/mcp" 19 | // Use a localhost redirect URI for this example 20 | redirectURI = "http://localhost:8085/oauth/callback" 21 | ) 22 | 23 | func main() { 24 | // Create a token store to persist tokens 25 | tokenStore := client.NewMemoryTokenStore() 26 | 27 | // Create OAuth configuration 28 | oauthConfig := client.OAuthConfig{ 29 | // Client ID can be empty if using dynamic registration 30 | ClientID: os.Getenv("MCP_CLIENT_ID"), 31 | ClientSecret: os.Getenv("MCP_CLIENT_SECRET"), 32 | RedirectURI: redirectURI, 33 | Scopes: []string{"mcp.read", "mcp.write"}, 34 | TokenStore: tokenStore, 35 | PKCEEnabled: true, // Enable PKCE for public clients 36 | } 37 | 38 | // Create the client with OAuth support 39 | c, err := client.NewOAuthStreamableHttpClient(serverURL, oauthConfig) 40 | if err != nil { 41 | log.Fatalf("Failed to create client: %v", err) 42 | } 43 | 44 | // Start the client 45 | if err := c.Start(context.Background()); err != nil { 46 | log.Fatalf("Failed to start client: %v", err) 47 | } 48 | defer c.Close() 49 | 50 | // Try to initialize the client 51 | result, err := c.Initialize(context.Background(), mcp.InitializeRequest{ 52 | Params: struct { 53 | ProtocolVersion string `json:"protocolVersion"` 54 | Capabilities mcp.ClientCapabilities `json:"capabilities"` 55 | ClientInfo mcp.Implementation `json:"clientInfo"` 56 | }{ 57 | ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, 58 | ClientInfo: mcp.Implementation{ 59 | Name: "mcp-go-oauth-example", 60 | Version: "0.1.0", 61 | }, 62 | }, 63 | }) 64 | 65 | // Check if we need OAuth authorization 66 | if client.IsOAuthAuthorizationRequiredError(err) { 67 | fmt.Println("OAuth authorization required. Starting authorization flow...") 68 | 69 | // Get the OAuth handler from the error 70 | oauthHandler := client.GetOAuthHandler(err) 71 | 72 | // Start a local server to handle the OAuth callback 73 | callbackChan := make(chan map[string]string) 74 | server := startCallbackServer(callbackChan) 75 | defer server.Close() 76 | 77 | // Generate PKCE code verifier and challenge 78 | codeVerifier, err := client.GenerateCodeVerifier() 79 | if err != nil { 80 | log.Fatalf("Failed to generate code verifier: %v", err) 81 | } 82 | codeChallenge := client.GenerateCodeChallenge(codeVerifier) 83 | 84 | // Generate state parameter 85 | state, err := client.GenerateState() 86 | if err != nil { 87 | log.Fatalf("Failed to generate state: %v", err) 88 | } 89 | 90 | // Get the authorization URL 91 | authURL, err := oauthHandler.GetAuthorizationURL(context.Background(), state, codeChallenge) 92 | if err != nil { 93 | log.Fatalf("Failed to get authorization URL: %v", err) 94 | } 95 | 96 | // Open the browser to the authorization URL 97 | fmt.Printf("Opening browser to: %s\n", authURL) 98 | openBrowser(authURL) 99 | 100 | // Wait for the callback 101 | fmt.Println("Waiting for authorization callback...") 102 | params := <-callbackChan 103 | 104 | // Verify state parameter 105 | if params["state"] != state { 106 | log.Fatalf("State mismatch: expected %s, got %s", state, params["state"]) 107 | } 108 | 109 | // Exchange the authorization code for a token 110 | code := params["code"] 111 | if code == "" { 112 | log.Fatalf("No authorization code received") 113 | } 114 | 115 | fmt.Println("Exchanging authorization code for token...") 116 | err = oauthHandler.ProcessAuthorizationResponse(context.Background(), code, state, codeVerifier) 117 | if err != nil { 118 | log.Fatalf("Failed to process authorization response: %v", err) 119 | } 120 | 121 | fmt.Println("Authorization successful!") 122 | 123 | // Try to initialize again with the token 124 | result, err = c.Initialize(context.Background(), mcp.InitializeRequest{ 125 | Params: struct { 126 | ProtocolVersion string `json:"protocolVersion"` 127 | Capabilities mcp.ClientCapabilities `json:"capabilities"` 128 | ClientInfo mcp.Implementation `json:"clientInfo"` 129 | }{ 130 | ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, 131 | ClientInfo: mcp.Implementation{ 132 | Name: "mcp-go-oauth-example", 133 | Version: "0.1.0", 134 | }, 135 | }, 136 | }) 137 | if err != nil { 138 | log.Fatalf("Failed to initialize client after authorization: %v", err) 139 | } 140 | } else if err != nil { 141 | log.Fatalf("Failed to initialize client: %v", err) 142 | } 143 | 144 | fmt.Printf("Client initialized successfully! Server: %s %s\n", 145 | result.ServerInfo.Name, 146 | result.ServerInfo.Version) 147 | 148 | // Now you can use the client as usual 149 | // For example, list resources 150 | resources, err := c.ListResources(context.Background(), mcp.ListResourcesRequest{}) 151 | if err != nil { 152 | log.Fatalf("Failed to list resources: %v", err) 153 | } 154 | 155 | fmt.Println("Available resources:") 156 | for _, resource := range resources.Resources { 157 | fmt.Printf("- %s\n", resource.URI) 158 | } 159 | } 160 | 161 | // startCallbackServer starts a local HTTP server to handle the OAuth callback 162 | func startCallbackServer(callbackChan chan<- map[string]string) *http.Server { 163 | server := &http.Server{ 164 | Addr: ":8085", 165 | } 166 | 167 | http.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { 168 | // Extract query parameters 169 | params := make(map[string]string) 170 | for key, values := range r.URL.Query() { 171 | if len(values) > 0 { 172 | params[key] = values[0] 173 | } 174 | } 175 | 176 | // Send parameters to the channel 177 | callbackChan <- params 178 | 179 | // Respond to the user 180 | w.Header().Set("Content-Type", "text/html") 181 | _, err := w.Write([]byte(` 182 | 183 | 184 |

Authorization Successful

185 |

You can now close this window and return to the application.

186 | 187 | 188 | 189 | `)) 190 | if err != nil { 191 | log.Printf("Error writing response: %v", err) 192 | } 193 | }) 194 | 195 | go func() { 196 | if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 197 | log.Printf("HTTP server error: %v", err) 198 | } 199 | }() 200 | 201 | return server 202 | } 203 | 204 | // openBrowser opens the default browser to the specified URL 205 | func openBrowser(url string) { 206 | var err error 207 | 208 | switch runtime.GOOS { 209 | case "linux": 210 | err = exec.Command("xdg-open", url).Start() 211 | case "windows": 212 | err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() 213 | case "darwin": 214 | err = exec.Command("open", url).Start() 215 | default: 216 | err = fmt.Errorf("unsupported platform") 217 | } 218 | 219 | if err != nil { 220 | log.Printf("Failed to open browser: %v", err) 221 | fmt.Printf("Please open the following URL in your browser: %s\n", url) 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /examples/simple_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "log" 9 | "os" 10 | "time" 11 | 12 | "github.com/mark3labs/mcp-go/client" 13 | "github.com/mark3labs/mcp-go/client/transport" 14 | "github.com/mark3labs/mcp-go/mcp" 15 | ) 16 | 17 | func main() { 18 | // Define command line flags 19 | stdioCmd := flag.String("stdio", "", "Command to execute for stdio transport (e.g. 'python server.py')") 20 | httpURL := flag.String("http", "", "URL for HTTP transport (e.g. 'http://localhost:8080/mcp')") 21 | flag.Parse() 22 | 23 | // Validate flags 24 | if (*stdioCmd == "" && *httpURL == "") || (*stdioCmd != "" && *httpURL != "") { 25 | fmt.Println("Error: You must specify exactly one of --stdio or --http") 26 | flag.Usage() 27 | os.Exit(1) 28 | } 29 | 30 | // Create a context with timeout 31 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 32 | defer cancel() 33 | 34 | // Create client based on transport type 35 | var c *client.Client 36 | var err error 37 | 38 | if *stdioCmd != "" { 39 | fmt.Println("Initializing stdio client...") 40 | // Parse command and arguments 41 | args := parseCommand(*stdioCmd) 42 | if len(args) == 0 { 43 | fmt.Println("Error: Invalid stdio command") 44 | os.Exit(1) 45 | } 46 | 47 | // Create command and stdio transport 48 | command := args[0] 49 | cmdArgs := args[1:] 50 | 51 | // Create stdio transport with verbose logging 52 | stdioTransport := transport.NewStdio(command, nil, cmdArgs...) 53 | 54 | // Create client with the transport 55 | c = client.NewClient(stdioTransport) 56 | 57 | // Set up logging for stderr if available 58 | if stderr, ok := client.GetStderr(c); ok { 59 | go func() { 60 | buf := make([]byte, 4096) 61 | for { 62 | n, err := stderr.Read(buf) 63 | if err != nil { 64 | if err != io.EOF { 65 | log.Printf("Error reading stderr: %v", err) 66 | } 67 | return 68 | } 69 | if n > 0 { 70 | fmt.Fprintf(os.Stderr, "[Server] %s", buf[:n]) 71 | } 72 | } 73 | }() 74 | } 75 | } else { 76 | fmt.Println("Initializing HTTP client...") 77 | // Create HTTP transport 78 | httpTransport, err := transport.NewStreamableHTTP(*httpURL) 79 | if err != nil { 80 | log.Fatalf("Failed to create HTTP transport: %v", err) 81 | } 82 | 83 | // Create client with the transport 84 | c = client.NewClient(httpTransport) 85 | } 86 | 87 | // Start the client 88 | if err := c.Start(ctx); err != nil { 89 | log.Fatalf("Failed to start client: %v", err) 90 | } 91 | 92 | // Set up notification handler 93 | c.OnNotification(func(notification mcp.JSONRPCNotification) { 94 | fmt.Printf("Received notification: %s\n", notification.Method) 95 | }) 96 | 97 | // Initialize the client 98 | fmt.Println("Initializing client...") 99 | initRequest := mcp.InitializeRequest{} 100 | initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 101 | initRequest.Params.ClientInfo = mcp.Implementation{ 102 | Name: "MCP-Go Simple Client Example", 103 | Version: "1.0.0", 104 | } 105 | initRequest.Params.Capabilities = mcp.ClientCapabilities{} 106 | 107 | serverInfo, err := c.Initialize(ctx, initRequest) 108 | if err != nil { 109 | log.Fatalf("Failed to initialize: %v", err) 110 | } 111 | 112 | // Display server information 113 | fmt.Printf("Connected to server: %s (version %s)\n", 114 | serverInfo.ServerInfo.Name, 115 | serverInfo.ServerInfo.Version) 116 | fmt.Printf("Server capabilities: %+v\n", serverInfo.Capabilities) 117 | 118 | // List available tools if the server supports them 119 | if serverInfo.Capabilities.Tools != nil { 120 | fmt.Println("Fetching available tools...") 121 | toolsRequest := mcp.ListToolsRequest{} 122 | toolsResult, err := c.ListTools(ctx, toolsRequest) 123 | if err != nil { 124 | log.Printf("Failed to list tools: %v", err) 125 | } else { 126 | fmt.Printf("Server has %d tools available\n", len(toolsResult.Tools)) 127 | for i, tool := range toolsResult.Tools { 128 | fmt.Printf(" %d. %s - %s\n", i+1, tool.Name, tool.Description) 129 | } 130 | } 131 | } 132 | 133 | // List available resources if the server supports them 134 | if serverInfo.Capabilities.Resources != nil { 135 | fmt.Println("Fetching available resources...") 136 | resourcesRequest := mcp.ListResourcesRequest{} 137 | resourcesResult, err := c.ListResources(ctx, resourcesRequest) 138 | if err != nil { 139 | log.Printf("Failed to list resources: %v", err) 140 | } else { 141 | fmt.Printf("Server has %d resources available\n", len(resourcesResult.Resources)) 142 | for i, resource := range resourcesResult.Resources { 143 | fmt.Printf(" %d. %s - %s\n", i+1, resource.URI, resource.Name) 144 | } 145 | } 146 | } 147 | 148 | fmt.Println("Client initialized successfully. Shutting down...") 149 | c.Close() 150 | } 151 | 152 | // parseCommand splits a command string into command and arguments 153 | func parseCommand(cmd string) []string { 154 | // This is a simple implementation that doesn't handle quotes or escapes 155 | // For a more robust solution, consider using a shell parser library 156 | var result []string 157 | var current string 158 | var inQuote bool 159 | var quoteChar rune 160 | 161 | for _, r := range cmd { 162 | switch { 163 | case r == ' ' && !inQuote: 164 | if current != "" { 165 | result = append(result, current) 166 | current = "" 167 | } 168 | case (r == '"' || r == '\''): 169 | if inQuote && r == quoteChar { 170 | inQuote = false 171 | quoteChar = 0 172 | } else if !inQuote { 173 | inQuote = true 174 | quoteChar = r 175 | } else { 176 | current += string(r) 177 | } 178 | default: 179 | current += string(r) 180 | } 181 | } 182 | 183 | if current != "" { 184 | result = append(result, current) 185 | } 186 | 187 | return result 188 | } 189 | -------------------------------------------------------------------------------- /examples/typed_tools/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/mark3labs/mcp-go/mcp" 8 | "github.com/mark3labs/mcp-go/server" 9 | ) 10 | 11 | // Define a struct for our typed arguments 12 | type GreetingArgs struct { 13 | Name string `json:"name"` 14 | Age int `json:"age"` 15 | IsVIP bool `json:"is_vip"` 16 | Languages []string `json:"languages"` 17 | Metadata struct { 18 | Location string `json:"location"` 19 | Timezone string `json:"timezone"` 20 | } `json:"metadata"` 21 | } 22 | 23 | func main() { 24 | // Create a new MCP server 25 | s := server.NewMCPServer( 26 | "Typed Tools Demo 🚀", 27 | "1.0.0", 28 | server.WithToolCapabilities(false), 29 | ) 30 | 31 | // Add tool with complex schema 32 | tool := mcp.NewTool("greeting", 33 | mcp.WithDescription("Generate a personalized greeting"), 34 | mcp.WithString("name", 35 | mcp.Required(), 36 | mcp.Description("Name of the person to greet"), 37 | ), 38 | mcp.WithNumber("age", 39 | mcp.Description("Age of the person"), 40 | mcp.Min(0), 41 | mcp.Max(150), 42 | ), 43 | mcp.WithBoolean("is_vip", 44 | mcp.Description("Whether the person is a VIP"), 45 | mcp.DefaultBool(false), 46 | ), 47 | mcp.WithArray("languages", 48 | mcp.Description("Languages the person speaks"), 49 | mcp.Items(map[string]any{"type": "string"}), 50 | ), 51 | mcp.WithObject("metadata", 52 | mcp.Description("Additional information about the person"), 53 | mcp.Properties(map[string]any{ 54 | "location": map[string]any{ 55 | "type": "string", 56 | "description": "Current location", 57 | }, 58 | "timezone": map[string]any{ 59 | "type": "string", 60 | "description": "Timezone", 61 | }, 62 | }), 63 | ), 64 | ) 65 | 66 | // Add tool handler using the typed handler 67 | s.AddTool(tool, mcp.NewTypedToolHandler(typedGreetingHandler)) 68 | 69 | // Start the stdio server 70 | if err := server.ServeStdio(s); err != nil { 71 | fmt.Printf("Server error: %v\n", err) 72 | } 73 | } 74 | 75 | // Our typed handler function that receives strongly-typed arguments 76 | func typedGreetingHandler(ctx context.Context, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { 77 | if args.Name == "" { 78 | return mcp.NewToolResultError("name is required"), nil 79 | } 80 | 81 | // Build a personalized greeting based on the complex arguments 82 | greeting := fmt.Sprintf("Hello, %s!", args.Name) 83 | 84 | if args.Age > 0 { 85 | greeting += fmt.Sprintf(" You are %d years old.", args.Age) 86 | } 87 | 88 | if args.IsVIP { 89 | greeting += " Welcome back, valued VIP customer!" 90 | } 91 | 92 | if len(args.Languages) > 0 { 93 | greeting += fmt.Sprintf(" You speak %d languages: %v.", len(args.Languages), args.Languages) 94 | } 95 | 96 | if args.Metadata.Location != "" { 97 | greeting += fmt.Sprintf(" I see you're from %s.", args.Metadata.Location) 98 | 99 | if args.Metadata.Timezone != "" { 100 | greeting += fmt.Sprintf(" Your timezone is %s.", args.Metadata.Timezone) 101 | } 102 | } 103 | 104 | return mcp.NewToolResultText(greeting), nil 105 | } 106 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mark3labs/mcp-go 2 | 3 | go 1.23 4 | 5 | require ( 6 | github.com/google/uuid v1.6.0 7 | github.com/spf13/cast v1.7.1 8 | github.com/stretchr/testify v1.9.0 9 | github.com/yosida95/uritemplate/v3 v3.0.2 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | gopkg.in/yaml.v3 v3.0.1 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= 4 | github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= 5 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 6 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 7 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 8 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 9 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 10 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 11 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 12 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 16 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 17 | github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= 18 | github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= 19 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 20 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 21 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= 22 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= 23 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 26 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 27 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mark3labs/mcp-go/2cbaebf51e3629d9409d14296dc0f3410b2013e2/logo.png -------------------------------------------------------------------------------- /mcp/prompts.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | /* Prompts */ 4 | 5 | // ListPromptsRequest is sent from the client to request a list of prompts and 6 | // prompt templates the server has. 7 | type ListPromptsRequest struct { 8 | PaginatedRequest 9 | } 10 | 11 | // ListPromptsResult is the server's response to a prompts/list request from 12 | // the client. 13 | type ListPromptsResult struct { 14 | PaginatedResult 15 | Prompts []Prompt `json:"prompts"` 16 | } 17 | 18 | // GetPromptRequest is used by the client to get a prompt provided by the 19 | // server. 20 | type GetPromptRequest struct { 21 | Request 22 | Params GetPromptParams `json:"params"` 23 | } 24 | 25 | type GetPromptParams struct { 26 | // The name of the prompt or prompt template. 27 | Name string `json:"name"` 28 | // Arguments to use for templating the prompt. 29 | Arguments map[string]string `json:"arguments,omitempty"` 30 | } 31 | 32 | // GetPromptResult is the server's response to a prompts/get request from the 33 | // client. 34 | type GetPromptResult struct { 35 | Result 36 | // An optional description for the prompt. 37 | Description string `json:"description,omitempty"` 38 | Messages []PromptMessage `json:"messages"` 39 | } 40 | 41 | // Prompt represents a prompt or prompt template that the server offers. 42 | // If Arguments is non-nil and non-empty, this indicates the prompt is a template 43 | // that requires argument values to be provided when calling prompts/get. 44 | // If Arguments is nil or empty, this is a static prompt that takes no arguments. 45 | type Prompt struct { 46 | // The name of the prompt or prompt template. 47 | Name string `json:"name"` 48 | // An optional description of what this prompt provides 49 | Description string `json:"description,omitempty"` 50 | // A list of arguments to use for templating the prompt. 51 | // The presence of arguments indicates this is a template prompt. 52 | Arguments []PromptArgument `json:"arguments,omitempty"` 53 | } 54 | 55 | // GetName returns the name of the prompt. 56 | func (p Prompt) GetName() string { 57 | return p.Name 58 | } 59 | 60 | // PromptArgument describes an argument that a prompt template can accept. 61 | // When a prompt includes arguments, clients must provide values for all 62 | // required arguments when making a prompts/get request. 63 | type PromptArgument struct { 64 | // The name of the argument. 65 | Name string `json:"name"` 66 | // A human-readable description of the argument. 67 | Description string `json:"description,omitempty"` 68 | // Whether this argument must be provided. 69 | // If true, clients must include this argument when calling prompts/get. 70 | Required bool `json:"required,omitempty"` 71 | } 72 | 73 | // Role represents the sender or recipient of messages and data in a 74 | // conversation. 75 | type Role string 76 | 77 | const ( 78 | RoleUser Role = "user" 79 | RoleAssistant Role = "assistant" 80 | ) 81 | 82 | // PromptMessage describes a message returned as part of a prompt. 83 | // 84 | // This is similar to `SamplingMessage`, but also supports the embedding of 85 | // resources from the MCP server. 86 | type PromptMessage struct { 87 | Role Role `json:"role"` 88 | Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource 89 | } 90 | 91 | // PromptListChangedNotification is an optional notification from the server 92 | // to the client, informing it that the list of prompts it offers has changed. This 93 | // may be issued by servers without any previous subscription from the client. 94 | type PromptListChangedNotification struct { 95 | Notification 96 | } 97 | 98 | // PromptOption is a function that configures a Prompt. 99 | // It provides a flexible way to set various properties of a Prompt using the functional options pattern. 100 | type PromptOption func(*Prompt) 101 | 102 | // ArgumentOption is a function that configures a PromptArgument. 103 | // It allows for flexible configuration of prompt arguments using the functional options pattern. 104 | type ArgumentOption func(*PromptArgument) 105 | 106 | // 107 | // Core Prompt Functions 108 | // 109 | 110 | // NewPrompt creates a new Prompt with the given name and options. 111 | // The prompt will be configured based on the provided options. 112 | // Options are applied in order, allowing for flexible prompt configuration. 113 | func NewPrompt(name string, opts ...PromptOption) Prompt { 114 | prompt := Prompt{ 115 | Name: name, 116 | } 117 | 118 | for _, opt := range opts { 119 | opt(&prompt) 120 | } 121 | 122 | return prompt 123 | } 124 | 125 | // WithPromptDescription adds a description to the Prompt. 126 | // The description should provide a clear, human-readable explanation of what the prompt does. 127 | func WithPromptDescription(description string) PromptOption { 128 | return func(p *Prompt) { 129 | p.Description = description 130 | } 131 | } 132 | 133 | // WithArgument adds an argument to the prompt's argument list. 134 | // The argument will be configured based on the provided options. 135 | func WithArgument(name string, opts ...ArgumentOption) PromptOption { 136 | return func(p *Prompt) { 137 | arg := PromptArgument{ 138 | Name: name, 139 | } 140 | 141 | for _, opt := range opts { 142 | opt(&arg) 143 | } 144 | 145 | if p.Arguments == nil { 146 | p.Arguments = make([]PromptArgument, 0) 147 | } 148 | p.Arguments = append(p.Arguments, arg) 149 | } 150 | } 151 | 152 | // 153 | // Argument Options 154 | // 155 | 156 | // ArgumentDescription adds a description to a prompt argument. 157 | // The description should explain the purpose and expected values of the argument. 158 | func ArgumentDescription(desc string) ArgumentOption { 159 | return func(arg *PromptArgument) { 160 | arg.Description = desc 161 | } 162 | } 163 | 164 | // RequiredArgument marks an argument as required in the prompt. 165 | // Required arguments must be provided when getting the prompt. 166 | func RequiredArgument() ArgumentOption { 167 | return func(arg *PromptArgument) { 168 | arg.Required = true 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /mcp/resources.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import "github.com/yosida95/uritemplate/v3" 4 | 5 | // ResourceOption is a function that configures a Resource. 6 | // It provides a flexible way to set various properties of a Resource using the functional options pattern. 7 | type ResourceOption func(*Resource) 8 | 9 | // NewResource creates a new Resource with the given URI, name and options. 10 | // The resource will be configured based on the provided options. 11 | // Options are applied in order, allowing for flexible resource configuration. 12 | func NewResource(uri string, name string, opts ...ResourceOption) Resource { 13 | resource := Resource{ 14 | URI: uri, 15 | Name: name, 16 | } 17 | 18 | for _, opt := range opts { 19 | opt(&resource) 20 | } 21 | 22 | return resource 23 | } 24 | 25 | // WithResourceDescription adds a description to the Resource. 26 | // The description should provide a clear, human-readable explanation of what the resource represents. 27 | func WithResourceDescription(description string) ResourceOption { 28 | return func(r *Resource) { 29 | r.Description = description 30 | } 31 | } 32 | 33 | // WithMIMEType sets the MIME type for the Resource. 34 | // This should indicate the format of the resource's contents. 35 | func WithMIMEType(mimeType string) ResourceOption { 36 | return func(r *Resource) { 37 | r.MIMEType = mimeType 38 | } 39 | } 40 | 41 | // WithAnnotations adds annotations to the Resource. 42 | // Annotations can provide additional metadata about the resource's intended use. 43 | func WithAnnotations(audience []Role, priority float64) ResourceOption { 44 | return func(r *Resource) { 45 | if r.Annotations == nil { 46 | r.Annotations = &Annotations{} 47 | } 48 | r.Annotations.Audience = audience 49 | r.Annotations.Priority = priority 50 | } 51 | } 52 | 53 | // ResourceTemplateOption is a function that configures a ResourceTemplate. 54 | // It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. 55 | type ResourceTemplateOption func(*ResourceTemplate) 56 | 57 | // NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. 58 | // The template will be configured based on the provided options. 59 | // Options are applied in order, allowing for flexible template configuration. 60 | func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { 61 | template := ResourceTemplate{ 62 | URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, 63 | Name: name, 64 | } 65 | 66 | for _, opt := range opts { 67 | opt(&template) 68 | } 69 | 70 | return template 71 | } 72 | 73 | // WithTemplateDescription adds a description to the ResourceTemplate. 74 | // The description should provide a clear, human-readable explanation of what resources this template represents. 75 | func WithTemplateDescription(description string) ResourceTemplateOption { 76 | return func(t *ResourceTemplate) { 77 | t.Description = description 78 | } 79 | } 80 | 81 | // WithTemplateMIMEType sets the MIME type for the ResourceTemplate. 82 | // This should only be set if all resources matching this template will have the same type. 83 | func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { 84 | return func(t *ResourceTemplate) { 85 | t.MIMEType = mimeType 86 | } 87 | } 88 | 89 | // WithTemplateAnnotations adds annotations to the ResourceTemplate. 90 | // Annotations can provide additional metadata about the template's intended use. 91 | func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { 92 | return func(t *ResourceTemplate) { 93 | if t.Annotations == nil { 94 | t.Annotations = &Annotations{} 95 | } 96 | t.Annotations.Audience = audience 97 | t.Annotations.Priority = priority 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /mcp/typed_tools.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // TypedToolHandlerFunc is a function that handles a tool call with typed arguments 9 | type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) 10 | 11 | // NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct 12 | func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { 13 | return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { 14 | var args T 15 | if err := request.BindArguments(&args); err != nil { 16 | return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil 17 | } 18 | return handler(ctx, request, args) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /mcp/typed_tools_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestTypedToolHandler(t *testing.T) { 13 | // Define a test struct for arguments 14 | type HelloArgs struct { 15 | Name string `json:"name"` 16 | Age int `json:"age"` 17 | IsAdmin bool `json:"is_admin"` 18 | } 19 | 20 | // Create a typed handler function 21 | typedHandler := func(ctx context.Context, request CallToolRequest, args HelloArgs) (*CallToolResult, error) { 22 | return NewToolResultText(args.Name), nil 23 | } 24 | 25 | // Create a wrapped handler 26 | wrappedHandler := NewTypedToolHandler(typedHandler) 27 | 28 | // Create a test request 29 | req := CallToolRequest{} 30 | req.Params.Name = "test-tool" 31 | req.Params.Arguments = map[string]any{ 32 | "name": "John Doe", 33 | "age": 30, 34 | "is_admin": true, 35 | } 36 | 37 | // Call the wrapped handler 38 | result, err := wrappedHandler(context.Background(), req) 39 | 40 | // Verify results 41 | assert.NoError(t, err) 42 | assert.NotNil(t, result) 43 | assert.Equal(t, "John Doe", result.Content[0].(TextContent).Text) 44 | 45 | // Test with invalid arguments 46 | req.Params.Arguments = map[string]any{ 47 | "name": 123, // Wrong type 48 | "age": "thirty", 49 | "is_admin": "yes", 50 | } 51 | 52 | // This should still work because of type conversion 53 | result, err = wrappedHandler(context.Background(), req) 54 | assert.NoError(t, err) 55 | assert.NotNil(t, result) 56 | 57 | // Test with missing required field 58 | req.Params.Arguments = map[string]any{ 59 | "age": 30, 60 | "is_admin": true, 61 | // Name is missing 62 | } 63 | 64 | // This should still work but name will be empty 65 | result, err = wrappedHandler(context.Background(), req) 66 | assert.NoError(t, err) 67 | assert.NotNil(t, result) 68 | assert.Equal(t, "", result.Content[0].(TextContent).Text) 69 | 70 | // Test with completely invalid arguments 71 | req.Params.Arguments = "not a map" 72 | result, err = wrappedHandler(context.Background(), req) 73 | assert.NoError(t, err) // Error is wrapped in the result 74 | assert.NotNil(t, result) 75 | assert.True(t, result.IsError) 76 | } 77 | 78 | func TestTypedToolHandlerWithValidation(t *testing.T) { 79 | // Define a test struct for arguments with validation 80 | type CalculatorArgs struct { 81 | Operation string `json:"operation"` 82 | X float64 `json:"x"` 83 | Y float64 `json:"y"` 84 | } 85 | 86 | // Create a typed handler function with validation 87 | typedHandler := func(ctx context.Context, request CallToolRequest, args CalculatorArgs) (*CallToolResult, error) { 88 | // Validate operation 89 | if args.Operation == "" { 90 | return NewToolResultError("operation is required"), nil 91 | } 92 | 93 | var result float64 94 | switch args.Operation { 95 | case "add": 96 | result = args.X + args.Y 97 | case "subtract": 98 | result = args.X - args.Y 99 | case "multiply": 100 | result = args.X * args.Y 101 | case "divide": 102 | if args.Y == 0 { 103 | return NewToolResultError("division by zero"), nil 104 | } 105 | result = args.X / args.Y 106 | default: 107 | return NewToolResultError("invalid operation"), nil 108 | } 109 | 110 | return NewToolResultText(fmt.Sprintf("%.0f", result)), nil 111 | } 112 | 113 | // Create a wrapped handler 114 | wrappedHandler := NewTypedToolHandler(typedHandler) 115 | 116 | // Create a test request 117 | req := CallToolRequest{} 118 | req.Params.Name = "calculator" 119 | req.Params.Arguments = map[string]any{ 120 | "operation": "add", 121 | "x": 10.5, 122 | "y": 5.5, 123 | } 124 | 125 | // Call the wrapped handler 126 | result, err := wrappedHandler(context.Background(), req) 127 | 128 | // Verify results 129 | assert.NoError(t, err) 130 | assert.NotNil(t, result) 131 | assert.Equal(t, "16", result.Content[0].(TextContent).Text) 132 | 133 | // Test division by zero 134 | req.Params.Arguments = map[string]any{ 135 | "operation": "divide", 136 | "x": 10.0, 137 | "y": 0.0, 138 | } 139 | 140 | result, err = wrappedHandler(context.Background(), req) 141 | assert.NoError(t, err) 142 | assert.NotNil(t, result) 143 | assert.True(t, result.IsError) 144 | assert.Contains(t, result.Content[0].(TextContent).Text, "division by zero") 145 | } 146 | 147 | func TestTypedToolHandlerWithComplexObjects(t *testing.T) { 148 | // Define a complex test struct with nested objects 149 | type Address struct { 150 | Street string `json:"street"` 151 | City string `json:"city"` 152 | Country string `json:"country"` 153 | ZipCode string `json:"zip_code"` 154 | } 155 | 156 | type UserPreferences struct { 157 | Theme string `json:"theme"` 158 | Timezone string `json:"timezone"` 159 | Newsletters []string `json:"newsletters"` 160 | } 161 | 162 | type UserProfile struct { 163 | Name string `json:"name"` 164 | Email string `json:"email"` 165 | Age int `json:"age"` 166 | IsVerified bool `json:"is_verified"` 167 | Address Address `json:"address"` 168 | Preferences UserPreferences `json:"preferences"` 169 | Tags []string `json:"tags"` 170 | } 171 | 172 | // Create a typed handler function 173 | typedHandler := func(ctx context.Context, request CallToolRequest, profile UserProfile) (*CallToolResult, error) { 174 | // Validate required fields 175 | if profile.Name == "" { 176 | return NewToolResultError("name is required"), nil 177 | } 178 | if profile.Email == "" { 179 | return NewToolResultError("email is required"), nil 180 | } 181 | 182 | // Build a response that includes nested object data 183 | response := fmt.Sprintf("User: %s (%s)", profile.Name, profile.Email) 184 | 185 | if profile.Age > 0 { 186 | response += fmt.Sprintf(", Age: %d", profile.Age) 187 | } 188 | 189 | if profile.IsVerified { 190 | response += ", Verified: Yes" 191 | } else { 192 | response += ", Verified: No" 193 | } 194 | 195 | // Include address information if available 196 | if profile.Address.City != "" && profile.Address.Country != "" { 197 | response += fmt.Sprintf(", Location: %s, %s", profile.Address.City, profile.Address.Country) 198 | } 199 | 200 | // Include preferences if available 201 | if profile.Preferences.Theme != "" { 202 | response += fmt.Sprintf(", Theme: %s", profile.Preferences.Theme) 203 | } 204 | 205 | if len(profile.Preferences.Newsletters) > 0 { 206 | response += fmt.Sprintf(", Subscribed to %d newsletters", len(profile.Preferences.Newsletters)) 207 | } 208 | 209 | if len(profile.Tags) > 0 { 210 | response += fmt.Sprintf(", Tags: %v", profile.Tags) 211 | } 212 | 213 | return NewToolResultText(response), nil 214 | } 215 | 216 | // Create a wrapped handler 217 | wrappedHandler := NewTypedToolHandler(typedHandler) 218 | 219 | // Test with complete complex object 220 | req := CallToolRequest{} 221 | req.Params.Name = "user_profile" 222 | req.Params.Arguments = map[string]any{ 223 | "name": "John Doe", 224 | "email": "john@example.com", 225 | "age": 35, 226 | "is_verified": true, 227 | "address": map[string]any{ 228 | "street": "123 Main St", 229 | "city": "San Francisco", 230 | "country": "USA", 231 | "zip_code": "94105", 232 | }, 233 | "preferences": map[string]any{ 234 | "theme": "dark", 235 | "timezone": "America/Los_Angeles", 236 | "newsletters": []string{"weekly", "product_updates"}, 237 | }, 238 | "tags": []string{"premium", "early_adopter"}, 239 | } 240 | 241 | // Call the wrapped handler 242 | result, err := wrappedHandler(context.Background(), req) 243 | 244 | // Verify results 245 | assert.NoError(t, err) 246 | assert.NotNil(t, result) 247 | assert.Contains(t, result.Content[0].(TextContent).Text, "John Doe") 248 | assert.Contains(t, result.Content[0].(TextContent).Text, "San Francisco, USA") 249 | assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: dark") 250 | assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 2 newsletters") 251 | assert.Contains(t, result.Content[0].(TextContent).Text, "Tags: [premium early_adopter]") 252 | 253 | // Test with partial data (missing some nested fields) 254 | req.Params.Arguments = map[string]any{ 255 | "name": "Jane Smith", 256 | "email": "jane@example.com", 257 | "age": 28, 258 | "is_verified": false, 259 | "address": map[string]any{ 260 | "city": "London", 261 | "country": "UK", 262 | }, 263 | "preferences": map[string]any{ 264 | "theme": "light", 265 | }, 266 | } 267 | 268 | result, err = wrappedHandler(context.Background(), req) 269 | assert.NoError(t, err) 270 | assert.NotNil(t, result) 271 | assert.Contains(t, result.Content[0].(TextContent).Text, "Jane Smith") 272 | assert.Contains(t, result.Content[0].(TextContent).Text, "London, UK") 273 | assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: light") 274 | assert.NotContains(t, result.Content[0].(TextContent).Text, "newsletters") 275 | 276 | // Test with JSON string input (simulating raw JSON from client) 277 | jsonInput := `{ 278 | "name": "Bob Johnson", 279 | "email": "bob@example.com", 280 | "age": 42, 281 | "is_verified": true, 282 | "address": { 283 | "street": "456 Park Ave", 284 | "city": "New York", 285 | "country": "USA", 286 | "zip_code": "10022" 287 | }, 288 | "preferences": { 289 | "theme": "system", 290 | "timezone": "America/New_York", 291 | "newsletters": ["monthly"] 292 | }, 293 | "tags": ["business"] 294 | }` 295 | 296 | req.Params.Arguments = json.RawMessage(jsonInput) 297 | result, err = wrappedHandler(context.Background(), req) 298 | assert.NoError(t, err) 299 | assert.NotNil(t, result) 300 | assert.Contains(t, result.Content[0].(TextContent).Text, "Bob Johnson") 301 | assert.Contains(t, result.Content[0].(TextContent).Text, "New York, USA") 302 | assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: system") 303 | assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 1 newsletters") 304 | } 305 | -------------------------------------------------------------------------------- /mcp/types_test.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestMetaMarshalling(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | json string 15 | meta *Meta 16 | expMeta *Meta 17 | }{ 18 | { 19 | name: "empty", 20 | json: "{}", 21 | meta: &Meta{}, 22 | expMeta: &Meta{AdditionalFields: map[string]any{}}, 23 | }, 24 | { 25 | name: "empty additional fields", 26 | json: "{}", 27 | meta: &Meta{AdditionalFields: map[string]any{}}, 28 | expMeta: &Meta{AdditionalFields: map[string]any{}}, 29 | }, 30 | { 31 | name: "string token only", 32 | json: `{"progressToken":"123"}`, 33 | meta: &Meta{ProgressToken: "123"}, 34 | expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, 35 | }, 36 | { 37 | name: "string token only, empty additional fields", 38 | json: `{"progressToken":"123"}`, 39 | meta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, 40 | expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{}}, 41 | }, 42 | { 43 | name: "additional fields only", 44 | json: `{"a":2,"b":"1"}`, 45 | meta: &Meta{AdditionalFields: map[string]any{"a": 2, "b": "1"}}, 46 | // For untyped map, numbers are always float64 47 | expMeta: &Meta{AdditionalFields: map[string]any{"a": float64(2), "b": "1"}}, 48 | }, 49 | { 50 | name: "progress token and additional fields", 51 | json: `{"a":2,"b":"1","progressToken":"123"}`, 52 | meta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{"a": 2, "b": "1"}}, 53 | // For untyped map, numbers are always float64 54 | expMeta: &Meta{ProgressToken: "123", AdditionalFields: map[string]any{"a": float64(2), "b": "1"}}, 55 | }, 56 | } 57 | 58 | for _, tc := range tests { 59 | t.Run(tc.name, func(t *testing.T) { 60 | data, err := json.Marshal(tc.meta) 61 | require.NoError(t, err) 62 | assert.Equal(t, tc.json, string(data)) 63 | 64 | meta := &Meta{} 65 | err = json.Unmarshal([]byte(tc.json), meta) 66 | require.NoError(t, err) 67 | assert.Equal(t, tc.expMeta, meta) 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /mcptest/mcptest.go: -------------------------------------------------------------------------------- 1 | // Package mcptest implements helper functions for testing MCP servers. 2 | package mcptest 3 | 4 | import ( 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "io" 9 | "log" 10 | "sync" 11 | "testing" 12 | 13 | "github.com/mark3labs/mcp-go/client" 14 | "github.com/mark3labs/mcp-go/client/transport" 15 | "github.com/mark3labs/mcp-go/mcp" 16 | "github.com/mark3labs/mcp-go/server" 17 | ) 18 | 19 | // Server encapsulates an MCP server and manages resources like pipes and context. 20 | type Server struct { 21 | name string 22 | 23 | tools []server.ServerTool 24 | prompts []server.ServerPrompt 25 | resources []server.ServerResource 26 | 27 | ctx context.Context 28 | cancel func() 29 | 30 | serverReader *io.PipeReader 31 | serverWriter *io.PipeWriter 32 | clientReader *io.PipeReader 33 | clientWriter *io.PipeWriter 34 | 35 | logBuffer bytes.Buffer 36 | 37 | transport transport.Interface 38 | client *client.Client 39 | 40 | wg sync.WaitGroup 41 | } 42 | 43 | // NewServer starts a new MCP server with the provided tools and returns the server instance. 44 | func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) { 45 | server := NewUnstartedServer(t) 46 | server.AddTools(tools...) 47 | 48 | if err := server.Start(); err != nil { 49 | return nil, err 50 | } 51 | 52 | return server, nil 53 | } 54 | 55 | // NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server. 56 | // Useful for tests where you need to add tools before starting the server. 57 | func NewUnstartedServer(t *testing.T) *Server { 58 | server := &Server{ 59 | name: t.Name(), 60 | } 61 | 62 | // Use t.Context() once we switch to go >= 1.24 63 | ctx := context.TODO() 64 | 65 | // Set up context with cancellation, used to stop the server 66 | server.ctx, server.cancel = context.WithCancel(ctx) 67 | 68 | // Set up pipes for client-server communication 69 | server.serverReader, server.clientWriter = io.Pipe() 70 | server.clientReader, server.serverWriter = io.Pipe() 71 | 72 | // Return the configured server 73 | return server 74 | } 75 | 76 | // AddTools adds multiple tools to an unstarted server. 77 | func (s *Server) AddTools(tools ...server.ServerTool) { 78 | s.tools = append(s.tools, tools...) 79 | } 80 | 81 | // AddTool adds a tool to an unstarted server. 82 | func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) { 83 | s.tools = append(s.tools, server.ServerTool{ 84 | Tool: tool, 85 | Handler: handler, 86 | }) 87 | } 88 | 89 | // AddPrompt adds a prompt to an unstarted server. 90 | func (s *Server) AddPrompt(prompt mcp.Prompt, handler server.PromptHandlerFunc) { 91 | s.prompts = append(s.prompts, server.ServerPrompt{ 92 | Prompt: prompt, 93 | Handler: handler, 94 | }) 95 | } 96 | 97 | // AddPrompts adds multiple prompts to an unstarted server. 98 | func (s *Server) AddPrompts(prompts ...server.ServerPrompt) { 99 | s.prompts = append(s.prompts, prompts...) 100 | } 101 | 102 | // AddResource adds a resource to an unstarted server. 103 | func (s *Server) AddResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { 104 | s.resources = append(s.resources, server.ServerResource{ 105 | Resource: resource, 106 | Handler: handler, 107 | }) 108 | } 109 | 110 | // AddResources adds multiple resources to an unstarted server. 111 | func (s *Server) AddResources(resources ...server.ServerResource) { 112 | s.resources = append(s.resources, resources...) 113 | } 114 | 115 | // Start starts the server in a goroutine. Make sure to defer Close() after Start(). 116 | // When using NewServer(), the returned server is already started. 117 | func (s *Server) Start() error { 118 | s.wg.Add(1) 119 | 120 | // Start the MCP server in a goroutine 121 | go func() { 122 | defer s.wg.Done() 123 | 124 | mcpServer := server.NewMCPServer(s.name, "1.0.0") 125 | 126 | mcpServer.AddTools(s.tools...) 127 | mcpServer.AddPrompts(s.prompts...) 128 | mcpServer.AddResources(s.resources...) 129 | 130 | logger := log.New(&s.logBuffer, "", 0) 131 | 132 | stdioServer := server.NewStdioServer(mcpServer) 133 | stdioServer.SetErrorLogger(logger) 134 | 135 | if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil { 136 | logger.Println("StdioServer.Listen failed:", err) 137 | } 138 | }() 139 | 140 | s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer)) 141 | if err := s.transport.Start(s.ctx); err != nil { 142 | return fmt.Errorf("transport.Start(): %w", err) 143 | } 144 | 145 | s.client = client.NewClient(s.transport) 146 | 147 | var initReq mcp.InitializeRequest 148 | initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION 149 | if _, err := s.client.Initialize(s.ctx, initReq); err != nil { 150 | return fmt.Errorf("client.Initialize(): %w", err) 151 | } 152 | 153 | return nil 154 | } 155 | 156 | // Close stops the server and cleans up resources like temporary directories. 157 | func (s *Server) Close() { 158 | if s.transport != nil { 159 | s.transport.Close() 160 | s.transport = nil 161 | s.client = nil 162 | } 163 | 164 | if s.cancel != nil { 165 | s.cancel() 166 | s.cancel = nil 167 | } 168 | 169 | // Wait for server goroutine to finish 170 | s.wg.Wait() 171 | 172 | s.serverWriter.Close() 173 | s.serverReader.Close() 174 | s.serverReader, s.serverWriter = nil, nil 175 | 176 | s.clientWriter.Close() 177 | s.clientReader.Close() 178 | s.clientReader, s.clientWriter = nil, nil 179 | } 180 | 181 | // Client returns an MCP client connected to the server. 182 | // The client is already initialized, i.e. you do _not_ need to call Client.Initialize(). 183 | func (s *Server) Client() *client.Client { 184 | return s.client 185 | } 186 | -------------------------------------------------------------------------------- /mcptest/mcptest_test.go: -------------------------------------------------------------------------------- 1 | package mcptest_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/mark3labs/mcp-go/mcp" 10 | "github.com/mark3labs/mcp-go/mcptest" 11 | "github.com/mark3labs/mcp-go/server" 12 | ) 13 | 14 | func TestServerWithTool(t *testing.T) { 15 | ctx := context.Background() 16 | 17 | srv, err := mcptest.NewServer(t, server.ServerTool{ 18 | Tool: mcp.NewTool("hello", 19 | mcp.WithDescription("Says hello to the provided name, or world."), 20 | mcp.WithString("name", mcp.Description("The name to say hello to.")), 21 | ), 22 | Handler: helloWorldHandler, 23 | }) 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | defer srv.Close() 28 | 29 | client := srv.Client() 30 | 31 | var req mcp.CallToolRequest 32 | req.Params.Name = "hello" 33 | req.Params.Arguments = map[string]any{ 34 | "name": "Claude", 35 | } 36 | 37 | result, err := client.CallTool(ctx, req) 38 | if err != nil { 39 | t.Fatal("CallTool:", err) 40 | } 41 | 42 | got, err := resultToString(result) 43 | if err != nil { 44 | t.Fatal(err) 45 | } 46 | 47 | want := "Hello, Claude!" 48 | if got != want { 49 | t.Errorf("Got %q, want %q", got, want) 50 | } 51 | } 52 | 53 | func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 54 | // Extract name from request arguments 55 | name, ok := request.GetArguments()["name"].(string) 56 | if !ok { 57 | name = "World" 58 | } 59 | 60 | return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil 61 | } 62 | 63 | func resultToString(result *mcp.CallToolResult) (string, error) { 64 | var b strings.Builder 65 | 66 | for _, content := range result.Content { 67 | text, ok := content.(mcp.TextContent) 68 | if !ok { 69 | return "", fmt.Errorf("unsupported content type: %T", content) 70 | } 71 | b.WriteString(text.Text) 72 | } 73 | 74 | if result.IsError { 75 | return "", fmt.Errorf("%s", b.String()) 76 | } 77 | 78 | return b.String(), nil 79 | } 80 | 81 | func TestServerWithPrompt(t *testing.T) { 82 | ctx := context.Background() 83 | 84 | srv := mcptest.NewUnstartedServer(t) 85 | defer srv.Close() 86 | 87 | prompt := mcp.Prompt{ 88 | Name: "greeting", 89 | Description: "A greeting prompt", 90 | Arguments: []mcp.PromptArgument{ 91 | { 92 | Name: "name", 93 | Description: "The name to greet", 94 | Required: true, 95 | }, 96 | }, 97 | } 98 | handler := func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 99 | return &mcp.GetPromptResult{ 100 | Description: "A greeting prompt", 101 | Messages: []mcp.PromptMessage{ 102 | { 103 | Role: mcp.RoleUser, 104 | Content: mcp.NewTextContent(fmt.Sprintf("Hello, %s!", request.Params.Arguments["name"])), 105 | }, 106 | }, 107 | }, nil 108 | } 109 | 110 | srv.AddPrompt(prompt, handler) 111 | 112 | err := srv.Start() 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | 117 | var getReq mcp.GetPromptRequest 118 | getReq.Params.Name = "greeting" 119 | getReq.Params.Arguments = map[string]string{"name": "John"} 120 | getResult, err := srv.Client().GetPrompt(ctx, getReq) 121 | if err != nil { 122 | t.Fatal("GetPrompt:", err) 123 | } 124 | if getResult.Description != "A greeting prompt" { 125 | t.Errorf("Expected prompt description 'A greeting prompt', got %q", getResult.Description) 126 | } 127 | if len(getResult.Messages) != 1 { 128 | t.Fatalf("Expected 1 message, got %d", len(getResult.Messages)) 129 | } 130 | if getResult.Messages[0].Role != mcp.RoleUser { 131 | t.Errorf("Expected message role 'user', got %q", getResult.Messages[0].Role) 132 | } 133 | content, ok := getResult.Messages[0].Content.(mcp.TextContent) 134 | if !ok { 135 | t.Fatalf("Expected TextContent, got %T", getResult.Messages[0].Content) 136 | } 137 | if content.Text != "Hello, John!" { 138 | t.Errorf("Expected message content 'Hello, John!', got %q", content.Text) 139 | } 140 | } 141 | 142 | func TestServerWithResource(t *testing.T) { 143 | ctx := context.Background() 144 | 145 | srv := mcptest.NewUnstartedServer(t) 146 | defer srv.Close() 147 | 148 | resource := mcp.Resource{ 149 | URI: "test://resource", 150 | Name: "Test Resource", 151 | Description: "A test resource", 152 | MIMEType: "text/plain", 153 | } 154 | 155 | handler := func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 156 | return []mcp.ResourceContents{ 157 | mcp.TextResourceContents{ 158 | URI: "test://resource", 159 | MIMEType: "text/plain", 160 | Text: "This is a test resource content.", 161 | }, 162 | }, nil 163 | } 164 | 165 | srv.AddResource(resource, handler) 166 | 167 | err := srv.Start() 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | 172 | var readReq mcp.ReadResourceRequest 173 | readReq.Params.URI = "test://resource" 174 | readResult, err := srv.Client().ReadResource(ctx, readReq) 175 | if err != nil { 176 | t.Fatal("ReadResource:", err) 177 | } 178 | if len(readResult.Contents) != 1 { 179 | t.Fatalf("Expected 1 content, got %d", len(readResult.Contents)) 180 | } 181 | textContent, ok := readResult.Contents[0].(mcp.TextResourceContents) 182 | if !ok { 183 | t.Fatalf("Expected TextResourceContents, got %T", readResult.Contents[0]) 184 | } 185 | want := "This is a test resource content." 186 | if textContent.Text != want { 187 | t.Errorf("Got %q, want %q", textContent.Text, want) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /server/errors.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | // Common server errors 10 | ErrUnsupported = errors.New("not supported") 11 | ErrResourceNotFound = errors.New("resource not found") 12 | ErrPromptNotFound = errors.New("prompt not found") 13 | ErrToolNotFound = errors.New("tool not found") 14 | 15 | // Session-related errors 16 | ErrSessionNotFound = errors.New("session not found") 17 | ErrSessionExists = errors.New("session already exists") 18 | ErrSessionNotInitialized = errors.New("session not properly initialized") 19 | ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") 20 | ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") 21 | 22 | // Notification-related errors 23 | ErrNotificationNotInitialized = errors.New("notification channel not initialized") 24 | ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") 25 | ) 26 | 27 | // ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration 28 | type ErrDynamicPathConfig struct { 29 | Method string 30 | } 31 | 32 | func (e *ErrDynamicPathConfig) Error() string { 33 | return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) 34 | } 35 | -------------------------------------------------------------------------------- /server/http_transport_options.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | // HTTPContextFunc is a function that takes an existing context and the current 9 | // request and returns a potentially modified context based on the request 10 | // content. This can be used to inject context values from headers, for example. 11 | type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context 12 | -------------------------------------------------------------------------------- /server/internal/gen/README.md: -------------------------------------------------------------------------------- 1 | # Readme for Codegen 2 | 3 | This internal module contains code generation for producing a few repetitive 4 | constructs, namely: 5 | 6 | - The switch statement that handles the request dispatch 7 | - The hook function types and the methods on the Hook struct 8 | 9 | To invoke the code generation: 10 | 11 | ``` 12 | go generate ./... 13 | ``` 14 | 15 | ## Development 16 | 17 | - `request_handler.go.tmpl` generates `server/request_handler.go`, and 18 | - `hooks.go.tmpl` generates `server/hooks.go` 19 | 20 | Inside of `data.go` there is a struct with the inputs to both templates. 21 | 22 | Note that the driver in `main.go` generates code and also pipes it through 23 | `goimports` for formatting and imports cleanup. 24 | 25 | -------------------------------------------------------------------------------- /server/internal/gen/data.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | type MCPRequestType struct { 4 | MethodName string 5 | ParamType string 6 | ResultType string 7 | HookName string 8 | Group string 9 | GroupName string 10 | GroupHookName string 11 | UnmarshalError string 12 | HandlerFunc string 13 | } 14 | 15 | var MCPRequestTypes = []MCPRequestType{ 16 | { 17 | MethodName: "MethodInitialize", 18 | ParamType: "InitializeRequest", 19 | ResultType: "InitializeResult", 20 | HookName: "Initialize", 21 | UnmarshalError: "invalid initialize request", 22 | HandlerFunc: "handleInitialize", 23 | }, { 24 | MethodName: "MethodPing", 25 | ParamType: "PingRequest", 26 | ResultType: "EmptyResult", 27 | HookName: "Ping", 28 | UnmarshalError: "invalid ping request", 29 | HandlerFunc: "handlePing", 30 | }, { 31 | MethodName: "MethodSetLogLevel", 32 | ParamType: "SetLevelRequest", 33 | ResultType: "EmptyResult", 34 | Group: "logging", 35 | GroupName: "Logging", 36 | GroupHookName: "Logging", 37 | HookName: "SetLevel", 38 | UnmarshalError: "invalid set level request", 39 | HandlerFunc: "handleSetLevel", 40 | }, { 41 | MethodName: "MethodResourcesList", 42 | ParamType: "ListResourcesRequest", 43 | ResultType: "ListResourcesResult", 44 | Group: "resources", 45 | GroupName: "Resources", 46 | GroupHookName: "Resource", 47 | HookName: "ListResources", 48 | UnmarshalError: "invalid list resources request", 49 | HandlerFunc: "handleListResources", 50 | }, { 51 | MethodName: "MethodResourcesTemplatesList", 52 | ParamType: "ListResourceTemplatesRequest", 53 | ResultType: "ListResourceTemplatesResult", 54 | Group: "resources", 55 | GroupName: "Resources", 56 | GroupHookName: "Resource", 57 | HookName: "ListResourceTemplates", 58 | UnmarshalError: "invalid list resource templates request", 59 | HandlerFunc: "handleListResourceTemplates", 60 | }, { 61 | MethodName: "MethodResourcesRead", 62 | ParamType: "ReadResourceRequest", 63 | ResultType: "ReadResourceResult", 64 | Group: "resources", 65 | GroupName: "Resources", 66 | GroupHookName: "Resource", 67 | HookName: "ReadResource", 68 | UnmarshalError: "invalid read resource request", 69 | HandlerFunc: "handleReadResource", 70 | }, { 71 | MethodName: "MethodPromptsList", 72 | ParamType: "ListPromptsRequest", 73 | ResultType: "ListPromptsResult", 74 | Group: "prompts", 75 | GroupName: "Prompts", 76 | GroupHookName: "Prompt", 77 | HookName: "ListPrompts", 78 | UnmarshalError: "invalid list prompts request", 79 | HandlerFunc: "handleListPrompts", 80 | }, { 81 | MethodName: "MethodPromptsGet", 82 | ParamType: "GetPromptRequest", 83 | ResultType: "GetPromptResult", 84 | Group: "prompts", 85 | GroupName: "Prompts", 86 | GroupHookName: "Prompt", 87 | HookName: "GetPrompt", 88 | UnmarshalError: "invalid get prompt request", 89 | HandlerFunc: "handleGetPrompt", 90 | }, { 91 | MethodName: "MethodToolsList", 92 | ParamType: "ListToolsRequest", 93 | ResultType: "ListToolsResult", 94 | Group: "tools", 95 | GroupName: "Tools", 96 | GroupHookName: "Tool", 97 | HookName: "ListTools", 98 | UnmarshalError: "invalid list tools request", 99 | HandlerFunc: "handleListTools", 100 | }, { 101 | MethodName: "MethodToolsCall", 102 | ParamType: "CallToolRequest", 103 | ResultType: "CallToolResult", 104 | Group: "tools", 105 | GroupName: "Tools", 106 | GroupHookName: "Tool", 107 | HookName: "CallTool", 108 | UnmarshalError: "invalid call tool request", 109 | HandlerFunc: "handleToolCall", 110 | }, 111 | } 112 | -------------------------------------------------------------------------------- /server/internal/gen/hooks.go.tmpl: -------------------------------------------------------------------------------- 1 | // Code generated by `go generate`. DO NOT EDIT. 2 | // source: server/internal/gen/hooks.go.tmpl 3 | package server 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. 15 | type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) 16 | 17 | // OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. 18 | type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) 19 | 20 | // BeforeAnyHookFunc is a function that is called after the request is 21 | // parsed but before the method is called. 22 | type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) 23 | 24 | // OnSuccessHookFunc is a hook that will be called after the request 25 | // successfully generates a result, but before the result is sent to the client. 26 | type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) 27 | 28 | // OnErrorHookFunc is a hook that will be called when an error occurs, 29 | // either during the request parsing or the method execution. 30 | // 31 | // Example usage: 32 | // ``` 33 | // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 34 | // // Check for specific error types using errors.Is 35 | // if errors.Is(err, ErrUnsupported) { 36 | // // Handle capability not supported errors 37 | // log.Printf("Capability not supported: %v", err) 38 | // } 39 | // 40 | // // Use errors.As to get specific error types 41 | // var parseErr = &UnparsableMessageError{} 42 | // if errors.As(err, &parseErr) { 43 | // // Access specific methods/fields of the error type 44 | // log.Printf("Failed to parse message for method %s: %v", 45 | // parseErr.GetMethod(), parseErr.Unwrap()) 46 | // // Access the raw message that failed to parse 47 | // rawMsg := parseErr.GetMessage() 48 | // } 49 | // 50 | // // Check for specific resource/prompt/tool errors 51 | // switch { 52 | // case errors.Is(err, ErrResourceNotFound): 53 | // log.Printf("Resource not found: %v", err) 54 | // case errors.Is(err, ErrPromptNotFound): 55 | // log.Printf("Prompt not found: %v", err) 56 | // case errors.Is(err, ErrToolNotFound): 57 | // log.Printf("Tool not found: %v", err) 58 | // } 59 | // }) 60 | type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) 61 | 62 | // OnRequestInitializationFunc is a function that called before handle diff request method 63 | // Should any errors arise during func execution, the service will promptly return the corresponding error message. 64 | type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error 65 | 66 | 67 | {{range .}} 68 | type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}) 69 | type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) 70 | {{end}} 71 | 72 | type Hooks struct { 73 | OnRegisterSession []OnRegisterSessionHookFunc 74 | OnUnregisterSession []OnUnregisterSessionHookFunc 75 | OnBeforeAny []BeforeAnyHookFunc 76 | OnSuccess []OnSuccessHookFunc 77 | OnError []OnErrorHookFunc 78 | OnRequestInitialization []OnRequestInitializationFunc 79 | {{- range .}} 80 | OnBefore{{.HookName}} []OnBefore{{.HookName}}Func 81 | OnAfter{{.HookName}} []OnAfter{{.HookName}}Func 82 | {{- end}} 83 | } 84 | 85 | func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { 86 | c.OnBeforeAny = append(c.OnBeforeAny, hook) 87 | } 88 | 89 | func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { 90 | c.OnSuccess = append(c.OnSuccess, hook) 91 | } 92 | 93 | // AddOnError registers a hook function that will be called when an error occurs. 94 | // The error parameter contains the actual error object, which can be interrogated 95 | // using Go's error handling patterns like errors.Is and errors.As. 96 | // 97 | // Example: 98 | // ``` 99 | // // Create a channel to receive errors for testing 100 | // errChan := make(chan error, 1) 101 | // 102 | // // Register hook to capture and inspect errors 103 | // hooks := &Hooks{} 104 | // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 105 | // // For capability-related errors 106 | // if errors.Is(err, ErrUnsupported) { 107 | // // Handle capability not supported 108 | // errChan <- err 109 | // return 110 | // } 111 | // 112 | // // For parsing errors 113 | // var parseErr = &UnparsableMessageError{} 114 | // if errors.As(err, &parseErr) { 115 | // // Handle unparsable message errors 116 | // fmt.Printf("Failed to parse %s request: %v\n", 117 | // parseErr.GetMethod(), parseErr.Unwrap()) 118 | // errChan <- parseErr 119 | // return 120 | // } 121 | // 122 | // // For resource/prompt/tool not found errors 123 | // if errors.Is(err, ErrResourceNotFound) || 124 | // errors.Is(err, ErrPromptNotFound) || 125 | // errors.Is(err, ErrToolNotFound) { 126 | // // Handle not found errors 127 | // errChan <- err 128 | // return 129 | // } 130 | // 131 | // // For other errors 132 | // errChan <- err 133 | // }) 134 | // 135 | // server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) 136 | // ``` 137 | func (c *Hooks) AddOnError(hook OnErrorHookFunc) { 138 | c.OnError = append(c.OnError, hook) 139 | } 140 | 141 | func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { 142 | if c == nil { 143 | return 144 | } 145 | for _, hook := range c.OnBeforeAny { 146 | hook(ctx, id, method, message) 147 | } 148 | } 149 | 150 | func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { 151 | if c == nil { 152 | return 153 | } 154 | for _, hook := range c.OnSuccess { 155 | hook(ctx, id, method, message, result) 156 | } 157 | } 158 | 159 | // onError calls all registered error hooks with the error object. 160 | // The err parameter contains the actual error that occurred, which implements 161 | // the standard error interface and may be a wrapped error or custom error type. 162 | // 163 | // This allows consumer code to use Go's error handling patterns: 164 | // - errors.Is(err, ErrUnsupported) to check for specific sentinel errors 165 | // - errors.As(err, &customErr) to extract custom error types 166 | // 167 | // Common error types include: 168 | // - ErrUnsupported: When a capability is not enabled 169 | // - UnparsableMessageError: When request parsing fails 170 | // - ErrResourceNotFound: When a resource is not found 171 | // - ErrPromptNotFound: When a prompt is not found 172 | // - ErrToolNotFound: When a tool is not found 173 | func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { 174 | if c == nil { 175 | return 176 | } 177 | for _, hook := range c.OnError { 178 | hook(ctx, id, method, message, err) 179 | } 180 | } 181 | 182 | func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { 183 | c.OnRegisterSession = append(c.OnRegisterSession, hook) 184 | } 185 | 186 | func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { 187 | if c == nil { 188 | return 189 | } 190 | for _, hook := range c.OnRegisterSession { 191 | hook(ctx, session) 192 | } 193 | } 194 | 195 | func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { 196 | c.OnUnregisterSession = append(c.OnUnregisterSession, hook) 197 | } 198 | 199 | func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { 200 | if c == nil { 201 | return 202 | } 203 | for _, hook := range c.OnUnregisterSession { 204 | hook(ctx, session) 205 | } 206 | } 207 | 208 | func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) { 209 | c.OnRequestInitialization = append(c.OnRequestInitialization, hook) 210 | } 211 | 212 | func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error { 213 | if c == nil { 214 | return nil 215 | } 216 | for _, hook := range c.OnRequestInitialization { 217 | err := hook(ctx, id, message) 218 | if err != nil { 219 | return err 220 | } 221 | } 222 | return nil 223 | } 224 | 225 | {{- range .}} 226 | func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) { 227 | c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook) 228 | } 229 | 230 | func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) { 231 | c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook) 232 | } 233 | 234 | func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) { 235 | c.beforeAny(ctx, id, mcp.{{.MethodName}}, message) 236 | if c == nil { 237 | return 238 | } 239 | for _, hook := range c.OnBefore{{.HookName}} { 240 | hook(ctx, id, message) 241 | } 242 | } 243 | 244 | func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) { 245 | c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result) 246 | if c == nil { 247 | return 248 | } 249 | for _, hook := range c.OnAfter{{.HookName}} { 250 | hook(ctx, id, message, result) 251 | } 252 | } 253 | {{- end -}} 254 | -------------------------------------------------------------------------------- /server/internal/gen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | "io" 7 | "log" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "strings" 12 | "text/template" 13 | ) 14 | 15 | //go:generate go run . ../.. 16 | 17 | //go:embed hooks.go.tmpl 18 | var hooksTemplate string 19 | 20 | //go:embed request_handler.go.tmpl 21 | var requestHandlerTemplate string 22 | 23 | func RenderTemplateToFile(templateContent, destPath, fileName string, data any) error { 24 | // Create temp file for initial output 25 | tempFile, err := os.CreateTemp("", "hooks-*.go") 26 | if err != nil { 27 | return err 28 | } 29 | tempFilePath := tempFile.Name() 30 | defer os.Remove(tempFilePath) // Clean up temp file when done 31 | defer tempFile.Close() 32 | 33 | // Parse and execute template to temp file 34 | tmpl, err := template.New(fileName).Funcs(template.FuncMap{ 35 | "toLower": strings.ToLower, 36 | }).Parse(templateContent) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | if err := tmpl.Execute(tempFile, data); err != nil { 42 | return err 43 | } 44 | 45 | // Run goimports on the temp file 46 | cmd := exec.Command("go", "run", "golang.org/x/tools/cmd/goimports@latest", "-w", tempFilePath) 47 | if output, err := cmd.CombinedOutput(); err != nil { 48 | return fmt.Errorf("goimports failed: %v\n%s", err, output) 49 | } 50 | 51 | // Read the processed content 52 | processedContent, err := os.ReadFile(tempFilePath) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | // Write the processed content to the destination 58 | var destWriter io.Writer 59 | if destPath == "-" { 60 | destWriter = os.Stdout 61 | } else { 62 | destFile, err := os.Create(filepath.Join(destPath, fileName)) 63 | if err != nil { 64 | return err 65 | } 66 | defer destFile.Close() 67 | destWriter = destFile 68 | } 69 | 70 | _, err = destWriter.Write(processedContent) 71 | return err 72 | } 73 | 74 | func main() { 75 | if len(os.Args) < 2 { 76 | log.Fatal("usage: gen ") 77 | } 78 | destPath := os.Args[1] 79 | 80 | if err := RenderTemplateToFile(hooksTemplate, destPath, "hooks.go", MCPRequestTypes); err != nil { 81 | log.Fatal(err) 82 | } 83 | 84 | if err := RenderTemplateToFile(requestHandlerTemplate, destPath, "request_handler.go", MCPRequestTypes); err != nil { 85 | log.Fatal(err) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /server/internal/gen/request_handler.go.tmpl: -------------------------------------------------------------------------------- 1 | // Code generated by `go generate`. DO NOT EDIT. 2 | // source: server/internal/gen/request_handler.go.tmpl 3 | package server 4 | 5 | import ( 6 | "context" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | 11 | "github.com/mark3labs/mcp-go/mcp" 12 | ) 13 | 14 | // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response 15 | func (s *MCPServer) HandleMessage( 16 | ctx context.Context, 17 | message json.RawMessage, 18 | ) mcp.JSONRPCMessage { 19 | // Add server to context 20 | ctx = context.WithValue(ctx, serverKey{}, s) 21 | var err *requestError 22 | 23 | var baseMessage struct { 24 | JSONRPC string `json:"jsonrpc"` 25 | Method mcp.MCPMethod `json:"method"` 26 | ID any `json:"id,omitempty"` 27 | Result any `json:"result,omitempty"` 28 | } 29 | 30 | if err := json.Unmarshal(message, &baseMessage); err != nil { 31 | return createErrorResponse( 32 | nil, 33 | mcp.PARSE_ERROR, 34 | "Failed to parse message", 35 | ) 36 | } 37 | 38 | // Check for valid JSONRPC version 39 | if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { 40 | return createErrorResponse( 41 | baseMessage.ID, 42 | mcp.INVALID_REQUEST, 43 | "Invalid JSON-RPC version", 44 | ) 45 | } 46 | 47 | if baseMessage.ID == nil { 48 | var notification mcp.JSONRPCNotification 49 | if err := json.Unmarshal(message, ¬ification); err != nil { 50 | return createErrorResponse( 51 | nil, 52 | mcp.PARSE_ERROR, 53 | "Failed to parse notification", 54 | ) 55 | } 56 | s.handleNotification(ctx, notification) 57 | return nil // Return nil for notifications 58 | } 59 | 60 | if baseMessage.Result != nil { 61 | // this is a response to a request sent by the server (e.g. from a ping 62 | // sent due to WithKeepAlive option) 63 | return nil 64 | } 65 | 66 | handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message) 67 | if handleErr != nil { 68 | return createErrorResponse( 69 | baseMessage.ID, 70 | mcp.INVALID_REQUEST, 71 | handleErr.Error(), 72 | ) 73 | } 74 | 75 | switch baseMessage.Method { 76 | {{- range .}} 77 | case mcp.{{.MethodName}}: 78 | var request mcp.{{.ParamType}} 79 | var result *mcp.{{.ResultType}} 80 | {{ if .Group }}if s.capabilities.{{.Group}} == nil { 81 | err = &requestError{ 82 | id: baseMessage.ID, 83 | code: mcp.METHOD_NOT_FOUND, 84 | err: fmt.Errorf("{{toLower .GroupName}} %w", ErrUnsupported), 85 | } 86 | } else{{ end }} if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { 87 | err = &requestError{ 88 | id: baseMessage.ID, 89 | code: mcp.INVALID_REQUEST, 90 | err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, 91 | } 92 | } else { 93 | s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) 94 | result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request) 95 | } 96 | if err != nil { 97 | s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) 98 | return err.ToJSONRPCError() 99 | } 100 | s.hooks.after{{.HookName}}(ctx, baseMessage.ID, &request, result) 101 | return createResponse(baseMessage.ID, *result) 102 | {{- end }} 103 | default: 104 | return createErrorResponse( 105 | baseMessage.ID, 106 | mcp.METHOD_NOT_FOUND, 107 | fmt.Sprintf("Method %s not found", baseMessage.Method), 108 | ) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /server/resource_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/mark3labs/mcp-go/mcp" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMCPServer_RemoveResource(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | action func(*testing.T, *MCPServer, chan mcp.JSONRPCNotification) 17 | expectedNotifications int 18 | validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) 19 | }{ 20 | { 21 | name: "RemoveResource removes the resource from the server", 22 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 23 | // Add a test resource 24 | server.AddResource( 25 | mcp.NewResource( 26 | "test://resource1", 27 | "Resource 1", 28 | mcp.WithResourceDescription("Test resource 1"), 29 | mcp.WithMIMEType("text/plain"), 30 | ), 31 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 32 | return []mcp.ResourceContents{ 33 | mcp.TextResourceContents{ 34 | URI: "test://resource1", 35 | MIMEType: "text/plain", 36 | Text: "test content 1", 37 | }, 38 | }, nil 39 | }, 40 | ) 41 | 42 | // Add a second resource 43 | server.AddResource( 44 | mcp.NewResource( 45 | "test://resource2", 46 | "Resource 2", 47 | mcp.WithResourceDescription("Test resource 2"), 48 | mcp.WithMIMEType("text/plain"), 49 | ), 50 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 51 | return []mcp.ResourceContents{ 52 | mcp.TextResourceContents{ 53 | URI: "test://resource2", 54 | MIMEType: "text/plain", 55 | Text: "test content 2", 56 | }, 57 | }, nil 58 | }, 59 | ) 60 | 61 | // First, verify we have two resources 62 | response := server.HandleMessage(context.Background(), []byte(`{ 63 | "jsonrpc": "2.0", 64 | "id": 1, 65 | "method": "resources/list" 66 | }`)) 67 | resp, ok := response.(mcp.JSONRPCResponse) 68 | assert.True(t, ok) 69 | result, ok := resp.Result.(mcp.ListResourcesResult) 70 | assert.True(t, ok) 71 | assert.Len(t, result.Resources, 2) 72 | 73 | // Now register session to receive notifications 74 | err := server.RegisterSession(context.TODO(), &fakeSession{ 75 | sessionID: "test", 76 | notificationChannel: notificationChannel, 77 | initialized: true, 78 | }) 79 | require.NoError(t, err) 80 | 81 | // Now remove one resource 82 | server.RemoveResource("test://resource1") 83 | }, 84 | expectedNotifications: 1, 85 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 86 | // Check that we received a list_changed notification 87 | assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) 88 | 89 | // Verify we now have only one resource 90 | resp, ok := resourcesList.(mcp.JSONRPCResponse) 91 | assert.True(t, ok, "Expected JSONRPCResponse, got %T", resourcesList) 92 | 93 | result, ok := resp.Result.(mcp.ListResourcesResult) 94 | assert.True(t, ok, "Expected ListResourcesResult, got %T", resp.Result) 95 | 96 | assert.Len(t, result.Resources, 1) 97 | assert.Equal(t, "Resource 2", result.Resources[0].Name) 98 | }, 99 | }, 100 | { 101 | name: "RemoveResource with non-existent resource does nothing and not receives notifications from MCPServer", 102 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 103 | // Add a test resource 104 | server.AddResource( 105 | mcp.NewResource( 106 | "test://resource1", 107 | "Resource 1", 108 | mcp.WithResourceDescription("Test resource 1"), 109 | mcp.WithMIMEType("text/plain"), 110 | ), 111 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 112 | return []mcp.ResourceContents{ 113 | mcp.TextResourceContents{ 114 | URI: "test://resource1", 115 | MIMEType: "text/plain", 116 | Text: "test content 1", 117 | }, 118 | }, nil 119 | }, 120 | ) 121 | 122 | // Register session to receive notifications 123 | err := server.RegisterSession(context.TODO(), &fakeSession{ 124 | sessionID: "test", 125 | notificationChannel: notificationChannel, 126 | initialized: true, 127 | }) 128 | require.NoError(t, err) 129 | 130 | // Remove a non-existent resource 131 | server.RemoveResource("test://nonexistent") 132 | }, 133 | expectedNotifications: 0, // No notifications expected 134 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 135 | // verify that no notifications were sent 136 | assert.Empty(t, notifications) 137 | 138 | // The original resource should still be there 139 | resp, ok := resourcesList.(mcp.JSONRPCResponse) 140 | assert.True(t, ok) 141 | 142 | result, ok := resp.Result.(mcp.ListResourcesResult) 143 | assert.True(t, ok) 144 | 145 | assert.Len(t, result.Resources, 1) 146 | assert.Equal(t, "Resource 1", result.Resources[0].Name) 147 | }, 148 | }, 149 | { 150 | name: "RemoveResource with no listChanged capability doesn't send notification", 151 | action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { 152 | // Create a new server without listChanged capability 153 | noListChangedServer := NewMCPServer( 154 | "test-server", 155 | "1.0.0", 156 | WithResourceCapabilities(true, false), // Subscribe but not listChanged 157 | ) 158 | 159 | // Add a resource 160 | noListChangedServer.AddResource( 161 | mcp.NewResource( 162 | "test://resource1", 163 | "Resource 1", 164 | mcp.WithResourceDescription("Test resource 1"), 165 | mcp.WithMIMEType("text/plain"), 166 | ), 167 | func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 168 | return []mcp.ResourceContents{ 169 | mcp.TextResourceContents{ 170 | URI: "test://resource1", 171 | MIMEType: "text/plain", 172 | Text: "test content 1", 173 | }, 174 | }, nil 175 | }, 176 | ) 177 | 178 | // Register session to receive notifications 179 | err := noListChangedServer.RegisterSession(context.TODO(), &fakeSession{ 180 | sessionID: "test", 181 | notificationChannel: notificationChannel, 182 | initialized: true, 183 | }) 184 | require.NoError(t, err) 185 | 186 | // Remove the resource 187 | noListChangedServer.RemoveResource("test://resource1") 188 | 189 | // The test can now proceed without waiting for notifications 190 | // since we don't expect any 191 | }, 192 | expectedNotifications: 0, // No notifications expected 193 | validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { 194 | // Nothing to do here, we're just verifying that no notifications were sent 195 | assert.Empty(t, notifications) 196 | }, 197 | }, 198 | } 199 | 200 | for _, tt := range tests { 201 | t.Run(tt.name, func(t *testing.T) { 202 | ctx := context.Background() 203 | server := NewMCPServer( 204 | "test-server", 205 | "1.0.0", 206 | WithResourceCapabilities(true, true), 207 | ) 208 | 209 | // Initialize the server 210 | _ = server.HandleMessage(ctx, []byte(`{ 211 | "jsonrpc": "2.0", 212 | "id": 1, 213 | "method": "initialize" 214 | }`)) 215 | 216 | notificationChannel := make(chan mcp.JSONRPCNotification, 100) 217 | notifications := make([]mcp.JSONRPCNotification, 0) 218 | 219 | tt.action(t, server, notificationChannel) 220 | 221 | // Collect notifications with a timeout 222 | if tt.expectedNotifications > 0 { 223 | for i := 0; i < tt.expectedNotifications; i++ { 224 | select { 225 | case notification := <-notificationChannel: 226 | notifications = append(notifications, notification) 227 | case <-time.After(1 * time.Second): 228 | t.Fatalf("Expected %d notifications but only received %d", tt.expectedNotifications, len(notifications)) 229 | } 230 | } 231 | } else { 232 | // If no notifications expected, wait a brief period to ensure none are sent 233 | select { 234 | case notification := <-notificationChannel: 235 | notifications = append(notifications, notification) 236 | case <-time.After(100 * time.Millisecond): 237 | // This is the expected path - no notifications 238 | } 239 | } 240 | 241 | // Get final resources list 242 | listMessage := `{ 243 | "jsonrpc": "2.0", 244 | "id": 1, 245 | "method": "resources/list" 246 | }` 247 | resourcesList := server.HandleMessage(ctx, []byte(listMessage)) 248 | 249 | // Validate the results 250 | tt.validate(t, notifications, resourcesList) 251 | }) 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /server/server_race_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/mark3labs/mcp-go/mcp" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // TestRaceConditions attempts to trigger race conditions by performing 16 | // concurrent operations on different resources of the MCPServer. 17 | func TestRaceConditions(t *testing.T) { 18 | // Create a server with all capabilities 19 | srv := NewMCPServer("test-server", "1.0.0", 20 | WithResourceCapabilities(true, true), 21 | WithPromptCapabilities(true), 22 | WithToolCapabilities(true), 23 | WithLogging(), 24 | WithRecovery(), 25 | ) 26 | 27 | // Create a context 28 | ctx := context.Background() 29 | 30 | // Create a sync.WaitGroup to coordinate test goroutines 31 | var wg sync.WaitGroup 32 | 33 | // Define test duration 34 | testDuration := 300 * time.Millisecond 35 | 36 | // Start goroutines to perform concurrent operations 37 | runConcurrentOperation(&wg, testDuration, "add-prompts", func() { 38 | name := fmt.Sprintf("prompt-%d", time.Now().UnixNano()) 39 | srv.AddPrompt(mcp.Prompt{ 40 | Name: name, 41 | Description: "Test prompt", 42 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 43 | return &mcp.GetPromptResult{}, nil 44 | }) 45 | }) 46 | 47 | runConcurrentOperation(&wg, testDuration, "delete-prompts", func() { 48 | name := fmt.Sprintf("delete-prompt-%d", time.Now().UnixNano()) 49 | srv.AddPrompt(mcp.Prompt{ 50 | Name: name, 51 | Description: "Temporary prompt", 52 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 53 | return &mcp.GetPromptResult{}, nil 54 | }) 55 | srv.DeletePrompts(name) 56 | }) 57 | 58 | runConcurrentOperation(&wg, testDuration, "add-tools", func() { 59 | name := fmt.Sprintf("tool-%d", time.Now().UnixNano()) 60 | srv.AddTool(mcp.Tool{ 61 | Name: name, 62 | Description: "Test tool", 63 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 64 | return &mcp.CallToolResult{}, nil 65 | }) 66 | }) 67 | 68 | runConcurrentOperation(&wg, testDuration, "delete-tools", func() { 69 | name := fmt.Sprintf("delete-tool-%d", time.Now().UnixNano()) 70 | // Add and immediately delete 71 | srv.AddTool(mcp.Tool{ 72 | Name: name, 73 | Description: "Temporary tool", 74 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 75 | return &mcp.CallToolResult{}, nil 76 | }) 77 | srv.DeleteTools(name) 78 | }) 79 | 80 | runConcurrentOperation(&wg, testDuration, "add-middleware", func() { 81 | middleware := func(next ToolHandlerFunc) ToolHandlerFunc { 82 | return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 83 | return next(ctx, req) 84 | } 85 | } 86 | WithToolHandlerMiddleware(middleware)(srv) 87 | }) 88 | 89 | runConcurrentOperation(&wg, testDuration, "list-tools", func() { 90 | result, reqErr := srv.handleListTools(ctx, "123", mcp.ListToolsRequest{}) 91 | require.Nil(t, reqErr, "List tools operation should not return an error") 92 | require.NotNil(t, result, "List tools result should not be nil") 93 | }) 94 | 95 | runConcurrentOperation(&wg, testDuration, "list-prompts", func() { 96 | result, reqErr := srv.handleListPrompts(ctx, "123", mcp.ListPromptsRequest{}) 97 | require.Nil(t, reqErr, "List prompts operation should not return an error") 98 | require.NotNil(t, result, "List prompts result should not be nil") 99 | }) 100 | 101 | // Add a persistent tool for testing tool calls 102 | srv.AddTool(mcp.Tool{ 103 | Name: "persistent-tool", 104 | Description: "Test tool that always exists", 105 | }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { 106 | return &mcp.CallToolResult{}, nil 107 | }) 108 | 109 | runConcurrentOperation(&wg, testDuration, "call-tools", func() { 110 | req := mcp.CallToolRequest{} 111 | req.Params.Name = "persistent-tool" 112 | req.Params.Arguments = map[string]any{"param": "test"} 113 | result, reqErr := srv.handleToolCall(ctx, "123", req) 114 | require.Nil(t, reqErr, "Tool call operation should not return an error") 115 | require.NotNil(t, result, "Tool call result should not be nil") 116 | }) 117 | 118 | runConcurrentOperation(&wg, testDuration, "add-resources", func() { 119 | uri := fmt.Sprintf("resource-%d", time.Now().UnixNano()) 120 | srv.AddResource(mcp.Resource{ 121 | URI: uri, 122 | Name: uri, 123 | Description: "Test resource", 124 | }, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { 125 | return []mcp.ResourceContents{ 126 | mcp.TextResourceContents{ 127 | URI: uri, 128 | Text: "Test content", 129 | }, 130 | }, nil 131 | }) 132 | }) 133 | 134 | // Wait for all operations to complete 135 | wg.Wait() 136 | t.Log("No race conditions detected") 137 | } 138 | 139 | // Helper function to run an operation concurrently for a specified duration 140 | func runConcurrentOperation( 141 | wg *sync.WaitGroup, 142 | duration time.Duration, 143 | _ string, 144 | operation func(), 145 | ) { 146 | wg.Add(1) 147 | go func() { 148 | defer wg.Done() 149 | 150 | done := time.After(duration) 151 | for { 152 | select { 153 | case <-done: 154 | return 155 | default: 156 | operation() 157 | } 158 | } 159 | }() 160 | } 161 | 162 | // TestConcurrentPromptAdd specifically tests for the deadlock scenario where adding a prompt 163 | // from a goroutine can cause a deadlock 164 | func TestConcurrentPromptAdd(t *testing.T) { 165 | srv := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) 166 | ctx := context.Background() 167 | 168 | // Add a prompt with a handler that adds another prompt in a goroutine 169 | srv.AddPrompt(mcp.Prompt{ 170 | Name: "initial-prompt", 171 | Description: "Initial prompt", 172 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 173 | go func() { 174 | srv.AddPrompt(mcp.Prompt{ 175 | Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), 176 | Description: "Added from handler", 177 | }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { 178 | return &mcp.GetPromptResult{}, nil 179 | }) 180 | }() 181 | return &mcp.GetPromptResult{}, nil 182 | }) 183 | 184 | // Create request and channel to track completion 185 | req := mcp.GetPromptRequest{} 186 | req.Params.Name = "initial-prompt" 187 | done := make(chan struct{}) 188 | 189 | // Try to get the prompt - this would deadlock with a single mutex 190 | go func() { 191 | result, reqErr := srv.handleGetPrompt(ctx, "123", req) 192 | require.Nil(t, reqErr, "Get prompt operation should not return an error") 193 | require.NotNil(t, result, "Get prompt result should not be nil") 194 | close(done) 195 | }() 196 | 197 | // Assert the operation completes without deadlock 198 | assert.Eventually(t, func() bool { 199 | select { 200 | case <-done: 201 | return true 202 | default: 203 | return false 204 | } 205 | }, 1*time.Second, 10*time.Millisecond, "Deadlock detected: operation did not complete in time") 206 | } 207 | -------------------------------------------------------------------------------- /server/stdio.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "sync/atomic" 13 | "syscall" 14 | 15 | "github.com/mark3labs/mcp-go/mcp" 16 | ) 17 | 18 | // StdioContextFunc is a function that takes an existing context and returns 19 | // a potentially modified context. 20 | // This can be used to inject context values from environment variables, 21 | // for example. 22 | type StdioContextFunc func(ctx context.Context) context.Context 23 | 24 | // StdioServer wraps a MCPServer and handles stdio communication. 25 | // It provides a simple way to create command-line MCP servers that 26 | // communicate via standard input/output streams using JSON-RPC messages. 27 | type StdioServer struct { 28 | server *MCPServer 29 | errLogger *log.Logger 30 | contextFunc StdioContextFunc 31 | } 32 | 33 | // StdioOption defines a function type for configuring StdioServer 34 | type StdioOption func(*StdioServer) 35 | 36 | // WithErrorLogger sets the error logger for the server 37 | func WithErrorLogger(logger *log.Logger) StdioOption { 38 | return func(s *StdioServer) { 39 | s.errLogger = logger 40 | } 41 | } 42 | 43 | // WithStdioContextFunc sets a function that will be called to customise the context 44 | // to the server. Note that the stdio server uses the same context for all requests, 45 | // so this function will only be called once per server instance. 46 | func WithStdioContextFunc(fn StdioContextFunc) StdioOption { 47 | return func(s *StdioServer) { 48 | s.contextFunc = fn 49 | } 50 | } 51 | 52 | // stdioSession is a static client session, since stdio has only one client. 53 | type stdioSession struct { 54 | notifications chan mcp.JSONRPCNotification 55 | initialized atomic.Bool 56 | loggingLevel atomic.Value 57 | clientInfo atomic.Value // stores session-specific client info 58 | } 59 | 60 | func (s *stdioSession) SessionID() string { 61 | return "stdio" 62 | } 63 | 64 | func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { 65 | return s.notifications 66 | } 67 | 68 | func (s *stdioSession) Initialize() { 69 | // set default logging level 70 | s.loggingLevel.Store(mcp.LoggingLevelError) 71 | s.initialized.Store(true) 72 | } 73 | 74 | func (s *stdioSession) Initialized() bool { 75 | return s.initialized.Load() 76 | } 77 | 78 | func (s *stdioSession) GetClientInfo() mcp.Implementation { 79 | if value := s.clientInfo.Load(); value != nil { 80 | if clientInfo, ok := value.(mcp.Implementation); ok { 81 | return clientInfo 82 | } 83 | } 84 | return mcp.Implementation{} 85 | } 86 | 87 | func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { 88 | s.clientInfo.Store(clientInfo) 89 | } 90 | 91 | func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { 92 | s.loggingLevel.Store(level) 93 | } 94 | 95 | func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { 96 | level := s.loggingLevel.Load() 97 | if level == nil { 98 | return mcp.LoggingLevelError 99 | } 100 | return level.(mcp.LoggingLevel) 101 | } 102 | 103 | var ( 104 | _ ClientSession = (*stdioSession)(nil) 105 | _ SessionWithLogging = (*stdioSession)(nil) 106 | _ SessionWithClientInfo = (*stdioSession)(nil) 107 | ) 108 | 109 | var stdioSessionInstance = stdioSession{ 110 | notifications: make(chan mcp.JSONRPCNotification, 100), 111 | } 112 | 113 | // NewStdioServer creates a new stdio server wrapper around an MCPServer. 114 | // It initializes the server with a default error logger that discards all output. 115 | func NewStdioServer(server *MCPServer) *StdioServer { 116 | return &StdioServer{ 117 | server: server, 118 | errLogger: log.New( 119 | os.Stderr, 120 | "", 121 | log.LstdFlags, 122 | ), // Default to discarding logs 123 | } 124 | } 125 | 126 | // SetErrorLogger configures where error messages from the StdioServer are logged. 127 | // The provided logger will receive all error messages generated during server operation. 128 | func (s *StdioServer) SetErrorLogger(logger *log.Logger) { 129 | s.errLogger = logger 130 | } 131 | 132 | // SetContextFunc sets a function that will be called to customise the context 133 | // to the server. Note that the stdio server uses the same context for all requests, 134 | // so this function will only be called once per server instance. 135 | func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { 136 | s.contextFunc = fn 137 | } 138 | 139 | // handleNotifications continuously processes notifications from the session's notification channel 140 | // and writes them to the provided output. It runs until the context is cancelled. 141 | // Any errors encountered while writing notifications are logged but do not stop the handler. 142 | func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { 143 | for { 144 | select { 145 | case notification := <-stdioSessionInstance.notifications: 146 | if err := s.writeResponse(notification, stdout); err != nil { 147 | s.errLogger.Printf("Error writing notification: %v", err) 148 | } 149 | case <-ctx.Done(): 150 | return 151 | } 152 | } 153 | } 154 | 155 | // processInputStream continuously reads and processes messages from the input stream. 156 | // It handles EOF gracefully as a normal termination condition. 157 | // The function returns when either: 158 | // - The context is cancelled (returns context.Err()) 159 | // - EOF is encountered (returns nil) 160 | // - An error occurs while reading or processing messages (returns the error) 161 | func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { 162 | for { 163 | if err := ctx.Err(); err != nil { 164 | return err 165 | } 166 | 167 | line, err := s.readNextLine(ctx, reader) 168 | if err != nil { 169 | if err == io.EOF { 170 | return nil 171 | } 172 | s.errLogger.Printf("Error reading input: %v", err) 173 | return err 174 | } 175 | 176 | if err := s.processMessage(ctx, line, stdout); err != nil { 177 | if err == io.EOF { 178 | return nil 179 | } 180 | s.errLogger.Printf("Error handling message: %v", err) 181 | return err 182 | } 183 | } 184 | } 185 | 186 | // readNextLine reads a single line from the input reader in a context-aware manner. 187 | // It uses channels to make the read operation cancellable via context. 188 | // Returns the read line and any error encountered. If the context is cancelled, 189 | // returns an empty string and the context's error. EOF is returned when the input 190 | // stream is closed. 191 | func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { 192 | type result struct { 193 | line string 194 | err error 195 | } 196 | 197 | resultCh := make(chan result, 1) 198 | 199 | go func() { 200 | line, err := reader.ReadString('\n') 201 | resultCh <- result{line: line, err: err} 202 | }() 203 | 204 | select { 205 | case <-ctx.Done(): 206 | return "", nil 207 | case res := <-resultCh: 208 | return res.line, res.err 209 | } 210 | } 211 | 212 | // Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. 213 | // It runs until the context is cancelled or an error occurs. 214 | // Returns an error if there are issues with reading input or writing output. 215 | func (s *StdioServer) Listen( 216 | ctx context.Context, 217 | stdin io.Reader, 218 | stdout io.Writer, 219 | ) error { 220 | // Set a static client context since stdio only has one client 221 | if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { 222 | return fmt.Errorf("register session: %w", err) 223 | } 224 | defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) 225 | ctx = s.server.WithContext(ctx, &stdioSessionInstance) 226 | 227 | // Add in any custom context. 228 | if s.contextFunc != nil { 229 | ctx = s.contextFunc(ctx) 230 | } 231 | 232 | reader := bufio.NewReader(stdin) 233 | 234 | // Start notification handler 235 | go s.handleNotifications(ctx, stdout) 236 | return s.processInputStream(ctx, reader, stdout) 237 | } 238 | 239 | // processMessage handles a single JSON-RPC message and writes the response. 240 | // It parses the message, processes it through the wrapped MCPServer, and writes any response. 241 | // Returns an error if there are issues with message processing or response writing. 242 | func (s *StdioServer) processMessage( 243 | ctx context.Context, 244 | line string, 245 | writer io.Writer, 246 | ) error { 247 | // Parse the message as raw JSON 248 | var rawMessage json.RawMessage 249 | if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { 250 | response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") 251 | return s.writeResponse(response, writer) 252 | } 253 | 254 | // Handle the message using the wrapped server 255 | response := s.server.HandleMessage(ctx, rawMessage) 256 | 257 | // Only write response if there is one (not for notifications) 258 | if response != nil { 259 | if err := s.writeResponse(response, writer); err != nil { 260 | return fmt.Errorf("failed to write response: %w", err) 261 | } 262 | } 263 | 264 | return nil 265 | } 266 | 267 | // writeResponse marshals and writes a JSON-RPC response message followed by a newline. 268 | // Returns an error if marshaling or writing fails. 269 | func (s *StdioServer) writeResponse( 270 | response mcp.JSONRPCMessage, 271 | writer io.Writer, 272 | ) error { 273 | responseBytes, err := json.Marshal(response) 274 | if err != nil { 275 | return err 276 | } 277 | 278 | // Write response followed by newline 279 | if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { 280 | return err 281 | } 282 | 283 | return nil 284 | } 285 | 286 | // ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. 287 | // It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. 288 | // Returns an error if the server encounters any issues during operation. 289 | func ServeStdio(server *MCPServer, opts ...StdioOption) error { 290 | s := NewStdioServer(server) 291 | 292 | for _, opt := range opts { 293 | opt(s) 294 | } 295 | 296 | ctx, cancel := context.WithCancel(context.Background()) 297 | defer cancel() 298 | 299 | // Set up signal handling 300 | sigChan := make(chan os.Signal, 1) 301 | signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) 302 | 303 | go func() { 304 | <-sigChan 305 | cancel() 306 | }() 307 | 308 | return s.Listen(ctx, os.Stdin, os.Stdout) 309 | } 310 | -------------------------------------------------------------------------------- /server/stdio_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "log" 9 | "os" 10 | "testing" 11 | 12 | "github.com/mark3labs/mcp-go/mcp" 13 | ) 14 | 15 | func TestStdioServer(t *testing.T) { 16 | t.Run("Can instantiate", func(t *testing.T) { 17 | mcpServer := NewMCPServer("test", "1.0.0") 18 | stdioServer := NewStdioServer(mcpServer) 19 | 20 | if stdioServer.server == nil { 21 | t.Error("MCPServer should not be nil") 22 | } 23 | if stdioServer.errLogger == nil { 24 | t.Error("errLogger should not be nil") 25 | } 26 | }) 27 | 28 | t.Run("Can send and receive messages", func(t *testing.T) { 29 | // Create pipes for stdin and stdout 30 | stdinReader, stdinWriter := io.Pipe() 31 | stdoutReader, stdoutWriter := io.Pipe() 32 | 33 | // Create server 34 | mcpServer := NewMCPServer("test", "1.0.0", 35 | WithResourceCapabilities(true, true), 36 | ) 37 | stdioServer := NewStdioServer(mcpServer) 38 | stdioServer.SetErrorLogger(log.New(io.Discard, "", 0)) 39 | 40 | // Create context with cancel 41 | ctx, cancel := context.WithCancel(context.Background()) 42 | defer cancel() 43 | 44 | // Create error channel to catch server errors 45 | serverErrCh := make(chan error, 1) 46 | 47 | // Start server in goroutine 48 | go func() { 49 | err := stdioServer.Listen(ctx, stdinReader, stdoutWriter) 50 | if err != nil && err != io.EOF && err != context.Canceled { 51 | serverErrCh <- err 52 | } 53 | close(serverErrCh) 54 | }() 55 | 56 | // Create test message 57 | initRequest := map[string]any{ 58 | "jsonrpc": "2.0", 59 | "id": 1, 60 | "method": "initialize", 61 | "params": map[string]any{ 62 | "protocolVersion": "2024-11-05", 63 | "clientInfo": map[string]any{ 64 | "name": "test-client", 65 | "version": "1.0.0", 66 | }, 67 | }, 68 | } 69 | 70 | // Send request 71 | requestBytes, err := json.Marshal(initRequest) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | // Read response 81 | scanner := bufio.NewScanner(stdoutReader) 82 | if !scanner.Scan() { 83 | t.Fatal("failed to read response") 84 | } 85 | responseBytes := scanner.Bytes() 86 | 87 | var response map[string]any 88 | if err := json.Unmarshal(responseBytes, &response); err != nil { 89 | t.Fatalf("failed to unmarshal response: %v", err) 90 | } 91 | 92 | // Verify response structure 93 | if response["jsonrpc"] != "2.0" { 94 | t.Errorf("expected jsonrpc version 2.0, got %v", response["jsonrpc"]) 95 | } 96 | if response["id"].(float64) != 1 { 97 | t.Errorf("expected id 1, got %v", response["id"]) 98 | } 99 | if response["error"] != nil { 100 | t.Errorf("unexpected error in response: %v", response["error"]) 101 | } 102 | if response["result"] == nil { 103 | t.Error("expected result in response") 104 | } 105 | 106 | // Clean up 107 | cancel() 108 | stdinWriter.Close() 109 | stdoutWriter.Close() 110 | 111 | // Check for server errors 112 | if err := <-serverErrCh; err != nil { 113 | t.Errorf("unexpected server error: %v", err) 114 | } 115 | }) 116 | 117 | t.Run("Can use a custom context function", func(t *testing.T) { 118 | // Use a custom context key to store a test value. 119 | type testContextKey struct{} 120 | testValFromContext := func(ctx context.Context) string { 121 | val := ctx.Value(testContextKey{}) 122 | if val == nil { 123 | return "" 124 | } 125 | return val.(string) 126 | } 127 | // Create a context function that sets a test value from the environment. 128 | // In real life this could be used to send configuration in a similar way, 129 | // or from a config file. 130 | const testEnvVar = "TEST_ENV_VAR" 131 | setTestValFromEnv := func(ctx context.Context) context.Context { 132 | return context.WithValue(ctx, testContextKey{}, os.Getenv(testEnvVar)) 133 | } 134 | t.Setenv(testEnvVar, "test_value") 135 | 136 | // Create pipes for stdin and stdout 137 | stdinReader, stdinWriter := io.Pipe() 138 | stdoutReader, stdoutWriter := io.Pipe() 139 | 140 | // Create server 141 | mcpServer := NewMCPServer("test", "1.0.0") 142 | // Add a tool which uses the context function. 143 | mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 144 | // Note this is agnostic to the transport type i.e. doesn't know about request headers. 145 | testVal := testValFromContext(ctx) 146 | return mcp.NewToolResultText(testVal), nil 147 | }) 148 | stdioServer := NewStdioServer(mcpServer) 149 | stdioServer.SetErrorLogger(log.New(io.Discard, "", 0)) 150 | stdioServer.SetContextFunc(setTestValFromEnv) 151 | 152 | // Create context with cancel 153 | ctx, cancel := context.WithCancel(context.Background()) 154 | defer cancel() 155 | 156 | // Create error channel to catch server errors 157 | serverErrCh := make(chan error, 1) 158 | 159 | // Start server in goroutine 160 | go func() { 161 | err := stdioServer.Listen(ctx, stdinReader, stdoutWriter) 162 | if err != nil && err != io.EOF && err != context.Canceled { 163 | serverErrCh <- err 164 | } 165 | close(serverErrCh) 166 | }() 167 | 168 | // Create test message 169 | initRequest := map[string]any{ 170 | "jsonrpc": "2.0", 171 | "id": 1, 172 | "method": "initialize", 173 | "params": map[string]any{ 174 | "protocolVersion": "2024-11-05", 175 | "clientInfo": map[string]any{ 176 | "name": "test-client", 177 | "version": "1.0.0", 178 | }, 179 | }, 180 | } 181 | 182 | // Send request 183 | requestBytes, err := json.Marshal(initRequest) 184 | if err != nil { 185 | t.Fatal(err) 186 | } 187 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | 192 | // Read response 193 | scanner := bufio.NewScanner(stdoutReader) 194 | if !scanner.Scan() { 195 | t.Fatal("failed to read response") 196 | } 197 | responseBytes := scanner.Bytes() 198 | 199 | var response map[string]any 200 | if err := json.Unmarshal(responseBytes, &response); err != nil { 201 | t.Fatalf("failed to unmarshal response: %v", err) 202 | } 203 | 204 | // Verify response structure 205 | if response["jsonrpc"] != "2.0" { 206 | t.Errorf("expected jsonrpc version 2.0, got %v", response["jsonrpc"]) 207 | } 208 | if response["id"].(float64) != 1 { 209 | t.Errorf("expected id 1, got %v", response["id"]) 210 | } 211 | if response["error"] != nil { 212 | t.Errorf("unexpected error in response: %v", response["error"]) 213 | } 214 | if response["result"] == nil { 215 | t.Error("expected result in response") 216 | } 217 | 218 | // Call the tool. 219 | toolRequest := map[string]any{ 220 | "jsonrpc": "2.0", 221 | "id": 2, 222 | "method": "tools/call", 223 | "params": map[string]any{ 224 | "name": "test_tool", 225 | }, 226 | } 227 | requestBytes, err = json.Marshal(toolRequest) 228 | if err != nil { 229 | t.Fatalf("Failed to marshal tool request: %v", err) 230 | } 231 | 232 | _, err = stdinWriter.Write(append(requestBytes, '\n')) 233 | if err != nil { 234 | t.Fatal(err) 235 | } 236 | 237 | if !scanner.Scan() { 238 | t.Fatal("failed to read response") 239 | } 240 | responseBytes = scanner.Bytes() 241 | 242 | response = map[string]any{} 243 | if err := json.Unmarshal(responseBytes, &response); err != nil { 244 | t.Fatalf("failed to unmarshal response: %v", err) 245 | } 246 | 247 | if response["jsonrpc"] != "2.0" { 248 | t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) 249 | } 250 | if response["id"].(float64) != 2 { 251 | t.Errorf("Expected id 2, got %v", response["id"]) 252 | } 253 | if response["result"].(map[string]any)["content"].([]any)[0].(map[string]any)["text"] != "test_value" { 254 | t.Errorf("Expected result 'test_value', got %v", response["result"]) 255 | } 256 | if response["error"] != nil { 257 | t.Errorf("Expected no error, got %v", response["error"]) 258 | } 259 | 260 | // Clean up 261 | cancel() 262 | stdinWriter.Close() 263 | stdoutWriter.Close() 264 | 265 | // Check for server errors 266 | if err := <-serverErrCh; err != nil { 267 | t.Errorf("unexpected server error: %v", err) 268 | } 269 | }) 270 | } 271 | -------------------------------------------------------------------------------- /testdata/mockstdio_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "encoding/json" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | 10 | "github.com/mark3labs/mcp-go/mcp" 11 | ) 12 | 13 | type JSONRPCRequest struct { 14 | JSONRPC string `json:"jsonrpc"` 15 | ID *mcp.RequestId `json:"id,omitempty"` 16 | Method string `json:"method"` 17 | Params json.RawMessage `json:"params"` 18 | } 19 | 20 | type JSONRPCResponse struct { 21 | JSONRPC string `json:"jsonrpc"` 22 | ID *mcp.RequestId `json:"id,omitempty"` 23 | Result any `json:"result,omitempty"` 24 | Error *struct { 25 | Code int `json:"code"` 26 | Message string `json:"message"` 27 | } `json:"error,omitempty"` 28 | } 29 | 30 | func main() { 31 | logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) 32 | logger.Info("launch successful") 33 | scanner := bufio.NewScanner(os.Stdin) 34 | for scanner.Scan() { 35 | var request JSONRPCRequest 36 | if err := json.Unmarshal(scanner.Bytes(), &request); err != nil { 37 | continue 38 | } 39 | 40 | response := handleRequest(request) 41 | responseBytes, _ := json.Marshal(response) 42 | fmt.Fprintf(os.Stdout, "%s\n", responseBytes) 43 | } 44 | } 45 | 46 | func handleRequest(request JSONRPCRequest) JSONRPCResponse { 47 | response := JSONRPCResponse{ 48 | JSONRPC: "2.0", 49 | ID: request.ID, 50 | } 51 | 52 | switch request.Method { 53 | case "initialize": 54 | response.Result = map[string]any{ 55 | "protocolVersion": "1.0", 56 | "serverInfo": map[string]any{ 57 | "name": "mock-server", 58 | "version": "1.0.0", 59 | }, 60 | "capabilities": map[string]any{ 61 | "prompts": map[string]any{ 62 | "listChanged": true, 63 | }, 64 | "resources": map[string]any{ 65 | "listChanged": true, 66 | "subscribe": true, 67 | }, 68 | "tools": map[string]any{ 69 | "listChanged": true, 70 | }, 71 | }, 72 | } 73 | case "ping": 74 | response.Result = struct{}{} 75 | case "resources/list": 76 | response.Result = map[string]any{ 77 | "resources": []map[string]any{ 78 | { 79 | "name": "test-resource", 80 | "uri": "test://resource", 81 | }, 82 | }, 83 | } 84 | case "resources/read": 85 | response.Result = map[string]any{ 86 | "contents": []map[string]any{ 87 | { 88 | "text": "test content", 89 | "uri": "test://resource", 90 | }, 91 | }, 92 | } 93 | case "resources/subscribe", "resources/unsubscribe": 94 | response.Result = struct{}{} 95 | case "prompts/list": 96 | response.Result = map[string]any{ 97 | "prompts": []map[string]any{ 98 | { 99 | "name": "test-prompt", 100 | }, 101 | }, 102 | } 103 | case "prompts/get": 104 | response.Result = map[string]any{ 105 | "messages": []map[string]any{ 106 | { 107 | "role": "assistant", 108 | "content": map[string]any{ 109 | "type": "text", 110 | "text": "test message", 111 | }, 112 | }, 113 | }, 114 | } 115 | case "tools/list": 116 | response.Result = map[string]any{ 117 | "tools": []map[string]any{ 118 | { 119 | "name": "test-tool", 120 | "inputSchema": map[string]any{ 121 | "type": "object", 122 | }, 123 | }, 124 | }, 125 | } 126 | case "tools/call": 127 | response.Result = map[string]any{ 128 | "content": []map[string]any{ 129 | { 130 | "type": "text", 131 | "text": "tool result", 132 | }, 133 | }, 134 | } 135 | case "logging/setLevel": 136 | response.Result = struct{}{} 137 | case "completion/complete": 138 | response.Result = map[string]any{ 139 | "completion": map[string]any{ 140 | "values": []string{"test completion"}, 141 | }, 142 | } 143 | 144 | // Debug methods for testing transport. 145 | case "debug/echo": 146 | response.Result = request 147 | case "debug/echo_notification": 148 | response.Result = request 149 | 150 | // send notification to client 151 | responseBytes, _ := json.Marshal(map[string]any{ 152 | "jsonrpc": "2.0", 153 | "method": "debug/test", 154 | "params": request, 155 | }) 156 | fmt.Fprintf(os.Stdout, "%s\n", responseBytes) 157 | 158 | case "debug/echo_error_string": 159 | all, _ := json.Marshal(request) 160 | response.Error = &struct { 161 | Code int `json:"code"` 162 | Message string `json:"message"` 163 | }{ 164 | Code: -32601, 165 | Message: string(all), 166 | } 167 | default: 168 | response.Error = &struct { 169 | Code int `json:"code"` 170 | Message string `json:"message"` 171 | }{ 172 | Code: -32601, 173 | Message: "Method not found", 174 | } 175 | } 176 | 177 | return response 178 | } 179 | -------------------------------------------------------------------------------- /util/logger.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | // Logger defines a minimal logging interface 8 | type Logger interface { 9 | Infof(format string, v ...any) 10 | Errorf(format string, v ...any) 11 | } 12 | 13 | // --- Standard Library Logger Wrapper --- 14 | 15 | // DefaultStdLogger implements Logger using the standard library's log.Logger. 16 | func DefaultLogger() Logger { 17 | return &stdLogger{ 18 | logger: log.Default(), 19 | } 20 | } 21 | 22 | // stdLogger wraps the standard library's log.Logger. 23 | type stdLogger struct { 24 | logger *log.Logger 25 | } 26 | 27 | func (l *stdLogger) Infof(format string, v ...any) { 28 | l.logger.Printf("INFO: "+format, v...) 29 | } 30 | 31 | func (l *stdLogger) Errorf(format string, v ...any) { 32 | l.logger.Printf("ERROR: "+format, v...) 33 | } 34 | -------------------------------------------------------------------------------- /www/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | docs/dist 3 | -------------------------------------------------------------------------------- /www/README.md: -------------------------------------------------------------------------------- 1 | This is a [Vocs](https://vocs.dev) project bootstrapped with the Vocs CLI. 2 | -------------------------------------------------------------------------------- /www/docs/pages/example.mdx: -------------------------------------------------------------------------------- 1 | # Example 2 | 3 | This is an example page. -------------------------------------------------------------------------------- /www/docs/pages/getting-started.mdx: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | MCP-Go makes it easy to build Model Context Protocol (MCP) servers in Go. This guide will help you create your first MCP server in just a few minutes. 4 | 5 | ## Installation 6 | 7 | Add MCP-Go to your Go project: 8 | 9 | ```bash 10 | go get github.com/mark3labs/mcp-go 11 | ``` 12 | 13 | ## Your First MCP Server 14 | 15 | Let's create a simple MCP server with a "hello world" tool: 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "context" 22 | "fmt" 23 | 24 | "github.com/mark3labs/mcp-go/mcp" 25 | "github.com/mark3labs/mcp-go/server" 26 | ) 27 | 28 | func main() { 29 | // Create a new MCP server 30 | s := server.NewMCPServer( 31 | "Demo 🚀", 32 | "1.0.0", 33 | server.WithToolCapabilities(false), 34 | ) 35 | 36 | // Add tool 37 | tool := mcp.NewTool("hello_world", 38 | mcp.WithDescription("Say hello to someone"), 39 | mcp.WithString("name", 40 | mcp.Required(), 41 | mcp.Description("Name of the person to greet"), 42 | ), 43 | ) 44 | 45 | // Add tool handler 46 | s.AddTool(tool, helloHandler) 47 | 48 | // Start the stdio server 49 | if err := server.ServeStdio(s); err != nil { 50 | fmt.Printf("Server error: %v\n", err) 51 | } 52 | } 53 | 54 | func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { 55 | name, err := request.RequireString("name") 56 | if err != nil { 57 | return mcp.NewToolResultError(err.Error()), nil 58 | } 59 | 60 | return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil 61 | } 62 | ``` 63 | 64 | ## Running Your Server 65 | 66 | 1. Save the code above to a file (e.g., `main.go`) 67 | 2. Run it with: 68 | ```bash 69 | go run main.go 70 | ``` 71 | 72 | Your MCP server is now running and ready to accept connections via stdio! 73 | 74 | ## What's Next? 75 | 76 | Now that you have a basic server running, you can: 77 | 78 | - **Add more tools** - Create tools for calculations, file operations, API calls, etc. 79 | - **Add resources** - Expose data sources like files, databases, or APIs 80 | - **Add prompts** - Create reusable prompt templates for better LLM interactions 81 | - **Explore examples** - Check out the `examples/` directory for more complex use cases 82 | 83 | ## Key Concepts 84 | 85 | ### Tools 86 | Tools let LLMs take actions through your server. They're like functions that the LLM can call: 87 | 88 | ```go 89 | calculatorTool := mcp.NewTool("calculate", 90 | mcp.WithDescription("Perform basic arithmetic operations"), 91 | mcp.WithString("operation", 92 | mcp.Required(), 93 | mcp.Enum("add", "subtract", "multiply", "divide"), 94 | ), 95 | mcp.WithNumber("x", mcp.Required()), 96 | mcp.WithNumber("y", mcp.Required()), 97 | ) 98 | ``` 99 | 100 | ### Resources 101 | Resources expose data to LLMs. They can be static files or dynamic data: 102 | 103 | ```go 104 | resource := mcp.NewResource( 105 | "docs://readme", 106 | "Project README", 107 | mcp.WithResourceDescription("The project's README file"), 108 | mcp.WithMIMEType("text/markdown"), 109 | ) 110 | ``` 111 | 112 | ### Server Options 113 | Customize your server with various options: 114 | 115 | ```go 116 | s := server.NewMCPServer( 117 | "My Server", 118 | "1.0.0", 119 | server.WithToolCapabilities(true), 120 | server.WithRecovery(), 121 | server.WithHooks(myHooks), 122 | ) 123 | ``` 124 | 125 | ## Transport Options 126 | 127 | MCP-Go supports multiple transport methods: 128 | 129 | - **Stdio** (most common): `server.ServeStdio(s)` 130 | - **HTTP**: `server.ServeHTTP(s, ":8080")` 131 | - **Server-Sent Events**: `server.ServeSSE(s, ":8080")` 132 | 133 | ## Need Help? 134 | 135 | - Check out the [examples](https://github.com/mark3labs/mcp-go/tree/main/examples) for more complex use cases 136 | - Join the discussion on [Discord](https://discord.gg/RqSS2NQVsY) 137 | - Read the full documentation in the [README](https://github.com/mark3labs/mcp-go/blob/main/README.md) -------------------------------------------------------------------------------- /www/docs/pages/index.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | layout: landing 3 | --- 4 | 5 | import { HomePage } from 'vocs/components' 6 | 7 | 8 | 9 | MCP-Go 10 | 11 | A Go implementation of the Model Context Protocol (MCP), enabling seamless integration between LLM applications and external data sources and tools. Build powerful MCP servers with minimal boilerplate and focus on creating great tools. 12 | 13 | 14 | Get started 15 | GitHub 16 | 17 | -------------------------------------------------------------------------------- /www/docs/public/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mark3labs/mcp-go/2cbaebf51e3629d9409d14296dc0f3410b2013e2/www/docs/public/logo.png -------------------------------------------------------------------------------- /www/docs/styles.css: -------------------------------------------------------------------------------- 1 | .vocs_HomePage_logo { 2 | height: auto; 3 | max-width: 100%; 4 | object-fit: contain; 5 | } 6 | -------------------------------------------------------------------------------- /www/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mcp-go", 3 | "private": true, 4 | "version": "0.0.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev": "vocs dev", 8 | "build": "vocs build", 9 | "preview": "vocs preview" 10 | }, 11 | "dependencies": { 12 | "react": "latest", 13 | "react-dom": "latest", 14 | "vocs": "latest" 15 | }, 16 | "devDependencies": { 17 | "@types/react": "latest", 18 | "typescript": "latest" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /www/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "useDefineForClassFields": true, 5 | "lib": ["ES2020", "DOM", "DOM.Iterable"], 6 | "module": "ESNext", 7 | "skipLibCheck": true, 8 | 9 | /* Bundler mode */ 10 | "moduleResolution": "bundler", 11 | "allowImportingTsExtensions": true, 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "noEmit": true, 15 | "jsx": "react-jsx", 16 | 17 | /* Linting */ 18 | "strict": true, 19 | "noUnusedLocals": true, 20 | "noUnusedParameters": true, 21 | "noFallthroughCasesInSwitch": true 22 | }, 23 | "include": ["**/*.ts", "**/*.tsx"] 24 | } 25 | -------------------------------------------------------------------------------- /www/vocs.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vocs' 2 | 3 | export default defineConfig({ 4 | title: 'MCP-Go', 5 | baseUrl: 'https://mcp-go.dev', 6 | basePath: '/', 7 | logoUrl: '/logo.png', 8 | description: 'A Go implementation of the Model Context Protocol (MCP), enabling seamless integration between LLM applications and external data sources and tools.', 9 | sidebar: [ 10 | { 11 | text: 'Getting Started', 12 | link: '/getting-started', 13 | }, 14 | { 15 | text: 'Example', 16 | link: '/example', 17 | }, 18 | ], 19 | socials: [ 20 | { 21 | icon: 'github', 22 | link: 'https://github.com/mark3labs/mcp-go', 23 | }, 24 | ], 25 | }) 26 | --------------------------------------------------------------------------------