├── .dockerignore ├── .github └── workflows │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── Dockerfile ├── README.md ├── cmd └── openapi-mcp │ └── main.go ├── example ├── agent_demo.png ├── docker-compose.yml └── weather │ ├── .env.example │ └── weatherbitio-swagger.json ├── go.mod ├── go.sum ├── openapi-mcp.png └── pkg ├── config ├── config.go └── config_test.go ├── mcp └── types.go ├── parser ├── parser.go └── parser_test.go └── server ├── manager.go ├── manager_test.go ├── server.go └── server_test.go /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git files 2 | .git 3 | .gitignore 4 | 5 | # Docker files 6 | .dockerignore 7 | Dockerfile 8 | 9 | # Documentation 10 | *.md 11 | 12 | # Environment files (except example) 13 | .env 14 | *.env 15 | !.env.example 16 | 17 | # Go cache and modules (handled in multi-stage build) 18 | vendor/ 19 | 20 | # Local build artifacts 21 | openapi-mcp 22 | *.exe 23 | *.test 24 | *.out 25 | 26 | # OS generated files 27 | .DS_Store 28 | *~ -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | environment: CI 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Set up Go 18 | uses: actions/setup-go@v5 19 | with: 20 | go-version: '1.21' 21 | cache: true 22 | 23 | - name: Install dependencies 24 | run: go mod tidy 25 | working-directory: . 26 | 27 | - name: Run tests with coverage 28 | run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... 29 | working-directory: . 30 | 31 | - name: Upload coverage to Codecov 32 | uses: codecov/codecov-action@v5 33 | with: 34 | token: ${{ secrets.CODECOV_TOKEN }} 35 | slug: ckanthony/openapi-mcp -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' # Trigger on version tags like v1.0.0 7 | 8 | jobs: 9 | push_to_registry: 10 | name: Build and push Docker image to Docker Hub 11 | environment: CI 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out the repo 15 | uses: actions/checkout@v4 16 | 17 | - name: Log in to Docker Hub 18 | uses: docker/login-action@v3 19 | with: 20 | username: ckanthony 21 | password: ${{ secrets.DOCKERHUB_TOKEN }} 22 | 23 | - name: Set up Docker Buildx 24 | uses: docker/setup-buildx-action@v3 25 | 26 | - name: Extract metadata (tags, labels) for Docker 27 | id: meta 28 | uses: docker/metadata-action@v5 29 | with: 30 | images: ckanthony/openapi-mcp 31 | # Add git tag as Docker tag 32 | tags: | 33 | type=semver,pattern={{version}} 34 | type=semver,pattern={{major}}.{{minor}} 35 | type=raw,value=latest,enable={{is_default_branch}} 36 | 37 | - name: Build and push Docker image 38 | uses: docker/build-push-action@v5 39 | with: 40 | context: . 41 | push: true 42 | tags: ${{ steps.meta.outputs.tags }} 43 | labels: ${{ steps.meta.outputs.labels }} 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | # Go workspace file 18 | go.work 19 | go.work.sum 20 | 21 | # Environment configuration files 22 | .env 23 | *.env 24 | !.env.example -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # --- Build Stage --- 2 | ARG GO_VERSION=1.22 3 | FROM golang:${GO_VERSION}-alpine AS builder 4 | 5 | WORKDIR /app 6 | 7 | # Copy Go modules and download dependencies first 8 | # This layer is cached unless go.mod or go.sum changes 9 | COPY go.mod go.sum ./ 10 | RUN go mod download 11 | 12 | # Copy the rest of the application source code 13 | COPY . . 14 | 15 | # Build the static binary for the command-line tool 16 | # CGO_ENABLED=0 produces a static binary, important for distroless/scratch images 17 | # -ldflags="-s -w" strips debug symbols and DWARF info, reducing binary size 18 | RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags="-s -w" -o /openapi-mcp ./cmd/openapi-mcp/main.go 19 | 20 | # --- Final Stage --- 21 | # Use a minimal base image. distroless/static is very small and secure. 22 | # alpine is another good option if you need a shell for debugging. 23 | # FROM alpine:latest 24 | FROM gcr.io/distroless/static-debian12 AS final 25 | 26 | # Copy the static binary from the builder stage 27 | COPY --from=builder /openapi-mcp /openapi-mcp 28 | 29 | # Copy example files (optional, but useful for demonstrating) 30 | COPY example /app/example 31 | 32 | WORKDIR /app 33 | 34 | # Define the default command to run when the container starts 35 | # Users can override this command or provide arguments like --spec, --port etc. 36 | ENTRYPOINT ["/openapi-mcp"] 37 | 38 | # Expose the default port (optional, good documentation) 39 | EXPOSE 8080 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAPI-MCP: Dockerized MCP Server to allow your AI agent to access any API with existing api docs 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/ckanthony/openapi-mcp.svg)](https://pkg.go.dev/github.com/ckanthony/openapi-mcp) 4 | [![CI](https://github.com/ckanthony/openapi-mcp/actions/workflows/ci.yml/badge.svg)](https://github.com/ckanthony/openapi-mcp/actions/workflows/ci.yml) 5 | [![codecov](https://codecov.io/gh/ckanthony/openapi-mcp/branch/main/graph/badge.svg)](https://codecov.io/gh/ckanthony/openapi-mcp) 6 | ![](https://badge.mcpx.dev?type=dev 'MCP Dev') 7 | 8 | ![openapi-mcp logo](openapi-mcp.png) 9 | 10 | **Generate MCP tool definitions directly from a Swagger/OpenAPI specification file.** 11 | 12 | OpenAPI-MCP is a dockerized MCP server that reads a `swagger.json` or `openapi.yaml` file and generates a corresponding [Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) toolset. This allows MCP-compatible clients like [Cursor](https://cursor.sh/) to interact with APIs described by standard OpenAPI specifications. Now you can enable your AI agent to access any API by simply providing its OpenAPI/Swagger specification - no additional coding required. 13 | 14 | ## Table of Contents 15 | 16 | - [Why OpenAPI-MCP?](#why-openapi-mcp) 17 | - [Features](#features) 18 | - [Installation](#installation) 19 | - [Using the Pre-built Docker Hub Image (Recommended)](#using-the-pre-built-docker-hub-image-recommended) 20 | - [Building Locally (Optional)](#building-locally-optional) 21 | - [Running the Weatherbit Example (Step-by-Step)](#running-the-weatherbit-example-step-by-step) 22 | - [Command-Line Options](#command-line-options) 23 | - [Environment Variables](#environment-variables) 24 | 25 | ## Demo 26 | 27 | Run the demo yourself: [Running the Weatherbit Example (Step-by-Step)](#running-the-weatherbit-example-step-by-step) 28 | 29 | ![demo](https://github.com/user-attachments/assets/4d457137-5da4-422a-b323-afd4b175bd56) 30 | 31 | ## Why OpenAPI-MCP? 32 | 33 | - **Standard Compliance:** Leverage your existing OpenAPI/Swagger documentation. 34 | - **Automatic Tool Generation:** Create MCP tools without manual configuration for each endpoint. 35 | - **Flexible API Key Handling:** Securely manage API key authentication for the proxied API without exposing keys to the MCP client. 36 | - **Local & Remote Specs:** Works with local specification files or remote URLs. 37 | - **Dockerized Tool:** Easily deploy and run as a containerized service with Docker. 38 | 39 | ## Features 40 | 41 | - **OpenAPI v2 (Swagger) & v3 Support:** Parses standard specification formats. 42 | - **Schema Generation:** Creates MCP tool schemas from OpenAPI operation parameters and request/response definitions. 43 | - **Secure API Key Management:** 44 | - Injects API keys into requests (`header`, `query`, `path`, `cookie`) based on command-line configuration. 45 | - Loads API keys directly from flags (`--api-key`), environment variables (`--api-key-env`), or `.env` files located alongside local specs. 46 | - Keeps API keys hidden from the end MCP client (e.g., the AI assistant). 47 | - **Server URL Detection:** Uses server URLs from the spec as the base for tool interactions (can be overridden). 48 | - **Filtering:** Options to include/exclude specific operations or tags (`--include-tag`, `--exclude-tag`, `--include-op`, `--exclude-op`). 49 | - **Request Header Injection:** Pass custom headers (e.g., for additional auth, tracing) via the `REQUEST_HEADERS` environment variable. 50 | 51 | ## Installation 52 | 53 | ### Docker 54 | 55 | The recommended way to run this tool is via [Docker](https://hub.docker.com/r/ckanthony/openapi-mcp). 56 | 57 | #### Using the Pre-built Docker Hub Image (Recommended) 58 | 59 | Alternatively, you can use the pre-built image available on [Docker Hub](https://hub.docker.com/r/ckanthony/openapi-mcp). 60 | 61 | 1. **Pull the Image:** 62 | ```bash 63 | docker pull ckanthony/openapi-mcp:latest 64 | ``` 65 | 2. **Run the Container:** 66 | Follow the `docker run` examples above, but replace `openapi-mcp:latest` with `ckanthony/openapi-mcp:latest`. 67 | 68 | #### Building Locally (Optional) 69 | 70 | 1. **Build the Docker Image Locally:** 71 | ```bash 72 | # Navigate to the repository root 73 | cd openapi-mcp 74 | # Build the Docker image (tag it as you like, e.g., openapi-mcp:latest) 75 | docker build -t openapi-mcp:latest . 76 | ``` 77 | 78 | 2. **Run the Container:** 79 | You need to provide the OpenAPI specification and any necessary API key configuration when running the container. 80 | 81 | * **Example 1: Using a local spec file and `.env` file:** 82 | - Create a directory (e.g., `./my-api`) containing your `openapi.json` or `swagger.yaml`. 83 | - If the API requires a key, create a `.env` file in the *same directory* (e.g., `./my-api/.env`) with `API_KEY=your_actual_key` (replace `API_KEY` if your `--api-key-env` flag is different). 84 | ```bash 85 | docker run -p 8080:8080 --rm \\ 86 | -v $(pwd)/my-api:/app/spec \\ 87 | --env-file $(pwd)/my-api/.env \\ 88 | openapi-mcp:latest \\ 89 | --spec /app/spec/openapi.json \\ 90 | --api-key-env API_KEY \\ 91 | --api-key-name X-API-Key \\ 92 | --api-key-loc header 93 | ``` 94 | *(Adjust `--spec`, `--api-key-env`, `--api-key-name`, `--api-key-loc`, and `-p` as needed.)* 95 | 96 | * **Example 2: Using a remote spec URL and direct environment variable:** 97 | ```bash 98 | docker run -p 8080:8080 --rm \\ 99 | -e SOME_API_KEY="your_actual_key" \\ 100 | openapi-mcp:latest \\ 101 | --spec https://petstore.swagger.io/v2/swagger.json \\ 102 | --api-key-env SOME_API_KEY \\ 103 | --api-key-name api_key \\ 104 | --api-key-loc header 105 | ``` 106 | 107 | * **Key Docker Run Options:** 108 | * `-p :8080`: Map a port on your host to the container's default port 8080. 109 | * `--rm`: Automatically remove the container when it exits. 110 | * `-v :`: Mount a local directory containing your spec into the container. Use absolute paths or `$(pwd)/...`. Common container path: `/app/spec`. 111 | * `--env-file `: Load environment variables from a local file (for API keys, etc.). Path is on the host. 112 | * `-e =""`: Pass a single environment variable directly. 113 | * `openapi-mcp:latest`: The name of the image you built locally. 114 | * `--spec ...`: **Required.** Path to the spec file *inside the container* (e.g., `/app/spec/openapi.json`) or a public URL. 115 | * `--port 8080`: (Optional) Change the internal port the server listens on (must match the container port in `-p`). 116 | * `--api-key-env`, `--api-key-name`, `--api-key-loc`: Required if the target API needs an API key. 117 | * (See `--help` for all command-line options by running `docker run --rm openapi-mcp:latest --help`) 118 | 119 | 120 | ## Running the Weatherbit Example (Step-by-Step) 121 | 122 | This repository includes an example using the [Weatherbit API](https://www.weatherbit.io/). Here's how to run it using the public Docker image: 123 | 124 | 1. **Find OpenAPI Specs (Optional Knowledge):** 125 | Many public APIs have their OpenAPI/Swagger specifications available online. A great resource for discovering them is [APIs.guru](https://apis.guru/). The Weatherbit specification used in this example (`weatherbitio-swagger.json`) was sourced from there. 126 | 127 | 2. **Get a Weatherbit API Key:** 128 | * Go to [Weatherbit.io](https://www.weatherbit.io/) and sign up for an account (they offer a free tier). 129 | * Find your API key in your Weatherbit account dashboard. 130 | 131 | 3. **Clone this Repository:** 132 | You need the example files from this repository. 133 | ```bash 134 | git clone https://github.com/ckanthony/openapi-mcp.git 135 | cd openapi-mcp 136 | ``` 137 | 138 | 4. **Prepare Environment File:** 139 | * Navigate to the example directory: `cd example/weather` 140 | * Copy the example environment file: `cp .env.example .env` 141 | * Edit the new `.env` file and replace `YOUR_WEATHERBIT_API_KEY_HERE` with the actual API key you obtained from Weatherbit. 142 | 143 | 5. **Run the Docker Container:** 144 | From the `openapi-mcp` **root directory** (the one containing the `example` folder), run the following command: 145 | ```bash 146 | docker run -p 8080:8080 --rm \\ 147 | -v $(pwd)/example/weather:/app/spec \\ 148 | --env-file $(pwd)/example/weather/.env \\ 149 | ckanthony/openapi-mcp:latest \\ 150 | --spec /app/spec/weatherbitio-swagger.json \\ 151 | --api-key-env API_KEY \\ 152 | --api-key-name key \\ 153 | --api-key-loc query 154 | ``` 155 | * `-v $(pwd)/example/weather:/app/spec`: Mounts the local `example/weather` directory (containing the spec and `.env` file) to `/app/spec` inside the container. 156 | * `--env-file $(pwd)/example/weather/.env`: Tells Docker to load environment variables (specifically `API_KEY`) from your `.env` file. 157 | * `ckanthony/openapi-mcp:latest`: Uses the public Docker image. 158 | * `--spec /app/spec/weatherbitio-swagger.json`: Points to the spec file inside the container. 159 | * The `--api-key-*` flags configure how the tool should inject the API key (read from the `API_KEY` env var, named `key`, placed in the `query` string). 160 | 161 | 6. **Access the MCP Server:** 162 | The MCP server should now be running and accessible at `http://localhost:8080` for compatible clients. 163 | 164 | **Using Docker Compose (Example):** 165 | 166 | A `docker-compose.yml` file is provided in the `example/` directory to demonstrate running the Weatherbit API example using the *locally built* image. 167 | 168 | 1. **Prepare Environment File:** Copy `example/weather/.env.example` to `example/weather/.env` and add your actual Weatherbit API key: 169 | ```dotenv 170 | # example/weather/.env 171 | API_KEY=YOUR_ACTUAL_WEATHERBIT_KEY 172 | ``` 173 | 174 | 2. **Run with Docker Compose:** Navigate to the `example` directory and run: 175 | ```bash 176 | cd example 177 | # This builds the image locally based on ../Dockerfile 178 | # It does NOT use the public Docker Hub image 179 | docker-compose up --build 180 | ``` 181 | * `--build`: Forces Docker Compose to build the image using the `Dockerfile` in the project root before starting the service. 182 | * Compose will read `example/docker-compose.yml`, build the image, mount `./weather`, read `./weather/.env`, and start the `openapi-mcp` container with the specified command-line arguments. 183 | * The MCP server will be available at `http://localhost:8080`. 184 | 185 | 3. **Stop the service:** Press `Ctrl+C` in the terminal where Compose is running, or run `docker-compose down` from the `example` directory in another terminal. 186 | 187 | ## Command-Line Options 188 | 189 | The `openapi-mcp` command accepts the following flags: 190 | 191 | | Flag | Description | Type | Default | 192 | |----------------------|---------------------------------------------------------------------------------------------------------------------|---------------|----------------------------------| 193 | | `--spec` | **Required.** Path or URL to the OpenAPI specification file. | `string` | (none) | 194 | | `--port` | Port to run the MCP server on. | `int` | `8080` | 195 | | `--api-key` | Direct API key value (use `--api-key-env` or `.env` file instead for security). | `string` | (none) | 196 | | `--api-key-env` | Environment variable name containing the API key. If spec is local, also checks `.env` file in the spec's directory. | `string` | (none) | 197 | | `--api-key-name` | **Required if key used.** Name of the API key parameter (header, query, path, or cookie name). | `string` | (none) | 198 | | `--api-key-loc` | **Required if key used.** Location of API key: `header`, `query`, `path`, or `cookie`. | `string` | (none) | 199 | | `--include-tag` | Tag to include (can be repeated). If include flags are used, only included items are exposed. | `string slice`| (none) | 200 | | `--exclude-tag` | Tag to exclude (can be repeated). Exclusions apply after inclusions. | `string slice`| (none) | 201 | | `--include-op` | Operation ID to include (can be repeated). | `string slice`| (none) | 202 | | `--exclude-op` | Operation ID to exclude (can be repeated). | `string slice`| (none) | 203 | | `--base-url` | Manually override the target API server base URL detected from the spec. | `string` | (none) | 204 | | `--name` | Default name for the generated MCP toolset (used if spec has no title). | `string` | "OpenAPI-MCP Tools" | 205 | | `--desc` | Default description for the generated MCP toolset (used if spec has no description). | `string` | "Tools generated from OpenAPI spec" | 206 | 207 | **Note:** You can get this list by running the tool with the `--help` flag (e.g., `docker run --rm ckanthony/openapi-mcp:latest --help`). 208 | 209 | ### Environment Variables 210 | 211 | * `REQUEST_HEADERS`: Set this environment variable to a JSON string (e.g., `'{"X-Custom": "Value"}'`) to add custom headers to *all* outgoing requests to the target API. 212 | -------------------------------------------------------------------------------- /cmd/openapi-mcp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | 11 | "github.com/ckanthony/openapi-mcp/pkg/config" 12 | "github.com/ckanthony/openapi-mcp/pkg/parser" 13 | "github.com/ckanthony/openapi-mcp/pkg/server" 14 | "github.com/joho/godotenv" 15 | ) 16 | 17 | // stringSliceFlag allows defining a flag that can be repeated to collect multiple string values. 18 | type stringSliceFlag []string 19 | 20 | func (i *stringSliceFlag) String() string { 21 | return strings.Join(*i, ", ") 22 | } 23 | 24 | func (i *stringSliceFlag) Set(value string) error { 25 | *i = append(*i, value) 26 | return nil 27 | } 28 | 29 | func main() { 30 | // --- Flag Definitions First --- 31 | // Define specPath early so we can use it for .env loading 32 | specPath := flag.String("spec", "", "Path or URL to the OpenAPI specification file (required)") 33 | port := flag.Int("port", 8080, "Port to run the MCP server on") 34 | 35 | apiKey := flag.String("api-key", "", "Direct API key value") 36 | apiKeyEnv := flag.String("api-key-env", "", "Environment variable name containing the API key") 37 | apiKeyName := flag.String("api-key-name", "", "Name of the API key header, query parameter, path parameter, or cookie (required if api-key or api-key-env is set)") 38 | apiKeyLocStr := flag.String("api-key-loc", "", "Location of API key: 'header', 'query', 'path', or 'cookie' (required if api-key or api-key-env is set)") 39 | 40 | var includeTags stringSliceFlag 41 | flag.Var(&includeTags, "include-tag", "Tag to include (can be repeated)") 42 | var excludeTags stringSliceFlag 43 | flag.Var(&excludeTags, "exclude-tag", "Tag to exclude (can be repeated)") 44 | var includeOps stringSliceFlag 45 | flag.Var(&includeOps, "include-op", "Operation ID to include (can be repeated)") 46 | var excludeOps stringSliceFlag 47 | flag.Var(&excludeOps, "exclude-op", "Operation ID to exclude (can be repeated)") 48 | 49 | serverBaseURL := flag.String("base-url", "", "Manually override the server base URL") 50 | defaultToolName := flag.String("name", "OpenAPI-MCP Tools", "Default name for the toolset") 51 | defaultToolDesc := flag.String("desc", "Tools generated from OpenAPI spec", "Default description for the toolset") 52 | 53 | // Parse flags *after* defining them all 54 | flag.Parse() 55 | 56 | // --- Load .env after parsing flags --- 57 | if *specPath != "" && !strings.HasPrefix(*specPath, "http://") && !strings.HasPrefix(*specPath, "https://") { 58 | envPath := filepath.Join(filepath.Dir(*specPath), ".env") 59 | log.Printf("Attempting to load .env file from spec directory: %s", envPath) 60 | err := godotenv.Load(envPath) 61 | if err != nil { 62 | // It's okay if the file doesn't exist, log other errors. 63 | if !os.IsNotExist(err) { 64 | log.Printf("Warning: Error loading .env file from %s: %v", envPath, err) 65 | } else { 66 | log.Printf("Info: No .env file found at %s, proceeding without it.", envPath) 67 | } 68 | } else { 69 | log.Printf("Successfully loaded .env file from %s", envPath) 70 | } 71 | } else if *specPath == "" { 72 | log.Println("Skipping .env load because --spec is missing.") 73 | } else { 74 | log.Println("Skipping .env load because spec path appears to be a URL.") 75 | } 76 | 77 | // --- Read REQUEST_HEADERS env var --- 78 | customHeadersEnv := os.Getenv("REQUEST_HEADERS") 79 | if customHeadersEnv != "" { 80 | log.Printf("Found REQUEST_HEADERS environment variable: %s", customHeadersEnv) 81 | } 82 | 83 | // --- Input Validation --- 84 | if *specPath == "" { 85 | log.Println("Error: --spec flag is required.") 86 | flag.Usage() 87 | os.Exit(1) 88 | } 89 | 90 | var apiKeyLocation config.APIKeyLocation 91 | if *apiKeyLocStr != "" { 92 | switch *apiKeyLocStr { 93 | case string(config.APIKeyLocationHeader): 94 | apiKeyLocation = config.APIKeyLocationHeader 95 | case string(config.APIKeyLocationQuery): 96 | apiKeyLocation = config.APIKeyLocationQuery 97 | case string(config.APIKeyLocationPath): 98 | apiKeyLocation = config.APIKeyLocationPath 99 | case string(config.APIKeyLocationCookie): 100 | apiKeyLocation = config.APIKeyLocationCookie 101 | default: 102 | log.Fatalf("Error: invalid --api-key-loc value: %s. Must be 'header', 'query', 'path', or 'cookie'.", *apiKeyLocStr) 103 | } 104 | } 105 | 106 | // --- Configuration Population --- 107 | cfg := &config.Config{ 108 | SpecPath: *specPath, 109 | APIKey: *apiKey, 110 | APIKeyFromEnvVar: *apiKeyEnv, 111 | APIKeyName: *apiKeyName, 112 | APIKeyLocation: apiKeyLocation, 113 | IncludeTags: includeTags, 114 | ExcludeTags: excludeTags, 115 | IncludeOperations: includeOps, 116 | ExcludeOperations: excludeOps, 117 | ServerBaseURL: *serverBaseURL, 118 | DefaultToolName: *defaultToolName, 119 | DefaultToolDesc: *defaultToolDesc, 120 | CustomHeaders: customHeadersEnv, 121 | } 122 | 123 | log.Printf("Configuration loaded: %+v\n", cfg) 124 | log.Println("API Key (resolved):", cfg.GetAPIKey()) 125 | 126 | // --- Call Parser --- 127 | specDoc, version, err := parser.LoadSwagger(cfg.SpecPath) 128 | if err != nil { 129 | log.Fatalf("Failed to load OpenAPI/Swagger spec: %v", err) 130 | } 131 | log.Printf("Spec type %s loaded successfully from %s.\n", version, cfg.SpecPath) 132 | 133 | toolSet, err := parser.GenerateToolSet(specDoc, version, cfg) 134 | if err != nil { 135 | log.Fatalf("Failed to generate MCP toolset: %v", err) 136 | } 137 | log.Printf("MCP toolset generated with %d tools.\n", len(toolSet.Tools)) 138 | 139 | // --- Start Server --- 140 | addr := fmt.Sprintf(":%d", *port) 141 | log.Printf("Starting MCP server on %s...", addr) 142 | err = server.ServeMCP(addr, toolSet, cfg) // Pass cfg to ServeMCP 143 | if err != nil { 144 | log.Fatalf("Failed to start server: %v", err) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /example/agent_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckanthony/openapi-mcp/93303275cecaf94a0fc807d07a81c454ff6c1d4e/example/agent_demo.png -------------------------------------------------------------------------------- /example/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' # Specifies the Docker Compose file version 2 | 3 | services: 4 | openapi-mcp: 5 | # Build the image using the Dockerfile located in the parent directory 6 | build: 7 | context: .. # The context is the parent directory (project root) 8 | dockerfile: Dockerfile # Explicitly points to the Dockerfile 9 | image: openapi-mcp-example-weather-compose:latest # Optional: Name the image built by compose 10 | container_name: openapi-mcp-example-weather-service # Sets a specific name for the container 11 | 12 | ports: 13 | # Map port 8080 on the host to port 8080 in the container 14 | - "8080:8080" 15 | 16 | volumes: 17 | # Mount the local './weather' directory (relative to this compose file) 18 | # to '/app/example/weather' inside the container. 19 | # This makes the spec file accessible to the application. 20 | - ./weather:/app/example/weather 21 | 22 | # Load environment variables from the .env file located in ./weather 23 | # This is the recommended way to handle secrets like API keys. 24 | # Ensure 'example/weather/.env' exists and defines API_KEY. 25 | env_file: 26 | - ./weather/.env 27 | 28 | # Define the command to run inside the container, overriding the Dockerfile's CMD/ENTRYPOINT args 29 | # Uses the variables loaded from the env_file. 30 | # Make sure the --spec path matches the volume mount point. 31 | command: > 32 | --spec /app/example/weather/weatherbitio-swagger.json 33 | --api-key-env API_KEY 34 | --api-key-name key 35 | --api-key-loc query 36 | --port 8080 # The port the app listens on inside the container 37 | 38 | # Restart policy: Automatically restart the container unless it was manually stopped. 39 | restart: unless-stopped -------------------------------------------------------------------------------- /example/weather/.env.example: -------------------------------------------------------------------------------- 1 | # Example environment variables for the Weatherbit API example. 2 | # Copy this file to .env in the same directory (example/weather/.env) 3 | # and replace placeholders with your actual values. 4 | 5 | # Required: Your Weatherbit API Key 6 | API_KEY=YOUR_WEATHERBIT_API_KEY_HERE 7 | 8 | # Optional: Custom headers (JSON format) 9 | # REQUEST_HEADERS='{"X-Client-ID": "MyTestClient"}' -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/ckanthony/openapi-mcp 2 | 3 | go 1.22.5 4 | 5 | toolchain go1.23.8 6 | 7 | require ( 8 | github.com/getkin/kin-openapi v0.131.0 9 | github.com/go-openapi/loads v0.22.0 10 | github.com/go-openapi/spec v0.21.0 11 | github.com/google/uuid v1.6.0 12 | github.com/joho/godotenv v1.5.1 13 | github.com/stretchr/testify v1.9.0 14 | ) 15 | 16 | require ( 17 | github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/go-openapi/analysis v0.23.0 // indirect 20 | github.com/go-openapi/errors v0.22.0 // indirect 21 | github.com/go-openapi/jsonpointer v0.21.0 // indirect 22 | github.com/go-openapi/jsonreference v0.21.0 // indirect 23 | github.com/go-openapi/strfmt v0.23.0 // indirect 24 | github.com/go-openapi/swag v0.23.0 // indirect 25 | github.com/josharian/intern v1.0.0 // indirect 26 | github.com/mailru/easyjson v0.7.7 // indirect 27 | github.com/mitchellh/mapstructure v1.5.0 // indirect 28 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect 29 | github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect 30 | github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect 31 | github.com/oklog/ulid v1.3.1 // indirect 32 | github.com/perimeterx/marshmallow v1.1.5 // indirect 33 | github.com/pmezard/go-difflib v1.0.0 // indirect 34 | go.mongodb.org/mongo-driver v1.14.0 // indirect 35 | gopkg.in/yaml.v3 v3.0.1 // indirect 36 | ) 37 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= 2 | github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/getkin/kin-openapi v0.131.0 h1:NO2UeHnFKRYhZ8wg6Nyh5Cq7dHk4suQQr72a4pMrDxE= 6 | github.com/getkin/kin-openapi v0.131.0/go.mod h1:3OlG51PCYNsPByuiMB0t4fjnNlIDnaEDsjiKUV8nL58= 7 | github.com/go-openapi/analysis v0.23.0 h1:aGday7OWupfMs+LbmLZG4k0MYXIANxcuBTYUC03zFCU= 8 | github.com/go-openapi/analysis v0.23.0/go.mod h1:9mz9ZWaSlV8TvjQHLl2mUW2PbZtemkE8yA5v22ohupo= 9 | github.com/go-openapi/errors v0.22.0 h1:c4xY/OLxUBSTiepAg3j/MHuAv5mJhnf53LLMWFB+u/w= 10 | github.com/go-openapi/errors v0.22.0/go.mod h1:J3DmZScxCDufmIMsdOuDHxJbdOGC0xtUynjIx092vXE= 11 | github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= 12 | github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= 13 | github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= 14 | github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= 15 | github.com/go-openapi/loads v0.22.0 h1:ECPGd4jX1U6NApCGG1We+uEozOAvXvJSF4nnwHZ8Aco= 16 | github.com/go-openapi/loads v0.22.0/go.mod h1:yLsaTCS92mnSAZX5WWoxszLj0u+Ojl+Zs5Stn1oF+rs= 17 | github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9ZY= 18 | github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= 19 | github.com/go-openapi/strfmt v0.23.0 h1:nlUS6BCqcnAk0pyhi9Y+kdDVZdZMHfEKQiS4HaMgO/c= 20 | github.com/go-openapi/strfmt v0.23.0/go.mod h1:NrtIpfKtWIygRkKVsxh7XQMDQW5HKQl6S5ik2elW+K4= 21 | github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= 22 | github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= 23 | github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= 24 | github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= 25 | github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= 26 | github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 27 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 28 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 29 | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= 30 | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= 31 | github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= 32 | github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= 33 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 34 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 35 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 36 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 37 | github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= 38 | github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= 39 | github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= 40 | github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= 41 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= 42 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= 43 | github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY= 44 | github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw= 45 | github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c= 46 | github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= 47 | github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= 48 | github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= 49 | github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= 50 | github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= 51 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 52 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 53 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 54 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 55 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 56 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 57 | github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= 58 | github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= 59 | go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= 60 | go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= 61 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 62 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 63 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 64 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 65 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 66 | -------------------------------------------------------------------------------- /openapi-mcp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckanthony/openapi-mcp/93303275cecaf94a0fc807d07a81c454ff6c1d4e/openapi-mcp.png -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "log" 5 | "os" 6 | ) 7 | 8 | // APIKeyLocation specifies where the API key is located for requests. 9 | type APIKeyLocation string 10 | 11 | const ( 12 | APIKeyLocationHeader APIKeyLocation = "header" 13 | APIKeyLocationQuery APIKeyLocation = "query" 14 | APIKeyLocationPath APIKeyLocation = "path" 15 | APIKeyLocationCookie APIKeyLocation = "cookie" 16 | // APIKeyLocationCookie APIKeyLocation = "cookie" // Add if needed 17 | ) 18 | 19 | // Config holds the configuration for generating the MCP toolset. 20 | type Config struct { 21 | SpecPath string // Path or URL to the OpenAPI specification file. 22 | 23 | // API Key details (optional, inferred from spec if possible) 24 | APIKey string // The actual API key value. 25 | APIKeyName string // Name of the header or query parameter for the API key (e.g., "X-API-Key", "api_key"). 26 | APIKeyLocation APIKeyLocation // Where the API key should be placed (header, query, path, or cookie). 27 | APIKeyFromEnvVar string // Environment variable name to read the API key from. 28 | 29 | // Filtering (optional) 30 | IncludeTags []string // Only include operations with these tags. 31 | ExcludeTags []string // Exclude operations with these tags. 32 | IncludeOperations []string // Only include operations with these IDs. 33 | ExcludeOperations []string // Exclude operations with these IDs. 34 | 35 | // Overrides (optional) 36 | ServerBaseURL string // Manually override the base URL for API calls, ignoring the spec's servers field. 37 | DefaultToolName string // Name for the toolset if not specified in the spec's info section. 38 | DefaultToolDesc string // Description for the toolset if not specified in the spec's info section. 39 | 40 | // Server-side request modification 41 | CustomHeaders string // Comma-separated list of headers (e.g., "Header1:Value1,Header2:Value2") to add to outgoing requests. 42 | } 43 | 44 | // GetAPIKey resolves the API key value, prioritizing the environment variable over the direct flag. 45 | func (c *Config) GetAPIKey() string { 46 | log.Println("GetAPIKey: Attempting to resolve API key...") 47 | 48 | // 1. Check environment variable specified by --api-key-env 49 | if c.APIKeyFromEnvVar != "" { 50 | log.Printf("GetAPIKey: Checking environment variable specified by --api-key-env: %s", c.APIKeyFromEnvVar) 51 | val := os.Getenv(c.APIKeyFromEnvVar) 52 | if val != "" { 53 | log.Printf("GetAPIKey: Found key in environment variable %s.", c.APIKeyFromEnvVar) 54 | return val 55 | } 56 | log.Printf("GetAPIKey: Environment variable %s not found or empty.", c.APIKeyFromEnvVar) 57 | } else { 58 | log.Println("GetAPIKey: No --api-key-env variable specified.") 59 | } 60 | 61 | // 2. Check direct flag --api-key 62 | if c.APIKey != "" { 63 | log.Println("GetAPIKey: Found key provided directly via --api-key flag.") 64 | return c.APIKey 65 | } 66 | 67 | // 3. No key found 68 | log.Println("GetAPIKey: No API key found from config (env var or direct flag).") 69 | return "" 70 | } 71 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestConfig_GetAPIKey(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | config Config 12 | envKey string // Environment variable name to set 13 | envValue string // Value to set for the env var 14 | expectedKey string 15 | cleanupEnv bool // Flag to indicate if env var needs cleanup 16 | }{ 17 | { 18 | name: "No key set", 19 | config: Config{}, // Empty config 20 | expectedKey: "", 21 | }, 22 | { 23 | name: "Direct key set only", 24 | config: Config{ 25 | APIKey: "direct-key-123", 26 | }, 27 | expectedKey: "direct-key-123", 28 | }, 29 | { 30 | name: "Env var set only", 31 | config: Config{ 32 | APIKeyFromEnvVar: "TEST_API_KEY_ENV_ONLY", 33 | }, 34 | envKey: "TEST_API_KEY_ENV_ONLY", 35 | envValue: "env-key-456", 36 | expectedKey: "env-key-456", 37 | cleanupEnv: true, 38 | }, 39 | { 40 | name: "Both direct and env var set (env takes precedence)", 41 | config: Config{ 42 | APIKey: "direct-key-789", 43 | APIKeyFromEnvVar: "TEST_API_KEY_BOTH", 44 | }, 45 | envKey: "TEST_API_KEY_BOTH", 46 | envValue: "env-key-abc", 47 | expectedKey: "env-key-abc", 48 | cleanupEnv: true, 49 | }, 50 | { 51 | name: "Direct key set, env var specified but not set", 52 | config: Config{ 53 | APIKey: "direct-key-xyz", 54 | APIKeyFromEnvVar: "TEST_API_KEY_UNSET", 55 | }, 56 | envKey: "TEST_API_KEY_UNSET", // Ensure this is not set 57 | envValue: "", 58 | expectedKey: "direct-key-xyz", // Should fall back to direct key 59 | cleanupEnv: true, // Cleanup in case it was set previously 60 | }, 61 | { 62 | name: "Env var specified but empty string value", 63 | config: Config{ 64 | APIKeyFromEnvVar: "TEST_API_KEY_EMPTY", 65 | }, 66 | envKey: "TEST_API_KEY_EMPTY", 67 | envValue: "", // Explicitly set to empty string 68 | expectedKey: "", // Empty env var should result in empty key 69 | cleanupEnv: true, 70 | }, 71 | } 72 | 73 | for _, tc := range tests { 74 | t.Run(tc.name, func(t *testing.T) { 75 | // Set environment variable if needed for this test case 76 | if tc.envKey != "" { 77 | originalValue, wasSet := os.LookupEnv(tc.envKey) 78 | err := os.Setenv(tc.envKey, tc.envValue) 79 | if err != nil { 80 | t.Fatalf("Failed to set environment variable %s: %v", tc.envKey, err) 81 | } 82 | // Schedule cleanup 83 | if tc.cleanupEnv { 84 | t.Cleanup(func() { 85 | if wasSet { 86 | os.Setenv(tc.envKey, originalValue) 87 | } else { 88 | os.Unsetenv(tc.envKey) 89 | } 90 | }) 91 | } 92 | } else { 93 | // Ensure env var is unset if tc.envKey is empty (for tests like "Direct key set only") 94 | // This prevents interference from previous tests if not cleaned up properly. 95 | os.Unsetenv(tc.config.APIKeyFromEnvVar) // Unset based on config field if relevant 96 | } 97 | 98 | // Call the method under test 99 | actualKey := tc.config.GetAPIKey() 100 | 101 | // Assert the result 102 | if actualKey != tc.expectedKey { 103 | t.Errorf("Expected API key %q, but got %q", tc.expectedKey, actualKey) 104 | } 105 | }) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /pkg/mcp/types.go: -------------------------------------------------------------------------------- 1 | package mcp 2 | 3 | // Based on the MCP specification: https://modelcontextprotocol.io/spec/ 4 | 5 | // ParameterDetail describes a single parameter for an operation. 6 | type ParameterDetail struct { 7 | Name string `json:"name"` 8 | In string `json:"in"` // Location (query, header, path, cookie) 9 | // Add other details if needed, e.g., required, type 10 | } 11 | 12 | // OperationDetail holds the necessary information to execute a specific API operation. 13 | type OperationDetail struct { 14 | Method string `json:"method"` 15 | Path string `json:"path"` // Path template (e.g., /users/{id}) 16 | BaseURL string `json:"baseUrl"` 17 | Parameters []ParameterDetail `json:"parameters,omitempty"` 18 | // Add RequestBody schema if needed 19 | } 20 | 21 | // ToolSet represents the collection of tools provided by an MCP server. 22 | type ToolSet struct { 23 | MCPVersion string `json:"mcp_version"` 24 | Name string `json:"name"` 25 | Description string `json:"description,omitempty"` 26 | // Auth *AuthInfo `json:"auth,omitempty"` // Removed authentication info 27 | Tools []Tool `json:"tools"` 28 | 29 | // Operations maps Tool.Name (operationId) to its execution details. 30 | // This is internal to the server and not part of the standard MCP JSON response. 31 | Operations map[string]OperationDetail `json:"-"` // Use json:"-" to exclude from JSON 32 | 33 | // Internal fields for server-side auth handling (not exposed in JSON) 34 | apiKeyName string // e.g., "key", "X-API-Key" 35 | apiKeyIn string // e.g., "query", "header" 36 | } 37 | 38 | // SetAPIKeyDetails allows the parser to set internal API key info. 39 | func (ts *ToolSet) SetAPIKeyDetails(name, in string) { 40 | ts.apiKeyName = name 41 | ts.apiKeyIn = in 42 | } 43 | 44 | // GetAPIKeyDetails allows the server to retrieve internal API key info. 45 | // We might need this later when making the request. 46 | func (ts *ToolSet) GetAPIKeyDetails() (name, in string) { 47 | return ts.apiKeyName, ts.apiKeyIn 48 | } 49 | 50 | // Tool represents a single function or capability exposed via MCP. 51 | type Tool struct { 52 | Name string `json:"name"` // Corresponds to OpenAPI operationId or generated name 53 | Description string `json:"description,omitempty"` 54 | InputSchema Schema `json:"inputSchema"` // Renamed from Parameters, consolidate parameters/body here 55 | // Entrypoint string `json:"entrypoint"` // Removed for simplicity, schema should contain enough info? 56 | // RequestBody RequestBody `json:"request_body,omitempty"` // Removed, info should be part of InputSchema 57 | // HTTPMethod string `json:"http_method"` // Removed for simplicity 58 | // TODO: Add Response handling if needed by spec/client 59 | } 60 | 61 | // RequestBody describes the expected request body for a tool. 62 | // This might become redundant if all info is in InputSchema. 63 | // Keeping it for now as the parser might still use it internally. 64 | type RequestBody struct { 65 | Description string `json:"description,omitempty"` 66 | Required bool `json:"required,omitempty"` 67 | Content map[string]Schema `json:"content"` // Keyed by media type (e.g., "application/json") 68 | } 69 | 70 | // Schema defines the structure and constraints of data (parameters or request/response bodies). 71 | // This mirrors a subset of JSON Schema properties. 72 | type Schema struct { 73 | Type string `json:"type,omitempty"` // e.g., "object", "string", "integer", "array" 74 | Description string `json:"description,omitempty"` 75 | Properties map[string]Schema `json:"properties,omitempty"` // For type "object" 76 | Required []string `json:"required,omitempty"` // For type "object" 77 | Items *Schema `json:"items,omitempty"` // For type "array" 78 | Format string `json:"format,omitempty"` // e.g., "int32", "date-time" 79 | Enum []interface{} `json:"enum,omitempty"` 80 | // Add other relevant JSON Schema fields as needed (e.g., minimum, maximum, pattern) 81 | } 82 | -------------------------------------------------------------------------------- /pkg/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "path/filepath" 13 | "sort" 14 | "strings" 15 | 16 | "github.com/ckanthony/openapi-mcp/pkg/config" 17 | "github.com/ckanthony/openapi-mcp/pkg/mcp" 18 | "github.com/getkin/kin-openapi/openapi3" 19 | "github.com/go-openapi/loads" 20 | "github.com/go-openapi/spec" 21 | ) 22 | 23 | const ( 24 | VersionV2 = "v2" 25 | VersionV3 = "v3" 26 | ) 27 | 28 | // LoadSwagger detects the version and loads an OpenAPI/Swagger specification 29 | // from a local file path or a remote URL. 30 | // It returns the loaded spec document (as interface{}), the detected version (string), and an error. 31 | func LoadSwagger(location string) (interface{}, string, error) { 32 | // Determine if location is URL or file path 33 | locationURL, urlErr := url.ParseRequestURI(location) 34 | isURL := urlErr == nil && locationURL != nil && (locationURL.Scheme == "http" || locationURL.Scheme == "https") 35 | 36 | var data []byte 37 | var err error 38 | var absPath string // Store absolute path if it's a file 39 | 40 | if !isURL { 41 | log.Printf("Detected file path location: %s", location) 42 | absPath, err = filepath.Abs(location) 43 | if err != nil { 44 | return nil, "", fmt.Errorf("failed to get absolute path for '%s': %w", location, err) 45 | } 46 | // Read data first for version detection 47 | data, err = os.ReadFile(absPath) 48 | if err != nil { 49 | return nil, "", fmt.Errorf("failed reading file path '%s': %w", absPath, err) 50 | } 51 | } else { 52 | log.Printf("Detected URL location: %s", location) 53 | // Read data first for version detection 54 | resp, err := http.Get(location) 55 | if err != nil { 56 | return nil, "", fmt.Errorf("failed to fetch URL '%s': %w", location, err) 57 | } 58 | defer resp.Body.Close() 59 | if resp.StatusCode != http.StatusOK { 60 | bodyBytes, _ := io.ReadAll(resp.Body) // Attempt to read body for error context 61 | return nil, "", fmt.Errorf("failed to fetch URL '%s': status code %d, body: %s", location, resp.StatusCode, string(bodyBytes)) 62 | } 63 | data, err = io.ReadAll(resp.Body) 64 | if err != nil { 65 | return nil, "", fmt.Errorf("failed to read response body from URL '%s': %w", location, err) 66 | } 67 | } 68 | 69 | // Detect version from data 70 | var detector map[string]interface{} 71 | if err := json.Unmarshal(data, &detector); err != nil { 72 | return nil, "", fmt.Errorf("failed to parse JSON from '%s' for version detection: %w", location, err) 73 | } 74 | 75 | if _, ok := detector["openapi"]; ok { 76 | // OpenAPI 3.x 77 | loader := openapi3.NewLoader() 78 | loader.IsExternalRefsAllowed = true 79 | var doc *openapi3.T 80 | var loadErr error 81 | 82 | if !isURL { 83 | // Use LoadFromFile for local files 84 | log.Printf("Loading V3 spec using LoadFromFile: %s", absPath) 85 | doc, loadErr = loader.LoadFromFile(absPath) 86 | } else { 87 | // Use LoadFromURI for URLs 88 | log.Printf("Loading V3 spec using LoadFromURI: %s", location) 89 | doc, loadErr = loader.LoadFromURI(locationURL) 90 | } 91 | 92 | if loadErr != nil { 93 | return nil, "", fmt.Errorf("failed to load OpenAPI v3 spec from '%s': %w", location, loadErr) 94 | } 95 | 96 | if err := doc.Validate(context.Background()); err != nil { 97 | return nil, "", fmt.Errorf("OpenAPI v3 spec validation failed for '%s': %w", location, err) 98 | } 99 | return doc, VersionV3, nil 100 | } else if _, ok := detector["swagger"]; ok { 101 | // Swagger 2.0 - Still load from data as loads.Analyzed expects bytes 102 | log.Printf("Loading V2 spec using loads.Analyzed from data (source: %s)", location) 103 | doc, err := loads.Analyzed(data, "2.0") 104 | if err != nil { 105 | return nil, "", fmt.Errorf("failed to load or validate Swagger v2 spec from '%s': %w", location, err) 106 | } 107 | return doc.Spec(), VersionV2, nil 108 | } else { 109 | return nil, "", fmt.Errorf("failed to detect OpenAPI/Swagger version in '%s': missing 'openapi' or 'swagger' key", location) 110 | } 111 | } 112 | 113 | // GenerateToolSet converts a loaded spec (v2 or v3) into an MCP ToolSet. 114 | func GenerateToolSet(specDoc interface{}, version string, cfg *config.Config) (*mcp.ToolSet, error) { 115 | switch version { 116 | case VersionV3: 117 | docV3, ok := specDoc.(*openapi3.T) 118 | if !ok { 119 | return nil, fmt.Errorf("internal error: expected *openapi3.T for v3 spec, got %T", specDoc) 120 | } 121 | return generateToolSetV3(docV3, cfg) 122 | case VersionV2: 123 | docV2, ok := specDoc.(*spec.Swagger) 124 | if !ok { 125 | return nil, fmt.Errorf("internal error: expected *spec.Swagger for v2 spec, got %T", specDoc) 126 | } 127 | return generateToolSetV2(docV2, cfg) 128 | default: 129 | return nil, fmt.Errorf("unsupported specification version: %s", version) 130 | } 131 | } 132 | 133 | // --- V3 Specific Implementation --- 134 | 135 | func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error) { 136 | toolSet := createBaseToolSet(doc.Info.Title, doc.Info.Description, cfg) 137 | toolSet.Operations = make(map[string]mcp.OperationDetail) // Initialize the map 138 | 139 | // Determine Base URL once 140 | baseURL, err := determineBaseURLV3(doc, cfg) 141 | if err != nil { 142 | log.Printf("Warning: Could not determine base URL for V3 spec: %v. Operations might fail if base URL override is not set.", err) 143 | baseURL = "" // Allow proceeding if override is set 144 | } 145 | 146 | // // V3 Handles security differently (Components.SecuritySchemes). Rely on config flags for server-side injection. 147 | // apiKeyName := cfg.APIKeyName 148 | // apiKeyIn := string(cfg.APIKeyLocation) 149 | // // Store detected/configured key details internally - Let config handle this 150 | // toolSet.SetAPIKeyDetails(apiKeyName, apiKeyIn) 151 | 152 | paths := getSortedPathsV3(doc.Paths) 153 | for _, rawPath := range paths { // Rename loop var to rawPath 154 | pathItem := doc.Paths.Value(rawPath) 155 | for method, op := range pathItem.Operations() { 156 | if op == nil || !shouldIncludeOperationV3(op, cfg) { 157 | continue 158 | } 159 | 160 | // Clean the path 161 | cleanPath := rawPath 162 | if queryIndex := strings.Index(rawPath, "?"); queryIndex != -1 { 163 | cleanPath = rawPath[:queryIndex] 164 | } 165 | 166 | toolName := generateToolNameV3(op, method, rawPath) // Still generate name from raw path 167 | toolDesc := getOperationDescriptionV3(op) 168 | 169 | // Convert parameters (query, header, path, cookie) 170 | parametersSchema, opParams, err := parametersToMCPSchemaAndDetailsV3(op.Parameters, cfg) 171 | if err != nil { 172 | return nil, fmt.Errorf("error processing v3 parameters for %s %s: %w", method, rawPath, err) 173 | } 174 | 175 | // Handle request body 176 | requestBody, err := requestBodyToMCPV3(op.RequestBody) 177 | if err != nil { 178 | log.Printf("Warning: skipping request body for %s %s due to error: %v", method, rawPath, err) 179 | } else { 180 | // Merge request body schema into the main parameter schema 181 | if requestBody.Content != nil { 182 | if parametersSchema.Properties == nil { 183 | parametersSchema.Properties = make(map[string]mcp.Schema) 184 | } 185 | for _, mediaTypeSchema := range requestBody.Content { 186 | if mediaTypeSchema.Type == "object" && mediaTypeSchema.Properties != nil { 187 | for propName, propSchema := range mediaTypeSchema.Properties { 188 | parametersSchema.Properties[propName] = propSchema 189 | } 190 | } else { 191 | // If body is not an object, represent as 'requestBody' 192 | log.Printf("Warning: V3 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) 193 | parametersSchema.Properties["requestBody"] = mediaTypeSchema 194 | } 195 | break // Only process the first content type 196 | } 197 | 198 | // Merge required fields from the body *schema* (not the requestBody boolean) 199 | var bodySchemaRequired []string 200 | for _, mediaTypeSchema := range requestBody.Content { 201 | if len(mediaTypeSchema.Required) > 0 { 202 | bodySchemaRequired = mediaTypeSchema.Required 203 | break // Use required from the first content type with a schema 204 | } 205 | } 206 | 207 | if len(bodySchemaRequired) > 0 { 208 | if parametersSchema.Required == nil { 209 | parametersSchema.Required = make([]string, 0) 210 | } 211 | for _, r := range bodySchemaRequired { // Range over the correct schema required list 212 | if !sliceContains(parametersSchema.Required, r) { 213 | parametersSchema.Required = append(parametersSchema.Required, r) 214 | } 215 | } 216 | sort.Strings(parametersSchema.Required) 217 | } 218 | 219 | // Optionally, add a note if the requestBody itself was marked as required 220 | if requestBody.Required { // Check the boolean field 221 | // How to indicate this? Maybe add to description? 222 | log.Printf("Note: Request body for %s %s is marked as required.", method, rawPath) 223 | // Or add all top-level body props to required? Needs decision. 224 | } 225 | } 226 | } 227 | 228 | // Prepend note about API key handling 229 | finalToolDesc := "Note: The API key is handled by the server, no need to provide it. " + toolDesc 230 | 231 | tool := mcp.Tool{ 232 | Name: toolName, 233 | Description: finalToolDesc, // Use potentially modified description 234 | InputSchema: parametersSchema, // Use InputSchema, assuming it contains combined params/body 235 | } 236 | toolSet.Tools = append(toolSet.Tools, tool) 237 | 238 | // Store operation details for execution 239 | toolSet.Operations[toolName] = mcp.OperationDetail{ 240 | Method: method, 241 | Path: cleanPath, // Use the cleaned path here 242 | BaseURL: baseURL, 243 | Parameters: opParams, 244 | } 245 | } 246 | } 247 | return toolSet, nil 248 | } 249 | 250 | func determineBaseURLV3(doc *openapi3.T, cfg *config.Config) (string, error) { 251 | if cfg.ServerBaseURL != "" { 252 | return strings.TrimSuffix(cfg.ServerBaseURL, "/"), nil 253 | } 254 | if len(doc.Servers) > 0 { 255 | baseURL := "" 256 | for _, server := range doc.Servers { 257 | if baseURL == "" { 258 | baseURL = server.URL 259 | } 260 | if strings.HasPrefix(strings.ToLower(server.URL), "https://") { 261 | baseURL = server.URL 262 | break 263 | } 264 | if strings.HasPrefix(strings.ToLower(server.URL), "http://") { 265 | baseURL = server.URL 266 | } 267 | } 268 | if baseURL == "" { 269 | return "", fmt.Errorf("v3: could not determine a suitable base URL from servers list") 270 | } 271 | return strings.TrimSuffix(baseURL, "/"), nil 272 | } 273 | return "", fmt.Errorf("v3: no server base URL specified in config or OpenAPI spec servers list") 274 | } 275 | 276 | func getSortedPathsV3(paths *openapi3.Paths) []string { 277 | if paths == nil { 278 | return []string{} 279 | } 280 | keys := make([]string, 0, len(paths.Map())) 281 | for k := range paths.Map() { 282 | keys = append(keys, k) 283 | } 284 | sort.Strings(keys) 285 | return keys 286 | } 287 | 288 | func generateToolNameV3(op *openapi3.Operation, method, path string) string { 289 | if op.OperationID != "" { 290 | return op.OperationID 291 | } 292 | return generateDefaultToolName(method, path) 293 | } 294 | 295 | func getOperationDescriptionV3(op *openapi3.Operation) string { 296 | if op.Summary != "" { 297 | return op.Summary 298 | } 299 | return op.Description 300 | } 301 | 302 | func shouldIncludeOperationV3(op *openapi3.Operation, cfg *config.Config) bool { 303 | return shouldInclude(op.OperationID, op.Tags, cfg) 304 | } 305 | 306 | // parametersToMCPSchemaAndDetailsV3 converts parameters and also returns the parameter details. 307 | func parametersToMCPSchemaAndDetailsV3(params openapi3.Parameters, cfg *config.Config) (mcp.Schema, []mcp.ParameterDetail, error) { 308 | mcpSchema := mcp.Schema{Type: "object", Properties: make(map[string]mcp.Schema), Required: []string{}} 309 | opParams := []mcp.ParameterDetail{} 310 | for _, paramRef := range params { 311 | if paramRef.Value == nil { 312 | log.Printf("Warning: Skipping parameter with nil value.") 313 | continue 314 | } 315 | param := paramRef.Value 316 | if param.Schema == nil { 317 | log.Printf("Warning: Skipping parameter '%s' with nil schema.", param.Name) 318 | continue 319 | } 320 | 321 | // Skip the API key parameter if configured 322 | if cfg.APIKeyName != "" && param.Name == cfg.APIKeyName && param.In == string(cfg.APIKeyLocation) { 323 | log.Printf("Parser V3: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) 324 | continue 325 | } 326 | 327 | // Store parameter detail (even if skipped for schema, needed for execution?) 328 | // Decision: Keep storing *all* params in opParams for potential server-side use, 329 | // but skip adding the API key to the mcpSchema exposed to the client. 330 | opParams = append(opParams, mcp.ParameterDetail{ 331 | Name: param.Name, 332 | In: param.In, 333 | }) 334 | 335 | propSchema, err := openapiSchemaToMCPSchemaV3(param.Schema) 336 | if err != nil { 337 | return mcp.Schema{}, nil, fmt.Errorf("v3 param '%s': %w", param.Name, err) 338 | } 339 | propSchema.Description = param.Description 340 | mcpSchema.Properties[param.Name] = propSchema 341 | if param.Required { 342 | mcpSchema.Required = append(mcpSchema.Required, param.Name) 343 | } 344 | } 345 | if len(mcpSchema.Required) > 1 { 346 | sort.Strings(mcpSchema.Required) 347 | } 348 | return mcpSchema, opParams, nil 349 | } 350 | 351 | func requestBodyToMCPV3(rbRef *openapi3.RequestBodyRef) (mcp.RequestBody, error) { 352 | mcpRB := mcp.RequestBody{Content: make(map[string]mcp.Schema)} 353 | if rbRef == nil || rbRef.Value == nil { 354 | return mcpRB, nil 355 | } 356 | rb := rbRef.Value 357 | mcpRB.Description = rb.Description 358 | mcpRB.Required = rb.Required 359 | 360 | var mediaType *openapi3.MediaType 361 | var chosenMediaTypeKey string 362 | if mt, ok := rb.Content["application/json"]; ok { 363 | mediaType, chosenMediaTypeKey = mt, "application/json" 364 | } else { 365 | for key, mt := range rb.Content { 366 | mediaType, chosenMediaTypeKey = mt, key 367 | break 368 | } 369 | } 370 | 371 | if mediaType != nil && mediaType.Schema != nil { 372 | contentSchema, err := openapiSchemaToMCPSchemaV3(mediaType.Schema) 373 | if err != nil { 374 | return mcp.RequestBody{}, fmt.Errorf("v3 request body (media type: %s): %w", chosenMediaTypeKey, err) 375 | } 376 | mcpRB.Content["application/json"] = contentSchema 377 | } else if mediaType != nil { 378 | mcpRB.Content["application/json"] = mcp.Schema{Type: "string", Description: fmt.Sprintf("Request body with media type %s (no specific schema defined)", chosenMediaTypeKey)} 379 | } 380 | return mcpRB, nil 381 | } 382 | 383 | func openapiSchemaToMCPSchemaV3(oapiSchemaRef *openapi3.SchemaRef) (mcp.Schema, error) { 384 | if oapiSchemaRef == nil { 385 | return mcp.Schema{Type: "string", Description: "Schema reference was nil"}, nil 386 | } 387 | if oapiSchemaRef.Value == nil { 388 | return mcp.Schema{Type: "string", Description: fmt.Sprintf("Schema reference value was nil (ref: %s)", oapiSchemaRef.Ref)}, nil 389 | } 390 | oapiSchema := oapiSchemaRef.Value 391 | 392 | var primaryType string 393 | if oapiSchema.Type != nil && len(*oapiSchema.Type) > 0 { 394 | primaryType = (*oapiSchema.Type)[0] 395 | } 396 | 397 | mcpSchema := mcp.Schema{ 398 | Type: mapJSONSchemaType(primaryType), 399 | Description: oapiSchema.Description, 400 | Format: oapiSchema.Format, 401 | Enum: oapiSchema.Enum, 402 | } 403 | 404 | switch mcpSchema.Type { 405 | case "object": 406 | mcpSchema.Properties = make(map[string]mcp.Schema) 407 | mcpSchema.Required = oapiSchema.Required 408 | for name, propRef := range oapiSchema.Properties { 409 | propSchema, err := openapiSchemaToMCPSchemaV3(propRef) 410 | if err != nil { 411 | return mcp.Schema{}, fmt.Errorf("v3 object property '%s': %w", name, err) 412 | } 413 | mcpSchema.Properties[name] = propSchema 414 | } 415 | if len(mcpSchema.Required) > 1 { 416 | sort.Strings(mcpSchema.Required) 417 | } 418 | case "array": 419 | if oapiSchema.Items != nil { 420 | itemsSchema, err := openapiSchemaToMCPSchemaV3(oapiSchema.Items) 421 | if err != nil { 422 | return mcp.Schema{}, fmt.Errorf("v3 array items: %w", err) 423 | } 424 | mcpSchema.Items = &itemsSchema 425 | } 426 | case "string", "number", "integer", "boolean", "null": 427 | // Basic types mapped 428 | default: 429 | if mcpSchema.Type == "string" && primaryType != "" && primaryType != "string" { 430 | mcpSchema.Description += fmt.Sprintf(" (Original type '%s' unknown or unsupported)", primaryType) 431 | } 432 | } 433 | return mcpSchema, nil 434 | } 435 | 436 | // --- V2 Specific Implementation --- 437 | 438 | func generateToolSetV2(doc *spec.Swagger, cfg *config.Config) (*mcp.ToolSet, error) { 439 | toolSet := createBaseToolSet(doc.Info.Title, doc.Info.Description, cfg) 440 | toolSet.Operations = make(map[string]mcp.OperationDetail) // Initialize map 441 | 442 | // Determine Base URL once 443 | baseURL, err := determineBaseURLV2(doc, cfg) 444 | if err != nil { 445 | log.Printf("Warning: Could not determine base URL for V2 spec: %v. Operations might fail if base URL override is not set.", err) 446 | baseURL = "" // Allow proceeding if override is set 447 | } 448 | 449 | // Detect API Key (Security Definitions) 450 | apiKeyName := cfg.APIKeyName 451 | apiKeyIn := string(cfg.APIKeyLocation) 452 | 453 | if apiKeyName == "" && apiKeyIn == "" { // Only infer if not provided by config 454 | for name, secDef := range doc.SecurityDefinitions { 455 | if secDef.Type == "apiKey" { 456 | apiKeyName = secDef.Name 457 | apiKeyIn = secDef.In // "query" or "header" 458 | log.Printf("Parser V2: Detected API key from security definition '%s': Name='%s', In='%s'", name, apiKeyName, apiKeyIn) 459 | break // Assume only one apiKey definition for simplicity 460 | } 461 | } 462 | } 463 | // Store detected/configured key details internally 464 | toolSet.SetAPIKeyDetails(apiKeyName, apiKeyIn) 465 | 466 | // --- Iterate through Paths --- 467 | paths := getSortedPathsV2(doc.Paths) 468 | for _, rawPath := range paths { // Rename loop var to rawPath 469 | pathItem := doc.Paths.Paths[rawPath] 470 | ops := map[string]*spec.Operation{ 471 | "GET": pathItem.Get, 472 | "PUT": pathItem.Put, 473 | "POST": pathItem.Post, 474 | "DELETE": pathItem.Delete, 475 | "OPTIONS": pathItem.Options, 476 | "HEAD": pathItem.Head, 477 | "PATCH": pathItem.Patch, 478 | } 479 | 480 | for method, op := range ops { 481 | if op == nil || !shouldIncludeOperationV2(op, cfg) { 482 | continue 483 | } 484 | 485 | // Clean the path 486 | cleanPath := rawPath 487 | if queryIndex := strings.Index(rawPath, "?"); queryIndex != -1 { 488 | cleanPath = rawPath[:queryIndex] 489 | } 490 | 491 | toolName := generateToolNameV2(op, method, rawPath) // Still generate name from raw path 492 | toolDesc := getOperationDescriptionV2(op) 493 | 494 | // Convert parameters and potential body schema 495 | parametersSchema, bodySchema, opParams, err := parametersToMCPSchemaAndDetailsV2(op.Parameters, doc.Definitions, apiKeyName) 496 | if err != nil { 497 | return nil, fmt.Errorf("error processing v2 parameters for %s %s: %w", method, rawPath, err) 498 | } 499 | 500 | // Combine request body into parameters schema if it exists 501 | if bodySchema.Type != "" { // Check if bodySchema was actually populated 502 | if bodySchema.Type == "object" && bodySchema.Properties != nil { 503 | if parametersSchema.Properties == nil { 504 | parametersSchema.Properties = make(map[string]mcp.Schema) 505 | } 506 | for propName, propSchema := range bodySchema.Properties { 507 | parametersSchema.Properties[propName] = propSchema 508 | } 509 | if len(bodySchema.Required) > 0 { 510 | if parametersSchema.Required == nil { 511 | parametersSchema.Required = make([]string, 0) 512 | } 513 | for _, r := range bodySchema.Required { 514 | if !sliceContains(parametersSchema.Required, r) { 515 | parametersSchema.Required = append(parametersSchema.Required, r) 516 | } 517 | } 518 | sort.Strings(parametersSchema.Required) 519 | } 520 | } else { 521 | // If body is not an object, represent as 'requestBody' 522 | log.Printf("Warning: V2 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) 523 | if parametersSchema.Properties == nil { 524 | parametersSchema.Properties = make(map[string]mcp.Schema) 525 | } 526 | parametersSchema.Properties["requestBody"] = bodySchema 527 | } 528 | } 529 | 530 | // Prepend note about API key handling 531 | finalToolDesc := "Note: The API key is handled by the server, no need to provide it. " + toolDesc 532 | 533 | tool := mcp.Tool{ 534 | Name: toolName, 535 | Description: finalToolDesc, // Use potentially modified description 536 | InputSchema: parametersSchema, // Use InputSchema, assuming it contains combined params/body 537 | } 538 | toolSet.Tools = append(toolSet.Tools, tool) 539 | 540 | // Store operation details for execution 541 | toolSet.Operations[toolName] = mcp.OperationDetail{ 542 | Method: method, 543 | Path: cleanPath, // Use the cleaned path here 544 | BaseURL: baseURL, 545 | Parameters: opParams, 546 | } 547 | } 548 | } 549 | 550 | return toolSet, nil 551 | } 552 | 553 | func determineBaseURLV2(doc *spec.Swagger, cfg *config.Config) (string, error) { 554 | if cfg.ServerBaseURL != "" { 555 | return strings.TrimSuffix(cfg.ServerBaseURL, "/"), nil 556 | } 557 | 558 | host := doc.Host 559 | if host == "" { 560 | return "", fmt.Errorf("v2: missing 'host' in spec") 561 | } 562 | 563 | scheme := "https" 564 | if len(doc.Schemes) > 0 { 565 | // Prefer https, then http, then first 566 | preferred := []string{"https", "http"} 567 | found := false 568 | for _, p := range preferred { 569 | for _, s := range doc.Schemes { 570 | if s == p { 571 | scheme = s 572 | found = true 573 | break 574 | } 575 | } 576 | if found { 577 | break 578 | } 579 | } 580 | if !found { 581 | scheme = doc.Schemes[0] 582 | } // fallback to first scheme 583 | } // else default to https 584 | 585 | basePath := doc.BasePath 586 | 587 | return strings.TrimSuffix(scheme+"://"+host+basePath, "/"), nil 588 | } 589 | 590 | func getSortedPathsV2(paths *spec.Paths) []string { 591 | if paths == nil { 592 | return []string{} 593 | } 594 | keys := make([]string, 0, len(paths.Paths)) 595 | for k := range paths.Paths { 596 | keys = append(keys, k) 597 | } 598 | sort.Strings(keys) 599 | return keys 600 | } 601 | 602 | func generateToolNameV2(op *spec.Operation, method, path string) string { 603 | if op.ID != "" { 604 | return op.ID 605 | } 606 | return generateDefaultToolName(method, path) 607 | } 608 | 609 | func getOperationDescriptionV2(op *spec.Operation) string { 610 | if op.Summary != "" { 611 | return op.Summary 612 | } 613 | return op.Description 614 | } 615 | 616 | func shouldIncludeOperationV2(op *spec.Operation, cfg *config.Config) bool { 617 | return shouldInclude(op.ID, op.Tags, cfg) 618 | } 619 | 620 | // parametersToMCPSchemaAndDetailsV2 converts V2 parameters and also returns details and request body. 621 | func parametersToMCPSchemaAndDetailsV2(params []spec.Parameter, definitions spec.Definitions, apiKeyName string) (mcp.Schema, mcp.Schema, []mcp.ParameterDetail, error) { 622 | mcpSchema := mcp.Schema{Type: "object", Properties: make(map[string]mcp.Schema), Required: []string{}} 623 | bodySchema := mcp.Schema{} // Initialize empty 624 | opParams := []mcp.ParameterDetail{} 625 | hasBodyParam := false 626 | var bodyParam *spec.Parameter // Declare bodyParam here to be accessible later 627 | 628 | // First pass: Separate body param, process others 629 | for _, param := range params { 630 | // Skip the API key parameter if it's configured/detected 631 | if apiKeyName != "" && param.Name == apiKeyName && (param.In == "query" || param.In == "header") { 632 | log.Printf("Parser V2: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) 633 | continue 634 | } 635 | 636 | if param.In == "body" { 637 | if hasBodyParam { 638 | return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2: multiple 'body' parameters found") 639 | } 640 | hasBodyParam = true 641 | bodyParam = ¶m // Assign to outer scope variable 642 | continue // Don't process body param further in this loop 643 | } 644 | 645 | if param.In != "query" && param.In != "path" && param.In != "header" && param.In != "formData" { 646 | log.Printf("Parser V2: Skipping unsupported parameter type '%s' for parameter '%s'", param.In, param.Name) 647 | continue 648 | } 649 | 650 | // Add non-body param detail 651 | opParams = append(opParams, mcp.ParameterDetail{ 652 | Name: param.Name, 653 | In: param.In, // query, header, path, formData 654 | }) 655 | 656 | // Convert non-body param schema and add to mcpSchema 657 | propSchema, err := swaggerParamToMCPSchema(¶m, definitions) 658 | if err != nil { 659 | return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2 param '%s': %w", param.Name, err) 660 | } 661 | mcpSchema.Properties[param.Name] = propSchema 662 | if param.Required { 663 | mcpSchema.Required = append(mcpSchema.Required, param.Name) 664 | } 665 | } 666 | 667 | // Second pass: Process the body parameter if found 668 | if bodyParam != nil { 669 | bodySchema.Description = bodyParam.Description 670 | 671 | if bodyParam.Schema != nil { 672 | // Convert the body schema (resolving $refs) 673 | bodySchemaFields, err := swaggerSchemaToMCPSchemaV2(bodyParam.Schema, definitions) 674 | if err != nil { 675 | return mcp.Schema{}, mcp.Schema{}, nil, fmt.Errorf("v2 request body schema: %w", err) 676 | } 677 | // Update our local bodySchema with the converted fields 678 | bodySchema.Type = bodySchemaFields.Type 679 | bodySchema.Properties = bodySchemaFields.Properties 680 | bodySchema.Items = bodySchemaFields.Items 681 | bodySchema.Format = bodySchemaFields.Format 682 | bodySchema.Enum = bodySchemaFields.Enum 683 | bodySchema.Required = bodySchemaFields.Required // Required fields from the *schema* itself 684 | 685 | // Merge bodySchema properties into the main mcpSchema 686 | if bodySchema.Type == "object" && bodySchema.Properties != nil { 687 | for propName, propSchema := range bodySchema.Properties { 688 | mcpSchema.Properties[propName] = propSchema 689 | } 690 | // Merge required fields from the body's schema into the main required list 691 | if len(bodySchema.Required) > 0 { 692 | mcpSchema.Required = append(mcpSchema.Required, bodySchema.Required...) 693 | } 694 | } else { 695 | // Handle non-object body schema (e.g., array, string) 696 | // Add a single property named after the body parameter 697 | mcpSchema.Properties[bodyParam.Name] = bodySchemaFields // Use the converted schema 698 | if bodyParam.Required { // Check the parameter's required status 699 | mcpSchema.Required = append(mcpSchema.Required, bodyParam.Name) 700 | } 701 | } 702 | 703 | } else { 704 | // Body param defined without a schema? Treat as simple string. 705 | log.Printf("Warning: V2 body parameter '%s' defined without a schema. Treating as string.", bodyParam.Name) 706 | bodySchema.Type = "string" 707 | mcpSchema.Properties[bodyParam.Name] = bodySchema 708 | if bodyParam.Required { 709 | mcpSchema.Required = append(mcpSchema.Required, bodyParam.Name) 710 | } 711 | } 712 | 713 | // Always add the body parameter to the OperationDetail list 714 | opParams = append(opParams, mcp.ParameterDetail{ 715 | Name: bodyParam.Name, 716 | In: bodyParam.In, 717 | }) 718 | } 719 | 720 | // Sort and deduplicate the final required list 721 | if len(mcpSchema.Required) > 1 { 722 | sort.Strings(mcpSchema.Required) 723 | seen := make(map[string]struct{}, len(mcpSchema.Required)) 724 | j := 0 725 | for _, r := range mcpSchema.Required { 726 | if _, ok := seen[r]; !ok { 727 | seen[r] = struct{}{} 728 | mcpSchema.Required[j] = r 729 | j++ 730 | } 731 | } 732 | mcpSchema.Required = mcpSchema.Required[:j] 733 | } 734 | 735 | return mcpSchema, bodySchema, opParams, nil 736 | } 737 | 738 | // swaggerParamToMCPSchema converts a V2 Parameter (non-body) to an MCP Schema. 739 | func swaggerParamToMCPSchema(param *spec.Parameter, definitions spec.Definitions) (mcp.Schema, error) { 740 | // This needs to handle types like string, integer, array based on param.Type, param.Format, param.Items 741 | // Simplified version: 742 | mcpSchema := mcp.Schema{ 743 | Type: mapJSONSchemaType(param.Type), // Use the same mapping 744 | Description: param.Description, 745 | Format: param.Format, 746 | Enum: param.Enum, 747 | // TODO: Map items for array type, map constraints (maximum, etc.) 748 | } 749 | if param.Type == "array" && param.Items != nil { 750 | // Need to convert param.Items (which is *spec.Items) to MCP schema 751 | itemsSchema, err := swaggerItemsToMCPSchema(param.Items, definitions) 752 | if err != nil { 753 | return mcp.Schema{}, fmt.Errorf("v2 array param '%s' items: %w", param.Name, err) 754 | } 755 | mcpSchema.Items = &itemsSchema 756 | } 757 | return mcpSchema, nil 758 | } 759 | 760 | // swaggerItemsToMCPSchema converts V2 Items object 761 | func swaggerItemsToMCPSchema(items *spec.Items, definitions spec.Definitions) (mcp.Schema, error) { 762 | if items == nil { 763 | return mcp.Schema{Type: "string", Description: "nil items"}, nil 764 | } 765 | // Similar logic to swaggerParamToMCPSchema but for Items structure 766 | mcpSchema := mcp.Schema{ 767 | Type: mapJSONSchemaType(items.Type), 768 | Description: "", // Items don't have descriptions typically 769 | Format: items.Format, 770 | Enum: items.Enum, 771 | } 772 | if items.Type == "array" && items.Items != nil { 773 | subItemsSchema, err := swaggerItemsToMCPSchema(items.Items, definitions) 774 | if err != nil { 775 | return mcp.Schema{}, fmt.Errorf("v2 nested array items: %w", err) 776 | } 777 | mcpSchema.Items = &subItemsSchema 778 | } 779 | // TODO: Handle $ref within items? Not directly supported by spec.Items 780 | return mcpSchema, nil 781 | } 782 | 783 | // swaggerSchemaToMCPSchemaV2 converts a Swagger v2 schema (from definitions or body param) to mcp.Schema 784 | func swaggerSchemaToMCPSchemaV2(oapiSchema *spec.Schema, definitions spec.Definitions) (mcp.Schema, error) { 785 | if oapiSchema == nil { 786 | return mcp.Schema{Type: "string", Description: "Schema was nil"}, nil 787 | } 788 | 789 | // Handle $ref 790 | if oapiSchema.Ref.String() != "" { 791 | refSchema, err := resolveRefV2(oapiSchema.Ref, definitions) 792 | if err != nil { 793 | return mcp.Schema{}, err 794 | } 795 | // Recursively convert the resolved schema, careful with cycles 796 | return swaggerSchemaToMCPSchemaV2(refSchema, definitions) 797 | } 798 | 799 | var primaryType string 800 | if len(oapiSchema.Type) > 0 { 801 | primaryType = oapiSchema.Type[0] 802 | } 803 | 804 | mcpSchema := mcp.Schema{ 805 | Type: mapJSONSchemaType(primaryType), 806 | Description: oapiSchema.Description, 807 | Format: oapiSchema.Format, 808 | Enum: oapiSchema.Enum, 809 | // TODO: Map V2 constraints (Maximum, Minimum, etc.) 810 | } 811 | 812 | switch mcpSchema.Type { 813 | case "object": 814 | mcpSchema.Properties = make(map[string]mcp.Schema) 815 | mcpSchema.Required = oapiSchema.Required 816 | for name, propSchema := range oapiSchema.Properties { 817 | // propSchema here is spec.Schema, need recursive call 818 | propMCPSchema, err := swaggerSchemaToMCPSchemaV2(&propSchema, definitions) 819 | if err != nil { 820 | return mcp.Schema{}, fmt.Errorf("v2 object property '%s': %w", name, err) 821 | } 822 | mcpSchema.Properties[name] = propMCPSchema 823 | } 824 | if len(mcpSchema.Required) > 1 { 825 | sort.Strings(mcpSchema.Required) 826 | } 827 | case "array": 828 | if oapiSchema.Items != nil && oapiSchema.Items.Schema != nil { 829 | // V2 Items has a single Schema field 830 | itemsSchema, err := swaggerSchemaToMCPSchemaV2(oapiSchema.Items.Schema, definitions) 831 | if err != nil { 832 | return mcp.Schema{}, fmt.Errorf("v2 array items: %w", err) 833 | } 834 | mcpSchema.Items = &itemsSchema 835 | } else if oapiSchema.Items != nil && len(oapiSchema.Items.Schemas) > 0 { 836 | // Handle tuple-like arrays (less common, maybe simplify to single type?) 837 | // For now, take the first schema 838 | itemsSchema, err := swaggerSchemaToMCPSchemaV2(&oapiSchema.Items.Schemas[0], definitions) 839 | if err != nil { 840 | return mcp.Schema{}, fmt.Errorf("v2 tuple array items: %w", err) 841 | } 842 | mcpSchema.Items = &itemsSchema 843 | mcpSchema.Description += " (Note: original was tuple-like array, showing first type)" 844 | } 845 | case "string", "number", "integer", "boolean", "null": 846 | // Basic types mapped 847 | default: 848 | if mcpSchema.Type == "string" && primaryType != "" && primaryType != "string" { 849 | mcpSchema.Description += fmt.Sprintf(" (Original type '%s' unknown or unsupported)", primaryType) 850 | } 851 | } 852 | return mcpSchema, nil 853 | } 854 | 855 | func resolveRefV2(ref spec.Ref, definitions spec.Definitions) (*spec.Schema, error) { 856 | // Simple local definition resolution 857 | refStr := ref.String() 858 | if !strings.HasPrefix(refStr, "#/definitions/") { 859 | return nil, fmt.Errorf("unsupported $ref format: %s", refStr) 860 | } 861 | defName := strings.TrimPrefix(refStr, "#/definitions/") 862 | schema, ok := definitions[defName] 863 | if !ok { 864 | return nil, fmt.Errorf("$ref '%s' not found in definitions", refStr) 865 | } 866 | return &schema, nil 867 | } 868 | 869 | // --- Common Helper Functions --- 870 | 871 | func createBaseToolSet(title, desc string, cfg *config.Config) *mcp.ToolSet { 872 | // Prioritize config overrides if they are set 873 | toolSetName := title // Default to spec title 874 | if cfg.DefaultToolName != "" { 875 | toolSetName = cfg.DefaultToolName // Use config override if provided 876 | } 877 | 878 | toolSetDesc := desc // Default to spec description 879 | if cfg.DefaultToolDesc != "" { 880 | toolSetDesc = cfg.DefaultToolDesc // Use config override if provided 881 | } 882 | 883 | toolSet := &mcp.ToolSet{ 884 | MCPVersion: "0.1.0", 885 | Name: toolSetName, // Use determined name 886 | Description: toolSetDesc, // Use determined description 887 | Tools: []mcp.Tool{}, 888 | Operations: make(map[string]mcp.OperationDetail), // Initialize map 889 | } 890 | 891 | // The old overwrite logic is removed as it's handled above 892 | // if title != "" { 893 | // toolSet.Name = title 894 | // } 895 | // if desc != "" { 896 | // toolSet.Description = desc 897 | // } 898 | return toolSet 899 | } 900 | 901 | // generateDefaultToolName creates a name if operationId is missing. 902 | func generateDefaultToolName(method, path string) string { 903 | pathParts := strings.Split(strings.Trim(path, "/"), "/") 904 | var nameParts []string 905 | nameParts = append(nameParts, strings.ToUpper(method[:1])+strings.ToLower(method[1:])) 906 | for _, part := range pathParts { 907 | if part == "" { 908 | continue 909 | } 910 | if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") { 911 | paramName := strings.Trim(part, "{}") 912 | nameParts = append(nameParts, "By"+strings.ToUpper(paramName[:1])+paramName[1:]) 913 | } else { 914 | sanitizedPart := strings.ReplaceAll(part, "-", "_") 915 | sanitizedPart = strings.Title(sanitizedPart) // Basic capitalization 916 | nameParts = append(nameParts, sanitizedPart) 917 | } 918 | } 919 | return strings.Join(nameParts, "") 920 | } 921 | 922 | // shouldInclude determines if an operation should be included based on config filters. 923 | func shouldInclude(opID string, opTags []string, cfg *config.Config) bool { 924 | // Exclusion rules take precedence 925 | if len(cfg.ExcludeOperations) > 0 && opID != "" && sliceContains(cfg.ExcludeOperations, opID) { 926 | return false 927 | } 928 | if len(cfg.ExcludeTags) > 0 { 929 | for _, tag := range opTags { 930 | if sliceContains(cfg.ExcludeTags, tag) { 931 | return false 932 | } 933 | } 934 | } 935 | 936 | // Inclusion rules 937 | hasInclusionRule := len(cfg.IncludeOperations) > 0 || len(cfg.IncludeTags) > 0 938 | if !hasInclusionRule { 939 | return true 940 | } // No inclusion rules, include by default 941 | 942 | if len(cfg.IncludeOperations) > 0 { 943 | if opID != "" && sliceContains(cfg.IncludeOperations, opID) { 944 | return true 945 | } 946 | } else if len(cfg.IncludeTags) > 0 { 947 | for _, tag := range opTags { 948 | if sliceContains(cfg.IncludeTags, tag) { 949 | return true 950 | } 951 | } 952 | } 953 | return false // Did not match any inclusion rule 954 | } 955 | 956 | // mapJSONSchemaType ensures the type is one recognized by JSON Schema / MCP. 957 | func mapJSONSchemaType(oapiType string) string { 958 | switch strings.ToLower(oapiType) { // Normalize type 959 | case "integer", "number", "string", "boolean", "array", "object": 960 | return strings.ToLower(oapiType) 961 | case "null": 962 | return "string" // Represent null as string for MCP? 963 | case "file": // Swagger 2.0 specific type 964 | return "string" // Represent file uploads as string (e.g., path or content)? 965 | default: 966 | return "string" 967 | } 968 | } 969 | 970 | // sliceContains checks if a string slice contains a specific string. 971 | func sliceContains(slice []string, item string) bool { 972 | for _, s := range slice { 973 | if s == item { 974 | return true 975 | } 976 | } 977 | return false 978 | } 979 | -------------------------------------------------------------------------------- /pkg/parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "os" 7 | "path/filepath" 8 | "sort" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/getkin/kin-openapi/openapi3" 13 | "github.com/go-openapi/spec" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | 17 | "github.com/ckanthony/openapi-mcp/pkg/config" 18 | "github.com/ckanthony/openapi-mcp/pkg/mcp" 19 | ) 20 | 21 | // Minimal valid OpenAPI V3 spec (JSON string) 22 | const minimalV3SpecJSON = `{ 23 | "openapi": "3.0.0", 24 | "info": { 25 | "title": "Minimal V3 API", 26 | "version": "1.0.0" 27 | }, 28 | "paths": { 29 | "/ping": { 30 | "get": { 31 | "summary": "Simple ping endpoint", 32 | "operationId": "getPing", 33 | "responses": { 34 | "200": { 35 | "description": "OK" 36 | } 37 | } 38 | } 39 | } 40 | } 41 | }` 42 | 43 | // Minimal valid Swagger V2 spec (JSON string) 44 | const minimalV2SpecJSON = `{ 45 | "swagger": "2.0", 46 | "info": { 47 | "title": "Minimal V2 API", 48 | "version": "1.0.0" 49 | }, 50 | "paths": { 51 | "/health": { 52 | "get": { 53 | "summary": "Simple health check", 54 | "operationId": "getHealth", 55 | "produces": ["application/json"], 56 | "responses": { 57 | "200": { 58 | "description": "OK" 59 | } 60 | } 61 | } 62 | } 63 | } 64 | }` 65 | 66 | // Malformed JSON 67 | const malformedJSON = `{ 68 | "openapi": "3.0.0", 69 | "info": { 70 | "title": "Missing Version", 71 | } 72 | }` 73 | 74 | // JSON without version key 75 | const noVersionKeyJSON = `{ 76 | "info": { 77 | "title": "No Version Key", 78 | "version": "1.0" 79 | }, 80 | "paths": {} 81 | }` 82 | 83 | // V3 Spec with tags and multiple operations 84 | const complexV3SpecJSON = `{ 85 | "openapi": "3.0.0", 86 | "info": { 87 | "title": "Complex V3 API", 88 | "version": "1.1.0" 89 | }, 90 | "tags": [ 91 | {"name": "tag1", "description": "First Tag"}, 92 | {"name": "tag2", "description": "Second Tag"} 93 | ], 94 | "paths": { 95 | "/items": { 96 | "get": { 97 | "summary": "List Items", 98 | "operationId": "listItems", 99 | "tags": ["tag1"], 100 | "responses": {"200": {"description": "OK"}} 101 | }, 102 | "post": { 103 | "summary": "Create Item", 104 | "operationId": "createItem", 105 | "tags": ["tag1", "tag2"], 106 | "responses": {"201": {"description": "Created"}} 107 | } 108 | }, 109 | "/users": { 110 | "get": { 111 | "summary": "List Users", 112 | "operationId": "listUsers", 113 | "tags": ["tag2"], 114 | "responses": {"200": {"description": "OK"}} 115 | } 116 | }, 117 | "/ping": { 118 | "get": { 119 | "summary": "Simple ping", 120 | "operationId": "getPing", 121 | "responses": {"200": {"description": "OK"}} 122 | } 123 | } 124 | } 125 | }` 126 | 127 | // V2 Spec with tags and multiple operations 128 | const complexV2SpecJSON = `{ 129 | "swagger": "2.0", 130 | "info": { 131 | "title": "Complex V2 API", 132 | "version": "1.1.0" 133 | }, 134 | "tags": [ 135 | {"name": "tag1", "description": "First Tag"}, 136 | {"name": "tag2", "description": "Second Tag"} 137 | ], 138 | "paths": { 139 | "/items": { 140 | "get": { 141 | "summary": "List Items", 142 | "operationId": "listItems", 143 | "tags": ["tag1"], 144 | "produces": ["application/json"], 145 | "responses": {"200": {"description": "OK"}} 146 | }, 147 | "post": { 148 | "summary": "Create Item", 149 | "operationId": "createItem", 150 | "tags": ["tag1", "tag2"], 151 | "produces": ["application/json"], 152 | "responses": {"201": {"description": "Created"}} 153 | } 154 | }, 155 | "/users": { 156 | "get": { 157 | "summary": "List Users", 158 | "operationId": "listUsers", 159 | "tags": ["tag2"], 160 | "produces": ["application/json"], 161 | "responses": {"200": {"description": "OK"}} 162 | } 163 | }, 164 | "/ping": { 165 | "get": { 166 | "summary": "Simple ping", 167 | "operationId": "getPing", 168 | "produces": ["application/json"], 169 | "responses": {"200": {"description": "OK"}} 170 | } 171 | } 172 | } 173 | }` 174 | 175 | // V3 Spec with various parameter types and request body 176 | const paramsV3SpecJSON = `{ 177 | "openapi": "3.0.0", 178 | "info": { 179 | "title": "Params V3 API", 180 | "version": "1.0.0" 181 | }, 182 | "paths": { 183 | "/test/{path_param}": { 184 | "post": { 185 | "summary": "Test various params", 186 | "operationId": "testParams", 187 | "parameters": [ 188 | { 189 | "name": "path_param", 190 | "in": "path", 191 | "required": true, 192 | "schema": {"type": "integer", "format": "int32"} 193 | }, 194 | { 195 | "name": "query_param", 196 | "in": "query", 197 | "required": true, 198 | "schema": {"type": "string", "enum": ["A", "B"]} 199 | }, 200 | { 201 | "name": "optional_query", 202 | "in": "query", 203 | "schema": {"type": "boolean"} 204 | }, 205 | { 206 | "name": "X-Header-Param", 207 | "in": "header", 208 | "required": true, 209 | "schema": {"type": "string"} 210 | }, 211 | { 212 | "name": "CookieParam", 213 | "in": "cookie", 214 | "schema": {"type": "number"} 215 | } 216 | ], 217 | "requestBody": { 218 | "required": true, 219 | "content": { 220 | "application/json": { 221 | "schema": { 222 | "type": "object", 223 | "properties": { 224 | "id": {"type": "string"}, 225 | "value": {"type": "number"} 226 | }, 227 | "required": ["id"] 228 | } 229 | } 230 | } 231 | }, 232 | "responses": { 233 | "200": {"description": "OK"} 234 | } 235 | } 236 | } 237 | } 238 | }` 239 | 240 | // V2 Spec with various parameter types and $ref 241 | const paramsV2SpecJSON = `{ 242 | "swagger": "2.0", 243 | "info": { 244 | "title": "Params V2 API", 245 | "version": "1.0.0" 246 | }, 247 | "definitions": { 248 | "Item": { 249 | "type": "object", 250 | "properties": { 251 | "id": {"type": "string", "format": "uuid"}, 252 | "name": {"type": "string"} 253 | }, 254 | "required": ["id"] 255 | } 256 | }, 257 | "paths": { 258 | "/test/{path_id}": { 259 | "put": { 260 | "summary": "Test V2 params and ref", 261 | "operationId": "testV2Params", 262 | "consumes": ["application/json"], 263 | "produces": ["application/json"], 264 | "parameters": [ 265 | { 266 | "name": "path_id", 267 | "in": "path", 268 | "required": true, 269 | "type": "string" 270 | }, 271 | { 272 | "name": "query_flag", 273 | "in": "query", 274 | "type": "boolean", 275 | "required": true 276 | }, 277 | { 278 | "name": "X-Request-ID", 279 | "in": "header", 280 | "type": "string", 281 | "required": false 282 | }, 283 | { 284 | "name": "body_param", 285 | "in": "body", 286 | "required": true, 287 | "schema": { 288 | "$ref": "#/definitions/Item" 289 | } 290 | } 291 | ], 292 | "responses": { 293 | "200": {"description": "OK"} 294 | } 295 | } 296 | } 297 | } 298 | }` 299 | 300 | // V3 Spec with array types 301 | const arraysV3SpecJSON = `{ 302 | "openapi": "3.0.0", 303 | "info": {"title": "Arrays V3 API", "version": "1.0.0"}, 304 | "paths": { 305 | "/process": { 306 | "post": { 307 | "summary": "Process arrays", 308 | "operationId": "processArrays", 309 | "parameters": [ 310 | { 311 | "name": "string_array_query", 312 | "in": "query", 313 | "schema": { 314 | "type": "array", 315 | "items": {"type": "string"} 316 | } 317 | } 318 | ], 319 | "requestBody": { 320 | "content": { 321 | "application/json": { 322 | "schema": { 323 | "type": "object", 324 | "properties": { 325 | "int_array_body": { 326 | "type": "array", 327 | "items": {"type": "integer", "format": "int64"} 328 | } 329 | } 330 | } 331 | } 332 | } 333 | }, 334 | "responses": {"200": {"description": "OK"}} 335 | } 336 | } 337 | } 338 | }` 339 | 340 | // V2 Spec with array types 341 | const arraysV2SpecJSON = `{ 342 | "swagger": "2.0", 343 | "info": {"title": "Arrays V2 API", "version": "1.0.0"}, 344 | "paths": { 345 | "/process": { 346 | "get": { 347 | "summary": "Get arrays", 348 | "operationId": "getArrays", 349 | "parameters": [ 350 | { 351 | "name": "string_array_query", 352 | "in": "query", 353 | "type": "array", 354 | "items": {"type": "string"}, 355 | "collectionFormat": "csv" 356 | }, 357 | { 358 | "name": "int_array_form", 359 | "in": "formData", 360 | "type": "array", 361 | "items": {"type": "integer", "format": "int32"} 362 | } 363 | ], 364 | "responses": {"200": {"description": "OK"}} 365 | } 366 | } 367 | } 368 | }` 369 | 370 | // V2 Spec with file parameter 371 | const fileV2SpecJSON = `{ 372 | "swagger": "2.0", 373 | "info": {"title": "File V2 API", "version": "1.0.0"}, 374 | "paths": { 375 | "/upload": { 376 | "post": { 377 | "summary": "Upload file", 378 | "operationId": "uploadFile", 379 | "consumes": ["multipart/form-data"], 380 | "parameters": [ 381 | { 382 | "name": "description", 383 | "in": "formData", 384 | "type": "string" 385 | }, 386 | { 387 | "name": "file_upload", 388 | "in": "formData", 389 | "required": true, 390 | "type": "file" 391 | } 392 | ], 393 | "responses": {"200": {"description": "OK"}} 394 | } 395 | } 396 | } 397 | }` 398 | 399 | func TestLoadSwagger(t *testing.T) { 400 | tests := []struct { 401 | name string 402 | content string 403 | fileName string 404 | expectError bool 405 | expectVersion string 406 | containsError string // Substring to check in error message 407 | isURLTest bool // Flag to indicate if the test uses a URL 408 | handler http.HandlerFunc // Handler for mock HTTP server 409 | }{ 410 | { 411 | name: "Valid V3 JSON file", 412 | content: minimalV3SpecJSON, 413 | fileName: "valid_v3.json", 414 | expectError: false, 415 | expectVersion: VersionV3, 416 | }, 417 | { 418 | name: "Valid V2 JSON file", 419 | content: minimalV2SpecJSON, 420 | fileName: "valid_v2.json", 421 | expectError: false, 422 | expectVersion: VersionV2, 423 | }, 424 | { 425 | name: "Malformed JSON file", 426 | content: malformedJSON, 427 | fileName: "malformed.json", 428 | expectError: true, 429 | containsError: "failed to parse JSON", 430 | }, 431 | { 432 | name: "No version key JSON file", 433 | content: noVersionKeyJSON, 434 | fileName: "no_version.json", 435 | expectError: true, 436 | containsError: "missing 'openapi' or 'swagger' key", 437 | }, 438 | { 439 | name: "Non-existent file", 440 | content: "", // No content needed 441 | fileName: "non_existent.json", 442 | expectError: true, 443 | containsError: "failed reading file path", 444 | }, 445 | // --- URL Tests --- 446 | { 447 | name: "Valid V3 JSON URL", 448 | content: minimalV3SpecJSON, 449 | expectError: false, 450 | expectVersion: VersionV3, 451 | isURLTest: true, 452 | handler: func(w http.ResponseWriter, r *http.Request) { 453 | w.WriteHeader(http.StatusOK) 454 | w.Write([]byte(minimalV3SpecJSON)) 455 | }, 456 | }, 457 | { 458 | name: "Valid V2 JSON URL", 459 | content: minimalV2SpecJSON, // Content used by handler 460 | expectError: false, 461 | expectVersion: VersionV2, 462 | isURLTest: true, 463 | handler: func(w http.ResponseWriter, r *http.Request) { 464 | w.WriteHeader(http.StatusOK) 465 | w.Write([]byte(minimalV2SpecJSON)) 466 | }, 467 | }, 468 | { 469 | name: "Malformed JSON URL", 470 | content: malformedJSON, 471 | expectError: true, 472 | containsError: "failed to parse JSON", 473 | isURLTest: true, 474 | handler: func(w http.ResponseWriter, r *http.Request) { 475 | w.WriteHeader(http.StatusOK) 476 | w.Write([]byte(malformedJSON)) 477 | }, 478 | }, 479 | { 480 | name: "No version key JSON URL", 481 | content: noVersionKeyJSON, 482 | expectError: true, 483 | containsError: "missing 'openapi' or 'swagger' key", 484 | isURLTest: true, 485 | handler: func(w http.ResponseWriter, r *http.Request) { 486 | w.WriteHeader(http.StatusOK) 487 | w.Write([]byte(noVersionKeyJSON)) 488 | }, 489 | }, 490 | { 491 | name: "URL Not Found (404)", 492 | expectError: true, 493 | containsError: "failed to fetch URL", // Check for fetch error 494 | isURLTest: true, 495 | handler: func(w http.ResponseWriter, r *http.Request) { 496 | http.NotFound(w, r) // Use standard http.NotFound 497 | }, 498 | }, 499 | { 500 | name: "URL Internal Server Error (500)", 501 | expectError: true, 502 | containsError: "failed to fetch URL", // Check for fetch error 503 | isURLTest: true, 504 | handler: func(w http.ResponseWriter, r *http.Request) { 505 | http.Error(w, "Internal Server Error", http.StatusInternalServerError) // Use standard http.Error 506 | }, 507 | }, 508 | } 509 | 510 | for _, tc := range tests { 511 | t.Run(tc.name, func(t *testing.T) { 512 | var location string 513 | var server *httptest.Server // Declare server variable 514 | 515 | if tc.isURLTest { 516 | // Set up mock HTTP server 517 | require.NotNil(t, tc.handler, "URL test case must provide a handler") 518 | server = httptest.NewServer(tc.handler) 519 | defer server.Close() 520 | location = server.URL // Use the mock server's URL 521 | } else { 522 | // Existing file path logic 523 | tempDir := t.TempDir() 524 | filePath := filepath.Join(tempDir, tc.fileName) 525 | 526 | // Create the file only if content is provided 527 | if tc.content != "" { 528 | err := os.WriteFile(filePath, []byte(tc.content), 0644) 529 | require.NoError(t, err, "Failed to write temp spec file") 530 | } 531 | 532 | // For the non-existent file case, ensure it really doesn't exist 533 | if tc.name == "Non-existent file" { 534 | filePath = filepath.Join(tempDir, "definitely_not_here.json") 535 | } 536 | location = filePath 537 | } 538 | 539 | specDoc, version, err := LoadSwagger(location) 540 | 541 | if tc.expectError { 542 | assert.Error(t, err) 543 | if tc.containsError != "" { 544 | assert.True(t, strings.Contains(err.Error(), tc.containsError), 545 | "Error message %q does not contain expected substring %q", err.Error(), tc.containsError) 546 | } 547 | assert.Nil(t, specDoc) 548 | assert.Empty(t, version) 549 | } else { 550 | assert.NoError(t, err) 551 | assert.NotNil(t, specDoc) 552 | assert.Equal(t, tc.expectVersion, version) 553 | // Basic type assertion based on expected version 554 | if version == VersionV3 { 555 | assert.IsType(t, &openapi3.T{}, specDoc) // Expecting a pointer 556 | } else if version == VersionV2 { 557 | assert.IsType(t, &spec.Swagger{}, specDoc) // Expecting a pointer 558 | } 559 | } 560 | }) 561 | } 562 | } 563 | 564 | // TODO: Add tests for GenerateToolSet 565 | func TestGenerateToolSet(t *testing.T) { 566 | // --- Load Specs Once --- 567 | // Load V3 spec (error checked in TestLoadSwagger) 568 | tempDirV3 := t.TempDir() 569 | filePathV3 := filepath.Join(tempDirV3, "minimal_v3.json") 570 | err := os.WriteFile(filePathV3, []byte(minimalV3SpecJSON), 0644) 571 | require.NoError(t, err) 572 | docV3, versionV3, err := LoadSwagger(filePathV3) 573 | require.NoError(t, err) 574 | require.Equal(t, VersionV3, versionV3) 575 | specV3 := docV3.(*openapi3.T) 576 | 577 | // Load V2 spec (error checked in TestLoadSwagger) 578 | tempDirV2 := t.TempDir() 579 | filePathV2 := filepath.Join(tempDirV2, "minimal_v2.json") 580 | err = os.WriteFile(filePathV2, []byte(minimalV2SpecJSON), 0644) 581 | require.NoError(t, err) 582 | docV2, versionV2, err := LoadSwagger(filePathV2) 583 | require.NoError(t, err) 584 | require.Equal(t, VersionV2, versionV2) 585 | specV2 := docV2.(*spec.Swagger) 586 | 587 | // Load Complex V3 spec 588 | tempDirComplexV3 := t.TempDir() 589 | filePathComplexV3 := filepath.Join(tempDirComplexV3, "complex_v3.json") 590 | err = os.WriteFile(filePathComplexV3, []byte(complexV3SpecJSON), 0644) 591 | require.NoError(t, err) 592 | docComplexV3, versionComplexV3, err := LoadSwagger(filePathComplexV3) 593 | require.NoError(t, err) 594 | require.Equal(t, VersionV3, versionComplexV3) 595 | specComplexV3 := docComplexV3.(*openapi3.T) 596 | 597 | // Load Complex V2 spec 598 | tempDirComplexV2 := t.TempDir() 599 | filePathComplexV2 := filepath.Join(tempDirComplexV2, "complex_v2.json") 600 | err = os.WriteFile(filePathComplexV2, []byte(complexV2SpecJSON), 0644) 601 | require.NoError(t, err) 602 | docComplexV2, versionComplexV2, err := LoadSwagger(filePathComplexV2) 603 | require.NoError(t, err) 604 | require.Equal(t, VersionV2, versionComplexV2) 605 | specComplexV2 := docComplexV2.(*spec.Swagger) 606 | 607 | // Load Params V3 spec 608 | tempDirParamsV3 := t.TempDir() 609 | filePathParamsV3 := filepath.Join(tempDirParamsV3, "params_v3.json") 610 | err = os.WriteFile(filePathParamsV3, []byte(paramsV3SpecJSON), 0644) 611 | require.NoError(t, err) 612 | docParamsV3, versionParamsV3, err := LoadSwagger(filePathParamsV3) 613 | require.NoError(t, err) 614 | require.Equal(t, VersionV3, versionParamsV3) 615 | specParamsV3 := docParamsV3.(*openapi3.T) 616 | 617 | // Load Params V2 spec 618 | tempDirParamsV2 := t.TempDir() 619 | filePathParamsV2 := filepath.Join(tempDirParamsV2, "params_v2.json") 620 | err = os.WriteFile(filePathParamsV2, []byte(paramsV2SpecJSON), 0644) 621 | require.NoError(t, err) 622 | docParamsV2, versionParamsV2, err := LoadSwagger(filePathParamsV2) 623 | require.NoError(t, err) 624 | require.Equal(t, VersionV2, versionParamsV2) 625 | specParamsV2 := docParamsV2.(*spec.Swagger) 626 | 627 | // Load Arrays V3 spec 628 | tempDirArraysV3 := t.TempDir() 629 | filePathArraysV3 := filepath.Join(tempDirArraysV3, "arrays_v3.json") 630 | err = os.WriteFile(filePathArraysV3, []byte(arraysV3SpecJSON), 0644) 631 | require.NoError(t, err) 632 | docArraysV3, versionArraysV3, err := LoadSwagger(filePathArraysV3) 633 | require.NoError(t, err) 634 | require.Equal(t, VersionV3, versionArraysV3) 635 | specArraysV3 := docArraysV3.(*openapi3.T) 636 | 637 | // Load Arrays V2 spec 638 | tempDirArraysV2 := t.TempDir() 639 | filePathArraysV2 := filepath.Join(tempDirArraysV2, "arrays_v2.json") 640 | err = os.WriteFile(filePathArraysV2, []byte(arraysV2SpecJSON), 0644) 641 | require.NoError(t, err) 642 | docArraysV2, versionArraysV2, err := LoadSwagger(filePathArraysV2) 643 | require.NoError(t, err) 644 | require.Equal(t, VersionV2, versionArraysV2) 645 | specArraysV2 := docArraysV2.(*spec.Swagger) 646 | 647 | // Load File V2 spec 648 | tempDirFileV2 := t.TempDir() 649 | filePathFileV2 := filepath.Join(tempDirFileV2, "file_v2.json") 650 | err = os.WriteFile(filePathFileV2, []byte(fileV2SpecJSON), 0644) 651 | require.NoError(t, err) 652 | docFileV2, versionFileV2, err := LoadSwagger(filePathFileV2) 653 | require.NoError(t, err) 654 | require.Equal(t, VersionV2, versionFileV2) 655 | specFileV2 := docFileV2.(*spec.Swagger) 656 | 657 | // --- Test Cases --- 658 | tests := []struct { 659 | name string 660 | spec interface{} 661 | version string 662 | cfg *config.Config 663 | expectError bool 664 | expectedToolSet *mcp.ToolSet // Define expected basic structure 665 | }{ 666 | { 667 | name: "V3 Minimal Spec - Default Config", 668 | spec: specV3, 669 | version: VersionV3, 670 | cfg: &config.Config{}, // Default empty config 671 | expectError: false, 672 | expectedToolSet: &mcp.ToolSet{ 673 | Name: "Minimal V3 API", 674 | Description: "", 675 | Tools: []mcp.Tool{ 676 | { 677 | Name: "getPing", 678 | Description: "Note: The API key is handled by the server, no need to provide it. Simple ping endpoint", 679 | InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}}, 680 | }, 681 | }, 682 | Operations: map[string]mcp.OperationDetail{ 683 | "getPing": { 684 | Method: "GET", 685 | Path: "/ping", 686 | BaseURL: "", // No server defined 687 | Parameters: []mcp.ParameterDetail{}, // Expect empty slice 688 | }, 689 | }, 690 | }, 691 | }, 692 | { 693 | name: "V2 Minimal Spec - Default Config", 694 | spec: specV2, 695 | version: VersionV2, 696 | cfg: &config.Config{}, // Default empty config 697 | expectError: false, 698 | expectedToolSet: &mcp.ToolSet{ 699 | Name: "Minimal V2 API", 700 | Description: "", 701 | Tools: []mcp.Tool{ 702 | { 703 | Name: "getHealth", 704 | Description: "Note: The API key is handled by the server, no need to provide it. Simple health check", 705 | InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}}, 706 | }, 707 | }, 708 | Operations: map[string]mcp.OperationDetail{ 709 | "getHealth": { 710 | Method: "GET", 711 | Path: "/health", 712 | BaseURL: "", // No host/schemes/basePath 713 | Parameters: []mcp.ParameterDetail{}, // Expect empty slice 714 | }, 715 | }, 716 | }, 717 | }, 718 | { 719 | name: "V3 Minimal Spec - Config Overrides", 720 | spec: specV3, 721 | version: VersionV3, 722 | cfg: &config.Config{ 723 | ServerBaseURL: "http://override.com/v1", 724 | DefaultToolName: "Override Name", 725 | DefaultToolDesc: "Override Desc", 726 | }, 727 | expectError: false, 728 | expectedToolSet: &mcp.ToolSet{ 729 | Name: "Override Name", // Uses override 730 | Description: "Override Desc", // Uses override 731 | Tools: []mcp.Tool{ 732 | { 733 | Name: "getPing", 734 | Description: "Note: The API key is handled by the server, no need to provide it. Simple ping endpoint", 735 | InputSchema: mcp.Schema{Type: "object", Properties: map[string]mcp.Schema{}, Required: []string{}}, 736 | }, 737 | }, 738 | Operations: map[string]mcp.OperationDetail{ 739 | "getPing": { 740 | Method: "GET", 741 | Path: "/ping", 742 | BaseURL: "http://override.com/v1", // Uses override 743 | Parameters: []mcp.ParameterDetail{}, // Expect empty slice 744 | }, 745 | }, 746 | }, 747 | }, 748 | // --- Filtering Tests (Using Complex Specs) --- 749 | { 750 | name: "V3 Complex - Include Tag1", 751 | spec: specComplexV3, 752 | version: VersionV3, 753 | cfg: &config.Config{IncludeTags: []string{"tag1"}}, 754 | expectError: false, 755 | expectedToolSet: &mcp.ToolSet{ 756 | Name: "Complex V3 API", Description: "", // Should only include listItems and createItem 757 | Tools: []mcp.Tool{{Name: "listItems"}, {Name: "createItem"}}, // Simplified for length check 758 | Operations: map[string]mcp.OperationDetail{"listItems": {}, "createItem": {}}, // Simplified for length check 759 | }, 760 | }, 761 | { 762 | name: "V3 Complex - Exclude Tag2", 763 | spec: specComplexV3, 764 | version: VersionV3, 765 | cfg: &config.Config{ExcludeTags: []string{"tag2"}}, 766 | expectError: false, 767 | expectedToolSet: &mcp.ToolSet{ 768 | Name: "Complex V3 API", Description: "", // Should include listItems and getPing 769 | Tools: []mcp.Tool{{Name: "listItems"}, {Name: "getPing"}}, // Simplified for length check 770 | Operations: map[string]mcp.OperationDetail{"listItems": {}, "getPing": {}}, // Simplified for length check 771 | }, 772 | }, 773 | { 774 | name: "V3 Complex - Include Operation listItems", 775 | spec: specComplexV3, 776 | version: VersionV3, 777 | cfg: &config.Config{IncludeOperations: []string{"listItems"}}, 778 | expectError: false, 779 | expectedToolSet: &mcp.ToolSet{ 780 | Name: "Complex V3 API", Description: "", // Should include only listItems 781 | Tools: []mcp.Tool{{Name: "listItems"}}, // Simplified for length check 782 | Operations: map[string]mcp.OperationDetail{"listItems": {}}, // Simplified for length check 783 | }, 784 | }, 785 | { 786 | name: "V3 Complex - Exclude Operation createItem, getPing", 787 | spec: specComplexV3, 788 | version: VersionV3, 789 | cfg: &config.Config{ExcludeOperations: []string{"createItem", "getPing"}}, 790 | expectError: false, 791 | expectedToolSet: &mcp.ToolSet{ 792 | Name: "Complex V3 API", Description: "", // Should include listItems and listUsers 793 | Tools: []mcp.Tool{{Name: "listItems"}, {Name: "listUsers"}}, // Simplified for length check 794 | Operations: map[string]mcp.OperationDetail{"listItems": {}, "listUsers": {}}, // Simplified for length check 795 | }, 796 | }, 797 | { 798 | name: "V2 Complex - Include Tag1", 799 | spec: specComplexV2, 800 | version: VersionV2, 801 | cfg: &config.Config{IncludeTags: []string{"tag1"}}, 802 | expectError: false, 803 | expectedToolSet: &mcp.ToolSet{ 804 | Name: "Complex V2 API", Description: "", // Should only include listItems and createItem 805 | Tools: []mcp.Tool{{Name: "listItems"}, {Name: "createItem"}}, // Simplified for length check 806 | Operations: map[string]mcp.OperationDetail{"listItems": {}, "createItem": {}}, // Simplified for length check 807 | }, 808 | }, 809 | { 810 | name: "V2 Complex - Exclude Tag2", 811 | spec: specComplexV2, 812 | version: VersionV2, 813 | cfg: &config.Config{ExcludeTags: []string{"tag2"}}, 814 | expectError: false, 815 | expectedToolSet: &mcp.ToolSet{ 816 | Name: "Complex V2 API", Description: "", // Should include listItems and getPing 817 | Tools: []mcp.Tool{{Name: "listItems"}, {Name: "getPing"}}, // Simplified for length check 818 | Operations: map[string]mcp.OperationDetail{"listItems": {}, "getPing": {}}, // Simplified for length check 819 | }, 820 | }, 821 | // --- Parameter/Schema Tests --- 822 | { 823 | name: "V3 Params and Request Body", 824 | spec: specParamsV3, 825 | version: VersionV3, 826 | cfg: &config.Config{}, 827 | expectError: false, 828 | expectedToolSet: &mcp.ToolSet{ 829 | Name: "Params V3 API", 830 | Description: "", // Updated: No description in spec info 831 | Tools: []mcp.Tool{ 832 | { 833 | Name: "testParams", 834 | Description: "Note: The API key is handled by the server, no need to provide it. Test various params", 835 | InputSchema: mcp.Schema{ 836 | Type: "object", 837 | Properties: map[string]mcp.Schema{ 838 | // Parameters merged with Request Body properties 839 | "path_param": {Type: "integer", Format: "int32"}, 840 | "query_param": {Type: "string", Enum: []interface{}{"A", "B"}}, 841 | "optional_query": {Type: "boolean"}, 842 | "X-Header-Param": {Type: "string"}, 843 | "CookieParam": {Type: "number"}, 844 | "id": {Type: "string"}, 845 | "value": {Type: "number"}, 846 | }, 847 | Required: []string{"path_param", "query_param", "X-Header-Param", "id"}, // Order might differ, will sort before assert 848 | }, 849 | }, 850 | }, 851 | Operations: map[string]mcp.OperationDetail{ 852 | "testParams": { 853 | Method: "POST", 854 | Path: "/test/{path_param}", 855 | BaseURL: "", // No server 856 | Parameters: []mcp.ParameterDetail{ 857 | {Name: "path_param", In: "path"}, 858 | {Name: "query_param", In: "query"}, 859 | {Name: "optional_query", In: "query"}, 860 | {Name: "X-Header-Param", In: "header"}, 861 | {Name: "CookieParam", In: "cookie"}, 862 | }, 863 | }, 864 | }, 865 | }, 866 | }, 867 | { 868 | name: "V2 Params and Ref", 869 | spec: specParamsV2, 870 | version: VersionV2, 871 | cfg: &config.Config{}, 872 | expectError: false, 873 | expectedToolSet: &mcp.ToolSet{ 874 | Name: "Params V2 API", 875 | Description: "", // Corrected: No description in spec info 876 | Tools: []mcp.Tool{ 877 | { 878 | Name: "testV2Params", 879 | Description: "Note: The API key is handled by the server, no need to provide it. Test V2 params and ref", 880 | InputSchema: mcp.Schema{ 881 | Type: "object", 882 | Properties: map[string]mcp.Schema{ 883 | // Path, Query, Header params first 884 | "path_id": {Type: "string"}, 885 | "query_flag": {Type: "boolean"}, 886 | "X-Request-ID": {Type: "string"}, 887 | // Body param ($ref to Item) merged 888 | "id": {Type: "string", Format: "uuid"}, 889 | "name": {Type: "string"}, 890 | }, 891 | Required: []string{"path_id", "query_flag", "id"}, // Required params + required definition props 892 | }, 893 | }, 894 | }, 895 | Operations: map[string]mcp.OperationDetail{ 896 | "testV2Params": { 897 | Method: "PUT", 898 | Path: "/test/{path_id}", 899 | BaseURL: "", // No server 900 | Parameters: []mcp.ParameterDetail{ 901 | {Name: "path_id", In: "path"}, 902 | {Name: "query_flag", In: "query"}, 903 | {Name: "X-Request-ID", In: "header"}, 904 | {Name: "body_param", In: "body"}, // Body param listed here 905 | }, 906 | }, 907 | }, 908 | }, 909 | }, 910 | // --- Array Tests --- 911 | { 912 | name: "V3 Arrays", 913 | spec: specArraysV3, 914 | version: VersionV3, 915 | cfg: &config.Config{}, 916 | expectError: false, 917 | expectedToolSet: &mcp.ToolSet{ 918 | Name: "Arrays V3 API", Description: "", 919 | Tools: []mcp.Tool{ 920 | { 921 | Name: "processArrays", 922 | Description: "Note: The API key is handled by the server, no need to provide it. Process arrays", 923 | InputSchema: mcp.Schema{ 924 | Type: "object", 925 | Properties: map[string]mcp.Schema{ 926 | "string_array_query": {Type: "array", Items: &mcp.Schema{Type: "string"}}, 927 | "int_array_body": {Type: "array", Items: &mcp.Schema{Type: "integer", Format: "int64"}}, 928 | }, 929 | Required: []string{}, // No required fields specified 930 | }, 931 | }, 932 | }, 933 | Operations: map[string]mcp.OperationDetail{ 934 | "processArrays": { 935 | Method: "POST", 936 | Path: "/process", 937 | BaseURL: "", 938 | Parameters: []mcp.ParameterDetail{ 939 | {Name: "string_array_query", In: "query"}, 940 | // Body param details are not explicitly listed in V3 op details 941 | }, 942 | }, 943 | }, 944 | }, 945 | }, 946 | { 947 | name: "V2 Arrays", 948 | spec: specArraysV2, 949 | version: VersionV2, 950 | cfg: &config.Config{}, 951 | expectError: false, 952 | expectedToolSet: &mcp.ToolSet{ 953 | Name: "Arrays V2 API", Description: "", 954 | Tools: []mcp.Tool{ 955 | { 956 | Name: "getArrays", 957 | Description: "Note: The API key is handled by the server, no need to provide it. Get arrays", 958 | InputSchema: mcp.Schema{ 959 | Type: "object", 960 | Properties: map[string]mcp.Schema{ 961 | "string_array_query": {Type: "array", Items: &mcp.Schema{Type: "string"}}, 962 | "int_array_form": {Type: "array", Items: &mcp.Schema{Type: "integer", Format: "int32"}}, 963 | }, 964 | Required: []string{}, // No required fields specified 965 | }, 966 | }, 967 | }, 968 | Operations: map[string]mcp.OperationDetail{ 969 | "getArrays": { 970 | Method: "GET", 971 | Path: "/process", 972 | BaseURL: "", 973 | Parameters: []mcp.ParameterDetail{ 974 | {Name: "string_array_query", In: "query"}, 975 | {Name: "int_array_form", In: "formData"}, 976 | }, 977 | }, 978 | }, 979 | }, 980 | }, 981 | { 982 | name: "V2 File Param", 983 | spec: specFileV2, 984 | version: VersionV2, 985 | cfg: &config.Config{}, 986 | expectError: false, 987 | expectedToolSet: &mcp.ToolSet{ 988 | Name: "File V2 API", Description: "", 989 | Tools: []mcp.Tool{ 990 | { 991 | Name: "uploadFile", 992 | Description: "Note: The API key is handled by the server, no need to provide it. Upload file", 993 | InputSchema: mcp.Schema{ 994 | Type: "object", 995 | Properties: map[string]mcp.Schema{ 996 | "description": {Type: "string"}, 997 | "file_upload": {Type: "string"}, // file type maps to string 998 | }, 999 | Required: []string{"file_upload"}, // file_upload is required 1000 | }, 1001 | }, 1002 | }, 1003 | Operations: map[string]mcp.OperationDetail{ 1004 | "uploadFile": { 1005 | Method: "POST", 1006 | Path: "/upload", 1007 | BaseURL: "", 1008 | Parameters: []mcp.ParameterDetail{ 1009 | {Name: "description", In: "formData"}, 1010 | {Name: "file_upload", In: "formData"}, 1011 | }, 1012 | }, 1013 | }, 1014 | }, 1015 | }, 1016 | // TODO: Add V3/V2 tests for refs 1017 | // TODO: Add V3/V2 tests for file types (V2) 1018 | } 1019 | 1020 | for _, tc := range tests { 1021 | t.Run(tc.name, func(t *testing.T) { 1022 | toolSet, err := GenerateToolSet(tc.spec, tc.version, tc.cfg) 1023 | 1024 | if tc.expectError { 1025 | assert.Error(t, err) 1026 | assert.Nil(t, toolSet) 1027 | } else { 1028 | assert.NoError(t, err) 1029 | require.NotNil(t, toolSet) 1030 | 1031 | // Compare basic ToolSet fields 1032 | assert.Equal(t, tc.expectedToolSet.Name, toolSet.Name, "ToolSet Name mismatch") 1033 | assert.Equal(t, tc.expectedToolSet.Description, toolSet.Description, "ToolSet Description mismatch") 1034 | 1035 | // Compare Tool/Operation counts first for filtering tests 1036 | assert.Equal(t, len(tc.expectedToolSet.Tools), len(toolSet.Tools), "Tool count mismatch") 1037 | assert.Equal(t, len(tc.expectedToolSet.Operations), len(toolSet.Operations), "Operation count mismatch") 1038 | 1039 | // If counts match, check specific tool names exist (more robust for filtering tests) 1040 | if len(tc.expectedToolSet.Tools) == len(toolSet.Tools) { 1041 | actualToolNames := make(map[string]bool) 1042 | for _, actualTool := range toolSet.Tools { 1043 | actualToolNames[actualTool.Name] = true 1044 | } 1045 | for _, expectedTool := range tc.expectedToolSet.Tools { 1046 | assert.Contains(t, actualToolNames, expectedTool.Name, "Expected tool %s not found in actual tools", expectedTool.Name) 1047 | } 1048 | } 1049 | 1050 | // If counts match, check specific operation IDs exist (more robust for filtering tests) 1051 | if len(tc.expectedToolSet.Operations) == len(toolSet.Operations) { 1052 | for opID := range tc.expectedToolSet.Operations { 1053 | assert.Contains(t, toolSet.Operations, opID, "Expected operation detail %s not found", opID) 1054 | } 1055 | } 1056 | 1057 | // Full comparison only for non-filtering tests for now (can be expanded) 1058 | if !strings.Contains(tc.name, "Complex") { 1059 | // Compare Tools slice fully 1060 | for i, expectedTool := range tc.expectedToolSet.Tools { 1061 | if i < len(toolSet.Tools) { // Bounds check 1062 | actualTool := toolSet.Tools[i] 1063 | assert.Equal(t, expectedTool.Name, actualTool.Name, "Tool[%d] Name mismatch", i) 1064 | assert.Equal(t, expectedTool.Description, actualTool.Description, "Tool[%d] Description mismatch", i) 1065 | // Sort Required slices before comparing Schemas 1066 | expectedSchema := expectedTool.InputSchema 1067 | actualSchema := actualTool.InputSchema 1068 | sort.Strings(expectedSchema.Required) 1069 | sort.Strings(actualSchema.Required) 1070 | assert.Equal(t, expectedSchema, actualSchema, "Tool[%d] InputSchema mismatch", i) 1071 | } 1072 | } 1073 | // Compare Operations map fully 1074 | for opID, expectedOpDetail := range tc.expectedToolSet.Operations { 1075 | if actualOpDetail, ok := toolSet.Operations[opID]; ok { 1076 | assert.Equal(t, expectedOpDetail.Method, actualOpDetail.Method, "OpDetail %s Method mismatch", opID) 1077 | assert.Equal(t, expectedOpDetail.Path, actualOpDetail.Path, "OpDetail %s Path mismatch", opID) 1078 | assert.Equal(t, expectedOpDetail.BaseURL, actualOpDetail.BaseURL, "OpDetail %s BaseURL mismatch", opID) 1079 | assert.Equal(t, expectedOpDetail.Parameters, actualOpDetail.Parameters, "OpDetail %s Parameters mismatch", opID) 1080 | } 1081 | } 1082 | } 1083 | } 1084 | }) 1085 | } 1086 | } 1087 | -------------------------------------------------------------------------------- /pkg/server/manager.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "sync" 8 | ) 9 | 10 | // client holds information about a connected SSE client. 11 | type client struct { 12 | writer http.ResponseWriter 13 | flusher http.Flusher 14 | // channel chan []byte // Could be used later for broadcasting updates 15 | } 16 | 17 | // connectionManager manages active client connections. 18 | type connectionManager struct { 19 | clients map[*http.Request]*client // Use request ptr as key 20 | mu sync.RWMutex 21 | toolSet []byte // Pre-encoded toolset JSON 22 | } 23 | 24 | // newConnectionManager creates a manager. 25 | func newConnectionManager(toolSetJSON []byte) *connectionManager { 26 | return &connectionManager{ 27 | clients: make(map[*http.Request]*client), 28 | toolSet: toolSetJSON, 29 | } 30 | } 31 | 32 | // addClient registers a new client and sends the initial toolset. 33 | func (m *connectionManager) addClient(r *http.Request, w http.ResponseWriter, f http.Flusher) { 34 | newClient := &client{writer: w, flusher: f} 35 | m.mu.Lock() 36 | m.clients[r] = newClient 37 | m.mu.Unlock() 38 | 39 | log.Printf("Client connected: %s (Total: %d)", r.RemoteAddr, m.getClientCount()) 40 | 41 | // Send initial toolset immediately 42 | go m.sendToolset(newClient) // Send in a goroutine to avoid blocking registration? 43 | } 44 | 45 | // removeClient removes a client. 46 | func (m *connectionManager) removeClient(r *http.Request) { 47 | m.mu.Lock() 48 | _, ok := m.clients[r] 49 | if ok { 50 | delete(m.clients, r) 51 | log.Printf("Client disconnected: %s (Total: %d)", r.RemoteAddr, len(m.clients)) 52 | } else { 53 | log.Printf("Attempted to remove already disconnected client: %s", r.RemoteAddr) 54 | } 55 | m.mu.Unlock() 56 | } 57 | 58 | // getClientCount returns the number of active clients. 59 | func (m *connectionManager) getClientCount() int { 60 | m.mu.RLock() 61 | count := len(m.clients) 62 | m.mu.RUnlock() 63 | return count 64 | } 65 | 66 | // sendToolset sends the pre-encoded toolset to a specific client. 67 | func (m *connectionManager) sendToolset(c *client) { 68 | if c == nil { 69 | return 70 | } 71 | log.Printf("Attempting to send toolset to client...") 72 | _, err := fmt.Fprintf(c.writer, "event: tool_set\ndata: %s\n\n", string(m.toolSet)) 73 | if err != nil { 74 | // This error often happens if the client disconnected before/during the write 75 | log.Printf("Error sending toolset data to client: %v (client likely disconnected)", err) 76 | // Optionally trigger removal here if possible, though context done in handler is primary mechanism 77 | return 78 | } 79 | // Flush the data 80 | c.flusher.Flush() 81 | log.Println("Sent tool_set event and flushed.") 82 | } 83 | -------------------------------------------------------------------------------- /pkg/server/manager_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | // mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE. 15 | type mockResponseWriter struct { 16 | *httptest.ResponseRecorder // Embed to get ResponseWriter behavior 17 | flushed bool // Track if Flush was called 18 | forceError error // Added for testing error handling 19 | } 20 | 21 | // NewMockResponseWriter creates a new mock response writer. 22 | func NewMockResponseWriter() *mockResponseWriter { 23 | return &mockResponseWriter{ 24 | ResponseRecorder: httptest.NewRecorder(), 25 | } 26 | } 27 | 28 | // Write method for mockResponseWriter (ensure it handles forceError) 29 | func (m *mockResponseWriter) Write(p []byte) (int, error) { 30 | if m.forceError != nil { 31 | return 0, m.forceError 32 | } 33 | return m.ResponseRecorder.Write(p) // Use embedded writer 34 | } 35 | 36 | // Flush method for mockResponseWriter 37 | func (m *mockResponseWriter) Flush() { 38 | if m.forceError != nil { // Don't flush if write failed 39 | return 40 | } 41 | m.flushed = true 42 | // We don't actually flush the embedded recorder in this mock 43 | } 44 | 45 | // --- Simple Mock Flusher --- 46 | type mockFlusher struct { 47 | flushed bool 48 | } 49 | 50 | func (f *mockFlusher) Flush() { 51 | f.flushed = true 52 | } 53 | 54 | // --- End Mock Flusher --- 55 | 56 | func TestManager_Run_Stop(t *testing.T) { 57 | // Basic test to ensure the manager can start and stop. 58 | // More comprehensive tests involving resource handling would be needed. 59 | 60 | // Dummy tool set JSON for initialization 61 | dummyToolSet := []byte(`{"tools": []}`) 62 | 63 | m := newConnectionManager(dummyToolSet) 64 | 65 | // Basic run/stop test - might need refinement depending on Run() implementation 66 | // We need a way to observe if Run() is actually doing something or blocking. 67 | // For now, just test start and stop signals. 68 | stopChan := make(chan struct{}) 69 | go func() { 70 | // Need to figure out what Run expects or does. 71 | // If Run is intended to block, this test structure needs adjustment. 72 | // For now, assume Run might just start background tasks and doesn't block indefinitely. 73 | // If it expects specific input or state, that needs mocking. 74 | // Placeholder: Simulate Run behavior relevant to Stop. 75 | // If Run blocks, this goroutine might hang. 76 | <-stopChan // Simulate Run blocking until Stop is called 77 | }() 78 | 79 | // Simulate adding a client to test remove logic 80 | req := httptest.NewRequest(http.MethodGet, "/events", nil) 81 | mrr := NewMockResponseWriter() // Use the mock 82 | m.addClient(req, mrr, mrr) // Pass the mock which implements both interfaces 83 | if m.getClientCount() != 1 { 84 | t.Errorf("Expected 1 client after add, got %d", m.getClientCount()) 85 | } 86 | 87 | time.Sleep(100 * time.Millisecond) // Give time for potential background tasks 88 | 89 | // Test removing the client 90 | m.removeClient(req) 91 | if m.getClientCount() != 0 { 92 | t.Errorf("Expected 0 clients after remove, got %d", m.getClientCount()) 93 | } 94 | 95 | // Simulate stopping the manager 96 | close(stopChan) // Signal the placeholder Run goroutine to exit 97 | 98 | // Need a way to verify Stop() worked. If it closes internal channels, 99 | // we could potentially check that. Without knowing Stop's implementation, 100 | // this is a basic check. 101 | // Maybe add a dedicated Stop() method to connectionManager if Run blocks? 102 | // Or check internal state if possible. 103 | 104 | // Example: If Stop closes a known channel: 105 | // select { 106 | // case <-m.internalStopChan: // Assuming internalStopChan exists and is closed by Stop() 107 | // // Expected behavior 108 | // case <-time.After(1 * time.Second): 109 | // t.Fatal("Manager did not signal stop within the expected time") 110 | // } 111 | } 112 | 113 | // Define a dummy non-flusher if needed 114 | type nonFlusher struct { 115 | http.ResponseWriter 116 | } 117 | 118 | func (nf *nonFlusher) Flush() { /* Do nothing */ } 119 | 120 | func TestManager_AddRemoveClient(t *testing.T) { 121 | dummyToolSet := []byte(`{"tools": []}`) 122 | m := newConnectionManager(dummyToolSet) 123 | 124 | req1 := httptest.NewRequest(http.MethodGet, "/events?id=1", nil) 125 | mrr1 := NewMockResponseWriter() // Use mock 126 | 127 | req2 := httptest.NewRequest(http.MethodGet, "/events?id=2", nil) 128 | mrr2 := NewMockResponseWriter() // Use mock 129 | 130 | m.addClient(req1, mrr1, mrr1) // Pass mock 131 | if count := m.getClientCount(); count != 1 { 132 | t.Errorf("Expected 1 client, got %d", count) 133 | } 134 | 135 | m.addClient(req2, mrr2, mrr2) // Pass mock 136 | if count := m.getClientCount(); count != 2 { 137 | t.Errorf("Expected 2 clients, got %d", count) 138 | } 139 | 140 | m.removeClient(req1) 141 | if count := m.getClientCount(); count != 1 { 142 | t.Errorf("Expected 1 client after removing req1, got %d", count) 143 | } 144 | // Ensure the correct client was removed 145 | m.mu.RLock() 146 | _, exists := m.clients[req1] 147 | m.mu.RUnlock() 148 | if exists { 149 | t.Error("req1 should have been removed but still exists in map") 150 | } 151 | 152 | m.removeClient(req2) 153 | if count := m.getClientCount(); count != 0 { 154 | t.Errorf("Expected 0 clients after removing req2, got %d", count) 155 | } 156 | 157 | // Test removing non-existent client 158 | m.removeClient(req1) // Remove again 159 | if count := m.getClientCount(); count != 0 { 160 | t.Errorf("Expected 0 clients after removing non-existent, got %d", count) 161 | } 162 | } 163 | 164 | // Test for sendToolset needs a way to capture output sent to the client. 165 | // httptest.ResponseRecorder can capture the body. 166 | func TestManager_SendToolset(t *testing.T) { 167 | toolSetData := `{"tools": ["tool1", "tool2"]}` 168 | m := newConnectionManager([]byte(toolSetData)) 169 | 170 | mrr := NewMockResponseWriter() // Use mock 171 | 172 | // Directly create a client struct instance for testing sendToolset specifically 173 | // Note: This bypasses addClient logic for focused testing of sendToolset. 174 | testClient := &client{writer: mrr, flusher: mrr} // Use mock for both 175 | 176 | m.sendToolset(testClient) 177 | 178 | // Use strings.TrimSpace for comparison to avoid issues with subtle whitespace differences 179 | // Escape inner quotes 180 | expectedOutputPattern := "event: tool_set\ndata: {\"tools\": [\"tool1\", \"tool2\"]}\n\n" 181 | actualOutput := mrr.Body.String() 182 | 183 | if strings.TrimSpace(actualOutput) != strings.TrimSpace(expectedOutputPattern) { 184 | // Use %q to quote strings, making whitespace visible 185 | t.Errorf("Expected toolset output matching pattern %q, got %q", expectedOutputPattern, actualOutput) 186 | } 187 | if !mrr.flushed { // Check if flush was called 188 | t.Error("Expected Flush() to be called on the writer, but it wasn't") 189 | } 190 | 191 | // Test sending to nil client 192 | m.sendToolset(nil) // Should not panic 193 | } 194 | 195 | // Test case for when writing the toolset fails (e.g., client disconnected) 196 | func TestConnectionManager_SendToolset_WriteError(t *testing.T) { 197 | mgr := newConnectionManager([]byte(`{"tool":"set"}`)) 198 | 199 | // Create a mock writer that always returns an error 200 | mockWriter := &mockResponseWriter{ 201 | ResponseRecorder: httptest.NewRecorder(), // Initialize embedded recorder 202 | forceError: fmt.Errorf("simulated write error"), 203 | } 204 | mockFlusher := &mockFlusher{} 205 | 206 | // Create a client with the erroring writer 207 | mockClient := &client{ 208 | writer: mockWriter, 209 | flusher: mockFlusher, 210 | } 211 | 212 | // Call sendToolset - we expect it to log the error and return early 213 | // We don't easily assert the log, but we run it for coverage. 214 | mgr.sendToolset(mockClient) 215 | 216 | // Assert that Flush was NOT called because the function should have returned early 217 | assert.False(t, mockFlusher.flushed, "Flush should not be called when Write fails") 218 | // Assert that Write was attempted (optional, depends on mock capabilities) 219 | // If mockResponseWriter tracks calls, assert Write was called once. 220 | } 221 | -------------------------------------------------------------------------------- /pkg/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | // "fmt" // No longer needed here 17 | // "sync" // No longer needed here 18 | 19 | "github.com/ckanthony/openapi-mcp/pkg/config" 20 | "github.com/ckanthony/openapi-mcp/pkg/mcp" 21 | "github.com/google/uuid" // Import UUID package 22 | ) 23 | 24 | // --- JSON-RPC Structures (Re-introduced for Handshake/Messages) --- 25 | 26 | type jsonRPCRequest struct { 27 | Jsonrpc string `json:"jsonrpc"` 28 | Method string `json:"method"` 29 | Params interface{} `json:"params,omitempty"` 30 | ID interface{} `json:"id,omitempty"` // Can be string, number, or null 31 | } 32 | 33 | type jsonRPCResponse struct { 34 | Jsonrpc string `json:"jsonrpc"` 35 | Result interface{} `json:"result,omitempty"` 36 | Error *jsonError `json:"error,omitempty"` 37 | ID interface{} `json:"id"` // ID should match the request ID 38 | } 39 | 40 | type jsonError struct { 41 | Code int `json:"code"` 42 | Message string `json:"message"` 43 | Data interface{} `json:"data,omitempty"` 44 | } 45 | 46 | // --- MCP Message Structures (Kept for clarity on expected payloads) --- 47 | 48 | // MCPMessage represents a generic message exchanged over the transport. 49 | // Note: Adapt this structure based on the exact MCP spec requirements if needed. 50 | // This structure is now more for understanding the *payloads* within JSON-RPC. 51 | type MCPMessage struct { 52 | Type string `json:"type"` // e.g., "initialize", "tools/list", "tools/call", "tool_result", "error" 53 | ID string `json:"id,omitempty"` // Unique message ID (less relevant for JSON-RPC wrapper) 54 | Payload json.RawMessage `json:"payload,omitempty"` // Content specific to the message type 55 | ConnID string `json:"connectionId,omitempty"` // Included in responses related to a connection 56 | } 57 | 58 | // MCPError defines a structured error for MCP responses. 59 | // This will be used within the 'Error.Data' field of a jsonRPCResponse. 60 | type MCPError struct { 61 | Code int `json:"code,omitempty"` // Optional error code 62 | Message string `json:"message"` 63 | Data interface{} `json:"data,omitempty"` // Optional additional data 64 | } 65 | 66 | // ToolCallParams represents the expected payload for a tools/call request. 67 | // This will be the structure within the 'params' field of a jsonRPCRequest. 68 | type ToolCallParams struct { 69 | ToolName string `json:"name"` // Aligning with gin-mcp JSON-RPC 'name' 70 | Input map[string]interface{} `json:"arguments"` // Aligning with gin-mcp JSON-RPC 'arguments' 71 | } 72 | 73 | // ToolResultContent represents an item in the 'content' array of a tool_result. 74 | type ToolResultContent struct { 75 | Type string `json:"type"` 76 | Text string `json:"text"` // Assuming text/JSON string result 77 | // Add other content types if needed 78 | } 79 | 80 | // ToolResultPayload represents the structure for the 'result' of a 'tool_result' JSON-RPC response. 81 | type ToolResultPayload struct { 82 | Content []ToolResultContent `json:"content"` // Array of content items 83 | IsError bool `json:"isError"` // Aligning with gin-mcp 84 | Error *MCPError `json:"error,omitempty"` // Detailed error info if IsError is true 85 | ToolCallID string `json:"tool_call_id,omitempty"` // Optional: Can be helpful 86 | } 87 | 88 | // --- Server State --- 89 | 90 | // activeConnections stores channels for sending messages back to active SSE clients. 91 | var activeConnections = make(map[string]chan jsonRPCResponse) // Changed value type 92 | var connMutex sync.RWMutex 93 | 94 | // Channel buffer size 95 | const messageChannelBufferSize = 10 96 | 97 | // --- Server Implementation --- 98 | 99 | // ServeMCP starts an HTTP server handling MCP communication. 100 | func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error { 101 | log.Printf("Preparing ToolSet for MCP...") 102 | 103 | // --- Handler Functions --- 104 | mcpHandler := func(w http.ResponseWriter, r *http.Request) { 105 | // CORS Headers (Apply to all relevant requests) 106 | w.Header().Set("Access-Control-Allow-Origin", "*") // Be more specific in production 107 | w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") 108 | w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Connection-ID") 109 | w.Header().Set("Access-Control-Expose-Headers", "X-Connection-ID") 110 | 111 | if r.Method == http.MethodOptions { 112 | log.Println("Responding to OPTIONS request") 113 | w.WriteHeader(http.StatusNoContent) // Use 204 No Content for OPTIONS 114 | return 115 | } 116 | 117 | if r.Method == http.MethodGet { 118 | httpMethodGetHandler(w, r) // Handle SSE connection setup 119 | } else if r.Method == http.MethodPost { 120 | httpMethodPostHandler(w, r, toolSet, cfg) // Pass the cfg object here 121 | } else { 122 | log.Printf("Method Not Allowed: %s", r.Method) 123 | http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) 124 | } 125 | } 126 | 127 | // Setup server mux 128 | mux := http.NewServeMux() 129 | mux.HandleFunc("/mcp", mcpHandler) // Single endpoint for GET/POST/OPTIONS 130 | 131 | log.Printf("MCP server listening on %s/mcp", addr) 132 | return http.ListenAndServe(addr, mux) 133 | } 134 | 135 | // httpMethodGetHandler handles the initial GET request to establish the SSE connection. 136 | func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { 137 | connectionID := uuid.New().String() 138 | log.Printf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID) 139 | 140 | flusher, ok := w.(http.Flusher) 141 | if !ok { 142 | http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) 143 | log.Println("Error: Client connection does not support flushing") 144 | return 145 | } 146 | 147 | // --- Set headers FIRST --- 148 | w.Header().Set("Content-Type", "text/event-stream") 149 | w.Header().Set("Cache-Control", "no-cache") 150 | w.Header().Set("Connection", "keep-alive") 151 | // CORS headers are set in the main handler 152 | w.Header().Set("X-Connection-ID", connectionID) 153 | w.Header().Set("X-Accel-Buffering", "no") // Useful for proxies like Nginx 154 | w.WriteHeader(http.StatusOK) // Write headers and status code 155 | flusher.Flush() // Ensure headers are sent immediately 156 | 157 | // --- Send initial :ok --- (Must happen *after* headers) 158 | if _, err := fmt.Fprintf(w, ":ok\n\n"); err != nil { 159 | log.Printf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) 160 | return // Cannot proceed if preamble fails 161 | } 162 | flusher.Flush() 163 | log.Printf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID) 164 | 165 | // --- Send initial SSE events --- (endpoint, mcp-ready) 166 | endpointURL := fmt.Sprintf("/mcp?sessionId=%s", connectionID) // Assuming /mcp is the mount path 167 | if err := writeSSEEvent(w, "endpoint", endpointURL); err != nil { 168 | log.Printf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) 169 | return 170 | } 171 | flusher.Flush() 172 | log.Printf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID) 173 | 174 | readyMsg := jsonRPCRequest{ // Use request struct for notification format 175 | Jsonrpc: "2.0", 176 | Method: "mcp-ready", 177 | Params: map[string]interface{}{ // Put data in params 178 | "connectionId": connectionID, 179 | "status": "connected", 180 | "protocol": "2.0", 181 | }, 182 | } 183 | if err := writeSSEEvent(w, "message", readyMsg); err != nil { 184 | log.Printf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) 185 | return 186 | } 187 | flusher.Flush() 188 | log.Printf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID) 189 | 190 | // --- Setup message channel and store connection --- 191 | msgChan := make(chan jsonRPCResponse, messageChannelBufferSize) // Channel for responses 192 | connMutex.Lock() 193 | activeConnections[connectionID] = msgChan 194 | connMutex.Unlock() 195 | log.Printf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections)) 196 | 197 | // --- Cleanup function --- 198 | cleanup := func() { 199 | connMutex.Lock() 200 | delete(activeConnections, connectionID) 201 | connMutex.Unlock() 202 | close(msgChan) // Close channel when connection ends 203 | log.Printf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections)) 204 | } 205 | defer cleanup() 206 | 207 | // --- Goroutine to write messages from channel to SSE stream --- 208 | ctx, cancel := context.WithCancel(r.Context()) 209 | defer cancel() 210 | 211 | go func() { 212 | log.Printf("[SSE Writer %s] Starting message writer goroutine", connectionID) 213 | defer log.Printf("[SSE Writer %s] Exiting message writer goroutine", connectionID) 214 | for { 215 | select { 216 | case <-ctx.Done(): 217 | return // Exit if main context is cancelled 218 | case resp, ok := <-msgChan: 219 | if !ok { 220 | log.Printf("[SSE Writer %s] Message channel closed.", connectionID) 221 | return // Exit if channel is closed 222 | } 223 | log.Printf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID) 224 | if err := writeSSEEvent(w, "message", resp); err != nil { 225 | log.Printf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err) 226 | cancel() // Signal main loop to exit on write error 227 | return 228 | } 229 | flusher.Flush() // Flush after writing message 230 | } 231 | } 232 | }() 233 | 234 | // --- Keep connection alive (main loop) --- 235 | keepAliveTicker := time.NewTicker(20 * time.Second) 236 | defer keepAliveTicker.Stop() 237 | 238 | log.Printf("[SSE %s] Entering keep-alive loop", connectionID) 239 | for { 240 | select { 241 | case <-ctx.Done(): 242 | log.Printf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID) 243 | return // Exit loop if context cancelled (client disconnect or write error) 244 | case <-keepAliveTicker.C: 245 | // Send JSON-RPC ping notification instead of SSE comment 246 | pingMsg := jsonRPCRequest{ // Use request struct for notification format 247 | Jsonrpc: "2.0", 248 | Method: "ping", 249 | Params: map[string]interface{}{ // Include timestamp like gin-mcp 250 | "timestamp": time.Now().Unix(), 251 | }, 252 | } 253 | if err := writeSSEEvent(w, "message", pingMsg); err != nil { 254 | log.Printf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err) 255 | cancel() // Signal writer goroutine and exit 256 | return 257 | } 258 | flusher.Flush() 259 | } 260 | } 261 | } 262 | 263 | // writeSSEEvent formats and writes data as a Server-Sent Event. 264 | func writeSSEEvent(w http.ResponseWriter, eventName string, data interface{}) error { 265 | buffer := bytes.Buffer{} 266 | if eventName != "" { 267 | buffer.WriteString(fmt.Sprintf("event: %s\n", eventName)) 268 | } 269 | 270 | // Marshal data to JSON if it's not a simple string already 271 | var dataStr string 272 | if strData, ok := data.(string); ok && eventName == "endpoint" { // Special case for endpoint URL 273 | dataStr = strData 274 | } else { 275 | jsonData, err := json.Marshal(data) 276 | if err != nil { 277 | return fmt.Errorf("failed to marshal data for SSE event '%s': %w", eventName, err) 278 | } 279 | dataStr = string(jsonData) 280 | } 281 | 282 | // Write data line(s). Split multiline JSON for proper SSE formatting. 283 | lines := strings.Split(dataStr, "\n") 284 | for _, line := range lines { 285 | buffer.WriteString(fmt.Sprintf("data: %s\n", line)) 286 | } 287 | 288 | // Add final newline 289 | buffer.WriteString("\n") 290 | 291 | // Write to the response writer 292 | _, err := w.Write(buffer.Bytes()) 293 | if err != nil { 294 | return fmt.Errorf("failed to write SSE event '%s': %w", eventName, err) 295 | } 296 | return nil 297 | } 298 | 299 | // httpMethodPostHandler handles incoming POST requests containing MCP messages. 300 | func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp.ToolSet, cfg *config.Config) { 301 | // --- Original Logic (Restored) --- 302 | connID := r.Header.Get("X-Connection-ID") // Try header first 303 | if connID == "" { 304 | connID = r.URL.Query().Get("sessionId") // Fallback to query parameter 305 | log.Printf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID) 306 | } 307 | 308 | if connID == "" { 309 | log.Println("Error: POST request received without X-Connection-ID header or sessionId query parameter") 310 | http.Error(w, "Missing X-Connection-ID header or sessionId query parameter", http.StatusBadRequest) 311 | return 312 | } 313 | 314 | // Find the corresponding message channel for this connection 315 | connMutex.RLock() 316 | msgChan, isActive := activeConnections[connID] 317 | connMutex.RUnlock() 318 | 319 | if !isActive { 320 | log.Printf("Error: POST request received for inactive/unknown connection ID: %s", connID) 321 | // Still send sync error here, as we don't have a channel 322 | tryWriteHTTPError(w, http.StatusNotFound, "Invalid or expired connection ID") 323 | return 324 | } 325 | 326 | bodyBytes, err := io.ReadAll(r.Body) 327 | if err != nil { 328 | log.Printf("Error reading POST request body for %s: %v", connID, err) 329 | // Create error response in the ToolResultPayload format 330 | errPayload := ToolResultPayload{ 331 | IsError: true, 332 | Error: &MCPError{ 333 | Code: -32700, // JSON-RPC Parse Error Code 334 | Message: "Parse error reading request body", 335 | }, 336 | // ToolCallID doesn't really apply here, maybe use connID or leave empty? 337 | // ToolCallID: connID, 338 | } 339 | errResp := jsonRPCResponse{ 340 | Jsonrpc: "2.0", 341 | ID: nil, // ID is unknown if we can't read the body 342 | Result: errPayload, 343 | Error: nil, // Ensure top-level error is nil 344 | } 345 | // Attempt to send via SSE channel 346 | select { 347 | case msgChan <- errResp: 348 | log.Printf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID) 349 | // Send HTTP 202 Accepted back to the POST request 350 | w.WriteHeader(http.StatusAccepted) 351 | fmt.Fprintln(w, "Request accepted (with parse error), response will be sent via SSE.") 352 | default: 353 | log.Printf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) 354 | // Send an error back on the POST request if channel fails 355 | tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel") 356 | } 357 | return // Stop processing 358 | } 359 | // No defer r.Body.Close() needed here as io.ReadAll reads to EOF 360 | 361 | log.Printf("Received POST data for %s: %s", connID, string(bodyBytes)) 362 | 363 | // Attempt to unmarshal into a temporary map first to extract ID if possible 364 | var rawReq map[string]interface{} 365 | var reqID interface{} // Keep track of ID even if full unmarshal fails 366 | 367 | // Try unmarshalling into raw map 368 | if err := json.Unmarshal(bodyBytes, &rawReq); err == nil { 369 | // Ensure reqID is treated as a string or number if possible, handle potential null 370 | if idVal, idExists := rawReq["id"]; idExists && idVal != nil { 371 | reqID = idVal 372 | } else { 373 | reqID = nil // Explicitly set to nil if missing or JSON null 374 | } 375 | } else { 376 | // Full unmarshal failed, log it but continue to try specific struct 377 | log.Printf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err) 378 | reqID = nil // ID is unknown 379 | } 380 | 381 | var req jsonRPCRequest // Expect JSON-RPC request 382 | if err := json.Unmarshal(bodyBytes, &req); err != nil { 383 | log.Printf("Error decoding JSON-RPC request for %s: %v", connID, err) 384 | // Use createJSONRPCError to correctly format the error response 385 | errResp := createJSONRPCError(reqID, -32700, "Parse error decoding JSON request", err.Error()) 386 | 387 | // Attempt to send via SSE channel 388 | select { 389 | case msgChan <- errResp: 390 | log.Printf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID) 391 | // Send HTTP 202 Accepted back to the POST request 392 | w.WriteHeader(http.StatusAccepted) 393 | // Use a specific message for decode errors 394 | fmt.Fprintln(w, "Request accepted (with decode error), response will be sent via SSE.") 395 | default: 396 | log.Printf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) 397 | // Send an error back on the POST request if channel fails 398 | tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel") 399 | } 400 | return // Stop processing 401 | } 402 | 403 | // If we successfully unmarshalled 'req', ensure reqID matches req.ID 404 | if req.ID != nil { 405 | reqID = req.ID 406 | } else { 407 | reqID = nil 408 | } 409 | 410 | // --- Variable to hold the final response to be sent via SSE --- 411 | var respToSend jsonRPCResponse 412 | 413 | // --- Validate JSON-RPC Request --- 414 | if req.Jsonrpc != "2.0" { 415 | log.Printf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID) 416 | respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: jsonrpc field must be \"2.0\"", nil) 417 | } else if req.Method == "" { 418 | log.Printf("Missing JSON-RPC method for %s, ID: %v", connID, reqID) 419 | respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: method field is missing or empty", nil) 420 | } else { 421 | // --- Process the valid request --- 422 | log.Printf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID) 423 | switch req.Method { 424 | case "initialize": 425 | incomingInitializeJSON, _ := json.Marshal(req) 426 | log.Printf("DEBUG: Handling 'initialize' for %s. Incoming request: %s", connID, string(incomingInitializeJSON)) 427 | respToSend = handleInitializeJSONRPC(connID, &req) 428 | outgoingInitializeJSON, _ := json.Marshal(respToSend) 429 | log.Printf("DEBUG: Prepared 'initialize' response for %s. Outgoing response: %s", connID, string(outgoingInitializeJSON)) 430 | case "notifications/initialized": 431 | log.Printf("Received 'notifications/initialized' notification for %s. Ignoring.", connID) 432 | w.WriteHeader(http.StatusAccepted) 433 | fmt.Fprintln(w, "Notification received.") 434 | return // Return early, do not send anything on SSE channel 435 | case "tools/list": 436 | respToSend = handleToolsListJSONRPC(connID, &req, toolSet) 437 | case "tools/call": 438 | respToSend = handleToolCallJSONRPC(connID, &req, toolSet, cfg) 439 | default: 440 | log.Printf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID) 441 | respToSend = createJSONRPCError(reqID, -32601, fmt.Sprintf("Method not found: %s", req.Method), nil) 442 | } 443 | } 444 | 445 | // --- Send response ASYNCHRONOUSLY via SSE channel (unless handled earlier) --- 446 | select { 447 | case msgChan <- respToSend: 448 | log.Printf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID) 449 | // Send HTTP 202 Accepted back to the POST request 450 | w.WriteHeader(http.StatusAccepted) 451 | // Use the standard message for successfully queued responses 452 | fmt.Fprintln(w, "Request accepted, response will be sent via SSE.") 453 | default: 454 | log.Printf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID) 455 | http.Error(w, "Failed to queue response for SSE channel", http.StatusInternalServerError) 456 | } 457 | } 458 | 459 | // --- JSON-RPC Message Handlers --- // Implementations returning jsonRPCResponse 460 | 461 | func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse { 462 | log.Printf("Handling 'initialize' (JSON-RPC) for %s", connID) 463 | 464 | // Construct the result payload based on gin-mcp's structure using map[string]interface{} 465 | resultPayload := map[string]interface{}{ 466 | "protocolVersion": "2024-11-05", // Aligning with gin-mcp 467 | "capabilities": map[string]interface{}{ 468 | "tools": map[string]interface{}{ 469 | "enabled": true, 470 | "config": map[string]interface{}{ 471 | "listChanged": false, 472 | }, 473 | }, 474 | "prompts": map[string]interface{}{ 475 | "enabled": false, 476 | }, 477 | "resources": map[string]interface{}{ 478 | "enabled": true, 479 | }, 480 | "logging": map[string]interface{}{ 481 | "enabled": false, 482 | }, 483 | "roots": map[string]interface{}{ 484 | "listChanged": false, 485 | }, 486 | }, 487 | "serverInfo": map[string]interface{}{ 488 | "name": "OpenAPI-MCP", // Or use config name if available 489 | "version": "openapi-mcp-0.1.0", // Your server version 490 | "apiVersion": "2024-11-05", // MCP API version 491 | }, 492 | "connectionId": connID, // Include the connection ID 493 | } 494 | 495 | return jsonRPCResponse{ 496 | Jsonrpc: "2.0", 497 | ID: req.ID, // Match request ID 498 | Result: resultPayload, 499 | } 500 | } 501 | 502 | func handleToolsListJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet) jsonRPCResponse { 503 | log.Printf("Handling 'tools/list' (JSON-RPC) for %s", connID) 504 | 505 | // Construct the result payload based on gin-mcp's structure 506 | resultPayload := map[string]interface{}{ 507 | "tools": toolSet.Tools, 508 | "metadata": map[string]interface{}{ 509 | "version": "2024-11-05", // Align with gin-mcp if possible 510 | "count": len(toolSet.Tools), 511 | }, 512 | } 513 | 514 | return jsonRPCResponse{ 515 | Jsonrpc: "2.0", 516 | ID: req.ID, // Match request ID 517 | Result: resultPayload, 518 | } 519 | } 520 | 521 | // executeToolCall performs the actual HTTP request based on the resolved operation and parameters. 522 | // It now correctly handles API key injection based on the *cfg* parameter. 523 | func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.Config) (*http.Response, error) { 524 | toolName := params.ToolName 525 | toolInput := params.Input // This is the map[string]interface{} from the client 526 | 527 | log.Printf("[ExecuteToolCall] Looking up details for tool: %s", toolName) 528 | operation, ok := toolSet.Operations[toolName] 529 | if !ok { 530 | log.Printf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName) 531 | return nil, fmt.Errorf("operation details for tool '%s' not found", toolName) 532 | } 533 | log.Printf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path) 534 | 535 | // --- Resolve API Key (using cfg passed from main) --- 536 | resolvedKey := cfg.GetAPIKey() 537 | apiKeyName := cfg.APIKeyName 538 | apiKeyLocation := cfg.APIKeyLocation 539 | hasServerKey := resolvedKey != "" && apiKeyName != "" && apiKeyLocation != "" 540 | 541 | log.Printf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "") 542 | 543 | // --- Prepare Request Components --- 544 | baseURL := operation.BaseURL // Use BaseURL from the specific operation 545 | if cfg.ServerBaseURL != "" { 546 | baseURL = cfg.ServerBaseURL // Override if global base URL is set 547 | log.Printf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL) 548 | } 549 | if baseURL == "" { 550 | log.Printf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName) 551 | // For now, assume relative if empty. 552 | } 553 | 554 | path := operation.Path 555 | queryParams := url.Values{} 556 | pathParams := make(map[string]string) 557 | headerParams := make(http.Header) // For headers to add 558 | cookieParams := []*http.Cookie{} // For cookies to add 559 | bodyData := make(map[string]interface{}) // For building the request body 560 | requestBodyRequired := operation.Method == "POST" || operation.Method == "PUT" || operation.Method == "PATCH" 561 | 562 | // Create a map of expected parameters from the operation details for easier lookup 563 | expectedParams := make(map[string]string) // Map param name to its location ('in') 564 | for _, p := range operation.Parameters { 565 | expectedParams[p.Name] = p.In 566 | } 567 | 568 | // --- Process Input Parameters (Separating and Handling API Key Override) --- 569 | log.Printf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput)) 570 | for key, value := range toolInput { 571 | // --- API Key Override Check --- 572 | // If this input param is the API key AND we have a valid server key config, 573 | // skip processing the client's value entirely. 574 | if hasServerKey && key == apiKeyName { 575 | log.Printf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key) 576 | continue 577 | } 578 | // --- End API Key Override --- 579 | 580 | paramLocation, knownParam := expectedParams[key] 581 | pathPlaceholder := "{" + key + "}" // OpenAPI uses {param} 582 | 583 | if strings.Contains(path, pathPlaceholder) { 584 | // Handle path parameter substitution 585 | pathParams[key] = fmt.Sprintf("%v", value) 586 | log.Printf("[ExecuteToolCall] Found path parameter %s=%v", key, value) 587 | } else if knownParam { 588 | // Handle parameters defined in the spec (query, header, cookie) 589 | switch paramLocation { 590 | case "query": 591 | queryParams.Add(key, fmt.Sprintf("%v", value)) 592 | log.Printf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value) 593 | case "header": 594 | headerParams.Add(key, fmt.Sprintf("%v", value)) 595 | log.Printf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value) 596 | case "cookie": 597 | cookieParams = append(cookieParams, &http.Cookie{Name: key, Value: fmt.Sprintf("%v", value)}) 598 | log.Printf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value) 599 | // case "formData": // TODO: Handle form data if needed 600 | // bodyData[key] = value // Or handle differently based on content type 601 | // log.Printf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value) 602 | default: 603 | // Known parameter but location handling is missing or mismatched. 604 | if paramLocation == "path" && (operation.Method == "GET" || operation.Method == "DELETE") { 605 | // If spec says 'path' but it wasn't in the actual path, and it's a GET/DELETE, 606 | // treat it as a query parameter as a fallback. 607 | log.Printf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path) 608 | queryParams.Add(key, fmt.Sprintf("%v", value)) 609 | } else { 610 | // Otherwise, log the warning and ignore. 611 | log.Printf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation) 612 | } 613 | } 614 | } else if requestBodyRequired { 615 | // If parameter is not in path or defined in spec params, and method expects a body, 616 | // assume it belongs in the request body. 617 | bodyData[key] = value 618 | log.Printf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value) 619 | } else { 620 | // Parameter not in path, not in spec, and not a body method. 621 | // This could be an extraneous parameter like 'explanation'. Log it. 622 | log.Printf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method) 623 | } 624 | } 625 | 626 | // --- Substitute Path Parameters --- 627 | for key, value := range pathParams { 628 | path = strings.Replace(path, "{"+key+"}", value, -1) 629 | } 630 | 631 | // --- Inject Server API Key (if applicable) --- 632 | if hasServerKey { 633 | log.Printf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation)) 634 | switch apiKeyLocation { 635 | case config.APIKeyLocationQuery: 636 | queryParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value 637 | log.Printf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName) 638 | case config.APIKeyLocationHeader: 639 | headerParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value 640 | log.Printf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName) 641 | case config.APIKeyLocationPath: 642 | pathPlaceholder := "{" + apiKeyName + "}" 643 | if strings.Contains(path, pathPlaceholder) { 644 | path = strings.Replace(path, pathPlaceholder, resolvedKey, -1) 645 | log.Printf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName) 646 | } else { 647 | log.Printf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path) 648 | } 649 | case config.APIKeyLocationCookie: 650 | // Check if cookie already exists from input, replace if so 651 | foundCookie := false 652 | for i, c := range cookieParams { 653 | if c.Name == apiKeyName { 654 | log.Printf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName) 655 | cookieParams[i] = &http.Cookie{Name: apiKeyName, Value: resolvedKey} // Replace existing 656 | foundCookie = true 657 | break 658 | } 659 | } 660 | if !foundCookie { 661 | log.Printf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName) 662 | cookieParams = append(cookieParams, &http.Cookie{Name: apiKeyName, Value: resolvedKey}) // Append new 663 | } 664 | default: 665 | // Use log.Printf for consistency 666 | log.Printf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation) 667 | } 668 | } else { 669 | log.Printf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).") 670 | } 671 | 672 | // --- Final URL Construction --- 673 | // Reconstruct query string *after* potential API key injection 674 | targetURL := baseURL + path 675 | if len(queryParams) > 0 { 676 | targetURL += "?" + queryParams.Encode() 677 | } 678 | log.Printf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL) 679 | 680 | // --- Prepare Request Body --- 681 | var reqBody io.Reader 682 | var bodyBytes []byte // Keep for logging 683 | if requestBodyRequired && len(bodyData) > 0 { 684 | var err error 685 | bodyBytes, err = json.Marshal(bodyData) 686 | if err != nil { 687 | log.Printf("[ExecuteToolCall] Error marshalling request body: %v", err) 688 | return nil, fmt.Errorf("error marshalling request body: %w", err) 689 | } 690 | reqBody = bytes.NewBuffer(bodyBytes) 691 | log.Printf("[ExecuteToolCall] Request body: %s", string(bodyBytes)) 692 | } 693 | 694 | // --- Create HTTP Request --- 695 | req, err := http.NewRequest(operation.Method, targetURL, reqBody) 696 | if err != nil { 697 | log.Printf("[ExecuteToolCall] Error creating HTTP request: %v", err) 698 | return nil, fmt.Errorf("error creating request: %w", err) 699 | } 700 | 701 | // --- Set Headers --- 702 | // Default headers 703 | req.Header.Set("Accept", "application/json") // Assume JSON response typical for APIs 704 | if reqBody != nil { 705 | req.Header.Set("Content-Type", "application/json") // Assume JSON body if body exists 706 | } 707 | 708 | // Add headers collected from input/spec AND potentially injected API key 709 | for key, values := range headerParams { 710 | // Note: We use Set, assuming single value per header from input typically. 711 | // If multi-value headers are needed from spec/input, use Add. 712 | if len(values) > 0 { 713 | req.Header.Set(key, values[0]) 714 | } 715 | } 716 | 717 | // Add custom headers from config (comma-separated) 718 | if cfg.CustomHeaders != "" { 719 | headers := strings.Split(cfg.CustomHeaders, ",") 720 | for _, h := range headers { 721 | parts := strings.SplitN(h, ":", 2) 722 | if len(parts) == 2 { 723 | headerName := strings.TrimSpace(parts[0]) 724 | headerValue := strings.TrimSpace(parts[1]) 725 | if headerName != "" { 726 | req.Header.Set(headerName, headerValue) // Set overrides potential input 727 | log.Printf("[ExecuteToolCall] Added custom header from config: %s", headerName) 728 | } 729 | } 730 | } 731 | } 732 | 733 | // --- Add Cookies --- 734 | for _, cookie := range cookieParams { 735 | req.AddCookie(cookie) 736 | } 737 | 738 | log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header) 739 | if len(req.Cookies()) > 0 { 740 | log.Printf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies()) 741 | } 742 | 743 | // --- Execute HTTP Request --- 744 | log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header) 745 | client := &http.Client{Timeout: 30 * time.Second} 746 | resp, err := client.Do(req) 747 | if err != nil { 748 | log.Printf("[ExecuteToolCall] Error executing HTTP request: %v", err) 749 | return nil, fmt.Errorf("error executing request: %w", err) 750 | } 751 | 752 | log.Printf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode) 753 | // Note: Don't close resp.Body here, the caller (handleToolCallJSONRPC) needs it. 754 | return resp, nil 755 | } 756 | 757 | func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet, cfg *config.Config) jsonRPCResponse { 758 | // req.Params is interface{}, but should contain json.RawMessage for tools/call 759 | rawParams, ok := req.Params.(json.RawMessage) 760 | if !ok { 761 | // If it's not RawMessage, maybe it was already decoded to a map? Handle that case too. 762 | if paramsMap, mapOk := req.Params.(map[string]interface{}); mapOk { 763 | // Attempt to marshal the map back to JSON bytes 764 | var marshalErr error 765 | rawParams, marshalErr = json.Marshal(paramsMap) 766 | if marshalErr != nil { 767 | log.Printf("Error marshalling params map for %s: %v", connID, marshalErr) 768 | return createJSONRPCError(req.ID, -32602, "Invalid parameters format (map marshal failed)", marshalErr.Error()) 769 | } 770 | log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from map)", connID, string(rawParams)) 771 | } else { 772 | log.Printf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params) 773 | return createJSONRPCError(req.ID, -32602, "Invalid parameters format (expected JSON object)", nil) 774 | } 775 | } else { 776 | log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from RawMessage)", connID, string(rawParams)) 777 | } 778 | 779 | // Now, unmarshal the rawParams ([]byte) into ToolCallParams 780 | var params ToolCallParams 781 | if err := json.Unmarshal(rawParams, ¶ms); err != nil { 782 | log.Printf("Error unmarshalling tools/call params for %s: %v", connID, err) 783 | return createJSONRPCError(req.ID, -32602, "Invalid parameters structure (unmarshal)", err.Error()) 784 | } 785 | 786 | log.Printf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input) 787 | 788 | // --- Execute the actual tool call --- 789 | httpResp, execErr := executeToolCall(¶ms, toolSet, cfg) 790 | 791 | // --- Process Response --- 792 | var resultPayload ToolResultPayload 793 | if execErr != nil { 794 | log.Printf("Error executing tool call '%s': %v", params.ToolName, execErr) 795 | resultPayload = ToolResultPayload{ 796 | IsError: true, 797 | Error: &MCPError{ 798 | Message: fmt.Sprintf("Failed to execute tool '%s': %v", params.ToolName, execErr), 799 | }, 800 | ToolCallID: fmt.Sprintf("%v", req.ID), 801 | } 802 | } else { 803 | defer httpResp.Body.Close() // Ensure body is closed 804 | bodyBytes, readErr := io.ReadAll(httpResp.Body) 805 | if readErr != nil { 806 | log.Printf("Error reading response body for tool '%s': %v", params.ToolName, readErr) 807 | resultPayload = ToolResultPayload{ 808 | IsError: true, 809 | Error: &MCPError{ 810 | Message: fmt.Sprintf("Failed to read response from tool '%s': %v", params.ToolName, readErr), 811 | }, 812 | ToolCallID: fmt.Sprintf("%v", req.ID), 813 | } 814 | } else { 815 | log.Printf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes)) 816 | // Check status code for API-level errors 817 | if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { 818 | resultPayload = ToolResultPayload{ 819 | IsError: true, 820 | Error: &MCPError{ 821 | Code: httpResp.StatusCode, 822 | Message: fmt.Sprintf("Tool '%s' API call failed with status %s", params.ToolName, httpResp.Status), 823 | Data: string(bodyBytes), // Include response body in error data 824 | }, 825 | ToolCallID: fmt.Sprintf("%v", req.ID), 826 | } 827 | } else { 828 | // Successful execution 829 | resultContent := []ToolResultContent{ 830 | { 831 | Type: "text", // TODO: Handle JSON responses properly if Content-Type indicates it 832 | Text: string(bodyBytes), 833 | }, 834 | } 835 | resultPayload = ToolResultPayload{ 836 | Content: resultContent, 837 | IsError: false, 838 | ToolCallID: fmt.Sprintf("%v", req.ID), 839 | } 840 | } 841 | } 842 | } 843 | 844 | // --- Send Response --- 845 | return jsonRPCResponse{ 846 | Jsonrpc: "2.0", 847 | ID: req.ID, // Match request ID 848 | Result: resultPayload, // Use the actual result payload 849 | } 850 | } 851 | 852 | // --- Helper Functions (Updated for JSON-RPC) --- 853 | 854 | // sendJSONRPCResponse sends a JSON-RPC response *synchronously*. 855 | // Keep this for now for sending synchronous errors on POST decode/read failures. 856 | func sendJSONRPCResponse(w http.ResponseWriter, resp jsonRPCResponse) { 857 | w.Header().Set("Content-Type", "application/json") 858 | if err := json.NewEncoder(w).Encode(resp); err != nil { 859 | log.Printf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err) 860 | // Attempt to send a plain text error if JSON encoding fails 861 | tryWriteHTTPError(w, http.StatusInternalServerError, "Internal Server Error encoding JSON-RPC response") 862 | } 863 | log.Printf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID) 864 | } 865 | 866 | // createJSONRPCError creates a JSON-RPC error response. 867 | func createJSONRPCError(id interface{}, code int, message string, data interface{}) jsonRPCResponse { 868 | jsonErr := &jsonError{Code: code, Message: message, Data: data} 869 | return jsonRPCResponse{ 870 | Jsonrpc: "2.0", 871 | ID: id, // Error response should echo the request ID 872 | Error: jsonErr, 873 | } 874 | } 875 | 876 | // sendJSONRPCError sends a JSON-RPC error response. 877 | func sendJSONRPCError(w http.ResponseWriter, connID string, id interface{}, code int, message string, data interface{}) { 878 | resp := createJSONRPCError(id, code, message, data) 879 | log.Printf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message) 880 | sendJSONRPCResponse(w, resp) 881 | } 882 | 883 | // Helper to get the method name for logging purposes (from the result/error structure if possible) 884 | func getMethodFromResponse(resp jsonRPCResponse) string { 885 | if resp.Result != nil { 886 | // Attempt to infer method from result structure if it has a type field 887 | if resMap, ok := resp.Result.(map[string]interface{}); ok { 888 | if methodType, typeOk := resMap["type"].(string); typeOk { 889 | return methodType + "_result" 890 | } 891 | } 892 | // Infer based on known result types if possible 893 | if _, ok := resp.Result.(map[string]interface{}); ok && resp.Result.(map[string]interface{})["tools"] != nil { 894 | return "tool_set" 895 | } 896 | // If not easily identifiable, just indicate success 897 | return "success" 898 | } else if resp.Error != nil { 899 | return "error" 900 | } 901 | return "unknown" 902 | } 903 | 904 | // tryWriteHTTPError attempts to write an HTTP error, ignoring failures. 905 | func tryWriteHTTPError(w http.ResponseWriter, code int, message string) { 906 | if _, err := w.Write([]byte(message)); err != nil { 907 | log.Printf("Error writing plain HTTP error response: %v", err) 908 | } 909 | log.Printf("Sent plain HTTP error: %s (Code: %d)", message, code) 910 | } 911 | -------------------------------------------------------------------------------- /pkg/server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net/http" 11 | "net/http/httptest" 12 | "strings" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | "github.com/ckanthony/openapi-mcp/pkg/config" 18 | "github.com/ckanthony/openapi-mcp/pkg/mcp" 19 | "github.com/google/uuid" 20 | "github.com/stretchr/testify/assert" 21 | "github.com/stretchr/testify/require" 22 | ) 23 | 24 | // --- Re-added Helper Functions --- 25 | 26 | // Helper function to create a simple ToolSet for testing tool calls 27 | func createTestToolSetForCall() *mcp.ToolSet { 28 | return &mcp.ToolSet{ 29 | Name: "Call Test API", 30 | Tools: []mcp.Tool{ 31 | { 32 | Name: "get_user", 33 | Description: "Get user details", 34 | InputSchema: mcp.Schema{ 35 | Type: "object", 36 | Properties: map[string]mcp.Schema{ 37 | "user_id": {Type: "string"}, 38 | }, 39 | Required: []string{"user_id"}, 40 | }, 41 | }, 42 | { 43 | Name: "post_data", 44 | Description: "Post some data", 45 | InputSchema: mcp.Schema{ 46 | Type: "object", 47 | Properties: map[string]mcp.Schema{ 48 | "data": {Type: "string"}, 49 | }, 50 | Required: []string{"data"}, 51 | }, 52 | }, 53 | }, 54 | Operations: map[string]mcp.OperationDetail{ 55 | "get_user": { 56 | Method: "GET", 57 | Path: "/users/{user_id}", 58 | Parameters: []mcp.ParameterDetail{ 59 | {Name: "user_id", In: "path"}, 60 | }, 61 | }, 62 | "post_data": { 63 | Method: "POST", 64 | Path: "/data", 65 | Parameters: []mcp.ParameterDetail{}, // Body params assumed 66 | }, 67 | }, 68 | } 69 | } 70 | 71 | // Helper to safely manage activeConnections for tests 72 | func setupTestConnection(connID string) chan jsonRPCResponse { 73 | msgChan := make(chan jsonRPCResponse, 1) // Buffer of 1 sufficient for most tests 74 | connMutex.Lock() 75 | activeConnections[connID] = msgChan 76 | connMutex.Unlock() 77 | return msgChan 78 | } 79 | 80 | func cleanupTestConnection(connID string) { 81 | connMutex.Lock() 82 | msgChan, exists := activeConnections[connID] 83 | if exists { 84 | delete(activeConnections, connID) 85 | close(msgChan) 86 | } 87 | connMutex.Unlock() 88 | } 89 | 90 | // --- End Re-added Helper Functions --- 91 | 92 | func TestHttpMethodPostHandler(t *testing.T) { 93 | // --- Setup common test items --- 94 | toolSet := createTestToolSetForCall() // Use the helper 95 | cfg := &config.Config{} // Basic config 96 | // NOTE: connID is now generated within each subtest to ensure isolation 97 | 98 | // --- Define Test Cases --- 99 | tests := []struct { 100 | name string 101 | requestBodyFn func(connID string) string // Function to generate body with dynamic connID 102 | expectedSyncStatus int // Expected status code for the immediate POST response 103 | expectedSyncBody string // Expected body for the immediate POST response 104 | checkAsyncResponse func(t *testing.T, resp jsonRPCResponse) // Function to check async response 105 | mockBackend http.HandlerFunc // Optional mock backend for tool calls 106 | setupChannelDirectly func(connID string) chan jsonRPCResponse // Optional: For specific channel setups 107 | }{ 108 | { 109 | name: "Valid Initialize Request", 110 | requestBodyFn: func(connID string) string { 111 | return fmt.Sprintf(`{ 112 | "jsonrpc": "2.0", 113 | "method": "initialize", 114 | "id": "init-post-1", 115 | "params": {"connectionId": "%s"} 116 | }`, connID) 117 | }, 118 | expectedSyncStatus: http.StatusAccepted, 119 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 120 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 121 | assert.Equal(t, "init-post-1", resp.ID) 122 | assert.Nil(t, resp.Error) 123 | resultMap, ok := resp.Result.(map[string]interface{}) 124 | require.True(t, ok) 125 | assert.Contains(t, resultMap, "connectionId") // Check existence, actual ID checked separately 126 | assert.Equal(t, "2024-11-05", resultMap["protocolVersion"]) 127 | }, 128 | }, 129 | { 130 | name: "Valid Tools List Request", 131 | requestBodyFn: func(connID string) string { 132 | return `{ 133 | "jsonrpc": "2.0", 134 | "method": "tools/list", 135 | "id": "list-post-1" 136 | }` 137 | }, 138 | expectedSyncStatus: http.StatusAccepted, 139 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 140 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 141 | assert.Equal(t, "list-post-1", resp.ID) 142 | assert.Nil(t, resp.Error) 143 | resultMap, ok := resp.Result.(map[string]interface{}) 144 | require.True(t, ok) 145 | assert.Contains(t, resultMap, "metadata") 146 | assert.Contains(t, resultMap, "tools") 147 | metadata, _ := resultMap["metadata"].(map[string]interface{}) 148 | assert.Equal(t, 2, metadata["count"]) // Corrected: Expect int(2) 149 | }, 150 | }, 151 | { 152 | name: "Valid Tool Call Request (Success)", 153 | requestBodyFn: func(connID string) string { 154 | return `{ 155 | "jsonrpc": "2.0", 156 | "method": "tools/call", 157 | "id": "call-post-1", 158 | "params": {"name": "get_user", "arguments": {"user_id": "postUser"}} 159 | }` 160 | }, 161 | expectedSyncStatus: http.StatusAccepted, 162 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 163 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 164 | assert.Equal(t, "call-post-1", resp.ID) 165 | assert.Nil(t, resp.Error) 166 | resultPayload, ok := resp.Result.(ToolResultPayload) 167 | require.True(t, ok) 168 | assert.False(t, resultPayload.IsError) 169 | require.Len(t, resultPayload.Content, 1) 170 | assert.JSONEq(t, `{"id":"postUser"}`, resultPayload.Content[0].Text) 171 | }, 172 | mockBackend: func(w http.ResponseWriter, r *http.Request) { 173 | w.WriteHeader(http.StatusOK) 174 | fmt.Fprintln(w, `{"id":"postUser"}`) 175 | }, 176 | }, 177 | { 178 | name: "Valid Tool Call Request (Tool Not Found)", 179 | requestBodyFn: func(connID string) string { 180 | return `{ 181 | "jsonrpc": "2.0", 182 | "method": "tools/call", 183 | "id": "call-post-err-1", 184 | "params": {"name": "nonexistent_tool", "arguments": {}} 185 | }` 186 | }, 187 | expectedSyncStatus: http.StatusAccepted, 188 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 189 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 190 | assert.Equal(t, "call-post-err-1", resp.ID) 191 | assert.Nil(t, resp.Error) 192 | resultPayload, ok := resp.Result.(ToolResultPayload) 193 | require.True(t, ok) 194 | assert.True(t, resultPayload.IsError) 195 | require.NotNil(t, resultPayload.Error) 196 | assert.Contains(t, resultPayload.Error.Message, "operation details for tool 'nonexistent_tool' not found") 197 | }, 198 | }, 199 | { 200 | name: "Malformed JSON Request", 201 | requestBodyFn: func(connID string) string { 202 | return `{"jsonrpc": "2.0", "method": "initialize"` 203 | }, 204 | expectedSyncStatus: http.StatusAccepted, // Even decode errors return 202, error is sent async 205 | expectedSyncBody: "Request accepted (with decode error), response will be sent via SSE.\n", 206 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 207 | assert.Nil(t, resp.ID) // ID might be nil if request parsing failed early 208 | require.NotNil(t, resp.Error) 209 | assert.Equal(t, -32700, resp.Error.Code) // Parse Error 210 | assert.Equal(t, "Parse error decoding JSON request", resp.Error.Message) // Corrected assertion 211 | }, 212 | }, 213 | { 214 | name: "Missing JSON-RPC Version", 215 | requestBodyFn: func(connID string) string { 216 | return `{ 217 | "method": "initialize", 218 | "id": "rpc-err-1" 219 | }` 220 | }, 221 | expectedSyncStatus: http.StatusAccepted, 222 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 223 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 224 | assert.Equal(t, "rpc-err-1", resp.ID) 225 | require.NotNil(t, resp.Error) 226 | assert.Equal(t, -32600, resp.Error.Code) // Invalid Request 227 | assert.Contains(t, resp.Error.Message, "jsonrpc field must be \"2.0\"") 228 | }, 229 | }, 230 | { 231 | name: "Unknown Method", 232 | requestBodyFn: func(connID string) string { 233 | return `{ 234 | "jsonrpc": "2.0", 235 | "method": "unknown/method", 236 | "id": "rpc-err-2" 237 | }` 238 | }, 239 | expectedSyncStatus: http.StatusAccepted, 240 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 241 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 242 | assert.Equal(t, "rpc-err-2", resp.ID) 243 | require.NotNil(t, resp.Error) 244 | assert.Equal(t, -32601, resp.Error.Code) // Method not found 245 | assert.Contains(t, resp.Error.Message, "Method not found") 246 | }, 247 | }, 248 | { 249 | name: "Missing Method", 250 | requestBodyFn: func(connID string) string { 251 | return `{ 252 | "jsonrpc": "2.0", 253 | "id": "rpc-err-3" 254 | }` 255 | }, 256 | expectedSyncStatus: http.StatusAccepted, 257 | expectedSyncBody: "Request accepted, response will be sent via SSE.\n", 258 | checkAsyncResponse: func(t *testing.T, resp jsonRPCResponse) { 259 | assert.Equal(t, "rpc-err-3", resp.ID) 260 | require.NotNil(t, resp.Error) 261 | assert.Equal(t, -32600, resp.Error.Code) // Invalid Request 262 | assert.Equal(t, "Invalid Request: method field is missing or empty", resp.Error.Message) // Corrected assertion 263 | }, 264 | }, 265 | { 266 | name: "Error Queuing Response To SSE", 267 | requestBodyFn: func(connID string) string { // Use a simple valid request like tools/list 268 | return `{ 269 | "jsonrpc": "2.0", 270 | "method": "tools/list", 271 | "id": "list-post-err-queue" 272 | }` 273 | }, 274 | expectedSyncStatus: http.StatusInternalServerError, // Expect 500 when channel is blocked 275 | expectedSyncBody: "Failed to queue response for SSE channel\n", // Specific error message expected 276 | setupChannelDirectly: func(connID string) chan jsonRPCResponse { 277 | // Create a NON-BUFFERED channel to simulate blocking/full channel 278 | msgChan := make(chan jsonRPCResponse) // No buffer size! 279 | connMutex.Lock() 280 | activeConnections[connID] = msgChan 281 | connMutex.Unlock() 282 | // Important: Do NOT start a reader for this channel 283 | return msgChan 284 | }, 285 | checkAsyncResponse: nil, // No async response should be successfully sent 286 | }, 287 | } 288 | 289 | // --- Run Test Cases --- 290 | for _, tc := range tests { 291 | t.Run(tc.name, func(t *testing.T) { 292 | connID := uuid.NewString() // Generate unique connID for each subtest 293 | 294 | // Setup mock backend if needed for this test case 295 | var backendServer *httptest.Server 296 | // --- Add Connection ID before test --- 297 | var msgChan chan jsonRPCResponse 298 | if tc.setupChannelDirectly != nil { 299 | // Use custom setup if provided (e.g., for blocking channel test) 300 | msgChan = tc.setupChannelDirectly(connID) 301 | } else { 302 | // Default setup using the helper with buffered channel 303 | msgChan = setupTestConnection(connID) 304 | } 305 | defer cleanupTestConnection(connID) // Ensure cleanup after test 306 | 307 | if tc.mockBackend != nil { 308 | backendServer = httptest.NewServer(tc.mockBackend) 309 | defer backendServer.Close() 310 | // IMPORTANT: Update the toolset's BaseURL for the relevant operation 311 | if strings.Contains(tc.requestBodyFn(connID), "get_user") { // Simple check based on request 312 | op := toolSet.Operations["get_user"] 313 | op.BaseURL = backendServer.URL 314 | toolSet.Operations["get_user"] = op 315 | } 316 | // Update post_data BaseURL if needed 317 | if strings.Contains(tc.requestBodyFn(connID), "post_data") { 318 | op := toolSet.Operations["post_data"] 319 | op.BaseURL = backendServer.URL 320 | toolSet.Operations["post_data"] = op 321 | } 322 | } 323 | 324 | reqBody := tc.requestBodyFn(connID) // Generate request body 325 | req := httptest.NewRequest(http.MethodPost, "/mcp", strings.NewReader(reqBody)) 326 | req.Header.Set("Content-Type", "application/json") 327 | req.Header.Set("X-Connection-ID", connID) // Use the generated connID 328 | rr := httptest.NewRecorder() 329 | 330 | httpMethodPostHandler(rr, req, toolSet, cfg) 331 | 332 | // 1. Check synchronous response 333 | assert.Equal(t, tc.expectedSyncStatus, rr.Code, "Unexpected status code for sync response") 334 | // Trim space for comparison as http.Error might add a newline our literal doesn't have 335 | assert.Equal(t, strings.TrimSpace(tc.expectedSyncBody), strings.TrimSpace(rr.Body.String()), "Unexpected body for sync response") 336 | 337 | // 2. Check asynchronous response (sent via SSE channel) 338 | if tc.checkAsyncResponse != nil { 339 | select { 340 | case asyncResp := <-msgChan: 341 | tc.checkAsyncResponse(t, asyncResp) 342 | case <-time.After(100 * time.Millisecond): // Add a timeout 343 | t.Fatal("Timeout waiting for async response on SSE channel") 344 | } 345 | } else { 346 | // If no async check is defined, ensure nothing was sent (e.g., for queue error test) 347 | select { 348 | case unexpectedResp, ok := <-msgChan: 349 | if ok { // Only fail if the channel wasn't closed AND we got a message 350 | t.Errorf("Received unexpected async response when none was expected: %+v", unexpectedResp) 351 | } 352 | // If !ok, channel was closed, which is fine/expected after cleanup 353 | case <-time.After(50 * time.Millisecond): 354 | // Success - no message received quickly, channel likely blocked as expected 355 | } 356 | } 357 | }) 358 | } 359 | } 360 | 361 | func TestHttpMethodGetHandler(t *testing.T) { 362 | // --- Setup --- 363 | // Reset global state for this test 364 | connMutex.Lock() 365 | originalConnections := activeConnections 366 | activeConnections = make(map[string]chan jsonRPCResponse) 367 | connMutex.Unlock() 368 | 369 | req, err := http.NewRequest("GET", "/mcp", nil) 370 | require.NoError(t, err, "Failed to create request") 371 | 372 | rr := httptest.NewRecorder() 373 | 374 | // Ensure cleanup happens regardless of test outcome 375 | defer func() { 376 | connMutex.Lock() 377 | // Clean up any connections potentially left by the test 378 | for id, ch := range activeConnections { 379 | close(ch) 380 | delete(activeConnections, id) 381 | log.Printf("[DEFER Cleanup] Closed channel and removed connection %s", id) 382 | } 383 | activeConnections = originalConnections // Restore the original map 384 | connMutex.Unlock() 385 | }() 386 | 387 | // --- Execute Handler (in a goroutine as it blocks waiting for context) --- 388 | ctx, cancel := context.WithCancel(context.Background()) 389 | req = req.WithContext(ctx) 390 | 391 | hwg := sync.WaitGroup{} 392 | hwg.Add(1) 393 | go func() { 394 | defer hwg.Done() 395 | // Simulate some work before handler returns 396 | // In a real scenario, this would block on ctx.Done() or keepAliveTicker 397 | // For the test, we just call cancel() after a short delay 398 | // to simulate the connection ending gracefully. 399 | time.AfterFunc(100*time.Millisecond, cancel) // Allow handler to start and write initial data 400 | httpMethodGetHandler(rr, req) 401 | }() 402 | 403 | // Wait for the handler goroutine to finish. 404 | // This ensures all writes to rr are complete before we read. 405 | if !waitTimeout(&hwg, 2*time.Second) { // Use a reasonable timeout 406 | t.Fatal("Handler goroutine did not exit cleanly after context cancellation") 407 | } 408 | 409 | // --- Assertions (Performed *after* handler completion) --- 410 | assert.Equal(t, http.StatusOK, rr.Code, "Status code should be OK") 411 | 412 | // Check headers are set correctly 413 | assert.Equal(t, "text/event-stream", rr.Header().Get("Content-Type")) 414 | assert.Equal(t, "no-cache", rr.Header().Get("Cache-Control")) 415 | assert.Equal(t, "keep-alive", rr.Header().Get("Connection")) 416 | connID := rr.Header().Get("X-Connection-ID") 417 | assert.NotEmpty(t, connID, "X-Connection-ID header should be set") 418 | 419 | // Check connection was registered and then cleaned up 420 | connMutex.RLock() 421 | _, exists := originalConnections[connID] // Check original map after cleanup 422 | connMutex.RUnlock() 423 | assert.False(t, exists, "Connection ID should be removed from map after handler exits") 424 | 425 | // Check initial body content is present 426 | bodyContent := rr.Body.String() 427 | assert.Contains(t, bodyContent, ":ok\n\n", "Body should contain :ok preamble") 428 | // Construct the expected endpoint data string accurately 429 | expectedEndpointData := "data: /mcp?sessionId=" + connID + "\n\n" 430 | assert.Contains(t, bodyContent, "event: endpoint\n"+expectedEndpointData, "Body should contain endpoint event") 431 | assert.Contains(t, bodyContent, "event: message\ndata: {", "Body should contain start of a message event (e.g., mcp-ready)") 432 | // Check if connectionId is present in the ready message (adjust based on actual JSON structure) 433 | assert.Contains(t, bodyContent, `"connectionId":"`+connID+`"`, "Body should contain mcp-ready event with correct connection ID") 434 | 435 | // The explicit cleanupTestConnection call is not needed because the handler's defer and the test's defer handle it. 436 | } 437 | 438 | func TestExecuteToolCall(t *testing.T) { 439 | tests := []struct { 440 | name string 441 | params ToolCallParams 442 | opDetail mcp.OperationDetail 443 | cfg *config.Config 444 | expectError bool 445 | containsError string 446 | requestAsserter func(t *testing.T, r *http.Request) // Function to assert details of the received HTTP request 447 | backendResponse string // Response body from mock backend 448 | backendStatusCode int // Status code from mock backend 449 | }{ 450 | // --- Basic GET with Path Param --- 451 | { 452 | name: "GET with path parameter", 453 | params: ToolCallParams{ 454 | ToolName: "get_item", 455 | Input: map[string]interface{}{"item_id": "item123"}, 456 | }, 457 | opDetail: mcp.OperationDetail{ 458 | Method: "GET", 459 | Path: "/items/{item_id}", 460 | Parameters: []mcp.ParameterDetail{{Name: "item_id", In: "path"}}, 461 | }, 462 | cfg: &config.Config{}, 463 | expectError: false, 464 | backendStatusCode: http.StatusOK, 465 | backendResponse: `{"status":"ok"}`, 466 | requestAsserter: func(t *testing.T, r *http.Request) { 467 | assert.Equal(t, http.MethodGet, r.Method) 468 | assert.Equal(t, "/items/item123", r.URL.Path) 469 | assert.Empty(t, r.URL.RawQuery) 470 | }, 471 | }, 472 | // --- POST with Query, Header, Cookie, and Body Params --- 473 | { 474 | name: "POST with various params", 475 | params: ToolCallParams{ 476 | ToolName: "create_resource", 477 | Input: map[string]interface{}{ 478 | "queryArg": "value1", 479 | "X-Custom-Hdr": "headerValue", 480 | "sessionToken": "cookieValue", 481 | "bodyFieldA": "A", 482 | "bodyFieldB": 123, 483 | }, 484 | }, 485 | opDetail: mcp.OperationDetail{ 486 | Method: "POST", 487 | Path: "/resources", 488 | Parameters: []mcp.ParameterDetail{ 489 | {Name: "queryArg", In: "query"}, 490 | {Name: "X-Custom-Hdr", In: "header"}, 491 | {Name: "sessionToken", In: "cookie"}, 492 | // Body fields are implicitly handled 493 | }, 494 | }, 495 | cfg: &config.Config{}, 496 | expectError: false, 497 | backendStatusCode: http.StatusCreated, 498 | backendResponse: `{"id":"res456"}`, 499 | requestAsserter: func(t *testing.T, r *http.Request) { 500 | assert.Equal(t, http.MethodPost, r.Method) 501 | assert.Equal(t, "/resources", r.URL.Path) 502 | assert.Equal(t, "value1", r.URL.Query().Get("queryArg")) 503 | assert.Equal(t, "headerValue", r.Header.Get("X-Custom-Hdr")) 504 | cookie, err := r.Cookie("sessionToken") 505 | require.NoError(t, err) 506 | assert.Equal(t, "cookieValue", cookie.Value) 507 | bodyBytes, _ := io.ReadAll(r.Body) 508 | assert.JSONEq(t, `{"bodyFieldA":"A", "bodyFieldB":123}`, string(bodyBytes)) 509 | }, 510 | }, 511 | // --- API Key Injection (Header) --- 512 | { 513 | name: "API Key Injection (Header)", 514 | params: ToolCallParams{ 515 | ToolName: "get_secure", 516 | Input: map[string]interface{}{}, // No client key provided 517 | }, 518 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure"}, 519 | cfg: &config.Config{ 520 | APIKey: "secret-server-key", 521 | APIKeyName: "Authorization", 522 | APIKeyLocation: config.APIKeyLocationHeader, 523 | }, 524 | expectError: false, 525 | backendStatusCode: http.StatusOK, 526 | requestAsserter: func(t *testing.T, r *http.Request) { 527 | assert.Equal(t, "secret-server-key", r.Header.Get("Authorization")) 528 | }, 529 | }, 530 | // --- API Key Injection (Query) --- 531 | { 532 | name: "API Key Injection (Query)", 533 | params: ToolCallParams{ 534 | ToolName: "get_secure", 535 | Input: map[string]interface{}{"otherParam": "abc"}, 536 | }, 537 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure", Parameters: []mcp.ParameterDetail{{Name: "otherParam", In: "query"}}}, 538 | cfg: &config.Config{ 539 | APIKey: "secret-server-key-q", 540 | APIKeyName: "api_key", 541 | APIKeyLocation: config.APIKeyLocationQuery, 542 | }, 543 | expectError: false, 544 | backendStatusCode: http.StatusOK, 545 | requestAsserter: func(t *testing.T, r *http.Request) { 546 | assert.Equal(t, "secret-server-key-q", r.URL.Query().Get("api_key")) 547 | assert.Equal(t, "abc", r.URL.Query().Get("otherParam")) // Ensure other params are preserved 548 | }, 549 | }, 550 | // --- API Key Injection (Path) --- 551 | { 552 | name: "API Key Injection (Path)", 553 | params: ToolCallParams{ 554 | ToolName: "get_secure_path", 555 | Input: map[string]interface{}{}, // Key comes from config 556 | }, 557 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure/{apiKey}/data"}, 558 | cfg: &config.Config{ 559 | APIKey: "path-key-123", 560 | APIKeyName: "apiKey", // Matches the placeholder name 561 | APIKeyLocation: config.APIKeyLocationPath, 562 | }, 563 | expectError: false, 564 | backendStatusCode: http.StatusOK, 565 | requestAsserter: func(t *testing.T, r *http.Request) { 566 | assert.Equal(t, "/secure/path-key-123/data", r.URL.Path) 567 | }, 568 | }, 569 | // --- API Key Injection (Cookie) --- 570 | { 571 | name: "API Key Injection (Cookie)", 572 | params: ToolCallParams{ 573 | ToolName: "get_secure_cookie", 574 | Input: map[string]interface{}{}, // Key comes from config 575 | }, 576 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/secure_cookie"}, 577 | cfg: &config.Config{ 578 | APIKey: "cookie-key-abc", 579 | APIKeyName: "AuthToken", 580 | APIKeyLocation: config.APIKeyLocationCookie, 581 | }, 582 | expectError: false, 583 | backendStatusCode: http.StatusOK, 584 | requestAsserter: func(t *testing.T, r *http.Request) { 585 | cookie, err := r.Cookie("AuthToken") 586 | require.NoError(t, err) 587 | assert.Equal(t, "cookie-key-abc", cookie.Value) 588 | }, 589 | }, 590 | // --- Base URL Handling Tests --- 591 | { 592 | name: "Base URL from Default (Mock Server)", 593 | params: ToolCallParams{ToolName: "get_default_url", Input: map[string]interface{}{}}, 594 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/path1"}, // No BaseURL here 595 | cfg: &config.Config{}, // No global override 596 | expectError: false, 597 | backendStatusCode: http.StatusOK, 598 | requestAsserter: func(t *testing.T, r *http.Request) { 599 | // Should hit the mock server at the correct path 600 | assert.Equal(t, "/path1", r.URL.Path) 601 | }, 602 | }, 603 | { 604 | name: "Base URL from Global Config Override", 605 | params: ToolCallParams{ToolName: "get_global_url", Input: map[string]interface{}{}}, 606 | opDetail: mcp.OperationDetail{Method: "GET", Path: "/path2", BaseURL: "http://should-be-ignored.com"}, 607 | // cfg will be updated in test loop to point ServerBaseURL to mock server 608 | cfg: &config.Config{}, 609 | expectError: false, 610 | backendStatusCode: http.StatusOK, 611 | requestAsserter: func(t *testing.T, r *http.Request) { 612 | // Should hit the mock server (set via cfg override) at the correct path 613 | assert.Equal(t, "/path2", r.URL.Path) 614 | }, 615 | }, 616 | // --- Error Case (Tool Not Found in ToolSet) --- 617 | { 618 | name: "Error - Tool Not Found", 619 | params: ToolCallParams{ 620 | ToolName: "nonexistent", 621 | Input: map[string]interface{}{}, 622 | }, 623 | opDetail: mcp.OperationDetail{}, // Not used, error occurs before this 624 | cfg: &config.Config{}, 625 | expectError: true, 626 | containsError: "operation details for tool 'nonexistent' not found", 627 | requestAsserter: nil, // No request should be made 628 | backendStatusCode: 0, // Not applicable 629 | }, 630 | } 631 | 632 | for _, tc := range tests { 633 | t.Run(tc.name, func(t *testing.T) { 634 | // --- Mock Backend Setup --- 635 | var backendServer *httptest.Server 636 | backendServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 637 | if tc.requestAsserter != nil { 638 | tc.requestAsserter(t, r) 639 | } 640 | w.WriteHeader(tc.backendStatusCode) 641 | fmt.Fprint(w, tc.backendResponse) 642 | })) 643 | defer backendServer.Close() 644 | 645 | // --- Prepare ToolSet (using mock server URL if needed) --- 646 | toolSet := &mcp.ToolSet{ 647 | Operations: make(map[string]mcp.OperationDetail), 648 | } 649 | 650 | // Clone config to avoid modifying the template test case config 651 | testCfg := *tc.cfg 652 | 653 | // Special handling for the global override test case 654 | if tc.name == "Base URL from Global Config Override" { 655 | testCfg.ServerBaseURL = backendServer.URL // Point global override to mock server 656 | } 657 | 658 | // If the opDetail needs a BaseURL, set it to the mock server ONLY if it wasn't 659 | // already set in the test case definition AND the global override isn't being used. 660 | if tc.opDetail.Method != "" { // Only add if it's a valid detail for the test 661 | if tc.opDetail.BaseURL == "" && testCfg.ServerBaseURL == "" { 662 | tc.opDetail.BaseURL = backendServer.URL 663 | } 664 | toolSet.Operations[tc.params.ToolName] = tc.opDetail 665 | } 666 | 667 | // --- Execute Function --- 668 | httpResp, err := executeToolCall(&tc.params, toolSet, &testCfg) // Use the potentially modified testCfg 669 | 670 | // --- Assertions --- 671 | if tc.expectError { 672 | assert.Error(t, err) 673 | if tc.containsError != "" { 674 | assert.Contains(t, err.Error(), tc.containsError) 675 | } 676 | assert.Nil(t, httpResp) 677 | } else { 678 | assert.NoError(t, err) 679 | require.NotNil(t, httpResp) 680 | defer httpResp.Body.Close() 681 | assert.Equal(t, tc.backendStatusCode, httpResp.StatusCode) 682 | bodyBytes, _ := io.ReadAll(httpResp.Body) 683 | assert.Equal(t, tc.backendResponse, string(bodyBytes)) 684 | } 685 | }) 686 | } 687 | } 688 | 689 | func TestWriteSSEEvent(t *testing.T) { 690 | tests := []struct { 691 | name string 692 | eventName string 693 | data interface{} 694 | expectedOut string 695 | expectError bool 696 | }{ 697 | { 698 | name: "Simple String Data", 699 | eventName: "endpoint", 700 | data: "/mcp?sessionId=123", 701 | expectedOut: "event: endpoint\ndata: /mcp?sessionId=123\n\n", 702 | expectError: false, 703 | }, 704 | { 705 | name: "Struct Data (JSON-RPC Request)", 706 | eventName: "message", 707 | data: jsonRPCRequest{ 708 | Jsonrpc: "2.0", 709 | Method: "mcp-ready", 710 | Params: map[string]interface{}{"connectionId": "abc"}, 711 | }, 712 | // Note: JSON marshaling order isn't guaranteed, so use JSONEq or check fields 713 | expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"mcp-ready\",\"params\":{\"connectionId\":\"abc\"}}\n\n", 714 | expectError: false, 715 | }, 716 | { 717 | name: "Struct Data (JSON-RPC Response)", 718 | eventName: "message", 719 | data: jsonRPCResponse{ 720 | Jsonrpc: "2.0", 721 | Result: map[string]interface{}{"status": "ok"}, 722 | ID: "req-1", 723 | }, 724 | expectedOut: "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"status\":\"ok\"},\"id\":\"req-1\"}\n\n", 725 | expectError: false, 726 | }, 727 | { 728 | name: "Error - Unmarshalable Data", 729 | eventName: "error", 730 | data: make(chan int), // Channels cannot be marshaled to JSON 731 | expectError: true, 732 | }, 733 | } 734 | 735 | for _, tc := range tests { 736 | t.Run(tc.name, func(t *testing.T) { 737 | rr := httptest.NewRecorder() 738 | err := writeSSEEvent(rr, tc.eventName, tc.data) 739 | 740 | if tc.expectError { 741 | assert.Error(t, err) 742 | } else { 743 | assert.NoError(t, err) 744 | // For struct data, use JSONEq for robust comparison 745 | if _, isStruct := tc.data.(jsonRPCRequest); isStruct { 746 | prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName) 747 | suffix := "\n\n" 748 | require.True(t, strings.HasPrefix(rr.Body.String(), prefix)) 749 | require.True(t, strings.HasSuffix(rr.Body.String(), suffix)) 750 | actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix) 751 | expectedJSONBytes, _ := json.Marshal(tc.data) 752 | assert.JSONEq(t, string(expectedJSONBytes), actualJSON) 753 | } else if _, isStruct := tc.data.(jsonRPCResponse); isStruct { 754 | prefix := fmt.Sprintf("event: %s\ndata: ", tc.eventName) 755 | suffix := "\n\n" 756 | require.True(t, strings.HasPrefix(rr.Body.String(), prefix)) 757 | require.True(t, strings.HasSuffix(rr.Body.String(), suffix)) 758 | actualJSON := strings.TrimSuffix(strings.TrimPrefix(rr.Body.String(), prefix), suffix) 759 | expectedJSONBytes, _ := json.Marshal(tc.data) 760 | assert.JSONEq(t, string(expectedJSONBytes), actualJSON) 761 | } else { 762 | // For simple types, direct string comparison is fine 763 | assert.Equal(t, tc.expectedOut, rr.Body.String()) 764 | } 765 | } 766 | }) 767 | } 768 | } 769 | 770 | func TestTryWriteHTTPError(t *testing.T) { 771 | rr := httptest.NewRecorder() 772 | message := "Test Error Message" 773 | code := http.StatusInternalServerError 774 | 775 | tryWriteHTTPError(rr, code, message) 776 | 777 | // Note: tryWriteHTTPError doesn't set the status code, it only writes the body. 778 | // The calling function is expected to have set the code earlier. 779 | // So, we only check the body content here. 780 | assert.Equal(t, message, rr.Body.String()) 781 | } 782 | 783 | func TestGetMethodFromResponse(t *testing.T) { 784 | tests := []struct { 785 | name string 786 | response jsonRPCResponse 787 | expected string 788 | }{ 789 | { 790 | name: "Error Response", 791 | response: jsonRPCResponse{ 792 | Error: &jsonError{Code: -32600, Message: "..."}, 793 | }, 794 | expected: "error", 795 | }, 796 | { 797 | name: "Tool List Response", 798 | response: jsonRPCResponse{ 799 | Result: map[string]interface{}{"tools": []interface{}{}, "metadata": map[string]interface{}{}}, 800 | }, 801 | expected: "tool_set", 802 | }, 803 | { 804 | name: "Initialize Response (Result is Map)", 805 | response: jsonRPCResponse{ 806 | Result: map[string]interface{}{"protocolVersion": "...", "capabilities": map[string]interface{}{}}, 807 | }, 808 | expected: "success", // Falls back to 'success' as type isn't explicitly set 809 | }, 810 | { 811 | name: "Tool Call Response (Result is ToolResultPayload)", 812 | response: jsonRPCResponse{ 813 | Result: ToolResultPayload{Content: []ToolResultContent{{Type: "text", Text: "..."}}}, 814 | }, 815 | expected: "success", // Falls back to 'success' 816 | }, 817 | { 818 | name: "Empty Response", 819 | response: jsonRPCResponse{}, 820 | expected: "unknown", 821 | }, 822 | } 823 | 824 | for _, tc := range tests { 825 | t.Run(tc.name, func(t *testing.T) { 826 | actual := getMethodFromResponse(tc.response) 827 | assert.Equal(t, tc.expected, actual) 828 | }) 829 | } 830 | } 831 | 832 | // --- Mock ResponseWriter for error simulation --- 833 | 834 | // mockResponseWriter implements http.ResponseWriter and http.Flusher for testing SSE. 835 | type sseMockResponseWriter struct { 836 | hdr http.Header // Internal map for headers 837 | statusCode int 838 | body *bytes.Buffer 839 | flushed bool 840 | forceError error // If set, Write and Flush will return this error 841 | failAfterNWrites int // Start failing after this many writes (-1 = disable) 842 | writesMade int // Counter for writes made 843 | } 844 | 845 | // Renamed constructor 846 | func newSseMockResponseWriter() *sseMockResponseWriter { 847 | return &sseMockResponseWriter{ 848 | hdr: make(http.Header), // Initialize internal map 849 | body: &bytes.Buffer{}, 850 | failAfterNWrites: -1, // Default to disabled 851 | } 852 | } 853 | 854 | // Implement http.ResponseWriter interface 855 | func (m *sseMockResponseWriter) Header() http.Header { 856 | return m.hdr // Return the internal map 857 | } 858 | 859 | func (m *sseMockResponseWriter) WriteHeader(statusCode int) { 860 | m.statusCode = statusCode 861 | } 862 | 863 | func (m *sseMockResponseWriter) Write(p []byte) (int, error) { 864 | // Check if already forced error 865 | if m.forceError != nil { 866 | return 0, m.forceError 867 | } 868 | 869 | // Increment write count 870 | m.writesMade++ 871 | 872 | // Check if write count triggers failure 873 | if m.failAfterNWrites >= 0 && m.writesMade >= m.failAfterNWrites { 874 | m.forceError = fmt.Errorf("forced write error after %d writes", m.failAfterNWrites) 875 | log.Printf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log 876 | return 0, m.forceError 877 | } 878 | 879 | // Proceed with normal write 880 | return m.body.Write(p) 881 | } 882 | 883 | // Implement http.Flusher interface 884 | func (m *sseMockResponseWriter) Flush() { 885 | // Check if already forced error 886 | if m.forceError != nil { 887 | // Optional: log or handle repeated flush attempts after error 888 | return 889 | } 890 | 891 | // Check if flush count triggers failure (less common to fail on flush, but possible) 892 | // We are primarily testing Write failures, so we might skip count check here for simplicity 893 | // or use a separate failAfterNFlushes counter if needed. 894 | 895 | m.flushed = true 896 | } 897 | 898 | // Helper to get body content 899 | func (m *sseMockResponseWriter) String() string { 900 | return m.body.String() 901 | } 902 | 903 | // --- End Mock ResponseWriter --- 904 | 905 | func TestHttpMethodGetHandler_WriteErrors(t *testing.T) { 906 | tests := []struct { 907 | name string 908 | errorOnStage string // "preamble", "endpoint", "ready", "ping", "message" 909 | forceError error // Error to set on the mock writer *before* handler runs 910 | expectConnRemoved bool 911 | }{ 912 | {"Error on Preamble (:ok)", "preamble", fmt.Errorf("forced write error during preamble"), true}, 913 | // Removed: {"Error on Endpoint Event", "endpoint", nil, true}, // Hard to simulate reliably without patching 914 | // Removed: {"Error on MCP-Ready Event", "ready", nil, true}, // Hard to simulate reliably without patching 915 | // TODO: Add test for error during keep-alive ping 916 | // TODO: Add test for error during message write from channel 917 | } 918 | 919 | for _, tc := range tests { 920 | t.Run(tc.name, func(t *testing.T) { 921 | // Use renamed mock writer 922 | mockWriter := newSseMockResponseWriter() 923 | req := httptest.NewRequest(http.MethodGet, "/mcp", nil) 924 | var connID string // Variable to capture assigned ID 925 | 926 | // Set the error on the writer *before* calling the handler 927 | if tc.forceError != nil { 928 | mockWriter.forceError = tc.forceError 929 | } 930 | 931 | // Need to capture connID *if* headers get written before error 932 | // We can check mockWriter.Header() after the handler potentially runs 933 | 934 | // Inject error based on the test stage - REMOVED FUNCTION PATCHING 935 | /* 936 | originalWriteSSE := writeSSEEvent 937 | defer func() { writeSSEEvent = originalWriteSSE }() // Restore original 938 | 939 | writeSSEEvent = func(w http.ResponseWriter, eventName string, data interface{}) error { 940 | // ... removed patching logic ... 941 | } 942 | */ 943 | 944 | // Execute handler in goroutine as it might block briefly before erroring 945 | done := make(chan struct{}) 946 | go func() { 947 | defer close(done) 948 | httpMethodGetHandler(mockWriter, req) 949 | }() 950 | 951 | // Wait for the handler goroutine to finish or timeout 952 | select { 953 | case <-done: 954 | // Handler finished (presumably due to error) 955 | case <-time.After(200 * time.Millisecond): // Generous timeout 956 | t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after injected error") 957 | } 958 | 959 | // Capture ConnID *after* handler exit, in case headers were set before error 960 | connID = mockWriter.Header().Get("X-Connection-ID") 961 | 962 | // Assert connection removal 963 | if tc.expectConnRemoved && connID != "" { 964 | connMutex.RLock() 965 | _, exists := activeConnections[connID] 966 | connMutex.RUnlock() 967 | assert.False(t, exists, "Connection %s should have been removed from activeConnections after write error", connID) 968 | } else if tc.expectConnRemoved && connID == "" { 969 | t.Log("Cannot assert connection removal as ConnID was not captured before error") 970 | } 971 | }) 972 | } 973 | } 974 | 975 | func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) { 976 | t.Run("Error_on_Message_Write", func(t *testing.T) { 977 | // Estimate writes before first message: :ok(1), endpoint(1), ready(1) = 3 writes 978 | // Target failure on the 4th write (first write of the actual message event line) 979 | mockWriter := newSseMockResponseWriter() 980 | mockWriter.failAfterNWrites = 4 // Fail on the 4th write overall 981 | 982 | req := httptest.NewRequest(http.MethodGet, "/mcp", nil) 983 | var connID string 984 | var msgChan chan jsonRPCResponse 985 | 986 | // Clean connections before test 987 | connMutex.Lock() 988 | activeConnections = make(map[string]chan jsonRPCResponse) 989 | connMutex.Unlock() 990 | defer func() { 991 | // Clean up after test, ensure channel is closed if exists 992 | connMutex.Lock() 993 | if msgChan != nil { 994 | // Only delete from map, handler is responsible for closing channel 995 | delete(activeConnections, connID) 996 | } 997 | activeConnections = make(map[string]chan jsonRPCResponse) // Reset for other tests 998 | connMutex.Unlock() 999 | }() 1000 | 1001 | done := make(chan struct{}) 1002 | go func() { 1003 | defer close(done) 1004 | httpMethodGetHandler(mockWriter, req) 1005 | log.Println("DEBUG: httpMethodGetHandler goroutine exited") 1006 | }() 1007 | 1008 | // Wait for the connection to be established 1009 | assert.Eventually(t, func() bool { 1010 | connMutex.RLock() 1011 | defer connMutex.RUnlock() 1012 | for id, ch := range activeConnections { 1013 | connID = id 1014 | msgChan = ch 1015 | log.Printf("DEBUG: Connection established: %s", connID) 1016 | return true 1017 | } 1018 | return false 1019 | }, 200*time.Millisecond, 20*time.Millisecond, "Connection not established in time") 1020 | 1021 | require.NotEmpty(t, connID, "connID should have been captured") 1022 | require.NotNil(t, msgChan, "msgChan should have been captured") 1023 | 1024 | // Send a message that should trigger the write error 1025 | testResp := jsonRPCResponse{Jsonrpc: "2.0", ID: "test-msg-1", Result: "test data"} 1026 | log.Printf("DEBUG: Sending test message to channel for %s", connID) 1027 | select { 1028 | case msgChan <- testResp: 1029 | log.Printf("DEBUG: Test message sent to channel for %s", connID) 1030 | case <-time.After(100 * time.Millisecond): 1031 | t.Fatal("Timeout sending message to channel") 1032 | } 1033 | 1034 | // Wait for the handler goroutine to finish due to the write error 1035 | select { 1036 | case <-done: 1037 | log.Printf("DEBUG: Handler goroutine finished as expected after message write error") 1038 | // Handler finished (presumably due to write error) 1039 | case <-time.After(1000 * time.Millisecond): // Increased timeout to 1 second 1040 | t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after message write error") 1041 | } 1042 | 1043 | // Assert connection removal 1044 | connMutex.RLock() 1045 | _, exists := activeConnections[connID] 1046 | connMutex.RUnlock() 1047 | assert.False(t, exists, "Connection %s should have been removed after message write error", connID) 1048 | }) 1049 | 1050 | // TODO: Add sub-test for Error_on_Ping_Write 1051 | } 1052 | 1053 | // Helper function to wait for a WaitGroup with a timeout 1054 | func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { 1055 | c := make(chan struct{}) 1056 | go func() { 1057 | defer close(c) 1058 | wg.Wait() 1059 | }() 1060 | select { 1061 | case <-c: 1062 | return true // Completed normally 1063 | case <-time.After(timeout): 1064 | return false // Timed out 1065 | } 1066 | } 1067 | --------------------------------------------------------------------------------