/go-mcp.git'
18 | ```
19 | You can replace "upstream" with any name you like, such as your username, nickname, or simply "me". Remember to make corresponding replacements in subsequent commands.
20 |
21 | ## 3 Sync the Code
22 | Unless you've just cloned the code locally, we need to sync the remote repository's code first.
23 | git fetch
24 |
25 | When not specifying a remote repository, this command will only sync the origin's code. If we need to sync our forked repository, we can add the remote repository name:
26 | git fetch upstream
27 |
28 | ## 4 Create a Feature Branch
29 | When creating a new feature branch, we need to first consider which branch to branch from.
30 | Let's assume we want our new feature to be merged into the `main` branch, or that our new feature should be based on `main`, execute:
31 | ```bash
32 | git checkout -b feature/my-feature origin/main
33 | ```
34 | This creates a branch that is identical to the code on `origin/main`.
35 |
36 | ## 5 Golint
37 | ```bash
38 | golint $(go list ./... | grep -v /examples/)
39 | golangci-lint run $(go list ./... | grep -v /examples/)
40 | ```
41 |
42 | ## 6 Go Test
43 | ```bash
44 | go test -v -race $(go list ./... | grep -v /examples/) -coverprofile=coverage.txt -covermode=atomic
45 | ```
46 |
47 | ## 7 Submit Commit
48 | ```bash
49 | git add .
50 | git commit
51 | git push upstream my-feature
52 | ```
53 |
54 | ## 8 Submit PR
55 | Visit https://github.com/thinkinaixyz/go-mcp,
56 | Click "Compare" to compare changes and click "Pull request" to submit the PR
57 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Anthropic, PBC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | user := $(shell whoami)
2 | rev := $(shell git rev-parse --short HEAD)
3 | os := $(shell sh -c 'echo $$(uname -s) | cut -c1-5')
4 |
5 | # GOBIN > GOPATH > INSTALLDIR
6 | # Mac OS X
7 | ifeq ($(shell uname),Darwin)
8 | GOBIN := $(shell echo ${GOBIN} | cut -d':' -f1)
9 | GOPATH := $(shell echo $(GOPATH) | cut -d':' -f1)
10 | endif
11 |
12 | # Linux
13 | ifeq ($(os),Linux)
14 | GOBIN := $(shell echo ${GOBIN} | cut -d':' -f1)
15 | GOPATH := $(shell echo $(GOPATH) | cut -d':' -f1)
16 | endif
17 |
18 | # Windows
19 | ifeq ($(os),MINGW)
20 | GOBIN := $(subst \,/,$(GOBIN))
21 | GOPATH := $(subst \,/,$(GOPATH))
22 | GOBIN :=/$(shell echo "$(GOBIN)" | cut -d';' -f1 | sed 's/://g')
23 | GOPATH :=/$(shell echo "$(GOPATH)" | cut -d';' -f1 | sed 's/://g')
24 | endif
25 | BIN := ""
26 |
27 | # check GOBIN
28 | ifneq ($(GOBIN),)
29 | BIN=$(GOBIN)
30 | else
31 | # check GOPATH
32 | ifneq ($(GOPATH),)
33 | BIN=$(GOPATH)/bin
34 | endif
35 | endif
36 |
37 | TOOLS_SHELL="./hack/tools.sh"
38 | # golangci-lint
39 | LINTER := bin/golangci-lint
40 |
41 | $(LINTER):
42 | curl -SL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v2.1.1
43 |
44 | .PHONY: init-dev
45 | init-dev:
46 | go install github.com/git-chglog/git-chglog/cmd/git-chglog@latest
47 | go install mvdan.cc/gofumpt@latest
48 | go install golang.org/x/tools/cmd/goimports@latest
49 |
50 |
51 | .PHONY: inspector
52 | inspector:
53 | npx -y @modelcontextprotocol/inspector
54 |
55 |
56 | .PHONY: fmt
57 | fmt:
58 | gofumpt -w -l .
59 | goimports -w -l .
60 |
61 |
62 | .PHONY: clean
63 | clean:
64 | @${TOOLS_SHELL} tidy
65 | @echo "clean finished"
66 |
67 | .PHONY: fix
68 | fix: $(LINTER)
69 | @${TOOLS_SHELL} fix
70 | @echo "lint fix finished"
71 |
72 | .PHONY: test
73 | test:
74 | @${TOOLS_SHELL} test
75 | @echo "go test finished"
76 |
77 | .PHONY: test-coverage
78 | test-coverage:
79 | @${TOOLS_SHELL} test_coverage
80 | @echo "go test with coverage finished"
81 |
82 | .PHONY: lint
83 | lint: $(LINTER)
84 | echo $(os)
85 | @${TOOLS_SHELL} lint
86 | @echo "lint check finished"
87 |
88 | .PHONY: changelog
89 | # 生成 changelog
90 | changelog:
91 | git-chglog -o ./CHANGELOG.md
92 |
93 | # show help
94 | help:
95 | @echo ''
96 | @echo 'Usage:'
97 | @echo ' make [target]'
98 | @echo ''
99 | @echo 'Targets:'
100 | @awk '/^[a-zA-Z\-_0-9]+:/ { \
101 | helpMessage = match(lastLine, /^# (.*)/); \
102 | if (helpMessage) { \
103 | helpCommand = substr($$1, 0, index($$1, ":")-1); \
104 | helpMessage = substr(lastLine, RSTART + 2, RLENGTH); \
105 | printf "\033[36m%-22s\033[0m %s\n", helpCommand,helpMessage; \
106 | } \
107 | } \
108 | { lastLine = $$0 }' $(MAKEFILE_LIST)
109 |
110 | .DEFAULT_GOAL := help
111 |
--------------------------------------------------------------------------------
/README_TW.md:
--------------------------------------------------------------------------------
1 | # Go-MCP
2 |
3 |
4 |

5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | English
20 |
21 |
22 | ## 🚀 概述
23 |
24 | Go-MCP 是一個強大的 Go 語言版本 MCP SDK,實現 Model Context Protocol (MCP),協助外部系統與 AI 應用之間的無縫溝通。基於 Go 語言的強型別與效能優勢,提供簡潔且符合習慣的 API,方便您將外部系統整合進 AI 應用程式。
25 |
26 | ### ✨ 主要特色
27 |
28 | - 🔄 **完整協議實作**:全面實現 MCP 規範,確保與所有相容服務無縫對接
29 | - 🏗️ **優雅的架構設計**:採用清晰的三層架構,支援雙向通訊,確保程式碼模組化與可擴充性
30 | - 🔌 **與 Web 框架無縫整合**:提供符合 MCP 協議的 http.Handler,讓開發者能將 MCP 整合進服務框架
31 | - 🛡️ **型別安全**:善用 Go 的強型別系統,確保程式碼清晰且高度可維護
32 | - 📦 **簡易部署**:受惠於 Go 的靜態編譯特性,無需複雜的相依管理
33 | - ⚡ **高效能設計**:充分發揮 Go 的並行能力,在各種場景下皆能維持優異效能與低資源消耗
34 |
35 | ## 🛠️ 安裝
36 |
37 | ```bash
38 | go get github.com/ThinkInAIXYZ/go-mcp
39 | ```
40 |
41 | 需 Go 1.18 或更高版本。
42 |
43 | ## 🎯 快速開始
44 |
45 | ### 客戶端範例
46 |
47 | ```go
48 | package main
49 |
50 | import (
51 | "context"
52 | "log"
53 |
54 | "github.com/ThinkInAIXYZ/go-mcp/client"
55 | "github.com/ThinkInAIXYZ/go-mcp/transport"
56 | )
57 |
58 | func main() {
59 | // 建立 SSE 傳輸客戶端
60 | transportClient, err := transport.NewSSEClientTransport("http://127.0.0.1:8080/sse")
61 | if err != nil {
62 | log.Fatalf("建立傳輸客戶端失敗: %v", err)
63 | }
64 |
65 | // 初始化 MCP 客戶端
66 | mcpClient, err := client.NewClient(transportClient)
67 | if err != nil {
68 | log.Fatalf("建立 MCP 客戶端失敗: %v", err)
69 | }
70 | defer mcpClient.Close()
71 |
72 | // 取得可用工具列表
73 | tools, err := mcpClient.ListTools(context.Background())
74 | if err != nil {
75 | log.Fatalf("取得工具列表失敗: %v", err)
76 | }
77 | log.Printf("可用工具: %+v", tools)
78 | }
79 | ```
80 |
81 | ### 伺服器範例
82 |
83 | ```go
84 | package main
85 |
86 | import (
87 | "context"
88 | "fmt"
89 | "log"
90 | "time"
91 |
92 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
93 | "github.com/ThinkInAIXYZ/go-mcp/server"
94 | "github.com/ThinkInAIXYZ/go-mcp/transport"
95 | )
96 |
97 | type TimeRequest struct {
98 | Timezone string `json:"timezone" description:"時區" required:"true"` // 使用 field tag 描述輸入結構
99 | }
100 |
101 | func main() {
102 | // 建立 SSE 傳輸伺服器
103 | transportServer, err := transport.NewSSEServerTransport("127.0.0.1:8080")
104 | if err != nil {
105 | log.Fatalf("建立傳輸伺服器失敗: %v", err)
106 | }
107 |
108 | // 初始化 MCP 伺服器
109 | mcpServer, err := server.NewServer(transportServer)
110 | if err != nil {
111 | log.Fatalf("建立 MCP 伺服器失敗: %v", err)
112 | }
113 |
114 | // 註冊時間查詢工具
115 | tool, err := protocol.NewTool("current_time", "取得指定時區的目前時間", TimeRequest{})
116 | if err != nil {
117 | log.Fatalf("建立工具失敗: %v", err)
118 | return
119 | }
120 | mcpServer.RegisterTool(tool, handleTimeRequest)
121 |
122 | // 啟動伺服器
123 | if err = mcpServer.Run(); err != nil {
124 | log.Fatalf("伺服器啟動失敗: %v", err)
125 | }
126 | }
127 |
128 | func handleTimeRequest(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
129 | var timeReq TimeRequest
130 | if err := protocol.VerifyAndUnmarshal(req.RawArguments, &timeReq); err != nil {
131 | return nil, err
132 | }
133 |
134 | loc, err := time.LoadLocation(timeReq.Timezone)
135 | if err != nil {
136 | return nil, fmt.Errorf("無效的時區: %v", err)
137 | }
138 |
139 | return &protocol.CallToolResult{
140 | Content: []protocol.Content{
141 | &protocol.TextContent{
142 | Type: "text",
143 | Text: time.Now().In(loc).String(),
144 | },
145 | },
146 | }, nil
147 | }
148 | ```
149 |
150 | ### 與 Gin 框架整合
151 |
152 | ```go
153 | package main
154 |
155 | import (
156 | "context"
157 | "log"
158 |
159 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
160 | "github.com/ThinkInAIXYZ/go-mcp/server"
161 | "github.com/ThinkInAIXYZ/go-mcp/transport"
162 | "github.com/gin-gonic/gin"
163 | )
164 |
165 | func main() {
166 | messageEndpointURL := "/message"
167 |
168 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL)
169 | if err != nil {
170 | log.Panicf("建立 SSE 傳輸與處理器失敗: %v", err)
171 | }
172 |
173 | // 建立 MCP 伺服器
174 | mcpServer, _ := server.NewServer(sseTransport)
175 |
176 | // 註冊工具
177 | // mcpServer.RegisterTool(tool, toolHandler)
178 |
179 | // 啟動 MCP 伺服器
180 | go func() {
181 | mcpServer.Run()
182 | }()
183 |
184 | defer mcpServer.Shutdown(context.Background())
185 |
186 | r := gin.Default()
187 | r.GET("/sse", func(ctx *gin.Context) {
188 | mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request)
189 | })
190 | r.POST(messageEndpointURL, func(ctx *gin.Context) {
191 | mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request)
192 | })
193 |
194 | if err = r.Run(":8080"); err != nil {
195 | return
196 | }
197 | }
198 | ```
199 |
200 | [參考:更完整的範例](https://github.com/ThinkInAIXYZ/go-mcp/blob/main/examples/http_handler/main.go)
201 |
202 | ## 🏗️ 架構設計
203 |
204 | Go-MCP 採用優雅的三層架構設計:
205 |
206 | 
207 |
208 | 1. **傳輸層**:負責底層通訊實作,支援多種傳輸協定
209 | 2. **協議層**:處理 MCP 協議的編解碼與資料結構定義
210 | 3. **使用者層**:提供友善的客戶端與伺服器 API
211 |
212 | 目前支援的傳輸方式:
213 |
214 | 
215 |
216 | - **HTTP SSE/POST**:基於 HTTP 的伺服器推播與客戶端請求,適用於 Web 場景
217 | - **Streamable HTTP**:支援 HTTP POST/GET 請求,具備 stateless 與 stateful 兩種模式,stateful 模式利用 SSE 進行多訊息串流傳輸,支援伺服器主動通知與請求
218 | - **Stdio**:基於標準輸入輸出流,適合本地進程間通訊
219 |
220 | 傳輸層採用統一介面抽象,讓新增傳輸方式(如 Streamable HTTP、WebSocket、gRPC)變得簡單直接,且不影響上層程式碼。
221 |
222 | ## 🤝 貢獻方式
223 |
224 | 歡迎各種形式的貢獻!詳情請參閱 [CONTRIBUTING.md](CONTRIBUTING.md)。
225 |
226 | ## 📄 授權條款
227 |
228 | 本專案採用 MIT 授權條款 - 詳見 [LICENSE](LICENSE) 檔案
229 |
230 | ## 📞 聯絡我們
231 |
232 | - **GitHub Issues**:[提交問題](https://github.com/ThinkInAIXYZ/go-mcp/issues)
233 | - **Discord**:點擊[這裡](https://discord.gg/4CSU8HYt)加入用戶群
234 | - **微信社群**:
235 |
236 | 
237 |
238 | ## ✨ 貢獻者
239 |
240 | 感謝所有為本專案做出貢獻的開發者!
241 |
242 |
243 |
244 |
245 |
246 | ## 📈 專案趨勢
247 |
248 | [](https://www.star-history.com/#ThinkInAIXYZ/go-mcp&Date)
249 |
--------------------------------------------------------------------------------
/client/client.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "sync"
7 | "time"
8 |
9 | cmap "github.com/orcaman/concurrent-map/v2"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
12 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
13 | "github.com/ThinkInAIXYZ/go-mcp/transport"
14 | )
15 |
16 | type Option func(*Client)
17 |
18 | func WithNotifyHandler(handler NotifyHandler) Option {
19 | return func(s *Client) {
20 | s.notifyHandler = handler
21 | }
22 | }
23 |
24 | func WithSamplingHandler(handler SamplingHandler) Option {
25 | return func(s *Client) {
26 | s.samplingHandler = handler
27 | }
28 | }
29 |
30 | func WithClientInfo(info *protocol.Implementation) Option {
31 | return func(s *Client) {
32 | s.clientInfo = info
33 | }
34 | }
35 |
36 | func WithInitTimeout(timeout time.Duration) Option {
37 | return func(s *Client) {
38 | s.initTimeout = timeout
39 | }
40 | }
41 |
42 | func WithLogger(logger pkg.Logger) Option {
43 | return func(s *Client) {
44 | s.logger = logger
45 | }
46 | }
47 |
48 | type Client struct {
49 | transport transport.ClientTransport
50 |
51 | reqID2respChan cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse]
52 |
53 | progressChanRW sync.RWMutex
54 | progressToken2notifyChan map[string]chan<- *protocol.ProgressNotification
55 |
56 | samplingHandler SamplingHandler
57 |
58 | notifyHandler NotifyHandler
59 |
60 | requestID int64
61 |
62 | ready *pkg.AtomicBool
63 | initializationMu sync.Mutex
64 |
65 | clientInfo *protocol.Implementation
66 | clientCapabilities *protocol.ClientCapabilities
67 |
68 | serverCapabilities *protocol.ServerCapabilities
69 | serverInfo *protocol.Implementation
70 | serverInstructions string
71 |
72 | initTimeout time.Duration
73 |
74 | closed chan struct{}
75 |
76 | logger pkg.Logger
77 | }
78 |
79 | func NewClient(t transport.ClientTransport, opts ...Option) (*Client, error) {
80 | client := &Client{
81 | transport: t,
82 | reqID2respChan: cmap.New[chan *protocol.JSONRPCResponse](),
83 | progressToken2notifyChan: make(map[string]chan<- *protocol.ProgressNotification),
84 | ready: pkg.NewAtomicBool(),
85 | clientInfo: &protocol.Implementation{},
86 | clientCapabilities: &protocol.ClientCapabilities{},
87 | initTimeout: time.Second * 30,
88 | closed: make(chan struct{}),
89 | logger: pkg.DefaultLogger,
90 | }
91 | t.SetReceiver(transport.NewClientReceiver(client.receive, client.receiveInterrupt))
92 |
93 | for _, opt := range opts {
94 | opt(client)
95 | }
96 |
97 | if client.notifyHandler == nil {
98 | h := NewBaseNotifyHandler()
99 | h.Logger = client.logger
100 | client.notifyHandler = h
101 | }
102 |
103 | if client.samplingHandler != nil {
104 | client.clientCapabilities.Sampling = struct{}{}
105 | }
106 |
107 | ctx, cancel := context.WithTimeout(context.Background(), client.initTimeout)
108 | defer cancel()
109 |
110 | if err := client.transport.Start(); err != nil {
111 | return nil, fmt.Errorf("init mcp client transpor start fail: %w", err)
112 | }
113 |
114 | if _, err := client.initialization(ctx, protocol.NewInitializeRequest(client.clientInfo, client.clientCapabilities)); err != nil {
115 | return nil, err
116 | }
117 |
118 | go func() {
119 | defer pkg.Recover()
120 |
121 | ticker := time.NewTicker(time.Minute)
122 | defer ticker.Stop()
123 |
124 | for {
125 | select {
126 | case <-client.closed:
127 | return
128 | case <-ticker.C:
129 | client.sessionDetection()
130 | }
131 | }
132 | }()
133 |
134 | return client, nil
135 | }
136 |
137 | func (client *Client) GetServerCapabilities() protocol.ServerCapabilities {
138 | return *client.serverCapabilities
139 | }
140 |
141 | func (client *Client) GetServerInfo() protocol.Implementation {
142 | return *client.serverInfo
143 | }
144 |
145 | func (client *Client) GetServerInstructions() string {
146 | return client.serverInstructions
147 | }
148 |
149 | func (client *Client) Close() error {
150 | close(client.closed)
151 |
152 | return client.transport.Close()
153 | }
154 |
155 | func (client *Client) sessionDetection() {
156 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
157 | defer cancel()
158 |
159 | if _, err := client.Ping(ctx, protocol.NewPingRequest()); err != nil {
160 | client.logger.Warnf("mcp client ping server fail: %v", err)
161 | }
162 | }
163 |
--------------------------------------------------------------------------------
/client/handle.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
10 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
11 | )
12 |
13 | func (client *Client) handleRequestWithPing() (*protocol.PingResult, error) {
14 | return protocol.NewPingResult(), nil
15 | }
16 |
17 | func (client *Client) handleRequestWithCreateMessagesSampling(ctx context.Context, rawParams json.RawMessage) (*protocol.CreateMessageResult, error) {
18 | if client.clientCapabilities.Sampling == nil {
19 | return nil, pkg.ErrClientNotSupport
20 | }
21 |
22 | var request *protocol.CreateMessageRequest
23 | if err := pkg.JSONUnmarshal(rawParams, &request); err != nil {
24 | return nil, err
25 | }
26 |
27 | return client.samplingHandler.CreateMessage(ctx, request)
28 | }
29 |
30 | func (client *Client) handleNotifyWithToolsListChanged(ctx context.Context, rawParams json.RawMessage) error {
31 | notify := &protocol.ToolListChangedNotification{}
32 | if len(rawParams) > 0 {
33 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil {
34 | return err
35 | }
36 | }
37 | return client.notifyHandler.ToolsListChanged(ctx, notify)
38 | }
39 |
40 | func (client *Client) handleNotifyWithPromptsListChanged(ctx context.Context, rawParams json.RawMessage) error {
41 | notify := &protocol.PromptListChangedNotification{}
42 | if len(rawParams) > 0 {
43 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil {
44 | return err
45 | }
46 | }
47 | return client.notifyHandler.PromptListChanged(ctx, notify)
48 | }
49 |
50 | func (client *Client) handleNotifyWithResourcesListChanged(ctx context.Context, rawParams json.RawMessage) error {
51 | notify := &protocol.ResourceListChangedNotification{}
52 | if len(rawParams) > 0 {
53 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil {
54 | return err
55 | }
56 | }
57 | return client.notifyHandler.ResourceListChanged(ctx, notify)
58 | }
59 |
60 | func (client *Client) handleNotifyWithResourcesUpdated(ctx context.Context, rawParams json.RawMessage) error {
61 | notify := &protocol.ResourceUpdatedNotification{}
62 | if len(rawParams) > 0 {
63 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil {
64 | return err
65 | }
66 | }
67 | return client.notifyHandler.ResourcesUpdated(ctx, notify)
68 | }
69 |
70 | func (client *Client) handleNotifyWithProgress(ctx context.Context, rawParams json.RawMessage) error {
71 | notify := &protocol.ProgressNotification{}
72 | if len(rawParams) > 0 {
73 | if err := pkg.JSONUnmarshal(rawParams, notify); err != nil {
74 | return err
75 | }
76 | }
77 | client.progressChanRW.RLock()
78 | defer client.progressChanRW.RUnlock()
79 |
80 | ch, ok := client.progressToken2notifyChan[fmt.Sprint(notify.ProgressToken)]
81 | if !ok {
82 | return fmt.Errorf("progress token not found")
83 | }
84 |
85 | ctx, cancel := context.WithTimeout(ctx, time.Second*1)
86 | defer cancel()
87 |
88 | select {
89 | case ch <- notify:
90 | case <-ctx.Done():
91 | return ctx.Err()
92 | }
93 | return nil
94 | }
95 |
--------------------------------------------------------------------------------
/client/notify_handler.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 |
7 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
8 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
9 | )
10 |
11 | type SamplingHandler interface {
12 | CreateMessage(ctx context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error)
13 | }
14 |
15 | // NotifyHandler
16 | // When implementing a custom NotifyHandler, you can combine it with BaseNotifyHandler to implement it on demand without implementing extra methods.
17 | type NotifyHandler interface {
18 | ToolsListChanged(ctx context.Context, request *protocol.ToolListChangedNotification) error
19 | PromptListChanged(ctx context.Context, request *protocol.PromptListChangedNotification) error
20 | ResourceListChanged(ctx context.Context, request *protocol.ResourceListChangedNotification) error
21 | ResourcesUpdated(ctx context.Context, request *protocol.ResourceUpdatedNotification) error
22 | }
23 |
24 | type BaseNotifyHandler struct {
25 | Logger pkg.Logger
26 | }
27 |
28 | func NewBaseNotifyHandler() *BaseNotifyHandler {
29 | return &BaseNotifyHandler{pkg.DefaultLogger}
30 | }
31 |
32 | func (handler *BaseNotifyHandler) ToolsListChanged(_ context.Context, request *protocol.ToolListChangedNotification) error {
33 | return handler.defaultNotifyHandler(protocol.NotificationToolsListChanged, request)
34 | }
35 |
36 | func (handler *BaseNotifyHandler) PromptListChanged(_ context.Context, request *protocol.PromptListChangedNotification) error {
37 | return handler.defaultNotifyHandler(protocol.NotificationPromptsListChanged, request)
38 | }
39 |
40 | func (handler *BaseNotifyHandler) ResourceListChanged(_ context.Context, request *protocol.ResourceListChangedNotification) error {
41 | return handler.defaultNotifyHandler(protocol.NotificationResourcesListChanged, request)
42 | }
43 |
44 | func (handler *BaseNotifyHandler) ResourcesUpdated(_ context.Context, request *protocol.ResourceUpdatedNotification) error {
45 | return handler.defaultNotifyHandler(protocol.NotificationResourcesUpdated, request)
46 | }
47 |
48 | func (handler *BaseNotifyHandler) defaultNotifyHandler(method protocol.Method, notify interface{}) error {
49 | b, err := json.Marshal(notify)
50 | if err != nil {
51 | return err
52 | }
53 | handler.Logger.Infof("receive notify: method=%s, notify=%s", method, b)
54 | return nil
55 | }
56 |
--------------------------------------------------------------------------------
/client/receive.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 |
8 | "github.com/tidwall/gjson"
9 |
10 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
11 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
12 | )
13 |
14 | func (client *Client) receive(ctx context.Context, msg []byte) error {
15 | defer pkg.Recover()
16 |
17 | ctx = pkg.NewCancelShieldContext(ctx)
18 |
19 | if !gjson.GetBytes(msg, "id").Exists() {
20 | notify := &protocol.JSONRPCNotification{}
21 | if err := pkg.JSONUnmarshal(msg, ¬ify); err != nil {
22 | return err
23 | }
24 | if notify.Method == protocol.NotificationProgress { // need sync handle
25 | if err := client.receiveNotify(ctx, notify); err != nil {
26 | notify.RawParams = nil // simplified log
27 | client.logger.Errorf("receive notify:%+v error: %s", notify, err.Error())
28 | return err
29 | }
30 | return nil
31 | }
32 | go func() {
33 | defer pkg.Recover()
34 |
35 | if err := client.receiveNotify(ctx, notify); err != nil {
36 | notify.RawParams = nil // simplified log
37 | client.logger.Errorf("receive notify:%+v error: %s", notify, err.Error())
38 | return
39 | }
40 | }()
41 | return nil
42 | }
43 |
44 | // Determine if it's a request or response
45 | if !gjson.GetBytes(msg, "method").Exists() {
46 | resp := &protocol.JSONRPCResponse{}
47 | if err := pkg.JSONUnmarshal(msg, &resp); err != nil {
48 | return err
49 | }
50 | if err := client.receiveResponse(resp); err != nil {
51 | resp.RawResult = nil // simplified log
52 | client.logger.Errorf("receive response:%+v error: %s", resp, err.Error())
53 | return err
54 | }
55 | return nil
56 | }
57 |
58 | req := &protocol.JSONRPCRequest{}
59 | if err := pkg.JSONUnmarshal(msg, &req); err != nil {
60 | return err
61 | }
62 | if !req.IsValid() {
63 | return pkg.ErrRequestInvalid
64 | }
65 | go func() {
66 | defer pkg.Recover()
67 |
68 | if err := client.receiveRequest(ctx, req); err != nil {
69 | req.RawParams = nil // simplified log
70 | client.logger.Errorf("receive request:%+v error: %s", req, err.Error())
71 | return
72 | }
73 | }()
74 | return nil
75 | }
76 |
77 | func (client *Client) receiveRequest(ctx context.Context, request *protocol.JSONRPCRequest) error {
78 | var (
79 | result protocol.ClientResponse
80 | err error
81 | )
82 |
83 | switch request.Method {
84 | case protocol.Ping:
85 | result, err = client.handleRequestWithPing()
86 | // case protocol.RootsList:
87 | // result, err = client.handleRequestWithListRoots(ctx, request.RawParams)
88 | case protocol.SamplingCreateMessage:
89 | result, err = client.handleRequestWithCreateMessagesSampling(ctx, request.RawParams)
90 | default:
91 | err = fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, request.Method)
92 | }
93 |
94 | if err != nil {
95 | switch {
96 | case errors.Is(err, pkg.ErrMethodNotSupport):
97 | return client.sendMsgWithError(ctx, request.ID, protocol.MethodNotFound, err.Error())
98 | case errors.Is(err, pkg.ErrRequestInvalid):
99 | return client.sendMsgWithError(ctx, request.ID, protocol.InvalidRequest, err.Error())
100 | case errors.Is(err, pkg.ErrJSONUnmarshal):
101 | return client.sendMsgWithError(ctx, request.ID, protocol.ParseError, err.Error())
102 | default:
103 | return client.sendMsgWithError(ctx, request.ID, protocol.InternalError, err.Error())
104 | }
105 | }
106 | return client.sendMsgWithResponse(ctx, request.ID, result)
107 | }
108 |
109 | func (client *Client) receiveNotify(ctx context.Context, notify *protocol.JSONRPCNotification) error {
110 | switch notify.Method {
111 | case protocol.NotificationToolsListChanged:
112 | return client.handleNotifyWithToolsListChanged(ctx, notify.RawParams)
113 | case protocol.NotificationPromptsListChanged:
114 | return client.handleNotifyWithPromptsListChanged(ctx, notify.RawParams)
115 | case protocol.NotificationResourcesListChanged:
116 | return client.handleNotifyWithResourcesListChanged(ctx, notify.RawParams)
117 | case protocol.NotificationResourcesUpdated:
118 | return client.handleNotifyWithResourcesUpdated(ctx, notify.RawParams)
119 | case protocol.NotificationProgress:
120 | return client.handleNotifyWithProgress(ctx, notify.RawParams)
121 | default:
122 | return fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, notify.Method)
123 | }
124 | }
125 |
126 | func (client *Client) receiveResponse(response *protocol.JSONRPCResponse) error {
127 | respChan, ok := client.reqID2respChan.Get(fmt.Sprint(response.ID))
128 | if !ok {
129 | return fmt.Errorf("%w: requestID=%+v", pkg.ErrLackResponseChan, response.ID)
130 | }
131 |
132 | select {
133 | case respChan <- response:
134 | default:
135 | return fmt.Errorf("%w: response=%+v", pkg.ErrDuplicateResponseReceived, response)
136 | }
137 | return nil
138 | }
139 |
140 | func (client *Client) receiveInterrupt(err error) {
141 | for reqID, respChan := range client.reqID2respChan.Items() {
142 | select {
143 | case respChan <- protocol.NewJSONRPCErrorResponse(reqID, protocol.ConnectionError, err.Error()):
144 | default:
145 | }
146 | }
147 | }
148 |
--------------------------------------------------------------------------------
/client/send.go:
--------------------------------------------------------------------------------
1 | package client
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "errors"
7 | "fmt"
8 |
9 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
10 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
11 | )
12 |
13 | func (client *Client) sendMsgWithRequest(ctx context.Context, requestID protocol.RequestID, method protocol.Method, params protocol.ClientRequest) error {
14 | if requestID == nil {
15 | return fmt.Errorf("requestID can't is nil")
16 | }
17 |
18 | req := protocol.NewJSONRPCRequest(requestID, method, params)
19 |
20 | message, err := json.Marshal(req)
21 | if err != nil {
22 | return err
23 | }
24 |
25 | if err = client.transport.Send(ctx, message); err != nil {
26 | if !errors.Is(err, pkg.ErrSessionClosed) {
27 | return fmt.Errorf("sendRequest: transport send: %w", err)
28 | }
29 | if err = client.againInitialization(ctx); err != nil {
30 | return err
31 | }
32 | }
33 | return nil
34 | }
35 |
36 | func (client *Client) sendMsgWithResponse(ctx context.Context, requestID protocol.RequestID, result protocol.ClientResponse) error {
37 | if requestID == nil {
38 | return fmt.Errorf("requestID can't is nil")
39 | }
40 |
41 | resp := protocol.NewJSONRPCSuccessResponse(requestID, result)
42 |
43 | message, err := json.Marshal(resp)
44 | if err != nil {
45 | return err
46 | }
47 |
48 | if err = client.transport.Send(ctx, message); err != nil {
49 | return fmt.Errorf("sendResponse: transport send: %w", err)
50 | }
51 | return nil
52 | }
53 |
54 | func (client *Client) sendMsgWithNotification(ctx context.Context, method protocol.Method, params protocol.ClientNotify) error {
55 | notify := protocol.NewJSONRPCNotification(method, params)
56 |
57 | message, err := json.Marshal(notify)
58 | if err != nil {
59 | return err
60 | }
61 |
62 | if err = client.transport.Send(ctx, message); err != nil {
63 | return fmt.Errorf("sendNotification: transport send: %w", err)
64 | }
65 | return nil
66 | }
67 |
68 | func (client *Client) sendMsgWithError(ctx context.Context, requestID protocol.RequestID, code int, msg string) error {
69 | if requestID == nil {
70 | return fmt.Errorf("requestID can't is nil")
71 | }
72 |
73 | resp := protocol.NewJSONRPCErrorResponse(requestID, code, msg)
74 |
75 | message, err := json.Marshal(resp)
76 | if err != nil {
77 | return err
78 | }
79 |
80 | if err = client.transport.Send(ctx, message); err != nil {
81 | return fmt.Errorf("sendResponse: transport send: %w", err)
82 | }
83 | return nil
84 | }
85 |
86 | func (client *Client) againInitialization(ctx context.Context) error {
87 | client.ready.Store(false)
88 |
89 | client.initializationMu.Lock()
90 | defer client.initializationMu.Unlock()
91 |
92 | if client.ready.Load() {
93 | return nil
94 | }
95 |
96 | if _, err := client.initialization(ctx, protocol.NewInitializeRequest(client.clientInfo, client.clientCapabilities)); err != nil {
97 | return err
98 | }
99 | client.ready.Store(true)
100 | return nil
101 | }
102 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: true
3 | notify:
4 | wait_for_ci: true
5 |
6 | coverage:
7 | precision: 2
8 | round: down
9 | range: "70...100"
10 | status:
11 | project:
12 | default:
13 | # Adjust based on your expectations - fail if overall project coverage drops more than 1%
14 | target: auto
15 | threshold: 1%
16 | patch:
17 | default:
18 | target: 80%
19 | changes: true
20 |
21 | comment:
22 | layout: "reach, diff, flags, files"
23 | behavior: default
24 | require_changes: false # if true: comment only if coverage changes
25 |
26 | # Ignore test files and generated code if applicable
27 | ignore:
28 | - "**/*_test.go"
29 | - "**/mock_*.go"
30 | - "**/mocks/**"
31 | - "**/vendor/**"
32 | - "**/testdata/**"
33 | - "examples/**"
34 |
35 | # GitHub features
36 | github_checks:
37 | annotations: true
--------------------------------------------------------------------------------
/docs/design.md:
--------------------------------------------------------------------------------
1 | # MCP Go SDK Design Document
2 |
3 | MCP Go SDK is a powerful and easy-to-use Go client library designed for interacting with the Management Control Panel API. This SDK provides complete API coverage, including core functionalities such as resource management, configuration, monitoring, and automation operations.
4 |
5 | # Design Philosophy
6 |
7 | - MCP Protocol Messages
8 |
9 | | Capability Provider | Capability | Protocol Messages (Client Send) | Protocol Messages (Server Send) |
10 | | ------------------ | --------------- | ----------------------------------------------------------------------------------------------------- | ----------------------------------------------------------- |
11 | | Client&Server | Initialization | • Initialize
• Initialized notifications | (None) |
12 | | Client&Server | Ping | • Ping | • Ping |
13 | | Client&Server | Cancellation | • Cancelled Notifications | • Cancelled Notifications |
14 | | Client&Server | Progress | • Progress Notifications | • Progress Notifications |
15 | | Client | roots | • Root List Changes | • Listing Roots |
16 | | Client | sampling | (None) | • Creating Messages |
17 | | Server | prompts | • Listing Prompts
• Getting a Prompt | • List Changed Notification |
18 | | Server | resources | • Listing Resources
• Reading Resources
• Resource Templates
• Subscriptions: Request
• UnSubscriptions: Request | • List Changed Notification
• Subscriptions: Update Notification |
19 | | Server | tools | • Listing Tools
• Calling Tools | • List Changed Notification |
20 | | Server | Completion | • Requesting Completions | (None) |
21 | | Server | logging | • Setting Log Level | • Log Message Notifications |
22 |
23 | - Interaction Details
24 | 
25 | - Both client and server need to have send and receive capabilities
26 | - Messages can be abstracted into three types: request, response, and notification
27 | - The architecture can be abstracted into three layers: transport layer, protocol layer, and user layer (server, client)
28 |
29 | - Design Principles
30 | - Protocol layer and transport layer are decoupled through the transport interface
31 | - Protocol layer contains all MCP protocol-related definitions, including data structures, request construction, and response parsing
32 | - Both server and client layers have send and receive capabilities. Send capabilities include sending messages (request, response, notification) and matching requests with responses. Receive capabilities include routing messages (request, response, notification) and handling them asynchronously/synchronously
33 | - Server and client layers implement the combination of requests and responses, presenting as synchronous request, processing, and response from the user's perspective
34 |
35 | # Architecture Design
36 | 
37 |
38 | # Project Structure
39 |
40 | - transports
41 | - sse_client.go
42 | - sse_server.go
43 | - stdio_client.go
44 | - sdtio_server.go
45 | - transport.go // transport interface definition
46 | - pkg
47 | - errors.go // error definitions
48 | - log.go // log interface definition
49 | - protocol // contains all MCP protocol-related definitions, including data structures, request construction, and response parsing
50 | - initialize.go
51 | - ping.go
52 | - cancellation.go
53 | - progress.go
54 | - roots.go
55 | - sampling.go
56 | - prompts.go
57 | - resources.go
58 | - tools.go
59 | - completion.go
60 | - logging.go
61 | - pagination.go
62 | - jsonrpc.go
63 | - server
64 | - server.go
65 | - call.go // send messages (request, notification) to client
66 | - handle.go // handle messages (request, notification) from client, return response or not
67 | - send.go // send messages (request, response, notification) to client
68 | - receive.go // receive messages (request, response, notification) from client
69 | - client
70 | - client.go
71 | - call.go // send messages (request, notification) to server
72 | - handle.go // handle messages (request, notification) from server, return response or not
73 | - send.go // send messages (request, response, notification) to server
74 | - receive.go // receive messages (request, response, notification) from server
75 |
--------------------------------------------------------------------------------
/docs/design_cn.md:
--------------------------------------------------------------------------------
1 | MCP Go SDK是一个功能强大且易于使用的Go语言客户端库,专为与Management Control Panel API进行交互而设计。该SDK提供了完整的API覆盖,包括资源管理、配置、监控和自动化操作等核心功能。
2 |
3 | # 设计思路
4 |
5 | - MCP 协议消息
6 |
7 | | 能力提供方 | 能力 | 协议消息(客户端发送) | 协议消息(服务端发送) |
8 | | ------------- | ---------------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------ |
9 | | Client&Server | Initialization | • Initialize
• Initialized notifications | (无) |
10 | | Client&Server | Ping | • Ping | • Ping |
11 | | Client&Server | Cancellation | • Cancelled Notifications | • Cancelled Notifications |
12 | | Client&Server | Progress | • Progress Notifications | • Progress Notifications |
13 | | Client | roots | • Root List Changes | • Listing Roots |
14 | | Client | sampling | (无) | • Creating Messages |
15 | | Server | prompts | • Listing Prompts
• Getting a Prompt | • List Changed Notification |
16 | | Server | resources | • Listing Resources
• Reading Resources
• Resource Templates
• Subscriptions: Request
• UnSubscriptions: Request | • List Changed Notification
• Subscriptions: Update Notification |
17 | | Server | tools | • Listing Tools
• Calling Tools | • List Changed Notification |
18 | | Server | Completion | • Requesting Completions | (无) |
19 | | Server | logging | • Setting Log Level | • Log Message Notifications |
20 |
21 | - 交互细节
22 | 
23 | - 客户端和服务端都需要具备收发功能;
24 | - 可以将消息类型抽象为 message,具体实现包括 request、response、notification 三种;
25 | - 可以将架构抽象为三层传输层、协议层、用户层(server、client)
26 |
27 |
28 | - 设计思想
29 | - 协议层与传输层通过 transport 接口进行解耦;
30 | - protocol 层完成 MCP 协议相关的全部定义,包括数据结构定义、请求结构构造、响应结构解析;
31 | - server 层与 client 层都具备发送(send)和接收(receive)的能力,发送能力包括发送 message(request、response、notification) 请求和响应的匹配 ,接收能力包括对 message(request、response、notification) 的路由、异步/同步处理;
32 | - server 层与 client 层实现对 request 和 response 的组合,用户侧使用时表现为同步请求、同步处理、同步返回。
33 |
34 | # 架构设计
35 | 
36 |
37 | # 项目目录
38 |
39 | - transports
40 | - sse_client.go
41 | - sse_server.go
42 | - stdio_client.go
43 | - sdtio_server.go
44 | - transport.go // transport 接口定义
45 | - pkg
46 | - errors.go // error 定义
47 | - log.go // log 接口定义
48 | - protocol // 放置 mcp 协议相关的全部定义,包括数据结构定义、请求结构构造、响应结构解析;
49 | - initialize.go
50 | - ping.go
51 | - cancellation.go
52 | - progress.go
53 | - roots.go
54 | - sampling.go
55 | - prompts.go
56 | - resources.go
57 | - tools.go
58 | - completion.go
59 | - logging.go
60 | - pagination.go
61 | - jsonrpc.go
62 | - server
63 | - server.go
64 | - call.go // 向客户端发送 message(request、notification)
65 | - handle.go // 对来自客户端的 message(request、notification) 进行处理,返回或不返回 response
66 | - send.go // 向客户端发送 message(request、response、notification)
67 | - receive.go // 对来自客户端的 message(request、response、notification)进行接收
68 | - client
69 | - client.go
70 | - call.go // 向服务端发送 message(request、notification)
71 | - handle.go // 对来自服务端的 message(request、notification) 进行处理,返回或不返回 response
72 | - send.go // 向服务端发送 message(request、response、notification)
73 | - receive.go // 对来自服务端的 message(request、response、notification)进行接收
74 |
--------------------------------------------------------------------------------
/docs/images/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img.png
--------------------------------------------------------------------------------
/docs/images/img_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img_1.png
--------------------------------------------------------------------------------
/docs/images/img_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/img_2.png
--------------------------------------------------------------------------------
/docs/images/wechat_qrcode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/docs/images/wechat_qrcode.png
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # More Examples
2 | References: https://github.com/ThinkInAIXYZ/mcp-servers
3 |
--------------------------------------------------------------------------------
/examples/auth_tool/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "net/http"
9 | "os"
10 | "os/signal"
11 | "syscall"
12 | "time"
13 |
14 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
15 | "github.com/ThinkInAIXYZ/go-mcp/server"
16 | "github.com/ThinkInAIXYZ/go-mcp/transport"
17 | )
18 |
19 | type userIDKey struct{}
20 |
21 | func setUserIDToCtx(ctx context.Context, userID string) context.Context {
22 | return context.WithValue(ctx, userIDKey{}, userID)
23 | }
24 |
25 | func getUserIDFromCtx(ctx context.Context) (string, error) {
26 | userID := ctx.Value(userIDKey{})
27 | if userID == nil {
28 | return "", errors.New("no userID found")
29 | }
30 | return userID.(string), nil
31 | }
32 |
33 | type currentTimeReq struct {
34 | Timezone string `json:"timezone" description:"current time timezone"`
35 | }
36 |
37 | func main() {
38 | messageEndpointURL := "/message"
39 |
40 | userParamKey := "user_id"
41 | paramKeysOpt := transport.WithSSEServerTransportAndHandlerOptionCopyParamKeys([]string{userParamKey})
42 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL, paramKeysOpt)
43 | if err != nil {
44 | log.Panicf("new sse transport and hander with error: %v", err)
45 | }
46 |
47 | mcpServer, err := server.NewServer(sseTransport,
48 | server.WithServerInfo(protocol.Implementation{
49 | Name: "mcp-example",
50 | Version: "1.0.0",
51 | }),
52 | )
53 | if err != nil {
54 | panic(err)
55 | }
56 |
57 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{})
58 | if err != nil {
59 | panic(fmt.Sprintf("Failed to create tool: %v", err))
60 | }
61 |
62 | authentication := authenticationMiddleware(map[string][]string{
63 | tool.Name: {"test_1"},
64 | })
65 | mcpServer.RegisterTool(tool, currentTime, authentication)
66 |
67 | router := http.NewServeMux()
68 | router.HandleFunc("/sse", mcpHandler.HandleSSE().ServeHTTP)
69 | router.HandleFunc(messageEndpointURL, func(w http.ResponseWriter, r *http.Request) {
70 | userID := r.URL.Query().Get(userParamKey)
71 | if userID == "" {
72 | w.Header().Set("Content-Type", "text/plain")
73 | w.WriteHeader(http.StatusBadRequest)
74 | if _, e := w.Write([]byte("lack user_id")); e != nil {
75 | fmt.Printf("writeError: %+v", e)
76 | }
77 | return
78 | }
79 |
80 | r = r.WithContext(setUserIDToCtx(r.Context(), userID))
81 |
82 | mcpHandler.HandleMessage().ServeHTTP(w, r)
83 | })
84 |
85 | // Can be replaced by using gin framework
86 | // router := gin.Default()
87 | // router.GET("/sse", func(ctx *gin.Context) {
88 | // mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request)
89 | // })
90 | // router.POST(messageEndpointURL, func(ctx *gin.Context) {
91 | // mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request)
92 | // })
93 |
94 | httpServer := &http.Server{
95 | Addr: ":8080",
96 | Handler: router,
97 | IdleTimeout: time.Minute,
98 | }
99 |
100 | errCh := make(chan error, 3)
101 | go func() {
102 | errCh <- mcpServer.Run()
103 | }()
104 |
105 | go func() {
106 | if err = httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
107 | errCh <- err
108 | }
109 | }()
110 |
111 | if err = signalWaiter(errCh); err != nil {
112 | panic(fmt.Sprintf("signal waiter: %v", err))
113 | }
114 |
115 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
116 | defer cancel()
117 |
118 | httpServer.RegisterOnShutdown(func() {
119 | if err = mcpServer.Shutdown(ctx); err != nil {
120 | panic(err)
121 | }
122 | })
123 |
124 | if err = httpServer.Shutdown(ctx); err != nil {
125 | panic(err)
126 | }
127 | }
128 |
129 | func authenticationMiddleware(toolName2UserID map[string][]string) server.ToolMiddleware {
130 | return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
131 | return func(ctx context.Context, req *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
132 | userID, err := getUserIDFromCtx(ctx)
133 | if err != nil {
134 | return nil, err
135 | }
136 |
137 | for _, id := range toolName2UserID[req.Name] {
138 | if userID == id {
139 | return next(ctx, req)
140 | }
141 | }
142 | return nil, fmt.Errorf("user %s not authorized", userID)
143 | }
144 | }
145 | }
146 |
147 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
148 | req := new(currentTimeReq)
149 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
150 | return nil, err
151 | }
152 |
153 | loc, err := time.LoadLocation(req.Timezone)
154 | if err != nil {
155 | return nil, fmt.Errorf("parse timezone with error: %v", err)
156 | }
157 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc))
158 |
159 | return &protocol.CallToolResult{
160 | Content: []protocol.Content{
161 | &protocol.TextContent{
162 | Type: "text",
163 | Text: text,
164 | },
165 | },
166 | }, nil
167 | }
168 |
169 | func signalWaiter(errCh chan error) error {
170 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM}
171 | if signal.Ignored(syscall.SIGHUP) {
172 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM}
173 | }
174 |
175 | signals := make(chan os.Signal, 1)
176 | signal.Notify(signals, signalToNotify...)
177 |
178 | select {
179 | case sig := <-signals:
180 | switch sig {
181 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM:
182 | log.Printf("Received signal: %s\n", sig)
183 | // graceful shutdown
184 | return nil
185 | }
186 | case err := <-errCh:
187 | return err
188 | }
189 |
190 | return nil
191 | }
192 |
--------------------------------------------------------------------------------
/examples/current_time_server/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "flag"
6 | "fmt"
7 | "log"
8 | "os"
9 | "os/signal"
10 | "syscall"
11 | "time"
12 |
13 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
14 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
15 | "github.com/ThinkInAIXYZ/go-mcp/server"
16 | "github.com/ThinkInAIXYZ/go-mcp/transport"
17 | )
18 |
19 | type currentTimeReq struct {
20 | Timezone string `json:"timezone" description:"current time timezone"`
21 | }
22 |
23 | func main() {
24 | // new mcp server with stdio or sse transport
25 | srv, err := server.NewServer(
26 | getTransport(),
27 | server.WithServerInfo(protocol.Implementation{
28 | Name: "current-time-v2-server",
29 | Version: "1.0.0",
30 | }),
31 | )
32 | if err != nil {
33 | log.Fatalf("Failed to create server: %v", err)
34 | }
35 |
36 | // new protocol tool with name, descipriton and properties
37 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{})
38 | if err != nil {
39 | log.Fatalf("Failed to create tool: %v", err)
40 | return
41 | }
42 |
43 | // register tool and start mcp server
44 | srv.RegisterTool(tool, currentTime,
45 | server.RateLimitMiddleware(pkg.NewTokenBucketLimiter(pkg.Rate{
46 | Limit: 10.0, // 每秒10个请求
47 | Burst: 20, // 最多允许20个请求的突发
48 | })))
49 | // srv.RegisterResource()
50 | // srv.RegisterPrompt()
51 | // srv.RegisterResourceTemplate()
52 |
53 | errCh := make(chan error)
54 | go func() {
55 | errCh <- srv.Run()
56 | }()
57 |
58 | if err = signalWaiter(errCh); err != nil {
59 | log.Fatalf("signal waiter: %v", err)
60 | return
61 | }
62 |
63 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
64 | defer cancel()
65 |
66 | if err := srv.Shutdown(ctx); err != nil {
67 | log.Fatalf("Shutdown error: %v", err)
68 | }
69 | }
70 |
71 | func getTransport() (t transport.ServerTransport) {
72 | var (
73 | mode string
74 | addr = "127.0.0.1:8080"
75 | )
76 |
77 | flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"")
78 | flag.Parse()
79 |
80 | switch mode {
81 | case "stdio":
82 | log.Println("start current time mcp server with stdio transport")
83 | t = transport.NewStdioServerTransport()
84 | case "sse":
85 | log.Printf("start current time mcp server with sse transport, listen %s", addr)
86 | t, _ = transport.NewSSEServerTransport(addr)
87 | case "streamable_http":
88 | log.Printf("start current time mcp server with streamable_http transport, listen %s", addr)
89 | t = transport.NewStreamableHTTPServerTransport(addr)
90 | default:
91 | panic(fmt.Errorf("unknown mode: %s", mode))
92 | }
93 |
94 | return t
95 | }
96 |
97 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
98 | req := new(currentTimeReq)
99 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
100 | return nil, err
101 | }
102 |
103 | loc, err := time.LoadLocation(req.Timezone)
104 | if err != nil {
105 | return nil, fmt.Errorf("parse timezone with error: %v", err)
106 | }
107 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc))
108 |
109 | return &protocol.CallToolResult{
110 | Content: []protocol.Content{
111 | &protocol.TextContent{
112 | Type: "text",
113 | Text: text,
114 | },
115 | },
116 | }, nil
117 | }
118 |
119 | func signalWaiter(errCh chan error) error {
120 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM}
121 | if signal.Ignored(syscall.SIGHUP) {
122 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM}
123 | }
124 |
125 | signals := make(chan os.Signal, 1)
126 | signal.Notify(signals, signalToNotify...)
127 |
128 | select {
129 | case sig := <-signals:
130 | switch sig {
131 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM:
132 | log.Printf("Received signal: %s\n", sig)
133 | // graceful shutdown
134 | return nil
135 | }
136 | case err := <-errCh:
137 | return err
138 | }
139 |
140 | return nil
141 | }
142 |
--------------------------------------------------------------------------------
/examples/filesystem_client/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "log"
7 | "time"
8 |
9 | "github.com/ThinkInAIXYZ/go-mcp/client"
10 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
11 | "github.com/ThinkInAIXYZ/go-mcp/transport"
12 | )
13 |
14 | func main() {
15 | t, err := transport.NewStdioClientTransport("npx", []string{"-y", "@modelcontextprotocol/server-filesystem", "~/tmp"})
16 | if err != nil {
17 | log.Fatal(err)
18 | }
19 |
20 | cli, err := client.NewClient(t, client.WithClientInfo(&protocol.Implementation{
21 | Name: "test",
22 | Version: "1.0.0",
23 | }))
24 | if err != nil {
25 | log.Fatalf("Failed to new client: %v", err)
26 | }
27 | defer func() {
28 | if err = cli.Close(); err != nil {
29 | log.Fatalf("Failed to close client: %v", err)
30 | }
31 | }()
32 |
33 | // Create context with timeout
34 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
35 | defer cancel()
36 |
37 | // List Tools
38 | log.Println("Listing available tools...")
39 | tools, err := cli.ListTools(ctx)
40 | if err != nil {
41 | log.Fatalf("Failed to list tools: %v", err)
42 | }
43 | for _, tool := range tools.Tools {
44 | log.Printf("- %s: %s\n", tool.Name, tool.Description)
45 | }
46 |
47 | // List allowed directories
48 | log.Println("Listing allowed directories...")
49 | listDirRequest := &protocol.CallToolRequest{
50 | Name: "list_allowed_directories",
51 | }
52 | result, err := cli.CallTool(ctx, listDirRequest)
53 | if err != nil {
54 | log.Fatalf("Failed to list allowed directories: %v", err)
55 | }
56 | printToolResult(result)
57 | log.Println()
58 |
59 | // List ~/tmp
60 | log.Println("Listing ~/tmp directory...")
61 | listTmpRequest := &protocol.CallToolRequest{
62 | Name: "list_directory",
63 | Arguments: map[string]interface{}{"path": "~/tmp"},
64 | }
65 | result, err = cli.CallTool(ctx, listTmpRequest)
66 | if err != nil {
67 | log.Fatalf("Failed to list directory: %v", err)
68 | }
69 | printToolResult(result)
70 | log.Println()
71 |
72 | // Create mcp directory
73 | log.Println("Creating ~/tmp/mcp directory...")
74 | createDirRequest := &protocol.CallToolRequest{
75 | Name: "create_directory",
76 | Arguments: map[string]interface{}{"path": "~/tmp/mcp"},
77 | }
78 | result, err = cli.CallTool(ctx, createDirRequest)
79 | if err != nil {
80 | log.Fatalf("Failed to create directory: %v", err)
81 | }
82 | printToolResult(result)
83 | log.Println()
84 |
85 | // Create hello.txt
86 | log.Println("Creating ~/tmp/mcp/hello.txt...")
87 | writeFileRequest := &protocol.CallToolRequest{
88 | Name: "write_file",
89 | Arguments: map[string]interface{}{
90 | "path": "~/tmp/mcp/hello.txt",
91 | "content": "Hello World",
92 | },
93 | }
94 | result, err = cli.CallTool(ctx, writeFileRequest)
95 | if err != nil {
96 | log.Fatalf("Failed to create file: %v", err)
97 | }
98 | printToolResult(result)
99 | log.Println()
100 |
101 | // Verify file contents
102 | log.Println("Reading ~/tmp/mcp/hello.txt...")
103 | readFileRequest := &protocol.CallToolRequest{
104 | Name: "read_file",
105 | Arguments: map[string]interface{}{
106 | "path": "~/tmp/mcp/hello.txt",
107 | },
108 | }
109 | result, err = cli.CallTool(ctx, readFileRequest)
110 | if err != nil {
111 | log.Fatalf("Failed to read file: %v", err)
112 | }
113 | printToolResult(result)
114 |
115 | // Get file info
116 | log.Println("Getting info for ~/tmp/mcp/hello.txt...")
117 | fileInfoRequest := &protocol.CallToolRequest{
118 | Name: "get_file_info",
119 | Arguments: map[string]interface{}{
120 | "path": "~/tmp/mcp/hello.txt",
121 | },
122 | }
123 | result, err = cli.CallTool(ctx, fileInfoRequest)
124 | if err != nil {
125 | log.Fatalf("Failed to get file info: %v", err)
126 | }
127 | printToolResult(result)
128 | }
129 |
130 | // Helper function to print tool results
131 | func printToolResult(result *protocol.CallToolResult) {
132 | for _, content := range result.Content {
133 | if textContent, ok := content.(*protocol.TextContent); ok {
134 | log.Println(textContent.Text)
135 | } else {
136 | jsonBytes, _ := json.MarshalIndent(content, "", " ")
137 | log.Println(string(jsonBytes))
138 | }
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/examples/http_handler/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "net/http"
9 | "os"
10 | "os/signal"
11 | "syscall"
12 | "time"
13 |
14 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
15 | "github.com/ThinkInAIXYZ/go-mcp/server"
16 | "github.com/ThinkInAIXYZ/go-mcp/transport"
17 | )
18 |
19 | type currentTimeReq struct {
20 | Timezone string `json:"timezone" description:"current time timezone"`
21 | }
22 |
23 | func main() {
24 | messageEndpointURL := "/message"
25 |
26 | sseTransport, mcpHandler, err := transport.NewSSEServerTransportAndHandler(messageEndpointURL)
27 | if err != nil {
28 | log.Panicf("new sse transport and hander with error: %v", err)
29 | }
30 |
31 | mcpServer, err := server.NewServer(sseTransport,
32 | server.WithServerInfo(protocol.Implementation{
33 | Name: "mcp-example",
34 | Version: "1.0.0",
35 | }),
36 | )
37 | if err != nil {
38 | panic(err)
39 | }
40 |
41 | tool, err := protocol.NewTool("current_time", "Get current time with timezone, Asia/Shanghai is default", currentTimeReq{})
42 | if err != nil {
43 | panic(fmt.Sprintf("Failed to create tool: %v", err))
44 | }
45 |
46 | mcpServer.RegisterTool(tool, currentTime)
47 |
48 | router := http.NewServeMux()
49 | router.HandleFunc("/sse", mcpHandler.HandleSSE().ServeHTTP)
50 | router.HandleFunc(messageEndpointURL, mcpHandler.HandleMessage().ServeHTTP)
51 |
52 | // Can be replaced by using gin framework
53 | // router := gin.Default()
54 | // router.GET("/sse", func(ctx *gin.Context) {
55 | // mcpHandler.HandleSSE().ServeHTTP(ctx.Writer, ctx.Request)
56 | // })
57 | // router.POST(messageEndpointURL, func(ctx *gin.Context) {
58 | // mcpHandler.HandleMessage().ServeHTTP(ctx.Writer, ctx.Request)
59 | // })
60 |
61 | httpServer := &http.Server{
62 | Addr: ":8080",
63 | Handler: router,
64 | IdleTimeout: time.Minute,
65 | }
66 |
67 | errCh := make(chan error, 3)
68 | go func() {
69 | errCh <- mcpServer.Run()
70 | }()
71 |
72 | go func() {
73 | if err = httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
74 | errCh <- err
75 | }
76 | }()
77 |
78 | if err = signalWaiter(errCh); err != nil {
79 | panic(fmt.Sprintf("signal waiter: %v", err))
80 | }
81 |
82 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
83 | defer cancel()
84 |
85 | httpServer.RegisterOnShutdown(func() {
86 | if err = mcpServer.Shutdown(ctx); err != nil {
87 | panic(err)
88 | }
89 | })
90 |
91 | if err = httpServer.Shutdown(ctx); err != nil {
92 | panic(err)
93 | }
94 | }
95 |
96 | func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
97 | req := new(currentTimeReq)
98 | if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
99 | return nil, err
100 | }
101 |
102 | loc, err := time.LoadLocation(req.Timezone)
103 | if err != nil {
104 | return nil, fmt.Errorf("parse timezone with error: %v", err)
105 | }
106 | text := fmt.Sprintf(`current time is %s`, time.Now().In(loc))
107 |
108 | return &protocol.CallToolResult{
109 | Content: []protocol.Content{
110 | &protocol.TextContent{
111 | Type: "text",
112 | Text: text,
113 | },
114 | },
115 | }, nil
116 | }
117 |
118 | func signalWaiter(errCh chan error) error {
119 | signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM}
120 | if signal.Ignored(syscall.SIGHUP) {
121 | signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM}
122 | }
123 |
124 | signals := make(chan os.Signal, 1)
125 | signal.Notify(signals, signalToNotify...)
126 |
127 | select {
128 | case sig := <-signals:
129 | switch sig {
130 | case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM:
131 | log.Printf("Received signal: %s\n", sig)
132 | // graceful shutdown
133 | return nil
134 | }
135 | case err := <-errCh:
136 | return err
137 | }
138 |
139 | return nil
140 | }
141 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/ThinkInAIXYZ/go-mcp
2 |
3 | go 1.18
4 |
5 | require (
6 | github.com/google/uuid v1.6.0
7 | github.com/orcaman/concurrent-map/v2 v2.0.1
8 | github.com/tidwall/gjson v1.18.0
9 | github.com/yosida95/uritemplate/v3 v3.0.2
10 | )
11 |
12 | require (
13 | github.com/tidwall/match v1.1.1 // indirect
14 | github.com/tidwall/pretty v1.2.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
3 | github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c=
4 | github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM=
5 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
6 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
7 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
8 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
9 | github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
10 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
11 | github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
12 | github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
13 |
--------------------------------------------------------------------------------
/hack/.lintcheck_failures:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/hack/.lintcheck_failures
--------------------------------------------------------------------------------
/hack/.test_ignored_files:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/hack/.test_ignored_files
--------------------------------------------------------------------------------
/hack/resolve-modules.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This is used by the linter action.
4 | # Recursively finds all directories with a go.mod file and creates
5 | # a GitHub Actions JSON output option.
6 |
7 | set -o errexit
8 |
9 | echo "Resolving modules in $(pwd)"
10 |
11 | PROJECT_HOME=$(
12 | cd "$(dirname "${BASH_SOURCE[0]}")" &&
13 | cd .. &&
14 | pwd
15 | )
16 |
17 | source "${PROJECT_HOME}/hack/util.sh"
18 |
19 | FAILURE_FILE=${PROJECT_HOME}/hack/.lintcheck_failures
20 |
21 | all_modules=$(util::find_modules)
22 | failing_modules=()
23 | while IFS='' read -r line; do failing_modules+=("$line"); done < <(cat "$FAILURE_FILE")
24 |
25 | echo "Ignored failing modules:"
26 | echo "${failing_modules[*]}"
27 | echo
28 |
29 | PATHS=""
30 |
31 | for mod in $all_modules; do
32 | echo "Checking module: $mod"
33 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$?
34 | if [[ "$in_failing" -ne "0" ]]; then
35 | PATHS+=$(printf '{"workdir":"%s"},' ${mod})
36 | fi
37 | done
38 |
39 | echo "::set-output name=matrix::{\"include\":[${PATHS%?}]}"
40 |
--------------------------------------------------------------------------------
/hack/tools.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This is a tools shell script
4 | # used by Makefile commands
5 |
6 | set -o errexit
7 | set -o nounset
8 | set -o pipefail
9 |
10 | GO111MODULE=on
11 | PROJECT_HOME=$(
12 | cd "$(dirname "${BASH_SOURCE[0]}")" &&
13 | cd .. &&
14 | pwd
15 | )
16 |
17 | source "${PROJECT_HOME}/hack/util.sh"
18 |
19 | LINTER=${PROJECT_HOME}/bin/golangci-lint
20 | LINTER_CONFIG=${PROJECT_HOME}/.golangci.yml
21 | FAILURE_FILE=${PROJECT_HOME}/hack/.lintcheck_failures
22 | IGNORED_FILE=${PROJECT_HOME}/hack/.test_ignored_files
23 |
24 | all_modules=$(util::find_modules)
25 | failing_modules=()
26 | while IFS='' read -r line; do failing_modules+=("$line"); done < <(cat "$FAILURE_FILE")
27 | ignored_modules=()
28 | while IFS='' read -r line; do ignored_modules+=("$line"); done < <(cat "$IGNORED_FILE")
29 |
30 | # functions
31 | # lint all mod
32 | function lint() {
33 | for mod in $all_modules; do
34 | local in_failing
35 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$?
36 | if [[ "$in_failing" -ne "0" ]]; then
37 | pushd "$mod" >/dev/null &&
38 | echo "golangci lint $(sed -n 1p go.mod | cut -d ' ' -f2)" &&
39 | eval "${LINTER} run --timeout=5m --config=${LINTER_CONFIG}"
40 | popd >/dev/null || exit
41 | fi
42 | done
43 | }
44 |
45 | # test all mod
46 | function test() {
47 | for mod in $all_modules; do
48 | local in_failing
49 | util::array_contains "$mod" "${ignored_modules[*]}" && in_failing=$? || in_failing=$?
50 | if [[ "$in_failing" -ne "0" ]]; then
51 | pushd "$mod" >/dev/null &&
52 | echo "go test $(sed -n 1p go.mod | cut -d ' ' -f2)" &&
53 | go test -race ./...
54 | popd >/dev/null || exit
55 | fi
56 | done
57 | }
58 |
59 | function test_coverage() {
60 | echo "" > coverage.txt
61 | local base
62 | base=$(pwd)
63 | for mod in $all_modules; do
64 | local in_failing
65 | util::array_contains "$mod" "${ignored_modules[*]}" && in_failing=$? || in_failing=$?
66 | if [[ "$in_failing" -ne "0" ]]; then
67 | pushd "$mod" >/dev/null &&
68 | echo "go test $(sed -n 1p go.mod | cut -d ' ' -f2)" &&
69 | go test -race -coverprofile=profile.out -covermode=atomic ./...
70 | if [ -f profile.out ]; then
71 | cat profile.out > "${base}/coverage.txt"
72 | rm profile.out
73 | fi
74 | popd >/dev/null || exit
75 | fi
76 | done
77 | }
78 |
79 | # try to fix all mod with golangci-lint
80 | function fix() {
81 | for mod in $all_modules; do
82 | local in_failing
83 | util::array_contains "$mod" "${failing_modules[*]}" && in_failing=$? || in_failing=$?
84 | if [[ "$in_failing" -ne "0" ]]; then
85 | pushd "$mod" >/dev/null &&
86 | echo "golangci fix $(sed -n 1p go.mod | cut -d ' ' -f2)" &&
87 | eval "${LINTER} run -v --fix --timeout=5m --config=${LINTER_CONFIG}"
88 | popd >/dev/null || exit
89 | fi
90 | done
91 | }
92 |
93 | function tidy() {
94 | for mod in $all_modules; do
95 | pushd "$mod" >/dev/null &&
96 | echo "go mod tidy $(sed -n 1p go.mod | cut -d ' ' -f2)" &&
97 | go mod tidy
98 | popd >/dev/null || exit
99 | done
100 | }
101 |
102 | function help() {
103 | echo "use: lint, test, test_coverage, fix, tidy"
104 | }
105 |
106 | case $1 in
107 | lint)
108 | lint
109 | ;;
110 | test)
111 | test
112 | ;;
113 | test_coverage)
114 | test_coverage
115 | ;;
116 | tidy)
117 | tidy
118 | ;;
119 | fix)
120 | fix
121 | ;;
122 | *)
123 | help
124 | ;;
125 | esac
126 |
--------------------------------------------------------------------------------
/hack/util.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This is a common util functions shell script
4 |
5 | # arguments: target, item1, item2, item3, ...
6 | # returns 0 if target is in the given items, 1 otherwise.
7 | function util::array_contains() {
8 | local target="$1"
9 | shift
10 | local items="$*"
11 | for item in ${items[*]}; do
12 | if [[ "${item}" == "${target}" ]]; then
13 | return 0
14 | fi
15 | done
16 | return 1
17 | }
18 |
19 | # find all go mod path
20 | # returns an array contains mod path
21 | function util::find_modules() {
22 | find . -not \( \
23 | \( \
24 | -path './output' \
25 | -o -path './.git' \
26 | -o -path '*/third_party/*' \
27 | -o -path '*/vendor/*' \
28 | \) -prune \
29 | \) -name 'go.mod' -print0 | xargs -0 -I {} dirname {}
30 | }
31 |
--------------------------------------------------------------------------------
/pkg/atomic.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import "sync/atomic"
4 |
5 | type AtomicBool struct {
6 | b atomic.Value
7 | }
8 |
9 | func NewAtomicBool() *AtomicBool {
10 | b := &AtomicBool{}
11 | b.b.Store(false)
12 | return b
13 | }
14 |
15 | func (b *AtomicBool) Store(value bool) {
16 | b.b.Store(value)
17 | }
18 |
19 | func (b *AtomicBool) Load() bool {
20 | return b.b.Load().(bool)
21 | }
22 |
23 | type AtomicString struct {
24 | b atomic.Value
25 | }
26 |
27 | func NewAtomicString() *AtomicString {
28 | b := &AtomicString{}
29 | b.b.Store("")
30 | return b
31 | }
32 |
33 | func (b *AtomicString) Store(value string) {
34 | b.b.Store(value)
35 | }
36 |
37 | func (b *AtomicString) Load() string {
38 | return b.b.Load().(string)
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/context.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "context"
5 | "time"
6 | )
7 |
8 | type CancelShieldContext struct {
9 | context.Context
10 | }
11 |
12 | func NewCancelShieldContext(ctx context.Context) context.Context {
13 | return CancelShieldContext{Context: ctx}
14 | }
15 |
16 | func (v CancelShieldContext) Deadline() (deadline time.Time, ok bool) {
17 | return
18 | }
19 |
20 | func (v CancelShieldContext) Done() <-chan struct{} {
21 | return nil
22 | }
23 |
24 | func (v CancelShieldContext) Err() error {
25 | return nil
26 | }
27 |
--------------------------------------------------------------------------------
/pkg/errors.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | )
7 |
8 | var (
9 | ErrClientNotSupport = errors.New("this feature client not support")
10 | ErrServerNotSupport = errors.New("this feature server not support")
11 | ErrRequestInvalid = errors.New("request invalid")
12 | ErrLackResponseChan = errors.New("lack response chan")
13 | ErrDuplicateResponseReceived = errors.New("duplicate response received")
14 | ErrMethodNotSupport = errors.New("method not support")
15 | ErrJSONUnmarshal = errors.New("json unmarshal error")
16 | ErrSessionHasNotInitialized = errors.New("the session has not been initialized")
17 | ErrLackSession = errors.New("lack session")
18 | ErrSessionClosed = errors.New("session closed")
19 | ErrSendEOF = errors.New("send EOF")
20 | ErrRateLimitExceeded = errors.New("rate limit exceeded")
21 | )
22 |
23 | type ResponseError struct {
24 | Code int
25 | Message string
26 | Data interface{}
27 | }
28 |
29 | func NewResponseError(code int, message string, data interface{}) *ResponseError {
30 | return &ResponseError{Code: code, Message: message, Data: data}
31 | }
32 |
33 | func (e *ResponseError) Error() string {
34 | return fmt.Sprintf("code=%d message=%s data=%+v", e.Code, e.Message, e.Data)
35 | }
36 |
--------------------------------------------------------------------------------
/pkg/helper.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "errors"
5 | "log"
6 | "runtime/debug"
7 | "strings"
8 | "unsafe"
9 | )
10 |
11 | func Recover() {
12 | if r := recover(); r != nil {
13 | log.Printf("panic: %v\nstack: %s", r, debug.Stack())
14 | }
15 | }
16 |
17 | func RecoverWithFunc(f func(r any)) {
18 | if r := recover(); r != nil {
19 | f(r)
20 | log.Printf("panic: %v\nstack: %s", r, debug.Stack())
21 | }
22 | }
23 |
24 | func B2S(b []byte) string {
25 | return *(*string)(unsafe.Pointer(&b))
26 | }
27 |
28 | func JoinErrors(errs []error) error {
29 | if len(errs) == 0 {
30 | return nil
31 | }
32 | messages := make([]string, len(errs))
33 | for i, err := range errs {
34 | messages[i] = err.Error()
35 | }
36 | return errors.New(strings.Join(messages, "; "))
37 | }
38 |
--------------------------------------------------------------------------------
/pkg/json.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | )
7 |
8 | // var sonicAPI = sonic.Config{UseInt64: true}.Froze() // Effectively prevents integer overflow
9 |
10 | func JSONUnmarshal(data []byte, v interface{}) error {
11 | if err := json.Unmarshal(data, v); err != nil {
12 | return fmt.Errorf("%w: data=%s, error: %+v", ErrJSONUnmarshal, data, err)
13 | }
14 | return nil
15 | }
16 |
--------------------------------------------------------------------------------
/pkg/limiter.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "sync"
5 | "time"
6 | )
7 |
8 | // RateLimiter 定义速率限制接口
9 | type RateLimiter interface {
10 | Allow(toolName string) bool
11 | }
12 |
13 | // TokenBucketLimiter 令牌桶限速器实现
14 | type TokenBucketLimiter struct {
15 | mu sync.RWMutex
16 | buckets map[string]*bucket
17 | defaultLimit Rate
18 | toolLimits map[string]Rate
19 | }
20 |
21 | // Rate 定义速率限制参数
22 | type Rate struct {
23 | Limit float64 // 每秒允许的请求数
24 | Burst int // 突发请求上限
25 | }
26 |
27 | // bucket 令牌桶
28 | type bucket struct {
29 | tokens float64
30 | lastTimestamp time.Time
31 | rate Rate
32 | }
33 |
34 | // NewTokenBucketLimiter 创建新的令牌桶限速器
35 | func NewTokenBucketLimiter(defaultRate Rate) *TokenBucketLimiter {
36 | return &TokenBucketLimiter{
37 | buckets: make(map[string]*bucket),
38 | defaultLimit: defaultRate,
39 | toolLimits: make(map[string]Rate),
40 | }
41 | }
42 |
43 | // SetToolLimit 为特定工具设置限制
44 | func (l *TokenBucketLimiter) SetToolLimit(toolName string, rate Rate) {
45 | l.mu.Lock()
46 | defer l.mu.Unlock()
47 |
48 | l.toolLimits[toolName] = rate
49 | // 如果已有桶,更新其速率
50 | if b, exists := l.buckets[toolName]; exists {
51 | b.rate = rate
52 | }
53 | }
54 |
55 | // Allow 检查请求是否被允许
56 | func (l *TokenBucketLimiter) Allow(toolName string) bool {
57 | l.mu.RLock()
58 | defer l.mu.RUnlock()
59 |
60 | now := time.Now()
61 |
62 | // 获取或创建桶
63 | b, exists := l.buckets[toolName]
64 | if !exists {
65 | // 查找工具特定的限制,如果没有则使用默认限制
66 | rate, exists := l.toolLimits[toolName]
67 | if !exists {
68 | rate = l.defaultLimit
69 | }
70 |
71 | b = &bucket{
72 | tokens: float64(rate.Burst),
73 | lastTimestamp: now,
74 | rate: rate,
75 | }
76 | l.buckets[toolName] = b
77 | }
78 |
79 | // 计算从上次请求到现在应该添加的令牌
80 | elapsed := now.Sub(b.lastTimestamp).Seconds()
81 | b.lastTimestamp = now
82 |
83 | // 添加令牌,但不超过最大值
84 | b.tokens += elapsed * b.rate.Limit
85 | if b.tokens > float64(b.rate.Burst) {
86 | b.tokens = float64(b.rate.Burst)
87 | }
88 |
89 | if b.tokens >= 1.0 {
90 | b.tokens -= 1.0
91 | return true
92 | }
93 | return false
94 | }
95 |
--------------------------------------------------------------------------------
/pkg/log.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import "log"
4 |
5 | type Logger interface {
6 | Debugf(format string, a ...any)
7 | Infof(format string, a ...any)
8 | Warnf(format string, a ...any)
9 | Errorf(format string, a ...any)
10 | }
11 | type LogLevel uint32
12 |
13 | const (
14 | LogLevelDebug = LogLevel(0)
15 | LogLevelInfo = LogLevel(1)
16 | LogLevelWarn = LogLevel(2)
17 | LogLevelError = LogLevel(3)
18 | )
19 |
20 | var DefaultLogger Logger = &defaultLogger{
21 | logLevel: LogLevelInfo,
22 | }
23 |
24 | var DebugLogger Logger = &defaultLogger{
25 | logLevel: LogLevelDebug,
26 | }
27 |
28 | type defaultLogger struct {
29 | logLevel LogLevel
30 | }
31 |
32 | func (l *defaultLogger) Debugf(format string, a ...any) {
33 | if l.logLevel > LogLevelDebug {
34 | return
35 | }
36 | log.Printf("[Debug] "+format+"\n", a...)
37 | }
38 |
39 | func (l *defaultLogger) Infof(format string, a ...any) {
40 | if l.logLevel > LogLevelInfo {
41 | return
42 | }
43 | log.Printf("[Info] "+format+"\n", a...)
44 | }
45 |
46 | func (l *defaultLogger) Warnf(format string, a ...any) {
47 | if l.logLevel > LogLevelWarn {
48 | return
49 | }
50 | log.Printf("[Warn] "+format+"\n", a...)
51 | }
52 |
53 | func (l *defaultLogger) Errorf(format string, a ...any) {
54 | if l.logLevel > LogLevelError {
55 | return
56 | }
57 | log.Printf("[Error] "+format+"\n", a...)
58 | }
59 |
--------------------------------------------------------------------------------
/pkg/sync_map.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import "sync"
4 |
5 | type SyncMap[V any] struct {
6 | m sync.Map
7 | }
8 |
9 | func (m *SyncMap[V]) Delete(key string) {
10 | m.m.Delete(key)
11 | }
12 |
13 | func (m *SyncMap[V]) Load(key string) (value V, ok bool) {
14 | v, ok := m.m.Load(key)
15 | if !ok {
16 | return value, ok
17 | }
18 | return v.(V), ok
19 | }
20 |
21 | func (m *SyncMap[V]) LoadAndDelete(key string) (value V, loaded bool) {
22 | v, loaded := m.m.LoadAndDelete(key)
23 | if !loaded {
24 | return value, loaded
25 | }
26 | return v.(V), loaded
27 | }
28 |
29 | func (m *SyncMap[V]) LoadOrStore(key string, value V) (actual V, loaded bool) {
30 | a, loaded := m.m.LoadOrStore(key, value)
31 | return a.(V), loaded
32 | }
33 |
34 | func (m *SyncMap[V]) Range(f func(key string, value V) bool) {
35 | m.m.Range(func(key, value any) bool { return f(key.(string), value.(V)) })
36 | }
37 |
38 | func (m *SyncMap[V]) Store(key string, value V) {
39 | m.m.Store(key, value)
40 | }
41 |
--------------------------------------------------------------------------------
/protocol/cancellation.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | // CancelledNotification represents a notification that a request has been canceled
4 | type CancelledNotification struct {
5 | RequestID RequestID `json:"requestId"`
6 | Reason string `json:"reason,omitempty"`
7 | }
8 |
9 | // NewCancelledNotification creates a new canceled notification
10 | func NewCancelledNotification(requestID RequestID, reason string) *CancelledNotification {
11 | return &CancelledNotification{
12 | RequestID: requestID,
13 | Reason: reason,
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/protocol/completion.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | // CompleteRequest represents a request for completion options
4 | type CompleteRequest struct {
5 | Argument struct {
6 | Name string `json:"name"`
7 | Value string `json:"value"`
8 | } `json:"argument"`
9 | Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference
10 | }
11 |
12 | // Reference types
13 | type PromptReference struct {
14 | Type string `json:"type"`
15 | Name string `json:"name"`
16 | }
17 |
18 | type ResourceReference struct {
19 | Type string `json:"type"`
20 | URI string `json:"uri"`
21 | }
22 |
23 | // CompleteResult represents the response to a completion request
24 | type CompleteResult struct {
25 | Completion *Complete `json:"completion"`
26 | }
27 |
28 | type Complete struct {
29 | Values []string `json:"values"`
30 | HasMore bool `json:"hasMore,omitempty"`
31 | Total int `json:"total,omitempty"`
32 | }
33 |
34 | // NewCompleteRequest creates a new completion request
35 | func NewCompleteRequest(argName string, argValue string, ref interface{}) *CompleteRequest {
36 | return &CompleteRequest{
37 | Argument: struct {
38 | Name string `json:"name"`
39 | Value string `json:"value"`
40 | }{
41 | Name: argName,
42 | Value: argValue,
43 | },
44 | Ref: ref,
45 | }
46 | }
47 |
48 | // NewCompleteResult creates a new completion response
49 | func NewCompleteResult(values []string, hasMore bool, total int) *CompleteResult {
50 | return &CompleteResult{
51 | Completion: &Complete{
52 | Values: values,
53 | HasMore: hasMore,
54 | Total: total,
55 | },
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/protocol/initialize.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/json"
5 |
6 | "github.com/tidwall/gjson"
7 | )
8 |
9 | // InitializeRequest represents the initialize request sent from client to server
10 | type InitializeRequest struct {
11 | ClientInfo *Implementation `json:"clientInfo"`
12 | Capabilities *ClientCapabilities `json:"capabilities"`
13 | ProtocolVersion string `json:"protocolVersion"`
14 | }
15 |
16 | // InitializeResult represents the server's response to an initialize request
17 | type InitializeResult struct {
18 | ServerInfo *Implementation `json:"serverInfo"`
19 | Capabilities *ServerCapabilities `json:"capabilities"`
20 | ProtocolVersion string `json:"protocolVersion"`
21 | Instructions string `json:"instructions,omitempty"`
22 | }
23 |
24 | // Implementation describes the name and version of an MCP implementation
25 | type Implementation struct {
26 | Name string `json:"name"`
27 | Version string `json:"version"`
28 | }
29 |
30 | // ClientCapabilities capabilities
31 | type ClientCapabilities struct {
32 | // Experimental map[string]interface{} `json:"experimental,omitempty"`
33 | // Roots *RootsCapability `json:"roots,omitempty"`
34 | Sampling interface{} `json:"sampling,omitempty"`
35 | }
36 |
37 | type RootsCapability struct {
38 | ListChanged bool `json:"listChanged,omitempty"`
39 | }
40 |
41 | type ServerCapabilities struct {
42 | // Experimental map[string]interface{} `json:"experimental,omitempty"`
43 | // Logging interface{} `json:"logging,omitempty"`
44 | Prompts *PromptsCapability `json:"prompts,omitempty"`
45 | Resources *ResourcesCapability `json:"resources,omitempty"`
46 | Tools *ToolsCapability `json:"tools,omitempty"`
47 | }
48 |
49 | type PromptsCapability struct {
50 | ListChanged bool `json:"listChanged,omitempty"`
51 | }
52 |
53 | type ResourcesCapability struct {
54 | ListChanged bool `json:"listChanged,omitempty"`
55 | Subscribe bool `json:"subscribe,omitempty"`
56 | }
57 |
58 | type ToolsCapability struct {
59 | ListChanged bool `json:"listChanged,omitempty"`
60 | }
61 |
62 | // InitializedNotification represents the notification sent from client to server after initialization
63 | type InitializedNotification struct {
64 | Meta map[string]interface{} `json:"_meta,omitempty"`
65 | }
66 |
67 | // NewInitializeRequest creates a new initialize request
68 | func NewInitializeRequest(clientInfo *Implementation, capabilities *ClientCapabilities) *InitializeRequest {
69 | return &InitializeRequest{
70 | ClientInfo: clientInfo,
71 | Capabilities: capabilities,
72 | ProtocolVersion: Version,
73 | }
74 | }
75 |
76 | // NewInitializeResult creates a new initialize response
77 | func NewInitializeResult(serverInfo *Implementation, capabilities *ServerCapabilities, version string, instructions string) *InitializeResult {
78 | return &InitializeResult{
79 | ServerInfo: serverInfo,
80 | Capabilities: capabilities,
81 | ProtocolVersion: version,
82 | Instructions: instructions,
83 | }
84 | }
85 |
86 | // NewInitializedNotification creates a new initialized notification
87 | func NewInitializedNotification() *InitializedNotification {
88 | return &InitializedNotification{}
89 | }
90 |
91 | func IsInitializedRequest(rawParams json.RawMessage) bool {
92 | return gjson.ParseBytes(rawParams).Get("method").String() == string(Initialize)
93 | }
94 |
--------------------------------------------------------------------------------
/protocol/jsonrpc.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/json"
5 |
6 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
7 | )
8 |
9 | const jsonrpcVersion = "2.0"
10 |
11 | // Standard JSON-RPC error codes
12 | const (
13 | ParseError = -32700 // Invalid JSON
14 | InvalidRequest = -32600 // The JSON sent is not a valid Request object
15 | MethodNotFound = -32601 // The method does not exist / is not available
16 | InvalidParams = -32602 // Invalid method parameter(s)
17 | InternalError = -32603 // Internal JSON-RPC error
18 |
19 | // 可以定义自己的错误代码,范围在-32000 以上。
20 | ConnectionError = -32400
21 | )
22 |
23 | type RequestID interface{} // 字符串/数值
24 |
25 | type JSONRPCRequest struct {
26 | JSONRPC string `json:"jsonrpc"`
27 | ID RequestID `json:"id"`
28 | Method Method `json:"method"`
29 | Params interface{} `json:"params,omitempty"`
30 | RawParams json.RawMessage `json:"-"`
31 | }
32 |
33 | func (r *JSONRPCRequest) UnmarshalJSON(data []byte) error {
34 | type alias JSONRPCRequest
35 | temp := &struct {
36 | Params json.RawMessage `json:"params,omitempty"`
37 | *alias
38 | }{
39 | alias: (*alias)(r),
40 | }
41 |
42 | if err := pkg.JSONUnmarshal(data, temp); err != nil {
43 | return err
44 | }
45 |
46 | r.RawParams = temp.Params
47 |
48 | if len(r.RawParams) != 0 {
49 | if err := pkg.JSONUnmarshal(r.RawParams, &r.Params); err != nil {
50 | return err
51 | }
52 | }
53 |
54 | return nil
55 | }
56 |
57 | // IsValid checks if the request is valid according to JSON-RPC 2.0 spec
58 | func (r *JSONRPCRequest) IsValid() bool {
59 | return r.JSONRPC == jsonrpcVersion && r.Method != "" && r.ID != nil
60 | }
61 |
62 | // JSONRPCResponse represents a response to a request.
63 | type JSONRPCResponse struct {
64 | JSONRPC string `json:"jsonrpc"`
65 | ID RequestID `json:"id"`
66 | Result interface{} `json:"result,omitempty"`
67 | RawResult json.RawMessage `json:"-"`
68 | Error *responseErr `json:"error,omitempty"`
69 | }
70 |
71 | type responseErr struct {
72 | // The error type that occurred.
73 | Code int `json:"code"`
74 | // A short description of the error. The message SHOULD be limited
75 | // to a concise single sentence.
76 | Message string `json:"message"`
77 | // Additional information about the error. The value of this member
78 | // is defined by the sender (e.g. detailed error information, nested errors etc.).
79 | Data interface{} `json:"data,omitempty"`
80 | }
81 |
82 | func (r *JSONRPCResponse) UnmarshalJSON(data []byte) error {
83 | type alias JSONRPCResponse
84 | temp := &struct {
85 | Result json.RawMessage `json:"result,omitempty"`
86 | *alias
87 | }{
88 | alias: (*alias)(r),
89 | }
90 |
91 | if err := pkg.JSONUnmarshal(data, temp); err != nil {
92 | return err
93 | }
94 |
95 | r.RawResult = temp.Result
96 |
97 | if len(r.RawResult) != 0 {
98 | if err := pkg.JSONUnmarshal(r.RawResult, &r.Result); err != nil {
99 | return err
100 | }
101 | }
102 |
103 | return nil
104 | }
105 |
106 | type JSONRPCNotification struct {
107 | JSONRPC string `json:"jsonrpc"`
108 | Method Method `json:"method"`
109 | Params interface{} `json:"params,omitempty"`
110 | RawParams json.RawMessage `json:"-"`
111 | }
112 |
113 | func (r *JSONRPCNotification) UnmarshalJSON(data []byte) error {
114 | type alias JSONRPCNotification
115 | temp := &struct {
116 | Params json.RawMessage `json:"params,omitempty"`
117 | *alias
118 | }{
119 | alias: (*alias)(r),
120 | }
121 |
122 | if err := pkg.JSONUnmarshal(data, temp); err != nil {
123 | return err
124 | }
125 |
126 | r.RawParams = temp.Params
127 |
128 | if len(r.RawParams) != 0 {
129 | if err := pkg.JSONUnmarshal(r.RawParams, &r.Params); err != nil {
130 | return err
131 | }
132 | }
133 |
134 | return nil
135 | }
136 |
137 | // NewJSONRPCRequest creates a new JSON-RPC request
138 | func NewJSONRPCRequest(id RequestID, method Method, params interface{}) *JSONRPCRequest {
139 | return &JSONRPCRequest{
140 | JSONRPC: jsonrpcVersion,
141 | ID: id,
142 | Method: method,
143 | Params: params,
144 | }
145 | }
146 |
147 | // NewJSONRPCSuccessResponse creates a new JSON-RPC response
148 | func NewJSONRPCSuccessResponse(id RequestID, result interface{}) *JSONRPCResponse {
149 | return &JSONRPCResponse{
150 | JSONRPC: jsonrpcVersion,
151 | ID: id,
152 | Result: result,
153 | }
154 | }
155 |
156 | // NewJSONRPCErrorResponse NewError creates a new JSON-RPC error response
157 | func NewJSONRPCErrorResponse(id RequestID, code int, message string) *JSONRPCResponse {
158 | err := &JSONRPCResponse{
159 | JSONRPC: jsonrpcVersion,
160 | ID: id,
161 | Error: &responseErr{
162 | Code: code,
163 | Message: message,
164 | },
165 | }
166 | return err
167 | }
168 |
169 | // NewJSONRPCNotification creates a new JSON-RPC notification
170 | func NewJSONRPCNotification(method Method, params interface{}) *JSONRPCNotification {
171 | return &JSONRPCNotification{
172 | JSONRPC: jsonrpcVersion,
173 | Method: method,
174 | Params: params,
175 | }
176 | }
177 |
--------------------------------------------------------------------------------
/protocol/logging.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | // LoggingLevel represents the severity of a log message
4 | type LoggingLevel string
5 |
6 | const (
7 | LogEmergency LoggingLevel = "emergency"
8 | LogAlert LoggingLevel = "alert"
9 | LogCritical LoggingLevel = "critical"
10 | LogError LoggingLevel = "error"
11 | LogWarning LoggingLevel = "warning"
12 | LogNotice LoggingLevel = "notice"
13 | LogInfo LoggingLevel = "info"
14 | LogDebug LoggingLevel = "debug"
15 | )
16 |
17 | // SetLoggingLevelRequest represents a request to set the logging level
18 | type SetLoggingLevelRequest struct {
19 | Level LoggingLevel `json:"level"`
20 | }
21 |
22 | // SetLoggingLevelResult represents the response to a set logging level request
23 | type SetLoggingLevelResult struct {
24 | Success bool `json:"success"`
25 | }
26 |
27 | // LogMessageNotification represents a log message notification
28 | type LogMessageNotification struct {
29 | Level LoggingLevel `json:"level"`
30 | Message string `json:"message"`
31 | Meta map[string]interface{} `json:"meta,omitempty"`
32 | }
33 |
34 | // NewSetLoggingLevelRequest creates a new set logging level request
35 | func NewSetLoggingLevelRequest(level LoggingLevel) *SetLoggingLevelRequest {
36 | return &SetLoggingLevelRequest{
37 | Level: level,
38 | }
39 | }
40 |
41 | // NewSetLoggingLevelResult creates a new set logging level response
42 | func NewSetLoggingLevelResult(success bool) *SetLoggingLevelResult {
43 | return &SetLoggingLevelResult{
44 | Success: success,
45 | }
46 | }
47 |
48 | // NewLogMessageNotification creates a new log message notification
49 | func NewLogMessageNotification(level LoggingLevel, message string, meta map[string]interface{}) *LogMessageNotification {
50 | return &LogMessageNotification{
51 | Level: level,
52 | Message: message,
53 | Meta: meta,
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/protocol/pagination.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/base64"
5 | "sort"
6 | )
7 |
8 | // Cursor is an opaque token used to represent a cursor for pagination.
9 | type Cursor string
10 |
11 | type Named interface {
12 | GetName() string
13 | }
14 |
15 | func PaginationLimit[T Named](allElements []T, cursor Cursor, limit int) ([]T, Cursor, error) {
16 | sort.Slice(allElements, func(i, j int) bool {
17 | return allElements[i].GetName() < allElements[j].GetName()
18 | })
19 | startPos := 0
20 | if cursor != "" {
21 | c, err := base64.StdEncoding.DecodeString(string(cursor))
22 | if err != nil {
23 | return nil, "", err
24 | }
25 | cString := string(c)
26 | startPos = sort.Search(len(allElements), func(i int) bool {
27 | nc := allElements[i].GetName()
28 | return nc > cString
29 | })
30 | }
31 | endPos := len(allElements)
32 | if len(allElements) > startPos+limit {
33 | endPos = startPos + limit
34 | }
35 | elementsToReturn := allElements[startPos:endPos]
36 | // set the next cursor
37 | nextCursor := func() Cursor {
38 | if len(elementsToReturn) < limit {
39 | return ""
40 | }
41 | element := elementsToReturn[len(elementsToReturn)-1]
42 | nc := element.GetName()
43 | toString := base64.StdEncoding.EncodeToString([]byte(nc))
44 | return Cursor(toString)
45 | }()
46 | return elementsToReturn, nextCursor, nil
47 | }
48 |
49 | // PaginatedRequest represents a request that supports pagination
50 | type PaginatedRequest struct {
51 | Cursor Cursor `json:"cursor,omitempty"`
52 | }
53 |
54 | // PaginatedResult represents a response that supports pagination
55 | type PaginatedResult struct {
56 | NextCursor Cursor `json:"nextCursor,omitempty"`
57 | }
58 |
--------------------------------------------------------------------------------
/protocol/pagination_test.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/base64"
5 | "fmt"
6 | "reflect"
7 | "sort"
8 | "testing"
9 | )
10 |
11 | func BenchmarkPaginationLimitForReflect(b *testing.B) {
12 | list := getTools(10000)
13 | for i := 0; i < b.N; i++ {
14 | _, _, _ = PaginationLimitForReflect[*Tool](list, "dG9vbDMz", 10)
15 | }
16 | }
17 |
18 | func BenchmarkPaginationLimitForTool(b *testing.B) {
19 | list := getTools(10000)
20 | for i := 0; i < b.N; i++ {
21 | _, _, _ = PaginationLimitForTool(list, "dG9vbDMz", 10)
22 | }
23 | }
24 |
25 | func BenchmarkPaginationLimit(b *testing.B) {
26 | list := getTools(10000)
27 | for i := 0; i < b.N; i++ {
28 | _, _, _ = PaginationLimit(list, "dG9vbDMz", 10)
29 | }
30 | }
31 |
32 | func getTools(length int) []*Tool {
33 | list := make([]*Tool, 0, 10000)
34 | for i := 0; i < length; i++ {
35 | list = append(list, &Tool{
36 | Name: fmt.Sprintf("tool%d", i),
37 | Description: fmt.Sprintf("tool%d", i),
38 | })
39 | }
40 | return list
41 | }
42 |
43 | func PaginationLimitForTool(allElements []*Tool, cursor Cursor, limit int) ([]*Tool, Cursor, error) {
44 | startPos := 0
45 | if cursor != "" {
46 | c, err := base64.StdEncoding.DecodeString(string(cursor))
47 | if err != nil {
48 | return nil, "", err
49 | }
50 | cString := string(c)
51 | startPos = sort.Search(len(allElements), func(i int) bool {
52 | nc := allElements[i].Name
53 | return nc > cString
54 | })
55 | }
56 | endPos := len(allElements)
57 | if len(allElements) > startPos+limit {
58 | endPos = startPos + limit
59 | }
60 | elementsToReturn := allElements[startPos:endPos]
61 | // set the next cursor
62 | nextCursor := func() Cursor {
63 | if len(elementsToReturn) < limit {
64 | return ""
65 | }
66 | element := elementsToReturn[len(elementsToReturn)-1]
67 | nc := element.Name
68 | toString := base64.StdEncoding.EncodeToString([]byte(nc))
69 | return Cursor(toString)
70 | }()
71 | return elementsToReturn, nextCursor, nil
72 | }
73 |
74 | func PaginationLimitForReflect[T any](allElements []T, cursor Cursor, limit int) ([]T, Cursor, error) {
75 | startPos := 0
76 | if cursor != "" {
77 | c, err := base64.StdEncoding.DecodeString(string(cursor))
78 | if err != nil {
79 | return nil, "", err
80 | }
81 | cString := string(c)
82 | startPos = sort.Search(len(allElements), func(i int) bool {
83 | val := reflect.ValueOf(allElements[i])
84 | var nc string
85 | if val.Kind() == reflect.Ptr {
86 | val = val.Elem()
87 | }
88 | nc = val.FieldByName("Name").String()
89 | return nc > cString
90 | })
91 | }
92 | endPos := len(allElements)
93 | if len(allElements) > startPos+limit {
94 | endPos = startPos + limit
95 | }
96 | elementsToReturn := allElements[startPos:endPos]
97 | // set the next cursor
98 | nextCursor := func() Cursor {
99 | if len(elementsToReturn) < limit {
100 | return ""
101 | }
102 | element := elementsToReturn[len(elementsToReturn)-1]
103 | val := reflect.ValueOf(element)
104 | var nc string
105 | if val.Kind() == reflect.Ptr {
106 | val = val.Elem()
107 | }
108 | nc = val.FieldByName("Name").String()
109 | toString := base64.StdEncoding.EncodeToString([]byte(nc))
110 | return Cursor(toString)
111 | }()
112 | return elementsToReturn, nextCursor, nil
113 | }
114 |
--------------------------------------------------------------------------------
/protocol/ping.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | type PingRequest struct{}
4 |
5 | type PingResult struct{}
6 |
7 | // NewPingRequest creates a new ping request
8 | func NewPingRequest() *PingRequest {
9 | return &PingRequest{}
10 | }
11 |
12 | // NewPingResult creates a new ping response
13 | func NewPingResult() *PingResult {
14 | return &PingResult{}
15 | }
16 |
--------------------------------------------------------------------------------
/protocol/progress.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | const ProgressTokenKey = "progressToken"
4 |
5 | // ProgressNotification represents a progress notification for a long-running request
6 | type ProgressNotification struct {
7 | ProgressToken ProgressToken `json:"progressToken"`
8 | Progress float64 `json:"progress"`
9 | Total float64 `json:"total,omitempty"`
10 | Message string `json:"message,omitempty"`
11 | }
12 |
13 | // ProgressToken represents a token used to associate progress notifications with the original request
14 | type ProgressToken interface{} // can be string or integer
15 |
16 | // NewProgressNotification creates a new progress notification
17 | func NewProgressNotification(progress float64, total float64, message string) *ProgressNotification {
18 | return &ProgressNotification{
19 | Progress: progress,
20 | Total: total,
21 | Message: message,
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/protocol/prompts.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 |
7 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
8 | )
9 |
10 | // ListPromptsRequest represents a request to list available prompts
11 | type ListPromptsRequest struct {
12 | Cursor Cursor `json:"cursor,omitempty"`
13 | }
14 |
15 | // ListPromptsResult represents the response to a list prompts request
16 | type ListPromptsResult struct {
17 | Prompts []*Prompt `json:"prompts"`
18 | NextCursor Cursor `json:"nextCursor,omitempty"`
19 | }
20 |
21 | // Prompt related types
22 | type Prompt struct {
23 | Name string `json:"name"`
24 | Description string `json:"description,omitempty"`
25 | Arguments []*PromptArgument `json:"arguments,omitempty"`
26 | }
27 |
28 | func (p *Prompt) GetName() string {
29 | return p.Name
30 | }
31 |
32 | type PromptArgument struct {
33 | Name string `json:"name"`
34 | Description string `json:"description,omitempty"`
35 | Required bool `json:"required,omitempty"`
36 | }
37 |
38 | // GetPromptRequest represents a request to get a specific prompt
39 | type GetPromptRequest struct {
40 | Name string `json:"name"`
41 | Arguments map[string]string `json:"arguments,omitempty"`
42 | }
43 |
44 | // GetPromptResult represents the response to a get prompt request
45 | type GetPromptResult struct {
46 | Messages []*PromptMessage `json:"messages"`
47 | Description string `json:"description,omitempty"`
48 | }
49 |
50 | type PromptMessage struct {
51 | Role Role `json:"role"`
52 | Content Content `json:"content"`
53 | }
54 |
55 | // UnmarshalJSON implements the json.Unmarshaler interface for PromptMessage
56 | func (m *PromptMessage) UnmarshalJSON(data []byte) error {
57 | type Alias PromptMessage
58 | aux := &struct {
59 | Content json.RawMessage `json:"content"`
60 | *Alias
61 | }{
62 | Alias: (*Alias)(m),
63 | }
64 | if err := pkg.JSONUnmarshal(data, &aux); err != nil {
65 | return err
66 | }
67 |
68 | // Try to unmarshal content as TextContent first
69 | var textContent *TextContent
70 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil {
71 | m.Content = textContent
72 | return nil
73 | }
74 |
75 | // Try to unmarshal content as ImageContent
76 | var imageContent *ImageContent
77 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil {
78 | m.Content = imageContent
79 | return nil
80 | }
81 |
82 | // Try to unmarshal content as AudioContent
83 | var audioContent *AudioContent
84 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil {
85 | m.Content = audioContent
86 | return nil
87 | }
88 |
89 | // Try to unmarshal content as embeddedResource
90 | var embeddedResource *EmbeddedResource
91 | if err := pkg.JSONUnmarshal(aux.Content, &embeddedResource); err == nil {
92 | m.Content = embeddedResource
93 | return nil
94 | }
95 |
96 | return fmt.Errorf("unknown content type")
97 | }
98 |
99 | // PromptListChangedNotification represents a notification that the prompt list has changed
100 | type PromptListChangedNotification struct {
101 | Meta map[string]interface{} `json:"_meta,omitempty"`
102 | }
103 |
104 | // NewListPromptsRequest creates a new list prompts request
105 | func NewListPromptsRequest() *ListPromptsRequest {
106 | return &ListPromptsRequest{}
107 | }
108 |
109 | // NewListPromptsResult creates a new list prompts response
110 | func NewListPromptsResult(prompts []*Prompt, nextCursor Cursor) *ListPromptsResult {
111 | return &ListPromptsResult{
112 | Prompts: prompts,
113 | NextCursor: nextCursor,
114 | }
115 | }
116 |
117 | // NewGetPromptRequest creates a new get prompt request
118 | func NewGetPromptRequest(name string, arguments map[string]string) *GetPromptRequest {
119 | return &GetPromptRequest{
120 | Name: name,
121 | Arguments: arguments,
122 | }
123 | }
124 |
125 | // NewGetPromptResult creates a new get prompt response
126 | func NewGetPromptResult(messages []*PromptMessage, description string) *GetPromptResult {
127 | return &GetPromptResult{
128 | Messages: messages,
129 | Description: description,
130 | }
131 | }
132 |
133 | // NewPromptListChangedNotification creates a new prompt list changed notification
134 | func NewPromptListChangedNotification() *PromptListChangedNotification {
135 | return &PromptListChangedNotification{}
136 | }
137 |
--------------------------------------------------------------------------------
/protocol/roots.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | // ListRootsRequest represents a request to list root directories
4 | type ListRootsRequest struct{}
5 |
6 | // ListRootsResult represents the response to a list roots request
7 | type ListRootsResult struct {
8 | Roots []*Root `json:"roots"`
9 | }
10 |
11 | // Root represents a root directory or file that the server can operate on
12 | type Root struct {
13 | Name string `json:"name,omitempty"`
14 | URI string `json:"uri"`
15 | }
16 |
17 | // RootsListChangedNotification represents a notification that the roots list has changed
18 | type RootsListChangedNotification struct {
19 | Meta map[string]interface{} `json:"_meta,omitempty"`
20 | }
21 |
22 | // NewListRootsRequest creates a new list roots request
23 | func NewListRootsRequest() *ListRootsRequest {
24 | return &ListRootsRequest{}
25 | }
26 |
27 | // NewListRootsResult creates a new list roots response
28 | func NewListRootsResult(roots []*Root) *ListRootsResult {
29 | return &ListRootsResult{
30 | Roots: roots,
31 | }
32 | }
33 |
34 | // NewRootsListChangedNotification creates a new roots list changed notification
35 | func NewRootsListChangedNotification() *RootsListChangedNotification {
36 | return &RootsListChangedNotification{}
37 | }
38 |
--------------------------------------------------------------------------------
/protocol/sampling.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 |
7 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
8 | )
9 |
10 | // CreateMessageRequest represents a request to create a message through sampling
11 | type CreateMessageRequest struct {
12 | Messages []*SamplingMessage `json:"messages"`
13 | MaxTokens int `json:"maxTokens"`
14 | Temperature float64 `json:"temperature,omitempty"`
15 | StopSequences []string `json:"stopSequences,omitempty"`
16 | SystemPrompt string `json:"systemPrompt,omitempty"`
17 | ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"`
18 | IncludeContext string `json:"includeContext,omitempty"`
19 | Metadata map[string]interface{} `json:"metadata,omitempty"`
20 | }
21 |
22 | type SamplingMessage struct {
23 | Role Role `json:"role"`
24 | Content Content `json:"content"`
25 | }
26 |
27 | // UnmarshalJSON implements the json.Unmarshaler interface for SamplingMessage
28 | func (r *SamplingMessage) UnmarshalJSON(data []byte) error {
29 | type Alias SamplingMessage
30 | aux := &struct {
31 | Content json.RawMessage `json:"content"`
32 | *Alias
33 | }{
34 | Alias: (*Alias)(r),
35 | }
36 | if err := pkg.JSONUnmarshal(data, &aux); err != nil {
37 | return err
38 | }
39 |
40 | // Try to unmarshal content as TextContent first
41 | var textContent *TextContent
42 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil {
43 | r.Content = textContent
44 | return nil
45 | }
46 |
47 | // Try to unmarshal content as ImageContent
48 | var imageContent *ImageContent
49 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil {
50 | r.Content = imageContent
51 | return nil
52 | }
53 |
54 | // Try to unmarshal content as AudioContent
55 | var audioContent *AudioContent
56 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil {
57 | r.Content = audioContent
58 | return nil
59 | }
60 |
61 | return fmt.Errorf("unknown content type, content=%s", aux.Content)
62 | }
63 |
64 | // CreateMessageResult represents the response to a create message request
65 | type CreateMessageResult struct {
66 | Content Content `json:"content"`
67 | Role Role `json:"role"`
68 | Model string `json:"model"`
69 | StopReason string `json:"stopReason,omitempty"`
70 | }
71 |
72 | // UnmarshalJSON implements the json.Unmarshaler interface for CreateMessageResult
73 | func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
74 | type Alias CreateMessageResult
75 | aux := &struct {
76 | Content json.RawMessage `json:"content"`
77 | *Alias
78 | }{
79 | Alias: (*Alias)(r),
80 | }
81 | if err := pkg.JSONUnmarshal(data, &aux); err != nil {
82 | return err
83 | }
84 |
85 | // Try to unmarshal content as TextContent first
86 | var textContent *TextContent
87 | if err := pkg.JSONUnmarshal(aux.Content, &textContent); err == nil {
88 | r.Content = textContent
89 | return nil
90 | }
91 |
92 | // Try to unmarshal content as ImageContent
93 | var imageContent *ImageContent
94 | if err := pkg.JSONUnmarshal(aux.Content, &imageContent); err == nil {
95 | r.Content = imageContent
96 | return nil
97 | }
98 |
99 | // Try to unmarshal content as AudioContent
100 | var audioContent *AudioContent
101 | if err := pkg.JSONUnmarshal(aux.Content, &audioContent); err == nil {
102 | r.Content = audioContent
103 | return nil
104 | }
105 |
106 | return fmt.Errorf("unknown content type, content=%s", aux.Content)
107 | }
108 |
109 | // NewCreateMessageRequest creates a new create message request
110 | func NewCreateMessageRequest(messages []*SamplingMessage, maxTokens int, opts ...CreateMessageOption) *CreateMessageRequest {
111 | req := &CreateMessageRequest{
112 | Messages: messages,
113 | MaxTokens: maxTokens,
114 | }
115 |
116 | for _, opt := range opts {
117 | opt(req)
118 | }
119 |
120 | return req
121 | }
122 |
123 | // NewCreateMessageResult creates a new create message response
124 | func NewCreateMessageResult(content Content, role Role, model string, stopReason string) *CreateMessageResult {
125 | return &CreateMessageResult{
126 | Content: content,
127 | Role: role,
128 | Model: model,
129 | StopReason: stopReason,
130 | }
131 | }
132 |
133 | // CreateMessageOption represents an option for creating a message
134 | type CreateMessageOption func(*CreateMessageRequest)
135 |
136 | // WithTemperature sets the temperature for the request
137 | func WithTemperature(temp float64) CreateMessageOption {
138 | return func(r *CreateMessageRequest) {
139 | r.Temperature = temp
140 | }
141 | }
142 |
143 | // WithStopSequences sets the stop sequences for the request
144 | func WithStopSequences(sequences []string) CreateMessageOption {
145 | return func(r *CreateMessageRequest) {
146 | r.StopSequences = sequences
147 | }
148 | }
149 |
150 | // WithSystemPrompt sets the system prompt for the request
151 | func WithSystemPrompt(prompt string) CreateMessageOption {
152 | return func(r *CreateMessageRequest) {
153 | r.SystemPrompt = prompt
154 | }
155 | }
156 |
157 | // WithModelPreferences sets the model preferences for the request
158 | func WithModelPreferences(prefs *ModelPreferences) CreateMessageOption {
159 | return func(r *CreateMessageRequest) {
160 | r.ModelPreferences = prefs
161 | }
162 | }
163 |
164 | // WithIncludeContext sets the include context option for the request
165 | func WithIncludeContext(ctx string) CreateMessageOption {
166 | return func(r *CreateMessageRequest) {
167 | r.IncludeContext = ctx
168 | }
169 | }
170 |
171 | // WithMetadata sets the metadata for the request
172 | func WithMetadata(metadata map[string]interface{}) CreateMessageOption {
173 | return func(r *CreateMessageRequest) {
174 | r.Metadata = metadata
175 | }
176 | }
177 |
--------------------------------------------------------------------------------
/protocol/schema_generate.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "strconv"
7 | "strings"
8 |
9 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
10 | )
11 |
12 | type DataType string
13 |
14 | const (
15 | ObjectT DataType = "object"
16 | Number DataType = "number"
17 | Integer DataType = "integer"
18 | String DataType = "string"
19 | Array DataType = "array"
20 | Null DataType = "null"
21 | Boolean DataType = "boolean"
22 | )
23 |
24 | type Property struct {
25 | Type DataType `json:"type"`
26 | // Description is the description of the schema.
27 | Description string `json:"description,omitempty"`
28 | // Items specifies which data type an array contains, if the schema type is Array.
29 | Items *Property `json:"items,omitempty"`
30 | // Properties describes the properties of an object, if the schema type is Object.
31 | Properties map[string]*Property `json:"properties,omitempty"`
32 | Required []string `json:"required,omitempty"`
33 | Enum []string `json:"enum,omitempty"`
34 | }
35 |
36 | var schemaCache = pkg.SyncMap[*InputSchema]{}
37 |
38 | func generateSchemaFromReqStruct(v any) (*InputSchema, error) {
39 | t := reflect.TypeOf(v)
40 | for t.Kind() != reflect.Struct {
41 | if t.Kind() != reflect.Ptr {
42 | return nil, fmt.Errorf("invalid type %v", t)
43 | }
44 | t = t.Elem()
45 | }
46 |
47 | typeUID := getTypeUUID(t)
48 | if schema, ok := schemaCache.Load(typeUID); ok {
49 | return schema, nil
50 | }
51 |
52 | schema := &InputSchema{Type: Object}
53 |
54 | property, err := reflectSchemaByObject(t)
55 | if err != nil {
56 | return nil, err
57 | }
58 |
59 | schema.Properties = property.Properties
60 | schema.Required = property.Required
61 |
62 | schemaCache.Store(typeUID, schema)
63 | return schema, nil
64 | }
65 |
66 | func getTypeUUID(t reflect.Type) string {
67 | if t.PkgPath() != "" && t.Name() != "" {
68 | return t.PkgPath() + "." + t.Name()
69 | }
70 | // fallback for unnamed types (like anonymous struct)
71 | return t.String()
72 | }
73 |
74 | func reflectSchemaByObject(t reflect.Type) (*Property, error) {
75 | var (
76 | properties = make(map[string]*Property)
77 | requiredFields = make([]string, 0)
78 | anonymousFields = make([]reflect.StructField, 0)
79 | )
80 |
81 | for i := 0; i < t.NumField(); i++ {
82 | field := t.Field(i)
83 |
84 | if field.Anonymous {
85 | anonymousFields = append(anonymousFields, field)
86 | continue
87 | }
88 |
89 | if !field.IsExported() {
90 | continue
91 | }
92 |
93 | jsonTag := field.Tag.Get("json")
94 | if jsonTag == "-" {
95 | continue
96 | }
97 | required := true
98 | if jsonTag == "" {
99 | jsonTag = field.Name
100 | }
101 | if strings.HasSuffix(jsonTag, ",omitempty") {
102 | jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
103 | required = false
104 | }
105 |
106 | item, err := reflectSchemaByType(field.Type)
107 | if err != nil {
108 | return nil, err
109 | }
110 |
111 | if description := field.Tag.Get("description"); description != "" {
112 | item.Description = description
113 | }
114 | properties[jsonTag] = item
115 |
116 | if s := field.Tag.Get("required"); s != "" {
117 | required, err = strconv.ParseBool(s)
118 | if err != nil {
119 | return nil, fmt.Errorf("invalid required field %v: %v", jsonTag, err)
120 | }
121 | }
122 | if required {
123 | requiredFields = append(requiredFields, jsonTag)
124 | }
125 |
126 | if v := field.Tag.Get("enum"); v != "" {
127 | enumValues := strings.Split(v, ",")
128 | for j, value := range enumValues {
129 | enumValues[j] = strings.TrimSpace(value)
130 | }
131 |
132 | // Check if enum values are consistent with the field type
133 | for _, value := range enumValues {
134 | switch field.Type.Kind() {
135 | case reflect.String:
136 | // No additional processing required for string type
137 | case reflect.Int, reflect.Int64:
138 | if _, err := strconv.Atoi(value); err != nil {
139 | return nil, fmt.Errorf("enum value %q is not compatible with type %v", value, field.Type)
140 | }
141 | case reflect.Float64:
142 | if _, err := strconv.ParseFloat(value, 64); err != nil {
143 | return nil, fmt.Errorf("enum value %q is not compatible with type %v", value, field.Type)
144 | }
145 | default:
146 | return nil, fmt.Errorf("unsupported type %v for enum validation", field.Type)
147 | }
148 | }
149 | item.Enum = enumValues
150 | }
151 | }
152 |
153 | for _, field := range anonymousFields {
154 | object, err := reflectSchemaByObject(field.Type)
155 | if err != nil {
156 | return nil, err
157 | }
158 | for propName, propValue := range object.Properties {
159 | if _, ok := properties[propName]; ok {
160 | return nil, fmt.Errorf("duplicate property name %s in anonymous struct", propName)
161 | }
162 | properties[propName] = propValue
163 | }
164 | requiredFields = append(requiredFields, object.Required...)
165 | }
166 |
167 | property := &Property{
168 | Type: ObjectT,
169 | Properties: properties,
170 | Required: requiredFields,
171 | }
172 | return property, nil
173 | }
174 |
175 | func reflectSchemaByType(t reflect.Type) (*Property, error) {
176 | s := &Property{}
177 |
178 | switch t.Kind() {
179 | case reflect.String:
180 | s.Type = String
181 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
182 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
183 | s.Type = Integer
184 | case reflect.Float32, reflect.Float64:
185 | s.Type = Number
186 | case reflect.Bool:
187 | s.Type = Boolean
188 | case reflect.Slice, reflect.Array:
189 | s.Type = Array
190 | items, err := reflectSchemaByType(t.Elem())
191 | if err != nil {
192 | return nil, err
193 | }
194 | s.Items = items
195 | case reflect.Struct:
196 | object, err := reflectSchemaByObject(t)
197 | if err != nil {
198 | return nil, err
199 | }
200 | object.Type = ObjectT
201 | s = object
202 | case reflect.Map:
203 | if t.Key().Kind() != reflect.String {
204 | return nil, fmt.Errorf("map key type %s is not supported", t.Key().Kind())
205 | }
206 | object := &Property{
207 | Type: ObjectT,
208 | }
209 | s = object
210 | case reflect.Ptr:
211 | p, err := reflectSchemaByType(t.Elem())
212 | if err != nil {
213 | return nil, err
214 | }
215 | s = p
216 | case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128,
217 | reflect.Chan, reflect.Func, reflect.Interface,
218 | reflect.UnsafePointer:
219 | return nil, fmt.Errorf("unsupported type: %s", t.Kind().String())
220 | default:
221 | }
222 | return s, nil
223 | }
224 |
--------------------------------------------------------------------------------
/protocol/schema_validate.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "reflect"
8 | "strconv"
9 |
10 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
11 | )
12 |
13 | func VerifyAndUnmarshal(content json.RawMessage, v any) error {
14 | if len(content) == 0 {
15 | return fmt.Errorf("request arguments is empty")
16 | }
17 |
18 | t := reflect.TypeOf(v)
19 | for t.Kind() != reflect.Struct {
20 | if t.Kind() != reflect.Ptr {
21 | return fmt.Errorf("invalid type %v, plz use func `pkg.JSONUnmarshal` instead", t)
22 | }
23 | t = t.Elem()
24 | }
25 |
26 | typeUID := getTypeUUID(t)
27 | schema, ok := schemaCache.Load(typeUID)
28 | if !ok {
29 | return fmt.Errorf("schema has not been generated,unable to verify: plz use func `pkg.JSONUnmarshal` instead")
30 | }
31 |
32 | return verifySchemaAndUnmarshal(Property{
33 | Type: ObjectT,
34 | Properties: schema.Properties,
35 | Required: schema.Required,
36 | }, content, v)
37 | }
38 |
39 | func verifySchemaAndUnmarshal(schema Property, content []byte, v any) error {
40 | var data any
41 | err := pkg.JSONUnmarshal(content, &data)
42 | if err != nil {
43 | return err
44 | }
45 | if !validate(schema, data) {
46 | return errors.New("data validation failed against the provided schema")
47 | }
48 | return pkg.JSONUnmarshal(content, &v)
49 | }
50 |
51 | func validate(schema Property, data any) bool {
52 | switch schema.Type {
53 | case ObjectT:
54 | return validateObject(schema, data)
55 | case Array:
56 | return validateArray(schema, data)
57 | case String:
58 | str, ok := data.(string)
59 | if ok {
60 | return validateEnumProperty[string](str, schema.Enum, func(value string, enumValue string) bool {
61 | return value == enumValue
62 | })
63 | }
64 | return false
65 | case Number: // float64 and int
66 | if num, ok := data.(float64); ok {
67 | return validateEnumProperty[float64](num, schema.Enum, func(value float64, enumValue string) bool {
68 | if enumNum, err := strconv.ParseFloat(enumValue, 64); err == nil && value == enumNum {
69 | return true
70 | }
71 | return false
72 | })
73 | }
74 | if num, ok := data.(int); ok {
75 | return validateEnumProperty[int](num, schema.Enum, func(value int, enumValue string) bool {
76 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == enumNum {
77 | return true
78 | }
79 | return false
80 | })
81 | }
82 | return false
83 | case Boolean:
84 | _, ok := data.(bool)
85 | return ok
86 | case Integer:
87 | // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer
88 | if num, ok := data.(float64); ok {
89 | if num == float64(int64(num)) {
90 | return validateEnumProperty[float64](num, schema.Enum, func(value float64, enumValue string) bool {
91 | if enumNum, err := strconv.ParseFloat(enumValue, 64); err == nil && value == enumNum {
92 | return true
93 | }
94 | return false
95 | })
96 | }
97 | return false
98 | }
99 |
100 | if num, ok := data.(int); ok {
101 | return validateEnumProperty[int](num, schema.Enum, func(value int, enumValue string) bool {
102 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == enumNum {
103 | return true
104 | }
105 | return false
106 | })
107 | }
108 |
109 | if num, ok := data.(int64); ok {
110 | return validateEnumProperty[int64](num, schema.Enum, func(value int64, enumValue string) bool {
111 | if enumNum, err := strconv.Atoi(enumValue); err == nil && value == int64(enumNum) {
112 | return true
113 | }
114 | return false
115 | })
116 | }
117 | return false
118 | case Null:
119 | return data == nil
120 | default:
121 | return false
122 | }
123 | }
124 |
125 | func validateObject(schema Property, data any) bool {
126 | dataMap, ok := data.(map[string]any)
127 | if !ok {
128 | return false
129 | }
130 | for _, field := range schema.Required {
131 | if _, exists := dataMap[field]; !exists {
132 | return false
133 | }
134 | }
135 | for key, valueSchema := range schema.Properties {
136 | value, exists := dataMap[key]
137 | if exists && !validate(*valueSchema, value) {
138 | return false
139 | }
140 | }
141 | return true
142 | }
143 |
144 | func validateArray(schema Property, data any) bool {
145 | dataArray, ok := data.([]any)
146 | if !ok {
147 | return false
148 | }
149 | for _, item := range dataArray {
150 | if !validate(*schema.Items, item) {
151 | return false
152 | }
153 | }
154 | return true
155 | }
156 |
157 | func validateEnumProperty[T any](data T, enum []string, compareFunc func(T, string) bool) bool {
158 | for _, enumValue := range enum {
159 | if compareFunc(data, enumValue) {
160 | return true
161 | }
162 | }
163 | return len(enum) == 0
164 | }
165 |
--------------------------------------------------------------------------------
/protocol/types.go:
--------------------------------------------------------------------------------
1 | package protocol
2 |
3 | const Version = "2025-03-26"
4 |
5 | var SupportedVersion = map[string]struct{}{
6 | "2024-11-05": {},
7 | "2025-03-26": {},
8 | }
9 |
10 | // Method represents the JSON-RPC method name
11 | type Method string
12 |
13 | const (
14 | // Core methods
15 | Ping Method = "ping"
16 | Initialize Method = "initialize"
17 | NotificationInitialized Method = "notifications/initialized"
18 |
19 | // Root related methods
20 | RootsList Method = "roots/list"
21 | NotificationRootsListChanged Method = "notifications/roots/list_changed"
22 |
23 | // Resource related methods
24 | ResourcesList Method = "resources/list"
25 | ResourceListTemplates Method = "resources/templates/list"
26 | ResourcesRead Method = "resources/read"
27 | ResourcesSubscribe Method = "resources/subscribe"
28 | ResourcesUnsubscribe Method = "resources/unsubscribe"
29 | NotificationResourcesListChanged Method = "notifications/resources/list_changed"
30 | NotificationResourcesUpdated Method = "notifications/resources/updated"
31 |
32 | // Tool related methods
33 | ToolsList Method = "tools/list"
34 | ToolsCall Method = "tools/call"
35 | NotificationToolsListChanged Method = "notifications/tools/list_changed"
36 |
37 | // Prompt related methods
38 | PromptsList Method = "prompts/list"
39 | PromptsGet Method = "prompts/get"
40 | NotificationPromptsListChanged Method = "notifications/prompts/list_changed"
41 |
42 | // Sampling related methods
43 | SamplingCreateMessage Method = "sampling/createMessage"
44 |
45 | // Logging related methods
46 | LoggingSetLevel Method = "logging/setLevel"
47 | NotificationLogMessage Method = "notifications/message"
48 |
49 | // Completion related methods
50 | CompletionComplete Method = "completion/complete"
51 |
52 | // progress related methods
53 | NotificationProgress Method = "notifications/progress"
54 | NotificationCancelled Method = "notifications/cancelled" // nolint:misspell
55 | )
56 |
57 | // Role represents the sender or recipient of messages and data in a conversation
58 | type Role string
59 |
60 | const (
61 | RoleUser Role = "user"
62 | RoleAssistant Role = "assistant"
63 | )
64 |
65 | type ClientRequest interface{}
66 |
67 | var (
68 | _ ClientRequest = &InitializeRequest{}
69 | _ ClientRequest = &PingRequest{}
70 | _ ClientRequest = &ListPromptsRequest{}
71 | _ ClientRequest = &GetPromptRequest{}
72 | _ ClientRequest = &ListResourcesRequest{}
73 | _ ClientRequest = &ReadResourceRequest{}
74 | _ ClientRequest = &ListResourceTemplatesRequest{}
75 | _ ClientRequest = &SubscribeRequest{}
76 | _ ClientRequest = &UnsubscribeRequest{}
77 | _ ClientRequest = &ListToolsRequest{}
78 | _ ClientRequest = &CallToolRequest{}
79 | _ ClientRequest = &CompleteRequest{}
80 | _ ClientRequest = &SetLoggingLevelRequest{}
81 | )
82 |
83 | type ClientResponse interface{}
84 |
85 | var (
86 | _ ClientResponse = &PingResult{}
87 | _ ClientResponse = &ListToolsResult{}
88 | _ ClientResponse = &CreateMessageResult{}
89 | )
90 |
91 | type ClientNotify interface{}
92 |
93 | var (
94 | _ ClientNotify = &InitializedNotification{}
95 | _ ClientNotify = &CancelledNotification{}
96 | _ ClientNotify = &ProgressNotification{}
97 | _ ClientNotify = &RootsListChangedNotification{}
98 | )
99 |
100 | type ServerRequest interface{}
101 |
102 | var (
103 | _ ServerRequest = &PingRequest{}
104 | _ ServerRequest = &ListRootsRequest{}
105 | _ ServerRequest = &CreateMessageRequest{}
106 | )
107 |
108 | type ServerResponse interface{}
109 |
110 | var (
111 | _ ServerResponse = &InitializeResult{}
112 | _ ServerResponse = &PingResult{}
113 | _ ServerResponse = &ListPromptsResult{}
114 | _ ServerResponse = &GetPromptResult{}
115 | _ ServerResponse = &ListResourcesResult{}
116 | _ ServerResponse = &ReadResourceResult{}
117 | _ ServerResponse = &ListResourceTemplatesResult{}
118 | _ ServerResponse = &SubscribeResult{}
119 | _ ServerResponse = &UnsubscribeResult{}
120 | _ ServerResponse = &ListToolsResult{}
121 | _ ServerResponse = &CallToolResult{}
122 | _ ServerResponse = &CompleteResult{}
123 | _ ServerResponse = &SetLoggingLevelResult{}
124 | )
125 |
126 | type ServerNotify interface{}
127 |
128 | var (
129 | _ ServerNotify = &CancelledNotification{}
130 | _ ServerNotify = &ProgressNotification{}
131 | _ ServerNotify = &ToolListChangedNotification{}
132 | _ ServerNotify = &PromptListChangedNotification{}
133 | _ ServerNotify = &ResourceListChangedNotification{}
134 | _ ServerNotify = &ResourceUpdatedNotification{}
135 | _ ServerNotify = &LogMessageNotification{}
136 | )
137 |
--------------------------------------------------------------------------------
/server/call.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "strconv"
8 |
9 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
10 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
11 | "github.com/ThinkInAIXYZ/go-mcp/server/session"
12 | )
13 |
14 | func (server *Server) Ping(ctx context.Context, request *protocol.PingRequest) (*protocol.PingResult, error) {
15 | sessionID, err := GetSessionIDFromCtx(ctx)
16 | if err != nil {
17 | return nil, err
18 | }
19 |
20 | response, err := server.callClient(ctx, sessionID, protocol.Ping, request)
21 | if err != nil {
22 | return nil, err
23 | }
24 |
25 | var result protocol.PingResult
26 | if err = pkg.JSONUnmarshal(response, &result); err != nil {
27 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
28 | }
29 | return &result, nil
30 | }
31 |
32 | func (server *Server) Sampling(ctx context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error) {
33 | sessionID, err := GetSessionIDFromCtx(ctx)
34 | if err != nil {
35 | return nil, err
36 | }
37 |
38 | s, ok := server.sessionManager.GetSession(sessionID)
39 | if !ok {
40 | return nil, pkg.ErrLackSession
41 | }
42 |
43 | if s.GetClientCapabilities() == nil || s.GetClientCapabilities().Sampling == nil {
44 | return nil, pkg.ErrClientNotSupport
45 | }
46 |
47 | response, err := server.callClient(ctx, sessionID, protocol.SamplingCreateMessage, request)
48 | if err != nil {
49 | return nil, err
50 | }
51 |
52 | var result protocol.CreateMessageResult
53 | if err = pkg.JSONUnmarshal(response, &result); err != nil {
54 | return nil, fmt.Errorf("failed to unmarshal response: %w", err)
55 | }
56 | return &result, nil
57 | }
58 |
59 | func (server *Server) SendProgressNotification(ctx context.Context, notify *protocol.ProgressNotification) error {
60 | progressToken, err := getProgressTokenFromCtx(ctx)
61 | if err != nil {
62 | return err
63 | }
64 | notify.ProgressToken = progressToken
65 |
66 | if err = server.sendMsgWithNotification(ctx, "", protocol.NotificationProgress, notify); err != nil {
67 | return err
68 | }
69 |
70 | return nil
71 | }
72 |
73 | func (server *Server) sendNotification4ToolListChanges(ctx context.Context) error {
74 | if server.capabilities.Tools == nil || !server.capabilities.Tools.ListChanged {
75 | return pkg.ErrServerNotSupport
76 | }
77 |
78 | var errList []error
79 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool {
80 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationToolsListChanged, protocol.NewToolListChangedNotification()); err != nil {
81 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err))
82 | }
83 | return true
84 | })
85 | return pkg.JoinErrors(errList)
86 | }
87 |
88 | func (server *Server) sendNotification4PromptListChanges(ctx context.Context) error {
89 | if server.capabilities.Prompts == nil || !server.capabilities.Prompts.ListChanged {
90 | return pkg.ErrServerNotSupport
91 | }
92 |
93 | var errList []error
94 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool {
95 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationPromptsListChanged, protocol.NewPromptListChangedNotification()); err != nil {
96 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err))
97 | }
98 | return true
99 | })
100 | return pkg.JoinErrors(errList)
101 | }
102 |
103 | func (server *Server) sendNotification4ResourceListChanges(ctx context.Context) error {
104 | if server.capabilities.Resources == nil || !server.capabilities.Resources.ListChanged {
105 | return pkg.ErrServerNotSupport
106 | }
107 |
108 | var errList []error
109 | server.sessionManager.RangeSessions(func(sessionID string, _ *session.State) bool {
110 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationResourcesListChanged,
111 | protocol.NewResourceListChangedNotification()); err != nil {
112 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err))
113 | }
114 | return true
115 | })
116 | return pkg.JoinErrors(errList)
117 | }
118 |
119 | func (server *Server) SendNotification4ResourcesUpdated(ctx context.Context, notify *protocol.ResourceUpdatedNotification) error {
120 | if server.capabilities.Resources == nil || !server.capabilities.Resources.Subscribe {
121 | return pkg.ErrServerNotSupport
122 | }
123 |
124 | var errList []error
125 | server.sessionManager.RangeSessions(func(sessionID string, s *session.State) bool {
126 | if _, ok := s.GetSubscribedResources().Get(notify.URI); !ok {
127 | return true
128 | }
129 |
130 | if err := server.sendMsgWithNotification(ctx, sessionID, protocol.NotificationResourcesUpdated, notify); err != nil {
131 | errList = append(errList, fmt.Errorf("sessionID=%s, err: %w", sessionID, err))
132 | }
133 | return true
134 | })
135 | return pkg.JoinErrors(errList)
136 | }
137 |
138 | // Responsible for request and response assembly
139 | func (server *Server) callClient(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerRequest) (json.RawMessage, error) {
140 | session, ok := server.sessionManager.GetSession(sessionID)
141 | if !ok {
142 | return nil, fmt.Errorf("callClient: %w", pkg.ErrLackSession)
143 | }
144 |
145 | requestID := strconv.FormatInt(session.IncRequestID(), 10)
146 | respChan := make(chan *protocol.JSONRPCResponse, 1)
147 | session.GetServerReqID2respChan().Set(requestID, respChan)
148 | defer session.GetServerReqID2respChan().Remove(requestID)
149 |
150 | if err := server.sendMsgWithRequest(ctx, sessionID, requestID, method, params); err != nil {
151 | return nil, fmt.Errorf("callClient: %w", err)
152 | }
153 |
154 | select {
155 | case <-ctx.Done():
156 | return nil, ctx.Err()
157 | case response := <-respChan:
158 | if err := response.Error; err != nil {
159 | return nil, pkg.NewResponseError(err.Code, err.Message, err.Data)
160 | }
161 | return response.RawResult, nil
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/server/context.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "errors"
6 | )
7 |
8 | type sessionIDKey struct{}
9 |
10 | func setSessionIDToCtx(ctx context.Context, sessionID string) context.Context {
11 | return context.WithValue(ctx, sessionIDKey{}, sessionID)
12 | }
13 |
14 | func GetSessionIDFromCtx(ctx context.Context) (string, error) {
15 | sessionID := ctx.Value(sessionIDKey{})
16 | if sessionID == nil {
17 | return "", errors.New("no session id found")
18 | }
19 | return sessionID.(string), nil
20 | }
21 |
22 | type sendChanKey struct{}
23 |
24 | func setSendChanToCtx(ctx context.Context, sendCh chan<- []byte) context.Context {
25 | return context.WithValue(ctx, sendChanKey{}, sendCh)
26 | }
27 |
28 | func getSendChanFromCtx(ctx context.Context) (chan<- []byte, error) {
29 | ch := ctx.Value(sendChanKey{})
30 | if ch == nil {
31 | return nil, errors.New("no send chan found")
32 | }
33 | return ch.(chan<- []byte), nil
34 | }
35 |
36 | type progressTokenKey struct{}
37 |
38 | func setProgressTokenToCtx(ctx context.Context, progressToken interface{}) context.Context {
39 | return context.WithValue(ctx, progressTokenKey{}, progressToken)
40 | }
41 |
42 | func getProgressTokenFromCtx(ctx context.Context) (interface{}, error) {
43 | progressToken := ctx.Value(progressTokenKey{})
44 | if progressToken == nil {
45 | return "", errors.New("no progress token found")
46 | }
47 | return progressToken, nil
48 | }
49 |
--------------------------------------------------------------------------------
/server/receive.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "errors"
7 | "fmt"
8 |
9 | "github.com/tidwall/gjson"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
12 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
13 | )
14 |
15 | func (server *Server) receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) {
16 | if sessionID != "" && !server.sessionManager.IsActiveSession(sessionID) {
17 | if server.sessionManager.IsClosedSession(sessionID) {
18 | return nil, pkg.ErrSessionClosed
19 | }
20 | return nil, pkg.ErrLackSession
21 | }
22 |
23 | if !gjson.GetBytes(msg, "id").Exists() {
24 | notify := &protocol.JSONRPCNotification{}
25 | if err := pkg.JSONUnmarshal(msg, ¬ify); err != nil {
26 | return nil, err
27 | }
28 | if err := server.receiveNotify(sessionID, notify); err != nil {
29 | notify.RawParams = nil // simplified log
30 | server.logger.Errorf("receive notify:%+v error: %s", notify, err.Error())
31 | return nil, err
32 | }
33 | return nil, nil
34 | }
35 |
36 | // case request or response
37 | if !gjson.GetBytes(msg, "method").Exists() {
38 | resp := &protocol.JSONRPCResponse{}
39 | if err := pkg.JSONUnmarshal(msg, &resp); err != nil {
40 | return nil, err
41 | }
42 |
43 | if err := server.receiveResponse(sessionID, resp); err != nil {
44 | resp.RawResult = nil // simplified log
45 | server.logger.Errorf("receive response:%+v error: %s", resp, err.Error())
46 | return nil, err
47 | }
48 | return nil, nil
49 | }
50 |
51 | req := &protocol.JSONRPCRequest{}
52 | if err := pkg.JSONUnmarshal(msg, &req); err != nil {
53 | return nil, err
54 | }
55 | if !req.IsValid() {
56 | return nil, pkg.ErrRequestInvalid
57 | }
58 |
59 | // if sessionID != "" && req.Method != protocol.Initialize && req.Method != protocol.Ping {
60 | // if s, ok := server.sessionManager.GetSession(sessionID); !ok {
61 | // return nil, pkg.ErrLackSession
62 | // } else if !s.GetReady() {
63 | // return nil, pkg.ErrSessionHasNotInitialized
64 | // }
65 | // }
66 |
67 | server.inFlyRequest.Add(1)
68 |
69 | if server.inShutdown.Load() {
70 | server.inFlyRequest.Done()
71 | return nil, errors.New("server already shutdown")
72 | }
73 |
74 | ch := make(chan []byte, 5)
75 | go func(ctx context.Context) {
76 | defer pkg.Recover()
77 | defer server.inFlyRequest.Done()
78 | defer close(ch)
79 |
80 | if s, ok := server.sessionManager.GetSession(sessionID); ok && req.Method != protocol.Initialize {
81 | var cancel context.CancelFunc
82 | ctx, cancel = context.WithCancel(ctx)
83 | requestID := fmt.Sprint(req.ID)
84 | s.GetClientReqID2cancelFunc().Set(requestID, cancel)
85 | defer s.GetClientReqID2cancelFunc().Remove(requestID)
86 | }
87 |
88 | if r := gjson.GetBytes(req.RawParams, fmt.Sprintf("_meta.%s", protocol.ProgressTokenKey)); r.Exists() {
89 | ctx = setProgressTokenToCtx(ctx, r.Value())
90 | }
91 |
92 | ctx = setSendChanToCtx(ctx, ch)
93 |
94 | resp := server.receiveRequest(ctx, sessionID, req)
95 | if errors.Is(ctx.Err(), context.Canceled) {
96 | return
97 | }
98 | message, err := json.Marshal(resp)
99 | if err != nil {
100 | server.logger.Errorf("receive json marshal response:%+v error: %s", resp, err.Error())
101 | return
102 | }
103 | ch <- message
104 | }(pkg.NewCancelShieldContext(ctx))
105 | return ch, nil
106 | }
107 |
108 | func (server *Server) receiveRequest(ctx context.Context, sessionID string, request *protocol.JSONRPCRequest) *protocol.JSONRPCResponse {
109 | if sessionID != "" {
110 | ctx = setSessionIDToCtx(ctx, sessionID)
111 | }
112 |
113 | if request.Method != protocol.Ping {
114 | server.sessionManager.UpdateSessionLastActiveAt(sessionID)
115 | }
116 |
117 | var (
118 | result protocol.ServerResponse
119 | err error
120 | )
121 |
122 | switch request.Method {
123 | case protocol.Ping:
124 | result, err = server.handleRequestWithPing()
125 | case protocol.Initialize:
126 | result, err = server.handleRequestWithInitialize(ctx, sessionID, request.RawParams)
127 | case protocol.PromptsList:
128 | result, err = server.handleRequestWithListPrompts(request.RawParams)
129 | case protocol.PromptsGet:
130 | result, err = server.handleRequestWithGetPrompt(ctx, request.RawParams)
131 | case protocol.ResourcesList:
132 | result, err = server.handleRequestWithListResources(request.RawParams)
133 | case protocol.ResourceListTemplates:
134 | result, err = server.handleRequestWithListResourceTemplates(request.RawParams)
135 | case protocol.ResourcesRead:
136 | result, err = server.handleRequestWithReadResource(ctx, request.RawParams)
137 | case protocol.ResourcesSubscribe:
138 | result, err = server.handleRequestWithSubscribeResourceChange(sessionID, request.RawParams)
139 | case protocol.ResourcesUnsubscribe:
140 | result, err = server.handleRequestWithUnSubscribeResourceChange(sessionID, request.RawParams)
141 | case protocol.ToolsList:
142 | result, err = server.handleRequestWithListTools(request.RawParams)
143 | case protocol.ToolsCall:
144 | result, err = server.handleRequestWithCallTool(ctx, request.RawParams)
145 | default:
146 | err = fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, request.Method)
147 | }
148 |
149 | if err != nil {
150 | var code int
151 | switch {
152 | case errors.Is(err, pkg.ErrMethodNotSupport):
153 | code = protocol.MethodNotFound
154 | case errors.Is(err, pkg.ErrRequestInvalid):
155 | code = protocol.InvalidRequest
156 | case errors.Is(err, pkg.ErrJSONUnmarshal):
157 | code = protocol.ParseError
158 | default:
159 | code = protocol.InternalError
160 | }
161 | return protocol.NewJSONRPCErrorResponse(request.ID, code, err.Error())
162 | }
163 | return protocol.NewJSONRPCSuccessResponse(request.ID, result)
164 | }
165 |
166 | func (server *Server) receiveNotify(sessionID string, notify *protocol.JSONRPCNotification) error {
167 | // if sessionID != "" {
168 | // if s, ok := server.sessionManager.GetSession(sessionID); !ok {
169 | // return pkg.ErrLackSession
170 | // } else if notify.Method != protocol.NotificationInitialized && !s.GetReady() {
171 | // return pkg.ErrSessionHasNotInitialized
172 | // }
173 | // }
174 |
175 | switch notify.Method {
176 | case protocol.NotificationInitialized:
177 | return server.handleNotifyWithInitialized(sessionID, notify.RawParams)
178 | case protocol.NotificationCancelled:
179 | return server.handleNotifyWithCancelled(sessionID, notify.RawParams)
180 | default:
181 | return fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, notify.Method)
182 | }
183 | }
184 |
185 | func (server *Server) receiveResponse(sessionID string, response *protocol.JSONRPCResponse) error {
186 | s, ok := server.sessionManager.GetSession(sessionID)
187 | if !ok {
188 | return pkg.ErrLackSession
189 | }
190 |
191 | respChan, ok := s.GetServerReqID2respChan().Get(fmt.Sprint(response.ID))
192 | if !ok {
193 | return fmt.Errorf("%w: sessionID=%+v, requestID=%+v", pkg.ErrLackResponseChan, sessionID, response.ID)
194 | }
195 |
196 | select {
197 | case respChan <- response:
198 | default:
199 | return fmt.Errorf("%w: sessionID=%+v, response=%+v", pkg.ErrDuplicateResponseReceived, sessionID, response)
200 | }
201 | return nil
202 | }
203 |
--------------------------------------------------------------------------------
/server/send.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 |
8 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
9 | )
10 |
11 | func (server *Server) sendMsgWithRequest(ctx context.Context, sessionID string, requestID protocol.RequestID,
12 | method protocol.Method, params protocol.ServerRequest,
13 | ) error { //nolint:whitespace
14 | if requestID == nil {
15 | return fmt.Errorf("requestID can't is nil")
16 | }
17 |
18 | req := protocol.NewJSONRPCRequest(requestID, method, params)
19 |
20 | message, err := json.Marshal(req)
21 | if err != nil {
22 | return err
23 | }
24 |
25 | if ch, err := getSendChanFromCtx(ctx); err == nil {
26 | ch <- message
27 | return nil
28 | }
29 |
30 | if err := server.transport.Send(ctx, sessionID, message); err != nil {
31 | return fmt.Errorf("sendRequest: transport send: %w", err)
32 | }
33 | return nil
34 | }
35 |
36 | func (server *Server) sendMsgWithNotification(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerNotify) error {
37 | notify := protocol.NewJSONRPCNotification(method, params)
38 |
39 | message, err := json.Marshal(notify)
40 | if err != nil {
41 | return err
42 | }
43 |
44 | if ch, err := getSendChanFromCtx(ctx); err == nil {
45 | ch <- message
46 | return nil
47 | }
48 |
49 | if err := server.transport.Send(ctx, sessionID, message); err != nil {
50 | return fmt.Errorf("sendNotification: transport send: %w", err)
51 | }
52 | return nil
53 | }
54 |
--------------------------------------------------------------------------------
/server/session/manager.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
8 | )
9 |
10 | type Manager struct {
11 | activeSessions pkg.SyncMap[*State]
12 | closedSessions pkg.SyncMap[struct{}]
13 |
14 | stopHeartbeat chan struct{}
15 |
16 | genSessionID func(ctx context.Context) string
17 |
18 | logger pkg.Logger
19 |
20 | detection func(ctx context.Context, sessionID string) error
21 | maxIdleTime time.Duration
22 | }
23 |
24 | func NewManager(detection func(ctx context.Context, sessionID string) error, genSessionID func(ctx context.Context) string) *Manager {
25 | return &Manager{
26 | genSessionID: genSessionID,
27 | detection: detection,
28 | stopHeartbeat: make(chan struct{}),
29 | logger: pkg.DefaultLogger,
30 | }
31 | }
32 |
33 | func (m *Manager) SetMaxIdleTime(d time.Duration) {
34 | m.maxIdleTime = d
35 | }
36 |
37 | func (m *Manager) SetLogger(logger pkg.Logger) {
38 | m.logger = logger
39 | }
40 |
41 | func (m *Manager) CreateSession(ctx context.Context) string {
42 | sessionID := m.genSessionID(ctx)
43 | state := NewState()
44 | m.activeSessions.Store(sessionID, state)
45 | return sessionID
46 | }
47 |
48 | func (m *Manager) IsActiveSession(sessionID string) bool {
49 | _, has := m.activeSessions.Load(sessionID)
50 | return has
51 | }
52 |
53 | func (m *Manager) IsClosedSession(sessionID string) bool {
54 | _, has := m.closedSessions.Load(sessionID)
55 | return has
56 | }
57 |
58 | func (m *Manager) GetSession(sessionID string) (*State, bool) {
59 | if sessionID == "" {
60 | return nil, false
61 | }
62 | state, has := m.activeSessions.Load(sessionID)
63 | if !has {
64 | return nil, false
65 | }
66 | return state, true
67 | }
68 |
69 | func (m *Manager) OpenMessageQueueForSend(sessionID string) error {
70 | state, has := m.GetSession(sessionID)
71 | if !has {
72 | return pkg.ErrLackSession
73 | }
74 | state.openMessageQueueForSend()
75 | return nil
76 | }
77 |
78 | func (m *Manager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error {
79 | state, has := m.GetSession(sessionID)
80 | if !has {
81 | return pkg.ErrLackSession
82 | }
83 | return state.enqueueMessage(ctx, message)
84 | }
85 |
86 | func (m *Manager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) {
87 | state, has := m.GetSession(sessionID)
88 | if !has {
89 | return nil, pkg.ErrLackSession
90 | }
91 | return state.dequeueMessage(ctx)
92 | }
93 |
94 | func (m *Manager) UpdateSessionLastActiveAt(sessionID string) {
95 | state, ok := m.activeSessions.Load(sessionID)
96 | if !ok {
97 | return
98 | }
99 | state.updateLastActiveAt()
100 | }
101 |
102 | func (m *Manager) CloseSession(sessionID string) {
103 | state, ok := m.activeSessions.LoadAndDelete(sessionID)
104 | if !ok {
105 | return
106 | }
107 | state.Close()
108 | m.closedSessions.Store(sessionID, struct{}{})
109 | }
110 |
111 | func (m *Manager) CloseAllSessions() {
112 | m.activeSessions.Range(func(sessionID string, _ *State) bool {
113 | // Here we load the session again to prevent concurrency conflicts with CloseSession, which may cause repeated close chan
114 | m.CloseSession(sessionID)
115 | return true
116 | })
117 | }
118 |
119 | func (m *Manager) StartHeartbeatAndCleanInvalidSessions() {
120 | ticker := time.NewTicker(time.Minute)
121 | defer ticker.Stop()
122 |
123 | for {
124 | select {
125 | case <-m.stopHeartbeat:
126 | return
127 | case <-ticker.C:
128 | now := time.Now()
129 | m.activeSessions.Range(func(sessionID string, state *State) bool {
130 | if m.maxIdleTime != 0 && now.Sub(state.lastActiveAt) > m.maxIdleTime {
131 | m.logger.Infof("session expire, session id: %v", sessionID)
132 | m.CloseSession(sessionID)
133 | return true
134 | }
135 |
136 | var err error
137 | for i := 0; i < 3; i++ {
138 | if err = m.detection(context.Background(), sessionID); err == nil {
139 | return true
140 | }
141 | }
142 | m.logger.Infof("session detection fail, session id: %v, fail reason: %+v", sessionID, err)
143 | m.CloseSession(sessionID)
144 | return true
145 | })
146 | }
147 | }
148 | }
149 |
150 | func (m *Manager) StopHeartbeat() {
151 | close(m.stopHeartbeat)
152 | }
153 |
154 | func (m *Manager) RangeSessions(f func(sessionID string, state *State) bool) {
155 | m.activeSessions.Range(f)
156 | }
157 |
158 | func (m *Manager) IsEmpty() bool {
159 | isEmpty := true
160 | m.activeSessions.Range(func(string, *State) bool {
161 | isEmpty = false
162 | return false
163 | })
164 | return isEmpty
165 | }
166 |
--------------------------------------------------------------------------------
/server/session/state.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "sync"
7 | "sync/atomic"
8 | "time"
9 |
10 | cmap "github.com/orcaman/concurrent-map/v2"
11 |
12 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
13 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
14 | )
15 |
16 | var ErrQueueNotOpened = errors.New("queue has not been opened")
17 |
18 | type State struct {
19 | lastActiveAt time.Time
20 |
21 | mu sync.RWMutex
22 | sendChan chan []byte
23 |
24 | requestID int64
25 |
26 | serverReqID2respChan cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse]
27 |
28 | clientReqID2cancelFunc cmap.ConcurrentMap[string, context.CancelFunc]
29 |
30 | // cache client initialize request info
31 | clientInfo *protocol.Implementation
32 | clientCapabilities *protocol.ClientCapabilities
33 |
34 | // subscribed resources
35 | subscribedResources cmap.ConcurrentMap[string, struct{}]
36 |
37 | receivedInitRequest *pkg.AtomicBool
38 | ready *pkg.AtomicBool
39 | closed *pkg.AtomicBool
40 | }
41 |
42 | func NewState() *State {
43 | return &State{
44 | lastActiveAt: time.Now(),
45 | serverReqID2respChan: cmap.New[chan *protocol.JSONRPCResponse](),
46 | clientReqID2cancelFunc: cmap.New[context.CancelFunc](),
47 | subscribedResources: cmap.New[struct{}](),
48 | receivedInitRequest: pkg.NewAtomicBool(),
49 | ready: pkg.NewAtomicBool(),
50 | closed: pkg.NewAtomicBool(),
51 | }
52 | }
53 |
54 | func (s *State) SetClientInfo(ClientInfo *protocol.Implementation, ClientCapabilities *protocol.ClientCapabilities) {
55 | s.clientInfo = ClientInfo
56 | s.clientCapabilities = ClientCapabilities
57 | }
58 |
59 | func (s *State) GetClientCapabilities() *protocol.ClientCapabilities {
60 | return s.clientCapabilities
61 | }
62 |
63 | func (s *State) SetReceivedInitRequest() {
64 | s.receivedInitRequest.Store(true)
65 | }
66 |
67 | func (s *State) GetReceivedInitRequest() bool {
68 | return s.receivedInitRequest.Load()
69 | }
70 |
71 | func (s *State) SetReady() {
72 | s.ready.Store(true)
73 | }
74 |
75 | func (s *State) GetReady() bool {
76 | return s.ready.Load()
77 | }
78 |
79 | func (s *State) IncRequestID() int64 {
80 | return atomic.AddInt64(&s.requestID, 1)
81 | }
82 |
83 | func (s *State) GetServerReqID2respChan() cmap.ConcurrentMap[string, chan *protocol.JSONRPCResponse] {
84 | return s.serverReqID2respChan
85 | }
86 |
87 | func (s *State) GetClientReqID2cancelFunc() cmap.ConcurrentMap[string, context.CancelFunc] {
88 | return s.clientReqID2cancelFunc
89 | }
90 |
91 | func (s *State) GetSubscribedResources() cmap.ConcurrentMap[string, struct{}] {
92 | return s.subscribedResources
93 | }
94 |
95 | func (s *State) Close() {
96 | s.mu.Lock()
97 | defer s.mu.Unlock()
98 |
99 | s.closed.Store(true)
100 |
101 | if s.sendChan != nil {
102 | close(s.sendChan)
103 | }
104 | }
105 |
106 | func (s *State) updateLastActiveAt() {
107 | s.lastActiveAt = time.Now()
108 | }
109 |
110 | func (s *State) openMessageQueueForSend() {
111 | s.mu.Lock()
112 | defer s.mu.Unlock()
113 |
114 | if s.sendChan == nil {
115 | s.sendChan = make(chan []byte, 64)
116 | }
117 | }
118 |
119 | func (s *State) enqueueMessage(ctx context.Context, message []byte) error {
120 | s.mu.RLock()
121 | defer s.mu.RUnlock()
122 |
123 | if s.closed.Load() {
124 | return errors.New("session already closed")
125 | }
126 |
127 | if s.sendChan == nil {
128 | return ErrQueueNotOpened
129 | }
130 |
131 | select {
132 | case s.sendChan <- message:
133 | return nil
134 | case <-ctx.Done():
135 | return ctx.Err()
136 | }
137 | }
138 |
139 | func (s *State) dequeueMessage(ctx context.Context) ([]byte, error) {
140 | s.mu.RLock()
141 | if s.sendChan == nil {
142 | s.mu.RUnlock()
143 | return nil, ErrQueueNotOpened
144 | }
145 | s.mu.RUnlock()
146 |
147 | select {
148 | case <-ctx.Done():
149 | return nil, ctx.Err()
150 | case msg, ok := <-s.sendChan:
151 | if msg == nil && !ok {
152 | // There are no new messages and the chan has been closed, indicating that the request may need to be terminated.
153 | return nil, pkg.ErrSendEOF
154 | }
155 | return msg, nil
156 | }
157 | }
158 |
--------------------------------------------------------------------------------
/testdata/mock_block_server.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | )
7 |
8 | func main() {
9 | if _, err := os.Stdin.Read(make([]byte, 1)); err != nil {
10 | fmt.Println(err)
11 | return
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/tests/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ThinkInAIXYZ/go-mcp/c7a0eb1f7e4a288220d3a3375006802558f473a2/tests/.DS_Store
--------------------------------------------------------------------------------
/tests/sse_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "os"
7 | "os/exec"
8 | "strconv"
9 | "testing"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/transport"
12 | )
13 |
14 | func TestSSE(t *testing.T) {
15 | port, err := getAvailablePort()
16 | if err != nil {
17 | t.Fatalf("Failed to get available port: %v", err)
18 | }
19 |
20 | transportClient, err := transport.NewSSEClientTransport(fmt.Sprintf("http://127.0.0.1:%d/sse", port))
21 | if err != nil {
22 | t.Fatalf("Failed to create transport client: %v", err)
23 | }
24 |
25 | test(t, func() error { return runSSEServer(port) }, transportClient, transport.Stateful)
26 | }
27 |
28 | // getAvailablePort returns a port that is available for use
29 | func getAvailablePort() (int, error) {
30 | addr, err := net.Listen("tcp", "127.0.0.1:0")
31 | if err != nil {
32 | return 0, fmt.Errorf("failed to get available port: %v", err)
33 | }
34 | defer func() {
35 | if err = addr.Close(); err != nil {
36 | fmt.Println(err)
37 | }
38 | }()
39 |
40 | port := addr.Addr().(*net.TCPAddr).Port
41 | return port, nil
42 | }
43 |
44 | func runSSEServer(port int) error {
45 | mockServerTrPath, err := compileMockStdioServerTr()
46 | if err != nil {
47 | return err
48 | }
49 | fmt.Println(mockServerTrPath)
50 |
51 | defer func(name string) {
52 | if err := os.Remove(name); err != nil {
53 | fmt.Printf("failed to remove mock server: %v\n", err)
54 | }
55 | }(mockServerTrPath)
56 |
57 | return exec.Command(mockServerTrPath, "-transport", "sse", "-port", strconv.Itoa(port)).Run()
58 | }
59 |
--------------------------------------------------------------------------------
/tests/stdio_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "testing"
7 |
8 | "github.com/ThinkInAIXYZ/go-mcp/transport"
9 | )
10 |
11 | func TestStdio(t *testing.T) {
12 | mockServerTrPath, err := compileMockStdioServerTr()
13 | if err != nil {
14 | t.Fatal(err)
15 | }
16 | defer func(name string) {
17 | if err = os.Remove(name); err != nil {
18 | fmt.Printf("Failed to remove mock server: %v\n", err)
19 | }
20 | }(mockServerTrPath)
21 |
22 | fmt.Println(mockServerTrPath)
23 | transportClient, err := transport.NewStdioClientTransport(mockServerTrPath, []string{"-transport", "stdio"})
24 | if err != nil {
25 | t.Fatalf("Failed to create transport client: %v", err)
26 | }
27 |
28 | test(t, func() error {
29 | <-make(chan error)
30 | return nil
31 | }, transportClient, transport.Stateful)
32 | }
33 |
--------------------------------------------------------------------------------
/tests/streamable_http_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "os/exec"
7 | "strconv"
8 | "testing"
9 |
10 | "github.com/ThinkInAIXYZ/go-mcp/transport"
11 | )
12 |
13 | func TestStreamableHTTPWithStateless(t *testing.T) {
14 | port, err := getAvailablePort()
15 | if err != nil {
16 | t.Fatalf("Failed to get available port: %v", err)
17 | }
18 |
19 | transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port))
20 | if err != nil {
21 | t.Fatalf("Failed to create transport client: %v", err)
22 | }
23 |
24 | test(t, func() error { return runStreamableHTTPServer(port, transport.Stateless) }, transportClient, transport.Stateless)
25 | }
26 |
27 | func TestStreamableHTTPWithStateful(t *testing.T) {
28 | port, err := getAvailablePort()
29 | if err != nil {
30 | t.Fatalf("Failed to get available port: %v", err)
31 | }
32 |
33 | transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port))
34 | if err != nil {
35 | t.Fatalf("Failed to create transport client: %v", err)
36 | }
37 |
38 | test(t, func() error { return runStreamableHTTPServer(port, transport.Stateful) }, transportClient, transport.Stateful)
39 | }
40 |
41 | func runStreamableHTTPServer(port int, stateful transport.StateMode) error {
42 | mockServerTrPath, err := compileMockStdioServerTr()
43 | if err != nil {
44 | return err
45 | }
46 | fmt.Println(mockServerTrPath)
47 |
48 | defer func(name string) {
49 | if err := os.Remove(name); err != nil {
50 | fmt.Printf("failed to remove mock server: %v\n", err)
51 | }
52 | }(mockServerTrPath)
53 |
54 | return exec.Command(mockServerTrPath, "-transport", "streamable_http", "-port", strconv.Itoa(port), "-state_mode", string(stateful)).Run()
55 | }
56 |
--------------------------------------------------------------------------------
/tests/utils.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "math/rand"
8 | "os"
9 | "os/exec"
10 | "path/filepath"
11 | "strconv"
12 | "testing"
13 | "time"
14 |
15 | "github.com/ThinkInAIXYZ/go-mcp/client"
16 | "github.com/ThinkInAIXYZ/go-mcp/protocol"
17 | "github.com/ThinkInAIXYZ/go-mcp/transport"
18 | )
19 |
20 | func test(t *testing.T, runServer func() error, transportClient transport.ClientTransport, mode transport.StateMode) {
21 | errCh := make(chan error, 1)
22 | go func() {
23 | errCh <- runServer()
24 | }()
25 |
26 | // Use select to handle potential errors
27 | select {
28 | case err := <-errCh:
29 | t.Fatalf("server.Run() failed: %v", err)
30 | case <-time.After(time.Second * 3):
31 | // Server started normally
32 | }
33 |
34 | // Create MCP client using transport
35 | mcpClient, err := client.NewClient(transportClient, client.WithClientInfo(&protocol.Implementation{
36 | Name: "Example MCP Client",
37 | Version: "1.0.0",
38 | }), client.WithSamplingHandler(&sampling{}))
39 | if err != nil {
40 | t.Fatalf("Failed to create MCP client: %v", err)
41 | }
42 | defer func() {
43 | if err = mcpClient.Close(); err != nil {
44 | t.Fatalf("Failed to close MCP client: %v", err)
45 | return
46 | }
47 | }()
48 |
49 | // List available tools
50 | toolsResult, err := mcpClient.ListTools(context.Background())
51 | if err != nil {
52 | t.Fatalf("Failed to list tools: %v", err)
53 | }
54 | bytes, _ := json.Marshal(toolsResult)
55 | fmt.Printf("Available tools: %s\n", bytes)
56 |
57 | callResult, err := mcpClient.CallTool(
58 | context.Background(),
59 | protocol.NewCallToolRequestWithRawArguments("current_time", json.RawMessage(`{"timezone": "UTC"}`)))
60 | if err != nil {
61 | t.Fatalf("Failed to call tool: %v", err)
62 | }
63 | bytes, _ = json.Marshal(callResult)
64 | fmt.Printf("Tool call result: %s\n", bytes)
65 |
66 | progressCh := make(chan *protocol.ProgressNotification)
67 | go func() {
68 | for progress := range progressCh {
69 | fmt.Printf("Progress: %+v\n", progress)
70 | }
71 | }()
72 | callResult, err = mcpClient.CallToolWithProgressChan(context.Background(),
73 | protocol.NewCallToolRequestWithRawArguments("generate_ppt", json.RawMessage(`{"ppt_description": "test"}`)), progressCh)
74 | if err != nil {
75 | t.Fatalf("Failed to call tool: %v", err)
76 | }
77 | bytes, _ = json.Marshal(callResult)
78 | fmt.Printf("Tool call result: %s\n", bytes)
79 |
80 | if mode == transport.Stateful {
81 | // if streamable_http transport, need wait streamable_http connection start
82 | time.Sleep(time.Second)
83 |
84 | callResult, err = mcpClient.CallTool(
85 | context.Background(),
86 | protocol.NewCallToolRequestWithRawArguments("delete_file", json.RawMessage(`{"file_name": "test_file.txt"}`)))
87 | if err != nil {
88 | t.Fatalf("Failed to call tool: %v", err)
89 | }
90 | bytes, _ = json.Marshal(callResult)
91 | fmt.Printf("Tool call result: %s\n", bytes)
92 | }
93 | }
94 |
95 | type sampling struct{}
96 |
97 | func (s *sampling) CreateMessage(_ context.Context, request *protocol.CreateMessageRequest) (*protocol.CreateMessageResult, error) {
98 | var lastUserMessages protocol.Content
99 | for _, message := range request.Messages {
100 | if message.Role == "user" {
101 | lastUserMessages = message.Content
102 | }
103 | }
104 |
105 | if lastUserMessages.GetType() != "text" {
106 | return nil, fmt.Errorf("expected 'text', got %s", lastUserMessages.GetType())
107 | }
108 |
109 | return &protocol.CreateMessageResult{
110 | Content: &protocol.TextContent{
111 | Annotated: protocol.Annotated{},
112 | Type: "text",
113 | Text: strconv.FormatBool(true),
114 | },
115 | Role: "assistant",
116 | Model: "stub-model",
117 | StopReason: "endTurn",
118 | }, nil
119 | }
120 |
121 | func compileMockStdioServerTr() (string, error) {
122 | r := rand.New(rand.NewSource(time.Now().UnixNano()))
123 |
124 | mockServerTrPath := filepath.Join(os.TempDir(), "mock_server_tr_"+strconv.Itoa(r.Int()))
125 |
126 | cmd := exec.Command("go", "build", "-o", mockServerTrPath, "../examples/everything/main.go")
127 |
128 | if output, err := cmd.CombinedOutput(); err != nil {
129 | return "", fmt.Errorf("compilation failed: %v\nOutput: %s", err, output)
130 | }
131 |
132 | return mockServerTrPath, nil
133 | }
134 |
--------------------------------------------------------------------------------
/transport/mock_client.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
12 | )
13 |
14 | type mockClientTransport struct {
15 | receiver clientReceiver
16 | in io.ReadCloser
17 | out io.Writer
18 |
19 | logger pkg.Logger
20 |
21 | cancel context.CancelFunc
22 | receiveShutDone chan struct{}
23 | }
24 |
25 | func NewMockClientTransport(in io.ReadCloser, out io.Writer) ClientTransport {
26 | return &mockClientTransport{
27 | in: in,
28 | out: out,
29 | logger: pkg.DefaultLogger,
30 | receiveShutDone: make(chan struct{}),
31 | }
32 | }
33 |
34 | func (t *mockClientTransport) Start() error {
35 | ctx, cancel := context.WithCancel(context.Background())
36 | t.cancel = cancel
37 |
38 | go func() {
39 | defer pkg.Recover()
40 |
41 | t.startReceive(ctx)
42 |
43 | close(t.receiveShutDone)
44 | }()
45 |
46 | return nil
47 | }
48 |
49 | func (t *mockClientTransport) Send(_ context.Context, msg Message) error {
50 | if _, err := t.out.Write(append(msg, mcpMessageDelimiter)); err != nil {
51 | return fmt.Errorf("failed to write: %w", err)
52 | }
53 | return nil
54 | }
55 |
56 | func (t *mockClientTransport) SetReceiver(receiver clientReceiver) {
57 | t.receiver = receiver
58 | }
59 |
60 | func (t *mockClientTransport) Close() error {
61 | t.cancel()
62 |
63 | if err := t.in.Close(); err != nil {
64 | return fmt.Errorf("failed to close writer: %w", err)
65 | }
66 |
67 | <-t.receiveShutDone
68 |
69 | return nil
70 | }
71 |
72 | func (t *mockClientTransport) startReceive(ctx context.Context) {
73 | s := bufio.NewReader(t.in)
74 |
75 | for {
76 | line, err := s.ReadBytes('\n')
77 | if err != nil {
78 | t.receiver.Interrupt(fmt.Errorf("reader read error: %w", err))
79 |
80 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here
81 | errors.Is(err, io.EOF) {
82 | return
83 | }
84 | t.logger.Errorf("reader read error: %+v", err)
85 | return
86 | }
87 |
88 | line = bytes.TrimRight(line, "\n")
89 |
90 | select {
91 | case <-ctx.Done():
92 | return
93 | default:
94 | if err = t.receiver.Receive(ctx, line); err != nil {
95 | t.logger.Errorf("receiver failed: %v", err)
96 | }
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/transport/mock_server.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
12 | )
13 |
14 | type mockServerTransport struct {
15 | receiver serverReceiver
16 | in io.ReadCloser
17 | out io.Writer
18 |
19 | sessionID string
20 |
21 | sessionManager sessionManager
22 |
23 | logger pkg.Logger
24 |
25 | cancel context.CancelFunc
26 | receiveShutDone chan struct{}
27 | }
28 |
29 | func NewMockServerTransport(in io.ReadCloser, out io.Writer) ServerTransport {
30 | return &mockServerTransport{
31 | in: in,
32 | out: out,
33 | logger: pkg.DefaultLogger,
34 |
35 | receiveShutDone: make(chan struct{}),
36 | }
37 | }
38 |
39 | func (t *mockServerTransport) Run() error {
40 | ctx, cancel := context.WithCancel(context.Background())
41 | t.cancel = cancel
42 |
43 | t.sessionID = t.sessionManager.CreateSession(context.Background())
44 |
45 | t.startReceive(ctx)
46 |
47 | close(t.receiveShutDone)
48 | return nil
49 | }
50 |
51 | func (t *mockServerTransport) Send(_ context.Context, _ string, msg Message) error {
52 | if _, err := t.out.Write(append(msg, mcpMessageDelimiter)); err != nil {
53 | return fmt.Errorf("failed to write: %w", err)
54 | }
55 | return nil
56 | }
57 |
58 | func (t *mockServerTransport) SetReceiver(receiver serverReceiver) {
59 | t.receiver = receiver
60 | }
61 |
62 | func (t *mockServerTransport) SetSessionManager(m sessionManager) {
63 | t.sessionManager = m
64 | }
65 |
66 | func (t *mockServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error {
67 | t.cancel()
68 |
69 | if err := t.in.Close(); err != nil {
70 | return err
71 | }
72 |
73 | <-t.receiveShutDone
74 |
75 | select {
76 | case <-serverCtx.Done():
77 | return nil
78 | case <-userCtx.Done():
79 | return userCtx.Err()
80 | }
81 | }
82 |
83 | func (t *mockServerTransport) startReceive(ctx context.Context) {
84 | s := bufio.NewReader(t.in)
85 |
86 | for {
87 | line, err := s.ReadBytes('\n')
88 | if err != nil {
89 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here
90 | errors.Is(err, io.EOF) {
91 | return
92 | }
93 | t.logger.Errorf("client receive unexpected error reading input: %v", err)
94 | return
95 | }
96 |
97 | line = bytes.TrimRight(line, "\n")
98 |
99 | select {
100 | case <-ctx.Done():
101 | return
102 | default:
103 | t.receive(ctx, line)
104 | }
105 | }
106 | }
107 |
108 | func (t *mockServerTransport) receive(ctx context.Context, line []byte) {
109 | outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, line)
110 | if err != nil {
111 | t.logger.Errorf("receiver failed: %v", err)
112 | return
113 | }
114 |
115 | if outputMsgCh == nil {
116 | return
117 | }
118 |
119 | go func() {
120 | defer pkg.Recover()
121 |
122 | for msg := range outputMsgCh {
123 | if e := t.Send(context.Background(), t.sessionID, msg); e != nil {
124 | t.logger.Errorf("Failed to send message: %v", e)
125 | }
126 | }
127 | }()
128 | }
129 |
--------------------------------------------------------------------------------
/transport/mock_test.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "io"
5 | "testing"
6 | )
7 |
8 | func TestMockTransport(t *testing.T) {
9 | reader1, writer1 := io.Pipe()
10 | reader2, writer2 := io.Pipe()
11 |
12 | serverTransport := NewMockServerTransport(reader2, writer1)
13 | clientTransport := NewMockClientTransport(reader1, writer2)
14 |
15 | testTransport(t, clientTransport, serverTransport)
16 | }
17 |
--------------------------------------------------------------------------------
/transport/sse_client.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "net/url"
12 | "strings"
13 | "time"
14 |
15 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
16 | )
17 |
18 | type SSEClientTransportOption func(*sseClientTransport)
19 |
20 | func WithSSEClientOptionReceiveTimeout(timeout time.Duration) SSEClientTransportOption {
21 | return func(t *sseClientTransport) {
22 | t.receiveTimeout = timeout
23 | }
24 | }
25 |
26 | func WithSSEClientOptionHTTPClient(client *http.Client) SSEClientTransportOption {
27 | return func(t *sseClientTransport) {
28 | t.client = client
29 | }
30 | }
31 |
32 | func WithSSEClientOptionLogger(log pkg.Logger) SSEClientTransportOption {
33 | return func(t *sseClientTransport) {
34 | t.logger = log
35 | }
36 | }
37 |
38 | func WithRetryFunc(retry func(func() error)) SSEClientTransportOption {
39 | return func(t *sseClientTransport) {
40 | t.retry = retry
41 | }
42 | }
43 |
44 | type sseClientTransport struct {
45 | ctx context.Context
46 | cancel context.CancelFunc
47 |
48 | serverURL *url.URL
49 |
50 | endpointChan chan struct{}
51 | messageEndpoint *url.URL
52 | receiver clientReceiver
53 |
54 | // options
55 | logger pkg.Logger
56 | receiveTimeout time.Duration
57 | client *http.Client
58 |
59 | retry func(func() error)
60 |
61 | sseConnectClose chan struct{}
62 | }
63 |
64 | func NewSSEClientTransport(serverURL string, opts ...SSEClientTransportOption) (ClientTransport, error) {
65 | parsedURL, err := url.Parse(serverURL)
66 | if err != nil {
67 | return nil, fmt.Errorf("failed to parse server URL: %w", err)
68 | }
69 |
70 | t := &sseClientTransport{
71 | serverURL: parsedURL,
72 | endpointChan: make(chan struct{}, 1),
73 | messageEndpoint: nil,
74 | receiver: nil,
75 | logger: pkg.DefaultLogger,
76 | receiveTimeout: time.Second * 30,
77 | client: http.DefaultClient,
78 | sseConnectClose: make(chan struct{}),
79 | retry: func(operation func() error) {
80 | for {
81 | if e := operation(); e == nil {
82 | return
83 | }
84 | time.Sleep(100 * time.Millisecond)
85 | }
86 | },
87 | }
88 |
89 | for _, opt := range opts {
90 | opt(t)
91 | }
92 |
93 | return t, nil
94 | }
95 |
96 | func (t *sseClientTransport) Start() (err error) {
97 | ctx, cancel := context.WithCancel(context.Background())
98 | t.ctx = ctx
99 | t.cancel = cancel
100 |
101 | defer func() {
102 | if err != nil {
103 | t.cancel()
104 | }
105 | }()
106 |
107 | errChan := make(chan error, 1)
108 | go func() {
109 | defer pkg.Recover()
110 | defer close(t.sseConnectClose)
111 |
112 | t.retry(func() error {
113 | if e := t.startSSE(); e != nil {
114 | if errors.Is(e, context.Canceled) {
115 | return nil
116 | }
117 | t.logger.Errorf("startSSE: %+v", e)
118 | t.receiver.Interrupt(fmt.Errorf("SSE connection disconnection: %w", e))
119 | return e
120 | }
121 | return nil
122 | })
123 | }()
124 |
125 | // Wait for the endpoint to be received
126 | select {
127 | case <-t.endpointChan:
128 | // Endpoint received, proceed
129 | case err = <-errChan:
130 | return fmt.Errorf("error in SSE stream: %w", err)
131 | case <-time.After(10 * time.Second): // Add a timeout
132 | return fmt.Errorf("timeout waiting for endpoint")
133 | }
134 |
135 | return nil
136 | }
137 |
138 | func (t *sseClientTransport) startSSE() error {
139 | req, err := http.NewRequestWithContext(t.ctx, http.MethodGet, t.serverURL.String(), nil)
140 | if err != nil {
141 | return fmt.Errorf("failed to create request: %w", err)
142 | }
143 |
144 | req.Header.Set("Accept", "text/event-stream")
145 | req.Header.Set("Cache-Control", "no-cache")
146 | req.Header.Set("Connection", "keep-alive")
147 |
148 | resp, err := t.client.Do(req) //nolint:bodyclose
149 | if err != nil {
150 | return fmt.Errorf("failed to connect to SSE stream: %w", err)
151 | }
152 | defer resp.Body.Close()
153 |
154 | if resp.StatusCode != http.StatusOK {
155 | return fmt.Errorf("unexpected status code: %d, status: %s", resp.StatusCode, resp.Status)
156 | }
157 |
158 | return t.readSSE(resp.Body)
159 | }
160 |
161 | // readSSE continuously reads the SSE stream and processes events.
162 | // It runs until the connection is closed or an error occurs.
163 | func (t *sseClientTransport) readSSE(reader io.ReadCloser) error {
164 | defer func() {
165 | _ = reader.Close()
166 | }()
167 |
168 | br := bufio.NewReader(reader)
169 | var event, data string
170 |
171 | for {
172 | line, err := br.ReadString('\n')
173 | if err != nil {
174 | if err == io.EOF {
175 | // Process any pending event before exit
176 | if event != "" && data != "" {
177 | t.handleSSEEvent(event, data)
178 | }
179 | }
180 | select {
181 | case <-t.ctx.Done():
182 | return t.ctx.Err()
183 | default:
184 | return fmt.Errorf("SSE stream error: %w", err)
185 | }
186 | }
187 |
188 | // Remove only newline markers
189 | line = strings.TrimRight(line, "\r\n")
190 | if line == "" {
191 | // Empty line means end of event
192 | if event != "" && data != "" {
193 | t.handleSSEEvent(event, data)
194 | event = ""
195 | data = ""
196 | }
197 | continue
198 | }
199 |
200 | if strings.HasPrefix(line, "event:") {
201 | event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
202 | } else if strings.HasPrefix(line, "data:") {
203 | data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
204 | }
205 | }
206 | }
207 |
208 | // handleSSEEvent processes SSE events based on their type.
209 | // Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
210 | func (t *sseClientTransport) handleSSEEvent(event, data string) {
211 | switch event {
212 | case "endpoint":
213 | endpoint, err := t.serverURL.Parse(data)
214 | if err != nil {
215 | t.logger.Errorf("Error parsing endpoint URL: %v", err)
216 | return
217 | }
218 | t.logger.Debugf("Received endpoint: %s", endpoint.String())
219 | t.messageEndpoint = endpoint
220 | select {
221 | case t.endpointChan <- struct{}{}:
222 | default:
223 | }
224 | case "message":
225 | ctx, cancel := context.WithTimeout(t.ctx, t.receiveTimeout)
226 | defer cancel()
227 | if err := t.receiver.Receive(ctx, []byte(data)); err != nil {
228 | t.logger.Errorf("Error receive message: %v", err)
229 | return
230 | }
231 | }
232 | }
233 |
234 | func (t *sseClientTransport) Send(ctx context.Context, msg Message) error {
235 | t.logger.Debugf("Sending message: %s to %s", msg, t.messageEndpoint.String())
236 |
237 | var (
238 | err error
239 | req *http.Request
240 | resp *http.Response
241 | )
242 |
243 | req, err = http.NewRequestWithContext(ctx, http.MethodPost, t.messageEndpoint.String(), bytes.NewReader(msg))
244 | if err != nil {
245 | return fmt.Errorf("failed to create request: %w", err)
246 | }
247 |
248 | req.Header.Set("Content-Type", "application/json")
249 |
250 | if resp, err = t.client.Do(req); err != nil {
251 | return fmt.Errorf("failed to send message: %w", err)
252 | }
253 | defer resp.Body.Close()
254 |
255 | if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
256 | return fmt.Errorf("unexpected status code: %d, status: %s", resp.StatusCode, resp.Status)
257 | }
258 |
259 | return nil
260 | }
261 |
262 | func (t *sseClientTransport) SetReceiver(receiver clientReceiver) {
263 | t.receiver = receiver
264 | }
265 |
266 | func (t *sseClientTransport) Close() error {
267 | t.cancel()
268 |
269 | <-t.sseConnectClose
270 |
271 | return nil
272 | }
273 |
--------------------------------------------------------------------------------
/transport/sse_test.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "net"
7 | "net/http"
8 | "net/url"
9 | "testing"
10 | "time"
11 |
12 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
13 | )
14 |
15 | func TestSSE(t *testing.T) {
16 | var (
17 | err error
18 | svr ServerTransport
19 | client ClientTransport
20 | )
21 |
22 | // Get an available port
23 | port, err := getAvailablePort()
24 | if err != nil {
25 | t.Fatalf("Failed to get available port: %v", err)
26 | }
27 |
28 | serverAddr := fmt.Sprintf("127.0.0.1:%d", port)
29 | serverURL := fmt.Sprintf("http://%s/sse", serverAddr)
30 |
31 | if svr, err = NewSSEServerTransport(serverAddr); err != nil {
32 | t.Fatalf("NewSSEServerTransport failed: %v", err)
33 | }
34 |
35 | if client, err = NewSSEClientTransport(serverURL); err != nil {
36 | t.Fatalf("NewSSEClientTransport failed: %v", err)
37 | }
38 |
39 | testTransport(t, client, svr)
40 | }
41 |
42 | func TestSSEHandler(t *testing.T) {
43 | var (
44 | messageURL = "/message"
45 | port int
46 |
47 | err error
48 | svr ServerTransport
49 | client ClientTransport
50 | )
51 |
52 | // Get an available port
53 | port, err = getAvailablePort()
54 | if err != nil {
55 | t.Fatalf("Failed to get available port: %v", err)
56 | }
57 |
58 | serverAddr := fmt.Sprintf("http://127.0.0.1:%d", port)
59 | serverURL := fmt.Sprintf("%s/sse", serverAddr)
60 |
61 | svr, handler, err := NewSSEServerTransportAndHandler(fmt.Sprintf("%s%s", serverAddr, messageURL))
62 | if err != nil {
63 | t.Fatalf("NewSSEServerTransport failed: %v", err)
64 | }
65 |
66 | // 设置 HTTP 路由
67 | http.Handle("/sse", handler.HandleSSE())
68 | http.Handle(messageURL, handler.HandleMessage())
69 |
70 | errCh := make(chan error, 1)
71 | go func() {
72 | if e := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); e != nil {
73 | log.Fatalf("Failed to start HTTP server: %v", e)
74 | }
75 | }()
76 |
77 | // Use select to handle potential errors
78 | select {
79 | case err = <-errCh:
80 | t.Fatalf("http.ListenAndServe() failed: %v", err)
81 | case <-time.After(time.Second):
82 | // Server started normally
83 | }
84 |
85 | if client, err = NewSSEClientTransport(serverURL); err != nil {
86 | t.Fatalf("NewSSEClientTransport failed: %v", err)
87 | }
88 |
89 | testTransport(t, client, svr)
90 | }
91 |
92 | // getAvailablePort returns a port that is available for use
93 | func getAvailablePort() (int, error) {
94 | addr, err := net.Listen("tcp", "127.0.0.1:0")
95 | if err != nil {
96 | return 0, fmt.Errorf("failed to get available port: %v", err)
97 | }
98 | defer func() {
99 | if err = addr.Close(); err != nil {
100 | fmt.Println(err)
101 | }
102 | }()
103 |
104 | port := addr.Addr().(*net.TCPAddr).Port
105 | return port, nil
106 | }
107 |
108 | func Test_joinPath(t *testing.T) {
109 | type args struct {
110 | u *url.URL
111 | elem []string
112 | }
113 | tests := []struct {
114 | name string
115 | args args
116 | want string
117 | }{
118 | {
119 | name: "1",
120 | args: args{
121 | u: func() *url.URL {
122 | uri, err := url.Parse("https://google.com/api/v1")
123 | if err != nil {
124 | panic(err)
125 | }
126 | return uri
127 | }(),
128 | elem: []string{"/test"},
129 | },
130 | want: "https://google.com/api/v1/test",
131 | },
132 | {
133 | name: "2",
134 | args: args{
135 | u: func() *url.URL {
136 | uri, err := url.Parse("/api/v1")
137 | if err != nil {
138 | panic(err)
139 | }
140 | return uri
141 | }(),
142 | elem: []string{"/test"},
143 | },
144 | want: "/api/v1/test",
145 | },
146 | }
147 | for _, tt := range tests {
148 | t.Run(tt.name, func(t *testing.T) {
149 | joinPath(tt.args.u, tt.args.elem...)
150 | if got := tt.args.u.String(); got != tt.want {
151 | t.Errorf("joinPath() = %v, want %v", got, tt.want)
152 | }
153 | })
154 | }
155 | }
156 |
157 | func Test_sseClientTransport_handleSSEEvent(t1 *testing.T) {
158 | type fields struct {
159 | serverURL *url.URL
160 | logger pkg.Logger
161 | }
162 | type args struct {
163 | event string
164 | data string
165 | }
166 | tests := []struct {
167 | name string
168 | fields fields
169 | args args
170 | want string
171 | }{
172 | {
173 | name: "1",
174 | fields: fields{
175 | serverURL: func() *url.URL {
176 | uri, err := url.Parse("https://api.baidu.com/mcp")
177 | if err != nil {
178 | panic(err)
179 | }
180 | return uri
181 | }(),
182 | logger: pkg.DefaultLogger,
183 | },
184 | args: args{
185 | event: "endpoint",
186 | data: "/sse/messages",
187 | },
188 | want: "https://api.baidu.com/sse/messages",
189 | },
190 | {
191 | name: "2",
192 | fields: fields{
193 | serverURL: func() *url.URL {
194 | uri, err := url.Parse("https://api.baidu.com/mcp")
195 | if err != nil {
196 | panic(err)
197 | }
198 | return uri
199 | }(),
200 | logger: pkg.DefaultLogger,
201 | },
202 | args: args{
203 | event: "endpoint",
204 | data: "https://api.google.com/sse/messages",
205 | },
206 | want: "https://api.google.com/sse/messages",
207 | },
208 | }
209 | for _, tt := range tests {
210 | t1.Run(tt.name, func(t1 *testing.T) {
211 | t := &sseClientTransport{
212 | serverURL: tt.fields.serverURL,
213 | logger: tt.fields.logger,
214 | endpointChan: make(chan struct{}),
215 | }
216 | t.handleSSEEvent(tt.args.event, tt.args.data)
217 | if t.messageEndpoint.String() != tt.want {
218 | t1.Errorf("handleSSEEvent() = %v, want %v", t.messageEndpoint.String(), tt.want)
219 | }
220 | })
221 | }
222 | }
223 |
--------------------------------------------------------------------------------
/transport/stdio_client.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "os"
11 | "os/exec"
12 | "sync"
13 |
14 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
15 | )
16 |
17 | type StdioClientTransportOption func(*stdioClientTransport)
18 |
19 | func WithStdioClientOptionLogger(log pkg.Logger) StdioClientTransportOption {
20 | return func(t *stdioClientTransport) {
21 | t.logger = log
22 | }
23 | }
24 |
25 | func WithStdioClientOptionEnv(env ...string) StdioClientTransportOption {
26 | return func(t *stdioClientTransport) {
27 | t.cmd.Env = append(t.cmd.Env, env...)
28 | }
29 | }
30 |
31 | const mcpMessageDelimiter = '\n'
32 |
33 | type stdioClientTransport struct {
34 | cmd *exec.Cmd
35 | receiver clientReceiver
36 | reader io.Reader
37 | writer io.WriteCloser
38 | errReader io.Reader
39 |
40 | logger pkg.Logger
41 |
42 | wg sync.WaitGroup
43 | cancel context.CancelFunc
44 | }
45 |
46 | func NewStdioClientTransport(command string, args []string, opts ...StdioClientTransportOption) (ClientTransport, error) {
47 | cmd := exec.Command(command, args...)
48 |
49 | cmd.Env = os.Environ()
50 |
51 | stdin, err := cmd.StdinPipe()
52 | if err != nil {
53 | return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
54 | }
55 |
56 | stdout, err := cmd.StdoutPipe()
57 | if err != nil {
58 | return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
59 | }
60 |
61 | stderr, err := cmd.StderrPipe()
62 | if err != nil {
63 | return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
64 | }
65 |
66 | t := &stdioClientTransport{
67 | cmd: cmd,
68 | reader: stdout,
69 | writer: stdin,
70 | errReader: stderr,
71 |
72 | logger: pkg.DefaultLogger,
73 | }
74 |
75 | for _, opt := range opts {
76 | opt(t)
77 | }
78 | return t, nil
79 | }
80 |
81 | func (t *stdioClientTransport) Start() error {
82 | if err := t.cmd.Start(); err != nil {
83 | return fmt.Errorf("failed to start command: %w", err)
84 | }
85 |
86 | innerCtx, cancel := context.WithCancel(context.Background())
87 | t.cancel = cancel
88 |
89 | t.wg.Add(1)
90 | go func() {
91 | defer pkg.Recover()
92 | defer t.wg.Done()
93 |
94 | t.startReceive(innerCtx)
95 | }()
96 |
97 | t.wg.Add(1)
98 | go func() {
99 | defer pkg.Recover()
100 | defer t.wg.Done()
101 |
102 | t.startReceiveErr(innerCtx)
103 | }()
104 |
105 | return nil
106 | }
107 |
108 | func (t *stdioClientTransport) Send(_ context.Context, msg Message) error {
109 | _, err := t.writer.Write(append(msg, mcpMessageDelimiter))
110 | return err
111 | }
112 |
113 | func (t *stdioClientTransport) SetReceiver(receiver clientReceiver) {
114 | t.receiver = receiver
115 | }
116 |
117 | func (t *stdioClientTransport) Close() error {
118 | t.cancel()
119 |
120 | if err := t.writer.Close(); err != nil {
121 | return fmt.Errorf("failed to close writer: %w", err)
122 | }
123 |
124 | if err := t.cmd.Wait(); err != nil {
125 | return err
126 | }
127 |
128 | t.wg.Wait()
129 |
130 | return nil
131 | }
132 |
133 | func (t *stdioClientTransport) startReceive(ctx context.Context) {
134 | s := bufio.NewReader(t.reader)
135 |
136 | for {
137 | line, err := s.ReadBytes('\n')
138 | if err != nil {
139 | t.receiver.Interrupt(fmt.Errorf("stdout read error: %w", err))
140 |
141 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here
142 | errors.Is(err, io.EOF) {
143 | return
144 | }
145 | t.logger.Errorf("stdout read error: %+v", err)
146 | return
147 | }
148 |
149 | line = bytes.TrimRight(line, "\n")
150 | // filter empty messages
151 | // filter space messages and \t messages
152 | if len(bytes.TrimFunc(line, func(r rune) bool { return r == ' ' || r == '\t' })) == 0 {
153 | t.logger.Debugf("skipping empty message")
154 | continue
155 | }
156 |
157 | select {
158 | case <-ctx.Done():
159 | return
160 | default:
161 | if err = t.receiver.Receive(ctx, line); err != nil {
162 | t.logger.Errorf("receiver failed: %v", err)
163 | }
164 | }
165 | }
166 | }
167 |
168 | func (t *stdioClientTransport) startReceiveErr(ctx context.Context) {
169 | s := bufio.NewReader(t.errReader)
170 |
171 | for {
172 | line, err := s.ReadBytes('\n')
173 | if err != nil {
174 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here
175 | errors.Is(err, io.EOF) {
176 | return
177 | }
178 | t.logger.Errorf("client receive unexpected server error reading input: %v", err)
179 | return
180 | }
181 |
182 | line = bytes.TrimRight(line, "\n")
183 | // filter empty messages
184 | // filter space messages and \t messages
185 | if len(bytes.TrimFunc(line, func(r rune) bool { return r == ' ' || r == '\t' })) == 0 {
186 | t.logger.Debugf("skipping empty message")
187 | continue
188 | }
189 |
190 | select {
191 | case <-ctx.Done():
192 | return
193 | default:
194 | t.logger.Infof("receive server info: %s", line)
195 | }
196 | }
197 | }
198 |
--------------------------------------------------------------------------------
/transport/stdio_server.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "os"
11 |
12 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
13 | )
14 |
15 | type StdioServerTransportOption func(*stdioServerTransport)
16 |
17 | func WithStdioServerOptionLogger(log pkg.Logger) StdioServerTransportOption {
18 | return func(t *stdioServerTransport) {
19 | t.logger = log
20 | }
21 | }
22 |
23 | type stdioServerTransport struct {
24 | receiver serverReceiver
25 | reader io.ReadCloser
26 | writer io.Writer
27 |
28 | sessionManager sessionManager
29 | sessionID string
30 |
31 | logger pkg.Logger
32 |
33 | cancel context.CancelFunc
34 | receiveShutDone chan struct{}
35 | }
36 |
37 | func NewStdioServerTransport(opts ...StdioServerTransportOption) ServerTransport {
38 | t := &stdioServerTransport{
39 | reader: os.Stdin,
40 | writer: os.Stdout,
41 | logger: pkg.DefaultLogger,
42 |
43 | receiveShutDone: make(chan struct{}),
44 | }
45 |
46 | for _, opt := range opts {
47 | opt(t)
48 | }
49 | return t
50 | }
51 |
52 | func (t *stdioServerTransport) Run() error {
53 | ctx, cancel := context.WithCancel(context.Background())
54 | t.cancel = cancel
55 |
56 | t.sessionID = t.sessionManager.CreateSession(context.Background())
57 |
58 | t.startReceive(ctx)
59 |
60 | close(t.receiveShutDone)
61 | return nil
62 | }
63 |
64 | func (t *stdioServerTransport) Send(_ context.Context, _ string, msg Message) error {
65 | if _, err := t.writer.Write(append(msg, mcpMessageDelimiter)); err != nil {
66 | return fmt.Errorf("failed to write: %w", err)
67 | }
68 | return nil
69 | }
70 |
71 | func (t *stdioServerTransport) SetReceiver(receiver serverReceiver) {
72 | t.receiver = receiver
73 | }
74 |
75 | func (t *stdioServerTransport) SetSessionManager(m sessionManager) {
76 | t.sessionManager = m
77 | }
78 |
79 | func (t *stdioServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error {
80 | t.cancel()
81 |
82 | if err := t.reader.Close(); err != nil {
83 | return err
84 | }
85 |
86 | select {
87 | case <-t.receiveShutDone:
88 | return nil
89 | case <-serverCtx.Done():
90 | return nil
91 | case <-userCtx.Done():
92 | return userCtx.Err()
93 | }
94 | }
95 |
96 | func (t *stdioServerTransport) startReceive(ctx context.Context) {
97 | s := bufio.NewReader(t.reader)
98 |
99 | for {
100 | line, err := s.ReadBytes('\n')
101 | if err != nil {
102 | if errors.Is(err, io.ErrClosedPipe) || // This error occurs during unit tests, suppressing it here
103 | errors.Is(err, io.EOF) {
104 | return
105 | }
106 | t.logger.Errorf("client receive unexpected error reading input: %v", err)
107 | }
108 | line = bytes.TrimRight(line, "\n")
109 |
110 | select {
111 | case <-ctx.Done():
112 | return
113 | default:
114 | t.receive(ctx, line)
115 | }
116 | }
117 | }
118 |
119 | func (t *stdioServerTransport) receive(ctx context.Context, line []byte) {
120 | outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, line)
121 | if err != nil {
122 | t.logger.Errorf("receiver failed: %v", err)
123 | return
124 | }
125 |
126 | if outputMsgCh == nil {
127 | return
128 | }
129 |
130 | go func() {
131 | defer pkg.Recover()
132 |
133 | for msg := range outputMsgCh {
134 | if e := t.Send(context.Background(), t.sessionID, msg); e != nil {
135 | t.logger.Errorf("Failed to send message: %v", e)
136 | }
137 | }
138 | }()
139 | }
140 |
--------------------------------------------------------------------------------
/transport/stdio_test.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "math/rand"
7 | "os"
8 | "os/exec"
9 | "path/filepath"
10 | "strconv"
11 | "testing"
12 | "time"
13 | )
14 |
15 | type mock struct {
16 | reader *io.PipeReader
17 | writer *io.PipeWriter
18 | closer io.Closer
19 | }
20 |
21 | func (m *mock) Write(p []byte) (n int, err error) {
22 | return m.writer.Write(p)
23 | }
24 |
25 | func (m *mock) Close() error {
26 | if err := m.writer.Close(); err != nil {
27 | return err
28 | }
29 | if err := m.reader.Close(); err != nil {
30 | return err
31 | }
32 | if err := m.closer.Close(); err != nil {
33 | return err
34 | }
35 | return nil
36 | }
37 |
38 | func TestStdioTransport(t *testing.T) {
39 | var (
40 | err error
41 | server *stdioServerTransport
42 | client *stdioClientTransport
43 | )
44 |
45 | r := rand.New(rand.NewSource(time.Now().UnixNano()))
46 |
47 | mockServerTrPath := filepath.Join(os.TempDir(), "mock_server_tr_"+strconv.Itoa(r.Int()))
48 | if err = compileMockStdioServerTr(mockServerTrPath); err != nil {
49 | t.Fatalf("Failed to compile mock server: %v", err)
50 | }
51 |
52 | defer func(name string) {
53 | if err = os.Remove(name); err != nil {
54 | fmt.Printf("Failed to remove mock server: %v\n", err)
55 | }
56 | }(mockServerTrPath)
57 |
58 | clientT, err := NewStdioClientTransport(mockServerTrPath, []string{})
59 | if err != nil {
60 | t.Fatalf("NewStdioClientTransport failed: %v", err)
61 | }
62 |
63 | client = clientT.(*stdioClientTransport)
64 | server = NewStdioServerTransport().(*stdioServerTransport)
65 |
66 | // Create pipes for communication
67 | reader1, writer1 := io.Pipe()
68 | reader2, writer2 := io.Pipe()
69 |
70 | // Set up the communication channels
71 | server.reader = reader2
72 | server.writer = writer1
73 | client.reader = reader1
74 | client.writer = &mock{
75 | reader: reader1,
76 | writer: writer2,
77 | closer: client.writer,
78 | }
79 |
80 | testTransport(t, client, server)
81 | }
82 |
83 | func compileMockStdioServerTr(outputPath string) error {
84 | cmd := exec.Command("go", "build", "-o", outputPath, "../testdata/mock_block_server.go")
85 |
86 | if output, err := cmd.CombinedOutput(); err != nil {
87 | return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output)
88 | }
89 |
90 | return nil
91 | }
92 |
--------------------------------------------------------------------------------
/transport/streamable_http_test.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 | )
7 |
8 | func TestStreamableHTTP(t *testing.T) {
9 | var (
10 | err error
11 | svr ServerTransport
12 | client ClientTransport
13 | )
14 |
15 | // Get an available port
16 | port, err := getAvailablePort()
17 | if err != nil {
18 | t.Fatalf("Failed to get available port: %v", err)
19 | }
20 |
21 | serverAddr := fmt.Sprintf("127.0.0.1:%d", port)
22 | serverURL := fmt.Sprintf("http://%s/mcp", serverAddr)
23 |
24 | svr = NewStreamableHTTPServerTransport(serverAddr)
25 |
26 | if client, err = NewStreamableHTTPClientTransport(serverURL); err != nil {
27 | t.Fatalf("NewStreamableHTTPClientTransport failed: %v", err)
28 | }
29 |
30 | testTransport(t, client, svr)
31 | }
32 |
--------------------------------------------------------------------------------
/transport/transport.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
7 | )
8 |
9 | /*
10 | * Transport is an abstraction of the underlying transport layer.
11 | * GO-MCP needs to be able to transmit JSON-RPC messages between server and client.
12 | */
13 |
14 | // Message defines the basic message interface
15 | type Message []byte
16 |
17 | func (msg Message) String() string {
18 | return pkg.B2S(msg)
19 | }
20 |
21 | type ClientTransport interface {
22 | // Start initiates the transport connection
23 | Start() error
24 |
25 | // Send transmits a message
26 | Send(ctx context.Context, msg Message) error
27 |
28 | // SetReceiver sets the handler for messages from the peer
29 | SetReceiver(receiver clientReceiver)
30 |
31 | // Close terminates the transport connection
32 | Close() error
33 | }
34 |
35 | type clientReceiver interface {
36 | Receive(ctx context.Context, msg []byte) error
37 | Interrupt(err error)
38 | }
39 |
40 | type ClientReceiver struct {
41 | receive func(ctx context.Context, msg []byte) error
42 | interrupt func(err error)
43 | }
44 |
45 | func (r *ClientReceiver) Receive(ctx context.Context, msg []byte) error {
46 | return r.receive(ctx, msg)
47 | }
48 |
49 | func (r *ClientReceiver) Interrupt(err error) {
50 | r.interrupt(err)
51 | }
52 |
53 | func NewClientReceiver(receive func(ctx context.Context, msg []byte) error, interrupt func(err error)) clientReceiver {
54 | r := &ClientReceiver{
55 | receive: receive,
56 | interrupt: interrupt,
57 | }
58 | return r
59 | }
60 |
61 | type ServerTransport interface {
62 | // Run starts listening for requests, this is synchronous, and cannot return before Shutdown is called
63 | Run() error
64 |
65 | // Send transmits a message
66 | Send(ctx context.Context, sessionID string, msg Message) error
67 |
68 | // SetReceiver sets the handler for messages from the peer
69 | SetReceiver(serverReceiver)
70 |
71 | SetSessionManager(manager sessionManager)
72 |
73 | // Shutdown gracefully closes, the internal implementation needs to stop receiving messages first,
74 | // then wait for serverCtx to be canceled, while using userCtx to control timeout.
75 | // userCtx is used to control the timeout of the server shutdown.
76 | // serverCtx is used to coordinate the internal cleanup sequence:
77 | // 1. turn off message listen
78 | // 2. Wait for serverCtx to be done (indicating server shutdown is complete)
79 | // 3. Cancel the transport's context to stop all ongoing operations
80 | // 4. Wait for all in-flight sends to complete
81 | // 5. Close all session
82 | Shutdown(userCtx context.Context, serverCtx context.Context) error
83 | }
84 |
85 | type serverReceiver interface {
86 | Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error)
87 | }
88 |
89 | type ServerReceiverF func(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error)
90 |
91 | func (f ServerReceiverF) Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) {
92 | return f(ctx, sessionID, msg)
93 | }
94 |
95 | type sessionManager interface {
96 | CreateSession(context.Context) string
97 | OpenMessageQueueForSend(sessionID string) error
98 | EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error
99 | DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error)
100 | CloseSession(sessionID string)
101 | CloseAllSessions()
102 | }
103 |
--------------------------------------------------------------------------------
/transport/transport_test.go:
--------------------------------------------------------------------------------
1 | package transport
2 |
3 | import (
4 | "context"
5 | "reflect"
6 | "testing"
7 | "time"
8 |
9 | "github.com/google/uuid"
10 |
11 | "github.com/ThinkInAIXYZ/go-mcp/pkg"
12 | )
13 |
14 | type mockSessionManager struct {
15 | pkg.SyncMap[chan []byte]
16 | }
17 |
18 | func newMockSessionManager() *mockSessionManager {
19 | return &mockSessionManager{}
20 | }
21 |
22 | func (m *mockSessionManager) CreateSession(context.Context) string {
23 | sessionID := uuid.NewString()
24 | m.Store(sessionID, nil)
25 | return sessionID
26 | }
27 |
28 | func (m *mockSessionManager) OpenMessageQueueForSend(sessionID string) error {
29 | _, ok := m.Load(sessionID)
30 | if !ok {
31 | return pkg.ErrLackSession
32 | }
33 | m.Store(sessionID, make(chan []byte))
34 | return nil
35 | }
36 |
37 | func (m *mockSessionManager) IsExistSession(sessionID string) bool {
38 | _, has := m.Load(sessionID)
39 | return has
40 | }
41 |
42 | func (m *mockSessionManager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error {
43 | ch, has := m.Load(sessionID)
44 | if !has {
45 | return pkg.ErrLackSession
46 | }
47 |
48 | select {
49 | case ch <- message:
50 | return nil
51 | case <-ctx.Done():
52 | return ctx.Err()
53 | }
54 | }
55 |
56 | func (m *mockSessionManager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) {
57 | ch, has := m.Load(sessionID)
58 | if !has {
59 | return nil, pkg.ErrLackSession
60 | }
61 |
62 | select {
63 | case <-ctx.Done():
64 | return nil, ctx.Err()
65 | case msg, ok := <-ch:
66 | if msg == nil && !ok {
67 | // There are no new messages and the chan has been closed, indicating that the request may need to be terminated.
68 | return nil, pkg.ErrSendEOF
69 | }
70 | return msg, nil
71 | }
72 | }
73 |
74 | func (m *mockSessionManager) CloseSession(sessionID string) {
75 | ch, ok := m.LoadAndDelete(sessionID)
76 | if !ok {
77 | return
78 | }
79 | close(ch)
80 | }
81 |
82 | func (m *mockSessionManager) CloseAllSessions() {
83 | m.Range(func(key string, value chan []byte) bool {
84 | m.Delete(key)
85 | close(value)
86 | return true
87 | })
88 | }
89 |
90 | func testTransport(t *testing.T, client ClientTransport, server ServerTransport) {
91 | testMsg := "hello server"
92 | expectedMsgWithServerCh := make(chan string, 1)
93 | server.SetReceiver(ServerReceiverF(func(_ context.Context, _ string, msg []byte) (<-chan []byte, error) {
94 | expectedMsgWithServerCh <- string(msg)
95 | msgCh := make(chan []byte, 1)
96 | go func() {
97 | defer close(msgCh)
98 | msgCh <- msg
99 | }()
100 | return msgCh, nil
101 | }))
102 | server.SetSessionManager(newMockSessionManager())
103 |
104 | expectedMsgWithClientCh := make(chan string, 1)
105 | client.SetReceiver(NewClientReceiver(func(_ context.Context, msg []byte) error {
106 | expectedMsgWithClientCh <- string(msg)
107 | return nil
108 | }, func(_ error) {
109 | close(expectedMsgWithClientCh)
110 | }))
111 |
112 | errCh := make(chan error, 1)
113 | go func() {
114 | errCh <- server.Run()
115 | }()
116 |
117 | // Use select to handle potential errors
118 | select {
119 | case err := <-errCh:
120 | t.Fatalf("server.Run() failed: %v", err)
121 | case <-time.After(time.Second):
122 | // Server started normally
123 | }
124 |
125 | defer func() {
126 | if _, ok := server.(*stdioServerTransport); ok { // stdioServerTransport not support shutdown
127 | return
128 | }
129 |
130 | userCtx, cancel := context.WithTimeout(context.Background(), time.Second*1)
131 | defer cancel()
132 |
133 | serverCtx, cancel := context.WithCancel(userCtx)
134 | cancel()
135 |
136 | if err := server.Shutdown(userCtx, serverCtx); err != nil {
137 | t.Fatalf("server.Shutdown() failed: %v", err)
138 | }
139 | }()
140 |
141 | if err := client.Start(); err != nil {
142 | t.Fatalf("client.Run() failed: %v", err)
143 | }
144 |
145 | defer func() {
146 | if err := client.Close(); err != nil {
147 | t.Fatalf("client.Close() failed: %v", err)
148 | }
149 | }()
150 |
151 | if err := client.Send(context.Background(), Message(testMsg)); err != nil {
152 | t.Fatalf("client.Send() failed: %v", err)
153 | }
154 | expectedMsg := <-expectedMsgWithServerCh
155 | if !reflect.DeepEqual(expectedMsg, testMsg) {
156 | t.Fatalf("client.Send() got %v, want %v", expectedMsg, testMsg)
157 | }
158 | expectedMsg = <-expectedMsgWithClientCh
159 | if !reflect.DeepEqual(expectedMsg, testMsg) {
160 | t.Fatalf("server.Send() failed: got %v, want %v", expectedMsg, testMsg)
161 | }
162 | }
163 |
--------------------------------------------------------------------------------