├── .dockerignore ├── .gitignore ├── ui.png ├── go.mod ├── go.sum ├── prometheus.yml ├── .github └── workflows │ └── pull-request.yaml ├── docker-compose.yml ├── Dockerfile ├── LICENSE ├── router_config.yaml ├── pkg ├── session │ ├── store.go │ └── store_test.go ├── logger │ └── logger.go ├── config │ ├── config.go │ └── config_test.go ├── ui │ ├── ui_test.go │ └── ui.go └── router │ ├── router_test.go │ └── router.go ├── cmd └── main.go └── README.md /.dockerignore: -------------------------------------------------------------------------------- 1 | vendor/ 2 | .idea/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | vendor/ 3 | *.log -------------------------------------------------------------------------------- /ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mclenhard/catie-mcp/HEAD/ui.png -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mclenhard/catie-mcp 2 | 3 | go 1.21.2 4 | 5 | require gopkg.in/yaml.v3 v3.0.1 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 2 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 3 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 4 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 5 | -------------------------------------------------------------------------------- /prometheus.yml: -------------------------------------------------------------------------------- 1 | global: 2 | scrape_interval: 15s 3 | evaluation_interval: 15s 4 | 5 | scrape_configs: 6 | - job_name: 'mcp-router-proxy' 7 | scrape_interval: 5s 8 | metrics_path: /metrics 9 | # If you're using basic auth, uncomment and configure these lines: 10 | basic_auth: 11 | username: 'admin' 12 | password: 'your_secure_password' 13 | static_configs: 14 | - targets: ['mcp-router-proxy:80'] -------------------------------------------------------------------------------- /.github/workflows/pull-request.yaml: -------------------------------------------------------------------------------- 1 | name: Go Tests 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | jobs: 8 | test: 9 | name: Run Go Tests 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Check out code 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v4 18 | with: 19 | go-version: '1.21' # Adjust this to your Go version 20 | 21 | - name: Install dependencies 22 | run: go mod download 23 | 24 | - name: Run tests 25 | run: go test -v ./... 26 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | mcp-router-proxy: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | container_name: mcp-router-proxy 9 | ports: 10 | - "80:80" 11 | volumes: 12 | - ./router_config.yaml:/root/router_config.yaml 13 | restart: unless-stopped 14 | environment: 15 | - GOTOOLCHAIN=auto 16 | networks: 17 | - mcp-network 18 | 19 | prometheus: 20 | image: prom/prometheus:latest 21 | container_name: prometheus 22 | ports: 23 | - "9090:9090" 24 | volumes: 25 | - ./prometheus.yml:/etc/prometheus/prometheus.yml 26 | restart: unless-stopped 27 | networks: 28 | - mcp-network 29 | depends_on: 30 | - mcp-router-proxy 31 | 32 | networks: 33 | mcp-network: 34 | driver: bridge -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Start from a Go base image 2 | FROM golang:1.21-alpine AS builder 3 | 4 | # Set working directory 5 | WORKDIR /app 6 | 7 | # Copy go.mod and go.sum files 8 | COPY go.mod ./ 9 | # COPY go.sum ./ # Uncomment if you have a go.sum file 10 | 11 | # Download dependencies 12 | RUN go mod download 13 | 14 | # Copy the source code 15 | COPY . . 16 | 17 | # Build the application 18 | RUN CGO_ENABLED=0 GOOS=linux go build -o mcp-router-proxy ./cmd/main.go 19 | 20 | # Use a minimal alpine image for the final container 21 | FROM alpine:latest 22 | 23 | # Install ca-certificates for HTTPS requests 24 | RUN apk --no-cache add ca-certificates 25 | 26 | WORKDIR /root/ 27 | 28 | # Copy the binary from the builder stage 29 | COPY --from=builder /app/mcp-router-proxy . 30 | # Copy the configuration file 31 | COPY router_config.yaml . 32 | 33 | # Expose the port the app runs on 34 | EXPOSE 80 35 | 36 | # Command to run the executable 37 | CMD ["./mcp-router-proxy"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MCP Router Proxy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /router_config.yaml: -------------------------------------------------------------------------------- 1 | # router_config.yaml 2 | resources: 3 | # Route file resources to file service 4 | "^file://.*": "http://file-service:8080/rpc" 5 | # Route database resources to DB service 6 | "^db://.*": "http://database-service:8080/rpc" 7 | # Route API resources to API gateway 8 | "^api://.*": "http://api-gateway:8080/rpc" 9 | # Route S3 resources to storage service 10 | "^s3://.*": "http://storage-service:8080/rpc" 11 | "^weather/.*": "http://weather-service:8080/mcp" 12 | "^database/.*": "http://database-service:8080/mcp" 13 | 14 | tools: 15 | # Route code generation tools to AI service 16 | "^code-.*": "http://ai-service:8080/rpc" 17 | # Route data processing tools to data service 18 | "^data-.*": "http://data-processing:8080/rpc" 19 | # Route search tools to search service 20 | "^search.*": "http://search-service:8080/rpc" 21 | # Route authentication tools to auth service 22 | "^auth.*": "http://auth-service:8080/rpc" 23 | "^calculator$": "http://calculator-service:8080/mcp" 24 | "^translator$": "http://translator-service:8080/mcp" 25 | 26 | # Default service to route to if no patterns match 27 | default: "http://default-service:8080/mcp" 28 | 29 | # UI authentication settings 30 | ui: 31 | username: "admin" 32 | password: "your_secure_password" 33 | 34 | toolMappings: 35 | - originalName: "weather" 36 | targetName: "getWeather" 37 | target: "http://weather-service:8080/mcp" 38 | - originalName: "search" 39 | targetName: "googleSearch" 40 | target: "http://search-service:8080/mcp" 41 | -------------------------------------------------------------------------------- /pkg/session/store.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // Store manages mappings between session IDs and target URLs 9 | type Store struct { 10 | sessions map[string]sessionInfo 11 | mu sync.RWMutex 12 | } 13 | 14 | type sessionInfo struct { 15 | target string 16 | lastUsed time.Time 17 | } 18 | 19 | // NewStore creates a new session store 20 | func NewStore() *Store { 21 | store := &Store{ 22 | sessions: make(map[string]sessionInfo), 23 | } 24 | 25 | // Start a background goroutine to clean up expired sessions 26 | go store.cleanupLoop() 27 | 28 | return store 29 | } 30 | 31 | // Get retrieves the target URL for a session ID 32 | func (s *Store) Get(sessionID string) (string, bool) { 33 | s.mu.RLock() 34 | info, exists := s.sessions[sessionID] 35 | s.mu.RUnlock() 36 | 37 | if exists { 38 | // Update the last used time 39 | s.mu.Lock() 40 | info.lastUsed = time.Now() 41 | s.sessions[sessionID] = info 42 | s.mu.Unlock() 43 | } 44 | 45 | return info.target, exists 46 | } 47 | 48 | // Set stores a mapping between a session ID and a target URL 49 | func (s *Store) Set(sessionID, target string) { 50 | s.mu.Lock() 51 | defer s.mu.Unlock() 52 | 53 | s.sessions[sessionID] = sessionInfo{ 54 | target: target, 55 | lastUsed: time.Now(), 56 | } 57 | } 58 | 59 | // Remove deletes a session mapping 60 | func (s *Store) Remove(sessionID string) { 61 | s.mu.Lock() 62 | defer s.mu.Unlock() 63 | 64 | delete(s.sessions, sessionID) 65 | } 66 | 67 | // cleanupLoop periodically removes expired sessions 68 | func (s *Store) cleanupLoop() { 69 | ticker := time.NewTicker(10 * time.Minute) 70 | defer ticker.Stop() 71 | 72 | for range ticker.C { 73 | s.cleanup() 74 | } 75 | } 76 | 77 | // cleanup removes sessions that haven't been used for a while 78 | func (s *Store) cleanup() { 79 | s.mu.Lock() 80 | defer s.mu.Unlock() 81 | 82 | expireTime := time.Now().Add(-24 * time.Hour) // Sessions expire after 24 hours of inactivity 83 | 84 | for id, info := range s.sessions { 85 | if info.lastUsed.Before(expireTime) { 86 | delete(s.sessions, id) 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /pkg/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // Level represents the severity level of a log message 11 | type Level int 12 | 13 | const ( 14 | // Debug level for detailed information 15 | Debug Level = iota 16 | // Info level for general operational information 17 | Info 18 | // Warn level for warning conditions 19 | Warn 20 | // Error level for error conditions 21 | Error 22 | // Fatal level for fatal conditions 23 | Fatal 24 | ) 25 | 26 | var levelNames = map[Level]string{ 27 | Debug: "DEBUG", 28 | Info: "INFO", 29 | Warn: "WARN", 30 | Error: "ERROR", 31 | Fatal: "FATAL", 32 | } 33 | 34 | // Logger is a simple structured logger 35 | type Logger struct { 36 | level Level 37 | logger *log.Logger 38 | } 39 | 40 | // New creates a new logger with the specified level 41 | func New(level Level) *Logger { 42 | return &Logger{ 43 | level: level, 44 | logger: log.New(os.Stdout, "", 0), 45 | } 46 | } 47 | 48 | // SetLevel changes the logger's level 49 | func (l *Logger) SetLevel(level Level) { 50 | l.level = level 51 | } 52 | 53 | // log logs a message at the specified level 54 | func (l *Logger) log(level Level, format string, args ...interface{}) { 55 | if level < l.level { 56 | return 57 | } 58 | 59 | timestamp := time.Now().Format("2006-01-02 15:04:05.000") 60 | levelName := levelNames[level] 61 | message := fmt.Sprintf(format, args...) 62 | 63 | l.logger.Printf("[%s] %s: %s", timestamp, levelName, message) 64 | 65 | if level == Fatal { 66 | os.Exit(1) 67 | } 68 | } 69 | 70 | // Debug logs a debug message 71 | func (l *Logger) Debug(format string, args ...interface{}) { 72 | l.log(Debug, format, args...) 73 | } 74 | 75 | // Info logs an info message 76 | func (l *Logger) Info(format string, args ...interface{}) { 77 | l.log(Info, format, args...) 78 | } 79 | 80 | // Warn logs a warning message 81 | func (l *Logger) Warn(format string, args ...interface{}) { 82 | l.log(Warn, format, args...) 83 | } 84 | 85 | // Error logs an error message 86 | func (l *Logger) Error(format string, args ...interface{}) { 87 | l.log(Error, format, args...) 88 | } 89 | 90 | // Fatal logs a fatal message and exits 91 | func (l *Logger) Fatal(format string, args ...interface{}) { 92 | l.log(Fatal, format, args...) 93 | } 94 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | // Main entry point for the MCP router proxy 2 | package main 3 | 4 | import ( 5 | "context" 6 | "flag" 7 | "log" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/mclenhard/catie-mcp/pkg/config" 15 | "github.com/mclenhard/catie-mcp/pkg/router" 16 | "github.com/mclenhard/catie-mcp/pkg/ui" 17 | ) 18 | 19 | var ( 20 | configPath = flag.String("config", "router_config.yaml", "Path to configuration file") 21 | serverPort = flag.String("port", ":80", "Server port") 22 | configInterval = flag.Duration("config-interval", 30*time.Second, "Configuration reload interval") 23 | shutdownTimeout = flag.Duration("shutdown-timeout", 30*time.Second, "Graceful shutdown timeout") 24 | ) 25 | 26 | func main() { 27 | flag.Parse() 28 | 29 | // Initialize configuration 30 | cfg := config.New(*configPath) 31 | 32 | // Start config watcher in background 33 | go cfg.WatchConfig(*configPath, *configInterval) 34 | 35 | // Initialize UI 36 | uiHandler := ui.New(cfg) 37 | 38 | // Initialize router with UI 39 | r := router.New(cfg, uiHandler) 40 | 41 | // Set up HTTP server 42 | mux := http.NewServeMux() 43 | mux.HandleFunc("/mcp", r.HandleMCPRequest) 44 | mux.HandleFunc("/health", handleHealth) 45 | 46 | // Register UI handlers with our mux 47 | uiHandler.RegisterHandlers(mux) 48 | 49 | server := &http.Server{ 50 | Addr: *serverPort, 51 | Handler: mux, 52 | } 53 | 54 | // Start server in a goroutine 55 | go func() { 56 | log.Printf("MCP router proxy listening on %s", *serverPort) 57 | log.Printf("Stats UI available at http://localhost%s/stats", *serverPort) 58 | if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { 59 | log.Fatalf("Server error: %v", err) 60 | } 61 | }() 62 | 63 | // Set up graceful shutdown 64 | quit := make(chan os.Signal, 1) 65 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 66 | <-quit 67 | log.Println("Shutting down server...") 68 | 69 | // Create a deadline for shutdown 70 | ctx, cancel := context.WithTimeout(context.Background(), *shutdownTimeout) 71 | defer cancel() 72 | 73 | // Attempt graceful shutdown 74 | if err := server.Shutdown(ctx); err != nil { 75 | log.Fatalf("Server forced to shutdown: %v", err) 76 | } 77 | 78 | log.Println("Server exited gracefully") 79 | } 80 | 81 | // handleHealth is a simple health check endpoint 82 | func handleHealth(w http.ResponseWriter, r *http.Request) { 83 | w.WriteHeader(http.StatusOK) 84 | w.Write([]byte("OK")) 85 | } 86 | -------------------------------------------------------------------------------- /pkg/session/store_test.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestStore_SetAndGet(t *testing.T) { 9 | store := NewStore() 10 | 11 | // Test setting and getting a session 12 | sessionID := "test-session-id" 13 | targetURL := "https://example.com" 14 | 15 | store.Set(sessionID, targetURL) 16 | 17 | got, exists := store.Get(sessionID) 18 | if !exists { 19 | t.Errorf("Get(%q) returned exists=false, want true", sessionID) 20 | } 21 | if got != targetURL { 22 | t.Errorf("Get(%q) = %q, want %q", sessionID, got, targetURL) 23 | } 24 | } 25 | 26 | func TestStore_Remove(t *testing.T) { 27 | store := NewStore() 28 | 29 | // Set up a session 30 | sessionID := "test-session-id" 31 | targetURL := "https://example.com" 32 | store.Set(sessionID, targetURL) 33 | 34 | // Verify it exists 35 | _, exists := store.Get(sessionID) 36 | if !exists { 37 | t.Fatalf("Session should exist before removal") 38 | } 39 | 40 | // Remove the session 41 | store.Remove(sessionID) 42 | 43 | // Verify it no longer exists 44 | _, exists = store.Get(sessionID) 45 | if exists { 46 | t.Errorf("Session should not exist after removal") 47 | } 48 | } 49 | 50 | func TestStore_Cleanup(t *testing.T) { 51 | store := NewStore() 52 | 53 | // Set up a session 54 | sessionID := "test-session-id" 55 | targetURL := "https://example.com" 56 | store.Set(sessionID, targetURL) 57 | 58 | // Manually modify the lastUsed time to be older than the expiration time 59 | store.mu.Lock() 60 | info := store.sessions[sessionID] 61 | info.lastUsed = time.Now().Add(-25 * time.Hour) // Older than the 24-hour expiration 62 | store.sessions[sessionID] = info 63 | store.mu.Unlock() 64 | 65 | // Run cleanup 66 | store.cleanup() 67 | 68 | // Verify the session was removed 69 | _, exists := store.Get(sessionID) 70 | if exists { 71 | t.Errorf("Session should have been cleaned up") 72 | } 73 | } 74 | 75 | func TestStore_GetUpdatesLastUsed(t *testing.T) { 76 | store := NewStore() 77 | 78 | // Set up a session 79 | sessionID := "test-session-id" 80 | targetURL := "https://example.com" 81 | store.Set(sessionID, targetURL) 82 | 83 | // Get the initial last used time 84 | store.mu.RLock() 85 | initialTime := store.sessions[sessionID].lastUsed 86 | store.mu.RUnlock() 87 | 88 | // Wait a small amount of time 89 | time.Sleep(10 * time.Millisecond) 90 | 91 | // Get the session, which should update the last used time 92 | store.Get(sessionID) 93 | 94 | // Check that the last used time was updated 95 | store.mu.RLock() 96 | updatedTime := store.sessions[sessionID].lastUsed 97 | store.mu.RUnlock() 98 | 99 | if !updatedTime.After(initialTime) { 100 | t.Errorf("Last used time was not updated: initial=%v, updated=%v", initialTime, updatedTime) 101 | } 102 | } -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | // Package config provides configuration handling for the MCP router proxy 2 | package config 3 | 4 | import ( 5 | "log" 6 | "os" 7 | "regexp" 8 | "sync" 9 | "time" 10 | 11 | "gopkg.in/yaml.v3" 12 | ) 13 | 14 | // RouteConfig structure for route mapping 15 | type RouteConfig struct { 16 | Resources map[string]string `yaml:"resources"` 17 | Tools map[string]string `yaml:"tools"` 18 | Default string `yaml:"default"` 19 | UI struct { 20 | Username string `yaml:"username"` 21 | Password string `yaml:"password"` 22 | } `yaml:"ui"` 23 | } 24 | 25 | // RouteRule represents a compiled regex pattern and its target 26 | type RouteRule struct { 27 | Pattern *regexp.Regexp 28 | Target string 29 | } 30 | 31 | // ToolMapping structure for tool mapping 32 | type ToolMapping struct { 33 | OriginalName string `yaml:"originalName"` 34 | TargetName string `yaml:"targetName"` 35 | Target string `yaml:"target"` 36 | } 37 | 38 | // Config holds the application configuration and related data 39 | type Config struct { 40 | RouteConfig RouteConfig 41 | ResourceRegexes []RouteRule 42 | ToolRegexes []RouteRule 43 | ToolMappings []ToolMapping 44 | ConfigMutex sync.RWMutex 45 | } 46 | 47 | // New creates a new Config instance and loads the initial configuration 48 | func New(configPath string) *Config { 49 | c := &Config{} 50 | c.Load(configPath) 51 | return c 52 | } 53 | 54 | // Load reads and parses the configuration file 55 | func (c *Config) Load(configPath string) { 56 | configData, err := os.ReadFile(configPath) 57 | if err != nil { 58 | log.Fatalf("Failed to load config: %v", err) 59 | } 60 | 61 | var newConfig RouteConfig 62 | if err := yaml.Unmarshal(configData, &newConfig); err != nil { 63 | log.Fatalf("Invalid config format: %v", err) 64 | } 65 | 66 | var newResources []RouteRule 67 | for pattern, url := range newConfig.Resources { 68 | r, err := regexp.Compile(pattern) 69 | if err != nil { 70 | log.Fatalf("Invalid resource regex pattern '%s': %v", pattern, err) 71 | } 72 | newResources = append(newResources, RouteRule{r, url}) 73 | } 74 | 75 | var newTools []RouteRule 76 | for pattern, url := range newConfig.Tools { 77 | r, err := regexp.Compile(pattern) 78 | if err != nil { 79 | log.Fatalf("Invalid tool regex pattern '%s': %v", pattern, err) 80 | } 81 | newTools = append(newTools, RouteRule{r, url}) 82 | } 83 | 84 | c.ConfigMutex.Lock() 85 | c.RouteConfig = newConfig 86 | c.ResourceRegexes = newResources 87 | c.ToolRegexes = newTools 88 | c.ConfigMutex.Unlock() 89 | 90 | log.Println("Router config loaded") 91 | } 92 | 93 | // WatchConfig monitors the config file for changes and reloads when necessary 94 | func (c *Config) WatchConfig(filename string, interval time.Duration) { 95 | lastMod := time.Time{} 96 | for { 97 | info, err := os.Stat(filename) 98 | if err == nil && info.ModTime().After(lastMod) { 99 | lastMod = info.ModTime() 100 | c.Load(filename) 101 | log.Println("Router config reloaded") 102 | } 103 | time.Sleep(interval) 104 | } 105 | } 106 | 107 | // GetDefault returns the default route 108 | func (c *Config) GetDefault() string { 109 | c.ConfigMutex.RLock() 110 | defer c.ConfigMutex.RUnlock() 111 | return c.RouteConfig.Default 112 | } 113 | 114 | // GetResourceRegexes returns a copy of the resource regexes with read lock 115 | func (c *Config) GetResourceRegexes() []RouteRule { 116 | c.ConfigMutex.RLock() 117 | defer c.ConfigMutex.RUnlock() 118 | return c.ResourceRegexes 119 | } 120 | 121 | // GetToolRegexes returns a copy of the tool regexes with read lock 122 | func (c *Config) GetToolRegexes() []RouteRule { 123 | c.ConfigMutex.RLock() 124 | defer c.ConfigMutex.RUnlock() 125 | return c.ToolRegexes 126 | } 127 | 128 | // GetAllTargets returns all configured target URLs 129 | func (c *Config) GetAllTargets() []string { 130 | c.ConfigMutex.RLock() 131 | defer c.ConfigMutex.RUnlock() 132 | 133 | targets := make([]string, 0) 134 | 135 | // Add the default target 136 | if c.RouteConfig.Default != "" { 137 | targets = append(targets, c.RouteConfig.Default) 138 | } 139 | 140 | // Add targets from resource rules 141 | for _, rule := range c.ResourceRegexes { 142 | if !contains(targets, rule.Target) { 143 | targets = append(targets, rule.Target) 144 | } 145 | } 146 | 147 | // Add targets from tool rules 148 | for _, rule := range c.ToolRegexes { 149 | if !contains(targets, rule.Target) { 150 | targets = append(targets, rule.Target) 151 | } 152 | } 153 | 154 | return targets 155 | } 156 | 157 | // Helper function to check if a slice contains a string 158 | func contains(slice []string, item string) bool { 159 | for _, s := range slice { 160 | if s == item { 161 | return true 162 | } 163 | } 164 | return false 165 | } 166 | 167 | // GetToolMappings returns a copy of the tool mappings with read lock 168 | func (c *Config) GetToolMappings() []ToolMapping { 169 | c.ConfigMutex.RLock() 170 | defer c.ConfigMutex.RUnlock() 171 | return c.ToolMappings 172 | } 173 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestNew(t *testing.T) { 11 | // Create a temporary config file 12 | tempDir := t.TempDir() 13 | configPath := filepath.Join(tempDir, "config.yaml") 14 | 15 | configContent := ` 16 | default: http://default-service:8080 17 | resources: 18 | "^/resource/([^/]+)$": "http://resource-service:8080" 19 | tools: 20 | "^/tool/([^/]+)$": "http://tool-service:8080" 21 | ` 22 | if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil { 23 | t.Fatalf("Failed to write test config: %v", err) 24 | } 25 | 26 | // Test creating a new config 27 | cfg := New(configPath) 28 | 29 | // Verify the config was loaded correctly 30 | if cfg.GetDefault() != "http://default-service:8080" { 31 | t.Errorf("Expected default to be 'http://default-service:8080', got '%s'", cfg.GetDefault()) 32 | } 33 | 34 | if len(cfg.GetResourceRegexes()) != 1 { 35 | t.Errorf("Expected 1 resource regex, got %d", len(cfg.GetResourceRegexes())) 36 | } 37 | 38 | if len(cfg.GetToolRegexes()) != 1 { 39 | t.Errorf("Expected 1 tool regex, got %d", len(cfg.GetToolRegexes())) 40 | } 41 | } 42 | 43 | func TestLoad(t *testing.T) { 44 | // Table-driven test 45 | tests := []struct { 46 | name string 47 | configContent string 48 | expectedDefault string 49 | expectedResourceCount int 50 | expectedToolCount int 51 | }{ 52 | { 53 | name: "basic config", 54 | configContent: ` 55 | default: http://default-service:8080 56 | resources: 57 | "^/resource/([^/]+)$": "http://resource-service:8080" 58 | tools: 59 | "^/tool/([^/]+)$": "http://tool-service:8080" 60 | `, 61 | expectedDefault: "http://default-service:8080", 62 | expectedResourceCount: 1, 63 | expectedToolCount: 1, 64 | }, 65 | { 66 | name: "multiple resources and tools", 67 | configContent: ` 68 | default: http://default-service:8080 69 | resources: 70 | "^/resource/([^/]+)$": "http://resource-service:8080" 71 | "^/api/v1/([^/]+)$": "http://api-service:8080" 72 | tools: 73 | "^/tool/([^/]+)$": "http://tool-service:8080" 74 | "^/utility/([^/]+)$": "http://utility-service:8080" 75 | `, 76 | expectedDefault: "http://default-service:8080", 77 | expectedResourceCount: 2, 78 | expectedToolCount: 2, 79 | }, 80 | { 81 | name: "no default", 82 | configContent: ` 83 | resources: 84 | "^/resource/([^/]+)$": "http://resource-service:8080" 85 | tools: 86 | "^/tool/([^/]+)$": "http://tool-service:8080" 87 | `, 88 | expectedDefault: "", 89 | expectedResourceCount: 1, 90 | expectedToolCount: 1, 91 | }, 92 | } 93 | 94 | for _, tt := range tests { 95 | t.Run(tt.name, func(t *testing.T) { 96 | // Create temp file with test content 97 | tempDir := t.TempDir() 98 | configPath := filepath.Join(tempDir, "config.yaml") 99 | if err := os.WriteFile(configPath, []byte(tt.configContent), 0644); err != nil { 100 | t.Fatalf("Failed to write test config: %v", err) 101 | } 102 | 103 | // Create config and test loading 104 | cfg := &Config{} 105 | cfg.Load(configPath) 106 | 107 | if cfg.GetDefault() != tt.expectedDefault { 108 | t.Errorf("Expected default to be '%s', got '%s'", tt.expectedDefault, cfg.GetDefault()) 109 | } 110 | 111 | if len(cfg.GetResourceRegexes()) != tt.expectedResourceCount { 112 | t.Errorf("Expected %d resource regexes, got %d", tt.expectedResourceCount, len(cfg.GetResourceRegexes())) 113 | } 114 | 115 | if len(cfg.GetToolRegexes()) != tt.expectedToolCount { 116 | t.Errorf("Expected %d tool regexes, got %d", tt.expectedToolCount, len(cfg.GetToolRegexes())) 117 | } 118 | }) 119 | } 120 | } 121 | 122 | func TestGetAllTargets(t *testing.T) { 123 | // Create a config with known values 124 | cfg := &Config{ 125 | RouteConfig: RouteConfig{ 126 | Default: "http://default:8080", 127 | Resources: map[string]string{ 128 | "pattern1": "http://resource1:8080", 129 | "pattern2": "http://resource2:8080", 130 | }, 131 | Tools: map[string]string{ 132 | "pattern3": "http://tool1:8080", 133 | "pattern4": "http://resource1:8080", // Duplicate target 134 | }, 135 | }, 136 | } 137 | 138 | // Manually set up the regexes (normally done by Load) 139 | cfg.ResourceRegexes = []RouteRule{ 140 | {Pattern: nil, Target: "http://resource1:8080"}, 141 | {Pattern: nil, Target: "http://resource2:8080"}, 142 | } 143 | 144 | cfg.ToolRegexes = []RouteRule{ 145 | {Pattern: nil, Target: "http://tool1:8080"}, 146 | {Pattern: nil, Target: "http://resource1:8080"}, 147 | } 148 | 149 | targets := cfg.GetAllTargets() 150 | 151 | // Update the expected number of unique targets 152 | expectedTargets := 4 // default + resource1 + resource2 + tool1 153 | if len(targets) != expectedTargets { 154 | t.Errorf("Expected %d unique targets, got %d: %v", expectedTargets, len(targets), targets) 155 | } 156 | 157 | // Check that all expected targets are present 158 | expectedURLs := []string{ 159 | "http://default:8080", 160 | "http://resource1:8080", 161 | "http://resource2:8080", 162 | "http://tool1:8080", 163 | } 164 | 165 | for _, url := range expectedURLs[:expectedTargets] { 166 | if !contains(targets, url) { 167 | t.Errorf("Expected target '%s' not found in results: %v", url, targets) 168 | } 169 | } 170 | } 171 | 172 | func TestWatchConfig(t *testing.T) { 173 | // Create a temporary config file 174 | tempDir := t.TempDir() 175 | configPath := filepath.Join(tempDir, "config.yaml") 176 | 177 | initialConfig := ` 178 | default: http://initial-default:8080 179 | resources: 180 | "^/resource/([^/]+)$": "http://initial-resource:8080" 181 | tools: 182 | "^/tool/([^/]+)$": "http://initial-tool:8080" 183 | ` 184 | if err := os.WriteFile(configPath, []byte(initialConfig), 0644); err != nil { 185 | t.Fatalf("Failed to write initial test config: %v", err) 186 | } 187 | 188 | // Create config 189 | cfg := New(configPath) 190 | 191 | // Start watching in a goroutine with a short interval 192 | go cfg.WatchConfig(configPath, 100*time.Millisecond) 193 | 194 | // Verify initial config 195 | if cfg.GetDefault() != "http://initial-default:8080" { 196 | t.Errorf("Initial default incorrect, got '%s'", cfg.GetDefault()) 197 | } 198 | 199 | // Wait a moment to ensure watcher is running 200 | time.Sleep(200 * time.Millisecond) 201 | 202 | // Update the config file 203 | updatedConfig := ` 204 | default: http://updated-default:8080 205 | resources: 206 | "^/resource/([^/]+)$": "http://updated-resource:8080" 207 | "^/new-resource/([^/]+)$": "http://new-resource:8080" 208 | tools: 209 | "^/tool/([^/]+)$": "http://updated-tool:8080" 210 | ` 211 | if err := os.WriteFile(configPath, []byte(updatedConfig), 0644); err != nil { 212 | t.Fatalf("Failed to write updated test config: %v", err) 213 | } 214 | 215 | // Wait for the watcher to detect the change 216 | time.Sleep(300 * time.Millisecond) 217 | 218 | // Verify config was updated 219 | if cfg.GetDefault() != "http://updated-default:8080" { 220 | t.Errorf("Updated default incorrect, got '%s'", cfg.GetDefault()) 221 | } 222 | 223 | if len(cfg.GetResourceRegexes()) != 2 { 224 | t.Errorf("Expected 2 resource regexes after update, got %d", len(cfg.GetResourceRegexes())) 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /pkg/ui/ui_test.go: -------------------------------------------------------------------------------- 1 | package ui 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/mclenhard/catie-mcp/pkg/config" 13 | ) 14 | 15 | func TestRecordRequest(t *testing.T) { 16 | stats := &Stats{ 17 | RequestsByMethod: make(map[string]int64), 18 | RequestsByEndpoint: make(map[string]int64), 19 | ResponseTimes: make(map[string][]int), 20 | StartTime: time.Now(), 21 | } 22 | 23 | // Record some test requests 24 | stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false) 25 | stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, false) 26 | stats.RecordRequest("GET", "/api/products", 100*time.Millisecond, true) 27 | 28 | // Verify stats were recorded correctly 29 | if stats.TotalRequests != 3 { 30 | t.Errorf("Expected 3 total requests, got %d", stats.TotalRequests) 31 | } 32 | 33 | if stats.RequestsByMethod["GET"] != 2 { 34 | t.Errorf("Expected 2 GET requests, got %d", stats.RequestsByMethod["GET"]) 35 | } 36 | 37 | if stats.RequestsByMethod["POST"] != 1 { 38 | t.Errorf("Expected 1 POST request, got %d", stats.RequestsByMethod["POST"]) 39 | } 40 | 41 | if stats.RequestsByEndpoint["/api/users"] != 2 { 42 | t.Errorf("Expected 2 /api/users requests, got %d", stats.RequestsByEndpoint["/api/users"]) 43 | } 44 | 45 | if stats.ErrorCount != 1 { 46 | t.Errorf("Expected 1 error, got %d", stats.ErrorCount) 47 | } 48 | 49 | // Check response times 50 | if len(stats.ResponseTimes["GET"]) != 2 { 51 | t.Errorf("Expected 2 response times for GET, got %d", len(stats.ResponseTimes["GET"])) 52 | } 53 | } 54 | 55 | func TestGetStats(t *testing.T) { 56 | stats := &Stats{ 57 | TotalRequests: 10, 58 | RequestsByMethod: map[string]int64{"GET": 7, "POST": 3}, 59 | RequestsByEndpoint: map[string]int64{"/api/users": 5, "/api/products": 5}, 60 | ResponseTimes: map[string][]int{"GET": {100, 150}, "POST": {200}}, 61 | ErrorCount: 2, 62 | StartTime: time.Now(), 63 | } 64 | 65 | // Get a copy of the stats 66 | statsCopy := stats.GetStats() 67 | 68 | // Verify the copy has the correct values 69 | if statsCopy.TotalRequests != 10 { 70 | t.Errorf("Expected 10 total requests in copy, got %d", statsCopy.TotalRequests) 71 | } 72 | 73 | // Modify the original stats 74 | stats.TotalRequests = 15 75 | stats.RequestsByMethod["GET"] = 10 76 | 77 | // Verify the copy wasn't affected 78 | if statsCopy.TotalRequests != 10 { 79 | t.Errorf("Stats copy was affected by changes to original") 80 | } 81 | 82 | if statsCopy.RequestsByMethod["GET"] != 7 { 83 | t.Errorf("Stats copy was affected by changes to original") 84 | } 85 | } 86 | 87 | func TestHandleStats(t *testing.T) { 88 | // Create a UI instance with test data 89 | cfg := &config.Config{ 90 | RouteConfig: config.RouteConfig{ 91 | UI: struct { 92 | Username string `yaml:"username"` 93 | Password string `yaml:"password"` 94 | }{}, 95 | }, 96 | } 97 | ui := New(cfg) 98 | 99 | // Add some test data 100 | ui.Stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false) 101 | ui.Stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, false) 102 | 103 | // Test HTML response 104 | req, _ := http.NewRequest("GET", "/stats", nil) 105 | rr := httptest.NewRecorder() 106 | 107 | ui.HandleStats(rr, req) 108 | 109 | if status := rr.Code; status != http.StatusOK { 110 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) 111 | } 112 | 113 | if ctype := rr.Header().Get("Content-Type"); ctype != "text/html" { 114 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "text/html") 115 | } 116 | 117 | // Test JSON response 118 | req, _ = http.NewRequest("GET", "/stats", nil) 119 | req.Header.Set("Accept", "application/json") 120 | rr = httptest.NewRecorder() 121 | 122 | ui.HandleStats(rr, req) 123 | 124 | if status := rr.Code; status != http.StatusOK { 125 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) 126 | } 127 | 128 | if ctype := rr.Header().Get("Content-Type"); ctype != "application/json" { 129 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "application/json") 130 | } 131 | 132 | // Verify we can parse the JSON response 133 | var data map[string]interface{} 134 | if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil { 135 | t.Errorf("Failed to parse JSON response: %v", err) 136 | } 137 | } 138 | 139 | func TestHandleMetrics(t *testing.T) { 140 | // Create a UI instance with test data 141 | cfg := &config.Config{ 142 | RouteConfig: config.RouteConfig{ 143 | UI: struct { 144 | Username string `yaml:"username"` 145 | Password string `yaml:"password"` 146 | }{}, 147 | }, 148 | } 149 | ui := New(cfg) 150 | 151 | // Add some test data 152 | ui.Stats.RecordRequest("GET", "/api/users", 150*time.Millisecond, false) 153 | ui.Stats.RecordRequest("POST", "/api/users", 200*time.Millisecond, true) 154 | 155 | // Test metrics response 156 | req, _ := http.NewRequest("GET", "/metrics", nil) 157 | rr := httptest.NewRecorder() 158 | 159 | ui.HandleMetrics(rr, req) 160 | 161 | if status := rr.Code; status != http.StatusOK { 162 | t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK) 163 | } 164 | 165 | if ctype := rr.Header().Get("Content-Type"); ctype != "text/plain" { 166 | t.Errorf("Content-Type header does not match: got %v want %v", ctype, "text/plain") 167 | } 168 | 169 | // Check for expected metrics in the response 170 | body := rr.Body.String() 171 | expectedMetrics := []string{ 172 | "mcp_router_requests_total 2", 173 | "mcp_router_errors_total 1", 174 | "mcp_router_requests_by_method{method=\"GET\"} 1", 175 | "mcp_router_requests_by_method{method=\"POST\"} 1", 176 | "mcp_router_requests_by_endpoint{endpoint=\"/api/users\"} 2", 177 | } 178 | 179 | for _, metric := range expectedMetrics { 180 | if !strings.Contains(body, metric) { 181 | t.Errorf("Expected metric not found in response: %s", metric) 182 | } 183 | } 184 | } 185 | 186 | func TestBasicAuth(t *testing.T) { 187 | // Test with auth credentials 188 | cfg := &config.Config{ 189 | RouteConfig: config.RouteConfig{ 190 | UI: struct { 191 | Username string `yaml:"username"` 192 | Password string `yaml:"password"` 193 | }{ 194 | Username: "admin", 195 | Password: "secret", 196 | }, 197 | }, 198 | } 199 | ui := New(cfg) 200 | 201 | // Create a simple handler for testing 202 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 203 | w.WriteHeader(http.StatusOK) 204 | w.Write([]byte("Success")) 205 | }) 206 | 207 | // Wrap it with basic auth 208 | authHandler := ui.basicAuth(testHandler) 209 | 210 | // Test with no credentials 211 | req, _ := http.NewRequest("GET", "/stats", nil) 212 | rr := httptest.NewRecorder() 213 | 214 | authHandler(rr, req) 215 | 216 | if status := rr.Code; status != http.StatusUnauthorized { 217 | t.Errorf("Handler should return 401 without credentials, got: %v", status) 218 | } 219 | 220 | // Test with invalid credentials 221 | req, _ = http.NewRequest("GET", "/stats", nil) 222 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:wrongpass"))) 223 | rr = httptest.NewRecorder() 224 | 225 | authHandler(rr, req) 226 | 227 | if status := rr.Code; status != http.StatusUnauthorized { 228 | t.Errorf("Handler should return 401 with invalid credentials, got: %v", status) 229 | } 230 | 231 | // Test with valid credentials 232 | req, _ = http.NewRequest("GET", "/stats", nil) 233 | req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:secret"))) 234 | rr = httptest.NewRecorder() 235 | 236 | authHandler(rr, req) 237 | 238 | if status := rr.Code; status != http.StatusOK { 239 | t.Errorf("Handler should return 200 with valid credentials, got: %v", status) 240 | } 241 | 242 | // Test with no auth required 243 | cfg = &config.Config{ 244 | RouteConfig: config.RouteConfig{ 245 | UI: struct { 246 | Username string `yaml:"username"` 247 | Password string `yaml:"password"` 248 | }{}, 249 | }, 250 | } 251 | ui = New(cfg) 252 | 253 | authHandler = ui.basicAuth(testHandler) 254 | req, _ = http.NewRequest("GET", "/stats", nil) 255 | rr = httptest.NewRecorder() 256 | 257 | authHandler(rr, req) 258 | 259 | if status := rr.Code; status != http.StatusOK { 260 | t.Errorf("Handler should return 200 when no auth required, got: %v", status) 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MCP Catie - Context Aware Traffic Ingress Engine 2 | 3 | A lightweight, configurable reverse proxy for routing and load balancing MCP (Model Context Protocol) requests to appropriate backend services based on request content. 4 | 5 | For detailed documentation, visit [Catie MCP Documentation](https://www.catiemcp.com/docs/). 6 | 7 | ## Features 8 | 9 | - Dynamic routing of MCP JSON-RPC requests based on tool call 10 | - Concatenate tools from multiple different MCP Servers so the client gets a unified view of all tools without the user installing multiple servers. 11 | - Session-aware routing to maintain client connections to the same backend 12 | - Support for Streamable HTTP transport with SSE (Server-Sent Events) 13 | - Tool name mapping and namespacing to resolve naming conflicts between different backends 14 | - Prometheus metrics integration for observability 15 | - Containerized deployment with Docker 16 | - Basic authentication for monitoring UI 17 | 18 | ## Architecture 19 | 20 | The application is structured into several packages: 21 | 22 | - `cmd/main.go` - Application entry point with server setup 23 | - `pkg/config` - Configuration loading and management 24 | - `pkg/router` - Request routing and proxy logic 25 | - `pkg/session` - Session management for maintaining client connections 26 | - `pkg/logger` - Structured logging system 27 | - `pkg/ui` - Simple web UI for monitoring 28 | 29 | ## Configuration 30 | 31 | The router is configured using a YAML file (`router_config.yaml`). Here's an example configuration: 32 | 33 | ```yaml 34 | resources: 35 | "^weather/.*": "http://weather-service:8080/mcp" 36 | "^database/.*": "http://database-service:8080/mcp" 37 | tools: 38 | "^calculator$": "http://calculator-service:8080/mcp" 39 | "^translator$": "http://translator-service:8080/mcp" 40 | toolMappings: 41 | - originalName: "weather" 42 | targetName: "getWeather" 43 | target: "http://weather-service:8080/mcp" 44 | - originalName: "search" 45 | targetName: "googleSearch" 46 | target: "http://search-service:8080/mcp" 47 | default: "http://default-service:8080/mcp" 48 | ui: 49 | username: "admin" 50 | password: "your_secure_password" 51 | ``` 52 | 53 | The configuration consists of: 54 | 55 | - `resources`: Regex patterns for resource URIs and their target endpoints 56 | - `tools`: Regex patterns for tool names and their target endpoints 57 | - `toolMappings`: Mappings for tool name transformations to resolve naming conflicts 58 | - `default`: Fallback endpoint for requests that don't match any pattern 59 | - `ui`: Authentication credentials for the monitoring UI 60 | 61 | The configuration file is automatically reloaded when changes are detected. 62 | 63 | ## Tool Name Mapping 64 | 65 | The tool name mapping feature allows you to present a unified tool interface to clients while handling naming differences across backend MCP servers. This is useful when: 66 | 67 | - Different backends use different names for similar functionality 68 | - You want to present a simplified or standardized naming scheme to clients 69 | - You need to avoid naming conflicts between tools from different backends 70 | 71 | For each tool mapping, specify: 72 | - `originalName`: The name presented to clients 73 | - `targetName`: The actual name expected by the backend server 74 | - `target`: The URL of the target backend server 75 | 76 | When a client makes a tool call with the original name, MCProute automatically transforms it to the target name before forwarding the request to the appropriate backend. 77 | 78 | ## Installation 79 | 80 | ### Prerequisites 81 | 82 | - Go 1.18 or higher 83 | - Docker (optional, for containerized deployment) 84 | 85 | ### Building from Source 86 | 87 | 1. Clone the repository: 88 | ```bash 89 | git clone https://github.com/mclenhard/mcp-router-proxy.git 90 | cd mcp-router-proxy 91 | ``` 92 | 93 | 2. Build the application: 94 | ```bash 95 | go build -o mcp-router-proxy ./cmd/main.go 96 | ``` 97 | 98 | 3. Edit router_config.yaml to match your environment 99 | 100 | 4. Run the application: 101 | ```bash 102 | ./mcp-router-proxy 103 | ``` 104 | 105 | ### Using Docker 106 | 107 | 1. Build the Docker image: 108 | ```bash 109 | docker build -t mcp-router-proxy . 110 | ``` 111 | 112 | 2. Run the container: 113 | ```bash 114 | docker run -p 80:80 -v $(pwd)/router_config.yaml:/root/router_config.yaml mcp-router-proxy 115 | ``` 116 | 117 | ## Usage 118 | 119 | The proxy listens for MCP requests on the `/mcp` endpoint. Requests are routed based on their method and parameters: 120 | 121 | - `resources/read` requests are routed based on the `uri` parameter 122 | - `tools/call` requests are routed based on the `name` parameter 123 | - Other requests are sent to the default endpoint 124 | 125 | The proxy supports both GET and POST methods according to the MCP Streamable HTTP transport specification: 126 | 127 | - POST requests are used to send JSON-RPC messages to the server 128 | - GET requests are used to establish SSE streams for server-to-client communication 129 | 130 | ### Session Management 131 | 132 | The proxy maintains session state by tracking the `Mcp-Session-Id` header. When a client establishes a session with an MCP server through the proxy, subsequent requests with the same session ID are routed to the same backend server. 133 | 134 | ### Health Check 135 | 136 | A health check endpoint is available at `/health` which returns a 200 OK response when the service is running. 137 | 138 | ### Monitoring 139 | 140 | A simple monitoring UI is available at `/stats` which shows request statistics and routing information. This interface is protected by basic authentication using the credentials specified in the configuration file. 141 | 142 | ![Monitoring UI Screenshot](ui.png) 143 | 144 | ### Prometheus Integration 145 | 146 | The service exposes Prometheus-compatible metrics at the `/metrics` endpoint. These metrics include: 147 | 148 | - `mcp_router_requests_total`: Total number of requests processed 149 | - `mcp_router_errors_total`: Total number of request errors 150 | - `mcp_router_requests_by_method`: Number of requests broken down by method 151 | - `mcp_router_requests_by_endpoint`: Number of requests broken down by target endpoint 152 | - `mcp_router_response_time_ms`: Average response time in milliseconds by method 153 | - `mcp_router_uptime_seconds`: Time since the router started in seconds 154 | 155 | You can configure Prometheus to scrape these metrics by adding the following to your Prometheus configuration: 156 | 157 | ```yaml 158 | scrape_configs: 159 | - job_name: 'mcp-router' 160 | scrape_interval: 15s 161 | static_configs: 162 | - targets: ['your-router-host:80'] 163 | ``` 164 | 165 | This endpoint is also protected by the same basic authentication as the stats UI. 166 | 167 | ## Development 168 | 169 | ### Project Structure 170 | 171 | ``` 172 | mcp-router-proxy/ 173 | ├── cmd/ 174 | │ └── main.go 175 | ├── pkg/ 176 | │ ├── config/ 177 | │ │ └── config.go 178 | │ ├── router/ 179 | │ │ └── router.go 180 | │ ├── session/ 181 | │ │ └── store.go 182 | │ ├── logger/ 183 | │ │ └── logger.go 184 | │ └── ui/ 185 | │ └── ui.go 186 | ├── Dockerfile 187 | ├── go.mod 188 | ├── go.sum 189 | ├── README.md 190 | └── router_config.yaml 191 | ``` 192 | 193 | ### Adding New Features 194 | 195 | 1. Fork the repository 196 | 2. Create a feature branch 197 | 3. Add your changes 198 | 4. Submit a pull request 199 | 200 | ## Roadmap 201 | 202 | The following features are planned for upcoming releases: 203 | - ~~**Add SSE Support**: Add support for SSE (Server-Sent Events) to the proxy~~ 204 | - **Complete Message Forwarding**: Ensure all MCP message types (including roots and sampling) are properly forwarded without interference 205 | - **Intelligent Caching**: Response caching with configurable TTL, cache invalidation, and support for memory and Redis backends 206 | - **Rate Limiting**: Configurable rate limiting with multiple strategies, response headers, and distributed rate limiting support 207 | - **Circuit Breaking**: Automatic detection of backend failures with fallback responses 208 | - **Request Transformation**: Modify requests before forwarding to backends 209 | - **Response Transformation**: Transform backend responses before returning to clients 210 | 211 | Development priorities are based on community feedback. Please open an issue to request features or contribute to the roadmap discussion. 212 | 213 | ## License 214 | 215 | [MIT License](LICENSE) 216 | 217 | ## Contributing 218 | 219 | Contributions are welcome! Please feel free to submit a Pull Request. 220 | 221 | ## Support 222 | 223 | For support, please open an issue in the GitHub repository or contact me at [mclenhard@gmail.com](mailto:mclenhard@gmail.com). -------------------------------------------------------------------------------- /pkg/router/router_test.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net/http" 7 | "net/http/httptest" 8 | "regexp" 9 | "testing" 10 | 11 | "github.com/mclenhard/catie-mcp/pkg/config" 12 | "github.com/mclenhard/catie-mcp/pkg/logger" 13 | "github.com/mclenhard/catie-mcp/pkg/session" 14 | "github.com/mclenhard/catie-mcp/pkg/ui" 15 | ) 16 | 17 | // MockConfig implements a simple config for testing 18 | type MockConfig struct { 19 | DefaultURL string 20 | ResourceRegexes []config.RouteRule 21 | ToolRegexes []config.RouteRule 22 | AllTargets []string 23 | } 24 | 25 | func (m *MockConfig) GetDefault() string { 26 | return m.DefaultURL 27 | } 28 | 29 | func (m *MockConfig) GetResourceRegexes() []config.RouteRule { 30 | return m.ResourceRegexes 31 | } 32 | 33 | func (m *MockConfig) GetToolRegexes() []config.RouteRule { 34 | return m.ToolRegexes 35 | } 36 | 37 | func (m *MockConfig) GetAllTargets() []string { 38 | return m.AllTargets 39 | } 40 | 41 | func (m *MockConfig) GetToolMappings() []config.ToolMapping { 42 | // Return an empty map for testing purposes 43 | return []config.ToolMapping{} 44 | } 45 | 46 | func TestRouteByContext(t *testing.T) { 47 | // Create a mock config 48 | mockConfig := &MockConfig{ 49 | DefaultURL: "http://default:8080", 50 | ResourceRegexes: []config.RouteRule{ 51 | {Pattern: compileRegex(t, "^/resource/([^/]+)$"), Target: "http://resource:8080"}, 52 | {Pattern: compileRegex(t, "^/api/v1/([^/]+)$"), Target: "http://api:8080"}, 53 | }, 54 | ToolRegexes: []config.RouteRule{ 55 | {Pattern: compileRegex(t, "^/tool/([^/]+)$"), Target: "http://tool:8080"}, 56 | {Pattern: compileRegex(t, "^/utility/([^/]+)$"), Target: "http://utility:8080"}, 57 | }, 58 | AllTargets: []string{ 59 | "http://default:8080", 60 | "http://resource:8080", 61 | "http://api:8080", 62 | "http://tool:8080", 63 | "http://utility:8080", 64 | }, 65 | } 66 | 67 | // Create a router with the mock config 68 | r := &Router{ 69 | Config: mockConfig, 70 | Logger: logger.New(logger.Debug), 71 | } 72 | 73 | // Test cases 74 | tests := []struct { 75 | name string 76 | request JSONRPCRequest 77 | expectedTarget string 78 | }{ 79 | { 80 | name: "resources/read with matching URI", 81 | request: JSONRPCRequest{ 82 | Method: "resources/read", 83 | Params: map[string]interface{}{ 84 | "uri": "/resource/test", 85 | }, 86 | }, 87 | expectedTarget: "http://resource:8080", 88 | }, 89 | { 90 | name: "resources/read with different matching URI", 91 | request: JSONRPCRequest{ 92 | Method: "resources/read", 93 | Params: map[string]interface{}{ 94 | "uri": "/api/v1/test", 95 | }, 96 | }, 97 | expectedTarget: "http://api:8080", 98 | }, 99 | { 100 | name: "resources/read with non-matching URI", 101 | request: JSONRPCRequest{ 102 | Method: "resources/read", 103 | Params: map[string]interface{}{ 104 | "uri": "/nonmatching/test", 105 | }, 106 | }, 107 | expectedTarget: "http://default:8080", 108 | }, 109 | { 110 | name: "tools/call with matching name", 111 | request: JSONRPCRequest{ 112 | Method: "tools/call", 113 | Params: map[string]interface{}{ 114 | "name": "/tool/test", 115 | }, 116 | }, 117 | expectedTarget: "http://tool:8080", 118 | }, 119 | { 120 | name: "tools/call with different matching name", 121 | request: JSONRPCRequest{ 122 | Method: "tools/call", 123 | Params: map[string]interface{}{ 124 | "name": "/utility/test", 125 | }, 126 | }, 127 | expectedTarget: "http://utility:8080", 128 | }, 129 | { 130 | name: "tools/call with non-matching name", 131 | request: JSONRPCRequest{ 132 | Method: "tools/call", 133 | Params: map[string]interface{}{ 134 | "name": "/nonmatching/test", 135 | }, 136 | }, 137 | expectedTarget: "http://default:8080", 138 | }, 139 | { 140 | name: "unknown method", 141 | request: JSONRPCRequest{ 142 | Method: "unknown/method", 143 | Params: map[string]interface{}{}, 144 | }, 145 | expectedTarget: "http://default:8080", 146 | }, 147 | } 148 | 149 | for _, tt := range tests { 150 | t.Run(tt.name, func(t *testing.T) { 151 | target := r.RouteByContext(tt.request) 152 | if target != tt.expectedTarget { 153 | t.Errorf("Expected target %s, got %s", tt.expectedTarget, target) 154 | } 155 | }) 156 | } 157 | } 158 | 159 | func TestDetermineTargetForSession(t *testing.T) { 160 | // Create a mock config 161 | mockConfig := &MockConfig{ 162 | DefaultURL: "http://default:8080", 163 | AllTargets: []string{ 164 | "http://default:8080", 165 | "http://target1:8080", 166 | "http://target2:8080", 167 | }, 168 | } 169 | 170 | // Create a router with the mock config 171 | r := &Router{ 172 | Config: mockConfig, 173 | SessionStore: NewTestSessionStore(), 174 | Logger: logger.New(logger.Debug), 175 | } 176 | 177 | // Test with no session ID 178 | target := r.determineTargetForSession("") 179 | if target != "http://default:8080" { 180 | t.Errorf("Expected default target for empty session ID, got %s", target) 181 | } 182 | 183 | // Test with a known session ID 184 | r.SessionStore.Set("known-session", "http://target1:8080") 185 | target = r.determineTargetForSession("known-session") 186 | if target != "http://target1:8080" { 187 | t.Errorf("Expected target1 for known session ID, got %s", target) 188 | } 189 | 190 | // Test with an unknown session ID 191 | // This should hash consistently to one of the targets 192 | target1 := r.determineTargetForSession("unknown-session-1") 193 | target2 := r.determineTargetForSession("unknown-session-1") // Same ID should give same target 194 | if target1 != target2 { 195 | t.Errorf("Same session ID gave different targets: %s and %s", target1, target2) 196 | } 197 | 198 | // Different unknown session IDs might hash to different targets 199 | // but we can at least verify they're in our list of targets 200 | target3 := r.determineTargetForSession("unknown-session-2") 201 | found := false 202 | for _, t := range mockConfig.AllTargets { 203 | if t == target3 { 204 | found = true 205 | break 206 | } 207 | } 208 | if !found { 209 | t.Errorf("Target %s not found in list of valid targets", target3) 210 | } 211 | } 212 | 213 | func TestHandleMCPRequest_Initialize(t *testing.T) { 214 | // Create a mock config 215 | mockConfig := &MockConfig{ 216 | DefaultURL: "http://default:8080", 217 | } 218 | 219 | // Create a router with the mock config 220 | r := &Router{ 221 | Config: mockConfig, 222 | UI: &ui.UI{Stats: &ui.Stats{RequestsByMethod: make(map[string]int64), RequestsByEndpoint: make(map[string]int64), ResponseTimes: make(map[string][]int)}}, 223 | SessionStore: NewTestSessionStore(), 224 | Logger: logger.New(logger.Debug), 225 | } 226 | 227 | // Create a test server that will respond to our proxied requests 228 | testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 229 | // Add a session ID to the response 230 | w.Header().Set("Mcp-Session-Id", "test-session-123") 231 | w.Header().Set("Content-Type", "application/json") 232 | w.WriteHeader(http.StatusOK) 233 | w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}`)) 234 | })) 235 | defer testServer.Close() 236 | 237 | // Override the default URL to point to our test server 238 | mockConfig.DefaultURL = testServer.URL 239 | 240 | // Create an initialize request 241 | initRequest := JSONRPCRequest{ 242 | JSONRPC: "2.0", 243 | ID: 1, 244 | Method: "initialize", 245 | Params: map[string]interface{}{}, 246 | } 247 | requestBody, _ := json.Marshal(initRequest) 248 | 249 | // Create a test request 250 | req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(requestBody)) 251 | req.Header.Set("Content-Type", "application/json") 252 | req.Header.Set("Accept", "application/json") 253 | 254 | // Create a response recorder 255 | rr := httptest.NewRecorder() 256 | 257 | // Call the handler 258 | r.HandleMCPRequest(rr, req) 259 | 260 | // Check the response 261 | if rr.Code != http.StatusOK { 262 | t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) 263 | } 264 | 265 | // Check that the session ID was stored 266 | target, exists := r.SessionStore.Get("test-session-123") 267 | if !exists { 268 | t.Errorf("Session ID was not stored") 269 | } 270 | if target != testServer.URL { 271 | t.Errorf("Expected target %s, got %s", testServer.URL, target) 272 | } 273 | 274 | // Check the response body 275 | var response map[string]interface{} 276 | if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil { 277 | t.Errorf("Failed to parse response body: %v", err) 278 | } 279 | if response["result"].(map[string]interface{})["status"] != "ok" { 280 | t.Errorf("Expected status 'ok', got %v", response["result"]) 281 | } 282 | } 283 | 284 | // Helper function to compile regex patterns for testing 285 | func compileRegex(t *testing.T, pattern string) *regexp.Regexp { 286 | r, err := regexp.Compile(pattern) 287 | if err != nil { 288 | t.Fatalf("Failed to compile regex pattern '%s': %v", pattern, err) 289 | } 290 | return r 291 | } 292 | 293 | // NewTestSessionStore creates a simple session store for testing 294 | func NewTestSessionStore() *session.Store { 295 | return session.NewStore() 296 | } 297 | -------------------------------------------------------------------------------- /pkg/ui/ui.go: -------------------------------------------------------------------------------- 1 | // Package ui provides a web interface for monitoring the MCP router proxy 2 | package ui 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "html/template" 8 | "net/http" 9 | "strconv" 10 | "sync" 11 | "time" 12 | 13 | "github.com/mclenhard/catie-mcp/pkg/config" 14 | "github.com/mclenhard/catie-mcp/pkg/logger" 15 | ) 16 | 17 | // Stats tracks various metrics about the router proxy 18 | type Stats struct { 19 | TotalRequests int64 `json:"totalRequests"` 20 | RequestsByMethod map[string]int64 `json:"requestsByMethod"` 21 | RequestsByEndpoint map[string]int64 `json:"requestsByEndpoint"` 22 | ResponseTimes map[string][]int `json:"responseTimes"` // in milliseconds 23 | ErrorCount int64 `json:"errorCount"` 24 | StartTime time.Time `json:"startTime"` 25 | mu sync.RWMutex 26 | } 27 | 28 | // UI handles the web interface for the router proxy 29 | type UI struct { 30 | Stats *Stats 31 | template *template.Template 32 | username string 33 | password string 34 | logger *logger.Logger 35 | } 36 | 37 | // New creates a new UI instance 38 | func New(config *config.Config) *UI { 39 | tmpl := template.Must(template.New("stats").Parse(statsTemplate)) 40 | 41 | return &UI{ 42 | Stats: &Stats{ 43 | RequestsByMethod: make(map[string]int64), 44 | RequestsByEndpoint: make(map[string]int64), 45 | ResponseTimes: make(map[string][]int), 46 | StartTime: time.Now(), 47 | }, 48 | template: tmpl, 49 | username: config.RouteConfig.UI.Username, 50 | password: config.RouteConfig.UI.Password, 51 | logger: logger.New(logger.Info), // Initialize with Info level 52 | } 53 | } 54 | 55 | // RecordRequest records statistics for a request 56 | func (s *Stats) RecordRequest(method, endpoint string, responseTime time.Duration, isError bool) { 57 | s.mu.Lock() 58 | defer s.mu.Unlock() 59 | 60 | s.TotalRequests++ 61 | s.RequestsByMethod[method]++ 62 | s.RequestsByEndpoint[endpoint]++ 63 | 64 | // Store response time (convert to milliseconds) 65 | respTimeMs := int(responseTime.Milliseconds()) 66 | s.ResponseTimes[method] = append(s.ResponseTimes[method], respTimeMs) 67 | 68 | // Limit the number of stored response times to prevent memory issues 69 | if len(s.ResponseTimes[method]) > 1000 { 70 | s.ResponseTimes[method] = s.ResponseTimes[method][len(s.ResponseTimes[method])-1000:] 71 | } 72 | 73 | if isError { 74 | s.ErrorCount++ 75 | } 76 | } 77 | 78 | // GetStats returns a copy of the current stats 79 | func (s *Stats) GetStats() Stats { 80 | s.mu.RLock() 81 | defer s.mu.RUnlock() 82 | 83 | // Create a deep copy to avoid race conditions 84 | statsCopy := Stats{ 85 | TotalRequests: s.TotalRequests, 86 | RequestsByMethod: make(map[string]int64), 87 | RequestsByEndpoint: make(map[string]int64), 88 | ResponseTimes: make(map[string][]int), 89 | ErrorCount: s.ErrorCount, 90 | StartTime: s.StartTime, 91 | 92 | } 93 | 94 | for k, v := range s.RequestsByMethod { 95 | statsCopy.RequestsByMethod[k] = v 96 | } 97 | 98 | for k, v := range s.RequestsByEndpoint { 99 | statsCopy.RequestsByEndpoint[k] = v 100 | } 101 | 102 | for k, v := range s.ResponseTimes { 103 | statsCopy.ResponseTimes[k] = make([]int, len(v)) 104 | copy(statsCopy.ResponseTimes[k], v) 105 | } 106 | 107 | return statsCopy 108 | } 109 | 110 | // HandleStats serves the stats page 111 | func (ui *UI) HandleStats(w http.ResponseWriter, r *http.Request) { 112 | stats := ui.Stats.GetStats() 113 | 114 | // Calculate uptime 115 | uptime := time.Since(stats.StartTime).String() 116 | 117 | // Calculate average response times 118 | avgResponseTimes := make(map[string]int) 119 | for method, times := range stats.ResponseTimes { 120 | if len(times) == 0 { 121 | continue 122 | } 123 | 124 | sum := 0 125 | for _, t := range times { 126 | sum += t 127 | } 128 | avgResponseTimes[method] = sum / len(times) 129 | } 130 | 131 | data := struct { 132 | Stats Stats 133 | Uptime string 134 | AvgResponseTimes map[string]int 135 | }{ 136 | Stats: stats, 137 | Uptime: uptime, 138 | AvgResponseTimes: avgResponseTimes, 139 | } 140 | 141 | if r.Header.Get("Accept") == "application/json" { 142 | w.Header().Set("Content-Type", "application/json") 143 | json.NewEncoder(w).Encode(data) 144 | return 145 | } 146 | 147 | w.Header().Set("Content-Type", "text/html") 148 | ui.template.Execute(w, data) 149 | } 150 | 151 | // RegisterHandlers registers the UI handlers with the provided mux 152 | func (ui *UI) RegisterHandlers(mux *http.ServeMux) { 153 | fmt.Println("Registering UI handlers with provided mux") 154 | mux.HandleFunc("/stats", ui.basicAuth(ui.HandleStats)) 155 | mux.HandleFunc("/metrics", ui.basicAuth(ui.HandleMetrics)) 156 | fmt.Println("UI handlers registered for /stats and /metrics") 157 | } 158 | 159 | // basicAuth wraps an http.HandlerFunc with basic authentication 160 | func (ui *UI) basicAuth(handler http.HandlerFunc) http.HandlerFunc { 161 | return func(w http.ResponseWriter, r *http.Request) { 162 | ui.logger.Debug("Received request for %s from %s", r.URL.Path, r.RemoteAddr) 163 | 164 | // Skip auth if credentials aren't configured 165 | if ui.username == "" && ui.password == "" { 166 | ui.logger.Info("No auth credentials configured, skipping authentication for %s", r.URL.Path) 167 | handler(w, r) 168 | return 169 | } 170 | 171 | ui.logger.Debug("Auth required for %s - username: '%s', password: [hidden]", r.URL.Path, ui.username) 172 | user, pass, ok := r.BasicAuth() 173 | 174 | if !ok { 175 | ui.logger.Warn("No auth credentials provided for %s from %s", r.URL.Path, r.RemoteAddr) 176 | w.Header().Set("WWW-Authenticate", `Basic realm="MCP Router Stats"`) 177 | w.WriteHeader(http.StatusUnauthorized) 178 | w.Write([]byte("Unauthorized - No credentials provided")) 179 | return 180 | } 181 | 182 | ui.logger.Debug("Auth attempt for %s - provided username: '%s'", r.URL.Path, user) 183 | 184 | if user != ui.username || pass != ui.password { 185 | ui.logger.Warn("Invalid credentials for %s from %s (user: %s)", r.URL.Path, r.RemoteAddr, user) 186 | w.Header().Set("WWW-Authenticate", `Basic realm="MCP Router Stats"`) 187 | w.WriteHeader(http.StatusUnauthorized) 188 | w.Write([]byte("Unauthorized - Invalid credentials")) 189 | return 190 | } 191 | 192 | ui.logger.Info("Authorized access to %s from %s (user: %s)", r.URL.Path, r.RemoteAddr, user) 193 | handler(w, r) 194 | } 195 | } 196 | 197 | // HandleMetrics serves Prometheus-compatible metrics 198 | func (ui *UI) HandleMetrics(w http.ResponseWriter, r *http.Request) { 199 | stats := ui.Stats.GetStats() 200 | 201 | w.Header().Set("Content-Type", "text/plain") 202 | 203 | // Basic Prometheus-style metrics 204 | w.Write([]byte("# HELP mcp_router_requests_total Total number of requests processed\n")) 205 | w.Write([]byte("# TYPE mcp_router_requests_total counter\n")) 206 | w.Write([]byte("mcp_router_requests_total " + strconv.FormatInt(stats.TotalRequests, 10) + "\n\n")) 207 | 208 | w.Write([]byte("# HELP mcp_router_errors_total Total number of request errors\n")) 209 | w.Write([]byte("# TYPE mcp_router_errors_total counter\n")) 210 | w.Write([]byte("mcp_router_errors_total " + strconv.FormatInt(stats.ErrorCount, 10) + "\n\n")) 211 | 212 | // Method-specific metrics 213 | w.Write([]byte("# HELP mcp_router_requests_by_method Number of requests by method\n")) 214 | w.Write([]byte("# TYPE mcp_router_requests_by_method counter\n")) 215 | for method, count := range stats.RequestsByMethod { 216 | w.Write([]byte("mcp_router_requests_by_method{method=\"" + method + "\"} " + strconv.FormatInt(count, 10) + "\n")) 217 | } 218 | 219 | // Endpoint-specific metrics 220 | w.Write([]byte("\n# HELP mcp_router_requests_by_endpoint Number of requests by endpoint\n")) 221 | w.Write([]byte("# TYPE mcp_router_requests_by_endpoint counter\n")) 222 | for endpoint, count := range stats.RequestsByEndpoint { 223 | w.Write([]byte("mcp_router_requests_by_endpoint{endpoint=\"" + endpoint + "\"} " + strconv.FormatInt(count, 10) + "\n")) 224 | } 225 | 226 | // Response time metrics 227 | w.Write([]byte("\n# HELP mcp_router_response_time_ms Average response time in milliseconds\n")) 228 | w.Write([]byte("# TYPE mcp_router_response_time_ms gauge\n")) 229 | for method, times := range stats.ResponseTimes { 230 | if len(times) == 0 { 231 | continue 232 | } 233 | 234 | sum := 0 235 | for _, t := range times { 236 | sum += t 237 | } 238 | avg := sum / len(times) 239 | w.Write([]byte("mcp_router_response_time_ms{method=\"" + method + "\"} " + strconv.Itoa(avg) + "\n")) 240 | } 241 | 242 | // Uptime metric 243 | uptime := time.Since(stats.StartTime).Seconds() 244 | w.Write([]byte("\n# HELP mcp_router_uptime_seconds Time since the router started in seconds\n")) 245 | w.Write([]byte("# TYPE mcp_router_uptime_seconds counter\n")) 246 | w.Write([]byte("mcp_router_uptime_seconds " + strconv.FormatFloat(uptime, 'f', 1, 64) + "\n")) 247 | } 248 | 249 | // HTML template for the stats page 250 | const statsTemplate = ` 251 | 252 | 253 | 254 | Request Statistics 255 | 256 | 313 | 314 | 315 |

MCP Router Proxy Statistics

316 | 317 | 318 | 319 |
320 |
321 |

General Stats

322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 |
Total Requests{{.Stats.TotalRequests}}
Error Count{{.Stats.ErrorCount}}
Uptime{{.Uptime}}
336 |
337 | 338 |
339 |

Requests by Method

340 | 341 | 342 | 343 | 344 | 345 | 346 | {{range $method, $count := .Stats.RequestsByMethod}} 347 | 348 | 349 | 350 | 351 | 352 | {{end}} 353 |
MethodCountAvg Response Time (ms)
{{$method}}{{$count}}{{index $.AvgResponseTimes $method}}
354 |
355 | 356 |
357 |

Requests by Endpoint

358 | 359 | 360 | 361 | 362 | 363 | {{range $endpoint, $count := .Stats.RequestsByEndpoint}} 364 | 365 | 366 | 367 | 368 | {{end}} 369 |
EndpointCount
{{$endpoint}}{{$count}}
370 |
371 |
372 | 373 | 379 | 380 | 381 | ` 382 | -------------------------------------------------------------------------------- /pkg/router/router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "encoding/json" 8 | "fmt" 9 | "hash/fnv" 10 | "io" 11 | "net/http" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "github.com/mclenhard/catie-mcp/pkg/config" 17 | "github.com/mclenhard/catie-mcp/pkg/logger" 18 | "github.com/mclenhard/catie-mcp/pkg/session" 19 | "github.com/mclenhard/catie-mcp/pkg/ui" 20 | ) 21 | 22 | // JSONRPCRequest structure 23 | type JSONRPCRequest struct { 24 | JSONRPC string `json:"jsonrpc"` 25 | ID interface{} `json:"id,omitempty"` 26 | Method string `json:"method"` 27 | Params map[string]interface{} `json:"params,omitempty"` 28 | } 29 | 30 | // ConfigInterface is an interface that both MockConfig and config.Config can implement 31 | type ConfigInterface interface { 32 | GetDefault() string 33 | GetResourceRegexes() []config.RouteRule 34 | GetToolRegexes() []config.RouteRule 35 | GetAllTargets() []string 36 | GetToolMappings() []config.ToolMapping 37 | } 38 | 39 | // Router handles the routing of MCP requests 40 | type Router struct { 41 | Config ConfigInterface // Changed from config.Config 42 | UI *ui.UI 43 | SessionStore *session.Store 44 | Logger *logger.Logger 45 | } 46 | 47 | // New creates a new Router instance 48 | func New(cfg *config.Config, ui *ui.UI) *Router { 49 | return &Router{ 50 | Config: cfg, 51 | UI: ui, 52 | SessionStore: session.NewStore(), 53 | Logger: logger.New(logger.Info), // Default to Info level 54 | } 55 | } 56 | 57 | // Add this constant at the top of the file 58 | const ( 59 | // DefaultTimeout is the default timeout for HTTP requests 60 | DefaultTimeout = 30 * time.Second 61 | ) 62 | 63 | // HandleMCPRequest processes incoming MCP requests and routes them to the appropriate target 64 | func (r *Router) HandleMCPRequest(w http.ResponseWriter, req *http.Request) { 65 | startTime := time.Now() 66 | var isError bool 67 | 68 | r.Logger.Info("Received request: Method=%s, Path=%s, ContentType=%s", 69 | req.Method, req.URL.Path, req.Header.Get("Content-Type")) 70 | 71 | // Handle OPTIONS requests (CORS preflight) 72 | if req.Method == http.MethodOptions { 73 | r.Logger.Debug("Handling OPTIONS request") 74 | // Set CORS headers 75 | w.Header().Set("Access-Control-Allow-Origin", "*") 76 | w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") 77 | w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id, Authorization, MCP-Protocol-Version") 78 | w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours 79 | 80 | // Respond with 200 OK for OPTIONS requests 81 | w.WriteHeader(http.StatusOK) 82 | r.UI.Stats.RecordRequest("OPTIONS", "cors", time.Since(startTime), false) 83 | return 84 | } 85 | 86 | // Check if this is a GET request (for SSE streaming) 87 | if req.Method == http.MethodGet { 88 | r.Logger.Debug("Handling GET request for SSE") 89 | // Extract session ID if present 90 | sessionID := req.Header.Get("Mcp-Session-Id") 91 | r.Logger.Debug("Session ID: %s", sessionID) 92 | 93 | // Determine target based on session ID or other routing logic 94 | targetURL := r.determineTargetForSession(sessionID) 95 | r.Logger.Info("Routing SSE stream to target: %s", targetURL) 96 | 97 | // Set up SSE headers for client 98 | w.Header().Set("Content-Type", "text/event-stream") 99 | w.Header().Set("Cache-Control", "no-cache") 100 | w.Header().Set("Connection", "keep-alive") 101 | w.Header().Set("Access-Control-Allow-Origin", "*") 102 | 103 | // Create a new GET request to the target server 104 | proxyReq, err := http.NewRequest(http.MethodGet, targetURL, nil) 105 | if err != nil { 106 | r.Logger.Error("Error creating proxy request: %v", err) 107 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError) 108 | isError = true 109 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError) 110 | return 111 | } 112 | 113 | // Copy relevant headers from the original request 114 | proxyReq.Header.Set("Accept", "text/event-stream") 115 | if sessionID != "" { 116 | proxyReq.Header.Set("Mcp-Session-Id", sessionID) 117 | } 118 | 119 | // Copy Last-Event-ID header if present (for resuming streams) 120 | if lastEventID := req.Header.Get("Last-Event-ID"); lastEventID != "" { 121 | proxyReq.Header.Set("Last-Event-ID", lastEventID) 122 | } 123 | 124 | // Make the request to the target server 125 | client := &http.Client{ 126 | // No timeout for SSE connections 127 | } 128 | 129 | resp, err := client.Do(proxyReq) 130 | if err != nil { 131 | r.Logger.Error("Error connecting to target server: %v", err) 132 | http.Error(w, "error connecting to target server: "+err.Error(), http.StatusBadGateway) 133 | isError = true 134 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError) 135 | return 136 | } 137 | defer resp.Body.Close() 138 | 139 | // Check if the target server returned an error 140 | if resp.StatusCode != http.StatusOK { 141 | // Read error body and forward it to the client 142 | errorBody, _ := io.ReadAll(resp.Body) 143 | r.Logger.Error("Target server returned error status %d: %s", resp.StatusCode, string(errorBody)) 144 | http.Error(w, string(errorBody), resp.StatusCode) 145 | isError = true 146 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError) 147 | return 148 | } 149 | 150 | r.Logger.Debug("Successfully connected to SSE stream from target") 151 | 152 | // Create a context that's canceled when the client connection closes 153 | ctx, cancel := context.WithCancel(req.Context()) 154 | defer cancel() 155 | 156 | // Create a done channel to signal when to close the connection 157 | done := make(chan bool) 158 | var once sync.Once // Add this to ensure the channel is only closed once 159 | 160 | // Handle client disconnection using context 161 | go func() { 162 | <-ctx.Done() 163 | once.Do(func() { close(done) }) // Use sync.Once to safely close the channel 164 | // Also close the response body to terminate the connection to the target 165 | resp.Body.Close() 166 | r.Logger.Debug("Client disconnected, closing SSE stream") 167 | }() 168 | 169 | // Stream SSE events from target to client 170 | go func() { 171 | scanner := bufio.NewScanner(resp.Body) 172 | // Set a larger buffer for the scanner to handle large SSE events 173 | const maxScanTokenSize = 1024 * 1024 // 1MB 174 | buf := make([]byte, maxScanTokenSize) 175 | scanner.Buffer(buf, maxScanTokenSize) 176 | 177 | eventCount := 0 178 | for scanner.Scan() { 179 | select { 180 | case <-done: 181 | return 182 | default: 183 | line := scanner.Text() 184 | fmt.Fprintf(w, "%s\n", line) 185 | // If this is the end of an event, flush the buffer 186 | if line == "" { 187 | eventCount++ 188 | if eventCount%100 == 0 { 189 | r.Logger.Debug("Streamed %d SSE events so far", eventCount) 190 | } 191 | if flusher, ok := w.(http.Flusher); ok { 192 | flusher.Flush() 193 | } 194 | } 195 | } 196 | } 197 | 198 | if err := scanner.Err(); err != nil { 199 | r.Logger.Error("Error scanning SSE stream: %v", err) 200 | } 201 | 202 | r.Logger.Info("SSE stream closed after sending %d events", eventCount) 203 | once.Do(func() { close(done) }) // Safely close the channel if not already closed 204 | }() 205 | 206 | // Set up heartbeat ticker 207 | heartbeatTicker := time.NewTicker(30 * time.Second) 208 | defer heartbeatTicker.Stop() 209 | 210 | // Send heartbeat events to keep the connection alive 211 | go func() { 212 | for { 213 | select { 214 | case <-heartbeatTicker.C: 215 | // Send a heartbeat event 216 | fmt.Fprintf(w, "event: heartbeat\ndata: %d\n\n", time.Now().Unix()) 217 | if flusher, ok := w.(http.Flusher); ok { 218 | flusher.Flush() 219 | } 220 | r.Logger.Debug("Sent heartbeat event") 221 | case <-done: 222 | return 223 | } 224 | } 225 | }() 226 | 227 | // Wait until done 228 | <-done 229 | r.UI.Stats.RecordRequest("GET", targetURL, time.Since(startTime), isError) 230 | return 231 | } 232 | 233 | // Handle POST requests (client sending messages to server) 234 | if req.Method == http.MethodPost { 235 | r.Logger.Debug("Handling POST request") 236 | 237 | // Read the request body 238 | body, err := io.ReadAll(req.Body) 239 | if err != nil { 240 | r.Logger.Error("Error reading request body: %v", err) 241 | http.Error(w, "error reading request body: "+err.Error(), http.StatusBadRequest) 242 | isError = true 243 | r.UI.Stats.RecordRequest("POST", "unknown", time.Since(startTime), isError) 244 | return 245 | } 246 | 247 | // Parse the JSON-RPC request to determine routing 248 | var jsonRPCRequest JSONRPCRequest 249 | if err := json.Unmarshal(body, &jsonRPCRequest); err != nil { 250 | // Try to parse as a batch request 251 | var batchRequest []JSONRPCRequest 252 | if batchErr := json.Unmarshal(body, &batchRequest); batchErr != nil { 253 | r.Logger.Error("Error parsing JSON-RPC request: %v", err) 254 | http.Error(w, "error parsing JSON-RPC request: "+err.Error(), http.StatusBadRequest) 255 | isError = true 256 | r.UI.Stats.RecordRequest("POST", "invalid_json", time.Since(startTime), isError) 257 | return 258 | } 259 | // Use the first request in the batch for routing 260 | if len(batchRequest) > 0 { 261 | jsonRPCRequest = batchRequest[0] 262 | } 263 | } 264 | 265 | // Extract session ID if present 266 | sessionID := req.Header.Get("Mcp-Session-Id") 267 | r.Logger.Debug("Session ID: %s", sessionID) 268 | 269 | // Check if this is an initialization request 270 | isInitialize := jsonRPCRequest.Method == "initialize" 271 | 272 | // Determine target URL based on the request content 273 | targetURL := r.determineTargetForSession(sessionID) 274 | r.Logger.Info("Routing request to target: %s", targetURL) 275 | 276 | // Apply tool name transformation if needed 277 | r.transformToolCall(jsonRPCRequest, targetURL) 278 | 279 | // Re-encode the possibly modified request 280 | modifiedBody, err := json.Marshal(jsonRPCRequest) 281 | if err != nil { 282 | http.Error(w, "error encoding request: "+err.Error(), http.StatusInternalServerError) 283 | return 284 | } 285 | 286 | // Create a new POST request to the target server 287 | proxyReq, err := http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(modifiedBody)) 288 | if err != nil { 289 | r.Logger.Error("Error creating proxy request: %v", err) 290 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError) 291 | isError = true 292 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 293 | return 294 | } 295 | 296 | // Copy headers from the original request 297 | proxyReq.Header.Set("Content-Type", "application/json") 298 | 299 | // Set Accept header to support both JSON and SSE responses 300 | proxyReq.Header.Set("Accept", req.Header.Get("Accept")) 301 | if proxyReq.Header.Get("Accept") == "" { 302 | proxyReq.Header.Set("Accept", "application/json, text/event-stream") 303 | } 304 | 305 | // Forward session ID if present 306 | if sessionID != "" { 307 | proxyReq.Header.Set("Mcp-Session-Id", sessionID) 308 | } 309 | 310 | // Make the request to the target server with a reasonable timeout 311 | client := &http.Client{ 312 | Timeout: 30 * time.Second, // Use a longer timeout for initialization 313 | } 314 | 315 | resp, err := client.Do(proxyReq) 316 | if err != nil { 317 | r.Logger.Error("Error connecting to target server: %v", err) 318 | http.Error(w, "error connecting to target server: "+err.Error(), http.StatusBadGateway) 319 | isError = true 320 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 321 | return 322 | } 323 | defer resp.Body.Close() 324 | 325 | // Check if the target server returned an error 326 | if resp.StatusCode >= 400 { 327 | // Read error body and forward it to the client 328 | errorBody, _ := io.ReadAll(resp.Body) 329 | r.Logger.Error("Target server returned error status %d: %s", resp.StatusCode, string(errorBody)) 330 | for name, values := range resp.Header { 331 | for _, value := range values { 332 | w.Header().Add(name, value) 333 | } 334 | } 335 | w.WriteHeader(resp.StatusCode) 336 | w.Write(errorBody) 337 | isError = true 338 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 339 | return 340 | } 341 | 342 | // Check if this is an initialization response with a session ID 343 | if isInitialize && resp.Header.Get("Mcp-Session-Id") != "" { 344 | newSessionID := resp.Header.Get("Mcp-Session-Id") 345 | r.Logger.Info("Received new session ID: %s", newSessionID) 346 | 347 | // Store the session ID and target URL mapping 348 | r.SessionStore.Set(newSessionID, targetURL) 349 | } 350 | 351 | // Check content type to determine how to handle the response 352 | contentType := resp.Header.Get("Content-Type") 353 | 354 | // Copy all response headers to the client 355 | for name, values := range resp.Header { 356 | for _, value := range values { 357 | w.Header().Add(name, value) 358 | } 359 | } 360 | 361 | // If the response is an SSE stream, handle it accordingly 362 | if strings.Contains(contentType, "text/event-stream") { 363 | r.Logger.Debug("Target server returned SSE stream") 364 | 365 | // Set SSE headers 366 | w.Header().Set("Content-Type", "text/event-stream") 367 | w.Header().Set("Cache-Control", "no-cache") 368 | w.Header().Set("Connection", "keep-alive") 369 | 370 | // Create a context that's canceled when the client connection closes 371 | ctx, cancel := context.WithCancel(req.Context()) 372 | defer cancel() 373 | 374 | // Create a done channel to signal when to close the connection 375 | done := make(chan bool) 376 | var once sync.Once // Add this to ensure the channel is only closed once 377 | 378 | // Handle client disconnection 379 | go func() { 380 | <-ctx.Done() 381 | once.Do(func() { close(done) }) // Use sync.Once to safely close the channel 382 | r.Logger.Debug("Client disconnected, closing SSE stream") 383 | }() 384 | 385 | // Stream SSE events from target to client 386 | scanner := bufio.NewScanner(resp.Body) 387 | // Set a larger buffer for the scanner to handle large SSE events 388 | const maxScanTokenSize = 1024 * 1024 // 1MB 389 | buf := make([]byte, maxScanTokenSize) 390 | scanner.Buffer(buf, maxScanTokenSize) 391 | 392 | go func() { 393 | for scanner.Scan() { 394 | select { 395 | case <-done: 396 | return 397 | default: 398 | line := scanner.Text() 399 | fmt.Fprintf(w, "%s\n", line) 400 | // If this is the end of an event, flush the buffer 401 | if line == "" { 402 | if flusher, ok := w.(http.Flusher); ok { 403 | flusher.Flush() 404 | } 405 | } 406 | } 407 | } 408 | 409 | if err := scanner.Err(); err != nil { 410 | r.Logger.Error("Error scanning SSE stream: %v", err) 411 | } 412 | 413 | once.Do(func() { close(done) }) // Safely close the channel if not already closed 414 | }() 415 | 416 | // Set up heartbeat ticker 417 | heartbeatTicker := time.NewTicker(30 * time.Second) 418 | defer heartbeatTicker.Stop() 419 | 420 | // Send heartbeat events to keep the connection alive 421 | go func() { 422 | for { 423 | select { 424 | case <-heartbeatTicker.C: 425 | // Send a heartbeat event 426 | fmt.Fprintf(w, "event: heartbeat\ndata: %d\n\n", time.Now().Unix()) 427 | if flusher, ok := w.(http.Flusher); ok { 428 | flusher.Flush() 429 | } 430 | r.Logger.Debug("Sent heartbeat event") 431 | case <-done: 432 | return 433 | } 434 | } 435 | }() 436 | 437 | // Wait until done 438 | <-done 439 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 440 | return 441 | } 442 | 443 | // For regular JSON responses, just copy the response body 444 | responseBody, err := io.ReadAll(resp.Body) 445 | if err != nil { 446 | r.Logger.Error("Error reading response body: %v", err) 447 | http.Error(w, "error reading response body: "+err.Error(), http.StatusInternalServerError) 448 | isError = true 449 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 450 | return 451 | } 452 | 453 | // Write the response status code and body 454 | w.WriteHeader(resp.StatusCode) 455 | w.Write(responseBody) 456 | 457 | // If this was an initialization request, log the response 458 | if isInitialize { 459 | r.Logger.Info("Successfully processed initialization request") 460 | 461 | // Parse the response to extract any relevant information 462 | var initResponse map[string]interface{} 463 | if err := json.Unmarshal(responseBody, &initResponse); err == nil { 464 | r.Logger.Debug("Initialization response: %v", initResponse) 465 | } 466 | } 467 | 468 | r.UI.Stats.RecordRequest("POST", targetURL, time.Since(startTime), isError) 469 | return 470 | } 471 | 472 | // Handle POST requests (client sending messages to server) 473 | if req.Method != http.MethodPost { 474 | r.Logger.Warn("Received unsupported method: %s", req.Method) 475 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 476 | isError = true 477 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError) 478 | return 479 | } 480 | 481 | // Check Accept header 482 | acceptHeader := req.Header.Get("Accept") 483 | if !strings.Contains(acceptHeader, "application/json") && !strings.Contains(acceptHeader, "text/event-stream") { 484 | r.Logger.Warn("Invalid Accept header: %s", acceptHeader) 485 | http.Error(w, "Invalid Accept header", http.StatusBadRequest) 486 | isError = true 487 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError) 488 | return 489 | } 490 | 491 | // Extract session ID if present 492 | sessionID := req.Header.Get("Mcp-Session-Id") 493 | r.Logger.Debug("Received POST request with session ID: %s", sessionID) 494 | 495 | body, err := io.ReadAll(req.Body) 496 | if err != nil { 497 | r.Logger.Error("Failed to read request body: %v", err) 498 | http.Error(w, "failed to read request body", http.StatusBadRequest) 499 | isError = true 500 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError) 501 | return 502 | } 503 | 504 | // Try to parse as a single request or as a batch 505 | var method string 506 | var targetURL string 507 | 508 | // First try to parse as a single request 509 | var rpcReq JSONRPCRequest 510 | if err := json.Unmarshal(body, &rpcReq); err == nil { 511 | // Successfully parsed as a single request 512 | method = rpcReq.Method 513 | r.Logger.Debug("Parsed single JSON-RPC request with method: %s", method) 514 | 515 | // Determine target URL based on the request 516 | if rpcReq.Method == "initialize" { 517 | // For initialize requests, use the default target 518 | targetURL = r.Config.GetDefault() 519 | r.Logger.Info("Initialize request, using default target: %s", targetURL) 520 | } else { 521 | // For other requests, use the routing logic 522 | targetURL = r.RouteByContext(rpcReq) 523 | r.Logger.Info("Routing method '%s' to target: %s", method, targetURL) 524 | } 525 | } else { 526 | // Try to parse as a batch 527 | var batchReq []JSONRPCRequest 528 | if err := json.Unmarshal(body, &batchReq); err == nil { 529 | // Successfully parsed as a batch 530 | r.Logger.Debug("Parsed batch request with %d methods", len(batchReq)) 531 | 532 | // For simplicity, use the first request's method for logging 533 | if len(batchReq) > 0 { 534 | method = batchReq[0].Method 535 | 536 | // For batch requests, we need a routing strategy 537 | // Here we use the first request to determine the target 538 | targetURL = r.RouteByContext(batchReq[0]) 539 | r.Logger.Info("Routing batch request (first method: '%s') to target: %s", method, targetURL) 540 | } else { 541 | method = "batch" 542 | targetURL = r.Config.GetDefault() 543 | r.Logger.Info("Empty batch request, using default target: %s", targetURL) 544 | } 545 | } else { 546 | // Failed to parse as either single request or batch 547 | r.Logger.Error("Failed to parse JSON-RPC request: %v", err) 548 | http.Error(w, "invalid JSON-RPC request", http.StatusBadRequest) 549 | isError = true 550 | r.UI.Stats.RecordRequest("unknown", "error", time.Since(startTime), isError) 551 | return 552 | } 553 | } 554 | 555 | // Create a new POST request to the target server 556 | proxyReq, err := http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body)) 557 | if err != nil { 558 | r.Logger.Error("Error creating proxy request: %v", err) 559 | http.Error(w, "error creating proxy request: "+err.Error(), http.StatusInternalServerError) 560 | isError = true 561 | r.UI.Stats.RecordRequest(method, targetURL, time.Since(startTime), isError) 562 | return 563 | } 564 | 565 | // Copy relevant headers from the original request 566 | proxyReq.Header.Set("Content-Type", "application/json") 567 | proxyReq.Header.Set("Accept", acceptHeader) 568 | if sessionID != "" { 569 | proxyReq.Header.Set("Mcp-Session-Id", sessionID) 570 | } 571 | 572 | // Forward the Authorization header 573 | if authHeader := req.Header.Get("Authorization"); authHeader != "" { 574 | proxyReq.Header.Set("Authorization", authHeader) 575 | r.Logger.Debug("Forwarding Authorization header") 576 | } 577 | 578 | // Forward MCP-Protocol-Version header if present 579 | if protocolVersion := req.Header.Get("MCP-Protocol-Version"); protocolVersion != "" { 580 | proxyReq.Header.Set("MCP-Protocol-Version", protocolVersion) 581 | } 582 | 583 | // Make the request to the target server 584 | client := &http.Client{ 585 | Timeout: DefaultTimeout, 586 | } 587 | ctx, cancel := context.WithTimeout(req.Context(), DefaultTimeout) 588 | defer cancel() 589 | proxyReq = proxyReq.WithContext(ctx) 590 | resp, err := client.Do(proxyReq) 591 | if err != nil { 592 | r.Logger.Error("Error forwarding request to target: %v", err) 593 | http.Error(w, "error forwarding request: "+err.Error(), http.StatusBadGateway) 594 | isError = true 595 | r.UI.Stats.RecordRequest(method, targetURL, time.Since(startTime), isError) 596 | return 597 | } 598 | defer resp.Body.Close() 599 | 600 | // Check if the response contains a session ID 601 | sessionID = resp.Header.Get("Mcp-Session-Id") 602 | if sessionID != "" { 603 | // Store the session ID -> target mapping 604 | r.SessionStore.Set(sessionID, targetURL) 605 | r.Logger.Debug("Stored session mapping: %s -> %s", sessionID, targetURL) 606 | } 607 | 608 | // Copy response headers 609 | for name, values := range resp.Header { 610 | for _, value := range values { 611 | w.Header().Add(name, value) 612 | } 613 | } 614 | 615 | // Copy response status code 616 | w.WriteHeader(resp.StatusCode) 617 | 618 | // Copy response body 619 | respBody, err := io.ReadAll(resp.Body) 620 | if err != nil { 621 | r.Logger.Error("Error reading response body: %v", err) 622 | // We've already written headers, so we can't send an HTTP error 623 | return 624 | } 625 | w.Write(respBody) 626 | 627 | // Record stats 628 | duration := time.Since(startTime) 629 | isError = resp.StatusCode >= 400 630 | r.UI.Stats.RecordRequest(method, targetURL, duration, isError) 631 | } 632 | 633 | // determineTargetForSession returns the target URL for a given session ID 634 | func (r *Router) determineTargetForSession(sessionID string) string { 635 | // If no session ID is provided, return the default target 636 | if sessionID == "" { 637 | r.Logger.Debug("No session ID provided, using default target") 638 | return r.Config.GetDefault() 639 | } 640 | 641 | // Check if we have this session ID in our session store 642 | if target, exists := r.SessionStore.Get(sessionID); exists { 643 | r.Logger.Debug("Found existing session mapping: %s -> %s", sessionID, target) 644 | return target 645 | } 646 | 647 | // If the session ID is not recognized, we have a few options: 648 | // 1. Return the default target 649 | // 2. Use a consistent hashing algorithm to map the session ID to a target 650 | // 3. Use a round-robin or load-balancing approach 651 | 652 | // For now, we'll use a simple approach - hash the session ID to consistently 653 | // map it to one of our configured targets 654 | targets := r.Config.GetAllTargets() 655 | if len(targets) == 0 { 656 | return r.Config.GetDefault() 657 | } 658 | 659 | // Use a hash of the session ID to pick a target 660 | h := fnv.New32a() 661 | h.Write([]byte(sessionID)) 662 | index := int(h.Sum32()) % len(targets) 663 | target := targets[index] 664 | 665 | r.Logger.Debug("Created new session mapping via hashing: %s -> %s", sessionID, target) 666 | // Store this mapping for future reference 667 | r.SessionStore.Set(sessionID, target) 668 | 669 | return target 670 | } 671 | 672 | // RouteByContext determines the target URL based on the request method and parameters 673 | func (r *Router) RouteByContext(req JSONRPCRequest) string { 674 | switch req.Method { 675 | case "resources/read": 676 | uri, ok := req.Params["uri"].(string) 677 | if ok { 678 | r.Logger.Debug("Routing resources/read for URI: %s", uri) 679 | for _, rule := range r.Config.GetResourceRegexes() { 680 | if rule.Pattern.MatchString(uri) { 681 | r.Logger.Debug("Matched resource rule, using target: %s", rule.Target) 682 | return rule.Target 683 | } 684 | } 685 | } 686 | case "tools/call": 687 | name, ok := req.Params["name"].(string) 688 | if ok { 689 | r.Logger.Debug("Routing tools/call for tool: %s", name) 690 | for _, rule := range r.Config.GetToolRegexes() { 691 | if rule.Pattern.MatchString(name) { 692 | r.Logger.Debug("Matched tool rule, using target: %s", rule.Target) 693 | return rule.Target 694 | } 695 | } 696 | } 697 | } 698 | r.Logger.Debug("No specific routing rule matched, using default target") 699 | return r.Config.GetDefault() 700 | } 701 | 702 | // transformToolCall applies tool name transformation if needed 703 | func (r *Router) transformToolCall(req JSONRPCRequest, targetURL string) { 704 | // Access fields directly 705 | if req.Method != "tools/call" { 706 | return 707 | } 708 | 709 | // Check if name exists in params 710 | name, ok := req.Params["name"].(string) 711 | if !ok { 712 | return 713 | } 714 | 715 | // Rest of function remains the same 716 | for _, mapping := range r.Config.GetToolMappings() { 717 | if mapping.OriginalName == name && mapping.Target == targetURL { 718 | // Transform the tool name 719 | req.Params["name"] = mapping.TargetName 720 | return 721 | } 722 | } 723 | } 724 | --------------------------------------------------------------------------------