├── .dockerignore
├── .gitignore
├── ui.png
├── go.mod
├── go.sum
├── prometheus.yml
├── .github
└── workflows
│ └── pull-request.yaml
├── docker-compose.yml
├── Dockerfile
├── LICENSE
├── router_config.yaml
├── pkg
├── session
│ ├── store.go
│ └── store_test.go
├── logger
│ └── logger.go
├── config
│ ├── config.go
│ └── config_test.go
├── ui
│ ├── ui_test.go
│ └── ui.go
└── router
│ ├── router_test.go
│ └── router.go
├── cmd
└── main.go
└── README.md
/.dockerignore:
--------------------------------------------------------------------------------
1 | vendor/
2 | .idea/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | vendor/
3 | *.log
--------------------------------------------------------------------------------
/ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mclenhard/catie-mcp/HEAD/ui.png
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/mclenhard/catie-mcp
2 |
3 | go 1.21.2
4 |
5 | require gopkg.in/yaml.v3 v3.0.1
6 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
2 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
4 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 |
--------------------------------------------------------------------------------
/prometheus.yml:
--------------------------------------------------------------------------------
1 | global:
2 | scrape_interval: 15s
3 | evaluation_interval: 15s
4 |
5 | scrape_configs:
6 | - job_name: 'mcp-router-proxy'
7 | scrape_interval: 5s
8 | metrics_path: /metrics
9 | # If you're using basic auth, uncomment and configure these lines:
10 | basic_auth:
11 | username: 'admin'
12 | password: 'your_secure_password'
13 | static_configs:
14 | - targets: ['mcp-router-proxy:80']
--------------------------------------------------------------------------------
/.github/workflows/pull-request.yaml:
--------------------------------------------------------------------------------
1 | name: Go Tests
2 |
3 | on:
4 | pull_request:
5 | branches: [ main ]
6 |
7 | jobs:
8 | test:
9 | name: Run Go Tests
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Check out code
14 | uses: actions/checkout@v3
15 |
16 | - name: Set up Go
17 | uses: actions/setup-go@v4
18 | with:
19 | go-version: '1.21' # Adjust this to your Go version
20 |
21 | - name: Install dependencies
22 | run: go mod download
23 |
24 | - name: Run tests
25 | run: go test -v ./...
26 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | mcp-router-proxy:
5 | build:
6 | context: .
7 | dockerfile: Dockerfile
8 | container_name: mcp-router-proxy
9 | ports:
10 | - "80:80"
11 | volumes:
12 | - ./router_config.yaml:/root/router_config.yaml
13 | restart: unless-stopped
14 | environment:
15 | - GOTOOLCHAIN=auto
16 | networks:
17 | - mcp-network
18 |
19 | prometheus:
20 | image: prom/prometheus:latest
21 | container_name: prometheus
22 | ports:
23 | - "9090:9090"
24 | volumes:
25 | - ./prometheus.yml:/etc/prometheus/prometheus.yml
26 | restart: unless-stopped
27 | networks:
28 | - mcp-network
29 | depends_on:
30 | - mcp-router-proxy
31 |
32 | networks:
33 | mcp-network:
34 | driver: bridge
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Start from a Go base image
2 | FROM golang:1.21-alpine AS builder
3 |
4 | # Set working directory
5 | WORKDIR /app
6 |
7 | # Copy go.mod and go.sum files
8 | COPY go.mod ./
9 | # COPY go.sum ./ # Uncomment if you have a go.sum file
10 |
11 | # Download dependencies
12 | RUN go mod download
13 |
14 | # Copy the source code
15 | COPY . .
16 |
17 | # Build the application
18 | RUN CGO_ENABLED=0 GOOS=linux go build -o mcp-router-proxy ./cmd/main.go
19 |
20 | # Use a minimal alpine image for the final container
21 | FROM alpine:latest
22 |
23 | # Install ca-certificates for HTTPS requests
24 | RUN apk --no-cache add ca-certificates
25 |
26 | WORKDIR /root/
27 |
28 | # Copy the binary from the builder stage
29 | COPY --from=builder /app/mcp-router-proxy .
30 | # Copy the configuration file
31 | COPY router_config.yaml .
32 |
33 | # Expose the port the app runs on
34 | EXPOSE 80
35 |
36 | # Command to run the executable
37 | CMD ["./mcp-router-proxy"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MCP Router Proxy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/router_config.yaml:
--------------------------------------------------------------------------------
1 | # router_config.yaml
2 | resources:
3 | # Route file resources to file service
4 | "^file://.*": "http://file-service:8080/rpc"
5 | # Route database resources to DB service
6 | "^db://.*": "http://database-service:8080/rpc"
7 | # Route API resources to API gateway
8 | "^api://.*": "http://api-gateway:8080/rpc"
9 | # Route S3 resources to storage service
10 | "^s3://.*": "http://storage-service:8080/rpc"
11 | "^weather/.*": "http://weather-service:8080/mcp"
12 | "^database/.*": "http://database-service:8080/mcp"
13 |
14 | tools:
15 | # Route code generation tools to AI service
16 | "^code-.*": "http://ai-service:8080/rpc"
17 | # Route data processing tools to data service
18 | "^data-.*": "http://data-processing:8080/rpc"
19 | # Route search tools to search service
20 | "^search.*": "http://search-service:8080/rpc"
21 | # Route authentication tools to auth service
22 | "^auth.*": "http://auth-service:8080/rpc"
23 | "^calculator$": "http://calculator-service:8080/mcp"
24 | "^translator$": "http://translator-service:8080/mcp"
25 |
26 | # Default service to route to if no patterns match
27 | default: "http://default-service:8080/mcp"
28 |
29 | # UI authentication settings
30 | ui:
31 | username: "admin"
32 | password: "your_secure_password"
33 |
34 | toolMappings:
35 | - originalName: "weather"
36 | targetName: "getWeather"
37 | target: "http://weather-service:8080/mcp"
38 | - originalName: "search"
39 | targetName: "googleSearch"
40 | target: "http://search-service:8080/mcp"
41 |
--------------------------------------------------------------------------------
/pkg/session/store.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "sync"
5 | "time"
6 | )
7 |
8 | // Store manages mappings between session IDs and target URLs
9 | type Store struct {
10 | sessions map[string]sessionInfo
11 | mu sync.RWMutex
12 | }
13 |
14 | type sessionInfo struct {
15 | target string
16 | lastUsed time.Time
17 | }
18 |
19 | // NewStore creates a new session store
20 | func NewStore() *Store {
21 | store := &Store{
22 | sessions: make(map[string]sessionInfo),
23 | }
24 |
25 | // Start a background goroutine to clean up expired sessions
26 | go store.cleanupLoop()
27 |
28 | return store
29 | }
30 |
31 | // Get retrieves the target URL for a session ID
32 | func (s *Store) Get(sessionID string) (string, bool) {
33 | s.mu.RLock()
34 | info, exists := s.sessions[sessionID]
35 | s.mu.RUnlock()
36 |
37 | if exists {
38 | // Update the last used time
39 | s.mu.Lock()
40 | info.lastUsed = time.Now()
41 | s.sessions[sessionID] = info
42 | s.mu.Unlock()
43 | }
44 |
45 | return info.target, exists
46 | }
47 |
48 | // Set stores a mapping between a session ID and a target URL
49 | func (s *Store) Set(sessionID, target string) {
50 | s.mu.Lock()
51 | defer s.mu.Unlock()
52 |
53 | s.sessions[sessionID] = sessionInfo{
54 | target: target,
55 | lastUsed: time.Now(),
56 | }
57 | }
58 |
59 | // Remove deletes a session mapping
60 | func (s *Store) Remove(sessionID string) {
61 | s.mu.Lock()
62 | defer s.mu.Unlock()
63 |
64 | delete(s.sessions, sessionID)
65 | }
66 |
67 | // cleanupLoop periodically removes expired sessions
68 | func (s *Store) cleanupLoop() {
69 | ticker := time.NewTicker(10 * time.Minute)
70 | defer ticker.Stop()
71 |
72 | for range ticker.C {
73 | s.cleanup()
74 | }
75 | }
76 |
77 | // cleanup removes sessions that haven't been used for a while
78 | func (s *Store) cleanup() {
79 | s.mu.Lock()
80 | defer s.mu.Unlock()
81 |
82 | expireTime := time.Now().Add(-24 * time.Hour) // Sessions expire after 24 hours of inactivity
83 |
84 | for id, info := range s.sessions {
85 | if info.lastUsed.Before(expireTime) {
86 | delete(s.sessions, id)
87 | }
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/pkg/logger/logger.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "os"
7 | "time"
8 | )
9 |
10 | // Level represents the severity level of a log message
11 | type Level int
12 |
13 | const (
14 | // Debug level for detailed information
15 | Debug Level = iota
16 | // Info level for general operational information
17 | Info
18 | // Warn level for warning conditions
19 | Warn
20 | // Error level for error conditions
21 | Error
22 | // Fatal level for fatal conditions
23 | Fatal
24 | )
25 |
26 | var levelNames = map[Level]string{
27 | Debug: "DEBUG",
28 | Info: "INFO",
29 | Warn: "WARN",
30 | Error: "ERROR",
31 | Fatal: "FATAL",
32 | }
33 |
34 | // Logger is a simple structured logger
35 | type Logger struct {
36 | level Level
37 | logger *log.Logger
38 | }
39 |
40 | // New creates a new logger with the specified level
41 | func New(level Level) *Logger {
42 | return &Logger{
43 | level: level,
44 | logger: log.New(os.Stdout, "", 0),
45 | }
46 | }
47 |
48 | // SetLevel changes the logger's level
49 | func (l *Logger) SetLevel(level Level) {
50 | l.level = level
51 | }
52 |
53 | // log logs a message at the specified level
54 | func (l *Logger) log(level Level, format string, args ...interface{}) {
55 | if level < l.level {
56 | return
57 | }
58 |
59 | timestamp := time.Now().Format("2006-01-02 15:04:05.000")
60 | levelName := levelNames[level]
61 | message := fmt.Sprintf(format, args...)
62 |
63 | l.logger.Printf("[%s] %s: %s", timestamp, levelName, message)
64 |
65 | if level == Fatal {
66 | os.Exit(1)
67 | }
68 | }
69 |
70 | // Debug logs a debug message
71 | func (l *Logger) Debug(format string, args ...interface{}) {
72 | l.log(Debug, format, args...)
73 | }
74 |
75 | // Info logs an info message
76 | func (l *Logger) Info(format string, args ...interface{}) {
77 | l.log(Info, format, args...)
78 | }
79 |
80 | // Warn logs a warning message
81 | func (l *Logger) Warn(format string, args ...interface{}) {
82 | l.log(Warn, format, args...)
83 | }
84 |
85 | // Error logs an error message
86 | func (l *Logger) Error(format string, args ...interface{}) {
87 | l.log(Error, format, args...)
88 | }
89 |
90 | // Fatal logs a fatal message and exits
91 | func (l *Logger) Fatal(format string, args ...interface{}) {
92 | l.log(Fatal, format, args...)
93 | }
94 |
--------------------------------------------------------------------------------
/cmd/main.go:
--------------------------------------------------------------------------------
1 | // Main entry point for the MCP router proxy
2 | package main
3 |
4 | import (
5 | "context"
6 | "flag"
7 | "log"
8 | "net/http"
9 | "os"
10 | "os/signal"
11 | "syscall"
12 | "time"
13 |
14 | "github.com/mclenhard/catie-mcp/pkg/config"
15 | "github.com/mclenhard/catie-mcp/pkg/router"
16 | "github.com/mclenhard/catie-mcp/pkg/ui"
17 | )
18 |
19 | var (
20 | configPath = flag.String("config", "router_config.yaml", "Path to configuration file")
21 | serverPort = flag.String("port", ":80", "Server port")
22 | configInterval = flag.Duration("config-interval", 30*time.Second, "Configuration reload interval")
23 | shutdownTimeout = flag.Duration("shutdown-timeout", 30*time.Second, "Graceful shutdown timeout")
24 | )
25 |
26 | func main() {
27 | flag.Parse()
28 |
29 | // Initialize configuration
30 | cfg := config.New(*configPath)
31 |
32 | // Start config watcher in background
33 | go cfg.WatchConfig(*configPath, *configInterval)
34 |
35 | // Initialize UI
36 | uiHandler := ui.New(cfg)
37 |
38 | // Initialize router with UI
39 | r := router.New(cfg, uiHandler)
40 |
41 | // Set up HTTP server
42 | mux := http.NewServeMux()
43 | mux.HandleFunc("/mcp", r.HandleMCPRequest)
44 | mux.HandleFunc("/health", handleHealth)
45 |
46 | // Register UI handlers with our mux
47 | uiHandler.RegisterHandlers(mux)
48 |
49 | server := &http.Server{
50 | Addr: *serverPort,
51 | Handler: mux,
52 | }
53 |
54 | // Start server in a goroutine
55 | go func() {
56 | log.Printf("MCP router proxy listening on %s", *serverPort)
57 | log.Printf("Stats UI available at http://localhost%s/stats", *serverPort)
58 | if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
59 | log.Fatalf("Server error: %v", err)
60 | }
61 | }()
62 |
63 | // Set up graceful shutdown
64 | quit := make(chan os.Signal, 1)
65 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
66 | <-quit
67 | log.Println("Shutting down server...")
68 |
69 | // Create a deadline for shutdown
70 | ctx, cancel := context.WithTimeout(context.Background(), *shutdownTimeout)
71 | defer cancel()
72 |
73 | // Attempt graceful shutdown
74 | if err := server.Shutdown(ctx); err != nil {
75 | log.Fatalf("Server forced to shutdown: %v", err)
76 | }
77 |
78 | log.Println("Server exited gracefully")
79 | }
80 |
81 | // handleHealth is a simple health check endpoint
82 | func handleHealth(w http.ResponseWriter, r *http.Request) {
83 | w.WriteHeader(http.StatusOK)
84 | w.Write([]byte("OK"))
85 | }
86 |
--------------------------------------------------------------------------------
/pkg/session/store_test.go:
--------------------------------------------------------------------------------
1 | package session
2 |
3 | import (
4 | "testing"
5 | "time"
6 | )
7 |
8 | func TestStore_SetAndGet(t *testing.T) {
9 | store := NewStore()
10 |
11 | // Test setting and getting a session
12 | sessionID := "test-session-id"
13 | targetURL := "https://example.com"
14 |
15 | store.Set(sessionID, targetURL)
16 |
17 | got, exists := store.Get(sessionID)
18 | if !exists {
19 | t.Errorf("Get(%q) returned exists=false, want true", sessionID)
20 | }
21 | if got != targetURL {
22 | t.Errorf("Get(%q) = %q, want %q", sessionID, got, targetURL)
23 | }
24 | }
25 |
26 | func TestStore_Remove(t *testing.T) {
27 | store := NewStore()
28 |
29 | // Set up a session
30 | sessionID := "test-session-id"
31 | targetURL := "https://example.com"
32 | store.Set(sessionID, targetURL)
33 |
34 | // Verify it exists
35 | _, exists := store.Get(sessionID)
36 | if !exists {
37 | t.Fatalf("Session should exist before removal")
38 | }
39 |
40 | // Remove the session
41 | store.Remove(sessionID)
42 |
43 | // Verify it no longer exists
44 | _, exists = store.Get(sessionID)
45 | if exists {
46 | t.Errorf("Session should not exist after removal")
47 | }
48 | }
49 |
50 | func TestStore_Cleanup(t *testing.T) {
51 | store := NewStore()
52 |
53 | // Set up a session
54 | sessionID := "test-session-id"
55 | targetURL := "https://example.com"
56 | store.Set(sessionID, targetURL)
57 |
58 | // Manually modify the lastUsed time to be older than the expiration time
59 | store.mu.Lock()
60 | info := store.sessions[sessionID]
61 | info.lastUsed = time.Now().Add(-25 * time.Hour) // Older than the 24-hour expiration
62 | store.sessions[sessionID] = info
63 | store.mu.Unlock()
64 |
65 | // Run cleanup
66 | store.cleanup()
67 |
68 | // Verify the session was removed
69 | _, exists := store.Get(sessionID)
70 | if exists {
71 | t.Errorf("Session should have been cleaned up")
72 | }
73 | }
74 |
75 | func TestStore_GetUpdatesLastUsed(t *testing.T) {
76 | store := NewStore()
77 |
78 | // Set up a session
79 | sessionID := "test-session-id"
80 | targetURL := "https://example.com"
81 | store.Set(sessionID, targetURL)
82 |
83 | // Get the initial last used time
84 | store.mu.RLock()
85 | initialTime := store.sessions[sessionID].lastUsed
86 | store.mu.RUnlock()
87 |
88 | // Wait a small amount of time
89 | time.Sleep(10 * time.Millisecond)
90 |
91 | // Get the session, which should update the last used time
92 | store.Get(sessionID)
93 |
94 | // Check that the last used time was updated
95 | store.mu.RLock()
96 | updatedTime := store.sessions[sessionID].lastUsed
97 | store.mu.RUnlock()
98 |
99 | if !updatedTime.After(initialTime) {
100 | t.Errorf("Last used time was not updated: initial=%v, updated=%v", initialTime, updatedTime)
101 | }
102 | }
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
1 | // Package config provides configuration handling for the MCP router proxy
2 | package config
3 |
4 | import (
5 | "log"
6 | "os"
7 | "regexp"
8 | "sync"
9 | "time"
10 |
11 | "gopkg.in/yaml.v3"
12 | )
13 |
14 | // RouteConfig structure for route mapping
15 | type RouteConfig struct {
16 | Resources map[string]string `yaml:"resources"`
17 | Tools map[string]string `yaml:"tools"`
18 | Default string `yaml:"default"`
19 | UI struct {
20 | Username string `yaml:"username"`
21 | Password string `yaml:"password"`
22 | } `yaml:"ui"`
23 | }
24 |
25 | // RouteRule represents a compiled regex pattern and its target
26 | type RouteRule struct {
27 | Pattern *regexp.Regexp
28 | Target string
29 | }
30 |
31 | // ToolMapping structure for tool mapping
32 | type ToolMapping struct {
33 | OriginalName string `yaml:"originalName"`
34 | TargetName string `yaml:"targetName"`
35 | Target string `yaml:"target"`
36 | }
37 |
38 | // Config holds the application configuration and related data
39 | type Config struct {
40 | RouteConfig RouteConfig
41 | ResourceRegexes []RouteRule
42 | ToolRegexes []RouteRule
43 | ToolMappings []ToolMapping
44 | ConfigMutex sync.RWMutex
45 | }
46 |
47 | // New creates a new Config instance and loads the initial configuration
48 | func New(configPath string) *Config {
49 | c := &Config{}
50 | c.Load(configPath)
51 | return c
52 | }
53 |
54 | // Load reads and parses the configuration file
55 | func (c *Config) Load(configPath string) {
56 | configData, err := os.ReadFile(configPath)
57 | if err != nil {
58 | log.Fatalf("Failed to load config: %v", err)
59 | }
60 |
61 | var newConfig RouteConfig
62 | if err := yaml.Unmarshal(configData, &newConfig); err != nil {
63 | log.Fatalf("Invalid config format: %v", err)
64 | }
65 |
66 | var newResources []RouteRule
67 | for pattern, url := range newConfig.Resources {
68 | r, err := regexp.Compile(pattern)
69 | if err != nil {
70 | log.Fatalf("Invalid resource regex pattern '%s': %v", pattern, err)
71 | }
72 | newResources = append(newResources, RouteRule{r, url})
73 | }
74 |
75 | var newTools []RouteRule
76 | for pattern, url := range newConfig.Tools {
77 | r, err := regexp.Compile(pattern)
78 | if err != nil {
79 | log.Fatalf("Invalid tool regex pattern '%s': %v", pattern, err)
80 | }
81 | newTools = append(newTools, RouteRule{r, url})
82 | }
83 |
84 | c.ConfigMutex.Lock()
85 | c.RouteConfig = newConfig
86 | c.ResourceRegexes = newResources
87 | c.ToolRegexes = newTools
88 | c.ConfigMutex.Unlock()
89 |
90 | log.Println("Router config loaded")
91 | }
92 |
93 | // WatchConfig monitors the config file for changes and reloads when necessary
94 | func (c *Config) WatchConfig(filename string, interval time.Duration) {
95 | lastMod := time.Time{}
96 | for {
97 | info, err := os.Stat(filename)
98 | if err == nil && info.ModTime().After(lastMod) {
99 | lastMod = info.ModTime()
100 | c.Load(filename)
101 | log.Println("Router config reloaded")
102 | }
103 | time.Sleep(interval)
104 | }
105 | }
106 |
107 | // GetDefault returns the default route
108 | func (c *Config) GetDefault() string {
109 | c.ConfigMutex.RLock()
110 | defer c.ConfigMutex.RUnlock()
111 | return c.RouteConfig.Default
112 | }
113 |
114 | // GetResourceRegexes returns a copy of the resource regexes with read lock
115 | func (c *Config) GetResourceRegexes() []RouteRule {
116 | c.ConfigMutex.RLock()
117 | defer c.ConfigMutex.RUnlock()
118 | return c.ResourceRegexes
119 | }
120 |
121 | // GetToolRegexes returns a copy of the tool regexes with read lock
122 | func (c *Config) GetToolRegexes() []RouteRule {
123 | c.ConfigMutex.RLock()
124 | defer c.ConfigMutex.RUnlock()
125 | return c.ToolRegexes
126 | }
127 |
128 | // GetAllTargets returns all configured target URLs
129 | func (c *Config) GetAllTargets() []string {
130 | c.ConfigMutex.RLock()
131 | defer c.ConfigMutex.RUnlock()
132 |
133 | targets := make([]string, 0)
134 |
135 | // Add the default target
136 | if c.RouteConfig.Default != "" {
137 | targets = append(targets, c.RouteConfig.Default)
138 | }
139 |
140 | // Add targets from resource rules
141 | for _, rule := range c.ResourceRegexes {
142 | if !contains(targets, rule.Target) {
143 | targets = append(targets, rule.Target)
144 | }
145 | }
146 |
147 | // Add targets from tool rules
148 | for _, rule := range c.ToolRegexes {
149 | if !contains(targets, rule.Target) {
150 | targets = append(targets, rule.Target)
151 | }
152 | }
153 |
154 | return targets
155 | }
156 |
157 | // Helper function to check if a slice contains a string
158 | func contains(slice []string, item string) bool {
159 | for _, s := range slice {
160 | if s == item {
161 | return true
162 | }
163 | }
164 | return false
165 | }
166 |
167 | // GetToolMappings returns a copy of the tool mappings with read lock
168 | func (c *Config) GetToolMappings() []ToolMapping {
169 | c.ConfigMutex.RLock()
170 | defer c.ConfigMutex.RUnlock()
171 | return c.ToolMappings
172 | }
173 |
--------------------------------------------------------------------------------
/pkg/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestNew(t *testing.T) {
11 | // Create a temporary config file
12 | tempDir := t.TempDir()
13 | configPath := filepath.Join(tempDir, "config.yaml")
14 |
15 | configContent := `
16 | default: http://default-service:8080
17 | resources:
18 | "^/resource/([^/]+)$": "http://resource-service:8080"
19 | tools:
20 | "^/tool/([^/]+)$": "http://tool-service:8080"
21 | `
22 | if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
23 | t.Fatalf("Failed to write test config: %v", err)
24 | }
25 |
26 | // Test creating a new config
27 | cfg := New(configPath)
28 |
29 | // Verify the config was loaded correctly
30 | if cfg.GetDefault() != "http://default-service:8080" {
31 | t.Errorf("Expected default to be 'http://default-service:8080', got '%s'", cfg.GetDefault())
32 | }
33 |
34 | if len(cfg.GetResourceRegexes()) != 1 {
35 | t.Errorf("Expected 1 resource regex, got %d", len(cfg.GetResourceRegexes()))
36 | }
37 |
38 | if len(cfg.GetToolRegexes()) != 1 {
39 | t.Errorf("Expected 1 tool regex, got %d", len(cfg.GetToolRegexes()))
40 | }
41 | }
42 |
43 | func TestLoad(t *testing.T) {
44 | // Table-driven test
45 | tests := []struct {
46 | name string
47 | configContent string
48 | expectedDefault string
49 | expectedResourceCount int
50 | expectedToolCount int
51 | }{
52 | {
53 | name: "basic config",
54 | configContent: `
55 | default: http://default-service:8080
56 | resources:
57 | "^/resource/([^/]+)$": "http://resource-service:8080"
58 | tools:
59 | "^/tool/([^/]+)$": "http://tool-service:8080"
60 | `,
61 | expectedDefault: "http://default-service:8080",
62 | expectedResourceCount: 1,
63 | expectedToolCount: 1,
64 | },
65 | {
66 | name: "multiple resources and tools",
67 | configContent: `
68 | default: http://default-service:8080
69 | resources:
70 | "^/resource/([^/]+)$": "http://resource-service:8080"
71 | "^/api/v1/([^/]+)$": "http://api-service:8080"
72 | tools:
73 | "^/tool/([^/]+)$": "http://tool-service:8080"
74 | "^/utility/([^/]+)$": "http://utility-service:8080"
75 | `,
76 | expectedDefault: "http://default-service:8080",
77 | expectedResourceCount: 2,
78 | expectedToolCount: 2,
79 | },
80 | {
81 | name: "no default",
82 | configContent: `
83 | resources:
84 | "^/resource/([^/]+)$": "http://resource-service:8080"
85 | tools:
86 | "^/tool/([^/]+)$": "http://tool-service:8080"
87 | `,
88 | expectedDefault: "",
89 | expectedResourceCount: 1,
90 | expectedToolCount: 1,
91 | },
92 | }
93 |
94 | for _, tt := range tests {
95 | t.Run(tt.name, func(t *testing.T) {
96 | // Create temp file with test content
97 | tempDir := t.TempDir()
98 | configPath := filepath.Join(tempDir, "config.yaml")
99 | if err := os.WriteFile(configPath, []byte(tt.configContent), 0644); err != nil {
100 | t.Fatalf("Failed to write test config: %v", err)
101 | }
102 |
103 | // Create config and test loading
104 | cfg := &Config{}
105 | cfg.Load(configPath)
106 |
107 | if cfg.GetDefault() != tt.expectedDefault {
108 | t.Errorf("Expected default to be '%s', got '%s'", tt.expectedDefault, cfg.GetDefault())
109 | }
110 |
111 | if len(cfg.GetResourceRegexes()) != tt.expectedResourceCount {
112 | t.Errorf("Expected %d resource regexes, got %d", tt.expectedResourceCount, len(cfg.GetResourceRegexes()))
113 | }
114 |
115 | if len(cfg.GetToolRegexes()) != tt.expectedToolCount {
116 | t.Errorf("Expected %d tool regexes, got %d", tt.expectedToolCount, len(cfg.GetToolRegexes()))
117 | }
118 | })
119 | }
120 | }
121 |
122 | func TestGetAllTargets(t *testing.T) {
123 | // Create a config with known values
124 | cfg := &Config{
125 | RouteConfig: RouteConfig{
126 | Default: "http://default:8080",
127 | Resources: map[string]string{
128 | "pattern1": "http://resource1:8080",
129 | "pattern2": "http://resource2:8080",
130 | },
131 | Tools: map[string]string{
132 | "pattern3": "http://tool1:8080",
133 | "pattern4": "http://resource1:8080", // Duplicate target
134 | },
135 | },
136 | }
137 |
138 | // Manually set up the regexes (normally done by Load)
139 | cfg.ResourceRegexes = []RouteRule{
140 | {Pattern: nil, Target: "http://resource1:8080"},
141 | {Pattern: nil, Target: "http://resource2:8080"},
142 | }
143 |
144 | cfg.ToolRegexes = []RouteRule{
145 | {Pattern: nil, Target: "http://tool1:8080"},
146 | {Pattern: nil, Target: "http://resource1:8080"},
147 | }
148 |
149 | targets := cfg.GetAllTargets()
150 |
151 | // Update the expected number of unique targets
152 | expectedTargets := 4 // default + resource1 + resource2 + tool1
153 | if len(targets) != expectedTargets {
154 | t.Errorf("Expected %d unique targets, got %d: %v", expectedTargets, len(targets), targets)
155 | }
156 |
157 | // Check that all expected targets are present
158 | expectedURLs := []string{
159 | "http://default:8080",
160 | "http://resource1:8080",
161 | "http://resource2:8080",
162 | "http://tool1:8080",
163 | }
164 |
165 | for _, url := range expectedURLs[:expectedTargets] {
166 | if !contains(targets, url) {
167 | t.Errorf("Expected target '%s' not found in results: %v", url, targets)
168 | }
169 | }
170 | }
171 |
172 | func TestWatchConfig(t *testing.T) {
173 | // Create a temporary config file
174 | tempDir := t.TempDir()
175 | configPath := filepath.Join(tempDir, "config.yaml")
176 |
177 | initialConfig := `
178 | default: http://initial-default:8080
179 | resources:
180 | "^/resource/([^/]+)$": "http://initial-resource:8080"
181 | tools:
182 | "^/tool/([^/]+)$": "http://initial-tool:8080"
183 | `
184 | if err := os.WriteFile(configPath, []byte(initialConfig), 0644); err != nil {
185 | t.Fatalf("Failed to write initial test config: %v", err)
186 | }
187 |
188 | // Create config
189 | cfg := New(configPath)
190 |
191 | // Start watching in a goroutine with a short interval
192 | go cfg.WatchConfig(configPath, 100*time.Millisecond)
193 |
194 | // Verify initial config
195 | if cfg.GetDefault() != "http://initial-default:8080" {
196 | t.Errorf("Initial default incorrect, got '%s'", cfg.GetDefault())
197 | }
198 |
199 | // Wait a moment to ensure watcher is running
200 | time.Sleep(200 * time.Millisecond)
201 |
202 | // Update the config file
203 | updatedConfig := `
204 | default: http://updated-default:8080
205 | resources:
206 | "^/resource/([^/]+)$": "http://updated-resource:8080"
207 | "^/new-resource/([^/]+)$": "http://new-resource:8080"
208 | tools:
209 | "^/tool/([^/]+)$": "http://updated-tool:8080"
210 | `
211 | if err := os.WriteFile(configPath, []byte(updatedConfig), 0644); err != nil {
212 | t.Fatalf("Failed to write updated test config: %v", err)
213 | }
214 |
215 | // Wait for the watcher to detect the change
216 | time.Sleep(300 * time.Millisecond)
217 |
218 | // Verify config was updated
219 | if cfg.GetDefault() != "http://updated-default:8080" {
220 | t.Errorf("Updated default incorrect, got '%s'", cfg.GetDefault())
221 | }
222 |
223 | if len(cfg.GetResourceRegexes()) != 2 {
224 | t.Errorf("Expected 2 resource regexes after update, got %d", len(cfg.GetResourceRegexes()))
225 | }
226 | }
227 |
--------------------------------------------------------------------------------
/pkg/ui/ui_test.go:
--------------------------------------------------------------------------------
1 | package ui
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/mclenhard/catie-mcp/pkg/config"
13 | )
14 |
15 | func TestRecordRequest(t *testing.T) {
16 | stats := &Stats{
17 | RequestsByMethod: make(map[string]int64),
18 | RequestsByEndpoint: make(map[string]int64),
19 | ResponseTimes: make(map[string][]int),
20 | StartTime: time.Now(),
21 | }
22 |
23 | // Record some test requests
24 | stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false)
25 | stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, false)
26 | stats.RecordRequest("GET", "/api/products", 100*time.Millisecond, true)
27 |
28 | // Verify stats were recorded correctly
29 | if stats.TotalRequests != 3 {
30 | t.Errorf("Expected 3 total requests, got %d", stats.TotalRequests)
31 | }
32 |
33 | if stats.RequestsByMethod["GET"] != 2 {
34 | t.Errorf("Expected 2 GET requests, got %d", stats.RequestsByMethod["GET"])
35 | }
36 |
37 | if stats.RequestsByMethod["POST"] != 1 {
38 | t.Errorf("Expected 1 POST request, got %d", stats.RequestsByMethod["POST"])
39 | }
40 |
41 | if stats.RequestsByEndpoint["/api/users"] != 2 {
42 | t.Errorf("Expected 2 /api/users requests, got %d", stats.RequestsByEndpoint["/api/users"])
43 | }
44 |
45 | if stats.ErrorCount != 1 {
46 | t.Errorf("Expected 1 error, got %d", stats.ErrorCount)
47 | }
48 |
49 | // Check response times
50 | if len(stats.ResponseTimes["GET"]) != 2 {
51 | t.Errorf("Expected 2 response times for GET, got %d", len(stats.ResponseTimes["GET"]))
52 | }
53 | }
54 |
55 | func TestGetStats(t *testing.T) {
56 | stats := &Stats{
57 | TotalRequests: 10,
58 | RequestsByMethod: map[string]int64{"GET": 7, "POST": 3},
59 | RequestsByEndpoint: map[string]int64{"/api/users": 5, "/api/products": 5},
60 | ResponseTimes: map[string][]int{"GET": {100, 150}, "POST": {200}},
61 | ErrorCount: 2,
62 | StartTime: time.Now(),
63 | }
64 |
65 | // Get a copy of the stats
66 | statsCopy := stats.GetStats()
67 |
68 | // Verify the copy has the correct values
69 | if statsCopy.TotalRequests != 10 {
70 | t.Errorf("Expected 10 total requests in copy, got %d", statsCopy.TotalRequests)
71 | }
72 |
73 | // Modify the original stats
74 | stats.TotalRequests = 15
75 | stats.RequestsByMethod["GET"] = 10
76 |
77 | // Verify the copy wasn't affected
78 | if statsCopy.TotalRequests != 10 {
79 | t.Errorf("Stats copy was affected by changes to original")
80 | }
81 |
82 | if statsCopy.RequestsByMethod["GET"] != 7 {
83 | t.Errorf("Stats copy was affected by changes to original")
84 | }
85 | }
86 |
87 | func TestHandleStats(t *testing.T) {
88 | // Create a UI instance with test data
89 | cfg := &config.Config{
90 | RouteConfig: config.RouteConfig{
91 | UI: struct {
92 | Username string `yaml:"username"`
93 | Password string `yaml:"password"`
94 | }{},
95 | },
96 | }
97 | ui := New(cfg)
98 |
99 | // Add some test data
100 | ui.Stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false)
101 | ui.Stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, false)
102 |
103 | // Test HTML response
104 | req, _ := http.NewRequest("GET", "/stats", nil)
105 | rr := httptest.NewRecorder()
106 |
107 | ui.HandleStats(rr, req)
108 |
109 | if status := rr.Code; status != http.StatusOK {
110 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
111 | }
112 |
113 | if ctype := rr.Header().Get("Content-Type"); ctype != "text/html" {
114 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "text/html")
115 | }
116 |
117 | // Test JSON response
118 | req, _ = http.NewRequest("GET", "/stats", nil)
119 | req.Header.Set("Accept", "application/json")
120 | rr = httptest.NewRecorder()
121 |
122 | ui.HandleStats(rr, req)
123 |
124 | if status := rr.Code; status != http.StatusOK {
125 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
126 | }
127 |
128 | if ctype := rr.Header().Get("Content-Type"); ctype != "application/json" {
129 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "application/json")
130 | }
131 |
132 | // Verify we can parse the JSON response
133 | var data map[string]interface{}
134 | if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil {
135 | t.Errorf("Failed to parse JSON response: %v", err)
136 | }
137 | }
138 |
139 | func TestHandleMetrics(t *testing.T) {
140 | // Create a UI instance with test data
141 | cfg := &config.Config{
142 | RouteConfig: config.RouteConfig{
143 | UI: struct {
144 | Username string `yaml:"username"`
145 | Password string `yaml:"password"`
146 | }{},
147 | },
148 | }
149 | ui := New(cfg)
150 |
151 | // Add some test data
152 | ui.Stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false)
153 | ui.Stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, true)
154 |
155 | // Test metrics response
156 | req, _ := http.NewRequest("GET", "/metrics", nil)
157 | rr := httptest.NewRecorder()
158 |
159 | ui.HandleMetrics(rr, req)
160 |
161 | if status := rr.Code; status != http.StatusOK {
162 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
163 | }
164 |
165 | if ctype := rr.Header().Get("Content-Type"); ctype != "text/plain" {
166 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "text/plain")
167 | }
168 |
169 | // Check for expected metrics in the response
170 | body := rr.Body.String()
171 | expectedMetrics := []string{
172 | "mcp_router_requests_total 2",
173 | "mcp_router_errors_total 1",
174 | "mcp_router_requests_by_method{method=\"GET\"} 1",
175 | "mcp_router_requests_by_method{method=\"POST\"} 1",
176 | "mcp_router_requests_by_endpoint{endpoint=\"/api/users\"} 2",
177 | }
178 |
179 | for _, metric := range expectedMetrics {
180 | if !strings.Contains(body, metric) {
181 | t.Errorf("Expected metric not found in response: %s", metric)
182 | }
183 | }
184 | }
185 |
186 | func TestBasicAuth(t *testing.T) {
187 | // Test with auth credentials
188 | cfg := &config.Config{
189 | RouteConfig: config.RouteConfig{
190 | UI: struct {
191 | Username string `yaml:"username"`
192 | Password string `yaml:"password"`
193 | }{
194 | Username: "admin",
195 | Password: "secret",
196 | },
197 | },
198 | }
199 | ui := New(cfg)
200 |
201 | // Create a simple handler for testing
202 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
203 | w.WriteHeader(http.StatusOK)
204 | w.Write([]byte("Success"))
205 | })
206 |
207 | // Wrap it with basic auth
208 | authHandler := ui.basicAuth(testHandler)
209 |
210 | // Test with no credentials
211 | req, _ := http.NewRequest("GET", "/stats", nil)
212 | rr := httptest.NewRecorder()
213 |
214 | authHandler(rr, req)
215 |
216 | if status := rr.Code; status != http.StatusUnauthorized {
217 | t.Errorf("Handler should return 401 without credentials, got: %v", status)
218 | }
219 |
220 | // Test with invalid credentials
221 | req, _ = http.NewRequest("GET", "/stats", nil)
222 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:wrongpass")))
223 | rr = httptest.NewRecorder()
224 |
225 | authHandler(rr, req)
226 |
227 | if status := rr.Code; status != http.StatusUnauthorized {
228 | t.Errorf("Handler should return 401 with invalid credentials, got: %v", status)
229 | }
230 |
231 | // Test with valid credentials
232 | req, _ = http.NewRequest("GET", "/stats", nil)
233 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:secret")))
234 | rr = httptest.NewRecorder()
235 |
236 | authHandler(rr, req)
237 |
238 | if status := rr.Code; status != http.StatusOK {
239 | t.Errorf("Handler should return 200 with valid credentials, got: %v", status)
240 | }
241 |
242 | // Test with no auth required
243 | cfg = &config.Config{
244 | RouteConfig: config.RouteConfig{
245 | UI: struct {
246 | Username string `yaml:"username"`
247 | Password string `yaml:"password"`
248 | }{},
249 | },
250 | }
251 | ui = New(cfg)
252 |
253 | authHandler = ui.basicAuth(testHandler)
254 | req, _ = http.NewRequest("GET", "/stats", nil)
255 | rr = httptest.NewRecorder()
256 |
257 | authHandler(rr, req)
258 |
259 | if status := rr.Code; status != http.StatusOK {
260 | t.Errorf("Handler should return 200 when no auth required, got: %v", status)
261 | }
262 | }
263 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MCP Catie - Context Aware Traffic Ingress Engine
2 |
3 | A lightweight, configurable reverse proxy for routing and load balancing MCP (Model Context Protocol) requests to appropriate backend services based on request content.
4 |
5 | For detailed documentation, visit [Catie MCP Documentation](https://www.catiemcp.com/docs/).
6 |
7 | ## Features
8 |
9 | - Dynamic routing of MCP JSON-RPC requests based on tool call
10 | - Concatenate tools from multiple different MCP Servers so the client gets a unified view of all tools without the user installing multiple servers.
11 | - Session-aware routing to maintain client connections to the same backend
12 | - Support for Streamable HTTP transport with SSE (Server-Sent Events)
13 | - Tool name mapping and namespacing to resolve naming conflicts between different backends
14 | - Prometheus metrics integration for observability
15 | - Containerized deployment with Docker
16 | - Basic authentication for monitoring UI
17 |
18 | ## Architecture
19 |
20 | The application is structured into several packages:
21 |
22 | - `cmd/main.go` - Application entry point with server setup
23 | - `pkg/config` - Configuration loading and management
24 | - `pkg/router` - Request routing and proxy logic
25 | - `pkg/session` - Session management for maintaining client connections
26 | - `pkg/logger` - Structured logging system
27 | - `pkg/ui` - Simple web UI for monitoring
28 |
29 | ## Configuration
30 |
31 | The router is configured using a YAML file (`router_config.yaml`). Here's an example configuration:
32 |
33 | ```yaml
34 | resources:
35 | "^weather/.*": "http://weather-service:8080/mcp"
36 | "^database/.*": "http://database-service:8080/mcp"
37 | tools:
38 | "^calculator$": "http://calculator-service:8080/mcp"
39 | "^translator$": "http://translator-service:8080/mcp"
40 | toolMappings:
41 | - originalName: "weather"
42 | targetName: "getWeather"
43 | target: "http://weather-service:8080/mcp"
44 | - originalName: "search"
45 | targetName: "googleSearch"
46 | target: "http://search-service:8080/mcp"
47 | default: "http://default-service:8080/mcp"
48 | ui:
49 | username: "admin"
50 | password: "your_secure_password"
51 | ```
52 |
53 | The configuration consists of:
54 |
55 | - `resources`: Regex patterns for resource URIs and their target endpoints
56 | - `tools`: Regex patterns for tool names and their target endpoints
57 | - `toolMappings`: Mappings for tool name transformations to resolve naming conflicts
58 | - `default`: Fallback endpoint for requests that don't match any pattern
59 | - `ui`: Authentication credentials for the monitoring UI
60 |
61 | The configuration file is automatically reloaded when changes are detected.
62 |
63 | ## Tool Name Mapping
64 |
65 | The tool name mapping feature allows you to present a unified tool interface to clients while handling naming differences across backend MCP servers. This is useful when:
66 |
67 | - Different backends use different names for similar functionality
68 | - You want to present a simplified or standardized naming scheme to clients
69 | - You need to avoid naming conflicts between tools from different backends
70 |
71 | For each tool mapping, specify:
72 | - `originalName`: The name presented to clients
73 | - `targetName`: The actual name expected by the backend server
74 | - `target`: The URL of the target backend server
75 |
76 | When a client makes a tool call with the original name, MCProute automatically transforms it to the target name before forwarding the request to the appropriate backend.
77 |
78 | ## Installation
79 |
80 | ### Prerequisites
81 |
82 | - Go 1.18 or higher
83 | - Docker (optional, for containerized deployment)
84 |
85 | ### Building from Source
86 |
87 | 1. Clone the repository:
88 | ```bash
89 | git clone https://github.com/mclenhard/mcp-router-proxy.git
90 | cd mcp-router-proxy
91 | ```
92 |
93 | 2. Build the application:
94 | ```bash
95 | go build -o mcp-router-proxy ./cmd/main.go
96 | ```
97 |
98 | 3. Edit router_config.yaml to match your environment
99 |
100 | 4. Run the application:
101 | ```bash
102 | ./mcp-router-proxy
103 | ```
104 |
105 | ### Using Docker
106 |
107 | 1. Build the Docker image:
108 | ```bash
109 | docker build -t mcp-router-proxy .
110 | ```
111 |
112 | 2. Run the container:
113 | ```bash
114 | docker run -p 80:80 -v $(pwd)/router_config.yaml:/root/router_config.yaml mcp-router-proxy
115 | ```
116 |
117 | ## Usage
118 |
119 | The proxy listens for MCP requests on the `/mcp` endpoint. Requests are routed based on their method and parameters:
120 |
121 | - `resources/read` requests are routed based on the `uri` parameter
122 | - `tools/call` requests are routed based on the `name` parameter
123 | - Other requests are sent to the default endpoint
124 |
125 | The proxy supports both GET and POST methods according to the MCP Streamable HTTP transport specification:
126 |
127 | - POST requests are used to send JSON-RPC messages to the server
128 | - GET requests are used to establish SSE streams for server-to-client communication
129 |
130 | ### Session Management
131 |
132 | The proxy maintains session state by tracking the `Mcp-Session-Id` header. When a client establishes a session with an MCP server through the proxy, subsequent requests with the same session ID are routed to the same backend server.
133 |
134 | ### Health Check
135 |
136 | A health check endpoint is available at `/health` which returns a 200 OK response when the service is running.
137 |
138 | ### Monitoring
139 |
140 | A simple monitoring UI is available at `/stats` which shows request statistics and routing information. This interface is protected by basic authentication using the credentials specified in the configuration file.
141 |
142 | 
143 |
144 | ### Prometheus Integration
145 |
146 | The service exposes Prometheus-compatible metrics at the `/metrics` endpoint. These metrics include:
147 |
148 | - `mcp_router_requests_total`: Total number of requests processed
149 | - `mcp_router_errors_total`: Total number of request errors
150 | - `mcp_router_requests_by_method`: Number of requests broken down by method
151 | - `mcp_router_requests_by_endpoint`: Number of requests broken down by target endpoint
152 | - `mcp_router_response_time_ms`: Average response time in milliseconds by method
153 | - `mcp_router_uptime_seconds`: Time since the router started in seconds
154 |
155 | You can configure Prometheus to scrape these metrics by adding the following to your Prometheus configuration:
156 |
157 | ```yaml
158 | scrape_configs:
159 | - job_name: 'mcp-router'
160 | scrape_interval: 15s
161 | static_configs:
162 | - targets: ['your-router-host:80']
163 | ```
164 |
165 | This endpoint is also protected by the same basic authentication as the stats UI.
166 |
167 | ## Development
168 |
169 | ### Project Structure
170 |
171 | ```
172 | mcp-router-proxy/
173 | ├── cmd/
174 | │ └── main.go
175 | ├── pkg/
176 | │ ├── config/
177 | │ │ └── config.go
178 | │ ├── router/
179 | │ │ └── router.go
180 | │ ├── session/
181 | │ │ └── store.go
182 | │ ├── logger/
183 | │ │ └── logger.go
184 | │ └── ui/
185 | │ └── ui.go
186 | ├── Dockerfile
187 | ├── go.mod
188 | ├── go.sum
189 | ├── README.md
190 | └── router_config.yaml
191 | ```
192 |
193 | ### Adding New Features
194 |
195 | 1. Fork the repository
196 | 2. Create a feature branch
197 | 3. Add your changes
198 | 4. Submit a pull request
199 |
200 | ## Roadmap
201 |
202 | The following features are planned for upcoming releases:
203 | - ~~**Add SSE Support**: Add support for SSE (Server-Sent Events) to the proxy~~
204 | - **Complete Message Forwarding**: Ensure all MCP message types (including roots and sampling) are properly forwarded without interference
205 | - **Intelligent Caching**: Response caching with configurable TTL, cache invalidation, and support for memory and Redis backends
206 | - **Rate Limiting**: Configurable rate limiting with multiple strategies, response headers, and distributed rate limiting support
207 | - **Circuit Breaking**: Automatic detection of backend failures with fallback responses
208 | - **Request Transformation**: Modify requests before forwarding to backends
209 | - **Response Transformation**: Transform backend responses before returning to clients
210 |
211 | Development priorities are based on community feedback. Please open an issue to request features or contribute to the roadmap discussion.
212 |
213 | ## License
214 |
215 | [MIT License](LICENSE)
216 |
217 | ## Contributing
218 |
219 | Contributions are welcome! Please feel free to submit a Pull Request.
220 |
221 | ## Support
222 |
223 | For support, please open an issue in the GitHub repository or contact me at [mclenhard@gmail.com](mailto:mclenhard@gmail.com).
--------------------------------------------------------------------------------
/pkg/router/router_test.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "net/http"
7 | "net/http/httptest"
8 | "regexp"
9 | "testing"
10 |
11 | "github.com/mclenhard/catie-mcp/pkg/config"
12 | "github.com/mclenhard/catie-mcp/pkg/logger"
13 | "github.com/mclenhard/catie-mcp/pkg/session"
14 | "github.com/mclenhard/catie-mcp/pkg/ui"
15 | )
16 |
17 | // MockConfig implements a simple config for testing
18 | type MockConfig struct {
19 | DefaultURL string
20 | ResourceRegexes []config.RouteRule
21 | ToolRegexes []config.RouteRule
22 | AllTargets []string
23 | }
24 |
25 | func (m *MockConfig) GetDefault() string {
26 | return m.DefaultURL
27 | }
28 |
29 | func (m *MockConfig) GetResourceRegexes() []config.RouteRule {
30 | return m.ResourceRegexes
31 | }
32 |
33 | func (m *MockConfig) GetToolRegexes() []config.RouteRule {
34 | return m.ToolRegexes
35 | }
36 |
37 | func (m *MockConfig) GetAllTargets() []string {
38 | return m.AllTargets
39 | }
40 |
41 | func (m *MockConfig) GetToolMappings() []config.ToolMapping {
42 | // Return an empty map for testing purposes
43 | return []config.ToolMapping{}
44 | }
45 |
46 | func TestRouteByContext(t *testing.T) {
47 | // Create a mock config
48 | mockConfig := &MockConfig{
49 | DefaultURL: "http://default:8080",
50 | ResourceRegexes: []config.RouteRule{
51 | {Pattern: compileRegex(t, "^/resource/([^/]+)$"), Target: "http://resource:8080"},
52 | {Pattern: compileRegex(t, "^/api/v1/([^/]+)$"), Target: "http://api:8080"},
53 | },
54 | ToolRegexes: []config.RouteRule{
55 | {Pattern: compileRegex(t, "^/tool/([^/]+)$"), Target: "http://tool:8080"},
56 | {Pattern: compileRegex(t, "^/utility/([^/]+)$"), Target: "http://utility:8080"},
57 | },
58 | AllTargets: []string{
59 | "http://default:8080",
60 | "http://resource:8080",
61 | "http://api:8080",
62 | "http://tool:8080",
63 | "http://utility:8080",
64 | },
65 | }
66 |
67 | // Create a router with the mock config
68 | r := &Router{
69 | Config: mockConfig,
70 | Logger: logger.New(logger.Debug),
71 | }
72 |
73 | // Test cases
74 | tests := []struct {
75 | name string
76 | request JSONRPCRequest
77 | expectedTarget string
78 | }{
79 | {
80 | name: "resources/read with matching URI",
81 | request: JSONRPCRequest{
82 | Method: "resources/read",
83 | Params: map[string]interface{}{
84 | "uri": "/resource/test",
85 | },
86 | },
87 | expectedTarget: "http://resource:8080",
88 | },
89 | {
90 | name: "resources/read with different matching URI",
91 | request: JSONRPCRequest{
92 | Method: "resources/read",
93 | Params: map[string]interface{}{
94 | "uri": "/api/v1/test",
95 | },
96 | },
97 | expectedTarget: "http://api:8080",
98 | },
99 | {
100 | name: "resources/read with non-matching URI",
101 | request: JSONRPCRequest{
102 | Method: "resources/read",
103 | Params: map[string]interface{}{
104 | "uri": "/nonmatching/test",
105 | },
106 | },
107 | expectedTarget: "http://default:8080",
108 | },
109 | {
110 | name: "tools/call with matching name",
111 | request: JSONRPCRequest{
112 | Method: "tools/call",
113 | Params: map[string]interface{}{
114 | "name": "/tool/test",
115 | },
116 | },
117 | expectedTarget: "http://tool:8080",
118 | },
119 | {
120 | name: "tools/call with different matching name",
121 | request: JSONRPCRequest{
122 | Method: "tools/call",
123 | Params: map[string]interface{}{
124 | "name": "/utility/test",
125 | },
126 | },
127 | expectedTarget: "http://utility:8080",
128 | },
129 | {
130 | name: "tools/call with non-matching name",
131 | request: JSONRPCRequest{
132 | Method: "tools/call",
133 | Params: map[string]interface{}{
134 | "name": "/nonmatching/test",
135 | },
136 | },
137 | expectedTarget: "http://default:8080",
138 | },
139 | {
140 | name: "unknown method",
141 | request: JSONRPCRequest{
142 | Method: "unknown/method",
143 | Params: map[string]interface{}{},
144 | },
145 | expectedTarget: "http://default:8080",
146 | },
147 | }
148 |
149 | for _, tt := range tests {
150 | t.Run(tt.name, func(t *testing.T) {
151 | target := r.RouteByContext(tt.request)
152 | if target != tt.expectedTarget {
153 | t.Errorf("Expected target %s, got %s", tt.expectedTarget, target)
154 | }
155 | })
156 | }
157 | }
158 |
159 | func TestDetermineTargetForSession(t *testing.T) {
160 | // Create a mock config
161 | mockConfig := &MockConfig{
162 | DefaultURL: "http://default:8080",
163 | AllTargets: []string{
164 | "http://default:8080",
165 | "http://target1:8080",
166 | "http://target2:8080",
167 | },
168 | }
169 |
170 | // Create a router with the mock config
171 | r := &Router{
172 | Config: mockConfig,
173 | SessionStore: NewTestSessionStore(),
174 | Logger: logger.New(logger.Debug),
175 | }
176 |
177 | // Test with no session ID
178 | target := r.determineTargetForSession("")
179 | if target != "http://default:8080" {
180 | t.Errorf("Expected default target for empty session ID, got %s", target)
181 | }
182 |
183 | // Test with a known session ID
184 | r.SessionStore.Set("known-session", "http://target1:8080")
185 | target = r.determineTargetForSession("known-session")
186 | if target != "http://target1:8080" {
187 | t.Errorf("Expected target1 for known session ID, got %s", target)
188 | }
189 |
190 | // Test with an unknown session ID
191 | // This should hash consistently to one of the targets
192 | target1 := r.determineTargetForSession("unknown-session-1")
193 | target2 := r.determineTargetForSession("unknown-session-1") // Same ID should give same target
194 | if target1 != target2 {
195 | t.Errorf("Same session ID gave different targets: %s and %s", target1, target2)
196 | }
197 |
198 | // Different unknown session IDs might hash to different targets
199 | // but we can at least verify they're in our list of targets
200 | target3 := r.determineTargetForSession("unknown-session-2")
201 | found := false
202 | for _, t := range mockConfig.AllTargets {
203 | if t == target3 {
204 | found = true
205 | break
206 | }
207 | }
208 | if !found {
209 | t.Errorf("Target %s not found in list of valid targets", target3)
210 | }
211 | }
212 |
213 | func TestHandleMCPRequest_Initialize(t *testing.T) {
214 | // Create a mock config
215 | mockConfig := &MockConfig{
216 | DefaultURL: "http://default:8080",
217 | }
218 |
219 | // Create a router with the mock config
220 | r := &Router{
221 | Config: mockConfig,
222 | UI: &ui.UI{Stats: &ui.Stats{RequestsByMethod: make(map[string]int64), RequestsByEndpoint: make(map[string]int64), ResponseTimes: make(map[string][]int)}},
223 | SessionStore: NewTestSessionStore(),
224 | Logger: logger.New(logger.Debug),
225 | }
226 |
227 | // Create a test server that will respond to our proxied requests
228 | testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
229 | // Add a session ID to the response
230 | w.Header().Set("Mcp-Session-Id", "test-session-123")
231 | w.Header().Set("Content-Type", "application/json")
232 | w.WriteHeader(http.StatusOK)
233 | w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`))
234 | }))
235 | defer testServer.Close()
236 |
237 | // Override the default URL to point to our test server
238 | mockConfig.DefaultURL = testServer.URL
239 |
240 | // Create an initialize request
241 | initRequest := JSONRPCRequest{
242 | JSONRPC: "2.0",
243 | ID: 1,
244 | Method: "initialize",
245 | Params: map[string]interface{}{},
246 | }
247 | requestBody, _ := json.Marshal(initRequest)
248 |
249 | // Create a test request
250 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(requestBody))
251 | req.Header.Set("Content-Type", "application/json")
252 | req.Header.Set("Accept", "application/json")
253 |
254 | // Create a response recorder
255 | rr := httptest.NewRecorder()
256 |
257 | // Call the handler
258 | r.HandleMCPRequest(rr, req)
259 |
260 | // Check the response
261 | if rr.Code != http.StatusOK {
262 | t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
263 | }
264 |
265 | // Check that the session ID was stored
266 | target, exists := r.SessionStore.Get("test-session-123")
267 | if !exists {
268 | t.Errorf("Session ID was not stored")
269 | }
270 | if target != testServer.URL {
271 | t.Errorf("Expected target %s, got %s", testServer.URL, target)
272 | }
273 |
274 | // Check the response body
275 | var response map[string]interface{}
276 | if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
277 | t.Errorf("Failed to parse response body: %v", err)
278 | }
279 | if response["result"].(map[string]interface{})["status"] != "ok" {
280 | t.Errorf("Expected status 'ok', got %v", response["result"])
281 | }
282 | }
283 |
284 | // Helper function to compile regex patterns for testing
285 | func compileRegex(t *testing.T, pattern string) *regexp.Regexp {
286 | r, err := regexp.Compile(pattern)
287 | if err != nil {
288 | t.Fatalf("Failed to compile regex pattern '%s': %v", pattern, err)
289 | }
290 | return r
291 | }
292 |
293 | // NewTestSessionStore creates a simple session store for testing
294 | func NewTestSessionStore() *session.Store {
295 | return session.NewStore()
296 | }
297 |
--------------------------------------------------------------------------------
/pkg/ui/ui.go:
--------------------------------------------------------------------------------
1 | // Package ui provides a web interface for monitoring the MCP router proxy
2 | package ui
3 |
4 | import (
5 | "encoding/json"
6 | "fmt"
7 | "html/template"
8 | "net/http"
9 | "strconv"
10 | "sync"
11 | "time"
12 |
13 | "github.com/mclenhard/catie-mcp/pkg/config"
14 | "github.com/mclenhard/catie-mcp/pkg/logger"
15 | )
16 |
17 | // Stats tracks various metrics about the router proxy
18 | type Stats struct {
19 | TotalRequests int64 `json:"totalRequests"`
20 | RequestsByMethod map[string]int64 `json:"requestsByMethod"`
21 | RequestsByEndpoint map[string]int64 `json:"requestsByEndpoint"`
22 | ResponseTimes map[string][]int `json:"responseTimes"` // in milliseconds
23 | ErrorCount int64 `json:"errorCount"`
24 | StartTime time.Time `json:"startTime"`
25 | mu sync.RWMutex
26 | }
27 |
28 | // UI handles the web interface for the router proxy
29 | type UI struct {
30 | Stats *Stats
31 | template *template.Template
32 | username string
33 | password string
34 | logger *logger.Logger
35 | }
36 |
37 | // New creates a new UI instance
38 | func New(config *config.Config) *UI {
39 | tmpl := template.Must(template.New("stats").Parse(statsTemplate))
40 |
41 | return &UI{
42 | Stats: &Stats{
43 | RequestsByMethod: make(map[string]int64),
44 | RequestsByEndpoint: make(map[string]int64),
45 | ResponseTimes: make(map[string][]int),
46 | StartTime: time.Now(),
47 | },
48 | template: tmpl,
49 | username: config.RouteConfig.UI.Username,
50 | password: config.RouteConfig.UI.Password,
51 | logger: logger.New(logger.Info), // Initialize with Info level
52 | }
53 | }
54 |
55 | // RecordRequest records statistics for a request
56 | func (s *Stats) RecordRequest(method, endpoint string, responseTime time.Duration, isError bool) {
57 | s.mu.Lock()
58 | defer s.mu.Unlock()
59 |
60 | s.TotalRequests++
61 | s.RequestsByMethod[method]++
62 | s.RequestsByEndpoint[endpoint]++
63 |
64 | // Store response time (convert to milliseconds)
65 | respTimeMs := int(responseTime.Milliseconds())
66 | s.ResponseTimes[method] = append(s.ResponseTimes[method], respTimeMs)
67 |
68 | // Limit the number of stored response times to prevent memory issues
69 | if len(s.ResponseTimes[method]) > 1000 {
70 | s.ResponseTimes[method] = s.ResponseTimes[method][len(s.ResponseTimes[method])-1000:]
71 | }
72 |
73 | if isError {
74 | s.ErrorCount++
75 | }
76 | }
77 |
78 | // GetStats returns a copy of the current stats
79 | func (s *Stats) GetStats() Stats {
80 | s.mu.RLock()
81 | defer s.mu.RUnlock()
82 |
83 | // Create a deep copy to avoid race conditions
84 | statsCopy := Stats{
85 | TotalRequests: s.TotalRequests,
86 | RequestsByMethod: make(map[string]int64),
87 | RequestsByEndpoint: make(map[string]int64),
88 | ResponseTimes: make(map[string][]int),
89 | ErrorCount: s.ErrorCount,
90 | StartTime: s.StartTime,
91 |
92 | }
93 |
94 | for k, v := range s.RequestsByMethod {
95 | statsCopy.RequestsByMethod[k] = v
96 | }
97 |
98 | for k, v := range s.RequestsByEndpoint {
99 | statsCopy.RequestsByEndpoint[k] = v
100 | }
101 |
102 | for k, v := range s.ResponseTimes {
103 | statsCopy.ResponseTimes[k] = make([]int, len(v))
104 | copy(statsCopy.ResponseTimes[k], v)
105 | }
106 |
107 | return statsCopy
108 | }
109 |
110 | // HandleStats serves the stats page
111 | func (ui *UI) HandleStats(w http.ResponseWriter, r *http.Request) {
112 | stats := ui.Stats.GetStats()
113 |
114 | // Calculate uptime
115 | uptime := time.Since(stats.StartTime).String()
116 |
117 | // Calculate average response times
118 | avgResponseTimes := make(map[string]int)
119 | for method, times := range stats.ResponseTimes {
120 | if len(times) == 0 {
121 | continue
122 | }
123 |
124 | sum := 0
125 | for _, t := range times {
126 | sum += t
127 | }
128 | avgResponseTimes[method] = sum / len(times)
129 | }
130 |
131 | data := struct {
132 | Stats Stats
133 | Uptime string
134 | AvgResponseTimes map[string]int
135 | }{
136 | Stats: stats,
137 | Uptime: uptime,
138 | AvgResponseTimes: avgResponseTimes,
139 | }
140 |
141 | if r.Header.Get("Accept") == "application/json" {
142 | w.Header().Set("Content-Type", "application/json")
143 | json.NewEncoder(w).Encode(data)
144 | return
145 | }
146 |
147 | w.Header().Set("Content-Type", "text/html")
148 | ui.template.Execute(w, data)
149 | }
150 |
151 | // RegisterHandlers registers the UI handlers with the provided mux
152 | func (ui *UI) RegisterHandlers(mux *http.ServeMux) {
153 | fmt.Println("Registering UI handlers with provided mux")
154 | mux.HandleFunc("/stats", ui.basicAuth(ui.HandleStats))
155 | mux.HandleFunc("/metrics", ui.basicAuth(ui.HandleMetrics))
156 | fmt.Println("UI handlers registered for /stats and /metrics")
157 | }
158 |
159 | // basicAuth wraps an http.HandlerFunc with basic authentication
160 | func (ui *UI) basicAuth(handler http.HandlerFunc) http.HandlerFunc {
161 | return func(w http.ResponseWriter, r *http.Request) {
162 | ui.logger.Debug("Received request for %s from %s", r.URL.Path, r.RemoteAddr)
163 |
164 | // Skip auth if credentials aren't configured
165 | if ui.username == "" && ui.password == "" {
166 | ui.logger.Info("No auth credentials configured, skipping authentication for %s", r.URL.Path)
167 | handler(w, r)
168 | return
169 | }
170 |
171 | ui.logger.Debug("Auth required for %s - username: '%s', password: [hidden]", r.URL.Path, ui.username)
172 | user, pass, ok := r.BasicAuth()
173 |
174 | if !ok {
175 | ui.logger.Warn("No auth credentials provided for %s from %s", r.URL.Path, r.RemoteAddr)
176 | w.Header().Set("WWW-Authenticate", `Basic realm="MCP Router Stats"`)
177 | w.WriteHeader(http.StatusUnauthorized)
178 | w.Write([]byte("Unauthorized - No credentials provided"))
179 | return
180 | }
181 |
182 | ui.logger.Debug("Auth attempt for %s - provided username: '%s'", r.URL.Path, user)
183 |
184 | if user != ui.username || pass != ui.password {
185 | ui.logger.Warn("Invalid credentials for %s from %s (user: %s)", r.URL.Path, r.RemoteAddr, user)
186 | w.Header().Set("WWW-Authenticate", `Basic realm="MCP Router Stats"`)
187 | w.WriteHeader(http.StatusUnauthorized)
188 | w.Write([]byte("Unauthorized - Invalid credentials"))
189 | return
190 | }
191 |
192 | ui.logger.Info("Authorized access to %s from %s (user: %s)", r.URL.Path, r.RemoteAddr, user)
193 | handler(w, r)
194 | }
195 | }
196 |
197 | // HandleMetrics serves Prometheus-compatible metrics
198 | func (ui *UI) HandleMetrics(w http.ResponseWriter, r *http.Request) {
199 | stats := ui.Stats.GetStats()
200 |
201 | w.Header().Set("Content-Type", "text/plain")
202 |
203 | // Basic Prometheus-style metrics
204 | w.Write([]byte("# HELP mcp_router_requests_total Total number of requests processed\n"))
205 | w.Write([]byte("# TYPE mcp_router_requests_total counter\n"))
206 | w.Write([]byte("mcp_router_requests_total " + strconv.FormatInt(stats.TotalRequests, 10) + "\n\n"))
207 |
208 | w.Write([]byte("# HELP mcp_router_errors_total Total number of request errors\n"))
209 | w.Write([]byte("# TYPE mcp_router_errors_total counter\n"))
210 | w.Write([]byte("mcp_router_errors_total " + strconv.FormatInt(stats.ErrorCount, 10) + "\n\n"))
211 |
212 | // Method-specific metrics
213 | w.Write([]byte("# HELP mcp_router_requests_by_method Number of requests by method\n"))
214 | w.Write([]byte("# TYPE mcp_router_requests_by_method counter\n"))
215 | for method, count := range stats.RequestsByMethod {
216 | w.Write([]byte("mcp_router_requests_by_method{method=\"" + method + "\"} " + strconv.FormatInt(count, 10) + "\n"))
217 | }
218 |
219 | // Endpoint-specific metrics
220 | w.Write([]byte("\n# HELP mcp_router_requests_by_endpoint Number of requests by endpoint\n"))
221 | w.Write([]byte("# TYPE mcp_router_requests_by_endpoint counter\n"))
222 | for endpoint, count := range stats.RequestsByEndpoint {
223 | w.Write([]byte("mcp_router_requests_by_endpoint{endpoint=\"" + endpoint + "\"} " + strconv.FormatInt(count, 10) + "\n"))
224 | }
225 |
226 | // Response time metrics
227 | w.Write([]byte("\n# HELP mcp_router_response_time_ms Average response time in milliseconds\n"))
228 | w.Write([]byte("# TYPE mcp_router_response_time_ms gauge\n"))
229 | for method, times := range stats.ResponseTimes {
230 | if len(times) == 0 {
231 | continue
232 | }
233 |
234 | sum := 0
235 | for _, t := range times {
236 | sum += t
237 | }
238 | avg := sum / len(times)
239 | w.Write([]byte("mcp_router_response_time_ms{method=\"" + method + "\"} " + strconv.Itoa(avg) + "\n"))
240 | }
241 |
242 | // Uptime metric
243 | uptime := time.Since(stats.StartTime).Seconds()
244 | w.Write([]byte("\n# HELP mcp_router_uptime_seconds Time since the router started in seconds\n"))
245 | w.Write([]byte("# TYPE mcp_router_uptime_seconds counter\n"))
246 | w.Write([]byte("mcp_router_uptime_seconds " + strconv.FormatFloat(uptime, 'f', 1, 64) + "\n"))
247 | }
248 |
249 | // HTML template for the stats page
250 | const statsTemplate = `
251 |
252 |
253 |
254 | Request Statistics
255 |
256 |
313 |
314 |
315 | MCP Router Proxy Statistics
316 |
317 |
318 |
319 |
320 |
321 |
General Stats
322 |
323 |
324 | | Total Requests |
325 | {{.Stats.TotalRequests}} |
326 |
327 |
328 | | Error Count |
329 | {{.Stats.ErrorCount}} |
330 |
331 |
332 | | Uptime |
333 | {{.Uptime}} |
334 |
335 |
336 |
337 |
338 |
339 |
Requests by Method
340 |
341 |
342 | | Method |
343 | Count |
344 | Avg Response Time (ms) |
345 |
346 | {{range $method, $count := .Stats.RequestsByMethod}}
347 |
348 | | {{$method}} |
349 | {{$count}} |
350 | {{index $.AvgResponseTimes $method}} |
351 |
352 | {{end}}
353 |
354 |
355 |
356 |
357 |
Requests by Endpoint
358 |
359 |
360 | | Endpoint |
361 | Count |
362 |
363 | {{range $endpoint, $count := .Stats.RequestsByEndpoint}}
364 |
365 | | {{$endpoint}} |
366 | {{$count}} |
367 |
368 | {{end}}
369 |
370 |
371 |
372 |
373 |
379 |
380 |
381 | `
382 |
--------------------------------------------------------------------------------
/pkg/router/router.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "encoding/json"
8 | "fmt"
9 | "hash/fnv"
10 | "io"
11 | "net/http"
12 | "strings"
13 | "sync"
14 | "time"
15 |
16 | "github.com/mclenhard/catie-mcp/pkg/config"
17 | "github.com/mclenhard/catie-mcp/pkg/logger"
18 | "github.com/mclenhard/catie-mcp/pkg/session"
19 | "github.com/mclenhard/catie-mcp/pkg/ui"
20 | )
21 |
22 | // JSONRPCRequest structure
23 | type JSONRPCRequest struct {
24 | JSONRPC string `json:"jsonrpc"`
25 | ID interface{} `json:"id,omitempty"`
26 | Method string `json:"method"`
27 | Params map[string]interface{} `json:"params,omitempty"`
28 | }
29 |
30 | // ConfigInterface is an interface that both MockConfig and config.Config can implement
31 | type ConfigInterface interface {
32 | GetDefault() string
33 | GetResourceRegexes() []config.RouteRule
34 | GetToolRegexes() []config.RouteRule
35 | GetAllTargets() []string
36 | GetToolMappings() []config.ToolMapping
37 | }
38 |
39 | // Router handles the routing of MCP requests
40 | type Router struct {
41 | Config ConfigInterface // Changed from config.Config
42 | UI *ui.UI
43 | SessionStore *session.Store
44 | Logger *logger.Logger
45 | }
46 |
47 | // New creates a new Router instance
48 | func New(cfg *config.Config, ui *ui.UI) *Router {
49 | return &Router{
50 | Config: cfg,
51 | UI: ui,
52 | SessionStore: session.NewStore(),
53 | Logger: logger.New(logger.Info), // Default to Info level
54 | }
55 | }
56 |
57 | // Add this constant at the top of the file
58 | const (
59 | // DefaultTimeout is the default timeout for HTTP requests
60 | DefaultTimeout = 30 * time.Second
61 | )
62 |
63 | // HandleMCPRequest processes incoming MCP requests and routes them to the appropriate target
64 | func (r *Router) HandleMCPRequest(w http.ResponseWriter, req *http.Request) {
65 | startTime := time.Now()
66 | var isError bool
67 |
68 | r.Logger.Info("Received request: Method=%s, Path=%s, ContentType=%s",
69 | req.Method, req.URL.Path, req.Header.Get("Content-Type"))
70 |
71 | // Handle OPTIONS requests (CORS preflight)
72 | if req.Method == http.MethodOptions {
73 | r.Logger.Debug("Handling OPTIONS request")
74 | // Set CORS headers
75 | w.Header().Set("Access-Control-Allow-Origin", "*")
76 | w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
77 | w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id, Authorization, MCP-Protocol-Version")
78 | w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
79 |
80 | // Respond with 200 OK for OPTIONS requests
81 | w.WriteHeader(http.StatusOK)
82 | r.UI.Stats.RecordRequest("OPTIONS", "cors", time.Since(startTime), false)
83 | return
84 | }
85 |
86 | // Check if this is a GET request (for SSE streaming)
87 | if req.Method == http.MethodGet {
88 | r.Logger.Debug("Handling GET request for SSE")
89 | // Extract session ID if present
90 | sessionID := req.Header.Get("Mcp-Session-Id")
91 | r.Logger.Debug("Session ID: %s", sessionID)
92 |
93 | // Determine target based on session ID or other routing logic
94 | targetURL := r.determineTargetForSession(sessionID)
95 | r.Logger.Info("Routing SSE stream to target: %s", targetURL)
96 |
97 | // Set up SSE headers for client
98 | w.Header().Set("Content-Type", "text/event-stream")
99 | w.Header().Set("Cache-Control", "no-cache")
100 | w.Header().Set("Connection", "keep-alive")
101 | w.Header().Set("Access-Control-Allow-Origin", "*")
102 |
103 | // Create a new GET request to the target server
104 | proxyReq, err := http.NewRequest(http.MethodGet, targetURL, nil)
105 | if err != nil {
106 | r.Logger.Error("Error creating proxy request: %v", err)
107 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError)
108 | isError = true
109 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError)
110 | return
111 | }
112 |
113 | // Copy relevant headers from the original request
114 | proxyReq.Header.Set("Accept", "text/event-stream")
115 | if sessionID != "" {
116 | proxyReq.Header.Set("Mcp-Session-Id", sessionID)
117 | }
118 |
119 | // Copy Last-Event-ID header if present (for resuming streams)
120 | if lastEventID := req.Header.Get("Last-Event-ID"); lastEventID != "" {
121 | proxyReq.Header.Set("Last-Event-ID", lastEventID)
122 | }
123 |
124 | // Make the request to the target server
125 | client := &http.Client{
126 | // No timeout for SSE connections
127 | }
128 |
129 | resp, err := client.Do(proxyReq)
130 | if err != nil {
131 | r.Logger.Error("Error connecting to target server: %v", err)
132 | http.Error(w, "error connecting to target server: "+err.Error(), http.StatusBadGateway)
133 | isError = true
134 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError)
135 | return
136 | }
137 | defer resp.Body.Close()
138 |
139 | // Check if the target server returned an error
140 | if resp.StatusCode != http.StatusOK {
141 | // Read error body and forward it to the client
142 | errorBody, _ := io.ReadAll(resp.Body)
143 | r.Logger.Error("Target server returned error status %d: %s", resp.StatusCode, string(errorBody))
144 | http.Error(w, string(errorBody), resp.StatusCode)
145 | isError = true
146 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError)
147 | return
148 | }
149 |
150 | r.Logger.Debug("Successfully connected to SSE stream from target")
151 |
152 | // Create a context that's canceled when the client connection closes
153 | ctx, cancel := context.WithCancel(req.Context())
154 | defer cancel()
155 |
156 | // Create a done channel to signal when to close the connection
157 | done := make(chan bool)
158 | var once sync.Once // Add this to ensure the channel is only closed once
159 |
160 | // Handle client disconnection using context
161 | go func() {
162 | <-ctx.Done()
163 | once.Do(func() { close(done) }) // Use sync.Once to safely close the channel
164 | // Also close the response body to terminate the connection to the target
165 | resp.Body.Close()
166 | r.Logger.Debug("Client disconnected, closing SSE stream")
167 | }()
168 |
169 | // Stream SSE events from target to client
170 | go func() {
171 | scanner := bufio.NewScanner(resp.Body)
172 | // Set a larger buffer for the scanner to handle large SSE events
173 | const maxScanTokenSize = 1024 * 1024 // 1MB
174 | buf := make([]byte, maxScanTokenSize)
175 | scanner.Buffer(buf, maxScanTokenSize)
176 |
177 | eventCount := 0
178 | for scanner.Scan() {
179 | select {
180 | case <-done:
181 | return
182 | default:
183 | line := scanner.Text()
184 | fmt.Fprintf(w, "%s\n", line)
185 | // If this is the end of an event, flush the buffer
186 | if line == "" {
187 | eventCount++
188 | if eventCount%100 == 0 {
189 | r.Logger.Debug("Streamed %d SSE events so far", eventCount)
190 | }
191 | if flusher, ok := w.(http.Flusher); ok {
192 | flusher.Flush()
193 | }
194 | }
195 | }
196 | }
197 |
198 | if err := scanner.Err(); err != nil {
199 | r.Logger.Error("Error scanning SSE stream: %v", err)
200 | }
201 |
202 | r.Logger.Info("SSE stream closed after sending %d events", eventCount)
203 | once.Do(func() { close(done) }) // Safely close the channel if not already closed
204 | }()
205 |
206 | // Set up heartbeat ticker
207 | heartbeatTicker := time.NewTicker(30 * time.Second)
208 | defer heartbeatTicker.Stop()
209 |
210 | // Send heartbeat events to keep the connection alive
211 | go func() {
212 | for {
213 | select {
214 | case <-heartbeatTicker.C:
215 | // Send a heartbeat event
216 | fmt.Fprintf(w, "event: heartbeat\ndata: %d\n\n", time.Now().Unix())
217 | if flusher, ok := w.(http.Flusher); ok {
218 | flusher.Flush()
219 | }
220 | r.Logger.Debug("Sent heartbeat event")
221 | case <-done:
222 | return
223 | }
224 | }
225 | }()
226 |
227 | // Wait until done
228 | <-done
229 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError)
230 | return
231 | }
232 |
233 | // Handle POST requests (client sending messages to server)
234 | if req.Method == http.MethodPost {
235 | r.Logger.Debug("Handling POST request")
236 |
237 | // Read the request body
238 | body, err := io.ReadAll(req.Body)
239 | if err != nil {
240 | r.Logger.Error("Error reading request body: %v", err)
241 | http.Error(w, "error reading request body: "+err.Error(), http.StatusBadRequest)
242 | isError = true
243 | r.UI.Stats.RecordRequest("POST", "unknown", time.Since(startTime), isError)
244 | return
245 | }
246 |
247 | // Parse the JSON-RPC request to determine routing
248 | var jsonRPCRequest JSONRPCRequest
249 | if err := json.Unmarshal(body, &jsonRPCRequest); err != nil {
250 | // Try to parse as a batch request
251 | var batchRequest []JSONRPCRequest
252 | if batchErr := json.Unmarshal(body, &batchRequest); batchErr != nil {
253 | r.Logger.Error("Error parsing JSON-RPC request: %v", err)
254 | http.Error(w, "error parsing JSON-RPC request: "+err.Error(), http.StatusBadRequest)
255 | isError = true
256 | r.UI.Stats.RecordRequest("POST", "invalid_json", time.Since(startTime), isError)
257 | return
258 | }
259 | // Use the first request in the batch for routing
260 | if len(batchRequest) > 0 {
261 | jsonRPCRequest = batchRequest[0]
262 | }
263 | }
264 |
265 | // Extract session ID if present
266 | sessionID := req.Header.Get("Mcp-Session-Id")
267 | r.Logger.Debug("Session ID: %s", sessionID)
268 |
269 | // Check if this is an initialization request
270 | isInitialize := jsonRPCRequest.Method == "initialize"
271 |
272 | // Determine target URL based on the request content
273 | targetURL := r.determineTargetForSession(sessionID)
274 | r.Logger.Info("Routing request to target: %s", targetURL)
275 |
276 | // Apply tool name transformation if needed
277 | r.transformToolCall(jsonRPCRequest, targetURL)
278 |
279 | // Re-encode the possibly modified request
280 | modifiedBody, err := json.Marshal(jsonRPCRequest)
281 | if err != nil {
282 | http.Error(w, "error encoding request: "+err.Error(), http.StatusInternalServerError)
283 | return
284 | }
285 |
286 | // Create a new POST request to the target server
287 | proxyReq, err := http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(modifiedBody))
288 | if err != nil {
289 | r.Logger.Error("Error creating proxy request: %v", err)
290 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError)
291 | isError = true
292 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
293 | return
294 | }
295 |
296 | // Copy headers from the original request
297 | proxyReq.Header.Set("Content-Type", "application/json")
298 |
299 | // Set Accept header to support both JSON and SSE responses
300 | proxyReq.Header.Set("Accept", req.Header.Get("Accept"))
301 | if proxyReq.Header.Get("Accept") == "" {
302 | proxyReq.Header.Set("Accept", "application/json, text/event-stream")
303 | }
304 |
305 | // Forward session ID if present
306 | if sessionID != "" {
307 | proxyReq.Header.Set("Mcp-Session-Id", sessionID)
308 | }
309 |
310 | // Make the request to the target server with a reasonable timeout
311 | client := &http.Client{
312 | Timeout: 30 * time.Second, // Use a longer timeout for initialization
313 | }
314 |
315 | resp, err := client.Do(proxyReq)
316 | if err != nil {
317 | r.Logger.Error("Error connecting to target server: %v", err)
318 | http.Error(w, "error connecting to target server: "+err.Error(), http.StatusBadGateway)
319 | isError = true
320 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
321 | return
322 | }
323 | defer resp.Body.Close()
324 |
325 | // Check if the target server returned an error
326 | if resp.StatusCode >= 400 {
327 | // Read error body and forward it to the client
328 | errorBody, _ := io.ReadAll(resp.Body)
329 | r.Logger.Error("Target server returned error status %d: %s", resp.StatusCode, string(errorBody))
330 | for name, values := range resp.Header {
331 | for _, value := range values {
332 | w.Header().Add(name, value)
333 | }
334 | }
335 | w.WriteHeader(resp.StatusCode)
336 | w.Write(errorBody)
337 | isError = true
338 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
339 | return
340 | }
341 |
342 | // Check if this is an initialization response with a session ID
343 | if isInitialize && resp.Header.Get("Mcp-Session-Id") != "" {
344 | newSessionID := resp.Header.Get("Mcp-Session-Id")
345 | r.Logger.Info("Received new session ID: %s", newSessionID)
346 |
347 | // Store the session ID and target URL mapping
348 | r.SessionStore.Set(newSessionID, targetURL)
349 | }
350 |
351 | // Check content type to determine how to handle the response
352 | contentType := resp.Header.Get("Content-Type")
353 |
354 | // Copy all response headers to the client
355 | for name, values := range resp.Header {
356 | for _, value := range values {
357 | w.Header().Add(name, value)
358 | }
359 | }
360 |
361 | // If the response is an SSE stream, handle it accordingly
362 | if strings.Contains(contentType, "text/event-stream") {
363 | r.Logger.Debug("Target server returned SSE stream")
364 |
365 | // Set SSE headers
366 | w.Header().Set("Content-Type", "text/event-stream")
367 | w.Header().Set("Cache-Control", "no-cache")
368 | w.Header().Set("Connection", "keep-alive")
369 |
370 | // Create a context that's canceled when the client connection closes
371 | ctx, cancel := context.WithCancel(req.Context())
372 | defer cancel()
373 |
374 | // Create a done channel to signal when to close the connection
375 | done := make(chan bool)
376 | var once sync.Once // Add this to ensure the channel is only closed once
377 |
378 | // Handle client disconnection
379 | go func() {
380 | <-ctx.Done()
381 | once.Do(func() { close(done) }) // Use sync.Once to safely close the channel
382 | r.Logger.Debug("Client disconnected, closing SSE stream")
383 | }()
384 |
385 | // Stream SSE events from target to client
386 | scanner := bufio.NewScanner(resp.Body)
387 | // Set a larger buffer for the scanner to handle large SSE events
388 | const maxScanTokenSize = 1024 * 1024 // 1MB
389 | buf := make([]byte, maxScanTokenSize)
390 | scanner.Buffer(buf, maxScanTokenSize)
391 |
392 | go func() {
393 | for scanner.Scan() {
394 | select {
395 | case <-done:
396 | return
397 | default:
398 | line := scanner.Text()
399 | fmt.Fprintf(w, "%s\n", line)
400 | // If this is the end of an event, flush the buffer
401 | if line == "" {
402 | if flusher, ok := w.(http.Flusher); ok {
403 | flusher.Flush()
404 | }
405 | }
406 | }
407 | }
408 |
409 | if err := scanner.Err(); err != nil {
410 | r.Logger.Error("Error scanning SSE stream: %v", err)
411 | }
412 |
413 | once.Do(func() { close(done) }) // Safely close the channel if not already closed
414 | }()
415 |
416 | // Set up heartbeat ticker
417 | heartbeatTicker := time.NewTicker(30 * time.Second)
418 | defer heartbeatTicker.Stop()
419 |
420 | // Send heartbeat events to keep the connection alive
421 | go func() {
422 | for {
423 | select {
424 | case <-heartbeatTicker.C:
425 | // Send a heartbeat event
426 | fmt.Fprintf(w, "event: heartbeat\ndata: %d\n\n", time.Now().Unix())
427 | if flusher, ok := w.(http.Flusher); ok {
428 | flusher.Flush()
429 | }
430 | r.Logger.Debug("Sent heartbeat event")
431 | case <-done:
432 | return
433 | }
434 | }
435 | }()
436 |
437 | // Wait until done
438 | <-done
439 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
440 | return
441 | }
442 |
443 | // For regular JSON responses, just copy the response body
444 | responseBody, err := io.ReadAll(resp.Body)
445 | if err != nil {
446 | r.Logger.Error("Error reading response body: %v", err)
447 | http.Error(w, "error reading response body: "+err.Error(), http.StatusInternalServerError)
448 | isError = true
449 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
450 | return
451 | }
452 |
453 | // Write the response status code and body
454 | w.WriteHeader(resp.StatusCode)
455 | w.Write(responseBody)
456 |
457 | // If this was an initialization request, log the response
458 | if isInitialize {
459 | r.Logger.Info("Successfully processed initialization request")
460 |
461 | // Parse the response to extract any relevant information
462 | var initResponse map[string]interface{}
463 | if err := json.Unmarshal(responseBody, &initResponse); err == nil {
464 | r.Logger.Debug("Initialization response: %v", initResponse)
465 | }
466 | }
467 |
468 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError)
469 | return
470 | }
471 |
472 | // Handle POST requests (client sending messages to server)
473 | if req.Method != http.MethodPost {
474 | r.Logger.Warn("Received unsupported method: %s", req.Method)
475 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
476 | isError = true
477 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError)
478 | return
479 | }
480 |
481 | // Check Accept header
482 | acceptHeader := req.Header.Get("Accept")
483 | if !strings.Contains(acceptHeader, "application/json") && !strings.Contains(acceptHeader, "text/event-stream") {
484 | r.Logger.Warn("Invalid Accept header: %s", acceptHeader)
485 | http.Error(w, "Invalid Accept header", http.StatusBadRequest)
486 | isError = true
487 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError)
488 | return
489 | }
490 |
491 | // Extract session ID if present
492 | sessionID := req.Header.Get("Mcp-Session-Id")
493 | r.Logger.Debug("Received POST request with session ID: %s", sessionID)
494 |
495 | body, err := io.ReadAll(req.Body)
496 | if err != nil {
497 | r.Logger.Error("Failed to read request body: %v", err)
498 | http.Error(w, "failed to read request body", http.StatusBadRequest)
499 | isError = true
500 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError)
501 | return
502 | }
503 |
504 | // Try to parse as a single request or as a batch
505 | var method string
506 | var targetURL string
507 |
508 | // First try to parse as a single request
509 | var rpcReq JSONRPCRequest
510 | if err := json.Unmarshal(body, &rpcReq); err == nil {
511 | // Successfully parsed as a single request
512 | method = rpcReq.Method
513 | r.Logger.Debug("Parsed single JSON-RPC request with method: %s", method)
514 |
515 | // Determine target URL based on the request
516 | if rpcReq.Method == "initialize" {
517 | // For initialize requests, use the default target
518 | targetURL = r.Config.GetDefault()
519 | r.Logger.Info("Initialize request, using default target: %s", targetURL)
520 | } else {
521 | // For other requests, use the routing logic
522 | targetURL = r.RouteByContext(rpcReq)
523 | r.Logger.Info("Routing method '%s' to target: %s", method, targetURL)
524 | }
525 | } else {
526 | // Try to parse as a batch
527 | var batchReq []JSONRPCRequest
528 | if err := json.Unmarshal(body, &batchReq); err == nil {
529 | // Successfully parsed as a batch
530 | r.Logger.Debug("Parsed batch request with %d methods", len(batchReq))
531 |
532 | // For simplicity, use the first request's method for logging
533 | if len(batchReq) > 0 {
534 | method = batchReq[0].Method
535 |
536 | // For batch requests, we need a routing strategy
537 | // Here we use the first request to determine the target
538 | targetURL = r.RouteByContext(batchReq[0])
539 | r.Logger.Info("Routing batch request (first method: '%s') to target: %s", method, targetURL)
540 | } else {
541 | method = "batch"
542 | targetURL = r.Config.GetDefault()
543 | r.Logger.Info("Empty batch request, using default target: %s", targetURL)
544 | }
545 | } else {
546 | // Failed to parse as either single request or batch
547 | r.Logger.Error("Failed to parse JSON-RPC request: %v", err)
548 | http.Error(w, "invalid JSON-RPC request", http.StatusBadRequest)
549 | isError = true
550 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError)
551 | return
552 | }
553 | }
554 |
555 | // Create a new POST request to the target server
556 | proxyReq, err := http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body))
557 | if err != nil {
558 | r.Logger.Error("Error creating proxy request: %v", err)
559 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError)
560 | isError = true
561 | r.UI.Stats.RecordRequest(method, targetURL, time.Since(startTime), isError)
562 | return
563 | }
564 |
565 | // Copy relevant headers from the original request
566 | proxyReq.Header.Set("Content-Type", "application/json")
567 | proxyReq.Header.Set("Accept", acceptHeader)
568 | if sessionID != "" {
569 | proxyReq.Header.Set("Mcp-Session-Id", sessionID)
570 | }
571 |
572 | // Forward the Authorization header
573 | if authHeader := req.Header.Get("Authorization"); authHeader != "" {
574 | proxyReq.Header.Set("Authorization", authHeader)
575 | r.Logger.Debug("Forwarding Authorization header")
576 | }
577 |
578 | // Forward MCP-Protocol-Version header if present
579 | if protocolVersion := req.Header.Get("MCP-Protocol-Version"); protocolVersion != "" {
580 | proxyReq.Header.Set("MCP-Protocol-Version", protocolVersion)
581 | }
582 |
583 | // Make the request to the target server
584 | client := &http.Client{
585 | Timeout: DefaultTimeout,
586 | }
587 | ctx, cancel := context.WithTimeout(req.Context(), DefaultTimeout)
588 | defer cancel()
589 | proxyReq = proxyReq.WithContext(ctx)
590 | resp, err := client.Do(proxyReq)
591 | if err != nil {
592 | r.Logger.Error("Error forwarding request to target: %v", err)
593 | http.Error(w, "error forwarding request: "+err.Error(), http.StatusBadGateway)
594 | isError = true
595 | r.UI.Stats.RecordRequest(method, targetURL, time.Since(startTime), isError)
596 | return
597 | }
598 | defer resp.Body.Close()
599 |
600 | // Check if the response contains a session ID
601 | sessionID = resp.Header.Get("Mcp-Session-Id")
602 | if sessionID != "" {
603 | // Store the session ID -> target mapping
604 | r.SessionStore.Set(sessionID, targetURL)
605 | r.Logger.Debug("Stored session mapping: %s -> %s", sessionID, targetURL)
606 | }
607 |
608 | // Copy response headers
609 | for name, values := range resp.Header {
610 | for _, value := range values {
611 | w.Header().Add(name, value)
612 | }
613 | }
614 |
615 | // Copy response status code
616 | w.WriteHeader(resp.StatusCode)
617 |
618 | // Copy response body
619 | respBody, err := io.ReadAll(resp.Body)
620 | if err != nil {
621 | r.Logger.Error("Error reading response body: %v", err)
622 | // We've already written headers, so we can't send an HTTP error
623 | return
624 | }
625 | w.Write(respBody)
626 |
627 | // Record stats
628 | duration := time.Since(startTime)
629 | isError = resp.StatusCode >= 400
630 | r.UI.Stats.RecordRequest(method, targetURL, duration, isError)
631 | }
632 |
633 | // determineTargetForSession returns the target URL for a given session ID
634 | func (r *Router) determineTargetForSession(sessionID string) string {
635 | // If no session ID is provided, return the default target
636 | if sessionID == "" {
637 | r.Logger.Debug("No session ID provided, using default target")
638 | return r.Config.GetDefault()
639 | }
640 |
641 | // Check if we have this session ID in our session store
642 | if target, exists := r.SessionStore.Get(sessionID); exists {
643 | r.Logger.Debug("Found existing session mapping: %s -> %s", sessionID, target)
644 | return target
645 | }
646 |
647 | // If the session ID is not recognized, we have a few options:
648 | // 1. Return the default target
649 | // 2. Use a consistent hashing algorithm to map the session ID to a target
650 | // 3. Use a round-robin or load-balancing approach
651 |
652 | // For now, we'll use a simple approach - hash the session ID to consistently
653 | // map it to one of our configured targets
654 | targets := r.Config.GetAllTargets()
655 | if len(targets) == 0 {
656 | return r.Config.GetDefault()
657 | }
658 |
659 | // Use a hash of the session ID to pick a target
660 | h := fnv.New32a()
661 | h.Write([]byte(sessionID))
662 | index := int(h.Sum32()) % len(targets)
663 | target := targets[index]
664 |
665 | r.Logger.Debug("Created new session mapping via hashing: %s -> %s", sessionID, target)
666 | // Store this mapping for future reference
667 | r.SessionStore.Set(sessionID, target)
668 |
669 | return target
670 | }
671 |
672 | // RouteByContext determines the target URL based on the request method and parameters
673 | func (r *Router) RouteByContext(req JSONRPCRequest) string {
674 | switch req.Method {
675 | case "resources/read":
676 | uri, ok := req.Params["uri"].(string)
677 | if ok {
678 | r.Logger.Debug("Routing resources/read for URI: %s", uri)
679 | for _, rule := range r.Config.GetResourceRegexes() {
680 | if rule.Pattern.MatchString(uri) {
681 | r.Logger.Debug("Matched resource rule, using target: %s", rule.Target)
682 | return rule.Target
683 | }
684 | }
685 | }
686 | case "tools/call":
687 | name, ok := req.Params["name"].(string)
688 | if ok {
689 | r.Logger.Debug("Routing tools/call for tool: %s", name)
690 | for _, rule := range r.Config.GetToolRegexes() {
691 | if rule.Pattern.MatchString(name) {
692 | r.Logger.Debug("Matched tool rule, using target: %s", rule.Target)
693 | return rule.Target
694 | }
695 | }
696 | }
697 | }
698 | r.Logger.Debug("No specific routing rule matched, using default target")
699 | return r.Config.GetDefault()
700 | }
701 |
702 | // transformToolCall applies tool name transformation if needed
703 | func (r *Router) transformToolCall(req JSONRPCRequest, targetURL string) {
704 | // Access fields directly
705 | if req.Method != "tools/call" {
706 | return
707 | }
708 |
709 | // Check if name exists in params
710 | name, ok := req.Params["name"].(string)
711 | if !ok {
712 | return
713 | }
714 |
715 | // Rest of function remains the same
716 | for _, mapping := range r.Config.GetToolMappings() {
717 | if mapping.OriginalName == name && mapping.Target == targetURL {
718 | // Transform the tool name
719 | req.Params["name"] = mapping.TargetName
720 | return
721 | }
722 | }
723 | }
724 |
--------------------------------------------------------------------------------