├── .gitignore ├── .npmignore ├── LICENSE ├── README.md ├── TODO.md ├── docs ├── api.md ├── examples.md ├── troubleshooting.md └── webhooks.md ├── mcp-replicate-0.1.0.tgz ├── mcp-replicate-0.1.1.tgz ├── package-lock.json ├── package.json ├── src ├── index.ts ├── models │ ├── collection.ts │ ├── hardware.ts │ ├── model.ts │ ├── openapi.ts │ ├── prediction.ts │ └── webhook.ts ├── replicate_client.ts ├── services │ ├── cache.ts │ ├── error.ts │ ├── image_viewer.ts │ └── webhook.ts ├── templates │ ├── manager.ts │ ├── parameters │ │ ├── quality.ts │ │ ├── size.ts │ │ └── style.ts │ └── prompts │ │ └── text_to_image.ts ├── tests │ ├── integration.test.ts │ └── protocol.test.ts ├── tools │ ├── handlers.ts │ ├── image_viewer.ts │ ├── index.ts │ ├── models.ts │ └── predictions.ts └── types │ └── mcp.ts └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | build/ 3 | *.log 4 | .env* 5 | .DS_Store 6 | dist/ 7 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | # Source 2 | src/ 3 | tests/ 4 | 5 | # Development configs 6 | .git/ 7 | .github/ 8 | .gitignore 9 | .npmrc 10 | tsconfig.json 11 | vitest.config.ts 12 | .biome.json 13 | 14 | # IDE 15 | .vscode/ 16 | .idea/ 17 | 18 | # Build artifacts 19 | coverage/ 20 | *.log 21 | .DS_Store 22 | 23 | # Development files 24 | TODO.md 25 | CONTRIBUTING.md 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 deepfates 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Replicate MCP Server 2 | 3 | A [Model Context Protocol](https://github.com/mcp-sdk/mcp) server implementation for Replicate. Run Replicate models through a simple tool-based interface. 4 | 5 | ## Quickstart 6 | 7 | 1. Install the server: 8 | 9 | ```bash 10 | npm install -g mcp-replicate 11 | ``` 12 | 13 | 2. Get your Replicate API token: 14 | 15 | - Go to [Replicate API tokens page](https://replicate.com/account/api-tokens) 16 | - Create a new token if you don't have one 17 | - Copy the token for the next step 18 | 19 | 3. Configure Claude Desktop: 20 | - Open Claude Desktop Settings (,) 21 | - Select the "Developer" section in the sidebar 22 | - Click "Edit Config" to open the configuration file 23 | - Add the following configuration, replacing `your_token_here` with your actual Replicate API token: 24 | 25 | ```json 26 | { 27 | "mcpServers": { 28 | "replicate": { 29 | "command": "mcp-replicate", 30 | "env": { 31 | "REPLICATE_API_TOKEN": "your_token_here" 32 | } 33 | } 34 | } 35 | } 36 | ``` 37 | 38 | 4. Start Claude Desktop. You should see a 🔨 hammer icon in the bottom right corner of new chat windows, indicating the tools are available. 39 | 40 | (You can also use any other MCP client, such as Cursor, Cline, or Continue.) 41 | 42 | ## Alternative Installation Methods 43 | 44 | ### Install from source 45 | 46 | ```bash 47 | git clone https://github.com/deepfates/mcp-replicate 48 | cd mcp-replicate 49 | npm install 50 | npm run build 51 | npm start 52 | ``` 53 | 54 | ### Run with npx 55 | 56 | ```bash 57 | npx mcp-replicate 58 | ``` 59 | 60 | ## Features 61 | 62 | ### Models 63 | 64 | - Search models using semantic search 65 | - Browse models and collections 66 | - Get detailed model information and versions 67 | 68 | ### Predictions 69 | 70 | - Create predictions with text or structured input 71 | - Track prediction status 72 | - Cancel running predictions 73 | - List your recent predictions 74 | 75 | ### Image Handling 76 | 77 | - View generated images in your browser 78 | - Manage image cache for better performance 79 | 80 | ## Configuration 81 | 82 | The server needs a Replicate API token to work. You can get one at [Replicate](https://replicate.com/account/api-tokens). 83 | 84 | There are two ways to provide the token: 85 | 86 | ### 1. In Claude Desktop Config (Recommended) 87 | 88 | Add it to your Claude Desktop configuration as shown in the Quickstart section: 89 | 90 | ```json 91 | { 92 | "mcpServers": { 93 | "replicate": { 94 | "command": "mcp-replicate", 95 | "env": { 96 | "REPLICATE_API_TOKEN": "your_token_here" 97 | } 98 | } 99 | } 100 | } 101 | ``` 102 | 103 | ### 2. As Environment Variable 104 | 105 | Alternatively, you can set it as an environment variable if you're using another MCP client: 106 | 107 | ```bash 108 | export REPLICATE_API_TOKEN=your_token_here 109 | ``` 110 | 111 | ## Available Tools 112 | 113 | ### Model Tools 114 | 115 | - `search_models`: Find models using semantic search 116 | - `list_models`: Browse available models 117 | - `get_model`: Get details about a specific model 118 | - `list_collections`: Browse model collections 119 | - `get_collection`: Get details about a specific collection 120 | 121 | ### Prediction Tools 122 | 123 | - `create_prediction`: Run a model with your inputs 124 | - `create_and_poll_prediction`: Run a model with your inputs and wait until it's completed 125 | - `get_prediction`: Check a prediction's status 126 | - `cancel_prediction`: Stop a running prediction 127 | - `list_predictions`: See your recent predictions 128 | 129 | ### Image Tools 130 | 131 | - `view_image`: Open an image in your browser 132 | - `clear_image_cache`: Clean up cached images 133 | - `get_image_cache_stats`: Check cache usage 134 | 135 | ## Troubleshooting 136 | 137 | ### Server is running but tools aren't showing up 138 | 139 | 1. Check that Claude Desktop is properly configured with the MCP server settings 140 | 2. Ensure your Replicate API token is set correctly 141 | 3. Try restarting both the server and Claude Desktop 142 | 4. Check the server logs for any error messages 143 | 144 | ### Tools are visible but not working 145 | 146 | 1. Verify your Replicate API token is valid 147 | 2. Check your internet connection 148 | 3. Look for any error messages in the server output 149 | 150 | ## Development 151 | 152 | 1. Install dependencies: 153 | 154 | ```bash 155 | npm install 156 | ``` 157 | 158 | 2. Start development server (with auto-reload): 159 | 160 | ```bash 161 | npm run dev 162 | ``` 163 | 164 | 3. Check code style: 165 | 166 | ```bash 167 | npm run lint 168 | ``` 169 | 170 | 4. Format code: 171 | 172 | ```bash 173 | npm run format 174 | ``` 175 | 176 | ## Requirements 177 | 178 | - Node.js >= 18.0.0 179 | - TypeScript >= 5.0.0 180 | - [Claude Desktop](https://claude.ai/download) for using the tools 181 | 182 | ## License 183 | 184 | MIT 185 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # MCP Server for Replicate - Simplified Implementation Plan 2 | 3 | ## Core Philosophy 4 | - Minimize complexity by focusing on tools over resources 5 | - Follow MCP spec for core functionality 6 | - Keep transport layer simple (stdio for local, SSE for remote) 7 | - Implement only essential features 8 | 9 | ## Current Status 10 | ✓ Basic functionality implemented: 11 | - Tool-based access to models and predictions 12 | - Type-safe interactions with protocol compliance 13 | - Simple error handling 14 | - Basic rate limiting 15 | - SSE transport layer for remote connections 16 | 17 | ## Implementation Plan 18 | 19 | ### Phase 1: Core Simplification (✓ Complete) 20 | 1. Replace Resource System with Tools 21 | - [x] Convert model listing to search_models tool 22 | - [x] Convert prediction access to get_prediction tool 23 | - [x] Remove resource-based URI schemes 24 | - [x] Simplify server initialization 25 | 26 | 2. Streamline Client Implementation 27 | - [x] Simplify ReplicateClient class 28 | - [x] Remove complex caching layers 29 | - [x] Implement basic error handling 30 | - [x] Add simple rate limiting 31 | 32 | 3. Transport Layer 33 | - [x] Keep stdio for local communication 34 | - [x] Implement basic SSE for remote (no complex retry logic) 35 | - [x] Remove unnecessary transport abstractions 36 | 37 | ### Phase 2: Essential Tools (✓ Complete) 38 | 1. Model Management 39 | - [x] search_models - Find models by query 40 | - [x] get_model - Get model details 41 | - [x] list_versions - List model versions 42 | 43 | 2. Prediction Handling 44 | - [x] create_prediction - Run model inference 45 | - [x] get_prediction - Check prediction status 46 | - [x] cancel_prediction - Stop running prediction 47 | 48 | 3. Image Tools 49 | - [x] view_image - Display result in browser 50 | - [x] save_image - Save to local filesystem 51 | 52 | ### Phase 3: Testing & Documentation (🚧 In Progress) 53 | 1. Testing 54 | - [x] Add basic protocol compliance tests 55 | - [x] Test core tool functionality 56 | - [x] Add integration tests 57 | 58 | 2. Documentation 59 | - [x] Update API reference for simplified interface 60 | - [ ] Add clear usage examples 61 | - [ ] Create troubleshooting guide 62 | 63 | ### Phase 4: Optional Enhancements (🚧 In Progress) 64 | 1. Webhook Support 65 | - [x] Simple webhook configuration 66 | - [x] Basic retry logic 67 | - [x] Event formatting 68 | 69 | 2. Template System 70 | - [ ] Basic parameter templates 71 | - [ ] Simple validation 72 | - [ ] Example presets 73 | 74 | ## Next Steps 75 | 76 | 1. Documentation: 77 | - Add clear usage examples 78 | - Create troubleshooting guide 79 | - Document common error cases 80 | 81 | 2. Template System: 82 | - Design parameter template format 83 | - Implement validation logic 84 | - Create example presets 85 | 86 | 3. Testing: 87 | - Add more edge case tests 88 | - Improve error handling coverage 89 | - Add performance benchmarks 90 | 91 | Legend: 92 | - [x] Completed 93 | - [ ] Not started 94 | - ✓ Phase complete 95 | - 🚧 Phase in progress 96 | - ❌ Phase not started 97 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # Replicate MCP Server API Reference 2 | 3 | ## Overview 4 | 5 | The Replicate MCP Server provides a Model Context Protocol (MCP) interface to Replicate's AI model platform. This document details the available tools and methods for interacting with the server. 6 | 7 | ## Tools 8 | 9 | ### search_models 10 | 11 | Search for models using semantic search. 12 | 13 | #### Input Schema 14 | ```typescript 15 | { 16 | query: string // Search query 17 | } 18 | ``` 19 | 20 | #### Example 21 | ```typescript 22 | await client.searchModels({ 23 | query: "high quality text to image models" 24 | }); 25 | ``` 26 | 27 | ### list_models 28 | 29 | List available models with optional filtering. 30 | 31 | #### Input Schema 32 | ```typescript 33 | { 34 | owner?: string, // Filter by model owner 35 | cursor?: string // Pagination cursor 36 | } 37 | ``` 38 | 39 | #### Example 40 | ```typescript 41 | await client.listModels({ 42 | owner: "stability-ai" 43 | }); 44 | ``` 45 | 46 | ### create_prediction 47 | 48 | Create a new prediction using a model version. 49 | 50 | #### Input Schema 51 | ```typescript 52 | { 53 | version: string, // Model version ID 54 | input: Record,// Model input parameters 55 | webhook_url?: string // Optional webhook URL 56 | } 57 | ``` 58 | 59 | #### Example 60 | ```typescript 61 | await client.createPrediction({ 62 | version: "stability-ai/sdxl@v1.0.0", 63 | input: { 64 | prompt: "A serene mountain landscape" 65 | } 66 | }); 67 | ``` 68 | 69 | ### cancel_prediction 70 | 71 | Cancel a running prediction. 72 | 73 | #### Input Schema 74 | ```typescript 75 | { 76 | prediction_id: string // ID of prediction to cancel 77 | } 78 | ``` 79 | 80 | #### Example 81 | ```typescript 82 | await client.cancelPrediction({ 83 | prediction_id: "pred_123abc" 84 | }); 85 | ``` 86 | 87 | ### get_prediction 88 | 89 | Get details about a specific prediction. 90 | 91 | #### Input Schema 92 | ```typescript 93 | { 94 | prediction_id: string // ID of prediction to get details for 95 | } 96 | ``` 97 | 98 | #### Example 99 | ```typescript 100 | await client.getPrediction({ 101 | prediction_id: "pred_123abc" 102 | }); 103 | ``` 104 | 105 | ## Templates 106 | 107 | The server includes a template system for common parameter configurations. 108 | 109 | ### Quality Templates 110 | 111 | Preset quality levels for image generation: 112 | - `draft`: Fast, lower quality results 113 | - `balanced`: Good balance of speed and quality 114 | - `quality`: High quality output 115 | - `extreme`: Maximum quality, slower generation 116 | 117 | ### Style Templates 118 | 119 | Common artistic styles: 120 | - `photographic`: Realistic photo-like images 121 | - `digital-art`: Digital artwork style 122 | - `cinematic`: Movie-like composition 123 | - `anime`: Anime/manga style 124 | - `painting`: Traditional painting styles 125 | 126 | ### Size Templates 127 | 128 | Standard image dimensions: 129 | - `square`: 1024x1024 130 | - `portrait`: 832x1216 131 | - `landscape`: 1216x832 132 | - `widescreen`: 1344x768 133 | 134 | ## Error Handling 135 | 136 | The server uses a simple error handling system with a single error class and clear error messages. 137 | 138 | ### ReplicateError 139 | 140 | All errors from the Replicate API are instances of `ReplicateError`. This class provides a consistent way to handle errors across the API. 141 | 142 | ```typescript 143 | class ReplicateError extends Error { 144 | name: "ReplicateError"; 145 | message: string; 146 | context?: Record; 147 | } 148 | ``` 149 | 150 | ### Error Factory Functions 151 | 152 | The API provides factory functions to create standardized errors: 153 | 154 | ```typescript 155 | const createError = { 156 | rateLimit: (retryAfter: number) => 157 | new ReplicateError("Rate limit exceeded", { retryAfter }), 158 | 159 | authentication: (details?: string) => 160 | new ReplicateError("Authentication failed", { details }), 161 | 162 | notFound: (resource: string) => 163 | new ReplicateError("Model not found", { resource }), 164 | 165 | validation: (field: string, message: string) => 166 | new ReplicateError("Invalid input parameters", { field, message }), 167 | 168 | timeout: (operation: string, ms: number) => 169 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }), 170 | }; 171 | ``` 172 | 173 | ### Example Error Handling 174 | 175 | ```typescript 176 | try { 177 | const prediction = await client.createPrediction({ 178 | version: "stability-ai/sdxl@latest", 179 | input: { prompt: "Test prompt" } 180 | }); 181 | } catch (error) { 182 | if (error instanceof ReplicateError) { 183 | console.error(error.message); 184 | // Access additional context if available 185 | if (error.context) { 186 | console.error("Error context:", error.context); 187 | } 188 | } 189 | } 190 | ``` 191 | 192 | ### Automatic Retries 193 | 194 | The API includes built-in retry functionality for certain types of errors: 195 | 196 | ```typescript 197 | const result = await ErrorHandler.withRetries( 198 | async () => client.listModels(), 199 | { 200 | maxAttempts: 3, 201 | minDelay: 1000, 202 | maxDelay: 30000, 203 | retryIf: (error) => error instanceof ReplicateError 204 | } 205 | ); 206 | ``` 207 | 208 | ## Rate Limiting 209 | 210 | The server implements basic rate limiting to prevent abuse. When rate limits are exceeded, the API returns a `ReplicateError` with the message "Rate limit exceeded" and includes a `retryAfter` value in the context. 211 | 212 | ### Limits 213 | - API requests per minute 214 | - Concurrent predictions per user 215 | 216 | ### Rate Limit Headers 217 | ```typescript 218 | { 219 | "X-RateLimit-Limit": "100", // Maximum requests per minute 220 | "X-RateLimit-Remaining": "95", // Remaining requests 221 | "X-RateLimit-Reset": "1704891600" // Unix timestamp when limit resets 222 | } 223 | ``` 224 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Replicate MCP Server Usage Examples 2 | 3 | This document provides practical examples of common use cases for the Replicate MCP server. 4 | 5 | ## Basic Usage 6 | 7 | ### Generating Images with Text 8 | 9 | ```typescript 10 | // Create a prediction with SDXL 11 | const prediction = await client.createPrediction({ 12 | version: "stability-ai/sdxl@latest", 13 | input: { 14 | prompt: "A serene mountain landscape at sunset", 15 | quality: "balanced", 16 | style: "photographic" 17 | } 18 | }); 19 | 20 | // Get prediction status 21 | const status = await client.getPrediction({ 22 | prediction_id: prediction.id 23 | }); 24 | 25 | // Cancel a running prediction if needed 26 | await client.cancelPrediction({ 27 | prediction_id: prediction.id 28 | }); 29 | ``` 30 | 31 | ### Browsing Models 32 | 33 | ```typescript 34 | // List all models (uses caching) 35 | const models = await client.listModels({}); 36 | 37 | // Filter models by owner (cached by owner) 38 | const stabilityModels = await client.listModels({ 39 | owner: "stability-ai" 40 | }); 41 | 42 | // Search for specific models (cached by query) 43 | const searchResults = await client.searchModels({ 44 | query: "text to image models with good quality" 45 | }); 46 | ``` 47 | 48 | ### Working with Collections 49 | 50 | ```typescript 51 | // List available collections (cached) 52 | const collections = await client.listCollections({}); 53 | 54 | // Get details of a specific collection (cached by slug) 55 | const textToImage = await client.getCollection({ 56 | slug: "text-to-image" 57 | }); 58 | 59 | // Browse models in a collection 60 | const collectionModels = textToImage.models; 61 | ``` 62 | 63 | ## Advanced Usage 64 | 65 | ### Using Templates 66 | 67 | ```typescript 68 | // Using quality templates 69 | const highQualityPrediction = await client.createPrediction({ 70 | version: "stability-ai/sdxl@latest", 71 | input: { 72 | prompt: "A futuristic cityscape", 73 | ...templates.quality.extreme, 74 | ...templates.style.cinematic, 75 | ...templates.size.widescreen 76 | } 77 | }); 78 | ``` 79 | 80 | ### Webhook Integration 81 | 82 | ```typescript 83 | // Create prediction with webhook notification 84 | const prediction = await client.createPrediction({ 85 | version: "stability-ai/sdxl@latest", 86 | input: { 87 | prompt: "An abstract digital artwork" 88 | }, 89 | webhook_url: "https://api.myapp.com/webhooks/replicate" 90 | }); 91 | 92 | // Example webhook handler (Express.js) 93 | app.post("/webhooks/replicate", async (req, res) => { 94 | try { 95 | const signature = req.headers["x-replicate-signature"]; 96 | const webhookSecret = await client.getWebhookSecret(); 97 | 98 | if (!verifyWebhookSignature(signature, webhookSecret, req.body)) { 99 | throw new ValidationError("Invalid signature"); 100 | } 101 | 102 | const { event, prediction } = req.body; 103 | switch (event) { 104 | case "prediction.completed": 105 | await handleCompletedPrediction(prediction); 106 | break; 107 | case "prediction.failed": 108 | await handleFailedPrediction(prediction); 109 | break; 110 | } 111 | 112 | res.status(200).send("OK"); 113 | } catch (error) { 114 | console.error(ErrorHandler.createErrorReport(error)); 115 | res.status(error instanceof ValidationError ? 401 : 500).json({ 116 | error: error.message, 117 | code: error.name 118 | }); 119 | } 120 | }); 121 | ``` 122 | 123 | ### Error Handling 124 | 125 | ```typescript 126 | try { 127 | const prediction = await client.createPrediction({ 128 | version: "stability-ai/sdxl@latest", 129 | input: { prompt: "Test prompt" } 130 | }); 131 | } catch (error) { 132 | if (error instanceof ReplicateError) { 133 | console.error(error.message); 134 | // Optional: access additional context if available 135 | if (error.context) { 136 | console.error("Error context:", error.context); 137 | } 138 | } else { 139 | console.error("Unexpected error:", error); 140 | } 141 | } 142 | ``` 143 | 144 | ### Automatic Retries 145 | 146 | ```typescript 147 | // Using the built-in retry functionality 148 | const result = await ErrorHandler.withRetries( 149 | async () => client.createPrediction({ 150 | version: "stability-ai/sdxl@latest", 151 | input: { prompt: "A test image" } 152 | }), 153 | { 154 | maxAttempts: 3, 155 | minDelay: 1000, 156 | maxDelay: 10000, 157 | retryIf: (error) => error instanceof ReplicateError, 158 | onRetry: (error, attempt) => { 159 | console.warn( 160 | `Request failed: ${error.message}. `, 161 | `Retrying (attempt ${attempt + 1}/3)` 162 | ); 163 | } 164 | } 165 | ); 166 | ``` 167 | 168 | ### Handling Common Errors 169 | 170 | ```typescript 171 | try { 172 | const prediction = await client.createPrediction({ 173 | version: "stability-ai/sdxl@latest", 174 | input: { prompt: "Test prompt" } 175 | }); 176 | } catch (error) { 177 | if (error instanceof ReplicateError) { 178 | switch (error.message) { 179 | case "Rate limit exceeded": 180 | const retryAfter = error.context?.retryAfter; 181 | console.log(`Rate limit hit. Retry after ${retryAfter} seconds`); 182 | break; 183 | case "Authentication failed": 184 | console.log("Please check your API token"); 185 | break; 186 | case "Model not found": 187 | console.log(`Model ${error.context?.resource} not found`); 188 | break; 189 | case "Invalid input parameters": 190 | console.log(`Invalid input: ${error.context?.field} - ${error.context?.message}`); 191 | break; 192 | default: 193 | console.log("Operation failed:", error.message); 194 | } 195 | } 196 | } 197 | ``` 198 | 199 | ### Batch Processing 200 | 201 | ```typescript 202 | // Batch processing example 203 | async function batchProcess(prompts: string[]) { 204 | const predictions = await Promise.all( 205 | prompts.map(prompt => 206 | client.createPrediction({ 207 | version: "stability-ai/sdxl@latest", 208 | input: { prompt } 209 | }) 210 | ) 211 | ); 212 | 213 | // Monitor all predictions 214 | return Promise.all( 215 | predictions.map(prediction => 216 | client.getPredictionStatus(prediction.id) 217 | ) 218 | ); 219 | } 220 | 221 | // Usage 222 | const prompts = [ 223 | "A serene beach at sunset", 224 | "A mystical forest with glowing mushrooms", 225 | "A futuristic space station" 226 | ]; 227 | 228 | const results = await batchProcess(prompts); 229 | ``` 230 | 231 | ## Best Practices 232 | 233 | 1. Use proper error handling: 234 | ```typescript 235 | try { 236 | // Your code here 237 | } catch (error) { 238 | if (error instanceof ReplicateError) { 239 | console.error(error.message); 240 | } else { 241 | console.error("Unexpected error:", error); 242 | } 243 | } 244 | ``` 245 | 246 | 2. Implement proper retry logic: 247 | ```typescript 248 | let attempts = 0; 249 | const maxAttempts = 3; 250 | 251 | while (attempts < maxAttempts) { 252 | try { 253 | const result = await makeRequest(); 254 | break; 255 | } catch (error) { 256 | attempts++; 257 | if (attempts === maxAttempts) throw error; 258 | await new Promise(resolve => setTimeout(resolve, 1000 * attempts)); 259 | } 260 | } 261 | ``` 262 | 263 | 3. Use webhooks effectively: 264 | ```typescript 265 | // Create prediction with webhook notification 266 | const prediction = await client.createPrediction({ 267 | version: "stability-ai/sdxl@latest", 268 | input: { 269 | prompt: "An abstract digital artwork" 270 | }, 271 | webhook_url: "https://api.myapp.com/webhooks/replicate" 272 | }); 273 | ``` 274 | 275 | 4. Handle rate limits gracefully: 276 | ```typescript 277 | try { 278 | await makeRequest(); 279 | } catch (error) { 280 | if (error.message.includes("rate limit")) { 281 | console.log("Rate limit hit. Please try again later."); 282 | } 283 | } 284 | ``` 285 | 286 | 5. Clean up resources: 287 | ```typescript 288 | try { 289 | // Use API 290 | await makeRequest(); 291 | } catch (error) { 292 | // Handle errors 293 | console.error(error); 294 | } 295 | ``` 296 | -------------------------------------------------------------------------------- /docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting Guide 2 | 3 | This guide helps you diagnose and resolve common issues with the Replicate MCP Server. 4 | 5 | ## Error Handling System 6 | 7 | ### ReplicateError Class 8 | 9 | ```typescript 10 | // Base error class with context 11 | class ReplicateError extends Error { 12 | constructor(message: string, public context?: Record) { 13 | super(message); 14 | this.name = "ReplicateError"; 15 | } 16 | } 17 | ``` 18 | 19 | ### Error Factory Functions 20 | 21 | ```typescript 22 | const createError = { 23 | rateLimit: (retryAfter: number) => 24 | new ReplicateError("Rate limit exceeded", { retryAfter }), 25 | 26 | authentication: (details?: string) => 27 | new ReplicateError("Authentication failed", { details }), 28 | 29 | notFound: (resource: string) => 30 | new ReplicateError("Model not found", { resource }), 31 | 32 | validation: (field: string, message: string) => 33 | new ReplicateError("Invalid input parameters", { field, message }), 34 | 35 | timeout: (operation: string, ms: number) => 36 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }), 37 | }; 38 | ``` 39 | 40 | ### Error Reports 41 | 42 | ```typescript 43 | interface ErrorReport { 44 | name: string; 45 | message: string; 46 | context?: Record; 47 | timestamp: string; 48 | } 49 | 50 | // Generate error reports 51 | const report = ErrorHandler.createErrorReport(error); 52 | console.error(JSON.stringify(report, null, 2)); 53 | ``` 54 | 55 | ## Common Issues 56 | 57 | ### 1. Authentication Issues 58 | 59 | #### Symptoms 60 | - "Authentication failed" errors 61 | - 401 Unauthorized responses 62 | - Authentication-related webhook failures 63 | 64 | #### Solutions 65 | 1. Verify API Token 66 | ```typescript 67 | try { 68 | await client.listModels(); 69 | } catch (error) { 70 | if (error instanceof ReplicateError && error.message === "Authentication failed") { 71 | console.error("Authentication failed:", error.context?.details); 72 | } 73 | } 74 | ``` 75 | 76 | 2. Check Environment Variables 77 | ```typescript 78 | if (!process.env.REPLICATE_API_TOKEN) { 79 | throw createError.authentication("Missing API token"); 80 | } 81 | ``` 82 | 83 | ### 2. Rate Limiting Issues 84 | 85 | #### Handling Rate Limits 86 | 87 | ```typescript 88 | try { 89 | await client.createPrediction({ 90 | version: "stability-ai/sdxl@latest", 91 | input: { prompt: "test" } 92 | }); 93 | } catch (error) { 94 | if (error instanceof ReplicateError && error.message === "Rate limit exceeded") { 95 | const retryAfter = error.context?.retryAfter; 96 | console.log(`Rate limit hit. Retry after ${retryAfter} seconds`); 97 | } 98 | } 99 | ``` 100 | 101 | ### 3. Network Issues 102 | 103 | #### Enhanced Network Error Handling 104 | 105 | ```typescript 106 | const result = await ErrorHandler.withRetries( 107 | async () => client.createPrediction({ 108 | version: "stability-ai/sdxl@latest", 109 | input: { prompt: "test" } 110 | }), 111 | { 112 | maxAttempts: 3, 113 | minDelay: 1000, 114 | maxDelay: 30000, 115 | retryIf: (error) => error instanceof ReplicateError, 116 | onRetry: (error, attempt) => { 117 | console.warn( 118 | `Request failed: ${error.message}. `, 119 | `Retrying (attempt ${attempt + 1}/3)` 120 | ); 121 | } 122 | } 123 | ); 124 | ``` 125 | 126 | ### 4. Prediction Issues 127 | 128 | #### Input Validation 129 | 130 | ```typescript 131 | try { 132 | const prediction = await client.createPrediction({ 133 | version: "stability-ai/sdxl@latest", 134 | input: { prompt: "" } // Invalid empty prompt 135 | }); 136 | } catch (error) { 137 | if (error instanceof ReplicateError && error.message === "Invalid input parameters") { 138 | console.error(`Validation error: ${error.context?.field} - ${error.context?.message}`); 139 | } 140 | } 141 | ``` 142 | 143 | ## Getting Help 144 | 145 | If you're experiencing issues: 146 | 147 | 1. Check error reports: 148 | ```typescript 149 | const errorReport = ErrorHandler.createErrorReport(error); 150 | console.error(JSON.stringify(errorReport, null, 2)); 151 | ``` 152 | 153 | 2. Generate a diagnostic report: 154 | ```typescript 155 | const diagnostics = { 156 | error: ErrorHandler.createErrorReport(error), 157 | timestamp: new Date().toISOString(), 158 | environment: { 159 | nodeVersion: process.version, 160 | platform: process.platform, 161 | arch: process.arch, 162 | env: process.env.NODE_ENV 163 | } 164 | }; 165 | ``` 166 | 167 | 3. Contact support with detailed information: 168 | - Error reports with context 169 | - Environment details 170 | - Steps to reproduce 171 | 172 | 4. Resources: 173 | - [GitHub Issues](https://github.com/replicate/replicate/issues) 174 | - [API Documentation](https://replicate.com/docs) 175 | - [Discord Community](https://discord.gg/replicate) 176 | - [Support Email](mailto:support@replicate.com) 177 | -------------------------------------------------------------------------------- /docs/webhooks.md: -------------------------------------------------------------------------------- 1 | # Webhook Integration Guide 2 | 3 | This guide explains how to integrate webhooks with the Replicate MCP Server for receiving updates about predictions. 4 | 5 | ## Overview 6 | 7 | Webhooks provide a way to receive asynchronous notifications about events in your Replicate predictions. Instead of polling for updates, your application can receive push notifications when important events occur. 8 | 9 | ## Supported Events 10 | 11 | The server sends webhooks for the following events: 12 | 13 | - `prediction.started`: When a prediction begins processing 14 | - `prediction.completed`: When a prediction successfully completes 15 | - `prediction.failed`: When a prediction encounters an error 16 | 17 | ## Webhook Payload 18 | 19 | Each webhook delivery includes a JSON payload with event details: 20 | 21 | ```typescript 22 | interface WebhookPayload { 23 | event: string; // Event type 24 | prediction: { // Prediction details 25 | id: string; // Prediction ID 26 | version: string; // Model version 27 | input: Record; // Input parameters 28 | output?: any; // Output data (for completed predictions) 29 | error?: string; // Error message (for failed predictions) 30 | status: string; // Current status 31 | created_at: string; // Creation timestamp 32 | started_at?: string; // Processing start timestamp 33 | completed_at?: string; // Completion timestamp 34 | }; 35 | timestamp: string; // Event timestamp 36 | } 37 | ``` 38 | 39 | ## Setting Up Webhooks 40 | 41 | ### 1. Create a Webhook Endpoint 42 | 43 | First, create an endpoint in your application to receive webhook notifications. Example using Express: 44 | 45 | ```typescript 46 | import express from "express"; 47 | 48 | const app = express(); 49 | app.use(express.json()); 50 | 51 | app.post("/webhooks/replicate", async (req, res) => { 52 | try { 53 | const { event, prediction } = req.body; 54 | 55 | switch (event) { 56 | case "prediction.started": 57 | await handlePredictionStarted(prediction); 58 | break; 59 | case "prediction.completed": 60 | await handlePredictionCompleted(prediction); 61 | break; 62 | case "prediction.failed": 63 | await handlePredictionFailed(prediction); 64 | break; 65 | default: 66 | console.warn(`Unknown event type: ${event}`); 67 | } 68 | 69 | // Return 200 OK quickly to acknowledge receipt 70 | res.status(200).send("OK"); 71 | } catch (error) { 72 | // Log error for debugging 73 | console.error("Webhook processing failed:", error); 74 | 75 | // Return 500 error 76 | res.status(500).json({ 77 | error: "Webhook processing failed", 78 | details: error instanceof Error ? error.message : String(error) 79 | }); 80 | } 81 | }); 82 | ``` 83 | 84 | ### 2. Configure Webhook URL 85 | 86 | When creating a prediction, include your webhook URL: 87 | 88 | ```typescript 89 | const prediction = await client.createPrediction({ 90 | version: "stability-ai/sdxl@latest", 91 | input: { 92 | prompt: "A serene landscape" 93 | }, 94 | webhook_url: "https://api.yourapp.com/webhooks/replicate" 95 | }); 96 | ``` 97 | 98 | ## Error Handling 99 | 100 | When processing webhooks, implement proper error handling: 101 | 102 | ```typescript 103 | app.post("/webhooks/replicate", async (req, res) => { 104 | try { 105 | const { event, prediction } = req.body; 106 | 107 | // Process webhook asynchronously 108 | processWebhookAsync(event, prediction).catch(error => { 109 | if (error instanceof ReplicateError) { 110 | console.error("Webhook processing failed:", error.message, error.context); 111 | } else { 112 | console.error("Unexpected error in webhook processing:", error); 113 | } 114 | }); 115 | 116 | // Return success quickly 117 | res.status(200).send("OK"); 118 | } catch (error) { 119 | if (error instanceof ReplicateError) { 120 | console.error("Webhook error:", error.message, error.context); 121 | res.status(400).json({ 122 | error: error.message, 123 | context: error.context 124 | }); 125 | } else { 126 | console.error("Unexpected webhook error:", error); 127 | res.status(500).json({ 128 | error: "Internal server error", 129 | message: error instanceof Error ? error.message : String(error) 130 | }); 131 | } 132 | } 133 | }); 134 | 135 | async function processWebhookAsync(event: string, prediction: any) { 136 | try { 137 | switch (event) { 138 | case "prediction.completed": 139 | await handlePredictionCompleted(prediction); 140 | break; 141 | case "prediction.failed": 142 | await handlePredictionFailed(prediction); 143 | break; 144 | default: 145 | throw createError.validation("event", `Unknown event type: ${event}`); 146 | } 147 | } catch (error) { 148 | // Log error but don't throw since we're in an async context 149 | if (error instanceof ReplicateError) { 150 | console.error("Failed to process webhook:", error.message, error.context); 151 | } else { 152 | console.error("Unexpected error in webhook processing:", error); 153 | } 154 | } 155 | } 156 | ``` 157 | 158 | ## Best Practices 159 | 160 | 1. **Use HTTPS** 161 | - Always use HTTPS for webhook endpoints 162 | - Ensure proper TLS configuration 163 | 164 | 2. **Handle Errors Gracefully** 165 | - Use the `ReplicateError` class for consistent error handling 166 | - Include relevant context in error messages 167 | - Return appropriate status codes 168 | 169 | 3. **Process Asynchronously** 170 | - Handle webhook processing in the background 171 | - Return 200 OK quickly to acknowledge receipt 172 | - Use error handling in async handlers 173 | 174 | 4. **Monitor Webhook Health** 175 | - Log webhook deliveries and errors 176 | - Track success/failure rates 177 | - Monitor processing times 178 | -------------------------------------------------------------------------------- /mcp-replicate-0.1.0.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepfates/mcp-replicate/1c7a2e6b01f78959414ca1a4a9973e3cd878216b/mcp-replicate-0.1.0.tgz -------------------------------------------------------------------------------- /mcp-replicate-0.1.1.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepfates/mcp-replicate/1c7a2e6b01f78959414ca1a4a9973e3cd878216b/mcp-replicate-0.1.1.tgz -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mcp-replicate", 3 | "version": "0.1.1", 4 | "description": "Run Replicate models through a simple MCP server interface", 5 | "type": "module", 6 | "main": "build/index.js", 7 | "types": "build/index.d.ts", 8 | "bin": { 9 | "mcp-replicate": "build/index.js" 10 | }, 11 | "files": [ 12 | "build", 13 | "README.md", 14 | "LICENSE" 15 | ], 16 | "scripts": { 17 | "build": "tsc", 18 | "postbuild": "chmod +x build/index.js", 19 | "start": "node build/index.js", 20 | "dev": "tsc -w", 21 | "test": "vitest", 22 | "test:watch": "vitest watch", 23 | "test:coverage": "vitest run --coverage", 24 | "lint": "biome check .", 25 | "format": "biome format . --write", 26 | "prepublishOnly": "npm run build" 27 | }, 28 | "dependencies": { 29 | "@modelcontextprotocol/sdk": "^0.6.0", 30 | "replicate": "^1.0.1" 31 | }, 32 | "devDependencies": { 33 | "@biomejs/biome": "^1.5.3", 34 | "@types/node": "^20.11.5", 35 | "typescript": "^5.3.3", 36 | "vitest": "^1.2.1" 37 | }, 38 | "engines": { 39 | "node": ">=18.0.0" 40 | }, 41 | "keywords": [ 42 | "replicate", 43 | "mcp", 44 | "machine-learning", 45 | "ai", 46 | "model-context-protocol" 47 | ], 48 | "author": "deepfates", 49 | "license": "MIT", 50 | "repository": { 51 | "type": "git", 52 | "url": "https://github.com/deepfates/mcp-replicate" 53 | }, 54 | "bugs": { 55 | "url": "https://github.com/deepfates/mcp-replicate/issues" 56 | }, 57 | "homepage": "https://github.com/deepfates/mcp-replicate#readme", 58 | "publishConfig": { 59 | "access": "public" 60 | }, 61 | "vitest": { 62 | "include": [ 63 | "src/**/*.test.ts" 64 | ], 65 | "exclude": [ 66 | "node_modules", 67 | "build" 68 | ], 69 | "environment": "node", 70 | "testTimeout": 10000, 71 | "coverage": { 72 | "provider": "v8", 73 | "reporter": [ 74 | "text", 75 | "html" 76 | ], 77 | "exclude": [ 78 | "node_modules", 79 | "build", 80 | "**/*.test.ts", 81 | "**/*.d.ts" 82 | ] 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | /** 4 | * MCP server implementation for Replicate. 5 | * Provides access to Replicate models and predictions through MCP. 6 | */ 7 | 8 | import { Server } from "@modelcontextprotocol/sdk/server/index.js"; 9 | import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; 10 | import { 11 | CallToolRequestSchema, 12 | ListToolsRequestSchema, 13 | } from "@modelcontextprotocol/sdk/types.js"; 14 | 15 | import { ReplicateClient } from "./replicate_client.js"; 16 | import type { Model } from "./models/model.js"; 17 | import type { 18 | Prediction, 19 | ModelIO, 20 | PredictionStatus, 21 | } from "./models/prediction.js"; 22 | import type { Collection } from "./models/collection.js"; 23 | 24 | import { tools } from "./tools/index.js"; 25 | import { 26 | handleSearchModels, 27 | handleListModels, 28 | handleListCollections, 29 | handleGetCollection, 30 | handleCreatePrediction, 31 | handleCreateAndPollPrediction, 32 | handleCancelPrediction, 33 | handleGetPrediction, 34 | handleListPredictions, 35 | handleGetModel, 36 | } from "./tools/handlers.js"; 37 | import { 38 | handleViewImage, 39 | handleClearImageCache, 40 | handleGetImageCacheStats, 41 | } from "./tools/image_viewer.js"; 42 | 43 | // Initialize Replicate client 44 | const client = new ReplicateClient(); 45 | 46 | // Cache for models, predictions, collections, and prediction status 47 | const modelCache = new Map(); 48 | const predictionCache = new Map(); 49 | const collectionCache = new Map(); 50 | const predictionStatus = new Map(); 51 | 52 | // Cache object for tool handlers 53 | const cache = { 54 | modelCache, 55 | predictionCache, 56 | collectionCache, 57 | predictionStatus, 58 | }; 59 | 60 | /** 61 | * Create an MCP server with capabilities for 62 | * tools (to run predictions) 63 | */ 64 | const server = new Server( 65 | { 66 | name: "replicate", 67 | version: "0.1.0", 68 | }, 69 | { 70 | capabilities: { 71 | tools: {}, 72 | prompts: {}, 73 | }, 74 | } 75 | ); 76 | 77 | /** 78 | * Handler that lists available tools. 79 | */ 80 | server.setRequestHandler(ListToolsRequestSchema, async () => { 81 | return { tools }; 82 | }); 83 | 84 | /** 85 | * Handler for tools. 86 | */ 87 | server.setRequestHandler(CallToolRequestSchema, async (request) => { 88 | switch (request.params.name) { 89 | case "search_models": 90 | return handleSearchModels(client, cache, { 91 | query: String(request.params.arguments?.query), 92 | }); 93 | 94 | case "list_models": 95 | return handleListModels(client, cache, { 96 | owner: request.params.arguments?.owner as string | undefined, 97 | cursor: request.params.arguments?.cursor as string | undefined, 98 | }); 99 | 100 | case "list_collections": 101 | return handleListCollections(client, cache, { 102 | cursor: request.params.arguments?.cursor as string | undefined, 103 | }); 104 | 105 | case "get_collection": 106 | return handleGetCollection(client, cache, { 107 | slug: String(request.params.arguments?.slug), 108 | }); 109 | 110 | case "create_prediction": 111 | return handleCreatePrediction(client, cache, { 112 | version: request.params.arguments?.version as string | undefined, 113 | model: request.params.arguments?.model as string | undefined, 114 | input: request.params.arguments?.input as ModelIO, 115 | webhook: request.params.arguments?.webhook_url as string | undefined, 116 | }); 117 | 118 | case "create_and_poll_prediction": 119 | return handleCreateAndPollPrediction(client, cache, { 120 | version: request.params.arguments?.version as string | undefined, 121 | model: request.params.arguments?.model as string | undefined, 122 | input: request.params.arguments?.input as ModelIO, 123 | webhook: request.params.arguments?.webhook_url as string | undefined, 124 | pollInterval: request.params.arguments?.poll_interval as 125 | | number 126 | | undefined, 127 | timeout: request.params.arguments?.timeout as number | undefined, 128 | }); 129 | 130 | case "cancel_prediction": 131 | return handleCancelPrediction(client, cache, { 132 | prediction_id: String(request.params.arguments?.prediction_id), 133 | }); 134 | 135 | case "get_prediction": 136 | return handleGetPrediction(client, cache, { 137 | prediction_id: String(request.params.arguments?.prediction_id), 138 | }); 139 | 140 | case "list_predictions": 141 | return handleListPredictions(client, cache, { 142 | limit: request.params.arguments?.limit as number | undefined, 143 | cursor: request.params.arguments?.cursor as string | undefined, 144 | }); 145 | 146 | case "get_model": 147 | return handleGetModel(client, cache, { 148 | owner: String(request.params.arguments?.owner), 149 | name: String(request.params.arguments?.name), 150 | }); 151 | 152 | case "view_image": 153 | return handleViewImage(request); 154 | 155 | case "clear_image_cache": 156 | return handleClearImageCache(request); 157 | 158 | case "get_image_cache_stats": 159 | return handleGetImageCacheStats(request); 160 | 161 | default: 162 | throw new Error("Unknown tool"); 163 | } 164 | }); 165 | 166 | /** 167 | * Start the server using stdio transport. 168 | */ 169 | async function main() { 170 | const transport = new StdioServerTransport(); 171 | await server.connect(transport); 172 | } 173 | 174 | main().catch((error) => { 175 | console.error("Server error:", error); 176 | process.exit(1); 177 | }); 178 | -------------------------------------------------------------------------------- /src/models/collection.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Data models for Replicate collections. 3 | */ 4 | 5 | import type { Model } from "./model.js"; 6 | 7 | /** 8 | * A collection of related models on Replicate. 9 | */ 10 | export interface Collection { 11 | /** Unique identifier for this collection */ 12 | id: string; 13 | /** Human-readable name of the collection */ 14 | name: string; 15 | /** URL-friendly slug for the collection */ 16 | slug: string; 17 | /** Description of the collection's purpose */ 18 | description?: string; 19 | /** Models included in this collection */ 20 | models: Model[]; 21 | /** Whether this collection is featured */ 22 | featured?: boolean; 23 | /** When this collection was created */ 24 | created_at: string; 25 | /** When this collection was last updated */ 26 | updated_at?: string; 27 | } 28 | 29 | /** 30 | * Response format for listing collections. 31 | */ 32 | export interface CollectionList { 33 | /** List of collections */ 34 | collections: Collection[]; 35 | /** Cursor for pagination */ 36 | next_cursor?: string; 37 | /** Total number of collections */ 38 | total_count?: number; 39 | } 40 | -------------------------------------------------------------------------------- /src/models/hardware.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Data models for Replicate hardware options. 3 | */ 4 | 5 | /** 6 | * A hardware option for running models on Replicate. 7 | */ 8 | export interface Hardware { 9 | /** Human-readable name of the hardware */ 10 | name: string; 11 | /** SKU identifier for the hardware */ 12 | sku: string; 13 | } 14 | 15 | /** 16 | * Response format for listing hardware options. 17 | */ 18 | export interface HardwareList { 19 | /** List of available hardware options */ 20 | hardware: Hardware[]; 21 | } 22 | -------------------------------------------------------------------------------- /src/models/model.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Data models for Replicate models and versions. 3 | */ 4 | 5 | /** 6 | * OpenAPI schema types 7 | */ 8 | export interface OpenAPISchema { 9 | openapi: string; 10 | info: { 11 | title: string; 12 | version: string; 13 | }; 14 | paths: Record; 15 | components?: { 16 | schemas?: Record; 17 | parameters?: Record; 18 | }; 19 | } 20 | 21 | /** 22 | * A specific version of a model on Replicate. 23 | */ 24 | export interface ModelVersion { 25 | /** Unique identifier for this model version */ 26 | id: string; 27 | /** When this version was created */ 28 | created_at: string; 29 | /** Version of Cog used to create this model */ 30 | cog_version: string; 31 | /** OpenAPI schema for the model */ 32 | openapi_schema: OpenAPISchema; 33 | /** Model identifier (owner/name) */ 34 | model?: string; 35 | /** Replicate version identifier */ 36 | replicate_version?: string; 37 | /** Hardware configuration for this version */ 38 | hardware?: string; 39 | } 40 | 41 | /** 42 | * Model information returned from Replicate. 43 | */ 44 | export interface Model { 45 | /** Unique identifier in format owner/name */ 46 | id: string; 47 | /** Owner of the model (user or organization) */ 48 | owner: string; 49 | /** Name of the model */ 50 | name: string; 51 | /** Description of the model's purpose and usage */ 52 | description?: string; 53 | /** Model visibility (public/private) */ 54 | visibility: "public" | "private"; 55 | /** URL to model's GitHub repository */ 56 | github_url?: string; 57 | /** URL to model's research paper */ 58 | paper_url?: string; 59 | /** URL to model's license */ 60 | license_url?: string; 61 | /** Number of times this model has been run */ 62 | run_count?: number; 63 | /** URL to model's cover image */ 64 | cover_image_url?: string; 65 | /** Latest version of the model */ 66 | latest_version?: ModelVersion; 67 | /** Default example inputs */ 68 | default_example?: Record; 69 | /** Whether this model is featured */ 70 | featured?: boolean; 71 | /** Model tags */ 72 | tags?: string[]; 73 | } 74 | 75 | /** 76 | * Response format for listing models. 77 | */ 78 | export interface ModelList { 79 | /** List of models */ 80 | models: Model[]; 81 | /** Cursor for pagination */ 82 | next_cursor?: string; 83 | /** Total number of models */ 84 | total_count?: number; 85 | } 86 | -------------------------------------------------------------------------------- /src/models/openapi.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * OpenAPI schema types. 3 | */ 4 | 5 | export interface OpenAPISchema { 6 | openapi: string; 7 | info: { 8 | title: string; 9 | version: string; 10 | }; 11 | components?: { 12 | schemas?: Record; 13 | }; 14 | } 15 | 16 | export interface SchemaObject { 17 | type?: string; 18 | required?: string[]; 19 | properties?: Record; 20 | additionalProperties?: boolean | SchemaObject; 21 | } 22 | 23 | export interface PropertyObject { 24 | type: string; 25 | format?: string; 26 | description?: string; 27 | default?: unknown; 28 | minimum?: number; 29 | maximum?: number; 30 | enum?: string[]; 31 | items?: SchemaObject; 32 | } 33 | 34 | export type ModelIO = Record; 35 | -------------------------------------------------------------------------------- /src/models/prediction.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Data models for Replicate predictions. 3 | */ 4 | 5 | /** 6 | * Status of a prediction. 7 | */ 8 | export enum PredictionStatus { 9 | Starting = "starting", 10 | Processing = "processing", 11 | Succeeded = "succeeded", 12 | Failed = "failed", 13 | Canceled = "canceled", 14 | } 15 | 16 | /** 17 | * Model input/output types 18 | */ 19 | export type ModelIO = Record; 20 | 21 | /** 22 | * Input parameters for creating a prediction. 23 | */ 24 | export interface PredictionInput { 25 | /** Model version to use for prediction */ 26 | model_version: string; 27 | /** Model-specific input parameters */ 28 | input: ModelIO; 29 | /** Optional template ID to use */ 30 | template_id?: string; 31 | /** URL for webhook notifications */ 32 | webhook_url?: string; 33 | /** Events to trigger webhooks */ 34 | webhook_events?: string[]; 35 | /** Whether to wait for prediction completion */ 36 | wait?: boolean; 37 | /** Max seconds to wait if wait=True (1-60) */ 38 | wait_timeout?: number; 39 | /** Whether to request streaming output */ 40 | stream?: boolean; 41 | } 42 | 43 | /** 44 | * A prediction (model run) on Replicate. 45 | */ 46 | export interface Prediction { 47 | /** Unique identifier for this prediction */ 48 | id: string; 49 | /** Model version used for this prediction */ 50 | version: string; 51 | /** Current status of the prediction */ 52 | status: PredictionStatus | string; 53 | /** Input parameters used for the prediction */ 54 | input: ModelIO; 55 | /** Output from the prediction if completed */ 56 | output?: ModelIO; 57 | /** Error message if prediction failed */ 58 | error?: string; 59 | /** Execution logs from the prediction */ 60 | logs?: string; 61 | /** When the prediction was created */ 62 | created_at: string; 63 | /** When the prediction started processing */ 64 | started_at?: string; 65 | /** When the prediction completed */ 66 | completed_at?: string; 67 | /** Related API URLs for this prediction */ 68 | urls: Record; 69 | /** Performance metrics if available */ 70 | metrics?: Record; 71 | /** URL for streaming output if requested */ 72 | stream_url?: string; 73 | } 74 | -------------------------------------------------------------------------------- /src/models/webhook.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Webhook types and utilities. 3 | */ 4 | 5 | import crypto from "node:crypto"; 6 | 7 | export interface WebhookConfig { 8 | url: string; 9 | secret?: string; 10 | retries?: number; 11 | timeout?: number; 12 | } 13 | 14 | export interface WebhookEvent { 15 | type: WebhookEventType; 16 | timestamp: string; 17 | data: Record; 18 | } 19 | 20 | export type WebhookEventType = 21 | | "prediction.created" 22 | | "prediction.processing" 23 | | "prediction.succeeded" 24 | | "prediction.failed" 25 | | "prediction.canceled"; 26 | 27 | export interface WebhookDeliveryResult { 28 | success: boolean; 29 | statusCode?: number; 30 | error?: string; 31 | retryCount: number; 32 | timestamp: string; 33 | } 34 | 35 | /** 36 | * Generate a webhook signature for request verification. 37 | */ 38 | export function generateWebhookSignature( 39 | payload: string, 40 | secret: string 41 | ): string { 42 | const hmac = crypto.createHmac("sha256", secret); 43 | hmac.update(payload); 44 | return `sha256=${hmac.digest("hex")}`; 45 | } 46 | 47 | /** 48 | * Verify a webhook signature from request headers. 49 | */ 50 | export function verifyWebhookSignature( 51 | payload: string, 52 | signature: string, 53 | secret: string 54 | ): boolean { 55 | const expectedSignature = generateWebhookSignature(payload, secret); 56 | return crypto.timingSafeEqual( 57 | Buffer.from(signature), 58 | Buffer.from(expectedSignature) 59 | ); 60 | } 61 | 62 | /** 63 | * Format a webhook event payload. 64 | */ 65 | export function formatWebhookEvent( 66 | type: WebhookEventType, 67 | data: Record 68 | ): WebhookEvent { 69 | return { 70 | type, 71 | timestamp: new Date().toISOString(), 72 | data, 73 | }; 74 | } 75 | 76 | /** 77 | * Validate webhook configuration. 78 | */ 79 | export function validateWebhookConfig(config: WebhookConfig): string[] { 80 | const errors: string[] = []; 81 | 82 | // Validate URL 83 | try { 84 | new URL(config.url); 85 | } catch { 86 | errors.push("Invalid webhook URL"); 87 | } 88 | 89 | // Validate secret 90 | if (config.secret && config.secret.length < 32) { 91 | errors.push("Webhook secret should be at least 32 characters long"); 92 | } 93 | 94 | // Validate retries 95 | if (config.retries !== undefined) { 96 | if (!Number.isInteger(config.retries) || config.retries < 0) { 97 | errors.push("Retries must be a non-negative integer"); 98 | } 99 | } 100 | 101 | // Validate timeout 102 | if (config.timeout !== undefined) { 103 | if (!Number.isInteger(config.timeout) || config.timeout < 1000) { 104 | errors.push("Timeout must be at least 1000ms"); 105 | } 106 | } 107 | 108 | return errors; 109 | } 110 | 111 | /** 112 | * Default webhook configuration. 113 | */ 114 | export const DEFAULT_WEBHOOK_CONFIG: Partial = { 115 | retries: 3, 116 | timeout: 10000, // 10 seconds 117 | }; 118 | 119 | /** 120 | * Retry delay calculator with exponential backoff. 121 | */ 122 | export function calculateRetryDelay(attempt: number, baseDelay = 1000): number { 123 | const maxDelay = 60000; // 1 minute 124 | const delay = Math.min(baseDelay * 2 ** attempt, maxDelay); 125 | // Add jitter to prevent thundering herd 126 | return delay + Math.random() * 1000; 127 | } 128 | -------------------------------------------------------------------------------- /src/replicate_client.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Replicate API client implementation with caching support. 3 | */ 4 | 5 | import Replicate from "replicate"; 6 | import type { Model, ModelList, ModelVersion } from "./models/model.js"; 7 | import type { 8 | Prediction, 9 | PredictionInput, 10 | PredictionStatus, 11 | ModelIO, 12 | } from "./models/prediction.js"; 13 | import type { Collection, CollectionList } from "./models/collection.js"; 14 | import { 15 | modelCache, 16 | predictionCache, 17 | collectionCache, 18 | } from "./services/cache.js"; 19 | import { ReplicateError, ErrorHandler, createError } from "./services/error.js"; 20 | 21 | // Constants 22 | const REPLICATE_API_BASE = "https://api.replicate.com/v1"; 23 | const DEFAULT_TIMEOUT = 60000; // 60 seconds in milliseconds 24 | const MAX_RETRIES = 3; 25 | const MIN_RETRY_DELAY = 1000; // 1 second in milliseconds 26 | const MAX_RETRY_DELAY = 10000; // 10 seconds in milliseconds 27 | const DEFAULT_RATE_LIMIT = 100; // requests per minute 28 | 29 | // Type definitions for Replicate client responses 30 | interface ReplicateModel { 31 | owner: string; 32 | name: string; 33 | description?: string; 34 | visibility?: "public" | "private"; 35 | github_url?: string; 36 | paper_url?: string; 37 | license_url?: string; 38 | run_count?: number; 39 | cover_image_url?: string; 40 | default_example?: Record; 41 | featured?: boolean; 42 | tags?: string[]; 43 | latest_version?: ModelVersion; 44 | } 45 | 46 | interface ReplicatePrediction { 47 | id: string; 48 | version: string; 49 | status: string; 50 | input: Record; 51 | output?: unknown; 52 | error?: unknown; 53 | logs?: string; 54 | created_at: string; 55 | started_at?: string; 56 | completed_at?: string; 57 | urls: Record; 58 | metrics?: Record; 59 | } 60 | 61 | interface ReplicatePage { 62 | results: T[]; 63 | next?: string; 64 | previous?: string; 65 | total?: number; 66 | } 67 | 68 | interface CreatePredictionOptions { 69 | version?: string; 70 | model?: string; 71 | input: ModelIO | string; 72 | webhook?: string; 73 | } 74 | 75 | /** 76 | * Client for interacting with the Replicate API. 77 | */ 78 | export class ReplicateClient { 79 | private api_token: string; 80 | private rate_limit: number; 81 | private request_times: number[]; 82 | private retry_count: number; 83 | private client: Replicate; 84 | 85 | constructor(api_token?: string) { 86 | this.api_token = api_token || process.env.REPLICATE_API_TOKEN || ""; 87 | if (!this.api_token) { 88 | throw new Error("Replicate API token is required"); 89 | } 90 | 91 | this.rate_limit = DEFAULT_RATE_LIMIT; 92 | this.request_times = []; 93 | this.retry_count = 0; 94 | this.client = new Replicate({ auth: this.api_token }); 95 | } 96 | 97 | /** 98 | * Wait if necessary to comply with rate limiting. 99 | */ 100 | private async waitForRateLimit(): Promise { 101 | const now = Date.now(); 102 | 103 | // Remove request times older than 1 minute 104 | this.request_times = this.request_times.filter((t) => now - t <= 60000); 105 | 106 | if (this.request_times.length >= this.rate_limit) { 107 | // Calculate wait time based on oldest request 108 | const wait_time = 60000 - (now - this.request_times[0]); 109 | if (wait_time > 0) { 110 | console.debug(`Rate limit reached. Waiting ${wait_time}ms`); 111 | await new Promise((resolve) => setTimeout(resolve, wait_time)); 112 | } 113 | } 114 | 115 | this.request_times.push(now); 116 | } 117 | 118 | /** 119 | * Handle rate limits and other response headers. 120 | */ 121 | private async handleResponse(response: Response): Promise { 122 | // Update rate limit from headers if available 123 | const limit = response.headers.get("X-RateLimit-Limit"); 124 | const remaining = response.headers.get("X-RateLimit-Remaining"); 125 | const reset = response.headers.get("X-RateLimit-Reset"); 126 | 127 | if (limit) { 128 | this.rate_limit = Number.parseInt(limit, 10); 129 | } 130 | 131 | // Handle rate limit exceeded 132 | if (response.status === 429) { 133 | const retryAfter = Number.parseInt( 134 | response.headers.get("Retry-After") || "60", 135 | 10 136 | ); 137 | throw createError.rateLimit(retryAfter); 138 | } 139 | } 140 | 141 | /** 142 | * Make an HTTP request with retries and rate limiting. 143 | */ 144 | private async makeRequest( 145 | method: string, 146 | endpoint: string, 147 | options: RequestInit = {} 148 | ): Promise { 149 | await this.waitForRateLimit(); 150 | 151 | return ErrorHandler.withRetries( 152 | async () => { 153 | const response = await fetch(`${REPLICATE_API_BASE}${endpoint}`, { 154 | method, 155 | headers: { 156 | Authorization: `Token ${this.api_token}`, 157 | "Content-Type": "application/json", 158 | ...options.headers, 159 | }, 160 | ...options, 161 | signal: AbortSignal.timeout(DEFAULT_TIMEOUT), 162 | }); 163 | 164 | await this.handleResponse(response); 165 | 166 | if (!response.ok) { 167 | throw await ErrorHandler.parseAPIError(response); 168 | } 169 | 170 | return response.json(); 171 | }, 172 | { 173 | maxAttempts: MAX_RETRIES, 174 | minDelay: MIN_RETRY_DELAY, 175 | maxDelay: MAX_RETRY_DELAY, 176 | onRetry: (error: Error, attempt: number) => { 177 | console.warn( 178 | `Request failed: ${error.message}. `, 179 | `Retrying (attempt ${attempt + 1}/${MAX_RETRIES})` 180 | ); 181 | }, 182 | } 183 | ); 184 | } 185 | 186 | /** 187 | * List available models on Replicate with pagination. 188 | */ 189 | async listModels( 190 | options: { owner?: string; cursor?: string } = {} 191 | ): Promise { 192 | try { 193 | // Check cache first 194 | const cacheKey = `models:${options.owner || "all"}:${ 195 | options.cursor || "" 196 | }`; 197 | const cached = modelCache.get(cacheKey); 198 | if (cached) { 199 | return cached; 200 | } 201 | 202 | if (options.owner) { 203 | // If owner is specified, use search to find their models 204 | const response = (await this.client.models.search( 205 | `owner:${options.owner}` 206 | )) as unknown as ReplicatePage; 207 | 208 | const result: ModelList = { 209 | models: response.results.map((model) => ({ 210 | id: `${model.owner}/${model.name}`, 211 | owner: model.owner, 212 | name: model.name, 213 | description: model.description || "", 214 | visibility: model.visibility || "public", 215 | github_url: model.github_url, 216 | paper_url: model.paper_url, 217 | license_url: model.license_url, 218 | run_count: model.run_count, 219 | cover_image_url: model.cover_image_url, 220 | default_example: model.default_example, 221 | featured: model.featured || false, 222 | tags: model.tags || [], 223 | latest_version: model.latest_version 224 | ? { 225 | id: model.latest_version.id, 226 | created_at: model.latest_version.created_at, 227 | cog_version: model.latest_version.cog_version, 228 | openapi_schema: { 229 | ...model.latest_version.openapi_schema, 230 | openapi: "3.0.0", 231 | info: { 232 | title: `${model.owner}/${model.name}`, 233 | version: model.latest_version.id, 234 | }, 235 | paths: {}, 236 | }, 237 | } 238 | : undefined, 239 | })), 240 | next_cursor: response.next, 241 | total_count: response.total || response.results.length, 242 | }; 243 | 244 | // Cache the result 245 | modelCache.set(cacheKey, result); 246 | return result; 247 | } 248 | 249 | // Otherwise list all models 250 | const params = new URLSearchParams(); 251 | if (options.cursor) { 252 | params.set("cursor", options.cursor); 253 | } 254 | 255 | const response = await this.makeRequest>( 256 | "GET", 257 | `/models${params.toString() ? `?${params.toString()}` : ""}` 258 | ); 259 | 260 | const result: ModelList = { 261 | models: response.results.map((model) => ({ 262 | id: `${model.owner}/${model.name}`, 263 | owner: model.owner, 264 | name: model.name, 265 | description: model.description || "", 266 | visibility: model.visibility || "public", 267 | github_url: model.github_url, 268 | paper_url: model.paper_url, 269 | license_url: model.license_url, 270 | run_count: model.run_count, 271 | cover_image_url: model.cover_image_url, 272 | default_example: model.default_example, 273 | featured: model.featured || false, 274 | tags: model.tags || [], 275 | latest_version: model.latest_version 276 | ? { 277 | id: model.latest_version.id, 278 | created_at: model.latest_version.created_at, 279 | cog_version: model.latest_version.cog_version, 280 | openapi_schema: { 281 | ...model.latest_version.openapi_schema, 282 | openapi: "3.0.0", 283 | info: { 284 | title: `${model.owner}/${model.name}`, 285 | version: model.latest_version.id, 286 | }, 287 | paths: {}, 288 | }, 289 | } 290 | : undefined, 291 | })), 292 | next_cursor: response.next, 293 | total_count: response.total || response.results.length, 294 | }; 295 | 296 | // Cache the result 297 | modelCache.set(cacheKey, result); 298 | return result; 299 | } catch (error) { 300 | throw ErrorHandler.parseAPIError(error as Response); 301 | } 302 | } 303 | 304 | /** 305 | * Search for models using semantic search. 306 | */ 307 | async searchModels(query: string): Promise { 308 | try { 309 | // Check cache first 310 | const cacheKey = `search:${query}`; 311 | const cached = modelCache.get(cacheKey); 312 | if (cached) { 313 | return cached; 314 | } 315 | 316 | // Use the official client for search 317 | const response = (await this.client.models.search( 318 | query 319 | )) as unknown as ReplicatePage; 320 | 321 | const result: ModelList = { 322 | models: response.results.map((model) => ({ 323 | id: `${model.owner}/${model.name}`, 324 | owner: model.owner, 325 | name: model.name, 326 | description: model.description || "", 327 | visibility: model.visibility || "public", 328 | github_url: model.github_url, 329 | paper_url: model.paper_url, 330 | license_url: model.license_url, 331 | run_count: model.run_count, 332 | cover_image_url: model.cover_image_url, 333 | default_example: model.default_example, 334 | featured: model.featured || false, 335 | tags: model.tags || [], 336 | latest_version: model.latest_version 337 | ? { 338 | id: model.latest_version.id, 339 | created_at: model.latest_version.created_at, 340 | cog_version: model.latest_version.cog_version, 341 | openapi_schema: { 342 | ...model.latest_version.openapi_schema, 343 | openapi: "3.0.0", 344 | info: { 345 | title: `${model.owner}/${model.name}`, 346 | version: model.latest_version.id, 347 | }, 348 | paths: {}, 349 | }, 350 | } 351 | : undefined, 352 | })), 353 | next_cursor: response.next, 354 | total_count: response.results.length, 355 | }; 356 | 357 | // Cache the result 358 | modelCache.set(cacheKey, result); 359 | return result; 360 | } catch (error) { 361 | throw ErrorHandler.parseAPIError(error as Response); 362 | } 363 | } 364 | 365 | /** 366 | * List available collections. 367 | */ 368 | async listCollections( 369 | options: { 370 | cursor?: string; 371 | } = {} 372 | ): Promise { 373 | try { 374 | // Check cache first 375 | const cacheKey = `collections:${options.cursor || ""}`; 376 | const cached = collectionCache.get(cacheKey); 377 | if (cached) { 378 | return cached; 379 | } 380 | 381 | // Use the official client for collections 382 | const response = await this.client.collections.list(); 383 | 384 | const result: CollectionList = { 385 | collections: response.results.map((collection) => ({ 386 | id: collection.slug, 387 | name: collection.name, 388 | slug: collection.slug, 389 | description: collection.description || "", 390 | models: 391 | collection.models?.map((model) => ({ 392 | id: `${model.owner}/${model.name}`, 393 | owner: model.owner, 394 | name: model.name, 395 | description: model.description || "", 396 | visibility: model.visibility || "public", 397 | github_url: model.github_url, 398 | paper_url: model.paper_url, 399 | license_url: model.license_url, 400 | run_count: model.run_count, 401 | cover_image_url: model.cover_image_url, 402 | default_example: model.default_example 403 | ? ({ 404 | input: model.default_example.input, 405 | output: model.default_example.output, 406 | error: model.default_example.error, 407 | status: model.default_example.status, 408 | logs: model.default_example.logs, 409 | metrics: model.default_example.metrics, 410 | } as Record) 411 | : undefined, 412 | featured: false, 413 | tags: [], 414 | latest_version: model.latest_version 415 | ? { 416 | id: model.latest_version.id, 417 | created_at: model.latest_version.created_at, 418 | cog_version: model.latest_version.cog_version, 419 | openapi_schema: { 420 | ...model.latest_version.openapi_schema, 421 | openapi: "3.0.0", 422 | info: { 423 | title: `${model.owner}/${model.name}`, 424 | version: model.latest_version.id, 425 | }, 426 | paths: {}, 427 | }, 428 | } 429 | : undefined, 430 | })) || [], 431 | featured: false, 432 | created_at: new Date().toISOString(), 433 | updated_at: undefined, 434 | })), 435 | next_cursor: response.next, 436 | total_count: response.results.length, 437 | }; 438 | 439 | // Cache the result 440 | collectionCache.set(cacheKey, result); 441 | return result; 442 | } catch (error) { 443 | throw ErrorHandler.parseAPIError(error as Response); 444 | } 445 | } 446 | 447 | /** 448 | * Get a specific collection by slug. 449 | */ 450 | async getCollection(slug: string): Promise { 451 | try { 452 | // Check cache first 453 | const cacheKey = `collection:${slug}`; 454 | const cached = collectionCache.get(cacheKey); 455 | if (cached?.collections?.length === 1) { 456 | return cached.collections[0]; 457 | } 458 | 459 | interface CollectionResponse { 460 | id: string; 461 | name: string; 462 | slug: string; 463 | description?: string; 464 | models: ReplicateModel[]; 465 | featured?: boolean; 466 | created_at: string; 467 | updated_at?: string; 468 | } 469 | 470 | const response = await this.makeRequest( 471 | "GET", 472 | `/collections/${slug}` 473 | ); 474 | 475 | const collection: Collection = { 476 | id: response.id, 477 | name: response.name, 478 | slug: response.slug, 479 | description: response.description || "", 480 | models: response.models.map((model) => ({ 481 | id: `${model.owner}/${model.name}`, 482 | owner: model.owner, 483 | name: model.name, 484 | description: model.description || "", 485 | visibility: model.visibility || "public", 486 | github_url: model.github_url, 487 | paper_url: model.paper_url, 488 | license_url: model.license_url, 489 | run_count: model.run_count, 490 | cover_image_url: model.cover_image_url, 491 | default_example: model.default_example, 492 | featured: model.featured || false, 493 | tags: model.tags || [], 494 | latest_version: model.latest_version, 495 | })), 496 | featured: response.featured || false, 497 | created_at: response.created_at, 498 | updated_at: response.updated_at, 499 | }; 500 | 501 | // Cache the result as a single-item collection list 502 | collectionCache.set(cacheKey, { 503 | collections: [collection], 504 | total_count: 1, 505 | }); 506 | return collection; 507 | } catch (error) { 508 | throw ErrorHandler.parseAPIError(error as Response); 509 | } 510 | } 511 | 512 | /** 513 | * Create a new prediction. 514 | */ 515 | async createPrediction( 516 | options: CreatePredictionOptions 517 | ): Promise { 518 | try { 519 | // If input is a string, wrap it in an object with 'prompt' property 520 | const input = 521 | typeof options.input === "string" 522 | ? { prompt: options.input } 523 | : options.input; 524 | 525 | // Create prediction parameters with the correct type 526 | const predictionParams = options.version 527 | ? { 528 | version: options.version, 529 | input, 530 | webhook: options.webhook, 531 | } 532 | : { 533 | model: options.model!, 534 | input, 535 | webhook: options.webhook, 536 | }; 537 | 538 | if (!options.version && !options.model) { 539 | throw new Error("Either model or version must be provided"); 540 | } 541 | 542 | // Use the official client for predictions 543 | const prediction = (await this.client.predictions.create( 544 | predictionParams 545 | )) as unknown as ReplicatePrediction; 546 | 547 | const result = { 548 | id: prediction.id, 549 | version: prediction.version, 550 | status: prediction.status as PredictionStatus, 551 | input: prediction.input as ModelIO, 552 | output: prediction.output as ModelIO | undefined, 553 | error: prediction.error ? String(prediction.error) : undefined, 554 | logs: prediction.logs, 555 | created_at: prediction.created_at, 556 | started_at: prediction.started_at, 557 | completed_at: prediction.completed_at, 558 | urls: prediction.urls, 559 | metrics: prediction.metrics, 560 | }; 561 | 562 | // Cache the result 563 | predictionCache.set(`prediction:${prediction.id}`, [result]); 564 | return result; 565 | } catch (error) { 566 | throw ErrorHandler.parseAPIError(error as Response); 567 | } 568 | } 569 | 570 | /** 571 | * Get the status of a prediction. 572 | */ 573 | async getPredictionStatus(prediction_id: string): Promise { 574 | try { 575 | // Check cache first 576 | const cacheKey = `prediction:${prediction_id}`; 577 | const cached = predictionCache.get(cacheKey); 578 | // Only use cache for completed predictions 579 | if ( 580 | cached?.length === 1 && 581 | ["succeeded", "failed", "canceled"].includes(cached[0].status) 582 | ) { 583 | return cached[0]; 584 | } 585 | 586 | // Use the official client for predictions 587 | const prediction = (await this.client.predictions.get( 588 | prediction_id 589 | )) as unknown as ReplicatePrediction; 590 | 591 | const result = { 592 | id: prediction.id, 593 | version: prediction.version, 594 | status: prediction.status as PredictionStatus, 595 | input: prediction.input as ModelIO, 596 | output: prediction.output as ModelIO | undefined, 597 | error: prediction.error ? String(prediction.error) : undefined, 598 | logs: prediction.logs, 599 | created_at: prediction.created_at, 600 | started_at: prediction.started_at, 601 | completed_at: prediction.completed_at, 602 | urls: prediction.urls, 603 | metrics: prediction.metrics, 604 | }; 605 | 606 | // Cache completed predictions 607 | if (["succeeded", "failed", "canceled"].includes(result.status)) { 608 | predictionCache.set(cacheKey, [result]); 609 | } 610 | 611 | return result; 612 | } catch (error) { 613 | throw ErrorHandler.parseAPIError(error as Response); 614 | } 615 | } 616 | 617 | /** 618 | * Cancel a running prediction. 619 | */ 620 | async cancelPrediction(prediction_id: string): Promise { 621 | try { 622 | const response = await this.makeRequest( 623 | "POST", 624 | `/predictions/${prediction_id}/cancel` 625 | ); 626 | 627 | const result = { 628 | id: response.id, 629 | version: response.version, 630 | status: response.status as PredictionStatus, 631 | input: response.input as ModelIO, 632 | output: response.output as ModelIO | undefined, 633 | error: response.error ? String(response.error) : undefined, 634 | logs: response.logs, 635 | created_at: response.created_at, 636 | started_at: response.started_at, 637 | completed_at: response.completed_at, 638 | urls: response.urls, 639 | metrics: response.metrics, 640 | }; 641 | 642 | // Update cache 643 | predictionCache.set(`prediction:${prediction_id}`, [result]); 644 | return result; 645 | } catch (error) { 646 | throw ErrorHandler.parseAPIError(error as Response); 647 | } 648 | } 649 | 650 | /** 651 | * List predictions with optional filtering. 652 | */ 653 | async listPredictions( 654 | options: { 655 | status?: PredictionStatus; 656 | limit?: number; 657 | cursor?: string; 658 | } = {} 659 | ): Promise { 660 | try { 661 | // Check cache first 662 | const cacheKey = `predictions:${options.status || "all"}:${ 663 | options.limit || "all" 664 | }:${options.cursor || ""}`; 665 | const cached = predictionCache.get(cacheKey); 666 | if (cached) { 667 | return cached; 668 | } 669 | 670 | // Use the official client for predictions 671 | const response = 672 | (await this.client.predictions.list()) as unknown as ReplicatePage; 673 | 674 | // Filter and limit results 675 | const filteredPredictions = options.status 676 | ? response.results.filter((p) => p.status === options.status) 677 | : response.results; 678 | 679 | const limitedPredictions = options.limit 680 | ? filteredPredictions.slice(0, options.limit) 681 | : filteredPredictions; 682 | 683 | const result = limitedPredictions.map((prediction) => ({ 684 | id: prediction.id, 685 | version: prediction.version, 686 | status: prediction.status as PredictionStatus, 687 | input: prediction.input as ModelIO, 688 | output: prediction.output as ModelIO | undefined, 689 | error: prediction.error ? String(prediction.error) : undefined, 690 | logs: prediction.logs, 691 | created_at: prediction.created_at, 692 | started_at: prediction.started_at, 693 | completed_at: prediction.completed_at, 694 | urls: prediction.urls, 695 | metrics: prediction.metrics, 696 | })); 697 | 698 | // Cache the result 699 | predictionCache.set(cacheKey, result); 700 | return result; 701 | } catch (error) { 702 | throw ErrorHandler.parseAPIError(error as Response); 703 | } 704 | } 705 | 706 | /** 707 | * Get details of a specific model including versions. 708 | */ 709 | async getModel(owner: string, name: string): Promise { 710 | try { 711 | // Check cache first 712 | const cacheKey = `model:${owner}/${name}`; 713 | const cached = modelCache.get(cacheKey); 714 | if (cached?.models?.length === 1) { 715 | return cached.models[0]; 716 | } 717 | 718 | // Use direct API request to get model details 719 | const response = await this.makeRequest( 720 | "GET", 721 | `/models/${owner}/${name}` 722 | ).catch((error) => { 723 | throw ErrorHandler.parseAPIError(error); 724 | }); 725 | 726 | // Get model versions 727 | const versionsResponse = await this.makeRequest< 728 | ReplicatePage 729 | >("GET", `/models/${owner}/${name}/versions`).catch((error) => { 730 | throw ErrorHandler.parseAPIError(error); 731 | }); 732 | 733 | const model: Model = { 734 | id: `${response.owner}/${response.name}`, 735 | owner: response.owner, 736 | name: response.name, 737 | description: response.description || "", 738 | visibility: response.visibility || "public", 739 | github_url: response.github_url, 740 | paper_url: response.paper_url, 741 | license_url: response.license_url, 742 | run_count: response.run_count, 743 | cover_image_url: response.cover_image_url, 744 | default_example: response.default_example, 745 | latest_version: versionsResponse.results[0] 746 | ? { 747 | id: versionsResponse.results[0].id, 748 | created_at: versionsResponse.results[0].created_at, 749 | cog_version: versionsResponse.results[0].cog_version, 750 | openapi_schema: { 751 | ...versionsResponse.results[0].openapi_schema, 752 | openapi: "3.0.0", 753 | info: { 754 | title: `${response.owner}/${response.name}`, 755 | version: versionsResponse.results[0].id, 756 | }, 757 | paths: {}, 758 | }, 759 | } 760 | : undefined, 761 | }; 762 | 763 | // Cache the result as a single-item model list 764 | modelCache.set(cacheKey, { 765 | models: [model], 766 | total_count: 1, 767 | }); 768 | return model; 769 | } catch (error) { 770 | if (error instanceof Promise) { 771 | throw new ReplicateError("Failed to fetch model details"); 772 | } 773 | throw ErrorHandler.parseAPIError(error as Response); 774 | } 775 | } 776 | 777 | /** 778 | * Get the webhook signing secret. 779 | */ 780 | async getWebhookSecret(): Promise { 781 | interface WebhookResponse { 782 | key: string; 783 | } 784 | 785 | const response = await this.makeRequest( 786 | "GET", 787 | "/webhooks/default/secret" 788 | ); 789 | 790 | return response.key; 791 | } 792 | } 793 | -------------------------------------------------------------------------------- /src/services/cache.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Cache service implementation with TTL and LRU strategy. 3 | */ 4 | 5 | import type { Collection, CollectionList } from "../models/collection.js"; 6 | import type { Model, ModelList } from "../models/model.js"; 7 | import type { Prediction } from "../models/prediction.js"; 8 | 9 | interface CacheEntry { 10 | value: T; 11 | timestamp: number; 12 | lastAccessed: number; 13 | } 14 | 15 | export interface CacheStats { 16 | hits: number; 17 | misses: number; 18 | evictions: number; 19 | size: number; 20 | } 21 | 22 | export class Cache { 23 | private cache: Map>; 24 | private maxSize: number; 25 | private ttl: number; 26 | private stats: CacheStats; 27 | 28 | constructor(maxSize = 1000, ttlSeconds = 300) { 29 | this.cache = new Map(); 30 | this.maxSize = maxSize; 31 | this.ttl = ttlSeconds * 1000; // Convert to milliseconds 32 | this.stats = { 33 | hits: 0, 34 | misses: 0, 35 | evictions: 0, 36 | size: 0, 37 | }; 38 | } 39 | 40 | /** 41 | * Set a value in the cache with TTL. 42 | */ 43 | set(key: string, value: T): void { 44 | // Evict oldest entries if cache is full 45 | if (this.cache.size >= this.maxSize) { 46 | this.evictOldest(); 47 | } 48 | 49 | this.cache.set(key, { 50 | value, 51 | timestamp: Date.now(), 52 | lastAccessed: Date.now(), 53 | }); 54 | this.stats.size = this.cache.size; 55 | } 56 | 57 | /** 58 | * Get a value from the cache, considering TTL. 59 | */ 60 | get(key: string): T | null { 61 | const entry = this.cache.get(key); 62 | 63 | if (!entry) { 64 | this.stats.misses++; 65 | return null; 66 | } 67 | 68 | // Check if entry has expired 69 | if (Date.now() - entry.timestamp > this.ttl) { 70 | this.cache.delete(key); 71 | this.stats.evictions++; 72 | this.stats.size = this.cache.size; 73 | this.stats.misses++; 74 | return null; 75 | } 76 | 77 | // Update last accessed time for LRU 78 | entry.lastAccessed = Date.now(); 79 | this.stats.hits++; 80 | return entry.value; 81 | } 82 | 83 | /** 84 | * Remove a specific key from the cache. 85 | */ 86 | delete(key: string): void { 87 | if (this.cache.delete(key)) { 88 | this.stats.evictions++; 89 | this.stats.size = this.cache.size; 90 | } 91 | } 92 | 93 | /** 94 | * Clear all entries from the cache. 95 | */ 96 | clear(): void { 97 | this.cache.clear(); 98 | this.stats.evictions += this.stats.size; 99 | this.stats.size = 0; 100 | } 101 | 102 | /** 103 | * Get cache statistics. 104 | */ 105 | getStats(): CacheStats { 106 | return { ...this.stats }; 107 | } 108 | 109 | /** 110 | * Warm up the cache with initial data. 111 | */ 112 | warmup(entries: [string, T][]): void { 113 | for (const [key, value] of entries) { 114 | this.set(key, value); 115 | } 116 | } 117 | 118 | /** 119 | * Remove expired entries from the cache. 120 | */ 121 | cleanup(): number { 122 | const now = Date.now(); 123 | let removed = 0; 124 | 125 | for (const [key, entry] of this.cache.entries()) { 126 | if (now - entry.timestamp > this.ttl) { 127 | this.cache.delete(key); 128 | removed++; 129 | } 130 | } 131 | 132 | this.stats.evictions += removed; 133 | this.stats.size = this.cache.size; 134 | return removed; 135 | } 136 | 137 | /** 138 | * Get all valid (non-expired) keys in the cache. 139 | */ 140 | keys(): string[] { 141 | const now = Date.now(); 142 | return Array.from(this.cache.entries()) 143 | .filter(([_, entry]) => now - entry.timestamp <= this.ttl) 144 | .map(([key]) => key); 145 | } 146 | 147 | /** 148 | * Check if a key exists and is not expired. 149 | */ 150 | has(key: string): boolean { 151 | const entry = this.cache.get(key); 152 | if (!entry) return false; 153 | 154 | const expired = Date.now() - entry.timestamp > this.ttl; 155 | if (expired) { 156 | this.cache.delete(key); 157 | this.stats.evictions++; 158 | this.stats.size = this.cache.size; 159 | return false; 160 | } 161 | 162 | return true; 163 | } 164 | 165 | private evictOldest(): void { 166 | let oldestKey: string | null = null; 167 | let oldestAccess = Number.POSITIVE_INFINITY; 168 | 169 | // Find the least recently used entry 170 | for (const [key, entry] of this.cache.entries()) { 171 | if (entry.lastAccessed < oldestAccess) { 172 | oldestAccess = entry.lastAccessed; 173 | oldestKey = key; 174 | } 175 | } 176 | 177 | // Remove the oldest entry 178 | if (oldestKey) { 179 | this.cache.delete(oldestKey); 180 | this.stats.evictions++; 181 | this.stats.size = this.cache.size; 182 | } 183 | } 184 | } 185 | 186 | // Create specialized cache instances for different types 187 | export const modelCache = new Cache(500, 3600); // 1 hour TTL for models 188 | export const predictionCache = new Cache(1000, 60); // 1 minute TTL for predictions 189 | export const collectionCache = new Cache(100, 3600); // 1 hour TTL for collections 190 | -------------------------------------------------------------------------------- /src/services/error.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Simple error handling system for Replicate API. 3 | */ 4 | 5 | interface APIErrorResponse { 6 | error?: string; 7 | code?: string; 8 | details?: Record; 9 | } 10 | 11 | /** 12 | * Base class for all Replicate API errors. 13 | */ 14 | export class ReplicateError extends Error { 15 | constructor(message: string, public context?: Record) { 16 | super(message); 17 | this.name = "ReplicateError"; 18 | } 19 | } 20 | 21 | // Re-export specialized error types as ReplicateError instances 22 | export const createError = { 23 | rateLimit: (retryAfter: number) => 24 | new ReplicateError("Rate limit exceeded", { retryAfter }), 25 | 26 | authentication: (details?: string) => 27 | new ReplicateError("Authentication failed", { details }), 28 | 29 | notFound: (resource: string) => 30 | new ReplicateError("Model not found", { resource }), 31 | 32 | api: (status: number, message: string) => 33 | new ReplicateError("API error", { status, message }), 34 | 35 | prediction: (id: string, message: string) => 36 | new ReplicateError("Prediction failed", { predictionId: id, message }), 37 | 38 | validation: (field: string, message: string) => 39 | new ReplicateError("Invalid input parameters", { field, message }), 40 | 41 | timeout: (operation: string, ms: number) => 42 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }), 43 | }; 44 | 45 | interface RetryOptions { 46 | maxAttempts?: number; 47 | minDelay?: number; 48 | maxDelay?: number; 49 | backoffFactor?: number; 50 | retryIf?: (error: Error) => boolean; 51 | onRetry?: (error: Error, attempt: number) => void; 52 | } 53 | 54 | /** 55 | * Error handling utilities. 56 | */ 57 | export const ErrorHandler = { 58 | /** 59 | * Check if an error should trigger a retry. 60 | */ 61 | isRetryable(error: Error): boolean { 62 | if (!(error instanceof ReplicateError)) return false; 63 | 64 | const retryableMessages = [ 65 | "Rate limit exceeded", 66 | "Internal server error", 67 | "Gateway timeout", 68 | "Service unavailable", 69 | ]; 70 | 71 | return retryableMessages.some((msg) => error.message.includes(msg)); 72 | }, 73 | 74 | /** 75 | * Calculate delay for exponential backoff. 76 | */ 77 | getBackoffDelay( 78 | attempt: number, 79 | { 80 | minDelay = 1000, 81 | maxDelay = 30000, 82 | backoffFactor = 2, 83 | }: Partial = {} 84 | ): number { 85 | const delay = minDelay * backoffFactor ** attempt; 86 | return Math.min(delay, maxDelay); 87 | }, 88 | 89 | /** 90 | * Execute an operation with automatic retries. 91 | */ 92 | async withRetries( 93 | operation: () => Promise, 94 | optionsOrMaxAttempts: RetryOptions | number = {} 95 | ): Promise { 96 | const options: RetryOptions = 97 | typeof optionsOrMaxAttempts === "number" 98 | ? { maxAttempts: optionsOrMaxAttempts } 99 | : optionsOrMaxAttempts; 100 | 101 | const { 102 | maxAttempts = 3, 103 | minDelay = 1000, 104 | maxDelay = 30000, 105 | backoffFactor = 2, 106 | retryIf = this.isRetryable, 107 | onRetry, 108 | } = options; 109 | 110 | let lastError: Error | undefined; 111 | 112 | for (let attempt = 0; attempt < maxAttempts; attempt++) { 113 | try { 114 | return await operation(); 115 | } catch (error) { 116 | lastError = error instanceof Error ? error : new Error(String(error)); 117 | 118 | if (attempt === maxAttempts - 1 || !retryIf(lastError)) { 119 | throw lastError; 120 | } 121 | 122 | const delay = this.getBackoffDelay(attempt, { 123 | minDelay, 124 | maxDelay, 125 | backoffFactor, 126 | }); 127 | 128 | if (onRetry) { 129 | onRetry(lastError, attempt); 130 | } 131 | 132 | await new Promise((resolve) => setTimeout(resolve, delay)); 133 | } 134 | } 135 | 136 | throw lastError || new Error("Operation failed"); 137 | }, 138 | 139 | /** 140 | * Parse error from API response. 141 | */ 142 | async parseAPIError(response: Response): Promise { 143 | const status = response.status; 144 | let message = response.statusText; 145 | let context: Record = { status }; 146 | 147 | try { 148 | const contentType = response.headers.get("content-type"); 149 | if (contentType?.includes("application/json")) { 150 | const details = (await response.json()) as APIErrorResponse; 151 | message = details.error || message; 152 | context = { ...context, ...details }; 153 | } 154 | } catch { 155 | // Ignore JSON parsing errors 156 | } 157 | 158 | return new ReplicateError(message, context); 159 | }, 160 | 161 | /** 162 | * Create a detailed error report. 163 | */ 164 | createErrorReport(error: Error): { 165 | name: string; 166 | message: string; 167 | context?: Record; 168 | timestamp: string; 169 | } { 170 | if (error instanceof ReplicateError) { 171 | return { 172 | name: error.name, 173 | message: error.message, 174 | context: error.context, 175 | timestamp: new Date().toISOString(), 176 | }; 177 | } 178 | 179 | return { 180 | name: error.name || "Error", 181 | message: error.message, 182 | timestamp: new Date().toISOString(), 183 | }; 184 | }, 185 | }; 186 | -------------------------------------------------------------------------------- /src/services/image_viewer.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Image viewer service for handling system image display and caching. 3 | */ 4 | 5 | import { Cache, type CacheStats } from "./cache.js"; 6 | 7 | // MCP-specific error types 8 | export enum ErrorCode { 9 | InvalidRequest = "INVALID_REQUEST", 10 | InternalError = "INTERNAL_ERROR", 11 | UnsupportedFormat = "UNSUPPORTED_FORMAT", 12 | } 13 | 14 | export class McpError extends Error { 15 | constructor(public code: ErrorCode, message: string) { 16 | super(message); 17 | this.name = "McpError"; 18 | } 19 | } 20 | 21 | // Supported image formats and their MIME types 22 | export const IMAGE_MIME_TYPES = { 23 | "image/jpeg": [".jpg", ".jpeg"], 24 | "image/png": [".png"], 25 | "image/gif": [".gif"], 26 | "image/webp": [".webp"], 27 | } as const; 28 | 29 | export type ImageFormat = keyof typeof IMAGE_MIME_TYPES; 30 | 31 | interface ImageMetadata { 32 | format: ImageFormat; 33 | url: string; 34 | localPath?: string; 35 | width?: number; 36 | height?: number; 37 | } 38 | 39 | // Create a specialized cache for images with 1 hour TTL 40 | export const imageCache = new Cache(200, 3600); 41 | 42 | export class ImageViewer { 43 | private static instance: ImageViewer; 44 | 45 | private constructor() { 46 | // Private constructor for singleton pattern 47 | } 48 | 49 | /** 50 | * Get singleton instance of ImageViewer 51 | */ 52 | static getInstance(): ImageViewer { 53 | if (!ImageViewer.instance) { 54 | ImageViewer.instance = new ImageViewer(); 55 | } 56 | return ImageViewer.instance; 57 | } 58 | 59 | /** 60 | * Display an image in the system's default web browser 61 | */ 62 | async displayImage(url: string): Promise { 63 | try { 64 | // Check cache first 65 | const cached = imageCache.get(url); 66 | if (cached) { 67 | await this.openInBrowser(cached.localPath || cached.url); 68 | return; 69 | } 70 | 71 | // If not cached, fetch and cache the image metadata 72 | const metadata = await this.fetchImageMetadata(url); 73 | imageCache.set(url, metadata); 74 | 75 | // Display the image 76 | await this.openInBrowser(metadata.localPath || url); 77 | } catch (error: unknown) { 78 | throw new McpError( 79 | ErrorCode.InternalError, 80 | `Failed to display image: ${ 81 | error instanceof Error ? error.message : String(error) 82 | }` 83 | ); 84 | } 85 | } 86 | 87 | /** 88 | * Fetch metadata for an image URL 89 | */ 90 | private async fetchImageMetadata(url: string): Promise { 91 | // For now, just return basic metadata without strict MIME type checking 92 | return { 93 | format: "image/png", // Default format 94 | url: url, 95 | }; 96 | } 97 | 98 | /** 99 | * Open an image URL in the system's default web browser 100 | */ 101 | private async openInBrowser(url: string): Promise { 102 | try { 103 | // Create a simple HTML page to display the image 104 | const html = ` 105 | 106 | 107 | 108 | Image Viewer 109 | 125 | 126 | 127 | Image preview 128 | 129 | `; 130 | 131 | // Create a temporary file to host the HTML 132 | const tempPath = `/tmp/mcp-image-viewer-${Date.now()}.html`; 133 | const fs = await import("node:fs/promises"); 134 | await fs.writeFile(tempPath, html); 135 | 136 | // Launch browser with the HTML file 137 | const fileUrl = `file://${tempPath}`; 138 | 139 | // Use the system's browser_action tool 140 | const { exec } = await import("node:child_process"); 141 | const { promisify } = await import("node:util"); 142 | const execAsync = promisify(exec); 143 | 144 | // Open the file in the default browser 145 | if (process.platform === "darwin") { 146 | await execAsync(`open "${fileUrl}"`); 147 | } else if (process.platform === "win32") { 148 | await execAsync(`start "" "${fileUrl}"`); 149 | } else { 150 | await execAsync(`xdg-open "${fileUrl}"`); 151 | } 152 | 153 | // Clean up the temporary file after a delay to ensure browser has loaded it 154 | setTimeout(async () => { 155 | try { 156 | await fs.unlink(tempPath); 157 | } catch (error) { 158 | console.error("Failed to clean up temporary file:", error); 159 | } 160 | }, 5000); 161 | } catch (error: unknown) { 162 | throw new McpError( 163 | ErrorCode.InternalError, 164 | `Failed to open browser: ${ 165 | error instanceof Error ? error.message : String(error) 166 | }` 167 | ); 168 | } 169 | } 170 | 171 | /** 172 | * Clear the image cache 173 | */ 174 | clearCache(): void { 175 | imageCache.clear(); 176 | } 177 | 178 | /** 179 | * Get statistics about the image cache 180 | */ 181 | getCacheStats(): CacheStats { 182 | return imageCache.getStats(); 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /src/services/webhook.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Webhook delivery service. 3 | */ 4 | 5 | import { 6 | type WebhookConfig, 7 | type WebhookEvent, 8 | type WebhookDeliveryResult, 9 | DEFAULT_WEBHOOK_CONFIG, 10 | generateWebhookSignature, 11 | calculateRetryDelay, 12 | validateWebhookConfig, 13 | } from "../models/webhook.js"; 14 | 15 | interface QueuedWebhook { 16 | id: string; 17 | config: WebhookConfig; 18 | event: WebhookEvent; 19 | retryCount: number; 20 | lastAttempt?: Date; 21 | nextAttempt?: Date; 22 | } 23 | 24 | /** 25 | * Service for managing webhook deliveries with retry logic. 26 | */ 27 | export class WebhookService { 28 | private queue: Map; 29 | private processing: boolean; 30 | private deliveryResults: Map; 31 | 32 | constructor() { 33 | this.queue = new Map(); 34 | this.processing = false; 35 | this.deliveryResults = new Map(); 36 | } 37 | 38 | /** 39 | * Validate webhook configuration. 40 | */ 41 | validateWebhookConfig(config: WebhookConfig): string[] { 42 | return validateWebhookConfig(config); 43 | } 44 | 45 | /** 46 | * Queue a webhook event for delivery. 47 | */ 48 | async queueWebhook( 49 | config: Partial, 50 | event: WebhookEvent 51 | ): Promise { 52 | const id = crypto.randomUUID(); 53 | const fullConfig = { 54 | ...DEFAULT_WEBHOOK_CONFIG, 55 | ...config, 56 | } as WebhookConfig; 57 | 58 | this.queue.set(id, { 59 | id, 60 | config: fullConfig, 61 | event, 62 | retryCount: 0, 63 | }); 64 | 65 | // Start processing if not already running 66 | if (!this.processing) { 67 | this.processQueue().catch(console.error); 68 | } 69 | 70 | return id; 71 | } 72 | 73 | /** 74 | * Get delivery results for a webhook. 75 | */ 76 | getDeliveryResults(webhookId: string): WebhookDeliveryResult[] { 77 | return this.deliveryResults.get(webhookId) || []; 78 | } 79 | 80 | /** 81 | * Process the webhook queue. 82 | */ 83 | private async processQueue(): Promise { 84 | if (this.processing) return; 85 | 86 | this.processing = true; 87 | try { 88 | // Process all webhooks immediately in test environment 89 | if (process.env.NODE_ENV === "test") { 90 | const webhooks = Array.from(this.queue.values()); 91 | await Promise.all( 92 | webhooks.map((webhook) => this.deliverWebhook(webhook)) 93 | ); 94 | return; 95 | } 96 | 97 | // Normal processing for non-test environment 98 | while (this.queue.size > 0) { 99 | const now = new Date(); 100 | const readyWebhooks = Array.from(this.queue.values()).filter( 101 | (webhook) => !webhook.nextAttempt || webhook.nextAttempt <= now 102 | ); 103 | 104 | if (readyWebhooks.length === 0) { 105 | // No webhooks ready for delivery, wait for the next one 106 | const nextAttempt = Math.min( 107 | ...Array.from(this.queue.values()) 108 | .map((w) => w.nextAttempt?.getTime() || Date.now()) 109 | .filter((t) => t > Date.now()) 110 | ); 111 | await new Promise((resolve) => 112 | setTimeout(resolve, nextAttempt - Date.now()) 113 | ); 114 | continue; 115 | } 116 | 117 | // Process ready webhooks in parallel 118 | await Promise.all( 119 | readyWebhooks.map((webhook) => this.deliverWebhook(webhook)) 120 | ); 121 | } 122 | } finally { 123 | this.processing = false; 124 | } 125 | } 126 | 127 | /** 128 | * Attempt to deliver a webhook. 129 | */ 130 | private async deliverWebhook(webhook: QueuedWebhook): Promise { 131 | const { id, config, event, retryCount } = webhook; 132 | const payload = JSON.stringify(event); 133 | 134 | // Prepare headers 135 | const headers: Record = { 136 | "Content-Type": "application/json", 137 | "User-Agent": "MCP-Replicate-Webhook/1.0", 138 | "X-Webhook-ID": id, 139 | "X-Event-Type": event.type, 140 | "X-Timestamp": event.timestamp, 141 | }; 142 | 143 | // Add signature if secret is provided 144 | if (config.secret) { 145 | headers["X-Signature"] = generateWebhookSignature(payload, config.secret); 146 | } 147 | 148 | try { 149 | // Attempt delivery with timeout 150 | const controller = new AbortController(); 151 | const timeout = setTimeout( 152 | () => controller.abort(), 153 | config.timeout || DEFAULT_WEBHOOK_CONFIG.timeout 154 | ); 155 | 156 | let response: Response; 157 | try { 158 | response = await fetch(config.url, { 159 | method: "POST", 160 | headers, 161 | body: payload, 162 | signal: controller.signal, 163 | }); 164 | 165 | clearTimeout(timeout); 166 | 167 | // Record result 168 | const result: WebhookDeliveryResult = { 169 | success: response.ok, 170 | statusCode: response.status, 171 | retryCount, 172 | timestamp: new Date().toISOString(), 173 | }; 174 | 175 | if (!response.ok) { 176 | result.error = `HTTP ${response.status}: ${response.statusText}`; 177 | } 178 | 179 | this.recordDeliveryResult(id, result); 180 | 181 | // Handle failed delivery 182 | const maxRetries = 183 | config.retries ?? DEFAULT_WEBHOOK_CONFIG.retries ?? 3; 184 | if (!response.ok && retryCount < maxRetries) { 185 | // Schedule retry 186 | const delay = calculateRetryDelay(retryCount); 187 | webhook.retryCount++; 188 | webhook.lastAttempt = new Date(); 189 | webhook.nextAttempt = new Date(Date.now() + delay); 190 | 191 | // In test environment, immediately retry 192 | if (process.env.NODE_ENV === "test") { 193 | await this.deliverWebhook(webhook); 194 | } 195 | return; 196 | } 197 | 198 | // Delivery succeeded or max retries reached 199 | this.queue.delete(id); 200 | } catch (error) { 201 | clearTimeout(timeout); 202 | throw error; 203 | } 204 | } catch (error) { 205 | // Record error result 206 | const result: WebhookDeliveryResult = { 207 | success: false, 208 | error: error instanceof Error ? error.message : String(error), 209 | retryCount, 210 | timestamp: new Date().toISOString(), 211 | }; 212 | 213 | this.recordDeliveryResult(id, result); 214 | 215 | // Handle error 216 | const maxRetries = config.retries ?? DEFAULT_WEBHOOK_CONFIG.retries ?? 3; 217 | if (retryCount < maxRetries) { 218 | // Schedule retry 219 | const delay = calculateRetryDelay(retryCount); 220 | webhook.retryCount++; 221 | webhook.lastAttempt = new Date(); 222 | webhook.nextAttempt = new Date(Date.now() + delay); 223 | 224 | // In test environment, immediately retry 225 | if (process.env.NODE_ENV === "test") { 226 | await this.deliverWebhook(webhook); 227 | } 228 | return; 229 | } 230 | 231 | // Max retries reached 232 | this.queue.delete(id); 233 | } 234 | } 235 | 236 | /** 237 | * Record a delivery result. 238 | */ 239 | private recordDeliveryResult( 240 | webhookId: string, 241 | result: WebhookDeliveryResult 242 | ): void { 243 | const results = this.deliveryResults.get(webhookId) || []; 244 | results.push(result); 245 | this.deliveryResults.set(webhookId, results); 246 | 247 | // Clean up old results (keep last 10) 248 | if (results.length > 10) { 249 | results.splice(0, results.length - 10); 250 | } 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /src/templates/manager.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Template manager for handling image generation parameters. 3 | */ 4 | 5 | import { qualityPresets, type QualityPreset } from "./parameters/quality.js"; 6 | import { stylePresets, type StylePreset } from "./parameters/style.js"; 7 | import { 8 | sizePresets, 9 | type SizePreset, 10 | scaleToMaxSize, 11 | } from "./parameters/size.js"; 12 | 13 | import type { ModelIO } from "../models/openapi.js"; 14 | 15 | export interface ImageGenerationParameters extends ModelIO { 16 | prompt: string; 17 | negative_prompt?: string; 18 | width: number; 19 | height: number; 20 | num_inference_steps?: number; 21 | guidance_scale?: number; 22 | scheduler?: string; 23 | style_strength?: number; 24 | seed?: number; 25 | num_outputs?: number; 26 | } 27 | 28 | export interface TemplateOptions { 29 | quality?: keyof typeof qualityPresets; 30 | style?: keyof typeof stylePresets; 31 | size?: keyof typeof sizePresets; 32 | custom_size?: { width: number; height: number }; 33 | seed?: number; 34 | num_outputs?: number; 35 | } 36 | 37 | /** 38 | * Manages templates and parameter generation for image generation. 39 | */ 40 | export class TemplateManager { 41 | private maxImageSize: number; 42 | 43 | constructor(maxImageSize = 1024) { 44 | this.maxImageSize = maxImageSize; 45 | } 46 | 47 | /** 48 | * Get all available presets. 49 | */ 50 | getAvailablePresets() { 51 | return { 52 | quality: Object.entries(qualityPresets).map(([id, preset]) => ({ 53 | id, 54 | ...preset, 55 | })), 56 | style: Object.entries(stylePresets).map(([id, preset]) => ({ 57 | id, 58 | ...preset, 59 | })), 60 | size: Object.entries(sizePresets).map(([id, preset]) => ({ 61 | id, 62 | ...preset, 63 | })), 64 | }; 65 | } 66 | 67 | /** 68 | * Generate parameters by combining presets and options. 69 | */ 70 | generateParameters( 71 | prompt: string, 72 | options: TemplateOptions = {} 73 | ): ImageGenerationParameters { 74 | // Get presets 75 | const qualityPreset = options.quality 76 | ? qualityPresets[options.quality] 77 | : qualityPresets.balanced; 78 | const stylePreset = options.style 79 | ? stylePresets[options.style] 80 | : stylePresets.photorealistic; 81 | const sizePreset = options.size 82 | ? sizePresets[options.size] 83 | : sizePresets.square; 84 | 85 | // Handle custom size 86 | let { width, height } = options.custom_size || sizePreset.parameters; 87 | if (width > this.maxImageSize || height > this.maxImageSize) { 88 | ({ width, height } = scaleToMaxSize(width, height, this.maxImageSize)); 89 | } 90 | 91 | // Combine prompts 92 | const fullPrompt = [ 93 | stylePreset.parameters.prompt_prefix, 94 | prompt.trim(), 95 | stylePreset.parameters.prompt_suffix, 96 | ] 97 | .filter(Boolean) 98 | .join(" "); 99 | 100 | // Combine negative prompts 101 | const negativePrompts = [ 102 | qualityPreset.parameters.negative_prompt, 103 | stylePreset.parameters.negative_prompt, 104 | ] 105 | .filter(Boolean) 106 | .join(", "); 107 | 108 | // Combine parameters 109 | return { 110 | prompt: fullPrompt, 111 | negative_prompt: negativePrompts || undefined, 112 | width, 113 | height, 114 | num_inference_steps: qualityPreset.parameters.num_inference_steps, 115 | guidance_scale: qualityPreset.parameters.guidance_scale, 116 | scheduler: qualityPreset.parameters.scheduler, 117 | style_strength: stylePreset.parameters.style_strength, 118 | seed: options.seed, 119 | num_outputs: options.num_outputs || 1, 120 | }; 121 | } 122 | 123 | /** 124 | * Validate parameters against model constraints. 125 | */ 126 | validateParameters( 127 | parameters: ImageGenerationParameters, 128 | modelConstraints: { 129 | min_width?: number; 130 | max_width?: number; 131 | min_height?: number; 132 | max_height?: number; 133 | step_size?: number; 134 | supported_schedulers?: string[]; 135 | } = {} 136 | ): void { 137 | const errors: string[] = []; 138 | 139 | // Basic validation 140 | if (parameters.width <= 0) { 141 | errors.push("Width must be positive"); 142 | } 143 | if (parameters.height <= 0) { 144 | errors.push("Height must be positive"); 145 | } 146 | 147 | // Model constraints validation 148 | if ( 149 | modelConstraints.min_width && 150 | parameters.width < modelConstraints.min_width 151 | ) { 152 | errors.push(`Width must be at least ${modelConstraints.min_width}`); 153 | } 154 | if ( 155 | modelConstraints.max_width && 156 | parameters.width > modelConstraints.max_width 157 | ) { 158 | errors.push(`Width must be at most ${modelConstraints.max_width}`); 159 | } 160 | if ( 161 | modelConstraints.min_height && 162 | parameters.height < modelConstraints.min_height 163 | ) { 164 | errors.push(`Height must be at least ${modelConstraints.min_height}`); 165 | } 166 | if ( 167 | modelConstraints.max_height && 168 | parameters.height > modelConstraints.max_height 169 | ) { 170 | errors.push(`Height must be at most ${modelConstraints.max_height}`); 171 | } 172 | 173 | // Validate step size 174 | if (modelConstraints.step_size) { 175 | if (parameters.width % modelConstraints.step_size !== 0) { 176 | errors.push( 177 | `Width must be a multiple of ${modelConstraints.step_size}` 178 | ); 179 | } 180 | if (parameters.height % modelConstraints.step_size !== 0) { 181 | errors.push( 182 | `Height must be a multiple of ${modelConstraints.step_size}` 183 | ); 184 | } 185 | } 186 | 187 | // Validate scheduler 188 | if ( 189 | modelConstraints.supported_schedulers && 190 | parameters.scheduler && 191 | !modelConstraints.supported_schedulers.includes(parameters.scheduler) 192 | ) { 193 | errors.push( 194 | `Scheduler must be one of: ${modelConstraints.supported_schedulers.join( 195 | ", " 196 | )}` 197 | ); 198 | } 199 | 200 | // Validate other parameters 201 | if (parameters.num_inference_steps && parameters.num_inference_steps < 1) { 202 | errors.push("Number of inference steps must be positive"); 203 | } 204 | if (parameters.guidance_scale && parameters.guidance_scale < 1) { 205 | errors.push("Guidance scale must be positive"); 206 | } 207 | if ( 208 | parameters.style_strength && 209 | (parameters.style_strength < 0 || parameters.style_strength > 1) 210 | ) { 211 | errors.push("Style strength must be between 0 and 1"); 212 | } 213 | if (parameters.num_outputs && parameters.num_outputs < 1) { 214 | errors.push("Number of outputs must be positive"); 215 | } 216 | 217 | if (errors.length > 0) { 218 | throw new Error(`Parameter validation failed:\n${errors.join("\n")}`); 219 | } 220 | } 221 | 222 | /** 223 | * Suggest parameters based on prompt analysis. 224 | */ 225 | suggestParameters(prompt: string): TemplateOptions { 226 | // Simple keyword-based suggestions 227 | const suggestions: TemplateOptions = {}; 228 | 229 | // Quality suggestions 230 | if (prompt.includes("quick") || prompt.includes("draft")) { 231 | suggestions.quality = "draft"; 232 | } else if (prompt.includes("high quality") || prompt.includes("detailed")) { 233 | suggestions.quality = "quality"; 234 | } 235 | 236 | // Style suggestions 237 | if (prompt.includes("photo") || prompt.includes("realistic")) { 238 | suggestions.style = "photorealistic"; 239 | } else if (prompt.includes("anime") || prompt.includes("manga")) { 240 | suggestions.style = "anime"; 241 | } else if (prompt.includes("painting") || prompt.includes("oil")) { 242 | suggestions.style = "oil_painting"; 243 | } else if (prompt.includes("watercolor")) { 244 | suggestions.style = "watercolor"; 245 | } else if (prompt.includes("pixel") || prompt.includes("8-bit")) { 246 | suggestions.style = "pixel_art"; 247 | } else if (prompt.includes("minimal")) { 248 | suggestions.style = "minimalist"; 249 | } 250 | 251 | // Size suggestions 252 | if (prompt.includes("portrait") || prompt.includes("vertical")) { 253 | suggestions.size = "portrait"; 254 | } else if (prompt.includes("landscape") || prompt.includes("horizontal")) { 255 | suggestions.size = "landscape"; 256 | } else if (prompt.includes("panorama") || prompt.includes("wide")) { 257 | suggestions.size = "panoramic"; 258 | } else if (prompt.includes("instagram")) { 259 | suggestions.size = prompt.includes("story") 260 | ? "instagram_story" 261 | : "instagram_post"; 262 | } else if (prompt.includes("twitter header")) { 263 | suggestions.size = "twitter_header"; 264 | } 265 | 266 | return suggestions; 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /src/templates/parameters/quality.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Quality presets for image generation. 3 | */ 4 | 5 | export interface QualityPreset { 6 | name: string; 7 | description: string; 8 | parameters: { 9 | num_inference_steps?: number; 10 | guidance_scale?: number; 11 | scheduler?: string; 12 | negative_prompt?: string; 13 | }; 14 | } 15 | 16 | export const qualityPresets: Record = { 17 | draft: { 18 | name: "Draft", 19 | description: "Quick, low-quality preview with minimal steps", 20 | parameters: { 21 | num_inference_steps: 20, 22 | guidance_scale: 7, 23 | scheduler: "DPMSolverMultistep", 24 | }, 25 | }, 26 | balanced: { 27 | name: "Balanced", 28 | description: "Good balance between quality and speed", 29 | parameters: { 30 | num_inference_steps: 30, 31 | guidance_scale: 7.5, 32 | scheduler: "DPMSolverMultistep", 33 | negative_prompt: "blurry, low quality, distorted", 34 | }, 35 | }, 36 | quality: { 37 | name: "Quality", 38 | description: "High-quality output with more steps", 39 | parameters: { 40 | num_inference_steps: 50, 41 | guidance_scale: 8, 42 | scheduler: "DPMSolverMultistep", 43 | negative_prompt: "blurry, low quality, distorted, ugly, deformed", 44 | }, 45 | }, 46 | extreme: { 47 | name: "Extreme", 48 | description: "Maximum quality with extensive steps", 49 | parameters: { 50 | num_inference_steps: 100, 51 | guidance_scale: 9, 52 | scheduler: "DPMSolverMultistep", 53 | negative_prompt: 54 | "blurry, low quality, distorted, ugly, deformed, noisy, grainy, oversaturated", 55 | }, 56 | }, 57 | }; 58 | -------------------------------------------------------------------------------- /src/templates/parameters/size.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Size and aspect ratio presets for image generation. 3 | */ 4 | 5 | export interface SizePreset { 6 | name: string; 7 | description: string; 8 | parameters: { 9 | width: number; 10 | height: number; 11 | aspect_ratio?: string; 12 | recommended_for?: string[]; 13 | }; 14 | } 15 | 16 | export const sizePresets: Record = { 17 | square: { 18 | name: "Square", 19 | description: "Perfect square format", 20 | parameters: { 21 | width: 1024, 22 | height: 1024, 23 | aspect_ratio: "1:1", 24 | recommended_for: ["social media", "profile pictures", "album covers"], 25 | }, 26 | }, 27 | portrait: { 28 | name: "Portrait", 29 | description: "Vertical format for portraits", 30 | parameters: { 31 | width: 768, 32 | height: 1024, 33 | aspect_ratio: "3:4", 34 | recommended_for: ["portraits", "mobile wallpapers", "book covers"], 35 | }, 36 | }, 37 | landscape: { 38 | name: "Landscape", 39 | description: "Horizontal format for landscapes", 40 | parameters: { 41 | width: 1024, 42 | height: 768, 43 | aspect_ratio: "4:3", 44 | recommended_for: ["landscapes", "desktop wallpapers", "banners"], 45 | }, 46 | }, 47 | widescreen: { 48 | name: "Widescreen", 49 | description: "16:9 format for modern displays", 50 | parameters: { 51 | width: 1024, 52 | height: 576, 53 | aspect_ratio: "16:9", 54 | recommended_for: [ 55 | "desktop backgrounds", 56 | "presentations", 57 | "video thumbnails", 58 | ], 59 | }, 60 | }, 61 | panoramic: { 62 | name: "Panoramic", 63 | description: "Extra wide format for panoramas", 64 | parameters: { 65 | width: 1024, 66 | height: 384, 67 | aspect_ratio: "21:9", 68 | recommended_for: [ 69 | "panoramic landscapes", 70 | "ultra-wide displays", 71 | "banners", 72 | ], 73 | }, 74 | }, 75 | instagram_post: { 76 | name: "Instagram Post", 77 | description: "Optimized for Instagram posts", 78 | parameters: { 79 | width: 1080, 80 | height: 1080, 81 | aspect_ratio: "1:1", 82 | recommended_for: ["instagram posts", "social media"], 83 | }, 84 | }, 85 | instagram_story: { 86 | name: "Instagram Story", 87 | description: "Optimized for Instagram stories", 88 | parameters: { 89 | width: 1080, 90 | height: 1920, 91 | aspect_ratio: "9:16", 92 | recommended_for: ["instagram stories", "mobile content"], 93 | }, 94 | }, 95 | twitter_header: { 96 | name: "Twitter Header", 97 | description: "Optimized for Twitter profile headers", 98 | parameters: { 99 | width: 1500, 100 | height: 500, 101 | aspect_ratio: "3:1", 102 | recommended_for: ["twitter headers", "social media banners"], 103 | }, 104 | }, 105 | }; 106 | 107 | /** 108 | * Get the closest size preset for given dimensions. 109 | */ 110 | export function findClosestSizePreset( 111 | width: number, 112 | height: number 113 | ): SizePreset { 114 | const targetRatio = width / height; 115 | let closestPreset = sizePresets.square; 116 | let smallestDiff = Number.POSITIVE_INFINITY; 117 | 118 | for (const preset of Object.values(sizePresets)) { 119 | const presetRatio = preset.parameters.width / preset.parameters.height; 120 | const ratioDiff = Math.abs(presetRatio - targetRatio); 121 | 122 | if (ratioDiff < smallestDiff) { 123 | smallestDiff = ratioDiff; 124 | closestPreset = preset; 125 | } 126 | } 127 | 128 | return closestPreset; 129 | } 130 | 131 | /** 132 | * Scale dimensions to fit within maximum size while maintaining aspect ratio. 133 | */ 134 | export function scaleToMaxSize( 135 | width: number, 136 | height: number, 137 | maxSize: number 138 | ): { width: number; height: number } { 139 | if (width <= maxSize && height <= maxSize) { 140 | return { width, height }; 141 | } 142 | 143 | const ratio = width / height; 144 | return ratio > 1 145 | ? { 146 | // Width is larger 147 | width: maxSize, 148 | height: Math.round(maxSize / ratio), 149 | } 150 | : { 151 | // Height is larger 152 | width: Math.round(maxSize * ratio), 153 | height: maxSize, 154 | }; 155 | } 156 | -------------------------------------------------------------------------------- /src/templates/parameters/style.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Style presets for image generation. 3 | */ 4 | 5 | export interface StylePreset { 6 | name: string; 7 | description: string; 8 | parameters: { 9 | prompt_prefix?: string; 10 | prompt_suffix?: string; 11 | negative_prompt?: string; 12 | style_strength?: number; 13 | }; 14 | } 15 | 16 | export const stylePresets: Record = { 17 | photorealistic: { 18 | name: "Photorealistic", 19 | description: "Highly detailed, realistic photography style", 20 | parameters: { 21 | prompt_prefix: "photorealistic, highly detailed photograph,", 22 | prompt_suffix: "8k uhd, high resolution, professional photography", 23 | negative_prompt: 24 | "illustration, painting, drawing, cartoon, anime, rendered, 3d, cgi", 25 | style_strength: 0.8, 26 | }, 27 | }, 28 | cinematic: { 29 | name: "Cinematic", 30 | description: "Movie-like scenes with dramatic lighting", 31 | parameters: { 32 | prompt_prefix: "cinematic shot, dramatic lighting, movie scene,", 33 | prompt_suffix: "anamorphic lens, film grain, depth of field, bokeh", 34 | negative_prompt: "flat lighting, flash photography, overexposed", 35 | style_strength: 0.85, 36 | }, 37 | }, 38 | anime: { 39 | name: "Anime", 40 | description: "Japanese anime and manga style", 41 | parameters: { 42 | prompt_prefix: "anime style, manga illustration,", 43 | prompt_suffix: "clean lines, vibrant colors, detailed anime drawing", 44 | negative_prompt: "photorealistic, 3d rendered, western animation", 45 | style_strength: 0.9, 46 | }, 47 | }, 48 | digital_art: { 49 | name: "Digital Art", 50 | description: "Modern digital art style", 51 | parameters: { 52 | prompt_prefix: "digital art, concept art,", 53 | prompt_suffix: "highly detailed, sharp focus, vibrant colors", 54 | negative_prompt: "traditional media, watercolor, oil painting", 55 | style_strength: 0.85, 56 | }, 57 | }, 58 | oil_painting: { 59 | name: "Oil Painting", 60 | description: "Classical oil painting style", 61 | parameters: { 62 | prompt_prefix: "oil painting, traditional art, painterly,", 63 | prompt_suffix: "detailed brushwork, canvas texture, rich colors", 64 | negative_prompt: "digital art, photograph, 3d rendered", 65 | style_strength: 0.9, 66 | }, 67 | }, 68 | watercolor: { 69 | name: "Watercolor", 70 | description: "Soft watercolor painting style", 71 | parameters: { 72 | prompt_prefix: "watercolor painting, soft and dreamy,", 73 | prompt_suffix: "flowing colors, wet on wet technique, artistic", 74 | negative_prompt: "sharp edges, harsh contrast, digital art", 75 | style_strength: 0.85, 76 | }, 77 | }, 78 | pixel_art: { 79 | name: "Pixel Art", 80 | description: "Retro pixel art style", 81 | parameters: { 82 | prompt_prefix: "pixel art, retro game style,", 83 | prompt_suffix: "8-bit, pixelated, video game art", 84 | negative_prompt: "smooth gradients, photorealistic, high resolution", 85 | style_strength: 0.95, 86 | }, 87 | }, 88 | minimalist: { 89 | name: "Minimalist", 90 | description: "Clean, simple minimalist style", 91 | parameters: { 92 | prompt_prefix: "minimalist design, simple composition,", 93 | prompt_suffix: "clean lines, negative space, geometric", 94 | negative_prompt: "busy, cluttered, detailed, ornate", 95 | style_strength: 0.8, 96 | }, 97 | }, 98 | }; 99 | -------------------------------------------------------------------------------- /src/templates/prompts/text_to_image.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Text-to-image generation prompt. 3 | */ 4 | 5 | import type { MCPMessage } from "../../types/mcp.js"; 6 | import { TemplateManager, type TemplateOptions } from "../manager.js"; 7 | 8 | const templateManager = new TemplateManager(); 9 | 10 | /** 11 | * Generate a text-to-image prompt with parameter suggestions. 12 | */ 13 | export function generateTextToImagePrompt(userPrompt: string): MCPMessage { 14 | // Get parameter suggestions based on prompt 15 | const suggestions = templateManager.suggestParameters(userPrompt); 16 | 17 | // Get all available presets for reference 18 | const presets = templateManager.getAvailablePresets(); 19 | 20 | // Generate example parameters 21 | const exampleParams = templateManager.generateParameters( 22 | userPrompt, 23 | suggestions 24 | ); 25 | 26 | return { 27 | jsonrpc: "2.0", 28 | method: "prompt/text_to_image", 29 | params: { 30 | messages: [ 31 | { 32 | role: "assistant", 33 | content: { 34 | type: "text", 35 | text: "I'll help you generate an image. Here's what I understand from your prompt:", 36 | }, 37 | }, 38 | { 39 | role: "assistant", 40 | content: { 41 | type: "text", 42 | text: `Based on your description, I suggest: 43 | ${suggestions.quality ? `- Quality: ${suggestions.quality} mode` : ""} 44 | ${suggestions.style ? `- Style: ${suggestions.style} style` : ""} 45 | ${suggestions.size ? `- Size: ${suggestions.size} format` : ""} 46 | 47 | Here are the parameters I'll use: 48 | - Prompt: "${exampleParams.prompt}" 49 | ${ 50 | exampleParams.negative_prompt 51 | ? `- Negative prompt: "${exampleParams.negative_prompt}"` 52 | : "" 53 | } 54 | - Size: ${exampleParams.width}x${exampleParams.height} 55 | - Steps: ${exampleParams.num_inference_steps} 56 | - Guidance scale: ${exampleParams.guidance_scale} 57 | ${exampleParams.scheduler ? `- Scheduler: ${exampleParams.scheduler}` : ""} 58 | ${ 59 | exampleParams.style_strength 60 | ? `- Style strength: ${exampleParams.style_strength}` 61 | : "" 62 | } 63 | 64 | Would you like to adjust any of these settings? You can choose from: 65 | 66 | Quality presets: 67 | ${presets.quality.map((p) => `- ${p.name}: ${p.description}`).join("\n")} 68 | 69 | Style presets: 70 | ${presets.style.map((p) => `- ${p.name}: ${p.description}`).join("\n")} 71 | 72 | Size presets: 73 | ${presets.size.map((p) => `- ${p.name}: ${p.description}`).join("\n")} 74 | 75 | Or you can specify custom parameters: 76 | - Custom size (e.g., "make it 1024x768") 77 | - Number of images (e.g., "generate 4 variations") 78 | - Seed number for reproducibility 79 | 80 | Let me know if you want to proceed with these settings or make any adjustments.`, 81 | }, 82 | }, 83 | ], 84 | parameters: exampleParams, 85 | suggestions, 86 | presets, 87 | }, 88 | }; 89 | } 90 | 91 | /** 92 | * Parse user response to extract parameter adjustments. 93 | */ 94 | export function parseParameterAdjustments( 95 | response: string 96 | ): Partial { 97 | const adjustments: Partial = {}; 98 | 99 | // Quality adjustments 100 | if (response.match(/\b(draft|quick|fast)\b/i)) { 101 | adjustments.quality = "draft"; 102 | } else if (response.match(/\b(balanced|medium|default)\b/i)) { 103 | adjustments.quality = "balanced"; 104 | } else if (response.match(/\b(quality|high|detailed)\b/i)) { 105 | adjustments.quality = "quality"; 106 | } else if (response.match(/\b(extreme|maximum|best)\b/i)) { 107 | adjustments.quality = "extreme"; 108 | } 109 | 110 | // Style adjustments 111 | if (response.match(/\b(photo|realistic)\b/i)) { 112 | adjustments.style = "photorealistic"; 113 | } else if (response.match(/\b(anime|manga)\b/i)) { 114 | adjustments.style = "anime"; 115 | } else if (response.match(/\b(digital[\s-]?art)\b/i)) { 116 | adjustments.style = "digital_art"; 117 | } else if (response.match(/\b(oil[\s-]?painting)\b/i)) { 118 | adjustments.style = "oil_painting"; 119 | } else if (response.match(/\b(watercolor)\b/i)) { 120 | adjustments.style = "watercolor"; 121 | } else if (response.match(/\b(pixel[\s-]?art|8[\s-]?bit)\b/i)) { 122 | adjustments.style = "pixel_art"; 123 | } else if (response.match(/\b(minimal|minimalist)\b/i)) { 124 | adjustments.style = "minimalist"; 125 | } 126 | 127 | // Size adjustments 128 | if (response.match(/\b(square|1:1)\b/i)) { 129 | adjustments.size = "square"; 130 | } else if (response.match(/\b(portrait|vertical|3:4)\b/i)) { 131 | adjustments.size = "portrait"; 132 | } else if (response.match(/\b(landscape|horizontal|4:3)\b/i)) { 133 | adjustments.size = "landscape"; 134 | } else if (response.match(/\b(widescreen|16:9)\b/i)) { 135 | adjustments.size = "widescreen"; 136 | } else if (response.match(/\b(panoramic|21:9)\b/i)) { 137 | adjustments.size = "panoramic"; 138 | } else if (response.match(/\b(instagram[\s-]?story)\b/i)) { 139 | adjustments.size = "instagram_story"; 140 | } else if (response.match(/\b(instagram[\s-]?post)\b/i)) { 141 | adjustments.size = "instagram_post"; 142 | } else if (response.match(/\b(twitter[\s-]?header)\b/i)) { 143 | adjustments.size = "twitter_header"; 144 | } 145 | 146 | // Custom size 147 | const sizeMatch = response.match(/(\d+)\s*x\s*(\d+)/i); 148 | if (sizeMatch) { 149 | adjustments.custom_size = { 150 | width: Number.parseInt(sizeMatch[1], 10), 151 | height: Number.parseInt(sizeMatch[2], 10), 152 | }; 153 | } 154 | 155 | // Number of outputs 156 | const numMatch = response.match( 157 | /\b(\d+)\s*(outputs?|variations?|images?)\b/i 158 | ); 159 | if (numMatch) { 160 | adjustments.num_outputs = Number.parseInt(numMatch[1], 10); 161 | } 162 | 163 | // Seed 164 | const seedMatch = response.match(/\bseed\s*[=:]?\s*(\d+)\b/i); 165 | if (seedMatch) { 166 | adjustments.seed = Number.parseInt(seedMatch[1], 10); 167 | } 168 | 169 | return adjustments; 170 | } 171 | -------------------------------------------------------------------------------- /src/tests/integration.test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Integration tests with Replicate API. 3 | */ 4 | 5 | import { 6 | describe, 7 | it, 8 | expect, 9 | beforeEach, 10 | afterEach, 11 | vi, 12 | beforeAll, 13 | afterAll, 14 | } from "vitest"; 15 | import { ReplicateClient } from "../replicate_client.js"; 16 | import { WebhookService } from "../services/webhook.js"; 17 | import { TemplateManager } from "../templates/manager.js"; 18 | import { ErrorHandler, ReplicateError } from "../services/error.js"; 19 | import { Cache } from "../services/cache.js"; 20 | import type { Model } from "../models/model.js"; 21 | import type { Prediction } from "../models/prediction.js"; 22 | import { PredictionStatus } from "../models/prediction.js"; 23 | import type { SchemaObject, PropertyObject } from "../models/openapi.js"; 24 | import type { WebhookEvent } from "../models/webhook.js"; 25 | import { createError } from "../services/error.js"; 26 | 27 | // Mock environment variables 28 | process.env.REPLICATE_API_TOKEN = "test_token"; 29 | 30 | describe("Replicate API Integration", () => { 31 | let client: ReplicateClient; 32 | let webhookService: WebhookService; 33 | let templateManager: TemplateManager; 34 | let cache: Cache; 35 | 36 | beforeEach(() => { 37 | // Initialize components 38 | client = new ReplicateClient(); 39 | webhookService = new WebhookService(); 40 | templateManager = new TemplateManager(); 41 | cache = new Cache(); 42 | 43 | // Mock client methods 44 | vi.spyOn(client, "listModels").mockResolvedValue({ 45 | models: [ 46 | { 47 | owner: "stability-ai", 48 | name: "sdxl", 49 | description: "Test model", 50 | id: "stability-ai/sdxl", 51 | visibility: "public", 52 | latest_version: { 53 | id: "test-version", 54 | created_at: new Date().toISOString(), 55 | cog_version: "0.3.0", 56 | openapi_schema: { 57 | openapi: "3.0.0", 58 | info: { 59 | title: "Test Model API", 60 | version: "1.0.0", 61 | }, 62 | paths: {}, 63 | components: { 64 | schemas: { 65 | Input: { 66 | type: "object", 67 | required: ["prompt"], 68 | properties: { 69 | prompt: { type: "string" }, 70 | width: { type: "number", minimum: 0 }, 71 | height: { type: "number", minimum: 0 }, 72 | }, 73 | }, 74 | }, 75 | }, 76 | }, 77 | }, 78 | }, 79 | ], 80 | }); 81 | 82 | vi.spyOn(client, "searchModels").mockResolvedValue({ 83 | models: [ 84 | { 85 | id: "stability-ai/sdxl", 86 | owner: "stability-ai", 87 | name: "sdxl", 88 | description: "Text to image model", 89 | visibility: "public", 90 | }, 91 | ], 92 | }); 93 | 94 | vi.spyOn(client, "getModel").mockResolvedValue({ 95 | owner: "stability-ai", 96 | name: "sdxl", 97 | description: "Test model", 98 | id: "stability-ai/sdxl", 99 | visibility: "public", 100 | latest_version: { 101 | id: "test-version", 102 | created_at: new Date().toISOString(), 103 | cog_version: "0.3.0", 104 | openapi_schema: { 105 | openapi: "3.0.0", 106 | info: { 107 | title: "Test Model API", 108 | version: "1.0.0", 109 | }, 110 | paths: {}, 111 | components: { 112 | schemas: { 113 | Input: { 114 | type: "object", 115 | required: ["prompt"], 116 | properties: { 117 | prompt: { type: "string" }, 118 | width: { type: "number", minimum: 0 }, 119 | height: { type: "number", minimum: 0 }, 120 | }, 121 | }, 122 | }, 123 | }, 124 | }, 125 | }, 126 | }); 127 | 128 | vi.spyOn(client, "createPrediction").mockResolvedValue({ 129 | id: "test-prediction", 130 | version: "test-version", 131 | status: PredictionStatus.Starting, 132 | input: { prompt: "test" }, 133 | created_at: new Date().toISOString(), 134 | urls: {}, 135 | }); 136 | 137 | vi.spyOn(client, "getPredictionStatus").mockResolvedValue({ 138 | id: "test-prediction", 139 | version: "test-version", 140 | status: PredictionStatus.Succeeded, 141 | input: { prompt: "test" }, 142 | output: { image: "test.png" }, 143 | created_at: new Date().toISOString(), 144 | urls: {}, 145 | }); 146 | 147 | vi.spyOn(client, "listPredictions").mockResolvedValue([ 148 | { 149 | id: "test-prediction", 150 | version: "test-version", 151 | status: PredictionStatus.Succeeded, 152 | input: { prompt: "test" }, 153 | output: { image: "test.png" }, 154 | created_at: new Date().toISOString(), 155 | urls: {}, 156 | }, 157 | ]); 158 | 159 | vi.spyOn(client, "listCollections").mockResolvedValue({ 160 | collections: [ 161 | { 162 | id: "test-collection", 163 | name: "Test Collection", 164 | slug: "test", 165 | description: "Test collection", 166 | models: [], 167 | created_at: new Date().toISOString(), 168 | }, 169 | ], 170 | }); 171 | 172 | vi.spyOn(client, "getCollection").mockResolvedValue({ 173 | id: "test-collection", 174 | name: "Test Collection", 175 | slug: "test", 176 | description: "Test collection", 177 | models: [], 178 | created_at: new Date().toISOString(), 179 | }); 180 | 181 | // Mock webhook service 182 | vi.spyOn(webhookService, "queueWebhook").mockResolvedValue("test-webhook"); 183 | vi.spyOn(webhookService, "getDeliveryResults").mockReturnValue([ 184 | { 185 | success: false, 186 | error: "Mock delivery failure", 187 | retryCount: 0, 188 | timestamp: new Date().toISOString(), 189 | }, 190 | ]); 191 | }); 192 | 193 | afterEach(() => { 194 | cache.clear(); 195 | vi.clearAllMocks(); 196 | }); 197 | 198 | describe("Error Handling", () => { 199 | it("should handle authentication errors", async () => { 200 | const invalidClient = new ReplicateClient("invalid_token"); 201 | vi.spyOn(invalidClient, "listModels").mockRejectedValue( 202 | createError.authentication() 203 | ); 204 | 205 | await expect(invalidClient.listModels()).rejects.toThrow(ReplicateError); 206 | }); 207 | 208 | it("should handle rate limit errors with retries", async () => { 209 | const rateLimitError = createError.rateLimit(1); 210 | 211 | // Mock client to throw rate limit error once then succeed 212 | let attempts = 0; 213 | vi.spyOn(client, "listModels").mockImplementation(async () => { 214 | if (attempts === 0) { 215 | attempts++; 216 | throw rateLimitError; 217 | } 218 | return { models: [] }; 219 | }); 220 | 221 | const result = await ErrorHandler.withRetries( 222 | async () => client.listModels(), 223 | { 224 | maxAttempts: 2, 225 | retryIf: (error: Error) => error instanceof ReplicateError, 226 | } 227 | ); 228 | 229 | expect(attempts).toBe(1); 230 | expect(result).toEqual({ models: [] }); 231 | }); 232 | 233 | it("should handle network errors with retries", async () => { 234 | const networkError = createError.api(500, "Connection failed"); 235 | 236 | // Mock client to throw network error twice then succeed 237 | let attempts = 0; 238 | vi.spyOn(client, "listModels").mockImplementation(async () => { 239 | if (attempts < 2) { 240 | attempts++; 241 | throw networkError; 242 | } 243 | return { models: [] }; 244 | }); 245 | 246 | const result = await ErrorHandler.withRetries( 247 | async () => client.listModels(), 248 | { 249 | maxAttempts: 3, 250 | retryIf: (error: Error) => error instanceof ReplicateError, 251 | } 252 | ); 253 | 254 | expect(attempts).toBe(2); 255 | expect(result).toEqual({ models: [] }); 256 | }); 257 | 258 | it("should handle validation errors", async () => { 259 | vi.spyOn(client, "createPrediction").mockRejectedValue( 260 | createError.validation("version", "Invalid version") 261 | ); 262 | 263 | await expect( 264 | client.createPrediction({ 265 | version: "invalid-version", 266 | input: {}, 267 | }) 268 | ).rejects.toThrow(ReplicateError); 269 | }); 270 | 271 | it("should generate detailed error reports", async () => { 272 | const error = createError.validation("test", "Test error"); 273 | 274 | const report = ErrorHandler.createErrorReport(error); 275 | expect(report).toMatchObject({ 276 | name: "ReplicateError", 277 | message: "Invalid input parameters", 278 | context: { 279 | field: "test", 280 | message: "Test error", 281 | }, 282 | timestamp: expect.any(String), 283 | }); 284 | }); 285 | }); 286 | 287 | describe("Caching Behavior", () => { 288 | it("should cache model listings", async () => { 289 | // First request should hit API 290 | const result1 = await client.listModels(); 291 | expect(result1.models.length).toBeGreaterThan(0); 292 | 293 | // Mock cache hit 294 | const result2 = await client.listModels(); 295 | expect(result2).toEqual(result1); 296 | }); 297 | 298 | it("should cache model details with TTL", async () => { 299 | const owner = "stability-ai"; 300 | const name = "sdxl"; 301 | 302 | // First request should hit API 303 | const model1 = await client.getModel(owner, name); 304 | expect(model1).toBeDefined(); 305 | 306 | // Mock cache hit 307 | const model2 = await client.getModel(owner, name); 308 | expect(model2).toEqual(model1); 309 | 310 | // Advance time past TTL 311 | vi.useFakeTimers(); 312 | vi.advanceTimersByTime(25 * 60 * 60 * 1000); // 25 hours 313 | 314 | // Request should hit API again 315 | const model3 = await client.getModel(owner, name); 316 | expect(model3).toBeDefined(); 317 | }); 318 | 319 | it("should handle cache invalidation for predictions", async () => { 320 | const prediction = await client.createPrediction({ 321 | version: "stability-ai/sdxl@latest", 322 | input: { 323 | prompt: "test", 324 | }, 325 | }); 326 | 327 | // Initial status should be cached 328 | const status1 = await client.getPredictionStatus(prediction.id); 329 | const status2 = await client.getPredictionStatus(prediction.id); 330 | expect(status2).toEqual(status1); 331 | 332 | // Completed predictions should stay cached 333 | if (status2.status === PredictionStatus.Succeeded) { 334 | const status3 = await client.getPredictionStatus(prediction.id); 335 | expect(status3).toEqual(status2); 336 | } 337 | // In-progress predictions should refresh 338 | else { 339 | const status3 = await client.getPredictionStatus(prediction.id); 340 | expect(status3).toBeDefined(); 341 | } 342 | }); 343 | }); 344 | 345 | describe("Model Operations", () => { 346 | it("should list available models", async () => { 347 | const result = await client.listModels(); 348 | expect(result.models).toBeInstanceOf(Array); 349 | expect(result.models.length).toBeGreaterThan(0); 350 | 351 | const model = result.models[0]; 352 | expect(model).toHaveProperty("owner"); 353 | expect(model).toHaveProperty("name"); 354 | expect(model).toHaveProperty("description"); 355 | }); 356 | 357 | it("should search models by query", async () => { 358 | const query = "text to image"; 359 | const result = await client.searchModels(query); 360 | expect(result.models).toBeInstanceOf(Array); 361 | expect(result.models.length).toBeGreaterThan(0); 362 | 363 | // Results should be relevant to the query 364 | const relevantModels = result.models.filter( 365 | (model) => 366 | model.description?.toLowerCase().includes("text") || 367 | model.description?.toLowerCase().includes("image") 368 | ); 369 | expect(relevantModels.length).toBeGreaterThan(0); 370 | }); 371 | 372 | it("should get model details", async () => { 373 | // Use a known stable model 374 | const owner = "stability-ai"; 375 | const name = "sdxl"; 376 | const model = await client.getModel(owner, name); 377 | 378 | expect(model.owner).toBe(owner); 379 | expect(model.name).toBe(name); 380 | expect(model.latest_version).toBeDefined(); 381 | expect(model.latest_version?.openapi_schema).toBeDefined(); 382 | }); 383 | }); 384 | 385 | describe("Prediction Operations", () => { 386 | let testModel: Model; 387 | 388 | beforeEach(async () => { 389 | // Get a test model for predictions 390 | const models = await client.listModels(); 391 | const foundModel = models.models.find( 392 | (m) => m.owner === "stability-ai" && m.name === "sdxl" 393 | ); 394 | if (!foundModel || !foundModel.latest_version) { 395 | throw new Error("Test model not found or missing version"); 396 | } 397 | testModel = foundModel; 398 | expect(testModel).toBeDefined(); 399 | }); 400 | 401 | it("should create prediction with community model version", async () => { 402 | if (!testModel.latest_version) { 403 | throw new Error("Test model missing version"); 404 | } 405 | const prediction = await client.createPrediction({ 406 | version: testModel.latest_version.id, 407 | input: { prompt: "test" }, 408 | }); 409 | 410 | expect(prediction.id).toBeDefined(); 411 | expect(prediction.status).toBe(PredictionStatus.Starting); 412 | expect(prediction.version).toBe(testModel.latest_version.id); 413 | }); 414 | 415 | it("should create prediction with official model", async () => { 416 | const prediction = await client.createPrediction({ 417 | model: "stability-ai/sdxl", 418 | input: { prompt: "test" }, 419 | }); 420 | 421 | expect(prediction.id).toBeDefined(); 422 | expect(prediction.status).toBe(PredictionStatus.Starting); 423 | }); 424 | 425 | it("should create and track prediction", async () => { 426 | if (!testModel.latest_version) { 427 | throw new Error("Test model missing version"); 428 | } 429 | const prompt = "a photo of a mountain landscape at sunset"; 430 | const params = templateManager.generateParameters(prompt, { 431 | quality: "quality", 432 | style: "photorealistic", 433 | size: "landscape", 434 | }); 435 | 436 | // Create prediction 437 | const prediction = await client.createPrediction({ 438 | version: testModel.latest_version.id, 439 | input: params as Record, 440 | }); 441 | 442 | expect(prediction.id).toBeDefined(); 443 | expect(prediction.status).toBe(PredictionStatus.Starting); 444 | 445 | // Mock the status progression 446 | vi.spyOn(client, "getPredictionStatus") 447 | .mockResolvedValueOnce({ 448 | ...prediction, 449 | status: PredictionStatus.Processing, 450 | }) 451 | .mockResolvedValueOnce({ 452 | ...prediction, 453 | status: PredictionStatus.Succeeded, 454 | output: { image: "test.png" }, 455 | }); 456 | 457 | // Check processing status 458 | const processingStatus = await client.getPredictionStatus(prediction.id); 459 | expect(processingStatus.status).toBe(PredictionStatus.Processing); 460 | 461 | // Check final status 462 | const finalStatus = await client.getPredictionStatus(prediction.id); 463 | expect(finalStatus.status).toBe(PredictionStatus.Succeeded); 464 | expect(finalStatus.output).toBeDefined(); 465 | }); 466 | 467 | it("should handle webhook notifications", async () => { 468 | if (!testModel.latest_version) { 469 | throw new Error("Test model missing version"); 470 | } 471 | // Setup fake timers 472 | vi.useFakeTimers(); 473 | 474 | const prompt = "a photo of a mountain landscape at sunset"; 475 | const params = templateManager.generateParameters(prompt, { 476 | quality: "quality", 477 | style: "photorealistic", 478 | size: "landscape", 479 | }); 480 | 481 | // Create mock webhook server 482 | const mockWebhook = { 483 | url: "https://example.com/webhook", 484 | }; 485 | 486 | // Create prediction with webhook 487 | const prediction = await client.createPrediction({ 488 | version: testModel.latest_version.id, 489 | input: params as Record, 490 | webhook: mockWebhook.url, 491 | }); 492 | 493 | // Queue webhook delivery 494 | const webhookId = await webhookService.queueWebhook( 495 | { url: mockWebhook.url }, 496 | { 497 | type: "prediction.created", 498 | timestamp: new Date().toISOString(), 499 | data: JSON.parse(JSON.stringify(prediction)), 500 | } as WebhookEvent 501 | ); 502 | 503 | // Advance timers instead of waiting 504 | await vi.runAllTimersAsync(); 505 | 506 | const results = webhookService.getDeliveryResults(webhookId); 507 | expect(results).toHaveLength(1); 508 | // Expect failure since we're using a mock URL 509 | expect(results[0].success).toBe(false); 510 | 511 | // Cleanup 512 | vi.useRealTimers(); 513 | }); 514 | }); 515 | 516 | describe("Collection Operations", () => { 517 | it("should list collections", async () => { 518 | const result = await client.listCollections(); 519 | expect(result.collections).toBeInstanceOf(Array); 520 | expect(result.collections.length).toBeGreaterThan(0); 521 | 522 | const collection = result.collections[0]; 523 | expect(collection).toHaveProperty("name"); 524 | expect(collection).toHaveProperty("slug"); 525 | expect(collection).toHaveProperty("models"); 526 | }); 527 | 528 | it("should get collection details", async () => { 529 | const collections = await client.listCollections(); 530 | const testCollection = collections.collections[0]; 531 | 532 | const collection = await client.getCollection(testCollection.slug); 533 | expect(collection.name).toBe(testCollection.name); 534 | expect(collection.slug).toBe(testCollection.slug); 535 | expect(collection.models).toBeInstanceOf(Array); 536 | }); 537 | }); 538 | 539 | describe("Template System Integration", () => { 540 | it("should generate valid parameters for models", async () => { 541 | // Get SDXL model for testing 542 | const model = await client.getModel("stability-ai", "sdxl"); 543 | const schema = model.latest_version?.openapi_schema; 544 | expect(schema).toBeDefined(); 545 | 546 | // Generate parameters 547 | const prompt = "a detailed portrait in anime style"; 548 | const params = templateManager.generateParameters(prompt, { 549 | quality: "quality", 550 | style: "anime", 551 | size: "portrait", 552 | }); 553 | 554 | // Validate against model schema 555 | const errors = []; 556 | const inputSchema = schema?.components?.schemas?.Input as SchemaObject; 557 | if (inputSchema) { 558 | // Check required fields 559 | if (inputSchema.required) { 560 | for (const field of inputSchema.required) { 561 | if (!(field in params)) { 562 | errors.push(`Missing required field: ${field}`); 563 | } 564 | } 565 | } 566 | 567 | // Validate property types 568 | if (inputSchema.properties) { 569 | for (const [field, prop] of Object.entries( 570 | inputSchema.properties as Record 571 | )) { 572 | if (field in params) { 573 | const value = params[field as keyof typeof params] as unknown; 574 | switch (prop.type) { 575 | case "number": { 576 | const numValue = value as number; 577 | if (typeof numValue !== "number") { 578 | errors.push(`${field} must be a number`); 579 | } else { 580 | if (prop.minimum !== undefined && numValue < prop.minimum) { 581 | errors.push(`${field} must be >= ${prop.minimum}`); 582 | } 583 | if (prop.maximum !== undefined && numValue > prop.maximum) { 584 | errors.push(`${field} must be <= ${prop.maximum}`); 585 | } 586 | } 587 | break; 588 | } 589 | case "string": { 590 | const strValue = value as string; 591 | if (typeof strValue !== "string") { 592 | errors.push(`${field} must be a string`); 593 | } else if (prop.enum && !prop.enum.includes(strValue)) { 594 | errors.push( 595 | `${field} must be one of: ${prop.enum.join(", ")}` 596 | ); 597 | } 598 | break; 599 | } 600 | case "integer": { 601 | const intValue = value as number; 602 | if (!Number.isInteger(intValue)) { 603 | errors.push(`${field} must be an integer`); 604 | } 605 | break; 606 | } 607 | } 608 | } 609 | } 610 | } 611 | } 612 | 613 | expect(errors).toHaveLength(0); 614 | }); 615 | }); 616 | }); 617 | -------------------------------------------------------------------------------- /src/tests/protocol.test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Protocol compliance tests. 3 | */ 4 | 5 | import { 6 | describe, 7 | it, 8 | expect, 9 | beforeEach, 10 | afterEach, 11 | vi, 12 | beforeAll, 13 | afterAll, 14 | } from "vitest"; 15 | import { Server } from "@modelcontextprotocol/sdk/server/index.js"; 16 | import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; 17 | import { WebhookService } from "../services/webhook.js"; 18 | import { ReplicateClient } from "../replicate_client.js"; 19 | import { TemplateManager } from "../templates/manager.js"; 20 | import { ErrorHandler, ReplicateError } from "../services/error.js"; 21 | import { PredictionStatus } from "../models/prediction.js"; 22 | import { createError } from "../services/error.js"; 23 | 24 | // Set test environment 25 | process.env.NODE_ENV = "test"; 26 | process.env.REPLICATE_API_TOKEN = "test_token"; 27 | 28 | describe("Protocol Compliance", () => { 29 | let server: Server; 30 | let client: ReplicateClient; 31 | let webhookService: WebhookService; 32 | let templateManager: TemplateManager; 33 | let transport: StdioServerTransport; 34 | 35 | beforeAll(() => { 36 | // Enable fake timers 37 | vi.useFakeTimers(); 38 | }); 39 | 40 | beforeEach(async () => { 41 | // Initialize components 42 | client = new ReplicateClient(); 43 | webhookService = new WebhookService(); 44 | templateManager = new TemplateManager(); 45 | transport = new StdioServerTransport(); 46 | 47 | // Create server with test configuration 48 | server = new Server( 49 | { 50 | name: "replicate-test", 51 | version: "0.1.0", 52 | }, 53 | { 54 | capabilities: { 55 | tools: {}, 56 | prompts: {}, 57 | }, 58 | } 59 | ); 60 | 61 | // Initialize server 62 | await server.connect(transport); 63 | }); 64 | 65 | afterEach(async () => { 66 | // Clean up 67 | if (server) { 68 | await server.close(); 69 | } 70 | vi.clearAllMocks(); 71 | }); 72 | 73 | afterAll(() => { 74 | // Restore real timers 75 | vi.useRealTimers(); 76 | }); 77 | 78 | describe("Message Format", () => { 79 | it("should use JSON-RPC 2.0 format", async () => { 80 | const message = { 81 | jsonrpc: "2.0", 82 | method: "test", 83 | params: {}, 84 | id: 1, 85 | }; 86 | 87 | expect(message.jsonrpc).toBe("2.0"); 88 | expect(message).toHaveProperty("method"); 89 | expect(message).toHaveProperty("params"); 90 | expect(message).toHaveProperty("id"); 91 | }); 92 | 93 | it("should handle notifications without id", async () => { 94 | const notification = { 95 | jsonrpc: "2.0", 96 | method: "test", 97 | params: {}, 98 | }; 99 | 100 | expect(notification.jsonrpc).toBe("2.0"); 101 | expect(notification).toHaveProperty("method"); 102 | expect(notification).toHaveProperty("params"); 103 | expect(notification).not.toHaveProperty("id"); 104 | }); 105 | }); 106 | 107 | describe("Error Handling", () => { 108 | it("should handle rate limit errors with retries", async () => { 109 | const rateLimitError = createError.rateLimit(1); 110 | 111 | // Mock client to throw rate limit error once then succeed 112 | let attempts = 0; 113 | vi.spyOn(client, "listModels").mockImplementation(async () => { 114 | if (attempts === 0) { 115 | attempts++; 116 | throw rateLimitError; 117 | } 118 | return { models: [] }; 119 | }); 120 | 121 | // Mock setTimeout to advance timers 122 | const setTimeoutSpy = vi.spyOn(global, "setTimeout"); 123 | 124 | // Start the retry operation 125 | const resultPromise = ErrorHandler.withRetries( 126 | async () => client.listModels(), 127 | { 128 | maxAttempts: 2, 129 | minDelay: 100, // Use smaller delay for tests 130 | maxDelay: 200, 131 | retryIf: (error: Error) => error instanceof ReplicateError, 132 | } 133 | ); 134 | 135 | // Wait for setTimeout to be called and advance timer 136 | await vi.waitFor(() => setTimeoutSpy.mock.calls.length > 0); 137 | await vi.runAllTimersAsync(); 138 | 139 | // Wait for the result 140 | const result = await resultPromise; 141 | 142 | expect(attempts).toBe(1); 143 | expect(result).toEqual({ models: [] }); 144 | }); 145 | 146 | it("should generate detailed error reports", () => { 147 | const error = createError.validation("test", "Test error"); 148 | 149 | const report = ErrorHandler.createErrorReport(error); 150 | expect(report).toMatchObject({ 151 | name: "ReplicateError", 152 | message: "Invalid input parameters", 153 | context: { 154 | field: "test", 155 | message: "Test error", 156 | }, 157 | timestamp: expect.any(String), 158 | }); 159 | }); 160 | 161 | it("should handle prediction status transitions", async () => { 162 | const prediction = { 163 | id: "test", 164 | status: PredictionStatus.Starting, 165 | version: "test-version", 166 | }; 167 | 168 | const statusUpdates: PredictionStatus[] = []; 169 | const mockTransport = { 170 | notify: vi.fn().mockImplementation((notification) => { 171 | if (notification.method === "prediction/status") { 172 | statusUpdates.push(notification.params.status); 173 | } 174 | }), 175 | }; 176 | 177 | // Simulate status transitions 178 | prediction.status = PredictionStatus.Processing; 179 | await mockTransport.notify({ 180 | method: "prediction/status", 181 | params: { status: prediction.status }, 182 | }); 183 | 184 | prediction.status = PredictionStatus.Succeeded; 185 | await mockTransport.notify({ 186 | method: "prediction/status", 187 | params: { status: prediction.status }, 188 | }); 189 | 190 | expect(statusUpdates).toEqual([ 191 | PredictionStatus.Processing, 192 | PredictionStatus.Succeeded, 193 | ]); 194 | expect(mockTransport.notify).toHaveBeenCalledTimes(2); 195 | }); 196 | }); 197 | 198 | describe("Webhook Integration", () => { 199 | beforeEach(() => { 200 | // Mock fetch for webhook delivery 201 | vi.spyOn(global, "fetch").mockImplementation(async () => { 202 | return Promise.resolve({ 203 | ok: true, 204 | status: 200, 205 | statusText: "OK", 206 | } as Response); 207 | }); 208 | }); 209 | 210 | it("should validate webhook configuration", () => { 211 | const validConfig = { 212 | url: "https://example.com/webhook", 213 | secret: "1234567890abcdef1234567890abcdef", 214 | retries: 3, 215 | timeout: 5000, 216 | }; 217 | 218 | const invalidConfig = { 219 | url: "not-a-url", 220 | secret: "too-short", 221 | retries: -1, 222 | timeout: 500, 223 | }; 224 | 225 | expect(webhookService.validateWebhookConfig(validConfig)).toHaveLength(0); 226 | expect(webhookService.validateWebhookConfig(invalidConfig)).toHaveLength( 227 | 4 228 | ); 229 | }); 230 | 231 | it("should handle webhook delivery with retries", async () => { 232 | const deliverySpy = vi.spyOn( 233 | webhookService, 234 | "deliverWebhook" as keyof WebhookService 235 | ); 236 | const fetchSpy = vi.spyOn(global, "fetch"); 237 | 238 | // First call fails, second succeeds 239 | fetchSpy 240 | .mockRejectedValueOnce(new Error("Delivery failed")) 241 | .mockResolvedValueOnce({ 242 | ok: true, 243 | status: 200, 244 | statusText: "OK", 245 | } as Response); 246 | 247 | const webhookId = await webhookService.queueWebhook( 248 | { 249 | url: "https://example.com/webhook", 250 | retries: 1, 251 | }, 252 | { 253 | type: "prediction.created", 254 | timestamp: new Date().toISOString(), 255 | data: { id: "123" }, 256 | } 257 | ); 258 | 259 | // Wait for delivery attempts 260 | await vi.advanceTimersByTimeAsync(100); // First attempt 261 | await vi.advanceTimersByTimeAsync(200); // Retry attempt 262 | await vi.advanceTimersByTimeAsync(100); // Processing time 263 | 264 | const results = webhookService.getDeliveryResults(webhookId); 265 | expect(results).toHaveLength(2); 266 | expect(results[0].success).toBe(false); 267 | expect(results[1].success).toBe(true); 268 | expect(deliverySpy).toHaveBeenCalledTimes(2); 269 | }); 270 | }); 271 | 272 | describe("Template System", () => { 273 | it("should generate parameters from templates", () => { 274 | const prompt = "a photo of a mountain landscape at sunset"; 275 | const params = templateManager.generateParameters(prompt, { 276 | quality: "quality", 277 | style: "photorealistic", 278 | size: "landscape", 279 | }); 280 | 281 | expect(params).toHaveProperty("prompt"); 282 | expect(params).toHaveProperty("negative_prompt"); 283 | expect(params).toHaveProperty("width"); 284 | expect(params).toHaveProperty("height"); 285 | expect(params).toHaveProperty("num_inference_steps"); 286 | expect(params).toHaveProperty("guidance_scale"); 287 | }); 288 | 289 | it("should suggest parameters based on prompt", () => { 290 | const prompt = "a detailed anime-style portrait"; 291 | const suggestions = templateManager.suggestParameters(prompt); 292 | 293 | expect(suggestions.quality).toBe("quality"); 294 | expect(suggestions.style).toBe("anime"); 295 | expect(suggestions.size).toBe("portrait"); 296 | }); 297 | 298 | it("should validate template parameters", () => { 299 | const validTemplate = templateManager.generateParameters("test prompt", { 300 | quality: "quality", 301 | style: "photorealistic", 302 | size: "landscape", 303 | }); 304 | 305 | const invalidTemplate = { 306 | ...validTemplate, 307 | width: -100, // Invalid width 308 | }; 309 | 310 | expect(() => 311 | templateManager.validateParameters(validTemplate) 312 | ).not.toThrow(); 313 | expect(() => 314 | templateManager.validateParameters(invalidTemplate) 315 | ).toThrow(); 316 | }); 317 | }); 318 | }); 319 | -------------------------------------------------------------------------------- /src/tools/handlers.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Handlers for MCP tools. 3 | */ 4 | 5 | import type { ReplicateClient } from "../replicate_client.js"; 6 | import type { Model } from "../models/model.js"; 7 | import type { ModelIO, Prediction } from "../models/prediction.js"; 8 | import { PredictionStatus } from "../models/prediction.js"; 9 | import type { Collection } from "../models/collection.js"; 10 | 11 | /** 12 | * Cache for models, predictions, and collections. 13 | */ 14 | interface Cache { 15 | modelCache: Map; 16 | predictionCache: Map; 17 | collectionCache: Map; 18 | predictionStatus: Map; 19 | } 20 | 21 | /** 22 | * Get error message from unknown error. 23 | */ 24 | function getErrorMessage(error: unknown): string { 25 | if (error instanceof Error) { 26 | return error.message; 27 | } 28 | if (error instanceof Promise) { 29 | return "An asynchronous error occurred. Please try again."; 30 | } 31 | return String(error); 32 | } 33 | 34 | /** 35 | * Handle the search_models tool. 36 | */ 37 | export async function handleSearchModels( 38 | client: ReplicateClient, 39 | cache: Cache, 40 | args: { query: string } 41 | ) { 42 | try { 43 | const result = await client.searchModels(args.query); 44 | 45 | // Update cache 46 | for (const model of result.models) { 47 | cache.modelCache.set(`${model.owner}/${model.name}`, model); 48 | } 49 | 50 | return { 51 | content: [ 52 | { 53 | type: "text", 54 | text: `Found ${result.models.length} models matching "${args.query}":`, 55 | }, 56 | ...result.models.map((model) => ({ 57 | type: "text" as const, 58 | text: `- ${model.owner}/${model.name}: ${ 59 | model.description || "No description" 60 | }`, 61 | })), 62 | ], 63 | }; 64 | } catch (error) { 65 | return { 66 | isError: true, 67 | content: [ 68 | { 69 | type: "text", 70 | text: `Error searching models: ${getErrorMessage(error)}`, 71 | }, 72 | ], 73 | }; 74 | } 75 | } 76 | 77 | /** 78 | * Handle the list_models tool. 79 | */ 80 | export async function handleListModels( 81 | client: ReplicateClient, 82 | cache: Cache, 83 | args: { owner?: string; cursor?: string } 84 | ) { 85 | try { 86 | const result = await client.listModels(args); 87 | 88 | // Update cache 89 | for (const model of result.models) { 90 | cache.modelCache.set(`${model.owner}/${model.name}`, model); 91 | } 92 | 93 | return { 94 | content: [ 95 | { 96 | type: "text", 97 | text: args.owner ? `Models by ${args.owner}:` : "Available models:", 98 | }, 99 | ...result.models.map((model) => ({ 100 | type: "text" as const, 101 | text: `- ${model.owner}/${model.name}: ${ 102 | model.description || "No description" 103 | }`, 104 | })), 105 | result.next_cursor 106 | ? { 107 | type: "text" as const, 108 | text: `\nUse cursor "${result.next_cursor}" to see more results.`, 109 | } 110 | : null, 111 | ].filter(Boolean), 112 | }; 113 | } catch (error) { 114 | return { 115 | isError: true, 116 | content: [ 117 | { 118 | type: "text", 119 | text: `Error listing models: ${getErrorMessage(error)}`, 120 | }, 121 | ], 122 | }; 123 | } 124 | } 125 | 126 | /** 127 | * Handle the list_collections tool. 128 | */ 129 | export async function handleListCollections( 130 | client: ReplicateClient, 131 | cache: Cache, 132 | args: { cursor?: string } 133 | ) { 134 | try { 135 | const result = await client.listCollections(args); 136 | 137 | // Update cache 138 | for (const collection of result.collections) { 139 | cache.collectionCache.set(collection.slug, collection); 140 | } 141 | 142 | return { 143 | content: [ 144 | { 145 | type: "text", 146 | text: "Available collections:", 147 | }, 148 | ...result.collections.map((collection: Collection) => ({ 149 | type: "text" as const, 150 | text: `- ${collection.name} (slug: ${collection.slug}): ${ 151 | collection.description || 152 | `A collection of ${collection.models.length} models` 153 | }`, 154 | })), 155 | result.next_cursor 156 | ? { 157 | type: "text" as const, 158 | text: `\nUse cursor "${result.next_cursor}" to see more results.`, 159 | } 160 | : null, 161 | ].filter(Boolean), 162 | }; 163 | } catch (error) { 164 | return { 165 | isError: true, 166 | content: [ 167 | { 168 | type: "text", 169 | text: `Error listing collections: ${getErrorMessage(error)}`, 170 | }, 171 | ], 172 | }; 173 | } 174 | } 175 | 176 | /** 177 | * Handle the get_collection tool. 178 | */ 179 | export async function handleGetCollection( 180 | client: ReplicateClient, 181 | cache: Cache, 182 | args: { slug: string } 183 | ) { 184 | try { 185 | const collection = await client.getCollection(args.slug); 186 | 187 | // Update cache 188 | cache.collectionCache.set(collection.slug, collection); 189 | 190 | return { 191 | content: [ 192 | { 193 | type: "text", 194 | text: `Collection: ${collection.name}`, 195 | }, 196 | collection.description 197 | ? { 198 | type: "text" as const, 199 | text: collection.description, 200 | } 201 | : null, 202 | { 203 | type: "text", 204 | text: "\nModels in this collection:", 205 | }, 206 | ...collection.models.map((model: Model) => ({ 207 | type: "text" as const, 208 | text: `- ${model.owner}/${model.name}: ${ 209 | model.description || "No description" 210 | }`, 211 | })), 212 | ].filter(Boolean), 213 | }; 214 | } catch (error) { 215 | return { 216 | isError: true, 217 | content: [ 218 | { 219 | type: "text", 220 | text: `Error getting collection: ${getErrorMessage(error)}`, 221 | }, 222 | ], 223 | }; 224 | } 225 | } 226 | 227 | /** 228 | * Handle the create_prediction tool. 229 | */ 230 | export async function handleCreatePrediction( 231 | client: ReplicateClient, 232 | cache: Cache, 233 | args: { 234 | version: string | undefined; 235 | model: string | undefined; 236 | input: ModelIO | string; 237 | webhook?: string; 238 | } 239 | ) { 240 | try { 241 | // If input is a string, wrap it in an object with 'prompt' property 242 | const input = 243 | typeof args.input === "string" ? { prompt: args.input } : args.input; 244 | 245 | const prediction = await client.createPrediction({ 246 | ...args, 247 | input, 248 | }); 249 | 250 | // Cache the prediction and its initial status 251 | cache.predictionCache.set(prediction.id, prediction); 252 | cache.predictionStatus.set( 253 | prediction.id, 254 | prediction.status as PredictionStatus 255 | ); 256 | 257 | return { 258 | content: [ 259 | { 260 | type: "text", 261 | text: `Created prediction ${prediction.id}`, 262 | }, 263 | ], 264 | }; 265 | } catch (error) { 266 | return { 267 | isError: true, 268 | content: [ 269 | { 270 | type: "text", 271 | text: `Error creating prediction: ${getErrorMessage(error)}`, 272 | }, 273 | ], 274 | }; 275 | } 276 | } 277 | 278 | /** 279 | * Creates a prediction and handles the full lifecycle: 280 | * sends the request, polls until completion, and returns the final result URL. 281 | */ 282 | export async function handleCreateAndPollPrediction( 283 | client: ReplicateClient, 284 | cache: Cache, 285 | args: { 286 | version: string | undefined; 287 | model: string | undefined; 288 | input: ModelIO | string; 289 | webhook?: string; 290 | pollInterval?: number; 291 | timeout?: number; 292 | } 293 | ) { 294 | // If input is a string, wrap it in an object with 'prompt' property 295 | const input = 296 | typeof args.input === "string" ? { prompt: args.input } : args.input; 297 | 298 | let prediction; 299 | try { 300 | prediction = await client.createPrediction({ 301 | ...args, 302 | input, 303 | }); 304 | } catch (error) { 305 | return { 306 | isError: true, 307 | content: [ 308 | { 309 | type: "text", 310 | text: `Error creating prediction: ${getErrorMessage(error)}`, 311 | }, 312 | ], 313 | }; 314 | } 315 | 316 | // Cache the prediction and its initial status 317 | cache.predictionCache.set(prediction.id, prediction); 318 | cache.predictionStatus.set( 319 | prediction.id, 320 | prediction.status as PredictionStatus 321 | ); 322 | 323 | const shouldContinuePolling = ( 324 | prediction: Prediction | null, 325 | timeoutAt: number 326 | ) => { 327 | if (!prediction) return true; 328 | if (performance.now() > timeoutAt) return false; 329 | 330 | if ( 331 | prediction.status === "succeeded" || 332 | prediction.status === "failed" || 333 | prediction.status === "canceled" 334 | ) { 335 | return false; 336 | } 337 | return true; 338 | }; 339 | 340 | const { pollInterval = 1, timeout = 60 } = args; 341 | const predictionId = prediction.id; 342 | let timeoutAt = performance.now() + timeout * 1000; 343 | 344 | do { 345 | await new Promise((resolve) => setTimeout(resolve, pollInterval * 1000)); 346 | 347 | try { 348 | prediction = await client.getPredictionStatus(predictionId); 349 | } catch (error) { 350 | console.error(error); 351 | } 352 | } while (shouldContinuePolling(prediction, timeoutAt)); 353 | 354 | if (timeoutAt < performance.now()) { 355 | console.warn( 356 | `Timeout reached while polling prediction by id: ${predictionId}` 357 | ); 358 | return { 359 | isError: true, 360 | content: [ 361 | { 362 | type: "text", 363 | text: `Timeout reached while polling prediction by id: ${predictionId}`, 364 | }, 365 | ], 366 | }; 367 | } 368 | 369 | if (prediction.status === "canceled" || prediction.status === "failed") { 370 | console.warn( 371 | `Prediction with id: ${predictionId} ${prediction.status} with error ${prediction.error}` 372 | ); 373 | return { 374 | isError: true, 375 | content: [ 376 | { 377 | type: "text", 378 | text: `Prediction with id: ${predictionId} ${prediction.status} with error ${prediction.error}`, 379 | }, 380 | ], 381 | }; 382 | } 383 | 384 | return { 385 | content: [ 386 | { 387 | type: "text", 388 | text: `Created prediction ${prediction.id}, Output: ${prediction.output}`, 389 | }, 390 | ], 391 | }; 392 | } 393 | 394 | /** 395 | * Handle the cancel_prediction tool. 396 | */ 397 | export async function handleCancelPrediction( 398 | client: ReplicateClient, 399 | cache: Cache, 400 | args: { prediction_id: string } 401 | ) { 402 | try { 403 | const prediction = await client.cancelPrediction(args.prediction_id); 404 | // Update cache 405 | cache.predictionCache.set(prediction.id, prediction); 406 | cache.predictionStatus.set( 407 | prediction.id, 408 | prediction.status as PredictionStatus 409 | ); 410 | 411 | return { 412 | content: [ 413 | { 414 | type: "text", 415 | text: `Cancelled prediction ${prediction.id}`, 416 | }, 417 | ], 418 | }; 419 | } catch (error) { 420 | return { 421 | isError: true, 422 | content: [ 423 | { 424 | type: "text", 425 | text: `Error cancelling prediction: ${getErrorMessage(error)}`, 426 | }, 427 | ], 428 | }; 429 | } 430 | } 431 | 432 | /** 433 | * Handle the get_model tool. 434 | */ 435 | export async function handleGetModel( 436 | client: ReplicateClient, 437 | cache: Cache, 438 | args: { owner: string; name: string } 439 | ) { 440 | try { 441 | const model = await client.getModel(args.owner, args.name); 442 | 443 | // Update cache 444 | cache.modelCache.set(`${model.owner}/${model.name}`, model); 445 | 446 | return { 447 | content: [ 448 | { 449 | type: "text", 450 | text: `Model: ${model.owner}/${model.name}`, 451 | }, 452 | model.description 453 | ? { 454 | type: "text" as const, 455 | text: `\nDescription: ${model.description}`, 456 | } 457 | : null, 458 | { 459 | type: "text", 460 | text: "\nLatest version:", 461 | }, 462 | model.latest_version 463 | ? { 464 | type: "text" as const, 465 | text: `ID: ${model.latest_version.id}\nCreated: ${model.latest_version.created_at}`, 466 | } 467 | : { 468 | type: "text" as const, 469 | text: "No versions available", 470 | }, 471 | ].filter(Boolean), 472 | }; 473 | } catch (error) { 474 | return { 475 | isError: true, 476 | content: [ 477 | { 478 | type: "text", 479 | text: `Error getting model: ${getErrorMessage(error)}`, 480 | }, 481 | ], 482 | }; 483 | } 484 | } 485 | 486 | /** 487 | * Handle the get_prediction tool. 488 | */ 489 | export async function handleGetPrediction( 490 | client: ReplicateClient, 491 | cache: Cache, 492 | args: { prediction_id: string } 493 | ) { 494 | try { 495 | const prediction = await client.getPredictionStatus(args.prediction_id); 496 | 497 | const previousStatus = cache.predictionStatus.get(prediction.id); 498 | 499 | // Update cache 500 | cache.predictionCache.set(prediction.id, prediction); 501 | cache.predictionStatus.set( 502 | prediction.id, 503 | prediction.status as PredictionStatus 504 | ); 505 | 506 | return { 507 | content: [ 508 | { 509 | type: "text", 510 | text: `Prediction ${prediction.id}:`, 511 | }, 512 | { 513 | type: "text", 514 | text: `Status: ${prediction.status}\nModel version: ${prediction.version}\nCreated: ${prediction.created_at}`, 515 | }, 516 | prediction.input 517 | ? { 518 | type: "text" as const, 519 | text: `\nInput:\n${JSON.stringify(prediction.input, null, 2)}`, 520 | } 521 | : null, 522 | prediction.output 523 | ? { 524 | type: "text" as const, 525 | text: `\nOutput:\n${JSON.stringify(prediction.output, null, 2)}`, 526 | } 527 | : null, 528 | prediction.error 529 | ? { 530 | type: "text" as const, 531 | text: `\nError: ${prediction.error}`, 532 | } 533 | : null, 534 | prediction.logs 535 | ? { 536 | type: "text" as const, 537 | text: `\nLogs:\n${prediction.logs}`, 538 | } 539 | : null, 540 | ].filter(Boolean), 541 | }; 542 | } catch (error) { 543 | return { 544 | isError: true, 545 | content: [ 546 | { 547 | type: "text", 548 | text: `Error getting prediction: ${getErrorMessage(error)}`, 549 | }, 550 | ], 551 | }; 552 | } 553 | } 554 | 555 | /** 556 | * Handle the list_predictions tool. 557 | */ 558 | /** 559 | * Estimate prediction progress based on logs and status. 560 | */ 561 | function estimateProgress(prediction: Prediction): number { 562 | if (prediction.status === PredictionStatus.Succeeded) return 100; 563 | if ( 564 | prediction.status === PredictionStatus.Failed || 565 | prediction.status === PredictionStatus.Canceled 566 | ) 567 | return 0; 568 | if (prediction.status === PredictionStatus.Starting) return 0; 569 | 570 | // Try to parse progress from logs 571 | if (prediction.logs) { 572 | const match = prediction.logs.match(/progress: (\d+)%/); 573 | if (match) { 574 | return Number.parseInt(match[1], 10); 575 | } 576 | } 577 | 578 | // Default to 50% if processing but no specific progress info 579 | return prediction.status === PredictionStatus.Processing ? 50 : 0; 580 | } 581 | 582 | export async function handleListPredictions( 583 | client: ReplicateClient, 584 | cache: Cache, 585 | args: { limit?: number; cursor?: string } 586 | ) { 587 | try { 588 | const predictions = await client.listPredictions({ 589 | limit: args.limit || 10, 590 | }); 591 | 592 | // Update cache and status tracking 593 | for (const prediction of predictions) { 594 | const previousStatus = cache.predictionStatus.get(prediction.id); 595 | cache.predictionCache.set(prediction.id, prediction); 596 | cache.predictionStatus.set( 597 | prediction.id, 598 | prediction.status as PredictionStatus 599 | ); 600 | } 601 | 602 | // Format predictions as text 603 | const predictionTexts = predictions.map((prediction) => { 604 | const status = prediction.status.toUpperCase(); 605 | const model = prediction.version; 606 | const time = new Date(prediction.created_at).toLocaleString(); 607 | return `- ID: ${prediction.id}\n Status: ${status}\n Model: ${model}\n Created: ${time}`; 608 | }); 609 | 610 | return { 611 | content: [ 612 | { 613 | type: "text", 614 | text: `Found ${predictions.length} predictions:`, 615 | }, 616 | { 617 | type: "text", 618 | text: predictionTexts.join("\n\n"), 619 | }, 620 | ], 621 | }; 622 | } catch (error) { 623 | return { 624 | isError: true, 625 | content: [ 626 | { 627 | type: "text", 628 | text: `Error listing predictions: ${getErrorMessage(error)}`, 629 | }, 630 | ], 631 | }; 632 | } 633 | } 634 | -------------------------------------------------------------------------------- /src/tools/image_viewer.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * MCP tools for image viewing functionality. 3 | */ 4 | 5 | import { ImageViewer } from "../services/image_viewer.js"; 6 | import type { Tool } from "@modelcontextprotocol/sdk/types.js"; 7 | import type { MCPRequest } from "../types/mcp.js"; 8 | 9 | /** 10 | * Tool for displaying images in the system's default web browser 11 | */ 12 | export const viewImageTool: Tool = { 13 | name: "view_image", 14 | description: "Display an image in the system's default web browser", 15 | inputSchema: { 16 | type: "object", 17 | properties: { 18 | url: { 19 | type: "string", 20 | description: "URL of the image to display", 21 | }, 22 | }, 23 | required: ["url"], 24 | }, 25 | handler: handleViewImage, 26 | }; 27 | 28 | /** 29 | * Tool for clearing the image cache 30 | */ 31 | export const clearImageCacheTool: Tool = { 32 | name: "clear_image_cache", 33 | description: "Clear the image viewer cache", 34 | inputSchema: { 35 | type: "object", 36 | properties: {}, 37 | }, 38 | handler: handleClearImageCache, 39 | }; 40 | 41 | /** 42 | * Tool for getting image cache statistics 43 | */ 44 | export const getImageCacheStatsTool: Tool = { 45 | name: "get_image_cache_stats", 46 | description: "Get statistics about the image cache", 47 | inputSchema: { 48 | type: "object", 49 | properties: {}, 50 | }, 51 | handler: handleGetImageCacheStats, 52 | }; 53 | 54 | /** 55 | * Display an image in the system's default web browser 56 | */ 57 | export async function handleViewImage(request: any) { 58 | const url = request.params.arguments?.url; 59 | 60 | if (typeof url !== "string") { 61 | return { 62 | isError: true, 63 | content: [ 64 | { 65 | type: "text", 66 | text: "URL parameter is required", 67 | }, 68 | ], 69 | }; 70 | } 71 | 72 | try { 73 | const viewer = ImageViewer.getInstance(); 74 | await viewer.displayImage(url); 75 | 76 | return { 77 | content: [ 78 | { 79 | type: "text", 80 | text: "Image displayed successfully", 81 | }, 82 | ], 83 | }; 84 | } catch (error) { 85 | return { 86 | isError: true, 87 | content: [ 88 | { 89 | type: "text", 90 | text: `Failed to display image: ${ 91 | error instanceof Error ? error.message : String(error) 92 | }`, 93 | }, 94 | ], 95 | }; 96 | } 97 | } 98 | 99 | /** 100 | * Clear the image viewer cache 101 | */ 102 | export async function handleClearImageCache(request: any) { 103 | try { 104 | const viewer = ImageViewer.getInstance(); 105 | viewer.clearCache(); 106 | 107 | return { 108 | content: [ 109 | { 110 | type: "text", 111 | text: "Image cache cleared successfully", 112 | }, 113 | ], 114 | }; 115 | } catch (error) { 116 | return { 117 | isError: true, 118 | content: [ 119 | { 120 | type: "text", 121 | text: `Failed to clear cache: ${ 122 | error instanceof Error ? error.message : String(error) 123 | }`, 124 | }, 125 | ], 126 | }; 127 | } 128 | } 129 | 130 | /** 131 | * Get image viewer cache statistics 132 | */ 133 | export async function handleGetImageCacheStats(request: any) { 134 | try { 135 | const viewer = ImageViewer.getInstance(); 136 | const stats = viewer.getCacheStats(); 137 | 138 | return { 139 | content: [ 140 | { 141 | type: "text", 142 | text: "Cache statistics:", 143 | }, 144 | { 145 | type: "text", 146 | text: JSON.stringify(stats, null, 2), 147 | }, 148 | ], 149 | }; 150 | } catch (error) { 151 | return { 152 | isError: true, 153 | content: [ 154 | { 155 | type: "text", 156 | text: `Failed to get cache stats: ${ 157 | error instanceof Error ? error.message : String(error) 158 | }`, 159 | }, 160 | ], 161 | }; 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/tools/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * MCP tools for interacting with Replicate. 3 | */ 4 | 5 | export * from "./models.js"; 6 | export * from "./predictions.js"; 7 | export * from "./image_viewer.js"; 8 | 9 | import type { Tool } from "@modelcontextprotocol/sdk/types.js"; 10 | import { 11 | searchModelsTool, 12 | listModelsTool, 13 | listCollectionsTool, 14 | getCollectionTool, 15 | getModelTool, 16 | } from "./models.js"; 17 | import { 18 | createPredictionTool, 19 | createAndPollPredictionTool, 20 | cancelPredictionTool, 21 | getPredictionTool, 22 | listPredictionsTool, 23 | } from "./predictions.js"; 24 | import { 25 | viewImageTool, 26 | clearImageCacheTool, 27 | getImageCacheStatsTool, 28 | } from "./image_viewer.js"; 29 | 30 | /** 31 | * All available tools. 32 | */ 33 | export const tools: Tool[] = [ 34 | searchModelsTool, 35 | listModelsTool, 36 | listCollectionsTool, 37 | getCollectionTool, 38 | createPredictionTool, 39 | createAndPollPredictionTool, 40 | cancelPredictionTool, 41 | getPredictionTool, 42 | listPredictionsTool, 43 | getModelTool, 44 | // Image viewer tools 45 | viewImageTool, 46 | clearImageCacheTool, 47 | getImageCacheStatsTool, 48 | ]; 49 | -------------------------------------------------------------------------------- /src/tools/models.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Tools for interacting with Replicate models. 3 | */ 4 | 5 | import type { Tool } from "@modelcontextprotocol/sdk/types.js"; 6 | import type { ReplicateClient } from "../replicate_client.js"; 7 | import type { Model } from "../models/model.js"; 8 | 9 | /** 10 | * Tool for searching models using semantic search. 11 | */ 12 | export const searchModelsTool: Tool = { 13 | name: "search_models", 14 | description: "Search for models using semantic search", 15 | inputSchema: { 16 | type: "object", 17 | properties: { 18 | query: { 19 | type: "string", 20 | description: "Search query", 21 | }, 22 | }, 23 | required: ["query"], 24 | }, 25 | }; 26 | 27 | /** 28 | * Tool for listing available models. 29 | */ 30 | export const listModelsTool: Tool = { 31 | name: "list_models", 32 | description: "List available models with optional filtering", 33 | inputSchema: { 34 | type: "object", 35 | properties: { 36 | owner: { 37 | type: "string", 38 | description: "Filter by model owner", 39 | }, 40 | cursor: { 41 | type: "string", 42 | description: "Pagination cursor", 43 | }, 44 | }, 45 | }, 46 | }; 47 | 48 | /** 49 | * Tool for listing model collections. 50 | */ 51 | export const listCollectionsTool: Tool = { 52 | name: "list_collections", 53 | description: "List available model collections", 54 | inputSchema: { 55 | type: "object", 56 | properties: { 57 | cursor: { 58 | type: "string", 59 | description: "Pagination cursor", 60 | }, 61 | }, 62 | }, 63 | }; 64 | 65 | /** 66 | * Tool for getting collection details. 67 | */ 68 | export const getCollectionTool: Tool = { 69 | name: "get_collection", 70 | description: "Get details of a specific collection", 71 | inputSchema: { 72 | type: "object", 73 | properties: { 74 | slug: { 75 | type: "string", 76 | description: "Collection slug", 77 | }, 78 | }, 79 | required: ["slug"], 80 | }, 81 | }; 82 | 83 | /** 84 | * Tool for getting model details including versions. 85 | */ 86 | export const getModelTool: Tool = { 87 | name: "get_model", 88 | description: "Get details of a specific model including available versions", 89 | inputSchema: { 90 | type: "object", 91 | properties: { 92 | owner: { 93 | type: "string", 94 | description: "Model owner", 95 | }, 96 | name: { 97 | type: "string", 98 | description: "Model name", 99 | }, 100 | }, 101 | required: ["owner", "name"], 102 | }, 103 | }; 104 | -------------------------------------------------------------------------------- /src/tools/predictions.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Tools for managing Replicate predictions. 3 | */ 4 | 5 | import type { Tool } from "@modelcontextprotocol/sdk/types.js"; 6 | import type { ReplicateClient } from "../replicate_client.js"; 7 | import type { Prediction } from "../models/prediction.js"; 8 | 9 | /** 10 | * Tool for creating new predictions. 11 | */ 12 | export const createPredictionTool: Tool = { 13 | name: "create_prediction", 14 | description: 15 | "Create a new prediction using either a model version (for community models) or model name (for official models)", 16 | inputSchema: { 17 | type: "object", 18 | properties: { 19 | version: { 20 | type: "string", 21 | description: "Model version ID to use (for community models)", 22 | }, 23 | model: { 24 | type: "string", 25 | description: "Model name to use (for official models)", 26 | }, 27 | input: { 28 | type: "object", 29 | description: "Input parameters for the model", 30 | additionalProperties: true, 31 | }, 32 | webhook_url: { 33 | type: "string", 34 | description: "Optional webhook URL for notifications", 35 | }, 36 | }, 37 | oneOf: [ 38 | { required: ["version", "input"] }, 39 | { required: ["model", "input"] }, 40 | ], 41 | }, 42 | }; 43 | 44 | /** 45 | * Tool for creating a new prediction, waiting for it to complete, 46 | * and returning the final output URL. 47 | */ 48 | export const createAndPollPredictionTool: Tool = { 49 | name: "create_and_poll_prediction", 50 | description: 51 | "Create a new prediction and wait until it's completed. Accepts either a model version (for community models) or a model name (for official models)", 52 | inputSchema: { 53 | type: "object", 54 | properties: { 55 | version: { 56 | type: "string", 57 | description: "Model version ID to use (for community models)", 58 | }, 59 | model: { 60 | type: "string", 61 | description: "Model name to use (for official models)", 62 | }, 63 | input: { 64 | type: "object", 65 | description: "Input parameters for the model", 66 | additionalProperties: true, 67 | }, 68 | webhook_url: { 69 | type: "string", 70 | description: "Optional webhook URL for notifications", 71 | }, 72 | poll_interval: { 73 | type: "number", 74 | description: "Optional interval between polls (default: 1)", 75 | }, 76 | timeout: { 77 | type: "number", 78 | description: "Optional timeout for prediction (default: 60)", 79 | }, 80 | }, 81 | oneOf: [ 82 | { required: ["version", "input"] }, 83 | { required: ["model", "input"] }, 84 | ], 85 | }, 86 | }; 87 | 88 | /** 89 | * Tool for canceling predictions. 90 | */ 91 | export const cancelPredictionTool: Tool = { 92 | name: "cancel_prediction", 93 | description: "Cancel a running prediction", 94 | inputSchema: { 95 | type: "object", 96 | properties: { 97 | prediction_id: { 98 | type: "string", 99 | description: "ID of the prediction to cancel", 100 | }, 101 | }, 102 | required: ["prediction_id"], 103 | }, 104 | }; 105 | 106 | /** 107 | * Tool for getting prediction details. 108 | */ 109 | export const getPredictionTool: Tool = { 110 | name: "get_prediction", 111 | description: "Get details about a specific prediction", 112 | inputSchema: { 113 | type: "object", 114 | properties: { 115 | prediction_id: { 116 | type: "string", 117 | description: "ID of the prediction to get details for", 118 | }, 119 | }, 120 | required: ["prediction_id"], 121 | }, 122 | }; 123 | 124 | /** 125 | * Tool for listing recent predictions. 126 | */ 127 | export const listPredictionsTool: Tool = { 128 | name: "list_predictions", 129 | description: "List recent predictions", 130 | inputSchema: { 131 | type: "object", 132 | properties: { 133 | limit: { 134 | type: "number", 135 | description: "Maximum number of predictions to return", 136 | default: 10, 137 | }, 138 | cursor: { 139 | type: "string", 140 | description: "Cursor for pagination", 141 | }, 142 | }, 143 | }, 144 | }; 145 | -------------------------------------------------------------------------------- /src/types/mcp.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Type definitions for MCP protocol. 3 | */ 4 | 5 | import { EventEmitter } from "node:events"; 6 | 7 | export interface MCPMessage { 8 | jsonrpc: "2.0"; 9 | id?: string | number; 10 | method?: string; 11 | params?: Record; 12 | } 13 | 14 | export interface MCPRequest extends MCPMessage { 15 | method: string; 16 | params: Record; 17 | } 18 | 19 | export interface MCPResponse extends MCPMessage { 20 | result?: unknown; 21 | error?: { 22 | code: number; 23 | message: string; 24 | data?: unknown; 25 | }; 26 | } 27 | 28 | export interface MCPNotification extends MCPMessage { 29 | method: string; 30 | params: Record; 31 | } 32 | 33 | export interface MCPResource { 34 | uri: string; 35 | mimeType?: string; 36 | text?: string; 37 | } 38 | 39 | export abstract class BaseTransport extends EventEmitter { 40 | abstract connect(): Promise; 41 | abstract disconnect(): Promise; 42 | abstract send(message: MCPMessage): Promise; 43 | } 44 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2022", 4 | "module": "NodeNext", 5 | "moduleResolution": "NodeNext", 6 | "outDir": "./build", 7 | "rootDir": "./src", 8 | "strict": true, 9 | "esModuleInterop": true, 10 | "skipLibCheck": true, 11 | "forceConsistentCasingInFileNames": true, 12 | "declaration": true, 13 | "sourceMap": true 14 | }, 15 | "include": ["src/**/*"], 16 | "exclude": ["node_modules", "dist"] 17 | } 18 | --------------------------------------------------------------------------------