├── .chglog ├── CHANGELOG.tpl.md └── config.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── codeql.yml │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .golangci.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── README_CN.md ├── README_TW.md ├── README_VI.md ├── client ├── call.go ├── client.go ├── client_test.go ├── handle.go ├── notify_handler.go ├── receive.go └── send.go ├── codecov.yml ├── docs ├── design.md ├── design_cn.md └── images │ ├── img.png │ ├── img_1.png │ ├── img_2.png │ └── wechat_qrcode.png ├── examples ├── README.md ├── auth_tool │ └── main.go ├── current_time_server │ └── main.go ├── everything │ └── main.go ├── filesystem_client │ └── main.go └── http_handler │ └── main.go ├── go.mod ├── go.sum ├── hack ├── .lintcheck_failures ├── .test_ignored_files ├── resolve-modules.sh ├── tools.sh └── util.sh ├── pkg ├── atomic.go ├── context.go ├── errors.go ├── helper.go ├── json.go ├── limiter.go ├── log.go └── sync_map.go ├── protocol ├── cancellation.go ├── completion.go ├── initialize.go ├── jsonrpc.go ├── logging.go ├── pagination.go ├── pagination_test.go ├── ping.go ├── progress.go ├── prompts.go ├── resources.go ├── roots.go ├── sampling.go ├── schema_generate.go ├── schema_generate_test.go ├── schema_validate.go ├── schema_validate_test.go ├── tools.go └── types.go ├── server ├── call.go ├── context.go ├── handle.go ├── receive.go ├── send.go ├── server.go ├── server_test.go └── session │ ├── manager.go │ └── state.go ├── testdata └── mock_block_server.go ├── tests ├── .DS_Store ├── sse_test.go ├── stdio_test.go ├── streamable_http_test.go └── utils.go └── transport ├── mock_client.go ├── mock_server.go ├── mock_test.go ├── sse_client.go ├── sse_server.go ├── sse_test.go ├── stdio_client.go ├── stdio_server.go ├── stdio_test.go ├── streamable_http_client.go ├── streamable_http_server.go ├── streamable_http_test.go ├── transport.go └── transport_test.go /.chglog/CHANGELOG.tpl.md: -------------------------------------------------------------------------------- 1 | {{ range .Versions }} 2 | 3 | ## {{ if .Tag.Previous }}[{{ .Tag.Name }}]({{ $.Info.RepositoryURL }}/compare/{{ .Tag.Previous.Name }}...{{ .Tag.Name }}){{ else }}{{ .Tag.Name }}{{ end }} ({{ datetime "2006-01-02" .Tag.Date }}) 4 | 5 | {{ range .CommitGroups -}} 6 | ### {{ .Title }} 7 | 8 | {{ range .Commits -}} 9 | * {{ if .Scope }}**{{ .Scope }}:** {{ end }}{{ .Subject }} ({{ .Hash.Short }}) 10 | {{ end }} 11 | {{ end -}} 12 | 13 | {{- if .RevertCommits -}} 14 | ### Reverts 15 | 16 | {{ range .RevertCommits -}} 17 | * {{ .Revert.Header }} 18 | {{ end }} 19 | {{ end -}} 20 | 21 | {{- if .NoteGroups -}} 22 | {{ range .NoteGroups -}} 23 | ### {{ .Title }} 24 | 25 | {{ range .Notes }} 26 | {{ .Body }} 27 | {{ end }} 28 | {{ end -}} 29 | {{ end -}} 30 | {{ end -}} -------------------------------------------------------------------------------- /.chglog/config.yml: -------------------------------------------------------------------------------- 1 | style: gitlab 2 | template: CHANGELOG.tpl.md 3 | info: 4 | title: CHANGELOG 5 | repository_url: https://github.com/ThinkInAIXYZ/go-mcp 6 | options: 7 | commits: 8 | # filters: 9 | # Type: 10 | # - feat 11 | # - fix 12 | # - perf 13 | # - refactor 14 | commit_groups: 15 | # title_maps: 16 | # feat: Features 17 | # fix: Bug Fixes 18 | # perf: Performance Improvements 19 | # refactor: Code Refactoring 20 | header: 21 | pattern: "^(\\w*)(?:\\(([\\S\\s]*)\\))?\\:(.*)$" 22 | pattern_maps: 23 | - Type 24 | - Scope 25 | - Subject 26 | notes: 27 | keywords: 28 | - BREAKING CHANGE -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '[BUG] ' 5 | labels: 'bug' 6 | assignees: '' 7 | --- 8 | 9 | ## Describe the bug 10 | A clear and concise description of what the bug is. 11 | 12 | ## To Reproduce 13 | Steps to reproduce the behavior: 14 | 1. Initialize with '...' 15 | 2. Call function '....' 16 | 3. Pass parameter '....' 17 | 4. See error 18 | 19 | ## Expected behavior 20 | A clear and concise description of what you expected to happen. 21 | 22 | ## Environment 23 | - Go version: [e.g. 1.21.0] 24 | - go-mcp version: [e.g. v1.0.0] 25 | - OS: [e.g. Ubuntu 22.04, macOS 14.0] 26 | 27 | ## Additional context 28 | Add any other context about the problem here, such as logs or error messages. 29 | 30 | ## Possible Solution 31 | If you have ideas on how to fix this issue, please share them here. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '[FEATURE] ' 5 | labels: 'enhancement' 6 | assignees: '' 7 | --- 8 | 9 | ## Is your feature request related to a problem? Please describe. 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | ## Describe the solution you'd like 13 | A clear and concise description of what you want to happen. 14 | 15 | ## Describe alternatives you've considered 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | ## How would this feature benefit the Go-Mcp project and its users? 19 | Explain the value this feature would bring to the project and the community. 20 | 21 | ## Additional context 22 | Add any other context or screenshots about the feature request here. -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Please provide a brief description of the changes in this pull request. 3 | 4 | ## Related Issue 5 | Fixes #(issue) 6 | 7 | ## Type of change 8 | Please delete options that are not relevant. 9 | 10 | - [ ] Bug fix (non-breaking change which fixes an issue) 11 | - [ ] New feature (non-breaking change which adds functionality) 12 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 13 | - [ ] Documentation update 14 | - [ ] Performance improvement 15 | - [ ] Code refactoring (no functional changes) 16 | 17 | ## How Has This Been Tested? 18 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. 19 | 20 | ## Checklist: 21 | - [ ] My code follows the style guidelines of this project 22 | - [ ] I have performed a self-review of my own code 23 | - [ ] I have commented my code, particularly in hard-to-understand areas 24 | - [ ] I have made corresponding changes to the documentation 25 | - [ ] My changes generate no new warnings 26 | - [ ] I have added tests that prove my fix is effective or that my feature works 27 | - [ ] New and existing unit tests pass locally with my changes -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL Advanced" 13 | permissions: 14 | contents: read 15 | pull-requests: write 16 | 17 | on: 18 | push: 19 | branches: [ "main" ] 20 | pull_request: 21 | branches: [ "main" ] 22 | schedule: 23 | - cron: '29 5 * * 3' 24 | 25 | jobs: 26 | analyze: 27 | name: Analyze (${{ matrix.language }}) 28 | # Runner size impacts CodeQL analysis time. To learn more, please see: 29 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 30 | # - https://gh.io/supported-runners-and-hardware-resources 31 | # - https://gh.io/using-larger-runners (GitHub.com only) 32 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 33 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 34 | permissions: 35 | # required for all workflows 36 | security-events: write 37 | 38 | # required to fetch internal or private CodeQL packs 39 | packages: read 40 | 41 | # only required for workflows in private repositories 42 | actions: read 43 | contents: read 44 | 45 | strategy: 46 | fail-fast: false 47 | matrix: 48 | include: 49 | - language: go 50 | build-mode: autobuild 51 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 52 | # Use `c-cpp` to analyze code written in C, C++ or both 53 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 54 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 55 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 56 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 57 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 58 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 59 | steps: 60 | - name: Checkout repository 61 | uses: actions/checkout@v4 62 | 63 | # Add any setup steps before running the `github/codeql-action/init` action. 64 | # This includes steps like installing compilers or runtimes (`actions/setup-node` 65 | # or others). This is typically only required for manual builds. 66 | # - name: Setup runtime (example) 67 | # uses: actions/setup-example@v1 68 | 69 | # Initializes the CodeQL tools for scanning. 70 | - name: Initialize CodeQL 71 | uses: github/codeql-action/init@v3 72 | with: 73 | languages: ${{ matrix.language }} 74 | build-mode: ${{ matrix.build-mode }} 75 | # If you wish to specify custom queries, you can do so here or in a config file. 76 | # By default, queries listed here will override any specified in a config file. 77 | # Prefix the list here with "+" to use these queries and those in the config file. 78 | 79 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 80 | # queries: security-extended,security-and-quality 81 | 82 | # If the analyze step fails for one of the languages you are analyzing with 83 | # "We were unable to automatically build your code", modify the matrix above 84 | # to set the build mode to "manual" for that language. Then modify this step 85 | # to build your code. 86 | # ℹ️ Command-line programs to run using the OS shell. 87 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 88 | - if: matrix.build-mode == 'manual' 89 | shell: bash 90 | run: | 91 | echo 'If you are using a "manual" build mode for one or more of the' \ 92 | 'languages you are analyzing, replace this with the commands to build' \ 93 | 'your code, for example:' 94 | echo ' make bootstrap' 95 | echo ' make release' 96 | exit 1 97 | 98 | - name: Perform CodeQL Analysis 99 | uses: github/codeql-action/analyze@v3 100 | with: 101 | category: "/language:${{matrix.language}}" 102 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | push: 7 | pull_request: 8 | branches: 9 | - main 10 | workflow_dispatch: 11 | 12 | jobs: 13 | resolve-modules: 14 | name: resolve module 15 | runs-on: ubuntu-latest 16 | outputs: 17 | matrix: ${{ steps.set-matrix.outputs.matrix }} 18 | steps: 19 | - name: Checkout Repo 20 | uses: actions/checkout@v4 21 | 22 | - id: set-matrix 23 | run: ./hack/resolve-modules.sh 24 | 25 | lint: 26 | name: lint module 27 | runs-on: ubuntu-latest 28 | needs: resolve-modules 29 | strategy: 30 | matrix: ${{ fromJson(needs.resolve-modules.outputs.matrix) }} 31 | steps: 32 | - uses: actions/checkout@v4 33 | - name: Lint 34 | uses: golangci/golangci-lint-action@v7 35 | with: 36 | version: v2.1 37 | working-directory: ${{ matrix.workdir }} 38 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - develop 7 | paths: 8 | - "**/*.go" 9 | - "go.mod" 10 | - "go.sum" 11 | - ".github/workflows/test.yml" 12 | pull_request: 13 | types: [ opened, synchronize, reopened ] 14 | branches: 15 | - main 16 | - develop 17 | paths: 18 | - "**/*.go" 19 | - "go.mod" 20 | - "go.sum" 21 | - ".github/workflows/test.yml" 22 | permissions: 23 | contents: read 24 | jobs: 25 | test: 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | go-version: [1.18,1.22] 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Set up Go 33 | uses: actions/setup-go@v5 34 | with: 35 | go-version: ${{ matrix.go-version }} 36 | - name: Checkout codebase 37 | uses: actions/checkout@v2 38 | with: 39 | fetch-depth: 0 40 | - name: Build 41 | run: go build -v ./... 42 | - name: Test 43 | run: make test-coverage 44 | 45 | - name: Upload coverage reports to Codecov 46 | uses: codecov/codecov-action@v5 47 | with: 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | fail_ci_if_error: true 50 | verbose: true 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | go.work.sum 23 | 24 | # env file 25 | .env 26 | .idea 27 | 28 | bin/ -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | run: 3 | modules-download-mode: readonly 4 | linters: 5 | default: none 6 | enable: 7 | - bodyclose 8 | - copyloopvar 9 | - dogsled 10 | - durationcheck 11 | - errcheck 12 | - goconst 13 | - gocyclo 14 | - govet 15 | - ineffassign 16 | - lll 17 | - misspell 18 | - mnd 19 | - prealloc 20 | - revive 21 | - staticcheck 22 | - unconvert 23 | - unused 24 | - wastedassign 25 | - whitespace 26 | settings: 27 | gocyclo: 28 | min-complexity: 50 29 | govet: 30 | enable: 31 | - shadow 32 | lll: 33 | line-length: 160 34 | misspell: 35 | locale: US 36 | mnd: 37 | checks: 38 | - case 39 | - condition 40 | - return 41 | whitespace: 42 | multi-func: true 43 | exclusions: 44 | generated: lax 45 | presets: 46 | - comments 47 | - common-false-positives 48 | - legacy 49 | - std-error-handling 50 | rules: 51 | - linters: 52 | - goconst 53 | path: (.+)_test\.go 54 | paths: 55 | - third_party$ 56 | - builtin$ 57 | - examples$ 58 | formatters: 59 | enable: 60 | - gofmt 61 | - gofumpt 62 | - goimports 63 | settings: 64 | goimports: 65 | local-prefixes: 66 | - github.com/ThinkInAIXYZ 67 | exclusions: 68 | generated: lax 69 | paths: 70 | - third_party$ 71 | - builtin$ 72 | - examples$ 73 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: detect-private-key 13 | 14 | - repo: https://github.com/dnephin/pre-commit-golang 15 | rev: v0.5.1 16 | hooks: 17 | - id: go-fmt 18 | # - id: go-imports 19 | - id: no-go-testing 20 | - id: golangci-lint 21 | - id: go-unit-tests 22 | - id: go-build 23 | - id: go-mod-tidy 24 | 25 | # go install golang.org/x/tools/cmd/goimports@latest 26 | # brew install golangci-lint 27 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## [v0.1.6](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.5...v0.1.6) (2025-04-11) 4 | 5 | ### Feat 6 | 7 | * update wechat_qrcode image (6fd298c) 8 | * add client to server Heartbeat (1141f58) 9 | * add server to client Heartbeat (7b04354) 10 | 11 | 12 | 13 | ## [v0.1.5](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.4...v0.1.5) (2025-04-11) 14 | 15 | ### Fix 16 | 17 | * Add Debug Logger to display debug information (65a3ad8) 18 | * Avoid service exit due to incorrect input formats in stdio (36cc48d) 19 | 20 | 21 | 22 | ## [v0.1.4](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.3...v0.1.4) (2025-04-10) 23 | 24 | ### Docs 25 | 26 | * add star history (622d582) 27 | * **protocol:** add VerifyAndUnmarshal function (c1ad68c) 28 | 29 | ### Feat 30 | 31 | * **protocol:** add VerifyAndUnmarshal function (42a5071) 32 | 33 | 34 | 35 | ## [v0.1.3](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.2...v0.1.3) (2025-04-09) 36 | 37 | ### Docs 38 | 39 | * reduce wechat code (44cead0) 40 | 41 | ### Examples 42 | 43 | * optimization (51cb740) 44 | 45 | ### Feat 46 | 47 | * readme add wechat qrcode (519f864) 48 | * readme add wechat qrcode (2b19574) 49 | * add feishu link (5481535) 50 | 51 | ### Fix 52 | 53 | * readme image link (c2458b4) 54 | 55 | ### Reverts 56 | 57 | * feat: add feishu link 58 | 59 | 60 | 61 | ## [v0.1.2](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.1...v0.1.2) (2025-04-09) 62 | 63 | ### Feat 64 | 65 | * optimization new tool (f05273e) 66 | 67 | 68 | 69 | ## [v0.1.1](https://github.com/ThinkInAIXYZ/go-mcp/compare/v0.1.0...v0.1.1) (2025-04-08) 70 | 71 | ### Fix 72 | 73 | * json JsonUnmarshal UseInt64 true (0b4d0d9) 74 | 75 | ### Refactor 76 | 77 | * stdio_shutdown (4bc2908) 78 | 79 | 80 | 81 | ## v0.1.0 (2025-04-04) 82 | 83 | ### Docs 84 | 85 | * update readme (f613144) 86 | * update readme (fdfe616) 87 | * update readme (be624d3) 88 | * design.md (8697968) 89 | * translate annotate and readme.md (ae7528f) 90 | * Readme.md add why to do go-mcp (20af48b) 91 | * add architecture design to Readme.md (c27606e) 92 | 93 | ### Fead 94 | 95 | * **stdio_transport:** update stdio transport client (8917ef0) 96 | 97 | ### Feat 98 | 99 | * perfecting the framework (5fcf7f7) 100 | * add pre-commit (ee5da68) 101 | * optimization tool struct (4e2ec36) 102 | * optimization test (e0be9cd) 103 | * init project framework (912b969) 104 | * modify test.yml (0156186) 105 | * downgrade go version (aa02032) 106 | * refactor ServerTransport writeError (298e2ab) 107 | * delete cursor (0e80394) 108 | * refactor ServerTransport close to shutdown (0f717d9) 109 | * client add capabilities check (d384dcf) 110 | * stdio add option (6727a78) 111 | * add (b3fd05c) 112 | * add annotate (1a6e06a) 113 | * replace server sessionID2session (998783c) 114 | * add server gracefully shut down logic (33046bc) 115 | * readme add contributors (15798b9) 116 | * sse start bug (cea2766) 117 | * solve conflict“ (cc374ca) 118 | * merge main (30779d6) 119 | * build part package (de069c0) 120 | * add request response matching (9fbd923) 121 | * add request response matching (9af60e6) 122 | * perfecting the framework (b7e62cd) 123 | * add e2e test (41daa6e) 124 | * perfecting the framework (15ff251) 125 | * add test and example (7a54b78) 126 | * add defer recover (315f5a8) 127 | * optimization desgin.md (df8b27d) 128 | * add logger (e8dace1) 129 | * build server package (debb86a) 130 | * build client package (f0b54d2) 131 | * **stdio_transport:** add stdio client/server transport impl (e4ee22e) 132 | * **transport:** sse transport (7e531b0) 133 | 134 | ### Fix 135 | 136 | * client test (ceef300) 137 | * empty param parse (1d000aa) 138 | * stdio server shutdown bug (28cd9db) 139 | * server and client test (5ccc35a) 140 | * stdio server run (52b31d0) 141 | * sse and client bug (2aaba60) 142 | * sse and server bug (f44e83f) 143 | * test (b43e3e6) 144 | * test (d5dd2ee) 145 | * test: (1cb8a97) 146 | * server test (0a2710e) 147 | * read resource (ba1d4dc) 148 | * ServerTransport interface (52e474c) 149 | * Prevent memory leaks in JSON-RPC client (0c7495c) 150 | * empty param parse (8d0b208) 151 | * transport test (a17e834) 152 | * Shutdown logic (539035c) 153 | * Shutdown cancel (f91a652) 154 | * receive some bug (5aa2bf0) 155 | * code conflict, Merge branch 'main' into feature/stdio (27a9a66) 156 | * some bug (a2897bc) 157 | * update test and sseClient ctx param name (3a2b7ba) 158 | * modify transport_test (773d4de) 159 | * modify transport_test and replace sync.Map with sessionStore in pkg (bd5eb93) 160 | * server listtools (9caede5) 161 | * **stdio_transport:** delete testdata (e0d6e04) 162 | 163 | ### Refactor 164 | 165 | * trasnport detail (6872e82) 166 | * server and client receive to not import (400b3af) 167 | * protocol response to result (0094f7c) 168 | * Simplified logging (e3987b5) 169 | * package name (4cb773e) 170 | * server handle (d8d7f66) 171 | * server Register (bdabd75) 172 | * server initialize (7acf660) 173 | * server receive (9c5f12d) 174 | * client receive (5a085d6) 175 | * part particulars (13c2073) 176 | * client call (1ba8a14) 177 | * client call (6b5245d) 178 | * server call (5b28da4) 179 | * stdio test (8df59dd) 180 | * stdio (61e25f6) 181 | * part particulars (a6884bd) 182 | * test logic (3bf8ebe) 183 | * pkg.JsonUnmarshal add error info (5ff76cc) 184 | 185 | ### Test 186 | 187 | * add TestServerNotify (6b59459) 188 | * add client_test (4f0a4f0) 189 | * add server_test (e39b373) 190 | 191 | ### Reverts 192 | 193 | * Update call.go 194 | 195 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Git Command Introduction 2 | Preparation: If you don't have a GitHub account, you need to create one before proceeding to the next step. 3 | 4 | ## 1 Fork the Code 5 | 1. Visit https://github.com/thinkinaixyz/go-mcp 6 | 2. Click the "Fork" button (located at the top right of the page) 7 | 8 | ## 2 Clone the Code 9 | We generally recommend setting the origin as the official repository and setting up your own upstream. 10 | 11 | If you have enabled SSH on GitHub, we recommend using SSH; otherwise, use HTTPS. The difference between the two is that when using HTTPS, you need to enter authentication information every time you push code to the remote repository. 12 | We strongly recommend always using HTTPS for the official repository to avoid accidental operations. 13 | 14 | ```bash 15 | git clone https://github.com/thinkinaixyz/go-mcp.git 16 | cd go-mcp 17 | git remote add upstream 'git@github.com:/go-mcp.git' 18 | ``` 19 | You can replace "upstream" with any name you like, such as your username, nickname, or simply "me". Remember to make corresponding replacements in subsequent commands. 20 | 21 | ## 3 Sync the Code 22 | Unless you've just cloned the code locally, we need to sync the remote repository's code first. 23 | git fetch 24 | 25 | When not specifying a remote repository, this command will only sync the origin's code. If we need to sync our forked repository, we can add the remote repository name: 26 | git fetch upstream 27 | 28 | ## 4 Create a Feature Branch 29 | When creating a new feature branch, we need to first consider which branch to branch from. 30 | Let's assume we want our new feature to be merged into the `main` branch, or that our new feature should be based on `main`, execute: 31 | ```bash 32 | git checkout -b feature/my-feature origin/main 33 | ``` 34 | This creates a branch that is identical to the code on `origin/main`. 35 | 36 | ## 5 Golint 37 | ```bash 38 | golint $(go list ./... | grep -v /examples/) 39 | golangci-lint run $(go list ./... | grep -v /examples/) 40 | ``` 41 | 42 | ## 6 Go Test 43 | ```bash 44 | go test -v -race $(go list ./... | grep -v /examples/) -coverprofile=coverage.txt -covermode=atomic 45 | ``` 46 | 47 | ## 7 Submit Commit 48 | ```bash 49 | git add . 50 | git commit 51 | git push upstream my-feature 52 | ``` 53 | 54 | ## 8 Submit PR 55 | Visit https://github.com/thinkinaixyz/go-mcp, 56 | Click "Compare" to compare changes and click "Pull request" to submit the PR 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anthropic, PBC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | user := $(shell whoami) 2 | rev := $(shell git rev-parse --short HEAD) 3 | os := $(shell sh -c 'echo $$(uname -s) | cut -c1-5') 4 | 5 | # GOBIN > GOPATH > INSTALLDIR 6 | # Mac OS X 7 | ifeq ($(shell uname),Darwin) 8 | GOBIN := $(shell echo ${GOBIN} | cut -d':' -f1) 9 | GOPATH := $(shell echo $(GOPATH) | cut -d':' -f1) 10 | endif 11 | 12 | # Linux 13 | ifeq ($(os),Linux) 14 | GOBIN := $(shell echo ${GOBIN} | cut -d':' -f1) 15 | GOPATH := $(shell echo $(GOPATH) | cut -d':' -f1) 16 | endif 17 | 18 | # Windows 19 | ifeq ($(os),MINGW) 20 | GOBIN := $(subst \,/,$(GOBIN)) 21 | GOPATH := $(subst \,/,$(GOPATH)) 22 | GOBIN :=/$(shell echo "$(GOBIN)" | cut -d';' -f1 | sed 's/://g') 23 | GOPATH :=/$(shell echo "$(GOPATH)" | cut -d';' -f1 | sed 's/://g') 24 | endif 25 | BIN := "" 26 | 27 | # check GOBIN 28 | ifneq ($(GOBIN),) 29 | BIN=$(GOBIN) 30 | else 31 | # check GOPATH 32 | ifneq ($(GOPATH),) 33 | BIN=$(GOPATH)/bin 34 | endif 35 | endif 36 | 37 | TOOLS_SHELL="./hack/tools.sh" 38 | # golangci-lint 39 | LINTER := bin/golangci-lint 40 | 41 | $(LINTER): 42 | curl -SL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v2.1.1 43 | 44 | .PHONY: init-dev 45 | init-dev: 46 | go install github.com/git-chglog/git-chglog/cmd/git-chglog@latest 47 | go install mvdan.cc/gofumpt@latest 48 | go install golang.org/x/tools/cmd/goimports@latest 49 | 50 | 51 | .PHONY: inspector 52 | inspector: 53 | npx -y @modelcontextprotocol/inspector 54 | 55 | 56 | .PHONY: fmt 57 | fmt: 58 | gofumpt -w -l . 59 | goimports -w -l . 60 | 61 | 62 | .PHONY: clean 63 | clean: 64 | @${TOOLS_SHELL} tidy 65 | @echo "clean finished" 66 | 67 | .PHONY: fix 68 | fix: $(LINTER) 69 | @${TOOLS_SHELL} fix 70 | @echo "lint fix finished" 71 | 72 | .PHONY: test 73 | test: 74 | @${TOOLS_SHELL} test 75 | @echo "go test finished" 76 | 77 | .PHONY: test-coverage 78 | test-coverage: 79 | @${TOOLS_SHELL} test_coverage 80 | @echo "go test with coverage finished" 81 | 82 | .PHONY: lint 83 | lint: $(LINTER) 84 | echo $(os) 85 | @${TOOLS_SHELL} lint 86 | @echo "lint check finished" 87 | 88 | .PHONY: changelog 89 | # 生成 changelog 90 | changelog: 91 | git-chglog -o ./CHANGELOG.md 92 | 93 | # show help 94 | help: 95 | @echo '' 96 | @echo 'Usage:' 97 | @echo ' make [target]' 98 | @echo '' 99 | @echo 'Targets:' 100 | @awk '/^[a-zA-Z\-_0-9]+:/ { \ 101 | helpMessage = match(lastLine, /^# (.*)/); \ 102 | if (helpMessage) { \ 103 | helpCommand = substr($$1, 0, index($$1, ":")-1); \ 104 | helpMessage = substr(lastLine, RSTART + 2, RLENGTH); \ 105 | printf "\033[36m%-22s\033[0m %s\n", helpCommand,helpMessage; \ 106 | } \ 107 | } \ 108 | { lastLine = $$0 }' $(MAKEFILE_LIST) 109 | 110 | .DEFAULT_GOAL := help 111 | -------------------------------------------------------------------------------- /README_TW.md: -------------------------------------------------------------------------------- 1 | # Go-MCP 2 | 3 |
4 | Statusphere logo 5 |
6 |
7 | 8 |

9 | Release 10 | Stars 11 | Forks 12 | Issues 13 | Pull Requests 14 | License 15 | Contributors 16 | Last Commit 17 |

18 |

19 | English 20 |

21 | 22 | ## 🚀 概述 23 | 24 | Go-MCP 是一個強大的 Go 語言版本 MCP SDK,實現 Model Context Protocol (MCP),協助外部系統與 AI 應用之間的無縫溝通。基於 Go 語言的強型別與效能優勢,提供簡潔且符合習慣的 API,方便您將外部系統整合進 AI 應用程式。 25 | 26 | ### ✨ 主要特色 27 | 28 | - 🔄 **完整協議實作**:全面實現 MCP 規範,確保與所有相容服務無縫對接 29 | - 🏗️ **優雅的架構設計**:採用清晰的三層架構,支援雙向通訊,確保程式碼模組化與可擴充性 30 | - 🔌 **與 Web 框架無縫整合**:提供符合 MCP 協議的 http.Handler,讓開發者能將 MCP 整合進服務框架 31 | - 🛡️ **型別安全**:善用 Go 的強型別系統,確保程式碼清晰且高度可維護 32 | - 📦 **簡易部署**:受惠於 Go 的靜態編譯特性,無需複雜的相依管理 33 | - ⚡ **高效能設計**:充分發揮 Go 的並行能力,在各種場景下皆能維持優異效能與低資源消耗 34 | 35 | ## 🛠️ 安裝 36 | 37 | ```bash 38 | go get github.com/ThinkInAIXYZ/go-mcp 39 | ``` 40 | 41 | 需 Go 1.18 或更高版本。 42 | 43 | ## 🎯 快速開始 44 | 45 | ### 客戶端範例 46 | 47 | ```go 48 | package main 49 | 50 | import ( 51 | "context" 52 | "log" 53 | 54 | "github.com/ThinkInAIXYZ/go-mcp/client" 55 | "github.com/ThinkInAIXYZ/go-mcp/transport" 56 | ) 57 | 58 | func main() { 59 | // 建立 SSE 傳輸客戶端 60 | transportClient, err := transport.NewSSEClientTransport("http://127.0.0.1:8080/sse") 61 | if err != nil { 62 | log.Fatalf("建立傳輸客戶端失敗: %v", err) 63 | } 64 | 65 | // 初始化 MCP 客戶端 66 | mcpClient, err := client.NewClient(transportClient) 67 | if err != nil { 68 | log.Fatalf("建立 MCP 客戶端失敗: %v", err) 69 | } 70 | defer mcpClient.Close() 71 | 72 | // 取得可用工具列表 73 | tools, err := mcpClient.ListTools(context.Background()) 74 | if err != nil { 75 | log.Fatalf("取得工具列表失敗: %v", err) 76 | } 77 | log.Printf("可用工具: %+v", tools) 78 | } 79 | ``` 80 | 81 | ### 伺服器範例 82 | 83 | ```go 84 | package main 85 | 86 | import ( 87 | "context" 88 | "fmt" 89 | "log" 90 | "time" 91 | 92 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 93 | "github.com/ThinkInAIXYZ/go-mcp/server" 94 | "github.com/ThinkInAIXYZ/go-mcp/transport" 95 | ) 96 | 97 | type TimeRequest struct { 98 | Timezone string `json:"timezone" description:"時區" required:"true"` // 使用 field tag 描述輸入結構 99 | } 100 | 101 | func main() { 102 | // 建立 SSE 傳輸伺服器 103 | transportServer, err := transport.NewSSEServerTransport("127.0.0.1:8080") 104 | if err != nil { 105 | log.Fatalf("建立傳輸伺服器失敗: %v", err) 106 | } 107 | 108 | // 初始化 MCP 伺服器 109 | mcpServer, err := server.NewServer(transportServer) 110 | if err != nil { 111 | log.Fatalf("建立 MCP 伺服器失敗: %v", err) 112 | } 113 | 114 | // 註冊時間查詢工具 115 | tool, err := protocol.NewTool("current_time", "取得指定時區的目前時間", TimeRequest{}) 116 | if err != nil { 117 | log.Fatalf("建立工具失敗: %v", err) 118 | return 119 | } 120 | mcpServer.RegisterTool(tool, handleTimeRequest) 121 | 122 | // 啟動伺服器 123 | if err = mcpServer.Run(); err != nil { 124 | log.Fatalf("伺服器啟動失敗: %v", err) 125 | } 126 | } 127 | 128 | func handleTimeRequest(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 129 | var timeReq TimeRequest 130 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &timeReq); err != nil { 131 | return nil, err 132 | } 133 | 134 | loc, err := time.LoadLocation(timeReq.Timezone) 135 | if err != nil { 136 | return nil, fmt.Errorf("無效的時區: %v", err) 137 | } 138 | 139 | return &protocol.CallToolResult{ 140 | Content: []protocol.Content{ 141 | &protocol.TextContent{ 142 | Type: "text", 143 | Text: time.Now().In(loc).String(), 144 | }, 145 | }, 146 | }, nil 147 | } 148 | ``` 149 | 150 | ### 與 Gin 框架整合 151 | 152 | ```go 153 | package main 154 | 155 | import ( 156 | "context" 157 | "log" 158 | 159 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 160 | "github.com/ThinkInAIXYZ/go-mcp/server" 161 | "github.com/ThinkInAIXYZ/go-mcp/transport" 162 | "github.com/gin-gonic/gin" 163 | ) 164 | 165 | func main() { 166 | messageEndpointURL := "/message" 167 | 168 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL) 169 | if err != nil { 170 | log.Panicf("建立 SSE 傳輸與處理器失敗: %v", err) 171 | } 172 | 173 | // 建立 MCP 伺服器 174 | mcpServer, _ := server.NewServer(sseTransport) 175 | 176 | // 註冊工具 177 | // mcpServer.RegisterTool(tool, toolHandler) 178 | 179 | // 啟動 MCP 伺服器 180 | go func() { 181 | mcpServer.Run() 182 | }() 183 | 184 | defer mcpServer.Shutdown(context.Background()) 185 | 186 | r := gin.Default() 187 | r.GET("/sse", func(ctx *gin.Context) { 188 | mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request) 189 | }) 190 | r.POST(messageEndpointURL, func(ctx *gin.Context) { 191 | mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request) 192 | }) 193 | 194 | if err = r.Run(":8080"); err != nil { 195 | return 196 | } 197 | } 198 | ``` 199 | 200 | [參考:更完整的範例](https://github.com/ThinkInAIXYZ/go-mcp/blob/main/examples/http_handler/main.go) 201 | 202 | ## 🏗️ 架構設計 203 | 204 | Go-MCP 採用優雅的三層架構設計: 205 | 206 | ![架構總覽](docs/images/img.png) 207 | 208 | 1. **傳輸層**:負責底層通訊實作,支援多種傳輸協定 209 | 2. **協議層**:處理 MCP 協議的編解碼與資料結構定義 210 | 3. **使用者層**:提供友善的客戶端與伺服器 API 211 | 212 | 目前支援的傳輸方式: 213 | 214 | ![傳輸方式](docs/images/img_1.png) 215 | 216 | - **HTTP SSE/POST**:基於 HTTP 的伺服器推播與客戶端請求,適用於 Web 場景 217 | - **Streamable HTTP**:支援 HTTP POST/GET 請求,具備 stateless 與 stateful 兩種模式,stateful 模式利用 SSE 進行多訊息串流傳輸,支援伺服器主動通知與請求 218 | - **Stdio**:基於標準輸入輸出流,適合本地進程間通訊 219 | 220 | 傳輸層採用統一介面抽象,讓新增傳輸方式(如 Streamable HTTP、WebSocket、gRPC)變得簡單直接,且不影響上層程式碼。 221 | 222 | ## 🤝 貢獻方式 223 | 224 | 歡迎各種形式的貢獻!詳情請參閱 [CONTRIBUTING.md](CONTRIBUTING.md)。 225 | 226 | ## 📄 授權條款 227 | 228 | 本專案採用 MIT 授權條款 - 詳見 [LICENSE](LICENSE) 檔案 229 | 230 | ## 📞 聯絡我們 231 | 232 | - **GitHub Issues**:[提交問題](https://github.com/ThinkInAIXYZ/go-mcp/issues) 233 | - **Discord**:點擊[這裡](https://discord.gg/4CSU8HYt)加入用戶群 234 | - **微信社群**: 235 | 236 | ![微信 QR Code](docs/images/wechat_qrcode.png) 237 | 238 | ## ✨ 貢獻者 239 | 240 | 感謝所有為本專案做出貢獻的開發者! 241 | 242 | 243 | Contributors 244 | 245 | 246 | ## 📈 專案趨勢 247 | 248 | [![Star 歷史](https://api.star-history.com/svg?repos=ThinkInAIXYZ/go-mcp&type=Date)](https://www.star-history.com/#ThinkInAIXYZ/go-mcp&Date) 249 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | cmap "github.com/orcaman/concurrent-map/v2" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 12 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 13 | "github.com/ThinkInAIXYZ/go-mcp/transport" 14 | ) 15 | 16 | type Option func(*Client) 17 | 18 | func WithNotifyHandler(handler NotifyHandler) Option { 19 | return func(s *Client) { 20 | s.notifyHandler = handler 21 | } 22 | } 23 | 24 | func WithSamplingHandler(handler SamplingHandler) Option { 25 | return func(s *Client) { 26 | s.samplingHandler = handler 27 | } 28 | } 29 | 30 | func WithClientInfo(info *protocol.Implementation) Option { 31 | return func(s *Client) { 32 | s.clientInfo = info 33 | } 34 | } 35 | 36 | func WithInitTimeout(timeout time.Duration) Option { 37 | return func(s *Client) { 38 | s.initTimeout = timeout 39 | } 40 | } 41 | 42 | func WithLogger(logger pkg.Logger) Option { 43 | return func(s *Client) { 44 | s.logger = logger 45 | } 46 | } 47 | 48 | type Client struct { 49 | transport transport.ClientTransport 50 | 51 | reqID2respChan cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse] 52 | 53 | progressChanRW sync.RWMutex 54 | progressToken2notifyChan map[string]chan<- *protocol.ProgressNotification 55 | 56 | samplingHandler SamplingHandler 57 | 58 | notifyHandler NotifyHandler 59 | 60 | requestID int64 61 | 62 | ready *pkg.AtomicBool 63 | initializationMu sync.Mutex 64 | 65 | clientInfo *protocol.Implementation 66 | clientCapabilities *protocol.ClientCapabilities 67 | 68 | serverCapabilities *protocol.ServerCapabilities 69 | serverInfo *protocol.Implementation 70 | serverInstructions string 71 | 72 | initTimeout time.Duration 73 | 74 | closed chan struct{} 75 | 76 | logger pkg.Logger 77 | } 78 | 79 | func NewClient(t transport.ClientTransport, opts ...Option) (*Client, error) { 80 | client := &Client{ 81 | transport: t, 82 | reqID2respChan: cmap.New[chan *protocol.JSONRPCResponse](), 83 | progressToken2notifyChan: make(map[string]chan<- *protocol.ProgressNotification), 84 | ready: pkg.NewAtomicBool(), 85 | clientInfo: &protocol.Implementation{}, 86 | clientCapabilities: &protocol.ClientCapabilities{}, 87 | initTimeout: time.Second * 30, 88 | closed: make(chan struct{}), 89 | logger: pkg.DefaultLogger, 90 | } 91 | t.SetReceiver(transport.NewClientReceiver(client.receive, client.receiveInterrupt)) 92 | 93 | for _, opt := range opts { 94 | opt(client) 95 | } 96 | 97 | if client.notifyHandler == nil { 98 | h := NewBaseNotifyHandler() 99 | h.Logger = client.logger 100 | client.notifyHandler = h 101 | } 102 | 103 | if client.samplingHandler != nil { 104 | client.clientCapabilities.Sampling = struct{}{} 105 | } 106 | 107 | ctx, cancel := context.WithTimeout(context.Background(), client.initTimeout) 108 | defer cancel() 109 | 110 | if err := client.transport.Start(); err != nil { 111 | return nil, fmt.Errorf("init mcp client transpor start fail: %w", err) 112 | } 113 | 114 | if _, err := client.initialization(ctx, protocol.NewInitializeRequest(client.clientInfo, client.clientCapabilities)); err != nil { 115 | return nil, err 116 | } 117 | 118 | go func() { 119 | defer pkg.Recover() 120 | 121 | ticker := time.NewTicker(time.Minute) 122 | defer ticker.Stop() 123 | 124 | for { 125 | select { 126 | case <-client.closed: 127 | return 128 | case <-ticker.C: 129 | client.sessionDetection() 130 | } 131 | } 132 | }() 133 | 134 | return client, nil 135 | } 136 | 137 | func (client *Client) GetServerCapabilities() protocol.ServerCapabilities { 138 | return *client.serverCapabilities 139 | } 140 | 141 | func (client *Client) GetServerInfo() protocol.Implementation { 142 | return *client.serverInfo 143 | } 144 | 145 | func (client *Client) GetServerInstructions() string { 146 | return client.serverInstructions 147 | } 148 | 149 | func (client *Client) Close() error { 150 | close(client.closed) 151 | 152 | return client.transport.Close() 153 | } 154 | 155 | func (client *Client) sessionDetection() { 156 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 157 | defer cancel() 158 | 159 | if _, err := client.Ping(ctx, protocol.NewPingRequest()); err != nil { 160 | client.logger.Warnf("mcp client ping server fail: %v", err) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /client/handle.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 10 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 11 | ) 12 | 13 | func (client *Client) handleRequestWithPing() (*protocol.PingResult, error) { 14 | return protocol.NewPingResult(), nil 15 | } 16 | 17 | func (client *Client) handleRequestWithCreateMessagesSampling(ctx context.Context, rawParams json.RawMessage) (*protocol.CreateMessageResult, error) { 18 | if client.clientCapabilities.Sampling == nil { 19 | return nil, pkg.ErrClientNotSupport 20 | } 21 | 22 | var request *protocol.CreateMessageRequest 23 | if err := pkg.JSONUnmarshal(rawParams, &request); err != nil { 24 | return nil, err 25 | } 26 | 27 | return client.samplingHandler.CreateMessage(ctx, request) 28 | } 29 | 30 | func (client *Client) handleNotifyWithToolsListChanged(ctx context.Context, rawParams json.RawMessage) error { 31 | notify := &protocol.ToolListChangedNotification{} 32 | if len(rawParams) > 0 { 33 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil { 34 | return err 35 | } 36 | } 37 | return client.notifyHandler.ToolsListChanged(ctx, notify) 38 | } 39 | 40 | func (client *Client) handleNotifyWithPromptsListChanged(ctx context.Context, rawParams json.RawMessage) error { 41 | notify := &protocol.PromptListChangedNotification{} 42 | if len(rawParams) > 0 { 43 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil { 44 | return err 45 | } 46 | } 47 | return client.notifyHandler.PromptListChanged(ctx, notify) 48 | } 49 | 50 | func (client *Client) handleNotifyWithResourcesListChanged(ctx context.Context, rawParams json.RawMessage) error { 51 | notify := &protocol.ResourceListChangedNotification{} 52 | if len(rawParams) > 0 { 53 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil { 54 | return err 55 | } 56 | } 57 | return client.notifyHandler.ResourceListChanged(ctx, notify) 58 | } 59 | 60 | func (client *Client) handleNotifyWithResourcesUpdated(ctx context.Context, rawParams json.RawMessage) error { 61 | notify := &protocol.ResourceUpdatedNotification{} 62 | if len(rawParams) > 0 { 63 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil { 64 | return err 65 | } 66 | } 67 | return client.notifyHandler.ResourcesUpdated(ctx, notify) 68 | } 69 | 70 | func (client *Client) handleNotifyWithProgress(ctx context.Context, rawParams json.RawMessage) error { 71 | notify := &protocol.ProgressNotification{} 72 | if len(rawParams) > 0 { 73 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil { 74 | return err 75 | } 76 | } 77 | client.progressChanRW.RLock() 78 | defer client.progressChanRW.RUnlock() 79 | 80 | ch, ok := client.progressToken2notifyChan[fmt.Sprint(notify.ProgressToken)] 81 | if !ok { 82 | return fmt.Errorf("progress token not found") 83 | } 84 | 85 | ctx, cancel := context.WithTimeout(ctx, time.Second*1) 86 | defer cancel() 87 | 88 | select { 89 | case ch <- notify: 90 | case <-ctx.Done(): 91 | return ctx.Err() 92 | } 93 | return nil 94 | } 95 | -------------------------------------------------------------------------------- /client/notify_handler.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 8 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 9 | ) 10 | 11 | type SamplingHandler interface { 12 | CreateMessage(ctx context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error) 13 | } 14 | 15 | // NotifyHandler 16 | // When implementing a custom NotifyHandler, you can combine it with BaseNotifyHandler to implement it on demand without implementing extra methods. 17 | type NotifyHandler interface { 18 | ToolsListChanged(ctx context.Context, request *protocol.ToolListChangedNotification) error 19 | PromptListChanged(ctx context.Context, request *protocol.PromptListChangedNotification) error 20 | ResourceListChanged(ctx context.Context, request *protocol.ResourceListChangedNotification) error 21 | ResourcesUpdated(ctx context.Context, request *protocol.ResourceUpdatedNotification) error 22 | } 23 | 24 | type BaseNotifyHandler struct { 25 | Logger pkg.Logger 26 | } 27 | 28 | func NewBaseNotifyHandler() *BaseNotifyHandler { 29 | return &BaseNotifyHandler{pkg.DefaultLogger} 30 | } 31 | 32 | func (handler *BaseNotifyHandler) ToolsListChanged(_ context.Context, request *protocol.ToolListChangedNotification) error { 33 | return handler.defaultNotifyHandler(protocol.NotificationToolsListChanged, request) 34 | } 35 | 36 | func (handler *BaseNotifyHandler) PromptListChanged(_ context.Context, request *protocol.PromptListChangedNotification) error { 37 | return handler.defaultNotifyHandler(protocol.NotificationPromptsListChanged, request) 38 | } 39 | 40 | func (handler *BaseNotifyHandler) ResourceListChanged(_ context.Context, request *protocol.ResourceListChangedNotification) error { 41 | return handler.defaultNotifyHandler(protocol.NotificationResourcesListChanged, request) 42 | } 43 | 44 | func (handler *BaseNotifyHandler) ResourcesUpdated(_ context.Context, request *protocol.ResourceUpdatedNotification) error { 45 | return handler.defaultNotifyHandler(protocol.NotificationResourcesUpdated, request) 46 | } 47 | 48 | func (handler *BaseNotifyHandler) defaultNotifyHandler(method protocol.Method, notify interface{}) error { 49 | b, err := json.Marshal(notify) 50 | if err != nil { 51 | return err 52 | } 53 | handler.Logger.Infof("receive notify: method=%s, notify=%s", method, b) 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /client/receive.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/tidwall/gjson" 9 | 10 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 11 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 12 | ) 13 | 14 | func (client *Client) receive(ctx context.Context, msg []byte) error { 15 | defer pkg.Recover() 16 | 17 | ctx = pkg.NewCancelShieldContext(ctx) 18 | 19 | if !gjson.GetBytes(msg, "id").Exists() { 20 | notify := &protocol.JSONRPCNotification{} 21 | if err := pkg.JSONUnmarshal(msg, ¬ify); err != nil { 22 | return err 23 | } 24 | if notify.Method == protocol.NotificationProgress { // need sync handle 25 | if err := client.receiveNotify(ctx, notify); err != nil { 26 | notify.RawParams = nil // simplified log 27 | client.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) 28 | return err 29 | } 30 | return nil 31 | } 32 | go func() { 33 | defer pkg.Recover() 34 | 35 | if err := client.receiveNotify(ctx, notify); err != nil { 36 | notify.RawParams = nil // simplified log 37 | client.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) 38 | return 39 | } 40 | }() 41 | return nil 42 | } 43 | 44 | // Determine if it's a request or response 45 | if !gjson.GetBytes(msg, "method").Exists() { 46 | resp := &protocol.JSONRPCResponse{} 47 | if err := pkg.JSONUnmarshal(msg, &resp); err != nil { 48 | return err 49 | } 50 | if err := client.receiveResponse(resp); err != nil { 51 | resp.RawResult = nil // simplified log 52 | client.logger.Errorf("receive response:%+v error: %s", resp, err.Error()) 53 | return err 54 | } 55 | return nil 56 | } 57 | 58 | req := &protocol.JSONRPCRequest{} 59 | if err := pkg.JSONUnmarshal(msg, &req); err != nil { 60 | return err 61 | } 62 | if !req.IsValid() { 63 | return pkg.ErrRequestInvalid 64 | } 65 | go func() { 66 | defer pkg.Recover() 67 | 68 | if err := client.receiveRequest(ctx, req); err != nil { 69 | req.RawParams = nil // simplified log 70 | client.logger.Errorf("receive request:%+v error: %s", req, err.Error()) 71 | return 72 | } 73 | }() 74 | return nil 75 | } 76 | 77 | func (client *Client) receiveRequest(ctx context.Context, request *protocol.JSONRPCRequest) error { 78 | var ( 79 | result protocol.ClientResponse 80 | err error 81 | ) 82 | 83 | switch request.Method { 84 | case protocol.Ping: 85 | result, err = client.handleRequestWithPing() 86 | // case protocol.RootsList: 87 | // result, err = client.handleRequestWithListRoots(ctx, request.RawParams) 88 | case protocol.SamplingCreateMessage: 89 | result, err = client.handleRequestWithCreateMessagesSampling(ctx, request.RawParams) 90 | default: 91 | err = fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, request.Method) 92 | } 93 | 94 | if err != nil { 95 | switch { 96 | case errors.Is(err, pkg.ErrMethodNotSupport): 97 | return client.sendMsgWithError(ctx, request.ID, protocol.MethodNotFound, err.Error()) 98 | case errors.Is(err, pkg.ErrRequestInvalid): 99 | return client.sendMsgWithError(ctx, request.ID, protocol.InvalidRequest, err.Error()) 100 | case errors.Is(err, pkg.ErrJSONUnmarshal): 101 | return client.sendMsgWithError(ctx, request.ID, protocol.ParseError, err.Error()) 102 | default: 103 | return client.sendMsgWithError(ctx, request.ID, protocol.InternalError, err.Error()) 104 | } 105 | } 106 | return client.sendMsgWithResponse(ctx, request.ID, result) 107 | } 108 | 109 | func (client *Client) receiveNotify(ctx context.Context, notify *protocol.JSONRPCNotification) error { 110 | switch notify.Method { 111 | case protocol.NotificationToolsListChanged: 112 | return client.handleNotifyWithToolsListChanged(ctx, notify.RawParams) 113 | case protocol.NotificationPromptsListChanged: 114 | return client.handleNotifyWithPromptsListChanged(ctx, notify.RawParams) 115 | case protocol.NotificationResourcesListChanged: 116 | return client.handleNotifyWithResourcesListChanged(ctx, notify.RawParams) 117 | case protocol.NotificationResourcesUpdated: 118 | return client.handleNotifyWithResourcesUpdated(ctx, notify.RawParams) 119 | case protocol.NotificationProgress: 120 | return client.handleNotifyWithProgress(ctx, notify.RawParams) 121 | default: 122 | return fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, notify.Method) 123 | } 124 | } 125 | 126 | func (client *Client) receiveResponse(response *protocol.JSONRPCResponse) error { 127 | respChan, ok := client.reqID2respChan.Get(fmt.Sprint(response.ID)) 128 | if !ok { 129 | return fmt.Errorf("%w: requestID=%+v", pkg.ErrLackResponseChan, response.ID) 130 | } 131 | 132 | select { 133 | case respChan <- response: 134 | default: 135 | return fmt.Errorf("%w: response=%+v", pkg.ErrDuplicateResponseReceived, response) 136 | } 137 | return nil 138 | } 139 | 140 | func (client *Client) receiveInterrupt(err error) { 141 | for reqID, respChan := range client.reqID2respChan.Items() { 142 | select { 143 | case respChan <- protocol.NewJSONRPCErrorResponse(reqID, protocol.ConnectionError, err.Error()): 144 | default: 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /client/send.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | 9 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 10 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 11 | ) 12 | 13 | func (client *Client) sendMsgWithRequest(ctx context.Context, requestID protocol.RequestID, method protocol.Method, params protocol.ClientRequest) error { 14 | if requestID == nil { 15 | return fmt.Errorf("requestID can't is nil") 16 | } 17 | 18 | req := protocol.NewJSONRPCRequest(requestID, method, params) 19 | 20 | message, err := json.Marshal(req) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | if err = client.transport.Send(ctx, message); err != nil { 26 | if !errors.Is(err, pkg.ErrSessionClosed) { 27 | return fmt.Errorf("sendRequest: transport send: %w", err) 28 | } 29 | if err = client.againInitialization(ctx); err != nil { 30 | return err 31 | } 32 | } 33 | return nil 34 | } 35 | 36 | func (client *Client) sendMsgWithResponse(ctx context.Context, requestID protocol.RequestID, result protocol.ClientResponse) error { 37 | if requestID == nil { 38 | return fmt.Errorf("requestID can't is nil") 39 | } 40 | 41 | resp := protocol.NewJSONRPCSuccessResponse(requestID, result) 42 | 43 | message, err := json.Marshal(resp) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | if err = client.transport.Send(ctx, message); err != nil { 49 | return fmt.Errorf("sendResponse: transport send: %w", err) 50 | } 51 | return nil 52 | } 53 | 54 | func (client *Client) sendMsgWithNotification(ctx context.Context, method protocol.Method, params protocol.ClientNotify) error { 55 | notify := protocol.NewJSONRPCNotification(method, params) 56 | 57 | message, err := json.Marshal(notify) 58 | if err != nil { 59 | return err 60 | } 61 | 62 | if err = client.transport.Send(ctx, message); err != nil { 63 | return fmt.Errorf("sendNotification: transport send: %w", err) 64 | } 65 | return nil 66 | } 67 | 68 | func (client *Client) sendMsgWithError(ctx context.Context, requestID protocol.RequestID, code int, msg string) error { 69 | if requestID == nil { 70 | return fmt.Errorf("requestID can't is nil") 71 | } 72 | 73 | resp := protocol.NewJSONRPCErrorResponse(requestID, code, msg) 74 | 75 | message, err := json.Marshal(resp) 76 | if err != nil { 77 | return err 78 | } 79 | 80 | if err = client.transport.Send(ctx, message); err != nil { 81 | return fmt.Errorf("sendResponse: transport send: %w", err) 82 | } 83 | return nil 84 | } 85 | 86 | func (client *Client) againInitialization(ctx context.Context) error { 87 | client.ready.Store(false) 88 | 89 | client.initializationMu.Lock() 90 | defer client.initializationMu.Unlock() 91 | 92 | if client.ready.Load() { 93 | return nil 94 | } 95 | 96 | if _, err := client.initialization(ctx, protocol.NewInitializeRequest(client.clientInfo, client.clientCapabilities)); err != nil { 97 | return err 98 | } 99 | client.ready.Store(true) 100 | return nil 101 | } 102 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: true 3 | notify: 4 | wait_for_ci: true 5 | 6 | coverage: 7 | precision: 2 8 | round: down 9 | range: "70...100" 10 | status: 11 | project: 12 | default: 13 | # Adjust based on your expectations - fail if overall project coverage drops more than 1% 14 | target: auto 15 | threshold: 1% 16 | patch: 17 | default: 18 | target: 80% 19 | changes: true 20 | 21 | comment: 22 | layout: "reach, diff, flags, files" 23 | behavior: default 24 | require_changes: false # if true: comment only if coverage changes 25 | 26 | # Ignore test files and generated code if applicable 27 | ignore: 28 | - "**/*_test.go" 29 | - "**/mock_*.go" 30 | - "**/mocks/**" 31 | - "**/vendor/**" 32 | - "**/testdata/**" 33 | - "examples/**" 34 | 35 | # GitHub features 36 | github_checks: 37 | annotations: true -------------------------------------------------------------------------------- /docs/design.md: -------------------------------------------------------------------------------- 1 | # MCP Go SDK Design Document 2 | 3 | MCP Go SDK is a powerful and easy-to-use Go client library designed for interacting with the Management Control Panel API. This SDK provides complete API coverage, including core functionalities such as resource management, configuration, monitoring, and automation operations. 4 | 5 | # Design Philosophy 6 | 7 | - MCP Protocol Messages 8 | 9 | | Capability Provider | Capability | Protocol Messages (Client Send) | Protocol Messages (Server Send) | 10 | | ------------------ | --------------- | ----------------------------------------------------------------------------------------------------- | ----------------------------------------------------------- | 11 | | Client&Server | Initialization | • Initialize
• Initialized notifications | (None) | 12 | | Client&Server | Ping | • Ping | • Ping | 13 | | Client&Server | Cancellation | • Cancelled Notifications | • Cancelled Notifications | 14 | | Client&Server | Progress | • Progress Notifications | • Progress Notifications | 15 | | Client | roots | • Root List Changes | • Listing Roots | 16 | | Client | sampling | (None) | • Creating Messages | 17 | | Server | prompts | • Listing Prompts
• Getting a Prompt | • List Changed Notification | 18 | | Server | resources | • Listing Resources
• Reading Resources
• Resource Templates
• Subscriptions: Request
• UnSubscriptions: Request | • List Changed Notification
• Subscriptions: Update Notification | 19 | | Server | tools | • Listing Tools
• Calling Tools | • List Changed Notification | 20 | | Server | Completion | • Requesting Completions | (None) | 21 | | Server | logging | • Setting Log Level | • Log Message Notifications | 22 | 23 | - Interaction Details 24 | ![img_1.png](images/img_1.png) 25 | - Both client and server need to have send and receive capabilities 26 | - Messages can be abstracted into three types: request, response, and notification 27 | - The architecture can be abstracted into three layers: transport layer, protocol layer, and user layer (server, client) 28 | 29 | - Design Principles 30 | - Protocol layer and transport layer are decoupled through the transport interface 31 | - Protocol layer contains all MCP protocol-related definitions, including data structures, request construction, and response parsing 32 | - Both server and client layers have send and receive capabilities. Send capabilities include sending messages (request, response, notification) and matching requests with responses. Receive capabilities include routing messages (request, response, notification) and handling them asynchronously/synchronously 33 | - Server and client layers implement the combination of requests and responses, presenting as synchronous request, processing, and response from the user's perspective 34 | 35 | # Architecture Design 36 | ![img.png](images/img.png) 37 | 38 | # Project Structure 39 | 40 | - transports 41 | - sse_client.go 42 | - sse_server.go 43 | - stdio_client.go 44 | - sdtio_server.go 45 | - transport.go // transport interface definition 46 | - pkg 47 | - errors.go // error definitions 48 | - log.go // log interface definition 49 | - protocol // contains all MCP protocol-related definitions, including data structures, request construction, and response parsing 50 | - initialize.go 51 | - ping.go 52 | - cancellation.go 53 | - progress.go 54 | - roots.go 55 | - sampling.go 56 | - prompts.go 57 | - resources.go 58 | - tools.go 59 | - completion.go 60 | - logging.go 61 | - pagination.go 62 | - jsonrpc.go 63 | - server 64 | - server.go 65 | - call.go // send messages (request, notification) to client 66 | - handle.go // handle messages (request, notification) from client, return response or not 67 | - send.go // send messages (request, response, notification) to client 68 | - receive.go // receive messages (request, response, notification) from client 69 | - client 70 | - client.go 71 | - call.go // send messages (request, notification) to server 72 | - handle.go // handle messages (request, notification) from server, return response or not 73 | - send.go // send messages (request, response, notification) to server 74 | - receive.go // receive messages (request, response, notification) from server 75 | -------------------------------------------------------------------------------- /docs/design_cn.md: -------------------------------------------------------------------------------- 1 | MCP Go SDK是一个功能强大且易于使用的Go语言客户端库,专为与Management Control Panel API进行交互而设计。该SDK提供了完整的API覆盖,包括资源管理、配置、监控和自动化操作等核心功能。 2 | 3 | # 设计思路 4 | 5 | - MCP 协议消息 6 | 7 | | 能力提供方 | 能力 | 协议消息(客户端发送) | 协议消息(服务端发送) | 8 | | ------------- | ---------------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------ | 9 | | Client&Server | Initialization | • Initialize
• Initialized notifications | (无) | 10 | | Client&Server | Ping | • Ping | • Ping | 11 | | Client&Server | Cancellation | • Cancelled Notifications | • Cancelled Notifications | 12 | | Client&Server | Progress | • Progress Notifications | • Progress Notifications | 13 | | Client | roots | • Root List Changes | • Listing Roots | 14 | | Client | sampling | (无) | • Creating Messages | 15 | | Server | prompts | • Listing Prompts
• Getting a Prompt | • List Changed Notification | 16 | | Server | resources | • Listing Resources
• Reading Resources
• Resource Templates
• Subscriptions: Request
• UnSubscriptions: Request | • List Changed Notification
• Subscriptions: Update Notification | 17 | | Server | tools | • Listing Tools
• Calling Tools | • List Changed Notification | 18 | | Server | Completion | • Requesting Completions | (无) | 19 | | Server | logging | • Setting Log Level | • Log Message Notifications | 20 | 21 | - 交互细节 22 | ![img_1.png](images/img_1.png) 23 | - 客户端和服务端都需要具备收发功能; 24 | - 可以将消息类型抽象为 message,具体实现包括 request、response、notification 三种; 25 | - 可以将架构抽象为三层传输层、协议层、用户层(server、client) 26 | 27 | 28 | - 设计思想 29 | - 协议层与传输层通过 transport 接口进行解耦; 30 | - protocol 层完成 MCP 协议相关的全部定义,包括数据结构定义、请求结构构造、响应结构解析; 31 | - server 层与 client 层都具备发送(send)和接收(receive)的能力,发送能力包括发送 message(request、response、notification) 请求和响应的匹配 ,接收能力包括对 message(request、response、notification) 的路由、异步/同步处理; 32 | - server 层与 client 层实现对 request 和 response 的组合,用户侧使用时表现为同步请求、同步处理、同步返回。 33 | 34 | # 架构设计 35 | ![img.png](images/img.png) 36 | 37 | # 项目目录 38 | 39 | - transports 40 | - sse_client.go 41 | - sse_server.go 42 | - stdio_client.go 43 | - sdtio_server.go 44 | - transport.go // transport 接口定义 45 | - pkg 46 | - errors.go // error 定义 47 | - log.go // log 接口定义 48 | - protocol // 放置 mcp 协议相关的全部定义,包括数据结构定义、请求结构构造、响应结构解析; 49 | - initialize.go 50 | - ping.go 51 | - cancellation.go 52 | - progress.go 53 | - roots.go 54 | - sampling.go 55 | - prompts.go 56 | - resources.go 57 | - tools.go 58 | - completion.go 59 | - logging.go 60 | - pagination.go 61 | - jsonrpc.go 62 | - server 63 | - server.go 64 | - call.go // 向客户端发送 message(request、notification) 65 | - handle.go // 对来自客户端的 message(request、notification) 进行处理,返回或不返回 response 66 | - send.go // 向客户端发送 message(request、response、notification) 67 | - receive.go // 对来自客户端的 message(request、response、notification)进行接收 68 | - client 69 | - client.go 70 | - call.go // 向服务端发送 message(request、notification) 71 | - handle.go // 对来自服务端的 message(request、notification) 进行处理,返回或不返回 response 72 | - send.go // 向服务端发送 message(request、response、notification) 73 | - receive.go // 对来自服务端的 message(request、response、notification)进行接收 74 | -------------------------------------------------------------------------------- /docs/images/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img.png -------------------------------------------------------------------------------- /docs/images/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img_1.png -------------------------------------------------------------------------------- /docs/images/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img_2.png -------------------------------------------------------------------------------- /docs/images/wechat_qrcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/wechat_qrcode.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # More Examples 2 | References: https://github.com/ThinkInAIXYZ/mcp-servers 3 | -------------------------------------------------------------------------------- /examples/auth_tool/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 15 | "github.com/ThinkInAIXYZ/go-mcp/server" 16 | "github.com/ThinkInAIXYZ/go-mcp/transport" 17 | ) 18 | 19 | type userIDKey struct{} 20 | 21 | func setUserIDToCtx(ctx context.Context, userID string) context.Context { 22 | return context.WithValue(ctx, userIDKey{}, userID) 23 | } 24 | 25 | func getUserIDFromCtx(ctx context.Context) (string, error) { 26 | userID := ctx.Value(userIDKey{}) 27 | if userID == nil { 28 | return "", errors.New("no userID found") 29 | } 30 | return userID.(string), nil 31 | } 32 | 33 | type currentTimeReq struct { 34 | Timezone string `json:"timezone" description:"current time timezone"` 35 | } 36 | 37 | func main() { 38 | messageEndpointURL := "/message" 39 | 40 | userParamKey := "user_id" 41 | paramKeysOpt := transport.WithSSEServerTransportAndHandlerOptionCopyParamKeys([]string{userParamKey}) 42 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL, paramKeysOpt) 43 | if err != nil { 44 | log.Panicf("new sse transport and hander with error: %v", err) 45 | } 46 | 47 | mcpServer, err := server.NewServer(sseTransport, 48 | server.WithServerInfo(protocol.Implementation{ 49 | Name: "mcp-example", 50 | Version: "1.0.0", 51 | }), 52 | ) 53 | if err != nil { 54 | panic(err) 55 | } 56 | 57 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{}) 58 | if err != nil { 59 | panic(fmt.Sprintf("Failed to create tool: %v", err)) 60 | } 61 | 62 | authentication := authenticationMiddleware(map[string][]string{ 63 | tool.Name: {"test_1"}, 64 | }) 65 | mcpServer.RegisterTool(tool, currentTime, authentication) 66 | 67 | router := http.NewServeMux() 68 | router.HandleFunc("/sse", mcpHandler.HandleSSE().ServeHTTP) 69 | router.HandleFunc(messageEndpointURL, func(w http.ResponseWriter, r *http.Request) { 70 | userID := r.URL.Query().Get(userParamKey) 71 | if userID == "" { 72 | w.Header().Set("Content-Type", "text/plain") 73 | w.WriteHeader(http.StatusBadRequest) 74 | if _, e := w.Write([]byte("lack user_id")); e != nil { 75 | fmt.Printf("writeError: %+v", e) 76 | } 77 | return 78 | } 79 | 80 | r = r.WithContext(setUserIDToCtx(r.Context(), userID)) 81 | 82 | mcpHandler.HandleMessage().ServeHTTP(w, r) 83 | }) 84 | 85 | // Can be replaced by using gin framework 86 | // router := gin.Default() 87 | // router.GET("/sse", func(ctx *gin.Context) { 88 | // mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request) 89 | // }) 90 | // router.POST(messageEndpointURL, func(ctx *gin.Context) { 91 | // mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request) 92 | // }) 93 | 94 | httpServer := &http.Server{ 95 | Addr: ":8080", 96 | Handler: router, 97 | IdleTimeout: time.Minute, 98 | } 99 | 100 | errCh := make(chan error, 3) 101 | go func() { 102 | errCh <- mcpServer.Run() 103 | }() 104 | 105 | go func() { 106 | if err = httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { 107 | errCh <- err 108 | } 109 | }() 110 | 111 | if err = signalWaiter(errCh); err != nil { 112 | panic(fmt.Sprintf("signal waiter: %v", err)) 113 | } 114 | 115 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 116 | defer cancel() 117 | 118 | httpServer.RegisterOnShutdown(func() { 119 | if err = mcpServer.Shutdown(ctx); err != nil { 120 | panic(err) 121 | } 122 | }) 123 | 124 | if err = httpServer.Shutdown(ctx); err != nil { 125 | panic(err) 126 | } 127 | } 128 | 129 | func authenticationMiddleware(toolName2UserID map[string][]string) server.ToolMiddleware { 130 | return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { 131 | return func(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 132 | userID, err := getUserIDFromCtx(ctx) 133 | if err != nil { 134 | return nil, err 135 | } 136 | 137 | for _, id := range toolName2UserID[req.Name] { 138 | if userID == id { 139 | return next(ctx, req) 140 | } 141 | } 142 | return nil, fmt.Errorf("user %s not authorized", userID) 143 | } 144 | } 145 | } 146 | 147 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 148 | req := new(currentTimeReq) 149 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { 150 | return nil, err 151 | } 152 | 153 | loc, err := time.LoadLocation(req.Timezone) 154 | if err != nil { 155 | return nil, fmt.Errorf("parse timezone with error: %v", err) 156 | } 157 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc)) 158 | 159 | return &protocol.CallToolResult{ 160 | Content: []protocol.Content{ 161 | &protocol.TextContent{ 162 | Type: "text", 163 | Text: text, 164 | }, 165 | }, 166 | }, nil 167 | } 168 | 169 | func signalWaiter(errCh chan error) error { 170 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM} 171 | if signal.Ignored(syscall.SIGHUP) { 172 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM} 173 | } 174 | 175 | signals := make(chan os.Signal, 1) 176 | signal.Notify(signals, signalToNotify...) 177 | 178 | select { 179 | case sig := <-signals: 180 | switch sig { 181 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM: 182 | log.Printf("Received signal: %s\n", sig) 183 | // graceful shutdown 184 | return nil 185 | } 186 | case err := <-errCh: 187 | return err 188 | } 189 | 190 | return nil 191 | } 192 | -------------------------------------------------------------------------------- /examples/current_time_server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 14 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 15 | "github.com/ThinkInAIXYZ/go-mcp/server" 16 | "github.com/ThinkInAIXYZ/go-mcp/transport" 17 | ) 18 | 19 | type currentTimeReq struct { 20 | Timezone string `json:"timezone" description:"current time timezone"` 21 | } 22 | 23 | func main() { 24 | // new mcp server with stdio or sse transport 25 | srv, err := server.NewServer( 26 | getTransport(), 27 | server.WithServerInfo(protocol.Implementation{ 28 | Name: "current-time-v2-server", 29 | Version: "1.0.0", 30 | }), 31 | ) 32 | if err != nil { 33 | log.Fatalf("Failed to create server: %v", err) 34 | } 35 | 36 | // new protocol tool with name, descipriton and properties 37 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{}) 38 | if err != nil { 39 | log.Fatalf("Failed to create tool: %v", err) 40 | return 41 | } 42 | 43 | // register tool and start mcp server 44 | srv.RegisterTool(tool, currentTime, 45 | server.RateLimitMiddleware(pkg.NewTokenBucketLimiter(pkg.Rate{ 46 | Limit: 10.0, // 每秒10个请求 47 | Burst: 20, // 最多允许20个请求的突发 48 | }))) 49 | // srv.RegisterResource() 50 | // srv.RegisterPrompt() 51 | // srv.RegisterResourceTemplate() 52 | 53 | errCh := make(chan error) 54 | go func() { 55 | errCh <- srv.Run() 56 | }() 57 | 58 | if err = signalWaiter(errCh); err != nil { 59 | log.Fatalf("signal waiter: %v", err) 60 | return 61 | } 62 | 63 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 64 | defer cancel() 65 | 66 | if err := srv.Shutdown(ctx); err != nil { 67 | log.Fatalf("Shutdown error: %v", err) 68 | } 69 | } 70 | 71 | func getTransport() (t transport.ServerTransport) { 72 | var ( 73 | mode string 74 | addr = "127.0.0.1:8080" 75 | ) 76 | 77 | flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"") 78 | flag.Parse() 79 | 80 | switch mode { 81 | case "stdio": 82 | log.Println("start current time mcp server with stdio transport") 83 | t = transport.NewStdioServerTransport() 84 | case "sse": 85 | log.Printf("start current time mcp server with sse transport, listen %s", addr) 86 | t, _ = transport.NewSSEServerTransport(addr) 87 | case "streamable_http": 88 | log.Printf("start current time mcp server with streamable_http transport, listen %s", addr) 89 | t = transport.NewStreamableHTTPServerTransport(addr) 90 | default: 91 | panic(fmt.Errorf("unknown mode: %s", mode)) 92 | } 93 | 94 | return t 95 | } 96 | 97 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 98 | req := new(currentTimeReq) 99 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { 100 | return nil, err 101 | } 102 | 103 | loc, err := time.LoadLocation(req.Timezone) 104 | if err != nil { 105 | return nil, fmt.Errorf("parse timezone with error: %v", err) 106 | } 107 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc)) 108 | 109 | return &protocol.CallToolResult{ 110 | Content: []protocol.Content{ 111 | &protocol.TextContent{ 112 | Type: "text", 113 | Text: text, 114 | }, 115 | }, 116 | }, nil 117 | } 118 | 119 | func signalWaiter(errCh chan error) error { 120 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM} 121 | if signal.Ignored(syscall.SIGHUP) { 122 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM} 123 | } 124 | 125 | signals := make(chan os.Signal, 1) 126 | signal.Notify(signals, signalToNotify...) 127 | 128 | select { 129 | case sig := <-signals: 130 | switch sig { 131 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM: 132 | log.Printf("Received signal: %s\n", sig) 133 | // graceful shutdown 134 | return nil 135 | } 136 | case err := <-errCh: 137 | return err 138 | } 139 | 140 | return nil 141 | } 142 | -------------------------------------------------------------------------------- /examples/filesystem_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "log" 7 | "time" 8 | 9 | "github.com/ThinkInAIXYZ/go-mcp/client" 10 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 11 | "github.com/ThinkInAIXYZ/go-mcp/transport" 12 | ) 13 | 14 | func main() { 15 | t, err := transport.NewStdioClientTransport("npx", []string{"-y", "@modelcontextprotocol/server-filesystem", "~/tmp"}) 16 | if err != nil { 17 | log.Fatal(err) 18 | } 19 | 20 | cli, err := client.NewClient(t, client.WithClientInfo(&protocol.Implementation{ 21 | Name: "test", 22 | Version: "1.0.0", 23 | })) 24 | if err != nil { 25 | log.Fatalf("Failed to new client: %v", err) 26 | } 27 | defer func() { 28 | if err = cli.Close(); err != nil { 29 | log.Fatalf("Failed to close client: %v", err) 30 | } 31 | }() 32 | 33 | // Create context with timeout 34 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 35 | defer cancel() 36 | 37 | // List Tools 38 | log.Println("Listing available tools...") 39 | tools, err := cli.ListTools(ctx) 40 | if err != nil { 41 | log.Fatalf("Failed to list tools: %v", err) 42 | } 43 | for _, tool := range tools.Tools { 44 | log.Printf("- %s: %s\n", tool.Name, tool.Description) 45 | } 46 | 47 | // List allowed directories 48 | log.Println("Listing allowed directories...") 49 | listDirRequest := &protocol.CallToolRequest{ 50 | Name: "list_allowed_directories", 51 | } 52 | result, err := cli.CallTool(ctx, listDirRequest) 53 | if err != nil { 54 | log.Fatalf("Failed to list allowed directories: %v", err) 55 | } 56 | printToolResult(result) 57 | log.Println() 58 | 59 | // List ~/tmp 60 | log.Println("Listing ~/tmp directory...") 61 | listTmpRequest := &protocol.CallToolRequest{ 62 | Name: "list_directory", 63 | Arguments: map[string]interface{}{"path": "~/tmp"}, 64 | } 65 | result, err = cli.CallTool(ctx, listTmpRequest) 66 | if err != nil { 67 | log.Fatalf("Failed to list directory: %v", err) 68 | } 69 | printToolResult(result) 70 | log.Println() 71 | 72 | // Create mcp directory 73 | log.Println("Creating ~/tmp/mcp directory...") 74 | createDirRequest := &protocol.CallToolRequest{ 75 | Name: "create_directory", 76 | Arguments: map[string]interface{}{"path": "~/tmp/mcp"}, 77 | } 78 | result, err = cli.CallTool(ctx, createDirRequest) 79 | if err != nil { 80 | log.Fatalf("Failed to create directory: %v", err) 81 | } 82 | printToolResult(result) 83 | log.Println() 84 | 85 | // Create hello.txt 86 | log.Println("Creating ~/tmp/mcp/hello.txt...") 87 | writeFileRequest := &protocol.CallToolRequest{ 88 | Name: "write_file", 89 | Arguments: map[string]interface{}{ 90 | "path": "~/tmp/mcp/hello.txt", 91 | "content": "Hello World", 92 | }, 93 | } 94 | result, err = cli.CallTool(ctx, writeFileRequest) 95 | if err != nil { 96 | log.Fatalf("Failed to create file: %v", err) 97 | } 98 | printToolResult(result) 99 | log.Println() 100 | 101 | // Verify file contents 102 | log.Println("Reading ~/tmp/mcp/hello.txt...") 103 | readFileRequest := &protocol.CallToolRequest{ 104 | Name: "read_file", 105 | Arguments: map[string]interface{}{ 106 | "path": "~/tmp/mcp/hello.txt", 107 | }, 108 | } 109 | result, err = cli.CallTool(ctx, readFileRequest) 110 | if err != nil { 111 | log.Fatalf("Failed to read file: %v", err) 112 | } 113 | printToolResult(result) 114 | 115 | // Get file info 116 | log.Println("Getting info for ~/tmp/mcp/hello.txt...") 117 | fileInfoRequest := &protocol.CallToolRequest{ 118 | Name: "get_file_info", 119 | Arguments: map[string]interface{}{ 120 | "path": "~/tmp/mcp/hello.txt", 121 | }, 122 | } 123 | result, err = cli.CallTool(ctx, fileInfoRequest) 124 | if err != nil { 125 | log.Fatalf("Failed to get file info: %v", err) 126 | } 127 | printToolResult(result) 128 | } 129 | 130 | // Helper function to print tool results 131 | func printToolResult(result *protocol.CallToolResult) { 132 | for _, content := range result.Content { 133 | if textContent, ok := content.(*protocol.TextContent); ok { 134 | log.Println(textContent.Text) 135 | } else { 136 | jsonBytes, _ := json.MarshalIndent(content, "", " ") 137 | log.Println(string(jsonBytes)) 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /examples/http_handler/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 15 | "github.com/ThinkInAIXYZ/go-mcp/server" 16 | "github.com/ThinkInAIXYZ/go-mcp/transport" 17 | ) 18 | 19 | type currentTimeReq struct { 20 | Timezone string `json:"timezone" description:"current time timezone"` 21 | } 22 | 23 | func main() { 24 | messageEndpointURL := "/message" 25 | 26 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL) 27 | if err != nil { 28 | log.Panicf("new sse transport and hander with error: %v", err) 29 | } 30 | 31 | mcpServer, err := server.NewServer(sseTransport, 32 | server.WithServerInfo(protocol.Implementation{ 33 | Name: "mcp-example", 34 | Version: "1.0.0", 35 | }), 36 | ) 37 | if err != nil { 38 | panic(err) 39 | } 40 | 41 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{}) 42 | if err != nil { 43 | panic(fmt.Sprintf("Failed to create tool: %v", err)) 44 | } 45 | 46 | mcpServer.RegisterTool(tool, currentTime) 47 | 48 | router := http.NewServeMux() 49 | router.HandleFunc("/sse", mcpHandler.HandleSSE().ServeHTTP) 50 | router.HandleFunc(messageEndpointURL, mcpHandler.HandleMessage().ServeHTTP) 51 | 52 | // Can be replaced by using gin framework 53 | // router := gin.Default() 54 | // router.GET("/sse", func(ctx *gin.Context) { 55 | // mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request) 56 | // }) 57 | // router.POST(messageEndpointURL, func(ctx *gin.Context) { 58 | // mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request) 59 | // }) 60 | 61 | httpServer := &http.Server{ 62 | Addr: ":8080", 63 | Handler: router, 64 | IdleTimeout: time.Minute, 65 | } 66 | 67 | errCh := make(chan error, 3) 68 | go func() { 69 | errCh <- mcpServer.Run() 70 | }() 71 | 72 | go func() { 73 | if err = httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { 74 | errCh <- err 75 | } 76 | }() 77 | 78 | if err = signalWaiter(errCh); err != nil { 79 | panic(fmt.Sprintf("signal waiter: %v", err)) 80 | } 81 | 82 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 83 | defer cancel() 84 | 85 | httpServer.RegisterOnShutdown(func() { 86 | if err = mcpServer.Shutdown(ctx); err != nil { 87 | panic(err) 88 | } 89 | }) 90 | 91 | if err = httpServer.Shutdown(ctx); err != nil { 92 | panic(err) 93 | } 94 | } 95 | 96 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { 97 | req := new(currentTimeReq) 98 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { 99 | return nil, err 100 | } 101 | 102 | loc, err := time.LoadLocation(req.Timezone) 103 | if err != nil { 104 | return nil, fmt.Errorf("parse timezone with error: %v", err) 105 | } 106 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc)) 107 | 108 | return &protocol.CallToolResult{ 109 | Content: []protocol.Content{ 110 | &protocol.TextContent{ 111 | Type: "text", 112 | Text: text, 113 | }, 114 | }, 115 | }, nil 116 | } 117 | 118 | func signalWaiter(errCh chan error) error { 119 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM} 120 | if signal.Ignored(syscall.SIGHUP) { 121 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM} 122 | } 123 | 124 | signals := make(chan os.Signal, 1) 125 | signal.Notify(signals, signalToNotify...) 126 | 127 | select { 128 | case sig := <-signals: 129 | switch sig { 130 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM: 131 | log.Printf("Received signal: %s\n", sig) 132 | // graceful shutdown 133 | return nil 134 | } 135 | case err := <-errCh: 136 | return err 137 | } 138 | 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ThinkInAIXYZ/go-mcp 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/google/uuid v1.6.0 7 | github.com/orcaman/concurrent-map/v2 v2.0.1 8 | github.com/tidwall/gjson v1.18.0 9 | github.com/yosida95/uritemplate/v3 v3.0.2 10 | ) 11 | 12 | require ( 13 | github.com/tidwall/match v1.1.1 // indirect 14 | github.com/tidwall/pretty v1.2.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 | github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c= 4 | github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM= 5 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= 6 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 7 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 8 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 9 | github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= 10 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 11 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= 12 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= 13 | -------------------------------------------------------------------------------- /hack/.lintcheck_failures: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/hack/.lintcheck_failures -------------------------------------------------------------------------------- /hack/.test_ignored_files: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/hack/.test_ignored_files -------------------------------------------------------------------------------- /hack/resolve-modules.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is used by the linter action. 4 | # Recursively finds all directories with a go.mod file and creates 5 | # a GitHub Actions JSON output option. 6 | 7 | set -o errexit 8 | 9 | echo "Resolving modules in $(pwd)" 10 | 11 | PROJECT_HOME=$( 12 | cd "$(dirname "${BASH_SOURCE[0]}")" && 13 | cd .. && 14 | pwd 15 | ) 16 | 17 | source "${PROJECT_HOME}/hack/util.sh" 18 | 19 | FAILURE_FILE=${PROJECT_HOME}/hack/.lintcheck_failures 20 | 21 | all_modules=$(util::find_modules) 22 | failing_modules=() 23 | while IFS='' read -r line; do failing_modules+=("$line"); done < <(cat "$FAILURE_FILE") 24 | 25 | echo "Ignored failing modules:" 26 | echo "${failing_modules[*]}" 27 | echo 28 | 29 | PATHS="" 30 | 31 | for mod in $all_modules; do 32 | echo "Checking module: $mod" 33 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$? 34 | if [[ "$in_failing" -ne "0" ]]; then 35 | PATHS+=$(printf '{"workdir":"%s"},' ${mod}) 36 | fi 37 | done 38 | 39 | echo "::set-output name=matrix::{\"include\":[${PATHS%?}]}" 40 | -------------------------------------------------------------------------------- /hack/tools.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is a tools shell script 4 | # used by Makefile commands 5 | 6 | set -o errexit 7 | set -o nounset 8 | set -o pipefail 9 | 10 | GO111MODULE=on 11 | PROJECT_HOME=$( 12 | cd "$(dirname "${BASH_SOURCE[0]}")" && 13 | cd .. && 14 | pwd 15 | ) 16 | 17 | source "${PROJECT_HOME}/hack/util.sh" 18 | 19 | LINTER=${PROJECT_HOME}/bin/golangci-lint 20 | LINTER_CONFIG=${PROJECT_HOME}/.golangci.yml 21 | FAILURE_FILE=${PROJECT_HOME}/hack/.lintcheck_failures 22 | IGNORED_FILE=${PROJECT_HOME}/hack/.test_ignored_files 23 | 24 | all_modules=$(util::find_modules) 25 | failing_modules=() 26 | while IFS='' read -r line; do failing_modules+=("$line"); done < <(cat "$FAILURE_FILE") 27 | ignored_modules=() 28 | while IFS='' read -r line; do ignored_modules+=("$line"); done < <(cat "$IGNORED_FILE") 29 | 30 | # functions 31 | # lint all mod 32 | function lint() { 33 | for mod in $all_modules; do 34 | local in_failing 35 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$? 36 | if [[ "$in_failing" -ne "0" ]]; then 37 | pushd "$mod" >/dev/null && 38 | echo "golangci lint $(sed -n 1p go.mod | cut -d ' ' -f2)" && 39 | eval "${LINTER} run --timeout=5m --config=${LINTER_CONFIG}" 40 | popd >/dev/null || exit 41 | fi 42 | done 43 | } 44 | 45 | # test all mod 46 | function test() { 47 | for mod in $all_modules; do 48 | local in_failing 49 | util::array_contains "$mod" "${ignored_modules[*]}" && in_failing=$? || in_failing=$? 50 | if [[ "$in_failing" -ne "0" ]]; then 51 | pushd "$mod" >/dev/null && 52 | echo "go test $(sed -n 1p go.mod | cut -d ' ' -f2)" && 53 | go test -race ./... 54 | popd >/dev/null || exit 55 | fi 56 | done 57 | } 58 | 59 | function test_coverage() { 60 | echo "" > coverage.txt 61 | local base 62 | base=$(pwd) 63 | for mod in $all_modules; do 64 | local in_failing 65 | util::array_contains "$mod" "${ignored_modules[*]}" && in_failing=$? || in_failing=$? 66 | if [[ "$in_failing" -ne "0" ]]; then 67 | pushd "$mod" >/dev/null && 68 | echo "go test $(sed -n 1p go.mod | cut -d ' ' -f2)" && 69 | go test -race -coverprofile=profile.out -covermode=atomic ./... 70 | if [ -f profile.out ]; then 71 | cat profile.out > "${base}/coverage.txt" 72 | rm profile.out 73 | fi 74 | popd >/dev/null || exit 75 | fi 76 | done 77 | } 78 | 79 | # try to fix all mod with golangci-lint 80 | function fix() { 81 | for mod in $all_modules; do 82 | local in_failing 83 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$? 84 | if [[ "$in_failing" -ne "0" ]]; then 85 | pushd "$mod" >/dev/null && 86 | echo "golangci fix $(sed -n 1p go.mod | cut -d ' ' -f2)" && 87 | eval "${LINTER} run -v --fix --timeout=5m --config=${LINTER_CONFIG}" 88 | popd >/dev/null || exit 89 | fi 90 | done 91 | } 92 | 93 | function tidy() { 94 | for mod in $all_modules; do 95 | pushd "$mod" >/dev/null && 96 | echo "go mod tidy $(sed -n 1p go.mod | cut -d ' ' -f2)" && 97 | go mod tidy 98 | popd >/dev/null || exit 99 | done 100 | } 101 | 102 | function help() { 103 | echo "use: lint, test, test_coverage, fix, tidy" 104 | } 105 | 106 | case $1 in 107 | lint) 108 | lint 109 | ;; 110 | test) 111 | test 112 | ;; 113 | test_coverage) 114 | test_coverage 115 | ;; 116 | tidy) 117 | tidy 118 | ;; 119 | fix) 120 | fix 121 | ;; 122 | *) 123 | help 124 | ;; 125 | esac 126 | -------------------------------------------------------------------------------- /hack/util.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is a common util functions shell script 4 | 5 | # arguments: target, item1, item2, item3, ... 6 | # returns 0 if target is in the given items, 1 otherwise. 7 | function util::array_contains() { 8 | local target="$1" 9 | shift 10 | local items="$*" 11 | for item in ${items[*]}; do 12 | if [[ "${item}" == "${target}" ]]; then 13 | return 0 14 | fi 15 | done 16 | return 1 17 | } 18 | 19 | # find all go mod path 20 | # returns an array contains mod path 21 | function util::find_modules() { 22 | find . -not \( \ 23 | \( \ 24 | -path './output' \ 25 | -o -path './.git' \ 26 | -o -path '*/third_party/*' \ 27 | -o -path '*/vendor/*' \ 28 | \) -prune \ 29 | \) -name 'go.mod' -print0 | xargs -0 -I {} dirname {} 30 | } 31 | -------------------------------------------------------------------------------- /pkg/atomic.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "sync/atomic" 4 | 5 | type AtomicBool struct { 6 | b atomic.Value 7 | } 8 | 9 | func NewAtomicBool() *AtomicBool { 10 | b := &AtomicBool{} 11 | b.b.Store(false) 12 | return b 13 | } 14 | 15 | func (b *AtomicBool) Store(value bool) { 16 | b.b.Store(value) 17 | } 18 | 19 | func (b *AtomicBool) Load() bool { 20 | return b.b.Load().(bool) 21 | } 22 | 23 | type AtomicString struct { 24 | b atomic.Value 25 | } 26 | 27 | func NewAtomicString() *AtomicString { 28 | b := &AtomicString{} 29 | b.b.Store("") 30 | return b 31 | } 32 | 33 | func (b *AtomicString) Store(value string) { 34 | b.b.Store(value) 35 | } 36 | 37 | func (b *AtomicString) Load() string { 38 | return b.b.Load().(string) 39 | } 40 | -------------------------------------------------------------------------------- /pkg/context.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | type CancelShieldContext struct { 9 | context.Context 10 | } 11 | 12 | func NewCancelShieldContext(ctx context.Context) context.Context { 13 | return CancelShieldContext{Context: ctx} 14 | } 15 | 16 | func (v CancelShieldContext) Deadline() (deadline time.Time, ok bool) { 17 | return 18 | } 19 | 20 | func (v CancelShieldContext) Done() <-chan struct{} { 21 | return nil 22 | } 23 | 24 | func (v CancelShieldContext) Err() error { 25 | return nil 26 | } 27 | -------------------------------------------------------------------------------- /pkg/errors.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | ErrClientNotSupport = errors.New("this feature client not support") 10 | ErrServerNotSupport = errors.New("this feature server not support") 11 | ErrRequestInvalid = errors.New("request invalid") 12 | ErrLackResponseChan = errors.New("lack response chan") 13 | ErrDuplicateResponseReceived = errors.New("duplicate response received") 14 | ErrMethodNotSupport = errors.New("method not support") 15 | ErrJSONUnmarshal = errors.New("json unmarshal error") 16 | ErrSessionHasNotInitialized = errors.New("the session has not been initialized") 17 | ErrLackSession = errors.New("lack session") 18 | ErrSessionClosed = errors.New("session closed") 19 | ErrSendEOF = errors.New("send EOF") 20 | ErrRateLimitExceeded = errors.New("rate limit exceeded") 21 | ) 22 | 23 | type ResponseError struct { 24 | Code int 25 | Message string 26 | Data interface{} 27 | } 28 | 29 | func NewResponseError(code int, message string, data interface{}) *ResponseError { 30 | return &ResponseError{Code: code, Message: message, Data: data} 31 | } 32 | 33 | func (e *ResponseError) Error() string { 34 | return fmt.Sprintf("code=%d message=%s data=%+v", e.Code, e.Message, e.Data) 35 | } 36 | -------------------------------------------------------------------------------- /pkg/helper.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "runtime/debug" 7 | "strings" 8 | "unsafe" 9 | ) 10 | 11 | func Recover() { 12 | if r := recover(); r != nil { 13 | log.Printf("panic: %v\nstack: %s", r, debug.Stack()) 14 | } 15 | } 16 | 17 | func RecoverWithFunc(f func(r any)) { 18 | if r := recover(); r != nil { 19 | f(r) 20 | log.Printf("panic: %v\nstack: %s", r, debug.Stack()) 21 | } 22 | } 23 | 24 | func B2S(b []byte) string { 25 | return *(*string)(unsafe.Pointer(&b)) 26 | } 27 | 28 | func JoinErrors(errs []error) error { 29 | if len(errs) == 0 { 30 | return nil 31 | } 32 | messages := make([]string, len(errs)) 33 | for i, err := range errs { 34 | messages[i] = err.Error() 35 | } 36 | return errors.New(strings.Join(messages, "; ")) 37 | } 38 | -------------------------------------------------------------------------------- /pkg/json.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | // var sonicAPI = sonic.Config{UseInt64: true}.Froze() // Effectively prevents integer overflow 9 | 10 | func JSONUnmarshal(data []byte, v interface{}) error { 11 | if err := json.Unmarshal(data, v); err != nil { 12 | return fmt.Errorf("%w: data=%s, error: %+v", ErrJSONUnmarshal, data, err) 13 | } 14 | return nil 15 | } 16 | -------------------------------------------------------------------------------- /pkg/limiter.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // RateLimiter 定义速率限制接口 9 | type RateLimiter interface { 10 | Allow(toolName string) bool 11 | } 12 | 13 | // TokenBucketLimiter 令牌桶限速器实现 14 | type TokenBucketLimiter struct { 15 | mu sync.RWMutex 16 | buckets map[string]*bucket 17 | defaultLimit Rate 18 | toolLimits map[string]Rate 19 | } 20 | 21 | // Rate 定义速率限制参数 22 | type Rate struct { 23 | Limit float64 // 每秒允许的请求数 24 | Burst int // 突发请求上限 25 | } 26 | 27 | // bucket 令牌桶 28 | type bucket struct { 29 | tokens float64 30 | lastTimestamp time.Time 31 | rate Rate 32 | } 33 | 34 | // NewTokenBucketLimiter 创建新的令牌桶限速器 35 | func NewTokenBucketLimiter(defaultRate Rate) *TokenBucketLimiter { 36 | return &TokenBucketLimiter{ 37 | buckets: make(map[string]*bucket), 38 | defaultLimit: defaultRate, 39 | toolLimits: make(map[string]Rate), 40 | } 41 | } 42 | 43 | // SetToolLimit 为特定工具设置限制 44 | func (l *TokenBucketLimiter) SetToolLimit(toolName string, rate Rate) { 45 | l.mu.Lock() 46 | defer l.mu.Unlock() 47 | 48 | l.toolLimits[toolName] = rate 49 | // 如果已有桶,更新其速率 50 | if b, exists := l.buckets[toolName]; exists { 51 | b.rate = rate 52 | } 53 | } 54 | 55 | // Allow 检查请求是否被允许 56 | func (l *TokenBucketLimiter) Allow(toolName string) bool { 57 | l.mu.RLock() 58 | defer l.mu.RUnlock() 59 | 60 | now := time.Now() 61 | 62 | // 获取或创建桶 63 | b, exists := l.buckets[toolName] 64 | if !exists { 65 | // 查找工具特定的限制,如果没有则使用默认限制 66 | rate, exists := l.toolLimits[toolName] 67 | if !exists { 68 | rate = l.defaultLimit 69 | } 70 | 71 | b = &bucket{ 72 | tokens: float64(rate.Burst), 73 | lastTimestamp: now, 74 | rate: rate, 75 | } 76 | l.buckets[toolName] = b 77 | } 78 | 79 | // 计算从上次请求到现在应该添加的令牌 80 | elapsed := now.Sub(b.lastTimestamp).Seconds() 81 | b.lastTimestamp = now 82 | 83 | // 添加令牌,但不超过最大值 84 | b.tokens += elapsed * b.rate.Limit 85 | if b.tokens > float64(b.rate.Burst) { 86 | b.tokens = float64(b.rate.Burst) 87 | } 88 | 89 | if b.tokens >= 1.0 { 90 | b.tokens -= 1.0 91 | return true 92 | } 93 | return false 94 | } 95 | -------------------------------------------------------------------------------- /pkg/log.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "log" 4 | 5 | type Logger interface { 6 | Debugf(format string, a ...any) 7 | Infof(format string, a ...any) 8 | Warnf(format string, a ...any) 9 | Errorf(format string, a ...any) 10 | } 11 | type LogLevel uint32 12 | 13 | const ( 14 | LogLevelDebug = LogLevel(0) 15 | LogLevelInfo = LogLevel(1) 16 | LogLevelWarn = LogLevel(2) 17 | LogLevelError = LogLevel(3) 18 | ) 19 | 20 | var DefaultLogger Logger = &defaultLogger{ 21 | logLevel: LogLevelInfo, 22 | } 23 | 24 | var DebugLogger Logger = &defaultLogger{ 25 | logLevel: LogLevelDebug, 26 | } 27 | 28 | type defaultLogger struct { 29 | logLevel LogLevel 30 | } 31 | 32 | func (l *defaultLogger) Debugf(format string, a ...any) { 33 | if l.logLevel > LogLevelDebug { 34 | return 35 | } 36 | log.Printf("[Debug] "+format+"\n", a...) 37 | } 38 | 39 | func (l *defaultLogger) Infof(format string, a ...any) { 40 | if l.logLevel > LogLevelInfo { 41 | return 42 | } 43 | log.Printf("[Info] "+format+"\n", a...) 44 | } 45 | 46 | func (l *defaultLogger) Warnf(format string, a ...any) { 47 | if l.logLevel > LogLevelWarn { 48 | return 49 | } 50 | log.Printf("[Warn] "+format+"\n", a...) 51 | } 52 | 53 | func (l *defaultLogger) Errorf(format string, a ...any) { 54 | if l.logLevel > LogLevelError { 55 | return 56 | } 57 | log.Printf("[Error] "+format+"\n", a...) 58 | } 59 | -------------------------------------------------------------------------------- /pkg/sync_map.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "sync" 4 | 5 | type SyncMap[V any] struct { 6 | m sync.Map 7 | } 8 | 9 | func (m *SyncMap[V]) Delete(key string) { 10 | m.m.Delete(key) 11 | } 12 | 13 | func (m *SyncMap[V]) Load(key string) (value V, ok bool) { 14 | v, ok := m.m.Load(key) 15 | if !ok { 16 | return value, ok 17 | } 18 | return v.(V), ok 19 | } 20 | 21 | func (m *SyncMap[V]) LoadAndDelete(key string) (value V, loaded bool) { 22 | v, loaded := m.m.LoadAndDelete(key) 23 | if !loaded { 24 | return value, loaded 25 | } 26 | return v.(V), loaded 27 | } 28 | 29 | func (m *SyncMap[V]) LoadOrStore(key string, value V) (actual V, loaded bool) { 30 | a, loaded := m.m.LoadOrStore(key, value) 31 | return a.(V), loaded 32 | } 33 | 34 | func (m *SyncMap[V]) Range(f func(key string, value V) bool) { 35 | m.m.Range(func(key, value any) bool { return f(key.(string), value.(V)) }) 36 | } 37 | 38 | func (m *SyncMap[V]) Store(key string, value V) { 39 | m.m.Store(key, value) 40 | } 41 | -------------------------------------------------------------------------------- /protocol/cancellation.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // CancelledNotification represents a notification that a request has been canceled 4 | type CancelledNotification struct { 5 | RequestID RequestID `json:"requestId"` 6 | Reason string `json:"reason,omitempty"` 7 | } 8 | 9 | // NewCancelledNotification creates a new canceled notification 10 | func NewCancelledNotification(requestID RequestID, reason string) *CancelledNotification { 11 | return &CancelledNotification{ 12 | RequestID: requestID, 13 | Reason: reason, 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /protocol/completion.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // CompleteRequest represents a request for completion options 4 | type CompleteRequest struct { 5 | Argument struct { 6 | Name string `json:"name"` 7 | Value string `json:"value"` 8 | } `json:"argument"` 9 | Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference 10 | } 11 | 12 | // Reference types 13 | type PromptReference struct { 14 | Type string `json:"type"` 15 | Name string `json:"name"` 16 | } 17 | 18 | type ResourceReference struct { 19 | Type string `json:"type"` 20 | URI string `json:"uri"` 21 | } 22 | 23 | // CompleteResult represents the response to a completion request 24 | type CompleteResult struct { 25 | Completion *Complete `json:"completion"` 26 | } 27 | 28 | type Complete struct { 29 | Values []string `json:"values"` 30 | HasMore bool `json:"hasMore,omitempty"` 31 | Total int `json:"total,omitempty"` 32 | } 33 | 34 | // NewCompleteRequest creates a new completion request 35 | func NewCompleteRequest(argName string, argValue string, ref interface{}) *CompleteRequest { 36 | return &CompleteRequest{ 37 | Argument: struct { 38 | Name string `json:"name"` 39 | Value string `json:"value"` 40 | }{ 41 | Name: argName, 42 | Value: argValue, 43 | }, 44 | Ref: ref, 45 | } 46 | } 47 | 48 | // NewCompleteResult creates a new completion response 49 | func NewCompleteResult(values []string, hasMore bool, total int) *CompleteResult { 50 | return &CompleteResult{ 51 | Completion: &Complete{ 52 | Values: values, 53 | HasMore: hasMore, 54 | Total: total, 55 | }, 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /protocol/initialize.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/tidwall/gjson" 7 | ) 8 | 9 | // InitializeRequest represents the initialize request sent from client to server 10 | type InitializeRequest struct { 11 | ClientInfo *Implementation `json:"clientInfo"` 12 | Capabilities *ClientCapabilities `json:"capabilities"` 13 | ProtocolVersion string `json:"protocolVersion"` 14 | } 15 | 16 | // InitializeResult represents the server's response to an initialize request 17 | type InitializeResult struct { 18 | ServerInfo *Implementation `json:"serverInfo"` 19 | Capabilities *ServerCapabilities `json:"capabilities"` 20 | ProtocolVersion string `json:"protocolVersion"` 21 | Instructions string `json:"instructions,omitempty"` 22 | } 23 | 24 | // Implementation describes the name and version of an MCP implementation 25 | type Implementation struct { 26 | Name string `json:"name"` 27 | Version string `json:"version"` 28 | } 29 | 30 | // ClientCapabilities capabilities 31 | type ClientCapabilities struct { 32 | // Experimental map[string]interface{} `json:"experimental,omitempty"` 33 | // Roots *RootsCapability `json:"roots,omitempty"` 34 | Sampling interface{} `json:"sampling,omitempty"` 35 | } 36 | 37 | type RootsCapability struct { 38 | ListChanged bool `json:"listChanged,omitempty"` 39 | } 40 | 41 | type ServerCapabilities struct { 42 | // Experimental map[string]interface{} `json:"experimental,omitempty"` 43 | // Logging interface{} `json:"logging,omitempty"` 44 | Prompts *PromptsCapability `json:"prompts,omitempty"` 45 | Resources *ResourcesCapability `json:"resources,omitempty"` 46 | Tools *ToolsCapability `json:"tools,omitempty"` 47 | } 48 | 49 | type PromptsCapability struct { 50 | ListChanged bool `json:"listChanged,omitempty"` 51 | } 52 | 53 | type ResourcesCapability struct { 54 | ListChanged bool `json:"listChanged,omitempty"` 55 | Subscribe bool `json:"subscribe,omitempty"` 56 | } 57 | 58 | type ToolsCapability struct { 59 | ListChanged bool `json:"listChanged,omitempty"` 60 | } 61 | 62 | // InitializedNotification represents the notification sent from client to server after initialization 63 | type InitializedNotification struct { 64 | Meta map[string]interface{} `json:"_meta,omitempty"` 65 | } 66 | 67 | // NewInitializeRequest creates a new initialize request 68 | func NewInitializeRequest(clientInfo *Implementation, capabilities *ClientCapabilities) *InitializeRequest { 69 | return &InitializeRequest{ 70 | ClientInfo: clientInfo, 71 | Capabilities: capabilities, 72 | ProtocolVersion: Version, 73 | } 74 | } 75 | 76 | // NewInitializeResult creates a new initialize response 77 | func NewInitializeResult(serverInfo *Implementation, capabilities *ServerCapabilities, version string, instructions string) *InitializeResult { 78 | return &InitializeResult{ 79 | ServerInfo: serverInfo, 80 | Capabilities: capabilities, 81 | ProtocolVersion: version, 82 | Instructions: instructions, 83 | } 84 | } 85 | 86 | // NewInitializedNotification creates a new initialized notification 87 | func NewInitializedNotification() *InitializedNotification { 88 | return &InitializedNotification{} 89 | } 90 | 91 | func IsInitializedRequest(rawParams json.RawMessage) bool { 92 | return gjson.ParseBytes(rawParams).Get("method").String() == string(Initialize) 93 | } 94 | -------------------------------------------------------------------------------- /protocol/jsonrpc.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 7 | ) 8 | 9 | const jsonrpcVersion = "2.0" 10 | 11 | // Standard JSON-RPC error codes 12 | const ( 13 | ParseError = -32700 // Invalid JSON 14 | InvalidRequest = -32600 // The JSON sent is not a valid Request object 15 | MethodNotFound = -32601 // The method does not exist / is not available 16 | InvalidParams = -32602 // Invalid method parameter(s) 17 | InternalError = -32603 // Internal JSON-RPC error 18 | 19 | // 可以定义自己的错误代码,范围在-32000 以上。 20 | ConnectionError = -32400 21 | ) 22 | 23 | type RequestID interface{} // 字符串/数值 24 | 25 | type JSONRPCRequest struct { 26 | JSONRPC string `json:"jsonrpc"` 27 | ID RequestID `json:"id"` 28 | Method Method `json:"method"` 29 | Params interface{} `json:"params,omitempty"` 30 | RawParams json.RawMessage `json:"-"` 31 | } 32 | 33 | func (r *JSONRPCRequest) UnmarshalJSON(data []byte) error { 34 | type alias JSONRPCRequest 35 | temp := &struct { 36 | Params json.RawMessage `json:"params,omitempty"` 37 | *alias 38 | }{ 39 | alias: (*alias)(r), 40 | } 41 | 42 | if err := pkg.JSONUnmarshal(data, temp); err != nil { 43 | return err 44 | } 45 | 46 | r.RawParams = temp.Params 47 | 48 | if len(r.RawParams) != 0 { 49 | if err := pkg.JSONUnmarshal(r.RawParams, &r.Params); err != nil { 50 | return err 51 | } 52 | } 53 | 54 | return nil 55 | } 56 | 57 | // IsValid checks if the request is valid according to JSON-RPC 2.0 spec 58 | func (r *JSONRPCRequest) IsValid() bool { 59 | return r.JSONRPC == jsonrpcVersion && r.Method != "" && r.ID != nil 60 | } 61 | 62 | // JSONRPCResponse represents a response to a request. 63 | type JSONRPCResponse struct { 64 | JSONRPC string `json:"jsonrpc"` 65 | ID RequestID `json:"id"` 66 | Result interface{} `json:"result,omitempty"` 67 | RawResult json.RawMessage `json:"-"` 68 | Error *responseErr `json:"error,omitempty"` 69 | } 70 | 71 | type responseErr struct { 72 | // The error type that occurred. 73 | Code int `json:"code"` 74 | // A short description of the error. The message SHOULD be limited 75 | // to a concise single sentence. 76 | Message string `json:"message"` 77 | // Additional information about the error. The value of this member 78 | // is defined by the sender (e.g. detailed error information, nested errors etc.). 79 | Data interface{} `json:"data,omitempty"` 80 | } 81 | 82 | func (r *JSONRPCResponse) UnmarshalJSON(data []byte) error { 83 | type alias JSONRPCResponse 84 | temp := &struct { 85 | Result json.RawMessage `json:"result,omitempty"` 86 | *alias 87 | }{ 88 | alias: (*alias)(r), 89 | } 90 | 91 | if err := pkg.JSONUnmarshal(data, temp); err != nil { 92 | return err 93 | } 94 | 95 | r.RawResult = temp.Result 96 | 97 | if len(r.RawResult) != 0 { 98 | if err := pkg.JSONUnmarshal(r.RawResult, &r.Result); err != nil { 99 | return err 100 | } 101 | } 102 | 103 | return nil 104 | } 105 | 106 | type JSONRPCNotification struct { 107 | JSONRPC string `json:"jsonrpc"` 108 | Method Method `json:"method"` 109 | Params interface{} `json:"params,omitempty"` 110 | RawParams json.RawMessage `json:"-"` 111 | } 112 | 113 | func (r *JSONRPCNotification) UnmarshalJSON(data []byte) error { 114 | type alias JSONRPCNotification 115 | temp := &struct { 116 | Params json.RawMessage `json:"params,omitempty"` 117 | *alias 118 | }{ 119 | alias: (*alias)(r), 120 | } 121 | 122 | if err := pkg.JSONUnmarshal(data, temp); err != nil { 123 | return err 124 | } 125 | 126 | r.RawParams = temp.Params 127 | 128 | if len(r.RawParams) != 0 { 129 | if err := pkg.JSONUnmarshal(r.RawParams, &r.Params); err != nil { 130 | return err 131 | } 132 | } 133 | 134 | return nil 135 | } 136 | 137 | // NewJSONRPCRequest creates a new JSON-RPC request 138 | func NewJSONRPCRequest(id RequestID, method Method, params interface{}) *JSONRPCRequest { 139 | return &JSONRPCRequest{ 140 | JSONRPC: jsonrpcVersion, 141 | ID: id, 142 | Method: method, 143 | Params: params, 144 | } 145 | } 146 | 147 | // NewJSONRPCSuccessResponse creates a new JSON-RPC response 148 | func NewJSONRPCSuccessResponse(id RequestID, result interface{}) *JSONRPCResponse { 149 | return &JSONRPCResponse{ 150 | JSONRPC: jsonrpcVersion, 151 | ID: id, 152 | Result: result, 153 | } 154 | } 155 | 156 | // NewJSONRPCErrorResponse NewError creates a new JSON-RPC error response 157 | func NewJSONRPCErrorResponse(id RequestID, code int, message string) *JSONRPCResponse { 158 | err := &JSONRPCResponse{ 159 | JSONRPC: jsonrpcVersion, 160 | ID: id, 161 | Error: &responseErr{ 162 | Code: code, 163 | Message: message, 164 | }, 165 | } 166 | return err 167 | } 168 | 169 | // NewJSONRPCNotification creates a new JSON-RPC notification 170 | func NewJSONRPCNotification(method Method, params interface{}) *JSONRPCNotification { 171 | return &JSONRPCNotification{ 172 | JSONRPC: jsonrpcVersion, 173 | Method: method, 174 | Params: params, 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /protocol/logging.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // LoggingLevel represents the severity of a log message 4 | type LoggingLevel string 5 | 6 | const ( 7 | LogEmergency LoggingLevel = "emergency" 8 | LogAlert LoggingLevel = "alert" 9 | LogCritical LoggingLevel = "critical" 10 | LogError LoggingLevel = "error" 11 | LogWarning LoggingLevel = "warning" 12 | LogNotice LoggingLevel = "notice" 13 | LogInfo LoggingLevel = "info" 14 | LogDebug LoggingLevel = "debug" 15 | ) 16 | 17 | // SetLoggingLevelRequest represents a request to set the logging level 18 | type SetLoggingLevelRequest struct { 19 | Level LoggingLevel `json:"level"` 20 | } 21 | 22 | // SetLoggingLevelResult represents the response to a set logging level request 23 | type SetLoggingLevelResult struct { 24 | Success bool `json:"success"` 25 | } 26 | 27 | // LogMessageNotification represents a log message notification 28 | type LogMessageNotification struct { 29 | Level LoggingLevel `json:"level"` 30 | Message string `json:"message"` 31 | Meta map[string]interface{} `json:"meta,omitempty"` 32 | } 33 | 34 | // NewSetLoggingLevelRequest creates a new set logging level request 35 | func NewSetLoggingLevelRequest(level LoggingLevel) *SetLoggingLevelRequest { 36 | return &SetLoggingLevelRequest{ 37 | Level: level, 38 | } 39 | } 40 | 41 | // NewSetLoggingLevelResult creates a new set logging level response 42 | func NewSetLoggingLevelResult(success bool) *SetLoggingLevelResult { 43 | return &SetLoggingLevelResult{ 44 | Success: success, 45 | } 46 | } 47 | 48 | // NewLogMessageNotification creates a new log message notification 49 | func NewLogMessageNotification(level LoggingLevel, message string, meta map[string]interface{}) *LogMessageNotification { 50 | return &LogMessageNotification{ 51 | Level: level, 52 | Message: message, 53 | Meta: meta, 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /protocol/pagination.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/base64" 5 | "sort" 6 | ) 7 | 8 | // Cursor is an opaque token used to represent a cursor for pagination. 9 | type Cursor string 10 | 11 | type Named interface { 12 | GetName() string 13 | } 14 | 15 | func PaginationLimit[T Named](allElements []T, cursor Cursor, limit int) ([]T, Cursor, error) { 16 | sort.Slice(allElements, func(i, j int) bool { 17 | return allElements[i].GetName() < allElements[j].GetName() 18 | }) 19 | startPos := 0 20 | if cursor != "" { 21 | c, err := base64.StdEncoding.DecodeString(string(cursor)) 22 | if err != nil { 23 | return nil, "", err 24 | } 25 | cString := string(c) 26 | startPos = sort.Search(len(allElements), func(i int) bool { 27 | nc := allElements[i].GetName() 28 | return nc > cString 29 | }) 30 | } 31 | endPos := len(allElements) 32 | if len(allElements) > startPos+limit { 33 | endPos = startPos + limit 34 | } 35 | elementsToReturn := allElements[startPos:endPos] 36 | // set the next cursor 37 | nextCursor := func() Cursor { 38 | if len(elementsToReturn) < limit { 39 | return "" 40 | } 41 | element := elementsToReturn[len(elementsToReturn)-1] 42 | nc := element.GetName() 43 | toString := base64.StdEncoding.EncodeToString([]byte(nc)) 44 | return Cursor(toString) 45 | }() 46 | return elementsToReturn, nextCursor, nil 47 | } 48 | 49 | // PaginatedRequest represents a request that supports pagination 50 | type PaginatedRequest struct { 51 | Cursor Cursor `json:"cursor,omitempty"` 52 | } 53 | 54 | // PaginatedResult represents a response that supports pagination 55 | type PaginatedResult struct { 56 | NextCursor Cursor `json:"nextCursor,omitempty"` 57 | } 58 | -------------------------------------------------------------------------------- /protocol/pagination_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "reflect" 7 | "sort" 8 | "testing" 9 | ) 10 | 11 | func BenchmarkPaginationLimitForReflect(b *testing.B) { 12 | list := getTools(10000) 13 | for i := 0; i < b.N; i++ { 14 | _, _, _ = PaginationLimitForReflect[*Tool](list, "dG9vbDMz", 10) 15 | } 16 | } 17 | 18 | func BenchmarkPaginationLimitForTool(b *testing.B) { 19 | list := getTools(10000) 20 | for i := 0; i < b.N; i++ { 21 | _, _, _ = PaginationLimitForTool(list, "dG9vbDMz", 10) 22 | } 23 | } 24 | 25 | func BenchmarkPaginationLimit(b *testing.B) { 26 | list := getTools(10000) 27 | for i := 0; i < b.N; i++ { 28 | _, _, _ = PaginationLimit(list, "dG9vbDMz", 10) 29 | } 30 | } 31 | 32 | func getTools(length int) []*Tool { 33 | list := make([]*Tool, 0, 10000) 34 | for i := 0; i < length; i++ { 35 | list = append(list, &Tool{ 36 | Name: fmt.Sprintf("tool%d", i), 37 | Description: fmt.Sprintf("tool%d", i), 38 | }) 39 | } 40 | return list 41 | } 42 | 43 | func PaginationLimitForTool(allElements []*Tool, cursor Cursor, limit int) ([]*Tool, Cursor, error) { 44 | startPos := 0 45 | if cursor != "" { 46 | c, err := base64.StdEncoding.DecodeString(string(cursor)) 47 | if err != nil { 48 | return nil, "", err 49 | } 50 | cString := string(c) 51 | startPos = sort.Search(len(allElements), func(i int) bool { 52 | nc := allElements[i].Name 53 | return nc > cString 54 | }) 55 | } 56 | endPos := len(allElements) 57 | if len(allElements) > startPos+limit { 58 | endPos = startPos + limit 59 | } 60 | elementsToReturn := allElements[startPos:endPos] 61 | // set the next cursor 62 | nextCursor := func() Cursor { 63 | if len(elementsToReturn) < limit { 64 | return "" 65 | } 66 | element := elementsToReturn[len(elementsToReturn)-1] 67 | nc := element.Name 68 | toString := base64.StdEncoding.EncodeToString([]byte(nc)) 69 | return Cursor(toString) 70 | }() 71 | return elementsToReturn, nextCursor, nil 72 | } 73 | 74 | func PaginationLimitForReflect[T any](allElements []T, cursor Cursor, limit int) ([]T, Cursor, error) { 75 | startPos := 0 76 | if cursor != "" { 77 | c, err := base64.StdEncoding.DecodeString(string(cursor)) 78 | if err != nil { 79 | return nil, "", err 80 | } 81 | cString := string(c) 82 | startPos = sort.Search(len(allElements), func(i int) bool { 83 | val := reflect.ValueOf(allElements[i]) 84 | var nc string 85 | if val.Kind() == reflect.Ptr { 86 | val = val.Elem() 87 | } 88 | nc = val.FieldByName("Name").String() 89 | return nc > cString 90 | }) 91 | } 92 | endPos := len(allElements) 93 | if len(allElements) > startPos+limit { 94 | endPos = startPos + limit 95 | } 96 | elementsToReturn := allElements[startPos:endPos] 97 | // set the next cursor 98 | nextCursor := func() Cursor { 99 | if len(elementsToReturn) < limit { 100 | return "" 101 | } 102 | element := elementsToReturn[len(elementsToReturn)-1] 103 | val := reflect.ValueOf(element) 104 | var nc string 105 | if val.Kind() == reflect.Ptr { 106 | val = val.Elem() 107 | } 108 | nc = val.FieldByName("Name").String() 109 | toString := base64.StdEncoding.EncodeToString([]byte(nc)) 110 | return Cursor(toString) 111 | }() 112 | return elementsToReturn, nextCursor, nil 113 | } 114 | -------------------------------------------------------------------------------- /protocol/ping.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | type PingRequest struct{} 4 | 5 | type PingResult struct{} 6 | 7 | // NewPingRequest creates a new ping request 8 | func NewPingRequest() *PingRequest { 9 | return &PingRequest{} 10 | } 11 | 12 | // NewPingResult creates a new ping response 13 | func NewPingResult() *PingResult { 14 | return &PingResult{} 15 | } 16 | -------------------------------------------------------------------------------- /protocol/progress.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | const ProgressTokenKey = "progressToken" 4 | 5 | // ProgressNotification represents a progress notification for a long-running request 6 | type ProgressNotification struct { 7 | ProgressToken ProgressToken `json:"progressToken"` 8 | Progress float64 `json:"progress"` 9 | Total float64 `json:"total,omitempty"` 10 | Message string `json:"message,omitempty"` 11 | } 12 | 13 | // ProgressToken represents a token used to associate progress notifications with the original request 14 | type ProgressToken interface{} // can be string or integer 15 | 16 | // NewProgressNotification creates a new progress notification 17 | func NewProgressNotification(progress float64, total float64, message string) *ProgressNotification { 18 | return &ProgressNotification{ 19 | Progress: progress, 20 | Total: total, 21 | Message: message, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /protocol/prompts.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 8 | ) 9 | 10 | // ListPromptsRequest represents a request to list available prompts 11 | type ListPromptsRequest struct { 12 | Cursor Cursor `json:"cursor,omitempty"` 13 | } 14 | 15 | // ListPromptsResult represents the response to a list prompts request 16 | type ListPromptsResult struct { 17 | Prompts []*Prompt `json:"prompts"` 18 | NextCursor Cursor `json:"nextCursor,omitempty"` 19 | } 20 | 21 | // Prompt related types 22 | type Prompt struct { 23 | Name string `json:"name"` 24 | Description string `json:"description,omitempty"` 25 | Arguments []*PromptArgument `json:"arguments,omitempty"` 26 | } 27 | 28 | func (p *Prompt) GetName() string { 29 | return p.Name 30 | } 31 | 32 | type PromptArgument struct { 33 | Name string `json:"name"` 34 | Description string `json:"description,omitempty"` 35 | Required bool `json:"required,omitempty"` 36 | } 37 | 38 | // GetPromptRequest represents a request to get a specific prompt 39 | type GetPromptRequest struct { 40 | Name string `json:"name"` 41 | Arguments map[string]string `json:"arguments,omitempty"` 42 | } 43 | 44 | // GetPromptResult represents the response to a get prompt request 45 | type GetPromptResult struct { 46 | Messages []*PromptMessage `json:"messages"` 47 | Description string `json:"description,omitempty"` 48 | } 49 | 50 | type PromptMessage struct { 51 | Role Role `json:"role"` 52 | Content Content `json:"content"` 53 | } 54 | 55 | // UnmarshalJSON implements the json.Unmarshaler interface for PromptMessage 56 | func (m *PromptMessage) UnmarshalJSON(data []byte) error { 57 | type Alias PromptMessage 58 | aux := &struct { 59 | Content json.RawMessage `json:"content"` 60 | *Alias 61 | }{ 62 | Alias: (*Alias)(m), 63 | } 64 | if err := pkg.JSONUnmarshal(data, &aux); err != nil { 65 | return err 66 | } 67 | 68 | // Try to unmarshal content as TextContent first 69 | var textContent *TextContent 70 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil { 71 | m.Content = textContent 72 | return nil 73 | } 74 | 75 | // Try to unmarshal content as ImageContent 76 | var imageContent *ImageContent 77 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil { 78 | m.Content = imageContent 79 | return nil 80 | } 81 | 82 | // Try to unmarshal content as AudioContent 83 | var audioContent *AudioContent 84 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil { 85 | m.Content = audioContent 86 | return nil 87 | } 88 | 89 | // Try to unmarshal content as embeddedResource 90 | var embeddedResource *EmbeddedResource 91 | if err := pkg.JSONUnmarshal(aux.Content, &embeddedResource); err == nil { 92 | m.Content = embeddedResource 93 | return nil 94 | } 95 | 96 | return fmt.Errorf("unknown content type") 97 | } 98 | 99 | // PromptListChangedNotification represents a notification that the prompt list has changed 100 | type PromptListChangedNotification struct { 101 | Meta map[string]interface{} `json:"_meta,omitempty"` 102 | } 103 | 104 | // NewListPromptsRequest creates a new list prompts request 105 | func NewListPromptsRequest() *ListPromptsRequest { 106 | return &ListPromptsRequest{} 107 | } 108 | 109 | // NewListPromptsResult creates a new list prompts response 110 | func NewListPromptsResult(prompts []*Prompt, nextCursor Cursor) *ListPromptsResult { 111 | return &ListPromptsResult{ 112 | Prompts: prompts, 113 | NextCursor: nextCursor, 114 | } 115 | } 116 | 117 | // NewGetPromptRequest creates a new get prompt request 118 | func NewGetPromptRequest(name string, arguments map[string]string) *GetPromptRequest { 119 | return &GetPromptRequest{ 120 | Name: name, 121 | Arguments: arguments, 122 | } 123 | } 124 | 125 | // NewGetPromptResult creates a new get prompt response 126 | func NewGetPromptResult(messages []*PromptMessage, description string) *GetPromptResult { 127 | return &GetPromptResult{ 128 | Messages: messages, 129 | Description: description, 130 | } 131 | } 132 | 133 | // NewPromptListChangedNotification creates a new prompt list changed notification 134 | func NewPromptListChangedNotification() *PromptListChangedNotification { 135 | return &PromptListChangedNotification{} 136 | } 137 | -------------------------------------------------------------------------------- /protocol/roots.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // ListRootsRequest represents a request to list root directories 4 | type ListRootsRequest struct{} 5 | 6 | // ListRootsResult represents the response to a list roots request 7 | type ListRootsResult struct { 8 | Roots []*Root `json:"roots"` 9 | } 10 | 11 | // Root represents a root directory or file that the server can operate on 12 | type Root struct { 13 | Name string `json:"name,omitempty"` 14 | URI string `json:"uri"` 15 | } 16 | 17 | // RootsListChangedNotification represents a notification that the roots list has changed 18 | type RootsListChangedNotification struct { 19 | Meta map[string]interface{} `json:"_meta,omitempty"` 20 | } 21 | 22 | // NewListRootsRequest creates a new list roots request 23 | func NewListRootsRequest() *ListRootsRequest { 24 | return &ListRootsRequest{} 25 | } 26 | 27 | // NewListRootsResult creates a new list roots response 28 | func NewListRootsResult(roots []*Root) *ListRootsResult { 29 | return &ListRootsResult{ 30 | Roots: roots, 31 | } 32 | } 33 | 34 | // NewRootsListChangedNotification creates a new roots list changed notification 35 | func NewRootsListChangedNotification() *RootsListChangedNotification { 36 | return &RootsListChangedNotification{} 37 | } 38 | -------------------------------------------------------------------------------- /protocol/sampling.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 8 | ) 9 | 10 | // CreateMessageRequest represents a request to create a message through sampling 11 | type CreateMessageRequest struct { 12 | Messages []*SamplingMessage `json:"messages"` 13 | MaxTokens int `json:"maxTokens"` 14 | Temperature float64 `json:"temperature,omitempty"` 15 | StopSequences []string `json:"stopSequences,omitempty"` 16 | SystemPrompt string `json:"systemPrompt,omitempty"` 17 | ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` 18 | IncludeContext string `json:"includeContext,omitempty"` 19 | Metadata map[string]interface{} `json:"metadata,omitempty"` 20 | } 21 | 22 | type SamplingMessage struct { 23 | Role Role `json:"role"` 24 | Content Content `json:"content"` 25 | } 26 | 27 | // UnmarshalJSON implements the json.Unmarshaler interface for SamplingMessage 28 | func (r *SamplingMessage) UnmarshalJSON(data []byte) error { 29 | type Alias SamplingMessage 30 | aux := &struct { 31 | Content json.RawMessage `json:"content"` 32 | *Alias 33 | }{ 34 | Alias: (*Alias)(r), 35 | } 36 | if err := pkg.JSONUnmarshal(data, &aux); err != nil { 37 | return err 38 | } 39 | 40 | // Try to unmarshal content as TextContent first 41 | var textContent *TextContent 42 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil { 43 | r.Content = textContent 44 | return nil 45 | } 46 | 47 | // Try to unmarshal content as ImageContent 48 | var imageContent *ImageContent 49 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil { 50 | r.Content = imageContent 51 | return nil 52 | } 53 | 54 | // Try to unmarshal content as AudioContent 55 | var audioContent *AudioContent 56 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil { 57 | r.Content = audioContent 58 | return nil 59 | } 60 | 61 | return fmt.Errorf("unknown content type, content=%s", aux.Content) 62 | } 63 | 64 | // CreateMessageResult represents the response to a create message request 65 | type CreateMessageResult struct { 66 | Content Content `json:"content"` 67 | Role Role `json:"role"` 68 | Model string `json:"model"` 69 | StopReason string `json:"stopReason,omitempty"` 70 | } 71 | 72 | // UnmarshalJSON implements the json.Unmarshaler interface for CreateMessageResult 73 | func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { 74 | type Alias CreateMessageResult 75 | aux := &struct { 76 | Content json.RawMessage `json:"content"` 77 | *Alias 78 | }{ 79 | Alias: (*Alias)(r), 80 | } 81 | if err := pkg.JSONUnmarshal(data, &aux); err != nil { 82 | return err 83 | } 84 | 85 | // Try to unmarshal content as TextContent first 86 | var textContent *TextContent 87 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil { 88 | r.Content = textContent 89 | return nil 90 | } 91 | 92 | // Try to unmarshal content as ImageContent 93 | var imageContent *ImageContent 94 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil { 95 | r.Content = imageContent 96 | return nil 97 | } 98 | 99 | // Try to unmarshal content as AudioContent 100 | var audioContent *AudioContent 101 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil { 102 | r.Content = audioContent 103 | return nil 104 | } 105 | 106 | return fmt.Errorf("unknown content type, content=%s", aux.Content) 107 | } 108 | 109 | // NewCreateMessageRequest creates a new create message request 110 | func NewCreateMessageRequest(messages []*SamplingMessage, maxTokens int, opts ...CreateMessageOption) *CreateMessageRequest { 111 | req := &CreateMessageRequest{ 112 | Messages: messages, 113 | MaxTokens: maxTokens, 114 | } 115 | 116 | for _, opt := range opts { 117 | opt(req) 118 | } 119 | 120 | return req 121 | } 122 | 123 | // NewCreateMessageResult creates a new create message response 124 | func NewCreateMessageResult(content Content, role Role, model string, stopReason string) *CreateMessageResult { 125 | return &CreateMessageResult{ 126 | Content: content, 127 | Role: role, 128 | Model: model, 129 | StopReason: stopReason, 130 | } 131 | } 132 | 133 | // CreateMessageOption represents an option for creating a message 134 | type CreateMessageOption func(*CreateMessageRequest) 135 | 136 | // WithTemperature sets the temperature for the request 137 | func WithTemperature(temp float64) CreateMessageOption { 138 | return func(r *CreateMessageRequest) { 139 | r.Temperature = temp 140 | } 141 | } 142 | 143 | // WithStopSequences sets the stop sequences for the request 144 | func WithStopSequences(sequences []string) CreateMessageOption { 145 | return func(r *CreateMessageRequest) { 146 | r.StopSequences = sequences 147 | } 148 | } 149 | 150 | // WithSystemPrompt sets the system prompt for the request 151 | func WithSystemPrompt(prompt string) CreateMessageOption { 152 | return func(r *CreateMessageRequest) { 153 | r.SystemPrompt = prompt 154 | } 155 | } 156 | 157 | // WithModelPreferences sets the model preferences for the request 158 | func WithModelPreferences(prefs *ModelPreferences) CreateMessageOption { 159 | return func(r *CreateMessageRequest) { 160 | r.ModelPreferences = prefs 161 | } 162 | } 163 | 164 | // WithIncludeContext sets the include context option for the request 165 | func WithIncludeContext(ctx string) CreateMessageOption { 166 | return func(r *CreateMessageRequest) { 167 | r.IncludeContext = ctx 168 | } 169 | } 170 | 171 | // WithMetadata sets the metadata for the request 172 | func WithMetadata(metadata map[string]interface{}) CreateMessageOption { 173 | return func(r *CreateMessageRequest) { 174 | r.Metadata = metadata 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /protocol/schema_generate.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 10 | ) 11 | 12 | type DataType string 13 | 14 | const ( 15 | ObjectT DataType = "object" 16 | Number DataType = "number" 17 | Integer DataType = "integer" 18 | String DataType = "string" 19 | Array DataType = "array" 20 | Null DataType = "null" 21 | Boolean DataType = "boolean" 22 | ) 23 | 24 | type Property struct { 25 | Type DataType `json:"type"` 26 | // Description is the description of the schema. 27 | Description string `json:"description,omitempty"` 28 | // Items specifies which data type an array contains, if the schema type is Array. 29 | Items *Property `json:"items,omitempty"` 30 | // Properties describes the properties of an object, if the schema type is Object. 31 | Properties map[string]*Property `json:"properties,omitempty"` 32 | Required []string `json:"required,omitempty"` 33 | Enum []string `json:"enum,omitempty"` 34 | } 35 | 36 | var schemaCache = pkg.SyncMap[*InputSchema]{} 37 | 38 | func generateSchemaFromReqStruct(v any) (*InputSchema, error) { 39 | t := reflect.TypeOf(v) 40 | for t.Kind() != reflect.Struct { 41 | if t.Kind() != reflect.Ptr { 42 | return nil, fmt.Errorf("invalid type %v", t) 43 | } 44 | t = t.Elem() 45 | } 46 | 47 | typeUID := getTypeUUID(t) 48 | if schema, ok := schemaCache.Load(typeUID); ok { 49 | return schema, nil 50 | } 51 | 52 | schema := &InputSchema{Type: Object} 53 | 54 | property, err := reflectSchemaByObject(t) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | schema.Properties = property.Properties 60 | schema.Required = property.Required 61 | 62 | schemaCache.Store(typeUID, schema) 63 | return schema, nil 64 | } 65 | 66 | func getTypeUUID(t reflect.Type) string { 67 | if t.PkgPath() != "" && t.Name() != "" { 68 | return t.PkgPath() + "." + t.Name() 69 | } 70 | // fallback for unnamed types (like anonymous struct) 71 | return t.String() 72 | } 73 | 74 | func reflectSchemaByObject(t reflect.Type) (*Property, error) { 75 | var ( 76 | properties = make(map[string]*Property) 77 | requiredFields = make([]string, 0) 78 | anonymousFields = make([]reflect.StructField, 0) 79 | ) 80 | 81 | for i := 0; i < t.NumField(); i++ { 82 | field := t.Field(i) 83 | 84 | if field.Anonymous { 85 | anonymousFields = append(anonymousFields, field) 86 | continue 87 | } 88 | 89 | if !field.IsExported() { 90 | continue 91 | } 92 | 93 | jsonTag := field.Tag.Get("json") 94 | if jsonTag == "-" { 95 | continue 96 | } 97 | required := true 98 | if jsonTag == "" { 99 | jsonTag = field.Name 100 | } 101 | if strings.HasSuffix(jsonTag, ",omitempty") { 102 | jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") 103 | required = false 104 | } 105 | 106 | item, err := reflectSchemaByType(field.Type) 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | if description := field.Tag.Get("description"); description != "" { 112 | item.Description = description 113 | } 114 | properties[jsonTag] = item 115 | 116 | if s := field.Tag.Get("required"); s != "" { 117 | required, err = strconv.ParseBool(s) 118 | if err != nil { 119 | return nil, fmt.Errorf("invalid required field %v: %v", jsonTag, err) 120 | } 121 | } 122 | if required { 123 | requiredFields = append(requiredFields, jsonTag) 124 | } 125 | 126 | if v := field.Tag.Get("enum"); v != "" { 127 | enumValues := strings.Split(v, ",") 128 | for j, value := range enumValues { 129 | enumValues[j] = strings.TrimSpace(value) 130 | } 131 | 132 | // Check if enum values are consistent with the field type 133 | for _, value := range enumValues { 134 | switch field.Type.Kind() { 135 | case reflect.String: 136 | // No additional processing required for string type 137 | case reflect.Int, reflect.Int64: 138 | if _, err := strconv.Atoi(value); err != nil { 139 | return nil, fmt.Errorf("enum value %q is not compatible with type %v", value, field.Type) 140 | } 141 | case reflect.Float64: 142 | if _, err := strconv.ParseFloat(value, 64); err != nil { 143 | return nil, fmt.Errorf("enum value %q is not compatible with type %v", value, field.Type) 144 | } 145 | default: 146 | return nil, fmt.Errorf("unsupported type %v for enum validation", field.Type) 147 | } 148 | } 149 | item.Enum = enumValues 150 | } 151 | } 152 | 153 | for _, field := range anonymousFields { 154 | object, err := reflectSchemaByObject(field.Type) 155 | if err != nil { 156 | return nil, err 157 | } 158 | for propName, propValue := range object.Properties { 159 | if _, ok := properties[propName]; ok { 160 | return nil, fmt.Errorf("duplicate property name %s in anonymous struct", propName) 161 | } 162 | properties[propName] = propValue 163 | } 164 | requiredFields = append(requiredFields, object.Required...) 165 | } 166 | 167 | property := &Property{ 168 | Type: ObjectT, 169 | Properties: properties, 170 | Required: requiredFields, 171 | } 172 | return property, nil 173 | } 174 | 175 | func reflectSchemaByType(t reflect.Type) (*Property, error) { 176 | s := &Property{} 177 | 178 | switch t.Kind() { 179 | case reflect.String: 180 | s.Type = String 181 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 182 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 183 | s.Type = Integer 184 | case reflect.Float32, reflect.Float64: 185 | s.Type = Number 186 | case reflect.Bool: 187 | s.Type = Boolean 188 | case reflect.Slice, reflect.Array: 189 | s.Type = Array 190 | items, err := reflectSchemaByType(t.Elem()) 191 | if err != nil { 192 | return nil, err 193 | } 194 | s.Items = items 195 | case reflect.Struct: 196 | object, err := reflectSchemaByObject(t) 197 | if err != nil { 198 | return nil, err 199 | } 200 | object.Type = ObjectT 201 | s = object 202 | case reflect.Map: 203 | if t.Key().Kind() != reflect.String { 204 | return nil, fmt.Errorf("map key type %s is not supported", t.Key().Kind()) 205 | } 206 | object := &Property{ 207 | Type: ObjectT, 208 | } 209 | s = object 210 | case reflect.Ptr: 211 | p, err := reflectSchemaByType(t.Elem()) 212 | if err != nil { 213 | return nil, err 214 | } 215 | s = p 216 | case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, 217 | reflect.Chan, reflect.Func, reflect.Interface, 218 | reflect.UnsafePointer: 219 | return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) 220 | default: 221 | } 222 | return s, nil 223 | } 224 | -------------------------------------------------------------------------------- /protocol/schema_validate.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "strconv" 9 | 10 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 11 | ) 12 | 13 | func VerifyAndUnmarshal(content json.RawMessage, v any) error { 14 | if len(content) == 0 { 15 | return fmt.Errorf("request arguments is empty") 16 | } 17 | 18 | t := reflect.TypeOf(v) 19 | for t.Kind() != reflect.Struct { 20 | if t.Kind() != reflect.Ptr { 21 | return fmt.Errorf("invalid type %v, plz use func `pkg.JSONUnmarshal` instead", t) 22 | } 23 | t = t.Elem() 24 | } 25 | 26 | typeUID := getTypeUUID(t) 27 | schema, ok := schemaCache.Load(typeUID) 28 | if !ok { 29 | return fmt.Errorf("schema has not been generated,unable to verify: plz use func `pkg.JSONUnmarshal` instead") 30 | } 31 | 32 | return verifySchemaAndUnmarshal(Property{ 33 | Type: ObjectT, 34 | Properties: schema.Properties, 35 | Required: schema.Required, 36 | }, content, v) 37 | } 38 | 39 | func verifySchemaAndUnmarshal(schema Property, content []byte, v any) error { 40 | var data any 41 | err := pkg.JSONUnmarshal(content, &data) 42 | if err != nil { 43 | return err 44 | } 45 | if !validate(schema, data) { 46 | return errors.New("data validation failed against the provided schema") 47 | } 48 | return pkg.JSONUnmarshal(content, &v) 49 | } 50 | 51 | func validate(schema Property, data any) bool { 52 | switch schema.Type { 53 | case ObjectT: 54 | return validateObject(schema, data) 55 | case Array: 56 | return validateArray(schema, data) 57 | case String: 58 | str, ok := data.(string) 59 | if ok { 60 | return validateEnumProperty[string](str, schema.Enum, func(value string, enumValue string) bool { 61 | return value == enumValue 62 | }) 63 | } 64 | return false 65 | case Number: // float64 and int 66 | if num, ok := data.(float64); ok { 67 | return validateEnumProperty[float64](num, schema.Enum, func(value float64, enumValue string) bool { 68 | if enumNum, err := strconv.ParseFloat(enumValue, 64); err == nil && value == enumNum { 69 | return true 70 | } 71 | return false 72 | }) 73 | } 74 | if num, ok := data.(int); ok { 75 | return validateEnumProperty[int](num, schema.Enum, func(value int, enumValue string) bool { 76 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == enumNum { 77 | return true 78 | } 79 | return false 80 | }) 81 | } 82 | return false 83 | case Boolean: 84 | _, ok := data.(bool) 85 | return ok 86 | case Integer: 87 | // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer 88 | if num, ok := data.(float64); ok { 89 | if num == float64(int64(num)) { 90 | return validateEnumProperty[float64](num, schema.Enum, func(value float64, enumValue string) bool { 91 | if enumNum, err := strconv.ParseFloat(enumValue, 64); err == nil && value == enumNum { 92 | return true 93 | } 94 | return false 95 | }) 96 | } 97 | return false 98 | } 99 | 100 | if num, ok := data.(int); ok { 101 | return validateEnumProperty[int](num, schema.Enum, func(value int, enumValue string) bool { 102 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == enumNum { 103 | return true 104 | } 105 | return false 106 | }) 107 | } 108 | 109 | if num, ok := data.(int64); ok { 110 | return validateEnumProperty[int64](num, schema.Enum, func(value int64, enumValue string) bool { 111 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == int64(enumNum) { 112 | return true 113 | } 114 | return false 115 | }) 116 | } 117 | return false 118 | case Null: 119 | return data == nil 120 | default: 121 | return false 122 | } 123 | } 124 | 125 | func validateObject(schema Property, data any) bool { 126 | dataMap, ok := data.(map[string]any) 127 | if !ok { 128 | return false 129 | } 130 | for _, field := range schema.Required { 131 | if _, exists := dataMap[field]; !exists { 132 | return false 133 | } 134 | } 135 | for key, valueSchema := range schema.Properties { 136 | value, exists := dataMap[key] 137 | if exists && !validate(*valueSchema, value) { 138 | return false 139 | } 140 | } 141 | return true 142 | } 143 | 144 | func validateArray(schema Property, data any) bool { 145 | dataArray, ok := data.([]any) 146 | if !ok { 147 | return false 148 | } 149 | for _, item := range dataArray { 150 | if !validate(*schema.Items, item) { 151 | return false 152 | } 153 | } 154 | return true 155 | } 156 | 157 | func validateEnumProperty[T any](data T, enum []string, compareFunc func(T, string) bool) bool { 158 | for _, enumValue := range enum { 159 | if compareFunc(data, enumValue) { 160 | return true 161 | } 162 | } 163 | return len(enum) == 0 164 | } 165 | -------------------------------------------------------------------------------- /protocol/types.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | const Version = "2025-03-26" 4 | 5 | var SupportedVersion = map[string]struct{}{ 6 | "2024-11-05": {}, 7 | "2025-03-26": {}, 8 | } 9 | 10 | // Method represents the JSON-RPC method name 11 | type Method string 12 | 13 | const ( 14 | // Core methods 15 | Ping Method = "ping" 16 | Initialize Method = "initialize" 17 | NotificationInitialized Method = "notifications/initialized" 18 | 19 | // Root related methods 20 | RootsList Method = "roots/list" 21 | NotificationRootsListChanged Method = "notifications/roots/list_changed" 22 | 23 | // Resource related methods 24 | ResourcesList Method = "resources/list" 25 | ResourceListTemplates Method = "resources/templates/list" 26 | ResourcesRead Method = "resources/read" 27 | ResourcesSubscribe Method = "resources/subscribe" 28 | ResourcesUnsubscribe Method = "resources/unsubscribe" 29 | NotificationResourcesListChanged Method = "notifications/resources/list_changed" 30 | NotificationResourcesUpdated Method = "notifications/resources/updated" 31 | 32 | // Tool related methods 33 | ToolsList Method = "tools/list" 34 | ToolsCall Method = "tools/call" 35 | NotificationToolsListChanged Method = "notifications/tools/list_changed" 36 | 37 | // Prompt related methods 38 | PromptsList Method = "prompts/list" 39 | PromptsGet Method = "prompts/get" 40 | NotificationPromptsListChanged Method = "notifications/prompts/list_changed" 41 | 42 | // Sampling related methods 43 | SamplingCreateMessage Method = "sampling/createMessage" 44 | 45 | // Logging related methods 46 | LoggingSetLevel Method = "logging/setLevel" 47 | NotificationLogMessage Method = "notifications/message" 48 | 49 | // Completion related methods 50 | CompletionComplete Method = "completion/complete" 51 | 52 | // progress related methods 53 | NotificationProgress Method = "notifications/progress" 54 | NotificationCancelled Method = "notifications/cancelled" // nolint:misspell 55 | ) 56 | 57 | // Role represents the sender or recipient of messages and data in a conversation 58 | type Role string 59 | 60 | const ( 61 | RoleUser Role = "user" 62 | RoleAssistant Role = "assistant" 63 | ) 64 | 65 | type ClientRequest interface{} 66 | 67 | var ( 68 | _ ClientRequest = &InitializeRequest{} 69 | _ ClientRequest = &PingRequest{} 70 | _ ClientRequest = &ListPromptsRequest{} 71 | _ ClientRequest = &GetPromptRequest{} 72 | _ ClientRequest = &ListResourcesRequest{} 73 | _ ClientRequest = &ReadResourceRequest{} 74 | _ ClientRequest = &ListResourceTemplatesRequest{} 75 | _ ClientRequest = &SubscribeRequest{} 76 | _ ClientRequest = &UnsubscribeRequest{} 77 | _ ClientRequest = &ListToolsRequest{} 78 | _ ClientRequest = &CallToolRequest{} 79 | _ ClientRequest = &CompleteRequest{} 80 | _ ClientRequest = &SetLoggingLevelRequest{} 81 | ) 82 | 83 | type ClientResponse interface{} 84 | 85 | var ( 86 | _ ClientResponse = &PingResult{} 87 | _ ClientResponse = &ListToolsResult{} 88 | _ ClientResponse = &CreateMessageResult{} 89 | ) 90 | 91 | type ClientNotify interface{} 92 | 93 | var ( 94 | _ ClientNotify = &InitializedNotification{} 95 | _ ClientNotify = &CancelledNotification{} 96 | _ ClientNotify = &ProgressNotification{} 97 | _ ClientNotify = &RootsListChangedNotification{} 98 | ) 99 | 100 | type ServerRequest interface{} 101 | 102 | var ( 103 | _ ServerRequest = &PingRequest{} 104 | _ ServerRequest = &ListRootsRequest{} 105 | _ ServerRequest = &CreateMessageRequest{} 106 | ) 107 | 108 | type ServerResponse interface{} 109 | 110 | var ( 111 | _ ServerResponse = &InitializeResult{} 112 | _ ServerResponse = &PingResult{} 113 | _ ServerResponse = &ListPromptsResult{} 114 | _ ServerResponse = &GetPromptResult{} 115 | _ ServerResponse = &ListResourcesResult{} 116 | _ ServerResponse = &ReadResourceResult{} 117 | _ ServerResponse = &ListResourceTemplatesResult{} 118 | _ ServerResponse = &SubscribeResult{} 119 | _ ServerResponse = &UnsubscribeResult{} 120 | _ ServerResponse = &ListToolsResult{} 121 | _ ServerResponse = &CallToolResult{} 122 | _ ServerResponse = &CompleteResult{} 123 | _ ServerResponse = &SetLoggingLevelResult{} 124 | ) 125 | 126 | type ServerNotify interface{} 127 | 128 | var ( 129 | _ ServerNotify = &CancelledNotification{} 130 | _ ServerNotify = &ProgressNotification{} 131 | _ ServerNotify = &ToolListChangedNotification{} 132 | _ ServerNotify = &PromptListChangedNotification{} 133 | _ ServerNotify = &ResourceListChangedNotification{} 134 | _ ServerNotify = &ResourceUpdatedNotification{} 135 | _ ServerNotify = &LogMessageNotification{} 136 | ) 137 | -------------------------------------------------------------------------------- /server/call.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "strconv" 8 | 9 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 10 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 11 | "github.com/ThinkInAIXYZ/go-mcp/server/session" 12 | ) 13 | 14 | func (server *Server) Ping(ctx context.Context, request *protocol.PingRequest) (*protocol.PingResult, error) { 15 | sessionID, err := GetSessionIDFromCtx(ctx) 16 | if err != nil { 17 | return nil, err 18 | } 19 | 20 | response, err := server.callClient(ctx, sessionID, protocol.Ping, request) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | var result protocol.PingResult 26 | if err = pkg.JSONUnmarshal(response, &result); err != nil { 27 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 28 | } 29 | return &result, nil 30 | } 31 | 32 | func (server *Server) Sampling(ctx context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error) { 33 | sessionID, err := GetSessionIDFromCtx(ctx) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | s, ok := server.sessionManager.GetSession(sessionID) 39 | if !ok { 40 | return nil, pkg.ErrLackSession 41 | } 42 | 43 | if s.GetClientCapabilities() == nil || s.GetClientCapabilities().Sampling == nil { 44 | return nil, pkg.ErrClientNotSupport 45 | } 46 | 47 | response, err := server.callClient(ctx, sessionID, protocol.SamplingCreateMessage, request) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | var result protocol.CreateMessageResult 53 | if err = pkg.JSONUnmarshal(response, &result); err != nil { 54 | return nil, fmt.Errorf("failed to unmarshal response: %w", err) 55 | } 56 | return &result, nil 57 | } 58 | 59 | func (server *Server) SendProgressNotification(ctx context.Context, notify *protocol.ProgressNotification) error { 60 | progressToken, err := getProgressTokenFromCtx(ctx) 61 | if err != nil { 62 | return err 63 | } 64 | notify.ProgressToken = progressToken 65 | 66 | if err = server.sendMsgWithNotification(ctx, "", protocol.NotificationProgress, notify); err != nil { 67 | return err 68 | } 69 | 70 | return nil 71 | } 72 | 73 | func (server *Server) sendNotification4ToolListChanges(ctx context.Context) error { 74 | if server.capabilities.Tools == nil || !server.capabilities.Tools.ListChanged { 75 | return pkg.ErrServerNotSupport 76 | } 77 | 78 | var errList []error 79 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool { 80 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationToolsListChanged, protocol.NewToolListChangedNotification()); err != nil { 81 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err)) 82 | } 83 | return true 84 | }) 85 | return pkg.JoinErrors(errList) 86 | } 87 | 88 | func (server *Server) sendNotification4PromptListChanges(ctx context.Context) error { 89 | if server.capabilities.Prompts == nil || !server.capabilities.Prompts.ListChanged { 90 | return pkg.ErrServerNotSupport 91 | } 92 | 93 | var errList []error 94 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool { 95 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationPromptsListChanged, protocol.NewPromptListChangedNotification()); err != nil { 96 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err)) 97 | } 98 | return true 99 | }) 100 | return pkg.JoinErrors(errList) 101 | } 102 | 103 | func (server *Server) sendNotification4ResourceListChanges(ctx context.Context) error { 104 | if server.capabilities.Resources == nil || !server.capabilities.Resources.ListChanged { 105 | return pkg.ErrServerNotSupport 106 | } 107 | 108 | var errList []error 109 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool { 110 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationResourcesListChanged, 111 | protocol.NewResourceListChangedNotification()); err != nil { 112 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err)) 113 | } 114 | return true 115 | }) 116 | return pkg.JoinErrors(errList) 117 | } 118 | 119 | func (server *Server) SendNotification4ResourcesUpdated(ctx context.Context, notify *protocol.ResourceUpdatedNotification) error { 120 | if server.capabilities.Resources == nil || !server.capabilities.Resources.Subscribe { 121 | return pkg.ErrServerNotSupport 122 | } 123 | 124 | var errList []error 125 | server.sessionManager.RangeSessions(func(sessionID string, s *session.State) bool { 126 | if _, ok := s.GetSubscribedResources().Get(notify.URI); !ok { 127 | return true 128 | } 129 | 130 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationResourcesUpdated, notify); err != nil { 131 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err)) 132 | } 133 | return true 134 | }) 135 | return pkg.JoinErrors(errList) 136 | } 137 | 138 | // Responsible for request and response assembly 139 | func (server *Server) callClient(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerRequest) (json.RawMessage, error) { 140 | session, ok := server.sessionManager.GetSession(sessionID) 141 | if !ok { 142 | return nil, fmt.Errorf("callClient: %w", pkg.ErrLackSession) 143 | } 144 | 145 | requestID := strconv.FormatInt(session.IncRequestID(), 10) 146 | respChan := make(chan *protocol.JSONRPCResponse, 1) 147 | session.GetServerReqID2respChan().Set(requestID, respChan) 148 | defer session.GetServerReqID2respChan().Remove(requestID) 149 | 150 | if err := server.sendMsgWithRequest(ctx, sessionID, requestID, method, params); err != nil { 151 | return nil, fmt.Errorf("callClient: %w", err) 152 | } 153 | 154 | select { 155 | case <-ctx.Done(): 156 | return nil, ctx.Err() 157 | case response := <-respChan: 158 | if err := response.Error; err != nil { 159 | return nil, pkg.NewResponseError(err.Code, err.Message, err.Data) 160 | } 161 | return response.RawResult, nil 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /server/context.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | type sessionIDKey struct{} 9 | 10 | func setSessionIDToCtx(ctx context.Context, sessionID string) context.Context { 11 | return context.WithValue(ctx, sessionIDKey{}, sessionID) 12 | } 13 | 14 | func GetSessionIDFromCtx(ctx context.Context) (string, error) { 15 | sessionID := ctx.Value(sessionIDKey{}) 16 | if sessionID == nil { 17 | return "", errors.New("no session id found") 18 | } 19 | return sessionID.(string), nil 20 | } 21 | 22 | type sendChanKey struct{} 23 | 24 | func setSendChanToCtx(ctx context.Context, sendCh chan<- []byte) context.Context { 25 | return context.WithValue(ctx, sendChanKey{}, sendCh) 26 | } 27 | 28 | func getSendChanFromCtx(ctx context.Context) (chan<- []byte, error) { 29 | ch := ctx.Value(sendChanKey{}) 30 | if ch == nil { 31 | return nil, errors.New("no send chan found") 32 | } 33 | return ch.(chan<- []byte), nil 34 | } 35 | 36 | type progressTokenKey struct{} 37 | 38 | func setProgressTokenToCtx(ctx context.Context, progressToken interface{}) context.Context { 39 | return context.WithValue(ctx, progressTokenKey{}, progressToken) 40 | } 41 | 42 | func getProgressTokenFromCtx(ctx context.Context) (interface{}, error) { 43 | progressToken := ctx.Value(progressTokenKey{}) 44 | if progressToken == nil { 45 | return "", errors.New("no progress token found") 46 | } 47 | return progressToken, nil 48 | } 49 | -------------------------------------------------------------------------------- /server/receive.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | 9 | "github.com/tidwall/gjson" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 12 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 13 | ) 14 | 15 | func (server *Server) receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) { 16 | if sessionID != "" && !server.sessionManager.IsActiveSession(sessionID) { 17 | if server.sessionManager.IsClosedSession(sessionID) { 18 | return nil, pkg.ErrSessionClosed 19 | } 20 | return nil, pkg.ErrLackSession 21 | } 22 | 23 | if !gjson.GetBytes(msg, "id").Exists() { 24 | notify := &protocol.JSONRPCNotification{} 25 | if err := pkg.JSONUnmarshal(msg, ¬ify); err != nil { 26 | return nil, err 27 | } 28 | if err := server.receiveNotify(sessionID, notify); err != nil { 29 | notify.RawParams = nil // simplified log 30 | server.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) 31 | return nil, err 32 | } 33 | return nil, nil 34 | } 35 | 36 | // case request or response 37 | if !gjson.GetBytes(msg, "method").Exists() { 38 | resp := &protocol.JSONRPCResponse{} 39 | if err := pkg.JSONUnmarshal(msg, &resp); err != nil { 40 | return nil, err 41 | } 42 | 43 | if err := server.receiveResponse(sessionID, resp); err != nil { 44 | resp.RawResult = nil // simplified log 45 | server.logger.Errorf("receive response:%+v error: %s", resp, err.Error()) 46 | return nil, err 47 | } 48 | return nil, nil 49 | } 50 | 51 | req := &protocol.JSONRPCRequest{} 52 | if err := pkg.JSONUnmarshal(msg, &req); err != nil { 53 | return nil, err 54 | } 55 | if !req.IsValid() { 56 | return nil, pkg.ErrRequestInvalid 57 | } 58 | 59 | // if sessionID != "" && req.Method != protocol.Initialize && req.Method != protocol.Ping { 60 | // if s, ok := server.sessionManager.GetSession(sessionID); !ok { 61 | // return nil, pkg.ErrLackSession 62 | // } else if !s.GetReady() { 63 | // return nil, pkg.ErrSessionHasNotInitialized 64 | // } 65 | // } 66 | 67 | server.inFlyRequest.Add(1) 68 | 69 | if server.inShutdown.Load() { 70 | server.inFlyRequest.Done() 71 | return nil, errors.New("server already shutdown") 72 | } 73 | 74 | ch := make(chan []byte, 5) 75 | go func(ctx context.Context) { 76 | defer pkg.Recover() 77 | defer server.inFlyRequest.Done() 78 | defer close(ch) 79 | 80 | if s, ok := server.sessionManager.GetSession(sessionID); ok && req.Method != protocol.Initialize { 81 | var cancel context.CancelFunc 82 | ctx, cancel = context.WithCancel(ctx) 83 | requestID := fmt.Sprint(req.ID) 84 | s.GetClientReqID2cancelFunc().Set(requestID, cancel) 85 | defer s.GetClientReqID2cancelFunc().Remove(requestID) 86 | } 87 | 88 | if r := gjson.GetBytes(req.RawParams, fmt.Sprintf("_meta.%s", protocol.ProgressTokenKey)); r.Exists() { 89 | ctx = setProgressTokenToCtx(ctx, r.Value()) 90 | } 91 | 92 | ctx = setSendChanToCtx(ctx, ch) 93 | 94 | resp := server.receiveRequest(ctx, sessionID, req) 95 | if errors.Is(ctx.Err(), context.Canceled) { 96 | return 97 | } 98 | message, err := json.Marshal(resp) 99 | if err != nil { 100 | server.logger.Errorf("receive json marshal response:%+v error: %s", resp, err.Error()) 101 | return 102 | } 103 | ch <- message 104 | }(pkg.NewCancelShieldContext(ctx)) 105 | return ch, nil 106 | } 107 | 108 | func (server *Server) receiveRequest(ctx context.Context, sessionID string, request *protocol.JSONRPCRequest) *protocol.JSONRPCResponse { 109 | if sessionID != "" { 110 | ctx = setSessionIDToCtx(ctx, sessionID) 111 | } 112 | 113 | if request.Method != protocol.Ping { 114 | server.sessionManager.UpdateSessionLastActiveAt(sessionID) 115 | } 116 | 117 | var ( 118 | result protocol.ServerResponse 119 | err error 120 | ) 121 | 122 | switch request.Method { 123 | case protocol.Ping: 124 | result, err = server.handleRequestWithPing() 125 | case protocol.Initialize: 126 | result, err = server.handleRequestWithInitialize(ctx, sessionID, request.RawParams) 127 | case protocol.PromptsList: 128 | result, err = server.handleRequestWithListPrompts(request.RawParams) 129 | case protocol.PromptsGet: 130 | result, err = server.handleRequestWithGetPrompt(ctx, request.RawParams) 131 | case protocol.ResourcesList: 132 | result, err = server.handleRequestWithListResources(request.RawParams) 133 | case protocol.ResourceListTemplates: 134 | result, err = server.handleRequestWithListResourceTemplates(request.RawParams) 135 | case protocol.ResourcesRead: 136 | result, err = server.handleRequestWithReadResource(ctx, request.RawParams) 137 | case protocol.ResourcesSubscribe: 138 | result, err = server.handleRequestWithSubscribeResourceChange(sessionID, request.RawParams) 139 | case protocol.ResourcesUnsubscribe: 140 | result, err = server.handleRequestWithUnSubscribeResourceChange(sessionID, request.RawParams) 141 | case protocol.ToolsList: 142 | result, err = server.handleRequestWithListTools(request.RawParams) 143 | case protocol.ToolsCall: 144 | result, err = server.handleRequestWithCallTool(ctx, request.RawParams) 145 | default: 146 | err = fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, request.Method) 147 | } 148 | 149 | if err != nil { 150 | var code int 151 | switch { 152 | case errors.Is(err, pkg.ErrMethodNotSupport): 153 | code = protocol.MethodNotFound 154 | case errors.Is(err, pkg.ErrRequestInvalid): 155 | code = protocol.InvalidRequest 156 | case errors.Is(err, pkg.ErrJSONUnmarshal): 157 | code = protocol.ParseError 158 | default: 159 | code = protocol.InternalError 160 | } 161 | return protocol.NewJSONRPCErrorResponse(request.ID, code, err.Error()) 162 | } 163 | return protocol.NewJSONRPCSuccessResponse(request.ID, result) 164 | } 165 | 166 | func (server *Server) receiveNotify(sessionID string, notify *protocol.JSONRPCNotification) error { 167 | // if sessionID != "" { 168 | // if s, ok := server.sessionManager.GetSession(sessionID); !ok { 169 | // return pkg.ErrLackSession 170 | // } else if notify.Method != protocol.NotificationInitialized && !s.GetReady() { 171 | // return pkg.ErrSessionHasNotInitialized 172 | // } 173 | // } 174 | 175 | switch notify.Method { 176 | case protocol.NotificationInitialized: 177 | return server.handleNotifyWithInitialized(sessionID, notify.RawParams) 178 | case protocol.NotificationCancelled: 179 | return server.handleNotifyWithCancelled(sessionID, notify.RawParams) 180 | default: 181 | return fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, notify.Method) 182 | } 183 | } 184 | 185 | func (server *Server) receiveResponse(sessionID string, response *protocol.JSONRPCResponse) error { 186 | s, ok := server.sessionManager.GetSession(sessionID) 187 | if !ok { 188 | return pkg.ErrLackSession 189 | } 190 | 191 | respChan, ok := s.GetServerReqID2respChan().Get(fmt.Sprint(response.ID)) 192 | if !ok { 193 | return fmt.Errorf("%w: sessionID=%+v, requestID=%+v", pkg.ErrLackResponseChan, sessionID, response.ID) 194 | } 195 | 196 | select { 197 | case respChan <- response: 198 | default: 199 | return fmt.Errorf("%w: sessionID=%+v, response=%+v", pkg.ErrDuplicateResponseReceived, sessionID, response) 200 | } 201 | return nil 202 | } 203 | -------------------------------------------------------------------------------- /server/send.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 9 | ) 10 | 11 | func (server *Server) sendMsgWithRequest(ctx context.Context, sessionID string, requestID protocol.RequestID, 12 | method protocol.Method, params protocol.ServerRequest, 13 | ) error { //nolint:whitespace 14 | if requestID == nil { 15 | return fmt.Errorf("requestID can't is nil") 16 | } 17 | 18 | req := protocol.NewJSONRPCRequest(requestID, method, params) 19 | 20 | message, err := json.Marshal(req) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | if ch, err := getSendChanFromCtx(ctx); err == nil { 26 | ch <- message 27 | return nil 28 | } 29 | 30 | if err := server.transport.Send(ctx, sessionID, message); err != nil { 31 | return fmt.Errorf("sendRequest: transport send: %w", err) 32 | } 33 | return nil 34 | } 35 | 36 | func (server *Server) sendMsgWithNotification(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerNotify) error { 37 | notify := protocol.NewJSONRPCNotification(method, params) 38 | 39 | message, err := json.Marshal(notify) 40 | if err != nil { 41 | return err 42 | } 43 | 44 | if ch, err := getSendChanFromCtx(ctx); err == nil { 45 | ch <- message 46 | return nil 47 | } 48 | 49 | if err := server.transport.Send(ctx, sessionID, message); err != nil { 50 | return fmt.Errorf("sendNotification: transport send: %w", err) 51 | } 52 | return nil 53 | } 54 | -------------------------------------------------------------------------------- /server/session/manager.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 8 | ) 9 | 10 | type Manager struct { 11 | activeSessions pkg.SyncMap[*State] 12 | closedSessions pkg.SyncMap[struct{}] 13 | 14 | stopHeartbeat chan struct{} 15 | 16 | genSessionID func(ctx context.Context) string 17 | 18 | logger pkg.Logger 19 | 20 | detection func(ctx context.Context, sessionID string) error 21 | maxIdleTime time.Duration 22 | } 23 | 24 | func NewManager(detection func(ctx context.Context, sessionID string) error, genSessionID func(ctx context.Context) string) *Manager { 25 | return &Manager{ 26 | genSessionID: genSessionID, 27 | detection: detection, 28 | stopHeartbeat: make(chan struct{}), 29 | logger: pkg.DefaultLogger, 30 | } 31 | } 32 | 33 | func (m *Manager) SetMaxIdleTime(d time.Duration) { 34 | m.maxIdleTime = d 35 | } 36 | 37 | func (m *Manager) SetLogger(logger pkg.Logger) { 38 | m.logger = logger 39 | } 40 | 41 | func (m *Manager) CreateSession(ctx context.Context) string { 42 | sessionID := m.genSessionID(ctx) 43 | state := NewState() 44 | m.activeSessions.Store(sessionID, state) 45 | return sessionID 46 | } 47 | 48 | func (m *Manager) IsActiveSession(sessionID string) bool { 49 | _, has := m.activeSessions.Load(sessionID) 50 | return has 51 | } 52 | 53 | func (m *Manager) IsClosedSession(sessionID string) bool { 54 | _, has := m.closedSessions.Load(sessionID) 55 | return has 56 | } 57 | 58 | func (m *Manager) GetSession(sessionID string) (*State, bool) { 59 | if sessionID == "" { 60 | return nil, false 61 | } 62 | state, has := m.activeSessions.Load(sessionID) 63 | if !has { 64 | return nil, false 65 | } 66 | return state, true 67 | } 68 | 69 | func (m *Manager) OpenMessageQueueForSend(sessionID string) error { 70 | state, has := m.GetSession(sessionID) 71 | if !has { 72 | return pkg.ErrLackSession 73 | } 74 | state.openMessageQueueForSend() 75 | return nil 76 | } 77 | 78 | func (m *Manager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error { 79 | state, has := m.GetSession(sessionID) 80 | if !has { 81 | return pkg.ErrLackSession 82 | } 83 | return state.enqueueMessage(ctx, message) 84 | } 85 | 86 | func (m *Manager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { 87 | state, has := m.GetSession(sessionID) 88 | if !has { 89 | return nil, pkg.ErrLackSession 90 | } 91 | return state.dequeueMessage(ctx) 92 | } 93 | 94 | func (m *Manager) UpdateSessionLastActiveAt(sessionID string) { 95 | state, ok := m.activeSessions.Load(sessionID) 96 | if !ok { 97 | return 98 | } 99 | state.updateLastActiveAt() 100 | } 101 | 102 | func (m *Manager) CloseSession(sessionID string) { 103 | state, ok := m.activeSessions.LoadAndDelete(sessionID) 104 | if !ok { 105 | return 106 | } 107 | state.Close() 108 | m.closedSessions.Store(sessionID, struct{}{}) 109 | } 110 | 111 | func (m *Manager) CloseAllSessions() { 112 | m.activeSessions.Range(func(sessionID string, _ *State) bool { 113 | // Here we load the session again to prevent concurrency conflicts with CloseSession, which may cause repeated close chan 114 | m.CloseSession(sessionID) 115 | return true 116 | }) 117 | } 118 | 119 | func (m *Manager) StartHeartbeatAndCleanInvalidSessions() { 120 | ticker := time.NewTicker(time.Minute) 121 | defer ticker.Stop() 122 | 123 | for { 124 | select { 125 | case <-m.stopHeartbeat: 126 | return 127 | case <-ticker.C: 128 | now := time.Now() 129 | m.activeSessions.Range(func(sessionID string, state *State) bool { 130 | if m.maxIdleTime != 0 && now.Sub(state.lastActiveAt) > m.maxIdleTime { 131 | m.logger.Infof("session expire, session id: %v", sessionID) 132 | m.CloseSession(sessionID) 133 | return true 134 | } 135 | 136 | var err error 137 | for i := 0; i < 3; i++ { 138 | if err = m.detection(context.Background(), sessionID); err == nil { 139 | return true 140 | } 141 | } 142 | m.logger.Infof("session detection fail, session id: %v, fail reason: %+v", sessionID, err) 143 | m.CloseSession(sessionID) 144 | return true 145 | }) 146 | } 147 | } 148 | } 149 | 150 | func (m *Manager) StopHeartbeat() { 151 | close(m.stopHeartbeat) 152 | } 153 | 154 | func (m *Manager) RangeSessions(f func(sessionID string, state *State) bool) { 155 | m.activeSessions.Range(f) 156 | } 157 | 158 | func (m *Manager) IsEmpty() bool { 159 | isEmpty := true 160 | m.activeSessions.Range(func(string, *State) bool { 161 | isEmpty = false 162 | return false 163 | }) 164 | return isEmpty 165 | } 166 | -------------------------------------------------------------------------------- /server/session/state.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | cmap "github.com/orcaman/concurrent-map/v2" 11 | 12 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 13 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 14 | ) 15 | 16 | var ErrQueueNotOpened = errors.New("queue has not been opened") 17 | 18 | type State struct { 19 | lastActiveAt time.Time 20 | 21 | mu sync.RWMutex 22 | sendChan chan []byte 23 | 24 | requestID int64 25 | 26 | serverReqID2respChan cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse] 27 | 28 | clientReqID2cancelFunc cmap.ConcurrentMap[string, context.CancelFunc] 29 | 30 | // cache client initialize request info 31 | clientInfo *protocol.Implementation 32 | clientCapabilities *protocol.ClientCapabilities 33 | 34 | // subscribed resources 35 | subscribedResources cmap.ConcurrentMap[string, struct{}] 36 | 37 | receivedInitRequest *pkg.AtomicBool 38 | ready *pkg.AtomicBool 39 | closed *pkg.AtomicBool 40 | } 41 | 42 | func NewState() *State { 43 | return &State{ 44 | lastActiveAt: time.Now(), 45 | serverReqID2respChan: cmap.New[chan *protocol.JSONRPCResponse](), 46 | clientReqID2cancelFunc: cmap.New[context.CancelFunc](), 47 | subscribedResources: cmap.New[struct{}](), 48 | receivedInitRequest: pkg.NewAtomicBool(), 49 | ready: pkg.NewAtomicBool(), 50 | closed: pkg.NewAtomicBool(), 51 | } 52 | } 53 | 54 | func (s *State) SetClientInfo(ClientInfo *protocol.Implementation, ClientCapabilities *protocol.ClientCapabilities) { 55 | s.clientInfo = ClientInfo 56 | s.clientCapabilities = ClientCapabilities 57 | } 58 | 59 | func (s *State) GetClientCapabilities() *protocol.ClientCapabilities { 60 | return s.clientCapabilities 61 | } 62 | 63 | func (s *State) SetReceivedInitRequest() { 64 | s.receivedInitRequest.Store(true) 65 | } 66 | 67 | func (s *State) GetReceivedInitRequest() bool { 68 | return s.receivedInitRequest.Load() 69 | } 70 | 71 | func (s *State) SetReady() { 72 | s.ready.Store(true) 73 | } 74 | 75 | func (s *State) GetReady() bool { 76 | return s.ready.Load() 77 | } 78 | 79 | func (s *State) IncRequestID() int64 { 80 | return atomic.AddInt64(&s.requestID, 1) 81 | } 82 | 83 | func (s *State) GetServerReqID2respChan() cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse] { 84 | return s.serverReqID2respChan 85 | } 86 | 87 | func (s *State) GetClientReqID2cancelFunc() cmap.ConcurrentMap[string, context.CancelFunc] { 88 | return s.clientReqID2cancelFunc 89 | } 90 | 91 | func (s *State) GetSubscribedResources() cmap.ConcurrentMap[string, struct{}] { 92 | return s.subscribedResources 93 | } 94 | 95 | func (s *State) Close() { 96 | s.mu.Lock() 97 | defer s.mu.Unlock() 98 | 99 | s.closed.Store(true) 100 | 101 | if s.sendChan != nil { 102 | close(s.sendChan) 103 | } 104 | } 105 | 106 | func (s *State) updateLastActiveAt() { 107 | s.lastActiveAt = time.Now() 108 | } 109 | 110 | func (s *State) openMessageQueueForSend() { 111 | s.mu.Lock() 112 | defer s.mu.Unlock() 113 | 114 | if s.sendChan == nil { 115 | s.sendChan = make(chan []byte, 64) 116 | } 117 | } 118 | 119 | func (s *State) enqueueMessage(ctx context.Context, message []byte) error { 120 | s.mu.RLock() 121 | defer s.mu.RUnlock() 122 | 123 | if s.closed.Load() { 124 | return errors.New("session already closed") 125 | } 126 | 127 | if s.sendChan == nil { 128 | return ErrQueueNotOpened 129 | } 130 | 131 | select { 132 | case s.sendChan <- message: 133 | return nil 134 | case <-ctx.Done(): 135 | return ctx.Err() 136 | } 137 | } 138 | 139 | func (s *State) dequeueMessage(ctx context.Context) ([]byte, error) { 140 | s.mu.RLock() 141 | if s.sendChan == nil { 142 | s.mu.RUnlock() 143 | return nil, ErrQueueNotOpened 144 | } 145 | s.mu.RUnlock() 146 | 147 | select { 148 | case <-ctx.Done(): 149 | return nil, ctx.Err() 150 | case msg, ok := <-s.sendChan: 151 | if msg == nil && !ok { 152 | // There are no new messages and the chan has been closed, indicating that the request may need to be terminated. 153 | return nil, pkg.ErrSendEOF 154 | } 155 | return msg, nil 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /testdata/mock_block_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | func main() { 9 | if _, err := os.Stdin.Read(make([]byte, 1)); err != nil { 10 | fmt.Println(err) 11 | return 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/tests/.DS_Store -------------------------------------------------------------------------------- /tests/sse_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | "os/exec" 8 | "strconv" 9 | "testing" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/transport" 12 | ) 13 | 14 | func TestSSE(t *testing.T) { 15 | port, err := getAvailablePort() 16 | if err != nil { 17 | t.Fatalf("Failed to get available port: %v", err) 18 | } 19 | 20 | transportClient, err := transport.NewSSEClientTransport(fmt.Sprintf("http://127.0.0.1:%d/sse", port)) 21 | if err != nil { 22 | t.Fatalf("Failed to create transport client: %v", err) 23 | } 24 | 25 | test(t, func() error { return runSSEServer(port) }, transportClient, transport.Stateful) 26 | } 27 | 28 | // getAvailablePort returns a port that is available for use 29 | func getAvailablePort() (int, error) { 30 | addr, err := net.Listen("tcp", "127.0.0.1:0") 31 | if err != nil { 32 | return 0, fmt.Errorf("failed to get available port: %v", err) 33 | } 34 | defer func() { 35 | if err = addr.Close(); err != nil { 36 | fmt.Println(err) 37 | } 38 | }() 39 | 40 | port := addr.Addr().(*net.TCPAddr).Port 41 | return port, nil 42 | } 43 | 44 | func runSSEServer(port int) error { 45 | mockServerTrPath, err := compileMockStdioServerTr() 46 | if err != nil { 47 | return err 48 | } 49 | fmt.Println(mockServerTrPath) 50 | 51 | defer func(name string) { 52 | if err := os.Remove(name); err != nil { 53 | fmt.Printf("failed to remove mock server: %v\n", err) 54 | } 55 | }(mockServerTrPath) 56 | 57 | return exec.Command(mockServerTrPath, "-transport", "sse", "-port", strconv.Itoa(port)).Run() 58 | } 59 | -------------------------------------------------------------------------------- /tests/stdio_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/ThinkInAIXYZ/go-mcp/transport" 9 | ) 10 | 11 | func TestStdio(t *testing.T) { 12 | mockServerTrPath, err := compileMockStdioServerTr() 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | defer func(name string) { 17 | if err = os.Remove(name); err != nil { 18 | fmt.Printf("Failed to remove mock server: %v\n", err) 19 | } 20 | }(mockServerTrPath) 21 | 22 | fmt.Println(mockServerTrPath) 23 | transportClient, err := transport.NewStdioClientTransport(mockServerTrPath, []string{"-transport", "stdio"}) 24 | if err != nil { 25 | t.Fatalf("Failed to create transport client: %v", err) 26 | } 27 | 28 | test(t, func() error { 29 | <-make(chan error) 30 | return nil 31 | }, transportClient, transport.Stateful) 32 | } 33 | -------------------------------------------------------------------------------- /tests/streamable_http_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/exec" 7 | "strconv" 8 | "testing" 9 | 10 | "github.com/ThinkInAIXYZ/go-mcp/transport" 11 | ) 12 | 13 | func TestStreamableHTTPWithStateless(t *testing.T) { 14 | port, err := getAvailablePort() 15 | if err != nil { 16 | t.Fatalf("Failed to get available port: %v", err) 17 | } 18 | 19 | transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port)) 20 | if err != nil { 21 | t.Fatalf("Failed to create transport client: %v", err) 22 | } 23 | 24 | test(t, func() error { return runStreamableHTTPServer(port, transport.Stateless) }, transportClient, transport.Stateless) 25 | } 26 | 27 | func TestStreamableHTTPWithStateful(t *testing.T) { 28 | port, err := getAvailablePort() 29 | if err != nil { 30 | t.Fatalf("Failed to get available port: %v", err) 31 | } 32 | 33 | transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port)) 34 | if err != nil { 35 | t.Fatalf("Failed to create transport client: %v", err) 36 | } 37 | 38 | test(t, func() error { return runStreamableHTTPServer(port, transport.Stateful) }, transportClient, transport.Stateful) 39 | } 40 | 41 | func runStreamableHTTPServer(port int, stateful transport.StateMode) error { 42 | mockServerTrPath, err := compileMockStdioServerTr() 43 | if err != nil { 44 | return err 45 | } 46 | fmt.Println(mockServerTrPath) 47 | 48 | defer func(name string) { 49 | if err := os.Remove(name); err != nil { 50 | fmt.Printf("failed to remove mock server: %v\n", err) 51 | } 52 | }(mockServerTrPath) 53 | 54 | return exec.Command(mockServerTrPath, "-transport", "streamable_http", "-port", strconv.Itoa(port), "-state_mode", string(stateful)).Run() 55 | } 56 | -------------------------------------------------------------------------------- /tests/utils.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "math/rand" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "strconv" 12 | "testing" 13 | "time" 14 | 15 | "github.com/ThinkInAIXYZ/go-mcp/client" 16 | "github.com/ThinkInAIXYZ/go-mcp/protocol" 17 | "github.com/ThinkInAIXYZ/go-mcp/transport" 18 | ) 19 | 20 | func test(t *testing.T, runServer func() error, transportClient transport.ClientTransport, mode transport.StateMode) { 21 | errCh := make(chan error, 1) 22 | go func() { 23 | errCh <- runServer() 24 | }() 25 | 26 | // Use select to handle potential errors 27 | select { 28 | case err := <-errCh: 29 | t.Fatalf("server.Run() failed: %v", err) 30 | case <-time.After(time.Second * 3): 31 | // Server started normally 32 | } 33 | 34 | // Create MCP client using transport 35 | mcpClient, err := client.NewClient(transportClient, client.WithClientInfo(&protocol.Implementation{ 36 | Name: "Example MCP Client", 37 | Version: "1.0.0", 38 | }), client.WithSamplingHandler(&sampling{})) 39 | if err != nil { 40 | t.Fatalf("Failed to create MCP client: %v", err) 41 | } 42 | defer func() { 43 | if err = mcpClient.Close(); err != nil { 44 | t.Fatalf("Failed to close MCP client: %v", err) 45 | return 46 | } 47 | }() 48 | 49 | // List available tools 50 | toolsResult, err := mcpClient.ListTools(context.Background()) 51 | if err != nil { 52 | t.Fatalf("Failed to list tools: %v", err) 53 | } 54 | bytes, _ := json.Marshal(toolsResult) 55 | fmt.Printf("Available tools: %s\n", bytes) 56 | 57 | callResult, err := mcpClient.CallTool( 58 | context.Background(), 59 | protocol.NewCallToolRequestWithRawArguments("current_time", json.RawMessage(`{"timezone": "UTC"}`))) 60 | if err != nil { 61 | t.Fatalf("Failed to call tool: %v", err) 62 | } 63 | bytes, _ = json.Marshal(callResult) 64 | fmt.Printf("Tool call result: %s\n", bytes) 65 | 66 | progressCh := make(chan *protocol.ProgressNotification) 67 | go func() { 68 | for progress := range progressCh { 69 | fmt.Printf("Progress: %+v\n", progress) 70 | } 71 | }() 72 | callResult, err = mcpClient.CallToolWithProgressChan(context.Background(), 73 | protocol.NewCallToolRequestWithRawArguments("generate_ppt", json.RawMessage(`{"ppt_description": "test"}`)), progressCh) 74 | if err != nil { 75 | t.Fatalf("Failed to call tool: %v", err) 76 | } 77 | bytes, _ = json.Marshal(callResult) 78 | fmt.Printf("Tool call result: %s\n", bytes) 79 | 80 | if mode == transport.Stateful { 81 | // if streamable_http transport, need wait streamable_http connection start 82 | time.Sleep(time.Second) 83 | 84 | callResult, err = mcpClient.CallTool( 85 | context.Background(), 86 | protocol.NewCallToolRequestWithRawArguments("delete_file", json.RawMessage(`{"file_name": "test_file.txt"}`))) 87 | if err != nil { 88 | t.Fatalf("Failed to call tool: %v", err) 89 | } 90 | bytes, _ = json.Marshal(callResult) 91 | fmt.Printf("Tool call result: %s\n", bytes) 92 | } 93 | } 94 | 95 | type sampling struct{} 96 | 97 | func (s *sampling) CreateMessage(_ context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error) { 98 | var lastUserMessages protocol.Content 99 | for _, message := range request.Messages { 100 | if message.Role == "user" { 101 | lastUserMessages = message.Content 102 | } 103 | } 104 | 105 | if lastUserMessages.GetType() != "text" { 106 | return nil, fmt.Errorf("expected 'text', got %s", lastUserMessages.GetType()) 107 | } 108 | 109 | return &protocol.CreateMessageResult{ 110 | Content: &protocol.TextContent{ 111 | Annotated: protocol.Annotated{}, 112 | Type: "text", 113 | Text: strconv.FormatBool(true), 114 | }, 115 | Role: "assistant", 116 | Model: "stub-model", 117 | StopReason: "endTurn", 118 | }, nil 119 | } 120 | 121 | func compileMockStdioServerTr() (string, error) { 122 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 123 | 124 | mockServerTrPath := filepath.Join(os.TempDir(), "mock_server_tr_"+strconv.Itoa(r.Int())) 125 | 126 | cmd := exec.Command("go", "build", "-o", mockServerTrPath, "../examples/everything/main.go") 127 | 128 | if output, err := cmd.CombinedOutput(); err != nil { 129 | return "", fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) 130 | } 131 | 132 | return mockServerTrPath, nil 133 | } 134 | -------------------------------------------------------------------------------- /transport/mock_client.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 12 | ) 13 | 14 | type mockClientTransport struct { 15 | receiver clientReceiver 16 | in io.ReadCloser 17 | out io.Writer 18 | 19 | logger pkg.Logger 20 | 21 | cancel context.CancelFunc 22 | receiveShutDone chan struct{} 23 | } 24 | 25 | func NewMockClientTransport(in io.ReadCloser, out io.Writer) ClientTransport { 26 | return &mockClientTransport{ 27 | in: in, 28 | out: out, 29 | logger: pkg.DefaultLogger, 30 | receiveShutDone: make(chan struct{}), 31 | } 32 | } 33 | 34 | func (t *mockClientTransport) Start() error { 35 | ctx, cancel := context.WithCancel(context.Background()) 36 | t.cancel = cancel 37 | 38 | go func() { 39 | defer pkg.Recover() 40 | 41 | t.startReceive(ctx) 42 | 43 | close(t.receiveShutDone) 44 | }() 45 | 46 | return nil 47 | } 48 | 49 | func (t *mockClientTransport) Send(_ context.Context, msg Message) error { 50 | if _, err := t.out.Write(append(msg, mcpMessageDelimiter)); err != nil { 51 | return fmt.Errorf("failed to write: %w", err) 52 | } 53 | return nil 54 | } 55 | 56 | func (t *mockClientTransport) SetReceiver(receiver clientReceiver) { 57 | t.receiver = receiver 58 | } 59 | 60 | func (t *mockClientTransport) Close() error { 61 | t.cancel() 62 | 63 | if err := t.in.Close(); err != nil { 64 | return fmt.Errorf("failed to close writer: %w", err) 65 | } 66 | 67 | <-t.receiveShutDone 68 | 69 | return nil 70 | } 71 | 72 | func (t *mockClientTransport) startReceive(ctx context.Context) { 73 | s := bufio.NewReader(t.in) 74 | 75 | for { 76 | line, err := s.ReadBytes('\n') 77 | if err != nil { 78 | t.receiver.Interrupt(fmt.Errorf("reader read error: %w", err)) 79 | 80 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here 81 | errors.Is(err, io.EOF) { 82 | return 83 | } 84 | t.logger.Errorf("reader read error: %+v", err) 85 | return 86 | } 87 | 88 | line = bytes.TrimRight(line, "\n") 89 | 90 | select { 91 | case <-ctx.Done(): 92 | return 93 | default: 94 | if err = t.receiver.Receive(ctx, line); err != nil { 95 | t.logger.Errorf("receiver failed: %v", err) 96 | } 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /transport/mock_server.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 12 | ) 13 | 14 | type mockServerTransport struct { 15 | receiver serverReceiver 16 | in io.ReadCloser 17 | out io.Writer 18 | 19 | sessionID string 20 | 21 | sessionManager sessionManager 22 | 23 | logger pkg.Logger 24 | 25 | cancel context.CancelFunc 26 | receiveShutDone chan struct{} 27 | } 28 | 29 | func NewMockServerTransport(in io.ReadCloser, out io.Writer) ServerTransport { 30 | return &mockServerTransport{ 31 | in: in, 32 | out: out, 33 | logger: pkg.DefaultLogger, 34 | 35 | receiveShutDone: make(chan struct{}), 36 | } 37 | } 38 | 39 | func (t *mockServerTransport) Run() error { 40 | ctx, cancel := context.WithCancel(context.Background()) 41 | t.cancel = cancel 42 | 43 | t.sessionID = t.sessionManager.CreateSession(context.Background()) 44 | 45 | t.startReceive(ctx) 46 | 47 | close(t.receiveShutDone) 48 | return nil 49 | } 50 | 51 | func (t *mockServerTransport) Send(_ context.Context, _ string, msg Message) error { 52 | if _, err := t.out.Write(append(msg, mcpMessageDelimiter)); err != nil { 53 | return fmt.Errorf("failed to write: %w", err) 54 | } 55 | return nil 56 | } 57 | 58 | func (t *mockServerTransport) SetReceiver(receiver serverReceiver) { 59 | t.receiver = receiver 60 | } 61 | 62 | func (t *mockServerTransport) SetSessionManager(m sessionManager) { 63 | t.sessionManager = m 64 | } 65 | 66 | func (t *mockServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error { 67 | t.cancel() 68 | 69 | if err := t.in.Close(); err != nil { 70 | return err 71 | } 72 | 73 | <-t.receiveShutDone 74 | 75 | select { 76 | case <-serverCtx.Done(): 77 | return nil 78 | case <-userCtx.Done(): 79 | return userCtx.Err() 80 | } 81 | } 82 | 83 | func (t *mockServerTransport) startReceive(ctx context.Context) { 84 | s := bufio.NewReader(t.in) 85 | 86 | for { 87 | line, err := s.ReadBytes('\n') 88 | if err != nil { 89 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here 90 | errors.Is(err, io.EOF) { 91 | return 92 | } 93 | t.logger.Errorf("client receive unexpected error reading input: %v", err) 94 | return 95 | } 96 | 97 | line = bytes.TrimRight(line, "\n") 98 | 99 | select { 100 | case <-ctx.Done(): 101 | return 102 | default: 103 | t.receive(ctx, line) 104 | } 105 | } 106 | } 107 | 108 | func (t *mockServerTransport) receive(ctx context.Context, line []byte) { 109 | outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, line) 110 | if err != nil { 111 | t.logger.Errorf("receiver failed: %v", err) 112 | return 113 | } 114 | 115 | if outputMsgCh == nil { 116 | return 117 | } 118 | 119 | go func() { 120 | defer pkg.Recover() 121 | 122 | for msg := range outputMsgCh { 123 | if e := t.Send(context.Background(), t.sessionID, msg); e != nil { 124 | t.logger.Errorf("Failed to send message: %v", e) 125 | } 126 | } 127 | }() 128 | } 129 | -------------------------------------------------------------------------------- /transport/mock_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | ) 7 | 8 | func TestMockTransport(t *testing.T) { 9 | reader1, writer1 := io.Pipe() 10 | reader2, writer2 := io.Pipe() 11 | 12 | serverTransport := NewMockServerTransport(reader2, writer1) 13 | clientTransport := NewMockClientTransport(reader1, writer2) 14 | 15 | testTransport(t, clientTransport, serverTransport) 16 | } 17 | -------------------------------------------------------------------------------- /transport/sse_client.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "time" 14 | 15 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 16 | ) 17 | 18 | type SSEClientTransportOption func(*sseClientTransport) 19 | 20 | func WithSSEClientOptionReceiveTimeout(timeout time.Duration) SSEClientTransportOption { 21 | return func(t *sseClientTransport) { 22 | t.receiveTimeout = timeout 23 | } 24 | } 25 | 26 | func WithSSEClientOptionHTTPClient(client *http.Client) SSEClientTransportOption { 27 | return func(t *sseClientTransport) { 28 | t.client = client 29 | } 30 | } 31 | 32 | func WithSSEClientOptionLogger(log pkg.Logger) SSEClientTransportOption { 33 | return func(t *sseClientTransport) { 34 | t.logger = log 35 | } 36 | } 37 | 38 | func WithRetryFunc(retry func(func() error)) SSEClientTransportOption { 39 | return func(t *sseClientTransport) { 40 | t.retry = retry 41 | } 42 | } 43 | 44 | type sseClientTransport struct { 45 | ctx context.Context 46 | cancel context.CancelFunc 47 | 48 | serverURL *url.URL 49 | 50 | endpointChan chan struct{} 51 | messageEndpoint *url.URL 52 | receiver clientReceiver 53 | 54 | // options 55 | logger pkg.Logger 56 | receiveTimeout time.Duration 57 | client *http.Client 58 | 59 | retry func(func() error) 60 | 61 | sseConnectClose chan struct{} 62 | } 63 | 64 | func NewSSEClientTransport(serverURL string, opts ...SSEClientTransportOption) (ClientTransport, error) { 65 | parsedURL, err := url.Parse(serverURL) 66 | if err != nil { 67 | return nil, fmt.Errorf("failed to parse server URL: %w", err) 68 | } 69 | 70 | t := &sseClientTransport{ 71 | serverURL: parsedURL, 72 | endpointChan: make(chan struct{}, 1), 73 | messageEndpoint: nil, 74 | receiver: nil, 75 | logger: pkg.DefaultLogger, 76 | receiveTimeout: time.Second * 30, 77 | client: http.DefaultClient, 78 | sseConnectClose: make(chan struct{}), 79 | retry: func(operation func() error) { 80 | for { 81 | if e := operation(); e == nil { 82 | return 83 | } 84 | time.Sleep(100 * time.Millisecond) 85 | } 86 | }, 87 | } 88 | 89 | for _, opt := range opts { 90 | opt(t) 91 | } 92 | 93 | return t, nil 94 | } 95 | 96 | func (t *sseClientTransport) Start() (err error) { 97 | ctx, cancel := context.WithCancel(context.Background()) 98 | t.ctx = ctx 99 | t.cancel = cancel 100 | 101 | defer func() { 102 | if err != nil { 103 | t.cancel() 104 | } 105 | }() 106 | 107 | errChan := make(chan error, 1) 108 | go func() { 109 | defer pkg.Recover() 110 | defer close(t.sseConnectClose) 111 | 112 | t.retry(func() error { 113 | if e := t.startSSE(); e != nil { 114 | if errors.Is(e, context.Canceled) { 115 | return nil 116 | } 117 | t.logger.Errorf("startSSE: %+v", e) 118 | t.receiver.Interrupt(fmt.Errorf("SSE connection disconnection: %w", e)) 119 | return e 120 | } 121 | return nil 122 | }) 123 | }() 124 | 125 | // Wait for the endpoint to be received 126 | select { 127 | case <-t.endpointChan: 128 | // Endpoint received, proceed 129 | case err = <-errChan: 130 | return fmt.Errorf("error in SSE stream: %w", err) 131 | case <-time.After(10 * time.Second): // Add a timeout 132 | return fmt.Errorf("timeout waiting for endpoint") 133 | } 134 | 135 | return nil 136 | } 137 | 138 | func (t *sseClientTransport) startSSE() error { 139 | req, err := http.NewRequestWithContext(t.ctx, http.MethodGet, t.serverURL.String(), nil) 140 | if err != nil { 141 | return fmt.Errorf("failed to create request: %w", err) 142 | } 143 | 144 | req.Header.Set("Accept", "text/event-stream") 145 | req.Header.Set("Cache-Control", "no-cache") 146 | req.Header.Set("Connection", "keep-alive") 147 | 148 | resp, err := t.client.Do(req) //nolint:bodyclose 149 | if err != nil { 150 | return fmt.Errorf("failed to connect to SSE stream: %w", err) 151 | } 152 | defer resp.Body.Close() 153 | 154 | if resp.StatusCode != http.StatusOK { 155 | return fmt.Errorf("unexpected status code: %d, status: %s", resp.StatusCode, resp.Status) 156 | } 157 | 158 | return t.readSSE(resp.Body) 159 | } 160 | 161 | // readSSE continuously reads the SSE stream and processes events. 162 | // It runs until the connection is closed or an error occurs. 163 | func (t *sseClientTransport) readSSE(reader io.ReadCloser) error { 164 | defer func() { 165 | _ = reader.Close() 166 | }() 167 | 168 | br := bufio.NewReader(reader) 169 | var event, data string 170 | 171 | for { 172 | line, err := br.ReadString('\n') 173 | if err != nil { 174 | if err == io.EOF { 175 | // Process any pending event before exit 176 | if event != "" && data != "" { 177 | t.handleSSEEvent(event, data) 178 | } 179 | } 180 | select { 181 | case <-t.ctx.Done(): 182 | return t.ctx.Err() 183 | default: 184 | return fmt.Errorf("SSE stream error: %w", err) 185 | } 186 | } 187 | 188 | // Remove only newline markers 189 | line = strings.TrimRight(line, "\r\n") 190 | if line == "" { 191 | // Empty line means end of event 192 | if event != "" && data != "" { 193 | t.handleSSEEvent(event, data) 194 | event = "" 195 | data = "" 196 | } 197 | continue 198 | } 199 | 200 | if strings.HasPrefix(line, "event:") { 201 | event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) 202 | } else if strings.HasPrefix(line, "data:") { 203 | data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) 204 | } 205 | } 206 | } 207 | 208 | // handleSSEEvent processes SSE events based on their type. 209 | // Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. 210 | func (t *sseClientTransport) handleSSEEvent(event, data string) { 211 | switch event { 212 | case "endpoint": 213 | endpoint, err := t.serverURL.Parse(data) 214 | if err != nil { 215 | t.logger.Errorf("Error parsing endpoint URL: %v", err) 216 | return 217 | } 218 | t.logger.Debugf("Received endpoint: %s", endpoint.String()) 219 | t.messageEndpoint = endpoint 220 | select { 221 | case t.endpointChan <- struct{}{}: 222 | default: 223 | } 224 | case "message": 225 | ctx, cancel := context.WithTimeout(t.ctx, t.receiveTimeout) 226 | defer cancel() 227 | if err := t.receiver.Receive(ctx, []byte(data)); err != nil { 228 | t.logger.Errorf("Error receive message: %v", err) 229 | return 230 | } 231 | } 232 | } 233 | 234 | func (t *sseClientTransport) Send(ctx context.Context, msg Message) error { 235 | t.logger.Debugf("Sending message: %s to %s", msg, t.messageEndpoint.String()) 236 | 237 | var ( 238 | err error 239 | req *http.Request 240 | resp *http.Response 241 | ) 242 | 243 | req, err = http.NewRequestWithContext(ctx, http.MethodPost, t.messageEndpoint.String(), bytes.NewReader(msg)) 244 | if err != nil { 245 | return fmt.Errorf("failed to create request: %w", err) 246 | } 247 | 248 | req.Header.Set("Content-Type", "application/json") 249 | 250 | if resp, err = t.client.Do(req); err != nil { 251 | return fmt.Errorf("failed to send message: %w", err) 252 | } 253 | defer resp.Body.Close() 254 | 255 | if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { 256 | return fmt.Errorf("unexpected status code: %d, status: %s", resp.StatusCode, resp.Status) 257 | } 258 | 259 | return nil 260 | } 261 | 262 | func (t *sseClientTransport) SetReceiver(receiver clientReceiver) { 263 | t.receiver = receiver 264 | } 265 | 266 | func (t *sseClientTransport) Close() error { 267 | t.cancel() 268 | 269 | <-t.sseConnectClose 270 | 271 | return nil 272 | } 273 | -------------------------------------------------------------------------------- /transport/sse_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | "net/http" 8 | "net/url" 9 | "testing" 10 | "time" 11 | 12 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 13 | ) 14 | 15 | func TestSSE(t *testing.T) { 16 | var ( 17 | err error 18 | svr ServerTransport 19 | client ClientTransport 20 | ) 21 | 22 | // Get an available port 23 | port, err := getAvailablePort() 24 | if err != nil { 25 | t.Fatalf("Failed to get available port: %v", err) 26 | } 27 | 28 | serverAddr := fmt.Sprintf("127.0.0.1:%d", port) 29 | serverURL := fmt.Sprintf("http://%s/sse", serverAddr) 30 | 31 | if svr, err = NewSSEServerTransport(serverAddr); err != nil { 32 | t.Fatalf("NewSSEServerTransport failed: %v", err) 33 | } 34 | 35 | if client, err = NewSSEClientTransport(serverURL); err != nil { 36 | t.Fatalf("NewSSEClientTransport failed: %v", err) 37 | } 38 | 39 | testTransport(t, client, svr) 40 | } 41 | 42 | func TestSSEHandler(t *testing.T) { 43 | var ( 44 | messageURL = "/message" 45 | port int 46 | 47 | err error 48 | svr ServerTransport 49 | client ClientTransport 50 | ) 51 | 52 | // Get an available port 53 | port, err = getAvailablePort() 54 | if err != nil { 55 | t.Fatalf("Failed to get available port: %v", err) 56 | } 57 | 58 | serverAddr := fmt.Sprintf("http://127.0.0.1:%d", port) 59 | serverURL := fmt.Sprintf("%s/sse", serverAddr) 60 | 61 | svr, handler, err := NewSSEServerTransportAndHandler(fmt.Sprintf("%s%s", serverAddr, messageURL)) 62 | if err != nil { 63 | t.Fatalf("NewSSEServerTransport failed: %v", err) 64 | } 65 | 66 | // 设置 HTTP 路由 67 | http.Handle("/sse", handler.HandleSSE()) 68 | http.Handle(messageURL, handler.HandleMessage()) 69 | 70 | errCh := make(chan error, 1) 71 | go func() { 72 | if e := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); e != nil { 73 | log.Fatalf("Failed to start HTTP server: %v", e) 74 | } 75 | }() 76 | 77 | // Use select to handle potential errors 78 | select { 79 | case err = <-errCh: 80 | t.Fatalf("http.ListenAndServe() failed: %v", err) 81 | case <-time.After(time.Second): 82 | // Server started normally 83 | } 84 | 85 | if client, err = NewSSEClientTransport(serverURL); err != nil { 86 | t.Fatalf("NewSSEClientTransport failed: %v", err) 87 | } 88 | 89 | testTransport(t, client, svr) 90 | } 91 | 92 | // getAvailablePort returns a port that is available for use 93 | func getAvailablePort() (int, error) { 94 | addr, err := net.Listen("tcp", "127.0.0.1:0") 95 | if err != nil { 96 | return 0, fmt.Errorf("failed to get available port: %v", err) 97 | } 98 | defer func() { 99 | if err = addr.Close(); err != nil { 100 | fmt.Println(err) 101 | } 102 | }() 103 | 104 | port := addr.Addr().(*net.TCPAddr).Port 105 | return port, nil 106 | } 107 | 108 | func Test_joinPath(t *testing.T) { 109 | type args struct { 110 | u *url.URL 111 | elem []string 112 | } 113 | tests := []struct { 114 | name string 115 | args args 116 | want string 117 | }{ 118 | { 119 | name: "1", 120 | args: args{ 121 | u: func() *url.URL { 122 | uri, err := url.Parse("https://google.com/api/v1") 123 | if err != nil { 124 | panic(err) 125 | } 126 | return uri 127 | }(), 128 | elem: []string{"/test"}, 129 | }, 130 | want: "https://google.com/api/v1/test", 131 | }, 132 | { 133 | name: "2", 134 | args: args{ 135 | u: func() *url.URL { 136 | uri, err := url.Parse("/api/v1") 137 | if err != nil { 138 | panic(err) 139 | } 140 | return uri 141 | }(), 142 | elem: []string{"/test"}, 143 | }, 144 | want: "/api/v1/test", 145 | }, 146 | } 147 | for _, tt := range tests { 148 | t.Run(tt.name, func(t *testing.T) { 149 | joinPath(tt.args.u, tt.args.elem...) 150 | if got := tt.args.u.String(); got != tt.want { 151 | t.Errorf("joinPath() = %v, want %v", got, tt.want) 152 | } 153 | }) 154 | } 155 | } 156 | 157 | func Test_sseClientTransport_handleSSEEvent(t1 *testing.T) { 158 | type fields struct { 159 | serverURL *url.URL 160 | logger pkg.Logger 161 | } 162 | type args struct { 163 | event string 164 | data string 165 | } 166 | tests := []struct { 167 | name string 168 | fields fields 169 | args args 170 | want string 171 | }{ 172 | { 173 | name: "1", 174 | fields: fields{ 175 | serverURL: func() *url.URL { 176 | uri, err := url.Parse("https://api.baidu.com/mcp") 177 | if err != nil { 178 | panic(err) 179 | } 180 | return uri 181 | }(), 182 | logger: pkg.DefaultLogger, 183 | }, 184 | args: args{ 185 | event: "endpoint", 186 | data: "/sse/messages", 187 | }, 188 | want: "https://api.baidu.com/sse/messages", 189 | }, 190 | { 191 | name: "2", 192 | fields: fields{ 193 | serverURL: func() *url.URL { 194 | uri, err := url.Parse("https://api.baidu.com/mcp") 195 | if err != nil { 196 | panic(err) 197 | } 198 | return uri 199 | }(), 200 | logger: pkg.DefaultLogger, 201 | }, 202 | args: args{ 203 | event: "endpoint", 204 | data: "https://api.google.com/sse/messages", 205 | }, 206 | want: "https://api.google.com/sse/messages", 207 | }, 208 | } 209 | for _, tt := range tests { 210 | t1.Run(tt.name, func(t1 *testing.T) { 211 | t := &sseClientTransport{ 212 | serverURL: tt.fields.serverURL, 213 | logger: tt.fields.logger, 214 | endpointChan: make(chan struct{}), 215 | } 216 | t.handleSSEEvent(tt.args.event, tt.args.data) 217 | if t.messageEndpoint.String() != tt.want { 218 | t1.Errorf("handleSSEEvent() = %v, want %v", t.messageEndpoint.String(), tt.want) 219 | } 220 | }) 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /transport/stdio_client.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "os" 11 | "os/exec" 12 | "sync" 13 | 14 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 15 | ) 16 | 17 | type StdioClientTransportOption func(*stdioClientTransport) 18 | 19 | func WithStdioClientOptionLogger(log pkg.Logger) StdioClientTransportOption { 20 | return func(t *stdioClientTransport) { 21 | t.logger = log 22 | } 23 | } 24 | 25 | func WithStdioClientOptionEnv(env ...string) StdioClientTransportOption { 26 | return func(t *stdioClientTransport) { 27 | t.cmd.Env = append(t.cmd.Env, env...) 28 | } 29 | } 30 | 31 | const mcpMessageDelimiter = '\n' 32 | 33 | type stdioClientTransport struct { 34 | cmd *exec.Cmd 35 | receiver clientReceiver 36 | reader io.Reader 37 | writer io.WriteCloser 38 | errReader io.Reader 39 | 40 | logger pkg.Logger 41 | 42 | wg sync.WaitGroup 43 | cancel context.CancelFunc 44 | } 45 | 46 | func NewStdioClientTransport(command string, args []string, opts ...StdioClientTransportOption) (ClientTransport, error) { 47 | cmd := exec.Command(command, args...) 48 | 49 | cmd.Env = os.Environ() 50 | 51 | stdin, err := cmd.StdinPipe() 52 | if err != nil { 53 | return nil, fmt.Errorf("failed to create stdin pipe: %w", err) 54 | } 55 | 56 | stdout, err := cmd.StdoutPipe() 57 | if err != nil { 58 | return nil, fmt.Errorf("failed to create stdout pipe: %w", err) 59 | } 60 | 61 | stderr, err := cmd.StderrPipe() 62 | if err != nil { 63 | return nil, fmt.Errorf("failed to create stdout pipe: %w", err) 64 | } 65 | 66 | t := &stdioClientTransport{ 67 | cmd: cmd, 68 | reader: stdout, 69 | writer: stdin, 70 | errReader: stderr, 71 | 72 | logger: pkg.DefaultLogger, 73 | } 74 | 75 | for _, opt := range opts { 76 | opt(t) 77 | } 78 | return t, nil 79 | } 80 | 81 | func (t *stdioClientTransport) Start() error { 82 | if err := t.cmd.Start(); err != nil { 83 | return fmt.Errorf("failed to start command: %w", err) 84 | } 85 | 86 | innerCtx, cancel := context.WithCancel(context.Background()) 87 | t.cancel = cancel 88 | 89 | t.wg.Add(1) 90 | go func() { 91 | defer pkg.Recover() 92 | defer t.wg.Done() 93 | 94 | t.startReceive(innerCtx) 95 | }() 96 | 97 | t.wg.Add(1) 98 | go func() { 99 | defer pkg.Recover() 100 | defer t.wg.Done() 101 | 102 | t.startReceiveErr(innerCtx) 103 | }() 104 | 105 | return nil 106 | } 107 | 108 | func (t *stdioClientTransport) Send(_ context.Context, msg Message) error { 109 | _, err := t.writer.Write(append(msg, mcpMessageDelimiter)) 110 | return err 111 | } 112 | 113 | func (t *stdioClientTransport) SetReceiver(receiver clientReceiver) { 114 | t.receiver = receiver 115 | } 116 | 117 | func (t *stdioClientTransport) Close() error { 118 | t.cancel() 119 | 120 | if err := t.writer.Close(); err != nil { 121 | return fmt.Errorf("failed to close writer: %w", err) 122 | } 123 | 124 | if err := t.cmd.Wait(); err != nil { 125 | return err 126 | } 127 | 128 | t.wg.Wait() 129 | 130 | return nil 131 | } 132 | 133 | func (t *stdioClientTransport) startReceive(ctx context.Context) { 134 | s := bufio.NewReader(t.reader) 135 | 136 | for { 137 | line, err := s.ReadBytes('\n') 138 | if err != nil { 139 | t.receiver.Interrupt(fmt.Errorf("stdout read error: %w", err)) 140 | 141 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here 142 | errors.Is(err, io.EOF) { 143 | return 144 | } 145 | t.logger.Errorf("stdout read error: %+v", err) 146 | return 147 | } 148 | 149 | line = bytes.TrimRight(line, "\n") 150 | // filter empty messages 151 | // filter space messages and \t messages 152 | if len(bytes.TrimFunc(line, func(r rune) bool { return r == ' ' || r == '\t' })) == 0 { 153 | t.logger.Debugf("skipping empty message") 154 | continue 155 | } 156 | 157 | select { 158 | case <-ctx.Done(): 159 | return 160 | default: 161 | if err = t.receiver.Receive(ctx, line); err != nil { 162 | t.logger.Errorf("receiver failed: %v", err) 163 | } 164 | } 165 | } 166 | } 167 | 168 | func (t *stdioClientTransport) startReceiveErr(ctx context.Context) { 169 | s := bufio.NewReader(t.errReader) 170 | 171 | for { 172 | line, err := s.ReadBytes('\n') 173 | if err != nil { 174 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here 175 | errors.Is(err, io.EOF) { 176 | return 177 | } 178 | t.logger.Errorf("client receive unexpected server error reading input: %v", err) 179 | return 180 | } 181 | 182 | line = bytes.TrimRight(line, "\n") 183 | // filter empty messages 184 | // filter space messages and \t messages 185 | if len(bytes.TrimFunc(line, func(r rune) bool { return r == ' ' || r == '\t' })) == 0 { 186 | t.logger.Debugf("skipping empty message") 187 | continue 188 | } 189 | 190 | select { 191 | case <-ctx.Done(): 192 | return 193 | default: 194 | t.logger.Infof("receive server info: %s", line) 195 | } 196 | } 197 | } 198 | -------------------------------------------------------------------------------- /transport/stdio_server.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "os" 11 | 12 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 13 | ) 14 | 15 | type StdioServerTransportOption func(*stdioServerTransport) 16 | 17 | func WithStdioServerOptionLogger(log pkg.Logger) StdioServerTransportOption { 18 | return func(t *stdioServerTransport) { 19 | t.logger = log 20 | } 21 | } 22 | 23 | type stdioServerTransport struct { 24 | receiver serverReceiver 25 | reader io.ReadCloser 26 | writer io.Writer 27 | 28 | sessionManager sessionManager 29 | sessionID string 30 | 31 | logger pkg.Logger 32 | 33 | cancel context.CancelFunc 34 | receiveShutDone chan struct{} 35 | } 36 | 37 | func NewStdioServerTransport(opts ...StdioServerTransportOption) ServerTransport { 38 | t := &stdioServerTransport{ 39 | reader: os.Stdin, 40 | writer: os.Stdout, 41 | logger: pkg.DefaultLogger, 42 | 43 | receiveShutDone: make(chan struct{}), 44 | } 45 | 46 | for _, opt := range opts { 47 | opt(t) 48 | } 49 | return t 50 | } 51 | 52 | func (t *stdioServerTransport) Run() error { 53 | ctx, cancel := context.WithCancel(context.Background()) 54 | t.cancel = cancel 55 | 56 | t.sessionID = t.sessionManager.CreateSession(context.Background()) 57 | 58 | t.startReceive(ctx) 59 | 60 | close(t.receiveShutDone) 61 | return nil 62 | } 63 | 64 | func (t *stdioServerTransport) Send(_ context.Context, _ string, msg Message) error { 65 | if _, err := t.writer.Write(append(msg, mcpMessageDelimiter)); err != nil { 66 | return fmt.Errorf("failed to write: %w", err) 67 | } 68 | return nil 69 | } 70 | 71 | func (t *stdioServerTransport) SetReceiver(receiver serverReceiver) { 72 | t.receiver = receiver 73 | } 74 | 75 | func (t *stdioServerTransport) SetSessionManager(m sessionManager) { 76 | t.sessionManager = m 77 | } 78 | 79 | func (t *stdioServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error { 80 | t.cancel() 81 | 82 | if err := t.reader.Close(); err != nil { 83 | return err 84 | } 85 | 86 | select { 87 | case <-t.receiveShutDone: 88 | return nil 89 | case <-serverCtx.Done(): 90 | return nil 91 | case <-userCtx.Done(): 92 | return userCtx.Err() 93 | } 94 | } 95 | 96 | func (t *stdioServerTransport) startReceive(ctx context.Context) { 97 | s := bufio.NewReader(t.reader) 98 | 99 | for { 100 | line, err := s.ReadBytes('\n') 101 | if err != nil { 102 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here 103 | errors.Is(err, io.EOF) { 104 | return 105 | } 106 | t.logger.Errorf("client receive unexpected error reading input: %v", err) 107 | } 108 | line = bytes.TrimRight(line, "\n") 109 | 110 | select { 111 | case <-ctx.Done(): 112 | return 113 | default: 114 | t.receive(ctx, line) 115 | } 116 | } 117 | } 118 | 119 | func (t *stdioServerTransport) receive(ctx context.Context, line []byte) { 120 | outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, line) 121 | if err != nil { 122 | t.logger.Errorf("receiver failed: %v", err) 123 | return 124 | } 125 | 126 | if outputMsgCh == nil { 127 | return 128 | } 129 | 130 | go func() { 131 | defer pkg.Recover() 132 | 133 | for msg := range outputMsgCh { 134 | if e := t.Send(context.Background(), t.sessionID, msg); e != nil { 135 | t.logger.Errorf("Failed to send message: %v", e) 136 | } 137 | } 138 | }() 139 | } 140 | -------------------------------------------------------------------------------- /transport/stdio_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "math/rand" 7 | "os" 8 | "os/exec" 9 | "path/filepath" 10 | "strconv" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | type mock struct { 16 | reader *io.PipeReader 17 | writer *io.PipeWriter 18 | closer io.Closer 19 | } 20 | 21 | func (m *mock) Write(p []byte) (n int, err error) { 22 | return m.writer.Write(p) 23 | } 24 | 25 | func (m *mock) Close() error { 26 | if err := m.writer.Close(); err != nil { 27 | return err 28 | } 29 | if err := m.reader.Close(); err != nil { 30 | return err 31 | } 32 | if err := m.closer.Close(); err != nil { 33 | return err 34 | } 35 | return nil 36 | } 37 | 38 | func TestStdioTransport(t *testing.T) { 39 | var ( 40 | err error 41 | server *stdioServerTransport 42 | client *stdioClientTransport 43 | ) 44 | 45 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 46 | 47 | mockServerTrPath := filepath.Join(os.TempDir(), "mock_server_tr_"+strconv.Itoa(r.Int())) 48 | if err = compileMockStdioServerTr(mockServerTrPath); err != nil { 49 | t.Fatalf("Failed to compile mock server: %v", err) 50 | } 51 | 52 | defer func(name string) { 53 | if err = os.Remove(name); err != nil { 54 | fmt.Printf("Failed to remove mock server: %v\n", err) 55 | } 56 | }(mockServerTrPath) 57 | 58 | clientT, err := NewStdioClientTransport(mockServerTrPath, []string{}) 59 | if err != nil { 60 | t.Fatalf("NewStdioClientTransport failed: %v", err) 61 | } 62 | 63 | client = clientT.(*stdioClientTransport) 64 | server = NewStdioServerTransport().(*stdioServerTransport) 65 | 66 | // Create pipes for communication 67 | reader1, writer1 := io.Pipe() 68 | reader2, writer2 := io.Pipe() 69 | 70 | // Set up the communication channels 71 | server.reader = reader2 72 | server.writer = writer1 73 | client.reader = reader1 74 | client.writer = &mock{ 75 | reader: reader1, 76 | writer: writer2, 77 | closer: client.writer, 78 | } 79 | 80 | testTransport(t, client, server) 81 | } 82 | 83 | func compileMockStdioServerTr(outputPath string) error { 84 | cmd := exec.Command("go", "build", "-o", outputPath, "../testdata/mock_block_server.go") 85 | 86 | if output, err := cmd.CombinedOutput(); err != nil { 87 | return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) 88 | } 89 | 90 | return nil 91 | } 92 | -------------------------------------------------------------------------------- /transport/streamable_http_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestStreamableHTTP(t *testing.T) { 9 | var ( 10 | err error 11 | svr ServerTransport 12 | client ClientTransport 13 | ) 14 | 15 | // Get an available port 16 | port, err := getAvailablePort() 17 | if err != nil { 18 | t.Fatalf("Failed to get available port: %v", err) 19 | } 20 | 21 | serverAddr := fmt.Sprintf("127.0.0.1:%d", port) 22 | serverURL := fmt.Sprintf("http://%s/mcp", serverAddr) 23 | 24 | svr = NewStreamableHTTPServerTransport(serverAddr) 25 | 26 | if client, err = NewStreamableHTTPClientTransport(serverURL); err != nil { 27 | t.Fatalf("NewStreamableHTTPClientTransport failed: %v", err) 28 | } 29 | 30 | testTransport(t, client, svr) 31 | } 32 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 7 | ) 8 | 9 | /* 10 | * Transport is an abstraction of the underlying transport layer. 11 | * GO-MCP needs to be able to transmit JSON-RPC messages between server and client. 12 | */ 13 | 14 | // Message defines the basic message interface 15 | type Message []byte 16 | 17 | func (msg Message) String() string { 18 | return pkg.B2S(msg) 19 | } 20 | 21 | type ClientTransport interface { 22 | // Start initiates the transport connection 23 | Start() error 24 | 25 | // Send transmits a message 26 | Send(ctx context.Context, msg Message) error 27 | 28 | // SetReceiver sets the handler for messages from the peer 29 | SetReceiver(receiver clientReceiver) 30 | 31 | // Close terminates the transport connection 32 | Close() error 33 | } 34 | 35 | type clientReceiver interface { 36 | Receive(ctx context.Context, msg []byte) error 37 | Interrupt(err error) 38 | } 39 | 40 | type ClientReceiver struct { 41 | receive func(ctx context.Context, msg []byte) error 42 | interrupt func(err error) 43 | } 44 | 45 | func (r *ClientReceiver) Receive(ctx context.Context, msg []byte) error { 46 | return r.receive(ctx, msg) 47 | } 48 | 49 | func (r *ClientReceiver) Interrupt(err error) { 50 | r.interrupt(err) 51 | } 52 | 53 | func NewClientReceiver(receive func(ctx context.Context, msg []byte) error, interrupt func(err error)) clientReceiver { 54 | r := &ClientReceiver{ 55 | receive: receive, 56 | interrupt: interrupt, 57 | } 58 | return r 59 | } 60 | 61 | type ServerTransport interface { 62 | // Run starts listening for requests, this is synchronous, and cannot return before Shutdown is called 63 | Run() error 64 | 65 | // Send transmits a message 66 | Send(ctx context.Context, sessionID string, msg Message) error 67 | 68 | // SetReceiver sets the handler for messages from the peer 69 | SetReceiver(serverReceiver) 70 | 71 | SetSessionManager(manager sessionManager) 72 | 73 | // Shutdown gracefully closes, the internal implementation needs to stop receiving messages first, 74 | // then wait for serverCtx to be canceled, while using userCtx to control timeout. 75 | // userCtx is used to control the timeout of the server shutdown. 76 | // serverCtx is used to coordinate the internal cleanup sequence: 77 | // 1. turn off message listen 78 | // 2. Wait for serverCtx to be done (indicating server shutdown is complete) 79 | // 3. Cancel the transport's context to stop all ongoing operations 80 | // 4. Wait for all in-flight sends to complete 81 | // 5. Close all session 82 | Shutdown(userCtx context.Context, serverCtx context.Context) error 83 | } 84 | 85 | type serverReceiver interface { 86 | Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) 87 | } 88 | 89 | type ServerReceiverF func(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) 90 | 91 | func (f ServerReceiverF) Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) { 92 | return f(ctx, sessionID, msg) 93 | } 94 | 95 | type sessionManager interface { 96 | CreateSession(context.Context) string 97 | OpenMessageQueueForSend(sessionID string) error 98 | EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error 99 | DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) 100 | CloseSession(sessionID string) 101 | CloseAllSessions() 102 | } 103 | -------------------------------------------------------------------------------- /transport/transport_test.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | 11 | "github.com/ThinkInAIXYZ/go-mcp/pkg" 12 | ) 13 | 14 | type mockSessionManager struct { 15 | pkg.SyncMap[chan []byte] 16 | } 17 | 18 | func newMockSessionManager() *mockSessionManager { 19 | return &mockSessionManager{} 20 | } 21 | 22 | func (m *mockSessionManager) CreateSession(context.Context) string { 23 | sessionID := uuid.NewString() 24 | m.Store(sessionID, nil) 25 | return sessionID 26 | } 27 | 28 | func (m *mockSessionManager) OpenMessageQueueForSend(sessionID string) error { 29 | _, ok := m.Load(sessionID) 30 | if !ok { 31 | return pkg.ErrLackSession 32 | } 33 | m.Store(sessionID, make(chan []byte)) 34 | return nil 35 | } 36 | 37 | func (m *mockSessionManager) IsExistSession(sessionID string) bool { 38 | _, has := m.Load(sessionID) 39 | return has 40 | } 41 | 42 | func (m *mockSessionManager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error { 43 | ch, has := m.Load(sessionID) 44 | if !has { 45 | return pkg.ErrLackSession 46 | } 47 | 48 | select { 49 | case ch <- message: 50 | return nil 51 | case <-ctx.Done(): 52 | return ctx.Err() 53 | } 54 | } 55 | 56 | func (m *mockSessionManager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { 57 | ch, has := m.Load(sessionID) 58 | if !has { 59 | return nil, pkg.ErrLackSession 60 | } 61 | 62 | select { 63 | case <-ctx.Done(): 64 | return nil, ctx.Err() 65 | case msg, ok := <-ch: 66 | if msg == nil && !ok { 67 | // There are no new messages and the chan has been closed, indicating that the request may need to be terminated. 68 | return nil, pkg.ErrSendEOF 69 | } 70 | return msg, nil 71 | } 72 | } 73 | 74 | func (m *mockSessionManager) CloseSession(sessionID string) { 75 | ch, ok := m.LoadAndDelete(sessionID) 76 | if !ok { 77 | return 78 | } 79 | close(ch) 80 | } 81 | 82 | func (m *mockSessionManager) CloseAllSessions() { 83 | m.Range(func(key string, value chan []byte) bool { 84 | m.Delete(key) 85 | close(value) 86 | return true 87 | }) 88 | } 89 | 90 | func testTransport(t *testing.T, client ClientTransport, server ServerTransport) { 91 | testMsg := "hello server" 92 | expectedMsgWithServerCh := make(chan string, 1) 93 | server.SetReceiver(ServerReceiverF(func(_ context.Context, _ string, msg []byte) (<-chan []byte, error) { 94 | expectedMsgWithServerCh <- string(msg) 95 | msgCh := make(chan []byte, 1) 96 | go func() { 97 | defer close(msgCh) 98 | msgCh <- msg 99 | }() 100 | return msgCh, nil 101 | })) 102 | server.SetSessionManager(newMockSessionManager()) 103 | 104 | expectedMsgWithClientCh := make(chan string, 1) 105 | client.SetReceiver(NewClientReceiver(func(_ context.Context, msg []byte) error { 106 | expectedMsgWithClientCh <- string(msg) 107 | return nil 108 | }, func(_ error) { 109 | close(expectedMsgWithClientCh) 110 | })) 111 | 112 | errCh := make(chan error, 1) 113 | go func() { 114 | errCh <- server.Run() 115 | }() 116 | 117 | // Use select to handle potential errors 118 | select { 119 | case err := <-errCh: 120 | t.Fatalf("server.Run() failed: %v", err) 121 | case <-time.After(time.Second): 122 | // Server started normally 123 | } 124 | 125 | defer func() { 126 | if _, ok := server.(*stdioServerTransport); ok { // stdioServerTransport not support shutdown 127 | return 128 | } 129 | 130 | userCtx, cancel := context.WithTimeout(context.Background(), time.Second*1) 131 | defer cancel() 132 | 133 | serverCtx, cancel := context.WithCancel(userCtx) 134 | cancel() 135 | 136 | if err := server.Shutdown(userCtx, serverCtx); err != nil { 137 | t.Fatalf("server.Shutdown() failed: %v", err) 138 | } 139 | }() 140 | 141 | if err := client.Start(); err != nil { 142 | t.Fatalf("client.Run() failed: %v", err) 143 | } 144 | 145 | defer func() { 146 | if err := client.Close(); err != nil { 147 | t.Fatalf("client.Close() failed: %v", err) 148 | } 149 | }() 150 | 151 | if err := client.Send(context.Background(), Message(testMsg)); err != nil { 152 | t.Fatalf("client.Send() failed: %v", err) 153 | } 154 | expectedMsg := <-expectedMsgWithServerCh 155 | if !reflect.DeepEqual(expectedMsg, testMsg) { 156 | t.Fatalf("client.Send() got %v, want %v", expectedMsg, testMsg) 157 | } 158 | expectedMsg = <-expectedMsgWithClientCh 159 | if !reflect.DeepEqual(expectedMsg, testMsg) { 160 | t.Fatalf("server.Send() failed: got %v, want %v", expectedMsg, testMsg) 161 | } 162 | } 163 | --------------------------------------------------------------------------------