├── .gitignore
├── .npmignore
├── LICENSE
├── README.md
├── TODO.md
├── docs
├── api.md
├── examples.md
├── troubleshooting.md
└── webhooks.md
├── mcp-replicate-0.1.0.tgz
├── mcp-replicate-0.1.1.tgz
├── package-lock.json
├── package.json
├── src
├── index.ts
├── models
│ ├── collection.ts
│ ├── hardware.ts
│ ├── model.ts
│ ├── openapi.ts
│ ├── prediction.ts
│ └── webhook.ts
├── replicate_client.ts
├── services
│ ├── cache.ts
│ ├── error.ts
│ ├── image_viewer.ts
│ └── webhook.ts
├── templates
│ ├── manager.ts
│ ├── parameters
│ │ ├── quality.ts
│ │ ├── size.ts
│ │ └── style.ts
│ └── prompts
│ │ └── text_to_image.ts
├── tests
│ ├── integration.test.ts
│ └── protocol.test.ts
├── tools
│ ├── handlers.ts
│ ├── image_viewer.ts
│ ├── index.ts
│ ├── models.ts
│ └── predictions.ts
└── types
│ └── mcp.ts
└── tsconfig.json
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/
2 | build/
3 | *.log
4 | .env*
5 | .DS_Store
6 | dist/
7 |
--------------------------------------------------------------------------------
/.npmignore:
--------------------------------------------------------------------------------
1 | # Source
2 | src/
3 | tests/
4 |
5 | # Development configs
6 | .git/
7 | .github/
8 | .gitignore
9 | .npmrc
10 | tsconfig.json
11 | vitest.config.ts
12 | .biome.json
13 |
14 | # IDE
15 | .vscode/
16 | .idea/
17 |
18 | # Build artifacts
19 | coverage/
20 | *.log
21 | .DS_Store
22 |
23 | # Development files
24 | TODO.md
25 | CONTRIBUTING.md
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 deepfates
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Replicate MCP Server
2 |
3 | A [Model Context Protocol](https://github.com/mcp-sdk/mcp) server implementation for Replicate. Run Replicate models through a simple tool-based interface.
4 |
5 | ## Quickstart
6 |
7 | 1. Install the server:
8 |
9 | ```bash
10 | npm install -g mcp-replicate
11 | ```
12 |
13 | 2. Get your Replicate API token:
14 |
15 | - Go to [Replicate API tokens page](https://replicate.com/account/api-tokens)
16 | - Create a new token if you don't have one
17 | - Copy the token for the next step
18 |
19 | 3. Configure Claude Desktop:
20 | - Open Claude Desktop Settings (⌘,)
21 | - Select the "Developer" section in the sidebar
22 | - Click "Edit Config" to open the configuration file
23 | - Add the following configuration, replacing `your_token_here` with your actual Replicate API token:
24 |
25 | ```json
26 | {
27 | "mcpServers": {
28 | "replicate": {
29 | "command": "mcp-replicate",
30 | "env": {
31 | "REPLICATE_API_TOKEN": "your_token_here"
32 | }
33 | }
34 | }
35 | }
36 | ```
37 |
38 | 4. Start Claude Desktop. You should see a 🔨 hammer icon in the bottom right corner of new chat windows, indicating the tools are available.
39 |
40 | (You can also use any other MCP client, such as Cursor, Cline, or Continue.)
41 |
42 | ## Alternative Installation Methods
43 |
44 | ### Install from source
45 |
46 | ```bash
47 | git clone https://github.com/deepfates/mcp-replicate
48 | cd mcp-replicate
49 | npm install
50 | npm run build
51 | npm start
52 | ```
53 |
54 | ### Run with npx
55 |
56 | ```bash
57 | npx mcp-replicate
58 | ```
59 |
60 | ## Features
61 |
62 | ### Models
63 |
64 | - Search models using semantic search
65 | - Browse models and collections
66 | - Get detailed model information and versions
67 |
68 | ### Predictions
69 |
70 | - Create predictions with text or structured input
71 | - Track prediction status
72 | - Cancel running predictions
73 | - List your recent predictions
74 |
75 | ### Image Handling
76 |
77 | - View generated images in your browser
78 | - Manage image cache for better performance
79 |
80 | ## Configuration
81 |
82 | The server needs a Replicate API token to work. You can get one at [Replicate](https://replicate.com/account/api-tokens).
83 |
84 | There are two ways to provide the token:
85 |
86 | ### 1. In Claude Desktop Config (Recommended)
87 |
88 | Add it to your Claude Desktop configuration as shown in the Quickstart section:
89 |
90 | ```json
91 | {
92 | "mcpServers": {
93 | "replicate": {
94 | "command": "mcp-replicate",
95 | "env": {
96 | "REPLICATE_API_TOKEN": "your_token_here"
97 | }
98 | }
99 | }
100 | }
101 | ```
102 |
103 | ### 2. As Environment Variable
104 |
105 | Alternatively, you can set it as an environment variable if you're using another MCP client:
106 |
107 | ```bash
108 | export REPLICATE_API_TOKEN=your_token_here
109 | ```
110 |
111 | ## Available Tools
112 |
113 | ### Model Tools
114 |
115 | - `search_models`: Find models using semantic search
116 | - `list_models`: Browse available models
117 | - `get_model`: Get details about a specific model
118 | - `list_collections`: Browse model collections
119 | - `get_collection`: Get details about a specific collection
120 |
121 | ### Prediction Tools
122 |
123 | - `create_prediction`: Run a model with your inputs
124 | - `create_and_poll_prediction`: Run a model with your inputs and wait until it's completed
125 | - `get_prediction`: Check a prediction's status
126 | - `cancel_prediction`: Stop a running prediction
127 | - `list_predictions`: See your recent predictions
128 |
129 | ### Image Tools
130 |
131 | - `view_image`: Open an image in your browser
132 | - `clear_image_cache`: Clean up cached images
133 | - `get_image_cache_stats`: Check cache usage
134 |
135 | ## Troubleshooting
136 |
137 | ### Server is running but tools aren't showing up
138 |
139 | 1. Check that Claude Desktop is properly configured with the MCP server settings
140 | 2. Ensure your Replicate API token is set correctly
141 | 3. Try restarting both the server and Claude Desktop
142 | 4. Check the server logs for any error messages
143 |
144 | ### Tools are visible but not working
145 |
146 | 1. Verify your Replicate API token is valid
147 | 2. Check your internet connection
148 | 3. Look for any error messages in the server output
149 |
150 | ## Development
151 |
152 | 1. Install dependencies:
153 |
154 | ```bash
155 | npm install
156 | ```
157 |
158 | 2. Start development server (with auto-reload):
159 |
160 | ```bash
161 | npm run dev
162 | ```
163 |
164 | 3. Check code style:
165 |
166 | ```bash
167 | npm run lint
168 | ```
169 |
170 | 4. Format code:
171 |
172 | ```bash
173 | npm run format
174 | ```
175 |
176 | ## Requirements
177 |
178 | - Node.js >= 18.0.0
179 | - TypeScript >= 5.0.0
180 | - [Claude Desktop](https://claude.ai/download) for using the tools
181 |
182 | ## License
183 |
184 | MIT
185 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | # MCP Server for Replicate - Simplified Implementation Plan
2 |
3 | ## Core Philosophy
4 | - Minimize complexity by focusing on tools over resources
5 | - Follow MCP spec for core functionality
6 | - Keep transport layer simple (stdio for local, SSE for remote)
7 | - Implement only essential features
8 |
9 | ## Current Status
10 | ✓ Basic functionality implemented:
11 | - Tool-based access to models and predictions
12 | - Type-safe interactions with protocol compliance
13 | - Simple error handling
14 | - Basic rate limiting
15 | - SSE transport layer for remote connections
16 |
17 | ## Implementation Plan
18 |
19 | ### Phase 1: Core Simplification (✓ Complete)
20 | 1. Replace Resource System with Tools
21 | - [x] Convert model listing to search_models tool
22 | - [x] Convert prediction access to get_prediction tool
23 | - [x] Remove resource-based URI schemes
24 | - [x] Simplify server initialization
25 |
26 | 2. Streamline Client Implementation
27 | - [x] Simplify ReplicateClient class
28 | - [x] Remove complex caching layers
29 | - [x] Implement basic error handling
30 | - [x] Add simple rate limiting
31 |
32 | 3. Transport Layer
33 | - [x] Keep stdio for local communication
34 | - [x] Implement basic SSE for remote (no complex retry logic)
35 | - [x] Remove unnecessary transport abstractions
36 |
37 | ### Phase 2: Essential Tools (✓ Complete)
38 | 1. Model Management
39 | - [x] search_models - Find models by query
40 | - [x] get_model - Get model details
41 | - [x] list_versions - List model versions
42 |
43 | 2. Prediction Handling
44 | - [x] create_prediction - Run model inference
45 | - [x] get_prediction - Check prediction status
46 | - [x] cancel_prediction - Stop running prediction
47 |
48 | 3. Image Tools
49 | - [x] view_image - Display result in browser
50 | - [x] save_image - Save to local filesystem
51 |
52 | ### Phase 3: Testing & Documentation (🚧 In Progress)
53 | 1. Testing
54 | - [x] Add basic protocol compliance tests
55 | - [x] Test core tool functionality
56 | - [x] Add integration tests
57 |
58 | 2. Documentation
59 | - [x] Update API reference for simplified interface
60 | - [ ] Add clear usage examples
61 | - [ ] Create troubleshooting guide
62 |
63 | ### Phase 4: Optional Enhancements (🚧 In Progress)
64 | 1. Webhook Support
65 | - [x] Simple webhook configuration
66 | - [x] Basic retry logic
67 | - [x] Event formatting
68 |
69 | 2. Template System
70 | - [ ] Basic parameter templates
71 | - [ ] Simple validation
72 | - [ ] Example presets
73 |
74 | ## Next Steps
75 |
76 | 1. Documentation:
77 | - Add clear usage examples
78 | - Create troubleshooting guide
79 | - Document common error cases
80 |
81 | 2. Template System:
82 | - Design parameter template format
83 | - Implement validation logic
84 | - Create example presets
85 |
86 | 3. Testing:
87 | - Add more edge case tests
88 | - Improve error handling coverage
89 | - Add performance benchmarks
90 |
91 | Legend:
92 | - [x] Completed
93 | - [ ] Not started
94 | - ✓ Phase complete
95 | - 🚧 Phase in progress
96 | - ❌ Phase not started
97 |
--------------------------------------------------------------------------------
/docs/api.md:
--------------------------------------------------------------------------------
1 | # Replicate MCP Server API Reference
2 |
3 | ## Overview
4 |
5 | The Replicate MCP Server provides a Model Context Protocol (MCP) interface to Replicate's AI model platform. This document details the available tools and methods for interacting with the server.
6 |
7 | ## Tools
8 |
9 | ### search_models
10 |
11 | Search for models using semantic search.
12 |
13 | #### Input Schema
14 | ```typescript
15 | {
16 | query: string // Search query
17 | }
18 | ```
19 |
20 | #### Example
21 | ```typescript
22 | await client.searchModels({
23 | query: "high quality text to image models"
24 | });
25 | ```
26 |
27 | ### list_models
28 |
29 | List available models with optional filtering.
30 |
31 | #### Input Schema
32 | ```typescript
33 | {
34 | owner?: string, // Filter by model owner
35 | cursor?: string // Pagination cursor
36 | }
37 | ```
38 |
39 | #### Example
40 | ```typescript
41 | await client.listModels({
42 | owner: "stability-ai"
43 | });
44 | ```
45 |
46 | ### create_prediction
47 |
48 | Create a new prediction using a model version.
49 |
50 | #### Input Schema
51 | ```typescript
52 | {
53 | version: string, // Model version ID
54 | input: Record,// Model input parameters
55 | webhook_url?: string // Optional webhook URL
56 | }
57 | ```
58 |
59 | #### Example
60 | ```typescript
61 | await client.createPrediction({
62 | version: "stability-ai/sdxl@v1.0.0",
63 | input: {
64 | prompt: "A serene mountain landscape"
65 | }
66 | });
67 | ```
68 |
69 | ### cancel_prediction
70 |
71 | Cancel a running prediction.
72 |
73 | #### Input Schema
74 | ```typescript
75 | {
76 | prediction_id: string // ID of prediction to cancel
77 | }
78 | ```
79 |
80 | #### Example
81 | ```typescript
82 | await client.cancelPrediction({
83 | prediction_id: "pred_123abc"
84 | });
85 | ```
86 |
87 | ### get_prediction
88 |
89 | Get details about a specific prediction.
90 |
91 | #### Input Schema
92 | ```typescript
93 | {
94 | prediction_id: string // ID of prediction to get details for
95 | }
96 | ```
97 |
98 | #### Example
99 | ```typescript
100 | await client.getPrediction({
101 | prediction_id: "pred_123abc"
102 | });
103 | ```
104 |
105 | ## Templates
106 |
107 | The server includes a template system for common parameter configurations.
108 |
109 | ### Quality Templates
110 |
111 | Preset quality levels for image generation:
112 | - `draft`: Fast, lower quality results
113 | - `balanced`: Good balance of speed and quality
114 | - `quality`: High quality output
115 | - `extreme`: Maximum quality, slower generation
116 |
117 | ### Style Templates
118 |
119 | Common artistic styles:
120 | - `photographic`: Realistic photo-like images
121 | - `digital-art`: Digital artwork style
122 | - `cinematic`: Movie-like composition
123 | - `anime`: Anime/manga style
124 | - `painting`: Traditional painting styles
125 |
126 | ### Size Templates
127 |
128 | Standard image dimensions:
129 | - `square`: 1024x1024
130 | - `portrait`: 832x1216
131 | - `landscape`: 1216x832
132 | - `widescreen`: 1344x768
133 |
134 | ## Error Handling
135 |
136 | The server uses a simple error handling system with a single error class and clear error messages.
137 |
138 | ### ReplicateError
139 |
140 | All errors from the Replicate API are instances of `ReplicateError`. This class provides a consistent way to handle errors across the API.
141 |
142 | ```typescript
143 | class ReplicateError extends Error {
144 | name: "ReplicateError";
145 | message: string;
146 | context?: Record;
147 | }
148 | ```
149 |
150 | ### Error Factory Functions
151 |
152 | The API provides factory functions to create standardized errors:
153 |
154 | ```typescript
155 | const createError = {
156 | rateLimit: (retryAfter: number) =>
157 | new ReplicateError("Rate limit exceeded", { retryAfter }),
158 |
159 | authentication: (details?: string) =>
160 | new ReplicateError("Authentication failed", { details }),
161 |
162 | notFound: (resource: string) =>
163 | new ReplicateError("Model not found", { resource }),
164 |
165 | validation: (field: string, message: string) =>
166 | new ReplicateError("Invalid input parameters", { field, message }),
167 |
168 | timeout: (operation: string, ms: number) =>
169 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }),
170 | };
171 | ```
172 |
173 | ### Example Error Handling
174 |
175 | ```typescript
176 | try {
177 | const prediction = await client.createPrediction({
178 | version: "stability-ai/sdxl@latest",
179 | input: { prompt: "Test prompt" }
180 | });
181 | } catch (error) {
182 | if (error instanceof ReplicateError) {
183 | console.error(error.message);
184 | // Access additional context if available
185 | if (error.context) {
186 | console.error("Error context:", error.context);
187 | }
188 | }
189 | }
190 | ```
191 |
192 | ### Automatic Retries
193 |
194 | The API includes built-in retry functionality for certain types of errors:
195 |
196 | ```typescript
197 | const result = await ErrorHandler.withRetries(
198 | async () => client.listModels(),
199 | {
200 | maxAttempts: 3,
201 | minDelay: 1000,
202 | maxDelay: 30000,
203 | retryIf: (error) => error instanceof ReplicateError
204 | }
205 | );
206 | ```
207 |
208 | ## Rate Limiting
209 |
210 | The server implements basic rate limiting to prevent abuse. When rate limits are exceeded, the API returns a `ReplicateError` with the message "Rate limit exceeded" and includes a `retryAfter` value in the context.
211 |
212 | ### Limits
213 | - API requests per minute
214 | - Concurrent predictions per user
215 |
216 | ### Rate Limit Headers
217 | ```typescript
218 | {
219 | "X-RateLimit-Limit": "100", // Maximum requests per minute
220 | "X-RateLimit-Remaining": "95", // Remaining requests
221 | "X-RateLimit-Reset": "1704891600" // Unix timestamp when limit resets
222 | }
223 | ```
224 |
--------------------------------------------------------------------------------
/docs/examples.md:
--------------------------------------------------------------------------------
1 | # Replicate MCP Server Usage Examples
2 |
3 | This document provides practical examples of common use cases for the Replicate MCP server.
4 |
5 | ## Basic Usage
6 |
7 | ### Generating Images with Text
8 |
9 | ```typescript
10 | // Create a prediction with SDXL
11 | const prediction = await client.createPrediction({
12 | version: "stability-ai/sdxl@latest",
13 | input: {
14 | prompt: "A serene mountain landscape at sunset",
15 | quality: "balanced",
16 | style: "photographic"
17 | }
18 | });
19 |
20 | // Get prediction status
21 | const status = await client.getPrediction({
22 | prediction_id: prediction.id
23 | });
24 |
25 | // Cancel a running prediction if needed
26 | await client.cancelPrediction({
27 | prediction_id: prediction.id
28 | });
29 | ```
30 |
31 | ### Browsing Models
32 |
33 | ```typescript
34 | // List all models (uses caching)
35 | const models = await client.listModels({});
36 |
37 | // Filter models by owner (cached by owner)
38 | const stabilityModels = await client.listModels({
39 | owner: "stability-ai"
40 | });
41 |
42 | // Search for specific models (cached by query)
43 | const searchResults = await client.searchModels({
44 | query: "text to image models with good quality"
45 | });
46 | ```
47 |
48 | ### Working with Collections
49 |
50 | ```typescript
51 | // List available collections (cached)
52 | const collections = await client.listCollections({});
53 |
54 | // Get details of a specific collection (cached by slug)
55 | const textToImage = await client.getCollection({
56 | slug: "text-to-image"
57 | });
58 |
59 | // Browse models in a collection
60 | const collectionModels = textToImage.models;
61 | ```
62 |
63 | ## Advanced Usage
64 |
65 | ### Using Templates
66 |
67 | ```typescript
68 | // Using quality templates
69 | const highQualityPrediction = await client.createPrediction({
70 | version: "stability-ai/sdxl@latest",
71 | input: {
72 | prompt: "A futuristic cityscape",
73 | ...templates.quality.extreme,
74 | ...templates.style.cinematic,
75 | ...templates.size.widescreen
76 | }
77 | });
78 | ```
79 |
80 | ### Webhook Integration
81 |
82 | ```typescript
83 | // Create prediction with webhook notification
84 | const prediction = await client.createPrediction({
85 | version: "stability-ai/sdxl@latest",
86 | input: {
87 | prompt: "An abstract digital artwork"
88 | },
89 | webhook_url: "https://api.myapp.com/webhooks/replicate"
90 | });
91 |
92 | // Example webhook handler (Express.js)
93 | app.post("/webhooks/replicate", async (req, res) => {
94 | try {
95 | const signature = req.headers["x-replicate-signature"];
96 | const webhookSecret = await client.getWebhookSecret();
97 |
98 | if (!verifyWebhookSignature(signature, webhookSecret, req.body)) {
99 | throw new ValidationError("Invalid signature");
100 | }
101 |
102 | const { event, prediction } = req.body;
103 | switch (event) {
104 | case "prediction.completed":
105 | await handleCompletedPrediction(prediction);
106 | break;
107 | case "prediction.failed":
108 | await handleFailedPrediction(prediction);
109 | break;
110 | }
111 |
112 | res.status(200).send("OK");
113 | } catch (error) {
114 | console.error(ErrorHandler.createErrorReport(error));
115 | res.status(error instanceof ValidationError ? 401 : 500).json({
116 | error: error.message,
117 | code: error.name
118 | });
119 | }
120 | });
121 | ```
122 |
123 | ### Error Handling
124 |
125 | ```typescript
126 | try {
127 | const prediction = await client.createPrediction({
128 | version: "stability-ai/sdxl@latest",
129 | input: { prompt: "Test prompt" }
130 | });
131 | } catch (error) {
132 | if (error instanceof ReplicateError) {
133 | console.error(error.message);
134 | // Optional: access additional context if available
135 | if (error.context) {
136 | console.error("Error context:", error.context);
137 | }
138 | } else {
139 | console.error("Unexpected error:", error);
140 | }
141 | }
142 | ```
143 |
144 | ### Automatic Retries
145 |
146 | ```typescript
147 | // Using the built-in retry functionality
148 | const result = await ErrorHandler.withRetries(
149 | async () => client.createPrediction({
150 | version: "stability-ai/sdxl@latest",
151 | input: { prompt: "A test image" }
152 | }),
153 | {
154 | maxAttempts: 3,
155 | minDelay: 1000,
156 | maxDelay: 10000,
157 | retryIf: (error) => error instanceof ReplicateError,
158 | onRetry: (error, attempt) => {
159 | console.warn(
160 | `Request failed: ${error.message}. `,
161 | `Retrying (attempt ${attempt + 1}/3)`
162 | );
163 | }
164 | }
165 | );
166 | ```
167 |
168 | ### Handling Common Errors
169 |
170 | ```typescript
171 | try {
172 | const prediction = await client.createPrediction({
173 | version: "stability-ai/sdxl@latest",
174 | input: { prompt: "Test prompt" }
175 | });
176 | } catch (error) {
177 | if (error instanceof ReplicateError) {
178 | switch (error.message) {
179 | case "Rate limit exceeded":
180 | const retryAfter = error.context?.retryAfter;
181 | console.log(`Rate limit hit. Retry after ${retryAfter} seconds`);
182 | break;
183 | case "Authentication failed":
184 | console.log("Please check your API token");
185 | break;
186 | case "Model not found":
187 | console.log(`Model ${error.context?.resource} not found`);
188 | break;
189 | case "Invalid input parameters":
190 | console.log(`Invalid input: ${error.context?.field} - ${error.context?.message}`);
191 | break;
192 | default:
193 | console.log("Operation failed:", error.message);
194 | }
195 | }
196 | }
197 | ```
198 |
199 | ### Batch Processing
200 |
201 | ```typescript
202 | // Batch processing example
203 | async function batchProcess(prompts: string[]) {
204 | const predictions = await Promise.all(
205 | prompts.map(prompt =>
206 | client.createPrediction({
207 | version: "stability-ai/sdxl@latest",
208 | input: { prompt }
209 | })
210 | )
211 | );
212 |
213 | // Monitor all predictions
214 | return Promise.all(
215 | predictions.map(prediction =>
216 | client.getPredictionStatus(prediction.id)
217 | )
218 | );
219 | }
220 |
221 | // Usage
222 | const prompts = [
223 | "A serene beach at sunset",
224 | "A mystical forest with glowing mushrooms",
225 | "A futuristic space station"
226 | ];
227 |
228 | const results = await batchProcess(prompts);
229 | ```
230 |
231 | ## Best Practices
232 |
233 | 1. Use proper error handling:
234 | ```typescript
235 | try {
236 | // Your code here
237 | } catch (error) {
238 | if (error instanceof ReplicateError) {
239 | console.error(error.message);
240 | } else {
241 | console.error("Unexpected error:", error);
242 | }
243 | }
244 | ```
245 |
246 | 2. Implement proper retry logic:
247 | ```typescript
248 | let attempts = 0;
249 | const maxAttempts = 3;
250 |
251 | while (attempts < maxAttempts) {
252 | try {
253 | const result = await makeRequest();
254 | break;
255 | } catch (error) {
256 | attempts++;
257 | if (attempts === maxAttempts) throw error;
258 | await new Promise(resolve => setTimeout(resolve, 1000 * attempts));
259 | }
260 | }
261 | ```
262 |
263 | 3. Use webhooks effectively:
264 | ```typescript
265 | // Create prediction with webhook notification
266 | const prediction = await client.createPrediction({
267 | version: "stability-ai/sdxl@latest",
268 | input: {
269 | prompt: "An abstract digital artwork"
270 | },
271 | webhook_url: "https://api.myapp.com/webhooks/replicate"
272 | });
273 | ```
274 |
275 | 4. Handle rate limits gracefully:
276 | ```typescript
277 | try {
278 | await makeRequest();
279 | } catch (error) {
280 | if (error.message.includes("rate limit")) {
281 | console.log("Rate limit hit. Please try again later.");
282 | }
283 | }
284 | ```
285 |
286 | 5. Clean up resources:
287 | ```typescript
288 | try {
289 | // Use API
290 | await makeRequest();
291 | } catch (error) {
292 | // Handle errors
293 | console.error(error);
294 | }
295 | ```
296 |
--------------------------------------------------------------------------------
/docs/troubleshooting.md:
--------------------------------------------------------------------------------
1 | # Troubleshooting Guide
2 |
3 | This guide helps you diagnose and resolve common issues with the Replicate MCP Server.
4 |
5 | ## Error Handling System
6 |
7 | ### ReplicateError Class
8 |
9 | ```typescript
10 | // Base error class with context
11 | class ReplicateError extends Error {
12 | constructor(message: string, public context?: Record) {
13 | super(message);
14 | this.name = "ReplicateError";
15 | }
16 | }
17 | ```
18 |
19 | ### Error Factory Functions
20 |
21 | ```typescript
22 | const createError = {
23 | rateLimit: (retryAfter: number) =>
24 | new ReplicateError("Rate limit exceeded", { retryAfter }),
25 |
26 | authentication: (details?: string) =>
27 | new ReplicateError("Authentication failed", { details }),
28 |
29 | notFound: (resource: string) =>
30 | new ReplicateError("Model not found", { resource }),
31 |
32 | validation: (field: string, message: string) =>
33 | new ReplicateError("Invalid input parameters", { field, message }),
34 |
35 | timeout: (operation: string, ms: number) =>
36 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }),
37 | };
38 | ```
39 |
40 | ### Error Reports
41 |
42 | ```typescript
43 | interface ErrorReport {
44 | name: string;
45 | message: string;
46 | context?: Record;
47 | timestamp: string;
48 | }
49 |
50 | // Generate error reports
51 | const report = ErrorHandler.createErrorReport(error);
52 | console.error(JSON.stringify(report, null, 2));
53 | ```
54 |
55 | ## Common Issues
56 |
57 | ### 1. Authentication Issues
58 |
59 | #### Symptoms
60 | - "Authentication failed" errors
61 | - 401 Unauthorized responses
62 | - Authentication-related webhook failures
63 |
64 | #### Solutions
65 | 1. Verify API Token
66 | ```typescript
67 | try {
68 | await client.listModels();
69 | } catch (error) {
70 | if (error instanceof ReplicateError && error.message === "Authentication failed") {
71 | console.error("Authentication failed:", error.context?.details);
72 | }
73 | }
74 | ```
75 |
76 | 2. Check Environment Variables
77 | ```typescript
78 | if (!process.env.REPLICATE_API_TOKEN) {
79 | throw createError.authentication("Missing API token");
80 | }
81 | ```
82 |
83 | ### 2. Rate Limiting Issues
84 |
85 | #### Handling Rate Limits
86 |
87 | ```typescript
88 | try {
89 | await client.createPrediction({
90 | version: "stability-ai/sdxl@latest",
91 | input: { prompt: "test" }
92 | });
93 | } catch (error) {
94 | if (error instanceof ReplicateError && error.message === "Rate limit exceeded") {
95 | const retryAfter = error.context?.retryAfter;
96 | console.log(`Rate limit hit. Retry after ${retryAfter} seconds`);
97 | }
98 | }
99 | ```
100 |
101 | ### 3. Network Issues
102 |
103 | #### Enhanced Network Error Handling
104 |
105 | ```typescript
106 | const result = await ErrorHandler.withRetries(
107 | async () => client.createPrediction({
108 | version: "stability-ai/sdxl@latest",
109 | input: { prompt: "test" }
110 | }),
111 | {
112 | maxAttempts: 3,
113 | minDelay: 1000,
114 | maxDelay: 30000,
115 | retryIf: (error) => error instanceof ReplicateError,
116 | onRetry: (error, attempt) => {
117 | console.warn(
118 | `Request failed: ${error.message}. `,
119 | `Retrying (attempt ${attempt + 1}/3)`
120 | );
121 | }
122 | }
123 | );
124 | ```
125 |
126 | ### 4. Prediction Issues
127 |
128 | #### Input Validation
129 |
130 | ```typescript
131 | try {
132 | const prediction = await client.createPrediction({
133 | version: "stability-ai/sdxl@latest",
134 | input: { prompt: "" } // Invalid empty prompt
135 | });
136 | } catch (error) {
137 | if (error instanceof ReplicateError && error.message === "Invalid input parameters") {
138 | console.error(`Validation error: ${error.context?.field} - ${error.context?.message}`);
139 | }
140 | }
141 | ```
142 |
143 | ## Getting Help
144 |
145 | If you're experiencing issues:
146 |
147 | 1. Check error reports:
148 | ```typescript
149 | const errorReport = ErrorHandler.createErrorReport(error);
150 | console.error(JSON.stringify(errorReport, null, 2));
151 | ```
152 |
153 | 2. Generate a diagnostic report:
154 | ```typescript
155 | const diagnostics = {
156 | error: ErrorHandler.createErrorReport(error),
157 | timestamp: new Date().toISOString(),
158 | environment: {
159 | nodeVersion: process.version,
160 | platform: process.platform,
161 | arch: process.arch,
162 | env: process.env.NODE_ENV
163 | }
164 | };
165 | ```
166 |
167 | 3. Contact support with detailed information:
168 | - Error reports with context
169 | - Environment details
170 | - Steps to reproduce
171 |
172 | 4. Resources:
173 | - [GitHub Issues](https://github.com/replicate/replicate/issues)
174 | - [API Documentation](https://replicate.com/docs)
175 | - [Discord Community](https://discord.gg/replicate)
176 | - [Support Email](mailto:support@replicate.com)
177 |
--------------------------------------------------------------------------------
/docs/webhooks.md:
--------------------------------------------------------------------------------
1 | # Webhook Integration Guide
2 |
3 | This guide explains how to integrate webhooks with the Replicate MCP Server for receiving updates about predictions.
4 |
5 | ## Overview
6 |
7 | Webhooks provide a way to receive asynchronous notifications about events in your Replicate predictions. Instead of polling for updates, your application can receive push notifications when important events occur.
8 |
9 | ## Supported Events
10 |
11 | The server sends webhooks for the following events:
12 |
13 | - `prediction.started`: When a prediction begins processing
14 | - `prediction.completed`: When a prediction successfully completes
15 | - `prediction.failed`: When a prediction encounters an error
16 |
17 | ## Webhook Payload
18 |
19 | Each webhook delivery includes a JSON payload with event details:
20 |
21 | ```typescript
22 | interface WebhookPayload {
23 | event: string; // Event type
24 | prediction: { // Prediction details
25 | id: string; // Prediction ID
26 | version: string; // Model version
27 | input: Record; // Input parameters
28 | output?: any; // Output data (for completed predictions)
29 | error?: string; // Error message (for failed predictions)
30 | status: string; // Current status
31 | created_at: string; // Creation timestamp
32 | started_at?: string; // Processing start timestamp
33 | completed_at?: string; // Completion timestamp
34 | };
35 | timestamp: string; // Event timestamp
36 | }
37 | ```
38 |
39 | ## Setting Up Webhooks
40 |
41 | ### 1. Create a Webhook Endpoint
42 |
43 | First, create an endpoint in your application to receive webhook notifications. Example using Express:
44 |
45 | ```typescript
46 | import express from "express";
47 |
48 | const app = express();
49 | app.use(express.json());
50 |
51 | app.post("/webhooks/replicate", async (req, res) => {
52 | try {
53 | const { event, prediction } = req.body;
54 |
55 | switch (event) {
56 | case "prediction.started":
57 | await handlePredictionStarted(prediction);
58 | break;
59 | case "prediction.completed":
60 | await handlePredictionCompleted(prediction);
61 | break;
62 | case "prediction.failed":
63 | await handlePredictionFailed(prediction);
64 | break;
65 | default:
66 | console.warn(`Unknown event type: ${event}`);
67 | }
68 |
69 | // Return 200 OK quickly to acknowledge receipt
70 | res.status(200).send("OK");
71 | } catch (error) {
72 | // Log error for debugging
73 | console.error("Webhook processing failed:", error);
74 |
75 | // Return 500 error
76 | res.status(500).json({
77 | error: "Webhook processing failed",
78 | details: error instanceof Error ? error.message : String(error)
79 | });
80 | }
81 | });
82 | ```
83 |
84 | ### 2. Configure Webhook URL
85 |
86 | When creating a prediction, include your webhook URL:
87 |
88 | ```typescript
89 | const prediction = await client.createPrediction({
90 | version: "stability-ai/sdxl@latest",
91 | input: {
92 | prompt: "A serene landscape"
93 | },
94 | webhook_url: "https://api.yourapp.com/webhooks/replicate"
95 | });
96 | ```
97 |
98 | ## Error Handling
99 |
100 | When processing webhooks, implement proper error handling:
101 |
102 | ```typescript
103 | app.post("/webhooks/replicate", async (req, res) => {
104 | try {
105 | const { event, prediction } = req.body;
106 |
107 | // Process webhook asynchronously
108 | processWebhookAsync(event, prediction).catch(error => {
109 | if (error instanceof ReplicateError) {
110 | console.error("Webhook processing failed:", error.message, error.context);
111 | } else {
112 | console.error("Unexpected error in webhook processing:", error);
113 | }
114 | });
115 |
116 | // Return success quickly
117 | res.status(200).send("OK");
118 | } catch (error) {
119 | if (error instanceof ReplicateError) {
120 | console.error("Webhook error:", error.message, error.context);
121 | res.status(400).json({
122 | error: error.message,
123 | context: error.context
124 | });
125 | } else {
126 | console.error("Unexpected webhook error:", error);
127 | res.status(500).json({
128 | error: "Internal server error",
129 | message: error instanceof Error ? error.message : String(error)
130 | });
131 | }
132 | }
133 | });
134 |
135 | async function processWebhookAsync(event: string, prediction: any) {
136 | try {
137 | switch (event) {
138 | case "prediction.completed":
139 | await handlePredictionCompleted(prediction);
140 | break;
141 | case "prediction.failed":
142 | await handlePredictionFailed(prediction);
143 | break;
144 | default:
145 | throw createError.validation("event", `Unknown event type: ${event}`);
146 | }
147 | } catch (error) {
148 | // Log error but don't throw since we're in an async context
149 | if (error instanceof ReplicateError) {
150 | console.error("Failed to process webhook:", error.message, error.context);
151 | } else {
152 | console.error("Unexpected error in webhook processing:", error);
153 | }
154 | }
155 | }
156 | ```
157 |
158 | ## Best Practices
159 |
160 | 1. **Use HTTPS**
161 | - Always use HTTPS for webhook endpoints
162 | - Ensure proper TLS configuration
163 |
164 | 2. **Handle Errors Gracefully**
165 | - Use the `ReplicateError` class for consistent error handling
166 | - Include relevant context in error messages
167 | - Return appropriate status codes
168 |
169 | 3. **Process Asynchronously**
170 | - Handle webhook processing in the background
171 | - Return 200 OK quickly to acknowledge receipt
172 | - Use error handling in async handlers
173 |
174 | 4. **Monitor Webhook Health**
175 | - Log webhook deliveries and errors
176 | - Track success/failure rates
177 | - Monitor processing times
178 |
--------------------------------------------------------------------------------
/mcp-replicate-0.1.0.tgz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepfates/mcp-replicate/1c7a2e6b01f78959414ca1a4a9973e3cd878216b/mcp-replicate-0.1.0.tgz
--------------------------------------------------------------------------------
/mcp-replicate-0.1.1.tgz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepfates/mcp-replicate/1c7a2e6b01f78959414ca1a4a9973e3cd878216b/mcp-replicate-0.1.1.tgz
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mcp-replicate",
3 | "version": "0.1.1",
4 | "description": "Run Replicate models through a simple MCP server interface",
5 | "type": "module",
6 | "main": "build/index.js",
7 | "types": "build/index.d.ts",
8 | "bin": {
9 | "mcp-replicate": "build/index.js"
10 | },
11 | "files": [
12 | "build",
13 | "README.md",
14 | "LICENSE"
15 | ],
16 | "scripts": {
17 | "build": "tsc",
18 | "postbuild": "chmod +x build/index.js",
19 | "start": "node build/index.js",
20 | "dev": "tsc -w",
21 | "test": "vitest",
22 | "test:watch": "vitest watch",
23 | "test:coverage": "vitest run --coverage",
24 | "lint": "biome check .",
25 | "format": "biome format . --write",
26 | "prepublishOnly": "npm run build"
27 | },
28 | "dependencies": {
29 | "@modelcontextprotocol/sdk": "^0.6.0",
30 | "replicate": "^1.0.1"
31 | },
32 | "devDependencies": {
33 | "@biomejs/biome": "^1.5.3",
34 | "@types/node": "^20.11.5",
35 | "typescript": "^5.3.3",
36 | "vitest": "^1.2.1"
37 | },
38 | "engines": {
39 | "node": ">=18.0.0"
40 | },
41 | "keywords": [
42 | "replicate",
43 | "mcp",
44 | "machine-learning",
45 | "ai",
46 | "model-context-protocol"
47 | ],
48 | "author": "deepfates",
49 | "license": "MIT",
50 | "repository": {
51 | "type": "git",
52 | "url": "https://github.com/deepfates/mcp-replicate"
53 | },
54 | "bugs": {
55 | "url": "https://github.com/deepfates/mcp-replicate/issues"
56 | },
57 | "homepage": "https://github.com/deepfates/mcp-replicate#readme",
58 | "publishConfig": {
59 | "access": "public"
60 | },
61 | "vitest": {
62 | "include": [
63 | "src/**/*.test.ts"
64 | ],
65 | "exclude": [
66 | "node_modules",
67 | "build"
68 | ],
69 | "environment": "node",
70 | "testTimeout": 10000,
71 | "coverage": {
72 | "provider": "v8",
73 | "reporter": [
74 | "text",
75 | "html"
76 | ],
77 | "exclude": [
78 | "node_modules",
79 | "build",
80 | "**/*.test.ts",
81 | "**/*.d.ts"
82 | ]
83 | }
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/index.ts:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 |
3 | /**
4 | * MCP server implementation for Replicate.
5 | * Provides access to Replicate models and predictions through MCP.
6 | */
7 |
8 | import { Server } from "@modelcontextprotocol/sdk/server/index.js";
9 | import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
10 | import {
11 | CallToolRequestSchema,
12 | ListToolsRequestSchema,
13 | } from "@modelcontextprotocol/sdk/types.js";
14 |
15 | import { ReplicateClient } from "./replicate_client.js";
16 | import type { Model } from "./models/model.js";
17 | import type {
18 | Prediction,
19 | ModelIO,
20 | PredictionStatus,
21 | } from "./models/prediction.js";
22 | import type { Collection } from "./models/collection.js";
23 |
24 | import { tools } from "./tools/index.js";
25 | import {
26 | handleSearchModels,
27 | handleListModels,
28 | handleListCollections,
29 | handleGetCollection,
30 | handleCreatePrediction,
31 | handleCreateAndPollPrediction,
32 | handleCancelPrediction,
33 | handleGetPrediction,
34 | handleListPredictions,
35 | handleGetModel,
36 | } from "./tools/handlers.js";
37 | import {
38 | handleViewImage,
39 | handleClearImageCache,
40 | handleGetImageCacheStats,
41 | } from "./tools/image_viewer.js";
42 |
43 | // Initialize Replicate client
44 | const client = new ReplicateClient();
45 |
46 | // Cache for models, predictions, collections, and prediction status
47 | const modelCache = new Map();
48 | const predictionCache = new Map();
49 | const collectionCache = new Map();
50 | const predictionStatus = new Map();
51 |
52 | // Cache object for tool handlers
53 | const cache = {
54 | modelCache,
55 | predictionCache,
56 | collectionCache,
57 | predictionStatus,
58 | };
59 |
60 | /**
61 | * Create an MCP server with capabilities for
62 | * tools (to run predictions)
63 | */
64 | const server = new Server(
65 | {
66 | name: "replicate",
67 | version: "0.1.0",
68 | },
69 | {
70 | capabilities: {
71 | tools: {},
72 | prompts: {},
73 | },
74 | }
75 | );
76 |
77 | /**
78 | * Handler that lists available tools.
79 | */
80 | server.setRequestHandler(ListToolsRequestSchema, async () => {
81 | return { tools };
82 | });
83 |
84 | /**
85 | * Handler for tools.
86 | */
87 | server.setRequestHandler(CallToolRequestSchema, async (request) => {
88 | switch (request.params.name) {
89 | case "search_models":
90 | return handleSearchModels(client, cache, {
91 | query: String(request.params.arguments?.query),
92 | });
93 |
94 | case "list_models":
95 | return handleListModels(client, cache, {
96 | owner: request.params.arguments?.owner as string | undefined,
97 | cursor: request.params.arguments?.cursor as string | undefined,
98 | });
99 |
100 | case "list_collections":
101 | return handleListCollections(client, cache, {
102 | cursor: request.params.arguments?.cursor as string | undefined,
103 | });
104 |
105 | case "get_collection":
106 | return handleGetCollection(client, cache, {
107 | slug: String(request.params.arguments?.slug),
108 | });
109 |
110 | case "create_prediction":
111 | return handleCreatePrediction(client, cache, {
112 | version: request.params.arguments?.version as string | undefined,
113 | model: request.params.arguments?.model as string | undefined,
114 | input: request.params.arguments?.input as ModelIO,
115 | webhook: request.params.arguments?.webhook_url as string | undefined,
116 | });
117 |
118 | case "create_and_poll_prediction":
119 | return handleCreateAndPollPrediction(client, cache, {
120 | version: request.params.arguments?.version as string | undefined,
121 | model: request.params.arguments?.model as string | undefined,
122 | input: request.params.arguments?.input as ModelIO,
123 | webhook: request.params.arguments?.webhook_url as string | undefined,
124 | pollInterval: request.params.arguments?.poll_interval as
125 | | number
126 | | undefined,
127 | timeout: request.params.arguments?.timeout as number | undefined,
128 | });
129 |
130 | case "cancel_prediction":
131 | return handleCancelPrediction(client, cache, {
132 | prediction_id: String(request.params.arguments?.prediction_id),
133 | });
134 |
135 | case "get_prediction":
136 | return handleGetPrediction(client, cache, {
137 | prediction_id: String(request.params.arguments?.prediction_id),
138 | });
139 |
140 | case "list_predictions":
141 | return handleListPredictions(client, cache, {
142 | limit: request.params.arguments?.limit as number | undefined,
143 | cursor: request.params.arguments?.cursor as string | undefined,
144 | });
145 |
146 | case "get_model":
147 | return handleGetModel(client, cache, {
148 | owner: String(request.params.arguments?.owner),
149 | name: String(request.params.arguments?.name),
150 | });
151 |
152 | case "view_image":
153 | return handleViewImage(request);
154 |
155 | case "clear_image_cache":
156 | return handleClearImageCache(request);
157 |
158 | case "get_image_cache_stats":
159 | return handleGetImageCacheStats(request);
160 |
161 | default:
162 | throw new Error("Unknown tool");
163 | }
164 | });
165 |
166 | /**
167 | * Start the server using stdio transport.
168 | */
169 | async function main() {
170 | const transport = new StdioServerTransport();
171 | await server.connect(transport);
172 | }
173 |
174 | main().catch((error) => {
175 | console.error("Server error:", error);
176 | process.exit(1);
177 | });
178 |
--------------------------------------------------------------------------------
/src/models/collection.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Data models for Replicate collections.
3 | */
4 |
5 | import type { Model } from "./model.js";
6 |
7 | /**
8 | * A collection of related models on Replicate.
9 | */
10 | export interface Collection {
11 | /** Unique identifier for this collection */
12 | id: string;
13 | /** Human-readable name of the collection */
14 | name: string;
15 | /** URL-friendly slug for the collection */
16 | slug: string;
17 | /** Description of the collection's purpose */
18 | description?: string;
19 | /** Models included in this collection */
20 | models: Model[];
21 | /** Whether this collection is featured */
22 | featured?: boolean;
23 | /** When this collection was created */
24 | created_at: string;
25 | /** When this collection was last updated */
26 | updated_at?: string;
27 | }
28 |
29 | /**
30 | * Response format for listing collections.
31 | */
32 | export interface CollectionList {
33 | /** List of collections */
34 | collections: Collection[];
35 | /** Cursor for pagination */
36 | next_cursor?: string;
37 | /** Total number of collections */
38 | total_count?: number;
39 | }
40 |
--------------------------------------------------------------------------------
/src/models/hardware.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Data models for Replicate hardware options.
3 | */
4 |
5 | /**
6 | * A hardware option for running models on Replicate.
7 | */
8 | export interface Hardware {
9 | /** Human-readable name of the hardware */
10 | name: string;
11 | /** SKU identifier for the hardware */
12 | sku: string;
13 | }
14 |
15 | /**
16 | * Response format for listing hardware options.
17 | */
18 | export interface HardwareList {
19 | /** List of available hardware options */
20 | hardware: Hardware[];
21 | }
22 |
--------------------------------------------------------------------------------
/src/models/model.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Data models for Replicate models and versions.
3 | */
4 |
5 | /**
6 | * OpenAPI schema types
7 | */
8 | export interface OpenAPISchema {
9 | openapi: string;
10 | info: {
11 | title: string;
12 | version: string;
13 | };
14 | paths: Record;
15 | components?: {
16 | schemas?: Record;
17 | parameters?: Record;
18 | };
19 | }
20 |
21 | /**
22 | * A specific version of a model on Replicate.
23 | */
24 | export interface ModelVersion {
25 | /** Unique identifier for this model version */
26 | id: string;
27 | /** When this version was created */
28 | created_at: string;
29 | /** Version of Cog used to create this model */
30 | cog_version: string;
31 | /** OpenAPI schema for the model */
32 | openapi_schema: OpenAPISchema;
33 | /** Model identifier (owner/name) */
34 | model?: string;
35 | /** Replicate version identifier */
36 | replicate_version?: string;
37 | /** Hardware configuration for this version */
38 | hardware?: string;
39 | }
40 |
41 | /**
42 | * Model information returned from Replicate.
43 | */
44 | export interface Model {
45 | /** Unique identifier in format owner/name */
46 | id: string;
47 | /** Owner of the model (user or organization) */
48 | owner: string;
49 | /** Name of the model */
50 | name: string;
51 | /** Description of the model's purpose and usage */
52 | description?: string;
53 | /** Model visibility (public/private) */
54 | visibility: "public" | "private";
55 | /** URL to model's GitHub repository */
56 | github_url?: string;
57 | /** URL to model's research paper */
58 | paper_url?: string;
59 | /** URL to model's license */
60 | license_url?: string;
61 | /** Number of times this model has been run */
62 | run_count?: number;
63 | /** URL to model's cover image */
64 | cover_image_url?: string;
65 | /** Latest version of the model */
66 | latest_version?: ModelVersion;
67 | /** Default example inputs */
68 | default_example?: Record;
69 | /** Whether this model is featured */
70 | featured?: boolean;
71 | /** Model tags */
72 | tags?: string[];
73 | }
74 |
75 | /**
76 | * Response format for listing models.
77 | */
78 | export interface ModelList {
79 | /** List of models */
80 | models: Model[];
81 | /** Cursor for pagination */
82 | next_cursor?: string;
83 | /** Total number of models */
84 | total_count?: number;
85 | }
86 |
--------------------------------------------------------------------------------
/src/models/openapi.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * OpenAPI schema types.
3 | */
4 |
5 | export interface OpenAPISchema {
6 | openapi: string;
7 | info: {
8 | title: string;
9 | version: string;
10 | };
11 | components?: {
12 | schemas?: Record;
13 | };
14 | }
15 |
16 | export interface SchemaObject {
17 | type?: string;
18 | required?: string[];
19 | properties?: Record;
20 | additionalProperties?: boolean | SchemaObject;
21 | }
22 |
23 | export interface PropertyObject {
24 | type: string;
25 | format?: string;
26 | description?: string;
27 | default?: unknown;
28 | minimum?: number;
29 | maximum?: number;
30 | enum?: string[];
31 | items?: SchemaObject;
32 | }
33 |
34 | export type ModelIO = Record;
35 |
--------------------------------------------------------------------------------
/src/models/prediction.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Data models for Replicate predictions.
3 | */
4 |
5 | /**
6 | * Status of a prediction.
7 | */
8 | export enum PredictionStatus {
9 | Starting = "starting",
10 | Processing = "processing",
11 | Succeeded = "succeeded",
12 | Failed = "failed",
13 | Canceled = "canceled",
14 | }
15 |
16 | /**
17 | * Model input/output types
18 | */
19 | export type ModelIO = Record;
20 |
21 | /**
22 | * Input parameters for creating a prediction.
23 | */
24 | export interface PredictionInput {
25 | /** Model version to use for prediction */
26 | model_version: string;
27 | /** Model-specific input parameters */
28 | input: ModelIO;
29 | /** Optional template ID to use */
30 | template_id?: string;
31 | /** URL for webhook notifications */
32 | webhook_url?: string;
33 | /** Events to trigger webhooks */
34 | webhook_events?: string[];
35 | /** Whether to wait for prediction completion */
36 | wait?: boolean;
37 | /** Max seconds to wait if wait=True (1-60) */
38 | wait_timeout?: number;
39 | /** Whether to request streaming output */
40 | stream?: boolean;
41 | }
42 |
43 | /**
44 | * A prediction (model run) on Replicate.
45 | */
46 | export interface Prediction {
47 | /** Unique identifier for this prediction */
48 | id: string;
49 | /** Model version used for this prediction */
50 | version: string;
51 | /** Current status of the prediction */
52 | status: PredictionStatus | string;
53 | /** Input parameters used for the prediction */
54 | input: ModelIO;
55 | /** Output from the prediction if completed */
56 | output?: ModelIO;
57 | /** Error message if prediction failed */
58 | error?: string;
59 | /** Execution logs from the prediction */
60 | logs?: string;
61 | /** When the prediction was created */
62 | created_at: string;
63 | /** When the prediction started processing */
64 | started_at?: string;
65 | /** When the prediction completed */
66 | completed_at?: string;
67 | /** Related API URLs for this prediction */
68 | urls: Record;
69 | /** Performance metrics if available */
70 | metrics?: Record;
71 | /** URL for streaming output if requested */
72 | stream_url?: string;
73 | }
74 |
--------------------------------------------------------------------------------
/src/models/webhook.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Webhook types and utilities.
3 | */
4 |
5 | import crypto from "node:crypto";
6 |
7 | export interface WebhookConfig {
8 | url: string;
9 | secret?: string;
10 | retries?: number;
11 | timeout?: number;
12 | }
13 |
14 | export interface WebhookEvent {
15 | type: WebhookEventType;
16 | timestamp: string;
17 | data: Record;
18 | }
19 |
20 | export type WebhookEventType =
21 | | "prediction.created"
22 | | "prediction.processing"
23 | | "prediction.succeeded"
24 | | "prediction.failed"
25 | | "prediction.canceled";
26 |
27 | export interface WebhookDeliveryResult {
28 | success: boolean;
29 | statusCode?: number;
30 | error?: string;
31 | retryCount: number;
32 | timestamp: string;
33 | }
34 |
35 | /**
36 | * Generate a webhook signature for request verification.
37 | */
38 | export function generateWebhookSignature(
39 | payload: string,
40 | secret: string
41 | ): string {
42 | const hmac = crypto.createHmac("sha256", secret);
43 | hmac.update(payload);
44 | return `sha256=${hmac.digest("hex")}`;
45 | }
46 |
47 | /**
48 | * Verify a webhook signature from request headers.
49 | */
50 | export function verifyWebhookSignature(
51 | payload: string,
52 | signature: string,
53 | secret: string
54 | ): boolean {
55 | const expectedSignature = generateWebhookSignature(payload, secret);
56 | return crypto.timingSafeEqual(
57 | Buffer.from(signature),
58 | Buffer.from(expectedSignature)
59 | );
60 | }
61 |
62 | /**
63 | * Format a webhook event payload.
64 | */
65 | export function formatWebhookEvent(
66 | type: WebhookEventType,
67 | data: Record
68 | ): WebhookEvent {
69 | return {
70 | type,
71 | timestamp: new Date().toISOString(),
72 | data,
73 | };
74 | }
75 |
76 | /**
77 | * Validate webhook configuration.
78 | */
79 | export function validateWebhookConfig(config: WebhookConfig): string[] {
80 | const errors: string[] = [];
81 |
82 | // Validate URL
83 | try {
84 | new URL(config.url);
85 | } catch {
86 | errors.push("Invalid webhook URL");
87 | }
88 |
89 | // Validate secret
90 | if (config.secret && config.secret.length < 32) {
91 | errors.push("Webhook secret should be at least 32 characters long");
92 | }
93 |
94 | // Validate retries
95 | if (config.retries !== undefined) {
96 | if (!Number.isInteger(config.retries) || config.retries < 0) {
97 | errors.push("Retries must be a non-negative integer");
98 | }
99 | }
100 |
101 | // Validate timeout
102 | if (config.timeout !== undefined) {
103 | if (!Number.isInteger(config.timeout) || config.timeout < 1000) {
104 | errors.push("Timeout must be at least 1000ms");
105 | }
106 | }
107 |
108 | return errors;
109 | }
110 |
111 | /**
112 | * Default webhook configuration.
113 | */
114 | export const DEFAULT_WEBHOOK_CONFIG: Partial = {
115 | retries: 3,
116 | timeout: 10000, // 10 seconds
117 | };
118 |
119 | /**
120 | * Retry delay calculator with exponential backoff.
121 | */
122 | export function calculateRetryDelay(attempt: number, baseDelay = 1000): number {
123 | const maxDelay = 60000; // 1 minute
124 | const delay = Math.min(baseDelay * 2 ** attempt, maxDelay);
125 | // Add jitter to prevent thundering herd
126 | return delay + Math.random() * 1000;
127 | }
128 |
--------------------------------------------------------------------------------
/src/replicate_client.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Replicate API client implementation with caching support.
3 | */
4 |
5 | import Replicate from "replicate";
6 | import type { Model, ModelList, ModelVersion } from "./models/model.js";
7 | import type {
8 | Prediction,
9 | PredictionInput,
10 | PredictionStatus,
11 | ModelIO,
12 | } from "./models/prediction.js";
13 | import type { Collection, CollectionList } from "./models/collection.js";
14 | import {
15 | modelCache,
16 | predictionCache,
17 | collectionCache,
18 | } from "./services/cache.js";
19 | import { ReplicateError, ErrorHandler, createError } from "./services/error.js";
20 |
21 | // Constants
22 | const REPLICATE_API_BASE = "https://api.replicate.com/v1";
23 | const DEFAULT_TIMEOUT = 60000; // 60 seconds in milliseconds
24 | const MAX_RETRIES = 3;
25 | const MIN_RETRY_DELAY = 1000; // 1 second in milliseconds
26 | const MAX_RETRY_DELAY = 10000; // 10 seconds in milliseconds
27 | const DEFAULT_RATE_LIMIT = 100; // requests per minute
28 |
29 | // Type definitions for Replicate client responses
30 | interface ReplicateModel {
31 | owner: string;
32 | name: string;
33 | description?: string;
34 | visibility?: "public" | "private";
35 | github_url?: string;
36 | paper_url?: string;
37 | license_url?: string;
38 | run_count?: number;
39 | cover_image_url?: string;
40 | default_example?: Record;
41 | featured?: boolean;
42 | tags?: string[];
43 | latest_version?: ModelVersion;
44 | }
45 |
46 | interface ReplicatePrediction {
47 | id: string;
48 | version: string;
49 | status: string;
50 | input: Record;
51 | output?: unknown;
52 | error?: unknown;
53 | logs?: string;
54 | created_at: string;
55 | started_at?: string;
56 | completed_at?: string;
57 | urls: Record;
58 | metrics?: Record;
59 | }
60 |
61 | interface ReplicatePage {
62 | results: T[];
63 | next?: string;
64 | previous?: string;
65 | total?: number;
66 | }
67 |
68 | interface CreatePredictionOptions {
69 | version?: string;
70 | model?: string;
71 | input: ModelIO | string;
72 | webhook?: string;
73 | }
74 |
75 | /**
76 | * Client for interacting with the Replicate API.
77 | */
78 | export class ReplicateClient {
79 | private api_token: string;
80 | private rate_limit: number;
81 | private request_times: number[];
82 | private retry_count: number;
83 | private client: Replicate;
84 |
85 | constructor(api_token?: string) {
86 | this.api_token = api_token || process.env.REPLICATE_API_TOKEN || "";
87 | if (!this.api_token) {
88 | throw new Error("Replicate API token is required");
89 | }
90 |
91 | this.rate_limit = DEFAULT_RATE_LIMIT;
92 | this.request_times = [];
93 | this.retry_count = 0;
94 | this.client = new Replicate({ auth: this.api_token });
95 | }
96 |
97 | /**
98 | * Wait if necessary to comply with rate limiting.
99 | */
100 | private async waitForRateLimit(): Promise {
101 | const now = Date.now();
102 |
103 | // Remove request times older than 1 minute
104 | this.request_times = this.request_times.filter((t) => now - t <= 60000);
105 |
106 | if (this.request_times.length >= this.rate_limit) {
107 | // Calculate wait time based on oldest request
108 | const wait_time = 60000 - (now - this.request_times[0]);
109 | if (wait_time > 0) {
110 | console.debug(`Rate limit reached. Waiting ${wait_time}ms`);
111 | await new Promise((resolve) => setTimeout(resolve, wait_time));
112 | }
113 | }
114 |
115 | this.request_times.push(now);
116 | }
117 |
118 | /**
119 | * Handle rate limits and other response headers.
120 | */
121 | private async handleResponse(response: Response): Promise {
122 | // Update rate limit from headers if available
123 | const limit = response.headers.get("X-RateLimit-Limit");
124 | const remaining = response.headers.get("X-RateLimit-Remaining");
125 | const reset = response.headers.get("X-RateLimit-Reset");
126 |
127 | if (limit) {
128 | this.rate_limit = Number.parseInt(limit, 10);
129 | }
130 |
131 | // Handle rate limit exceeded
132 | if (response.status === 429) {
133 | const retryAfter = Number.parseInt(
134 | response.headers.get("Retry-After") || "60",
135 | 10
136 | );
137 | throw createError.rateLimit(retryAfter);
138 | }
139 | }
140 |
141 | /**
142 | * Make an HTTP request with retries and rate limiting.
143 | */
144 | private async makeRequest(
145 | method: string,
146 | endpoint: string,
147 | options: RequestInit = {}
148 | ): Promise {
149 | await this.waitForRateLimit();
150 |
151 | return ErrorHandler.withRetries(
152 | async () => {
153 | const response = await fetch(`${REPLICATE_API_BASE}${endpoint}`, {
154 | method,
155 | headers: {
156 | Authorization: `Token ${this.api_token}`,
157 | "Content-Type": "application/json",
158 | ...options.headers,
159 | },
160 | ...options,
161 | signal: AbortSignal.timeout(DEFAULT_TIMEOUT),
162 | });
163 |
164 | await this.handleResponse(response);
165 |
166 | if (!response.ok) {
167 | throw await ErrorHandler.parseAPIError(response);
168 | }
169 |
170 | return response.json();
171 | },
172 | {
173 | maxAttempts: MAX_RETRIES,
174 | minDelay: MIN_RETRY_DELAY,
175 | maxDelay: MAX_RETRY_DELAY,
176 | onRetry: (error: Error, attempt: number) => {
177 | console.warn(
178 | `Request failed: ${error.message}. `,
179 | `Retrying (attempt ${attempt + 1}/${MAX_RETRIES})`
180 | );
181 | },
182 | }
183 | );
184 | }
185 |
186 | /**
187 | * List available models on Replicate with pagination.
188 | */
189 | async listModels(
190 | options: { owner?: string; cursor?: string } = {}
191 | ): Promise {
192 | try {
193 | // Check cache first
194 | const cacheKey = `models:${options.owner || "all"}:${
195 | options.cursor || ""
196 | }`;
197 | const cached = modelCache.get(cacheKey);
198 | if (cached) {
199 | return cached;
200 | }
201 |
202 | if (options.owner) {
203 | // If owner is specified, use search to find their models
204 | const response = (await this.client.models.search(
205 | `owner:${options.owner}`
206 | )) as unknown as ReplicatePage;
207 |
208 | const result: ModelList = {
209 | models: response.results.map((model) => ({
210 | id: `${model.owner}/${model.name}`,
211 | owner: model.owner,
212 | name: model.name,
213 | description: model.description || "",
214 | visibility: model.visibility || "public",
215 | github_url: model.github_url,
216 | paper_url: model.paper_url,
217 | license_url: model.license_url,
218 | run_count: model.run_count,
219 | cover_image_url: model.cover_image_url,
220 | default_example: model.default_example,
221 | featured: model.featured || false,
222 | tags: model.tags || [],
223 | latest_version: model.latest_version
224 | ? {
225 | id: model.latest_version.id,
226 | created_at: model.latest_version.created_at,
227 | cog_version: model.latest_version.cog_version,
228 | openapi_schema: {
229 | ...model.latest_version.openapi_schema,
230 | openapi: "3.0.0",
231 | info: {
232 | title: `${model.owner}/${model.name}`,
233 | version: model.latest_version.id,
234 | },
235 | paths: {},
236 | },
237 | }
238 | : undefined,
239 | })),
240 | next_cursor: response.next,
241 | total_count: response.total || response.results.length,
242 | };
243 |
244 | // Cache the result
245 | modelCache.set(cacheKey, result);
246 | return result;
247 | }
248 |
249 | // Otherwise list all models
250 | const params = new URLSearchParams();
251 | if (options.cursor) {
252 | params.set("cursor", options.cursor);
253 | }
254 |
255 | const response = await this.makeRequest>(
256 | "GET",
257 | `/models${params.toString() ? `?${params.toString()}` : ""}`
258 | );
259 |
260 | const result: ModelList = {
261 | models: response.results.map((model) => ({
262 | id: `${model.owner}/${model.name}`,
263 | owner: model.owner,
264 | name: model.name,
265 | description: model.description || "",
266 | visibility: model.visibility || "public",
267 | github_url: model.github_url,
268 | paper_url: model.paper_url,
269 | license_url: model.license_url,
270 | run_count: model.run_count,
271 | cover_image_url: model.cover_image_url,
272 | default_example: model.default_example,
273 | featured: model.featured || false,
274 | tags: model.tags || [],
275 | latest_version: model.latest_version
276 | ? {
277 | id: model.latest_version.id,
278 | created_at: model.latest_version.created_at,
279 | cog_version: model.latest_version.cog_version,
280 | openapi_schema: {
281 | ...model.latest_version.openapi_schema,
282 | openapi: "3.0.0",
283 | info: {
284 | title: `${model.owner}/${model.name}`,
285 | version: model.latest_version.id,
286 | },
287 | paths: {},
288 | },
289 | }
290 | : undefined,
291 | })),
292 | next_cursor: response.next,
293 | total_count: response.total || response.results.length,
294 | };
295 |
296 | // Cache the result
297 | modelCache.set(cacheKey, result);
298 | return result;
299 | } catch (error) {
300 | throw ErrorHandler.parseAPIError(error as Response);
301 | }
302 | }
303 |
304 | /**
305 | * Search for models using semantic search.
306 | */
307 | async searchModels(query: string): Promise {
308 | try {
309 | // Check cache first
310 | const cacheKey = `search:${query}`;
311 | const cached = modelCache.get(cacheKey);
312 | if (cached) {
313 | return cached;
314 | }
315 |
316 | // Use the official client for search
317 | const response = (await this.client.models.search(
318 | query
319 | )) as unknown as ReplicatePage;
320 |
321 | const result: ModelList = {
322 | models: response.results.map((model) => ({
323 | id: `${model.owner}/${model.name}`,
324 | owner: model.owner,
325 | name: model.name,
326 | description: model.description || "",
327 | visibility: model.visibility || "public",
328 | github_url: model.github_url,
329 | paper_url: model.paper_url,
330 | license_url: model.license_url,
331 | run_count: model.run_count,
332 | cover_image_url: model.cover_image_url,
333 | default_example: model.default_example,
334 | featured: model.featured || false,
335 | tags: model.tags || [],
336 | latest_version: model.latest_version
337 | ? {
338 | id: model.latest_version.id,
339 | created_at: model.latest_version.created_at,
340 | cog_version: model.latest_version.cog_version,
341 | openapi_schema: {
342 | ...model.latest_version.openapi_schema,
343 | openapi: "3.0.0",
344 | info: {
345 | title: `${model.owner}/${model.name}`,
346 | version: model.latest_version.id,
347 | },
348 | paths: {},
349 | },
350 | }
351 | : undefined,
352 | })),
353 | next_cursor: response.next,
354 | total_count: response.results.length,
355 | };
356 |
357 | // Cache the result
358 | modelCache.set(cacheKey, result);
359 | return result;
360 | } catch (error) {
361 | throw ErrorHandler.parseAPIError(error as Response);
362 | }
363 | }
364 |
365 | /**
366 | * List available collections.
367 | */
368 | async listCollections(
369 | options: {
370 | cursor?: string;
371 | } = {}
372 | ): Promise {
373 | try {
374 | // Check cache first
375 | const cacheKey = `collections:${options.cursor || ""}`;
376 | const cached = collectionCache.get(cacheKey);
377 | if (cached) {
378 | return cached;
379 | }
380 |
381 | // Use the official client for collections
382 | const response = await this.client.collections.list();
383 |
384 | const result: CollectionList = {
385 | collections: response.results.map((collection) => ({
386 | id: collection.slug,
387 | name: collection.name,
388 | slug: collection.slug,
389 | description: collection.description || "",
390 | models:
391 | collection.models?.map((model) => ({
392 | id: `${model.owner}/${model.name}`,
393 | owner: model.owner,
394 | name: model.name,
395 | description: model.description || "",
396 | visibility: model.visibility || "public",
397 | github_url: model.github_url,
398 | paper_url: model.paper_url,
399 | license_url: model.license_url,
400 | run_count: model.run_count,
401 | cover_image_url: model.cover_image_url,
402 | default_example: model.default_example
403 | ? ({
404 | input: model.default_example.input,
405 | output: model.default_example.output,
406 | error: model.default_example.error,
407 | status: model.default_example.status,
408 | logs: model.default_example.logs,
409 | metrics: model.default_example.metrics,
410 | } as Record)
411 | : undefined,
412 | featured: false,
413 | tags: [],
414 | latest_version: model.latest_version
415 | ? {
416 | id: model.latest_version.id,
417 | created_at: model.latest_version.created_at,
418 | cog_version: model.latest_version.cog_version,
419 | openapi_schema: {
420 | ...model.latest_version.openapi_schema,
421 | openapi: "3.0.0",
422 | info: {
423 | title: `${model.owner}/${model.name}`,
424 | version: model.latest_version.id,
425 | },
426 | paths: {},
427 | },
428 | }
429 | : undefined,
430 | })) || [],
431 | featured: false,
432 | created_at: new Date().toISOString(),
433 | updated_at: undefined,
434 | })),
435 | next_cursor: response.next,
436 | total_count: response.results.length,
437 | };
438 |
439 | // Cache the result
440 | collectionCache.set(cacheKey, result);
441 | return result;
442 | } catch (error) {
443 | throw ErrorHandler.parseAPIError(error as Response);
444 | }
445 | }
446 |
447 | /**
448 | * Get a specific collection by slug.
449 | */
450 | async getCollection(slug: string): Promise {
451 | try {
452 | // Check cache first
453 | const cacheKey = `collection:${slug}`;
454 | const cached = collectionCache.get(cacheKey);
455 | if (cached?.collections?.length === 1) {
456 | return cached.collections[0];
457 | }
458 |
459 | interface CollectionResponse {
460 | id: string;
461 | name: string;
462 | slug: string;
463 | description?: string;
464 | models: ReplicateModel[];
465 | featured?: boolean;
466 | created_at: string;
467 | updated_at?: string;
468 | }
469 |
470 | const response = await this.makeRequest(
471 | "GET",
472 | `/collections/${slug}`
473 | );
474 |
475 | const collection: Collection = {
476 | id: response.id,
477 | name: response.name,
478 | slug: response.slug,
479 | description: response.description || "",
480 | models: response.models.map((model) => ({
481 | id: `${model.owner}/${model.name}`,
482 | owner: model.owner,
483 | name: model.name,
484 | description: model.description || "",
485 | visibility: model.visibility || "public",
486 | github_url: model.github_url,
487 | paper_url: model.paper_url,
488 | license_url: model.license_url,
489 | run_count: model.run_count,
490 | cover_image_url: model.cover_image_url,
491 | default_example: model.default_example,
492 | featured: model.featured || false,
493 | tags: model.tags || [],
494 | latest_version: model.latest_version,
495 | })),
496 | featured: response.featured || false,
497 | created_at: response.created_at,
498 | updated_at: response.updated_at,
499 | };
500 |
501 | // Cache the result as a single-item collection list
502 | collectionCache.set(cacheKey, {
503 | collections: [collection],
504 | total_count: 1,
505 | });
506 | return collection;
507 | } catch (error) {
508 | throw ErrorHandler.parseAPIError(error as Response);
509 | }
510 | }
511 |
512 | /**
513 | * Create a new prediction.
514 | */
515 | async createPrediction(
516 | options: CreatePredictionOptions
517 | ): Promise {
518 | try {
519 | // If input is a string, wrap it in an object with 'prompt' property
520 | const input =
521 | typeof options.input === "string"
522 | ? { prompt: options.input }
523 | : options.input;
524 |
525 | // Create prediction parameters with the correct type
526 | const predictionParams = options.version
527 | ? {
528 | version: options.version,
529 | input,
530 | webhook: options.webhook,
531 | }
532 | : {
533 | model: options.model!,
534 | input,
535 | webhook: options.webhook,
536 | };
537 |
538 | if (!options.version && !options.model) {
539 | throw new Error("Either model or version must be provided");
540 | }
541 |
542 | // Use the official client for predictions
543 | const prediction = (await this.client.predictions.create(
544 | predictionParams
545 | )) as unknown as ReplicatePrediction;
546 |
547 | const result = {
548 | id: prediction.id,
549 | version: prediction.version,
550 | status: prediction.status as PredictionStatus,
551 | input: prediction.input as ModelIO,
552 | output: prediction.output as ModelIO | undefined,
553 | error: prediction.error ? String(prediction.error) : undefined,
554 | logs: prediction.logs,
555 | created_at: prediction.created_at,
556 | started_at: prediction.started_at,
557 | completed_at: prediction.completed_at,
558 | urls: prediction.urls,
559 | metrics: prediction.metrics,
560 | };
561 |
562 | // Cache the result
563 | predictionCache.set(`prediction:${prediction.id}`, [result]);
564 | return result;
565 | } catch (error) {
566 | throw ErrorHandler.parseAPIError(error as Response);
567 | }
568 | }
569 |
570 | /**
571 | * Get the status of a prediction.
572 | */
573 | async getPredictionStatus(prediction_id: string): Promise {
574 | try {
575 | // Check cache first
576 | const cacheKey = `prediction:${prediction_id}`;
577 | const cached = predictionCache.get(cacheKey);
578 | // Only use cache for completed predictions
579 | if (
580 | cached?.length === 1 &&
581 | ["succeeded", "failed", "canceled"].includes(cached[0].status)
582 | ) {
583 | return cached[0];
584 | }
585 |
586 | // Use the official client for predictions
587 | const prediction = (await this.client.predictions.get(
588 | prediction_id
589 | )) as unknown as ReplicatePrediction;
590 |
591 | const result = {
592 | id: prediction.id,
593 | version: prediction.version,
594 | status: prediction.status as PredictionStatus,
595 | input: prediction.input as ModelIO,
596 | output: prediction.output as ModelIO | undefined,
597 | error: prediction.error ? String(prediction.error) : undefined,
598 | logs: prediction.logs,
599 | created_at: prediction.created_at,
600 | started_at: prediction.started_at,
601 | completed_at: prediction.completed_at,
602 | urls: prediction.urls,
603 | metrics: prediction.metrics,
604 | };
605 |
606 | // Cache completed predictions
607 | if (["succeeded", "failed", "canceled"].includes(result.status)) {
608 | predictionCache.set(cacheKey, [result]);
609 | }
610 |
611 | return result;
612 | } catch (error) {
613 | throw ErrorHandler.parseAPIError(error as Response);
614 | }
615 | }
616 |
617 | /**
618 | * Cancel a running prediction.
619 | */
620 | async cancelPrediction(prediction_id: string): Promise {
621 | try {
622 | const response = await this.makeRequest(
623 | "POST",
624 | `/predictions/${prediction_id}/cancel`
625 | );
626 |
627 | const result = {
628 | id: response.id,
629 | version: response.version,
630 | status: response.status as PredictionStatus,
631 | input: response.input as ModelIO,
632 | output: response.output as ModelIO | undefined,
633 | error: response.error ? String(response.error) : undefined,
634 | logs: response.logs,
635 | created_at: response.created_at,
636 | started_at: response.started_at,
637 | completed_at: response.completed_at,
638 | urls: response.urls,
639 | metrics: response.metrics,
640 | };
641 |
642 | // Update cache
643 | predictionCache.set(`prediction:${prediction_id}`, [result]);
644 | return result;
645 | } catch (error) {
646 | throw ErrorHandler.parseAPIError(error as Response);
647 | }
648 | }
649 |
650 | /**
651 | * List predictions with optional filtering.
652 | */
653 | async listPredictions(
654 | options: {
655 | status?: PredictionStatus;
656 | limit?: number;
657 | cursor?: string;
658 | } = {}
659 | ): Promise {
660 | try {
661 | // Check cache first
662 | const cacheKey = `predictions:${options.status || "all"}:${
663 | options.limit || "all"
664 | }:${options.cursor || ""}`;
665 | const cached = predictionCache.get(cacheKey);
666 | if (cached) {
667 | return cached;
668 | }
669 |
670 | // Use the official client for predictions
671 | const response =
672 | (await this.client.predictions.list()) as unknown as ReplicatePage;
673 |
674 | // Filter and limit results
675 | const filteredPredictions = options.status
676 | ? response.results.filter((p) => p.status === options.status)
677 | : response.results;
678 |
679 | const limitedPredictions = options.limit
680 | ? filteredPredictions.slice(0, options.limit)
681 | : filteredPredictions;
682 |
683 | const result = limitedPredictions.map((prediction) => ({
684 | id: prediction.id,
685 | version: prediction.version,
686 | status: prediction.status as PredictionStatus,
687 | input: prediction.input as ModelIO,
688 | output: prediction.output as ModelIO | undefined,
689 | error: prediction.error ? String(prediction.error) : undefined,
690 | logs: prediction.logs,
691 | created_at: prediction.created_at,
692 | started_at: prediction.started_at,
693 | completed_at: prediction.completed_at,
694 | urls: prediction.urls,
695 | metrics: prediction.metrics,
696 | }));
697 |
698 | // Cache the result
699 | predictionCache.set(cacheKey, result);
700 | return result;
701 | } catch (error) {
702 | throw ErrorHandler.parseAPIError(error as Response);
703 | }
704 | }
705 |
706 | /**
707 | * Get details of a specific model including versions.
708 | */
709 | async getModel(owner: string, name: string): Promise {
710 | try {
711 | // Check cache first
712 | const cacheKey = `model:${owner}/${name}`;
713 | const cached = modelCache.get(cacheKey);
714 | if (cached?.models?.length === 1) {
715 | return cached.models[0];
716 | }
717 |
718 | // Use direct API request to get model details
719 | const response = await this.makeRequest(
720 | "GET",
721 | `/models/${owner}/${name}`
722 | ).catch((error) => {
723 | throw ErrorHandler.parseAPIError(error);
724 | });
725 |
726 | // Get model versions
727 | const versionsResponse = await this.makeRequest<
728 | ReplicatePage
729 | >("GET", `/models/${owner}/${name}/versions`).catch((error) => {
730 | throw ErrorHandler.parseAPIError(error);
731 | });
732 |
733 | const model: Model = {
734 | id: `${response.owner}/${response.name}`,
735 | owner: response.owner,
736 | name: response.name,
737 | description: response.description || "",
738 | visibility: response.visibility || "public",
739 | github_url: response.github_url,
740 | paper_url: response.paper_url,
741 | license_url: response.license_url,
742 | run_count: response.run_count,
743 | cover_image_url: response.cover_image_url,
744 | default_example: response.default_example,
745 | latest_version: versionsResponse.results[0]
746 | ? {
747 | id: versionsResponse.results[0].id,
748 | created_at: versionsResponse.results[0].created_at,
749 | cog_version: versionsResponse.results[0].cog_version,
750 | openapi_schema: {
751 | ...versionsResponse.results[0].openapi_schema,
752 | openapi: "3.0.0",
753 | info: {
754 | title: `${response.owner}/${response.name}`,
755 | version: versionsResponse.results[0].id,
756 | },
757 | paths: {},
758 | },
759 | }
760 | : undefined,
761 | };
762 |
763 | // Cache the result as a single-item model list
764 | modelCache.set(cacheKey, {
765 | models: [model],
766 | total_count: 1,
767 | });
768 | return model;
769 | } catch (error) {
770 | if (error instanceof Promise) {
771 | throw new ReplicateError("Failed to fetch model details");
772 | }
773 | throw ErrorHandler.parseAPIError(error as Response);
774 | }
775 | }
776 |
777 | /**
778 | * Get the webhook signing secret.
779 | */
780 | async getWebhookSecret(): Promise {
781 | interface WebhookResponse {
782 | key: string;
783 | }
784 |
785 | const response = await this.makeRequest(
786 | "GET",
787 | "/webhooks/default/secret"
788 | );
789 |
790 | return response.key;
791 | }
792 | }
793 |
--------------------------------------------------------------------------------
/src/services/cache.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Cache service implementation with TTL and LRU strategy.
3 | */
4 |
5 | import type { Collection, CollectionList } from "../models/collection.js";
6 | import type { Model, ModelList } from "../models/model.js";
7 | import type { Prediction } from "../models/prediction.js";
8 |
9 | interface CacheEntry {
10 | value: T;
11 | timestamp: number;
12 | lastAccessed: number;
13 | }
14 |
15 | export interface CacheStats {
16 | hits: number;
17 | misses: number;
18 | evictions: number;
19 | size: number;
20 | }
21 |
22 | export class Cache {
23 | private cache: Map>;
24 | private maxSize: number;
25 | private ttl: number;
26 | private stats: CacheStats;
27 |
28 | constructor(maxSize = 1000, ttlSeconds = 300) {
29 | this.cache = new Map();
30 | this.maxSize = maxSize;
31 | this.ttl = ttlSeconds * 1000; // Convert to milliseconds
32 | this.stats = {
33 | hits: 0,
34 | misses: 0,
35 | evictions: 0,
36 | size: 0,
37 | };
38 | }
39 |
40 | /**
41 | * Set a value in the cache with TTL.
42 | */
43 | set(key: string, value: T): void {
44 | // Evict oldest entries if cache is full
45 | if (this.cache.size >= this.maxSize) {
46 | this.evictOldest();
47 | }
48 |
49 | this.cache.set(key, {
50 | value,
51 | timestamp: Date.now(),
52 | lastAccessed: Date.now(),
53 | });
54 | this.stats.size = this.cache.size;
55 | }
56 |
57 | /**
58 | * Get a value from the cache, considering TTL.
59 | */
60 | get(key: string): T | null {
61 | const entry = this.cache.get(key);
62 |
63 | if (!entry) {
64 | this.stats.misses++;
65 | return null;
66 | }
67 |
68 | // Check if entry has expired
69 | if (Date.now() - entry.timestamp > this.ttl) {
70 | this.cache.delete(key);
71 | this.stats.evictions++;
72 | this.stats.size = this.cache.size;
73 | this.stats.misses++;
74 | return null;
75 | }
76 |
77 | // Update last accessed time for LRU
78 | entry.lastAccessed = Date.now();
79 | this.stats.hits++;
80 | return entry.value;
81 | }
82 |
83 | /**
84 | * Remove a specific key from the cache.
85 | */
86 | delete(key: string): void {
87 | if (this.cache.delete(key)) {
88 | this.stats.evictions++;
89 | this.stats.size = this.cache.size;
90 | }
91 | }
92 |
93 | /**
94 | * Clear all entries from the cache.
95 | */
96 | clear(): void {
97 | this.cache.clear();
98 | this.stats.evictions += this.stats.size;
99 | this.stats.size = 0;
100 | }
101 |
102 | /**
103 | * Get cache statistics.
104 | */
105 | getStats(): CacheStats {
106 | return { ...this.stats };
107 | }
108 |
109 | /**
110 | * Warm up the cache with initial data.
111 | */
112 | warmup(entries: [string, T][]): void {
113 | for (const [key, value] of entries) {
114 | this.set(key, value);
115 | }
116 | }
117 |
118 | /**
119 | * Remove expired entries from the cache.
120 | */
121 | cleanup(): number {
122 | const now = Date.now();
123 | let removed = 0;
124 |
125 | for (const [key, entry] of this.cache.entries()) {
126 | if (now - entry.timestamp > this.ttl) {
127 | this.cache.delete(key);
128 | removed++;
129 | }
130 | }
131 |
132 | this.stats.evictions += removed;
133 | this.stats.size = this.cache.size;
134 | return removed;
135 | }
136 |
137 | /**
138 | * Get all valid (non-expired) keys in the cache.
139 | */
140 | keys(): string[] {
141 | const now = Date.now();
142 | return Array.from(this.cache.entries())
143 | .filter(([_, entry]) => now - entry.timestamp <= this.ttl)
144 | .map(([key]) => key);
145 | }
146 |
147 | /**
148 | * Check if a key exists and is not expired.
149 | */
150 | has(key: string): boolean {
151 | const entry = this.cache.get(key);
152 | if (!entry) return false;
153 |
154 | const expired = Date.now() - entry.timestamp > this.ttl;
155 | if (expired) {
156 | this.cache.delete(key);
157 | this.stats.evictions++;
158 | this.stats.size = this.cache.size;
159 | return false;
160 | }
161 |
162 | return true;
163 | }
164 |
165 | private evictOldest(): void {
166 | let oldestKey: string | null = null;
167 | let oldestAccess = Number.POSITIVE_INFINITY;
168 |
169 | // Find the least recently used entry
170 | for (const [key, entry] of this.cache.entries()) {
171 | if (entry.lastAccessed < oldestAccess) {
172 | oldestAccess = entry.lastAccessed;
173 | oldestKey = key;
174 | }
175 | }
176 |
177 | // Remove the oldest entry
178 | if (oldestKey) {
179 | this.cache.delete(oldestKey);
180 | this.stats.evictions++;
181 | this.stats.size = this.cache.size;
182 | }
183 | }
184 | }
185 |
186 | // Create specialized cache instances for different types
187 | export const modelCache = new Cache(500, 3600); // 1 hour TTL for models
188 | export const predictionCache = new Cache(1000, 60); // 1 minute TTL for predictions
189 | export const collectionCache = new Cache(100, 3600); // 1 hour TTL for collections
190 |
--------------------------------------------------------------------------------
/src/services/error.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Simple error handling system for Replicate API.
3 | */
4 |
5 | interface APIErrorResponse {
6 | error?: string;
7 | code?: string;
8 | details?: Record;
9 | }
10 |
11 | /**
12 | * Base class for all Replicate API errors.
13 | */
14 | export class ReplicateError extends Error {
15 | constructor(message: string, public context?: Record) {
16 | super(message);
17 | this.name = "ReplicateError";
18 | }
19 | }
20 |
21 | // Re-export specialized error types as ReplicateError instances
22 | export const createError = {
23 | rateLimit: (retryAfter: number) =>
24 | new ReplicateError("Rate limit exceeded", { retryAfter }),
25 |
26 | authentication: (details?: string) =>
27 | new ReplicateError("Authentication failed", { details }),
28 |
29 | notFound: (resource: string) =>
30 | new ReplicateError("Model not found", { resource }),
31 |
32 | api: (status: number, message: string) =>
33 | new ReplicateError("API error", { status, message }),
34 |
35 | prediction: (id: string, message: string) =>
36 | new ReplicateError("Prediction failed", { predictionId: id, message }),
37 |
38 | validation: (field: string, message: string) =>
39 | new ReplicateError("Invalid input parameters", { field, message }),
40 |
41 | timeout: (operation: string, ms: number) =>
42 | new ReplicateError("Operation timed out", { operation, timeoutMs: ms }),
43 | };
44 |
45 | interface RetryOptions {
46 | maxAttempts?: number;
47 | minDelay?: number;
48 | maxDelay?: number;
49 | backoffFactor?: number;
50 | retryIf?: (error: Error) => boolean;
51 | onRetry?: (error: Error, attempt: number) => void;
52 | }
53 |
54 | /**
55 | * Error handling utilities.
56 | */
57 | export const ErrorHandler = {
58 | /**
59 | * Check if an error should trigger a retry.
60 | */
61 | isRetryable(error: Error): boolean {
62 | if (!(error instanceof ReplicateError)) return false;
63 |
64 | const retryableMessages = [
65 | "Rate limit exceeded",
66 | "Internal server error",
67 | "Gateway timeout",
68 | "Service unavailable",
69 | ];
70 |
71 | return retryableMessages.some((msg) => error.message.includes(msg));
72 | },
73 |
74 | /**
75 | * Calculate delay for exponential backoff.
76 | */
77 | getBackoffDelay(
78 | attempt: number,
79 | {
80 | minDelay = 1000,
81 | maxDelay = 30000,
82 | backoffFactor = 2,
83 | }: Partial = {}
84 | ): number {
85 | const delay = minDelay * backoffFactor ** attempt;
86 | return Math.min(delay, maxDelay);
87 | },
88 |
89 | /**
90 | * Execute an operation with automatic retries.
91 | */
92 | async withRetries(
93 | operation: () => Promise,
94 | optionsOrMaxAttempts: RetryOptions | number = {}
95 | ): Promise {
96 | const options: RetryOptions =
97 | typeof optionsOrMaxAttempts === "number"
98 | ? { maxAttempts: optionsOrMaxAttempts }
99 | : optionsOrMaxAttempts;
100 |
101 | const {
102 | maxAttempts = 3,
103 | minDelay = 1000,
104 | maxDelay = 30000,
105 | backoffFactor = 2,
106 | retryIf = this.isRetryable,
107 | onRetry,
108 | } = options;
109 |
110 | let lastError: Error | undefined;
111 |
112 | for (let attempt = 0; attempt < maxAttempts; attempt++) {
113 | try {
114 | return await operation();
115 | } catch (error) {
116 | lastError = error instanceof Error ? error : new Error(String(error));
117 |
118 | if (attempt === maxAttempts - 1 || !retryIf(lastError)) {
119 | throw lastError;
120 | }
121 |
122 | const delay = this.getBackoffDelay(attempt, {
123 | minDelay,
124 | maxDelay,
125 | backoffFactor,
126 | });
127 |
128 | if (onRetry) {
129 | onRetry(lastError, attempt);
130 | }
131 |
132 | await new Promise((resolve) => setTimeout(resolve, delay));
133 | }
134 | }
135 |
136 | throw lastError || new Error("Operation failed");
137 | },
138 |
139 | /**
140 | * Parse error from API response.
141 | */
142 | async parseAPIError(response: Response): Promise {
143 | const status = response.status;
144 | let message = response.statusText;
145 | let context: Record = { status };
146 |
147 | try {
148 | const contentType = response.headers.get("content-type");
149 | if (contentType?.includes("application/json")) {
150 | const details = (await response.json()) as APIErrorResponse;
151 | message = details.error || message;
152 | context = { ...context, ...details };
153 | }
154 | } catch {
155 | // Ignore JSON parsing errors
156 | }
157 |
158 | return new ReplicateError(message, context);
159 | },
160 |
161 | /**
162 | * Create a detailed error report.
163 | */
164 | createErrorReport(error: Error): {
165 | name: string;
166 | message: string;
167 | context?: Record;
168 | timestamp: string;
169 | } {
170 | if (error instanceof ReplicateError) {
171 | return {
172 | name: error.name,
173 | message: error.message,
174 | context: error.context,
175 | timestamp: new Date().toISOString(),
176 | };
177 | }
178 |
179 | return {
180 | name: error.name || "Error",
181 | message: error.message,
182 | timestamp: new Date().toISOString(),
183 | };
184 | },
185 | };
186 |
--------------------------------------------------------------------------------
/src/services/image_viewer.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Image viewer service for handling system image display and caching.
3 | */
4 |
5 | import { Cache, type CacheStats } from "./cache.js";
6 |
7 | // MCP-specific error types
8 | export enum ErrorCode {
9 | InvalidRequest = "INVALID_REQUEST",
10 | InternalError = "INTERNAL_ERROR",
11 | UnsupportedFormat = "UNSUPPORTED_FORMAT",
12 | }
13 |
14 | export class McpError extends Error {
15 | constructor(public code: ErrorCode, message: string) {
16 | super(message);
17 | this.name = "McpError";
18 | }
19 | }
20 |
21 | // Supported image formats and their MIME types
22 | export const IMAGE_MIME_TYPES = {
23 | "image/jpeg": [".jpg", ".jpeg"],
24 | "image/png": [".png"],
25 | "image/gif": [".gif"],
26 | "image/webp": [".webp"],
27 | } as const;
28 |
29 | export type ImageFormat = keyof typeof IMAGE_MIME_TYPES;
30 |
31 | interface ImageMetadata {
32 | format: ImageFormat;
33 | url: string;
34 | localPath?: string;
35 | width?: number;
36 | height?: number;
37 | }
38 |
39 | // Create a specialized cache for images with 1 hour TTL
40 | export const imageCache = new Cache(200, 3600);
41 |
42 | export class ImageViewer {
43 | private static instance: ImageViewer;
44 |
45 | private constructor() {
46 | // Private constructor for singleton pattern
47 | }
48 |
49 | /**
50 | * Get singleton instance of ImageViewer
51 | */
52 | static getInstance(): ImageViewer {
53 | if (!ImageViewer.instance) {
54 | ImageViewer.instance = new ImageViewer();
55 | }
56 | return ImageViewer.instance;
57 | }
58 |
59 | /**
60 | * Display an image in the system's default web browser
61 | */
62 | async displayImage(url: string): Promise {
63 | try {
64 | // Check cache first
65 | const cached = imageCache.get(url);
66 | if (cached) {
67 | await this.openInBrowser(cached.localPath || cached.url);
68 | return;
69 | }
70 |
71 | // If not cached, fetch and cache the image metadata
72 | const metadata = await this.fetchImageMetadata(url);
73 | imageCache.set(url, metadata);
74 |
75 | // Display the image
76 | await this.openInBrowser(metadata.localPath || url);
77 | } catch (error: unknown) {
78 | throw new McpError(
79 | ErrorCode.InternalError,
80 | `Failed to display image: ${
81 | error instanceof Error ? error.message : String(error)
82 | }`
83 | );
84 | }
85 | }
86 |
87 | /**
88 | * Fetch metadata for an image URL
89 | */
90 | private async fetchImageMetadata(url: string): Promise {
91 | // For now, just return basic metadata without strict MIME type checking
92 | return {
93 | format: "image/png", // Default format
94 | url: url,
95 | };
96 | }
97 |
98 | /**
99 | * Open an image URL in the system's default web browser
100 | */
101 | private async openInBrowser(url: string): Promise {
102 | try {
103 | // Create a simple HTML page to display the image
104 | const html = `
105 |
106 |
107 |
108 | Image Viewer
109 |
125 |
126 |
127 |
128 |
129 | `;
130 |
131 | // Create a temporary file to host the HTML
132 | const tempPath = `/tmp/mcp-image-viewer-${Date.now()}.html`;
133 | const fs = await import("node:fs/promises");
134 | await fs.writeFile(tempPath, html);
135 |
136 | // Launch browser with the HTML file
137 | const fileUrl = `file://${tempPath}`;
138 |
139 | // Use the system's browser_action tool
140 | const { exec } = await import("node:child_process");
141 | const { promisify } = await import("node:util");
142 | const execAsync = promisify(exec);
143 |
144 | // Open the file in the default browser
145 | if (process.platform === "darwin") {
146 | await execAsync(`open "${fileUrl}"`);
147 | } else if (process.platform === "win32") {
148 | await execAsync(`start "" "${fileUrl}"`);
149 | } else {
150 | await execAsync(`xdg-open "${fileUrl}"`);
151 | }
152 |
153 | // Clean up the temporary file after a delay to ensure browser has loaded it
154 | setTimeout(async () => {
155 | try {
156 | await fs.unlink(tempPath);
157 | } catch (error) {
158 | console.error("Failed to clean up temporary file:", error);
159 | }
160 | }, 5000);
161 | } catch (error: unknown) {
162 | throw new McpError(
163 | ErrorCode.InternalError,
164 | `Failed to open browser: ${
165 | error instanceof Error ? error.message : String(error)
166 | }`
167 | );
168 | }
169 | }
170 |
171 | /**
172 | * Clear the image cache
173 | */
174 | clearCache(): void {
175 | imageCache.clear();
176 | }
177 |
178 | /**
179 | * Get statistics about the image cache
180 | */
181 | getCacheStats(): CacheStats {
182 | return imageCache.getStats();
183 | }
184 | }
185 |
--------------------------------------------------------------------------------
/src/services/webhook.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Webhook delivery service.
3 | */
4 |
5 | import {
6 | type WebhookConfig,
7 | type WebhookEvent,
8 | type WebhookDeliveryResult,
9 | DEFAULT_WEBHOOK_CONFIG,
10 | generateWebhookSignature,
11 | calculateRetryDelay,
12 | validateWebhookConfig,
13 | } from "../models/webhook.js";
14 |
15 | interface QueuedWebhook {
16 | id: string;
17 | config: WebhookConfig;
18 | event: WebhookEvent;
19 | retryCount: number;
20 | lastAttempt?: Date;
21 | nextAttempt?: Date;
22 | }
23 |
24 | /**
25 | * Service for managing webhook deliveries with retry logic.
26 | */
27 | export class WebhookService {
28 | private queue: Map;
29 | private processing: boolean;
30 | private deliveryResults: Map;
31 |
32 | constructor() {
33 | this.queue = new Map();
34 | this.processing = false;
35 | this.deliveryResults = new Map();
36 | }
37 |
38 | /**
39 | * Validate webhook configuration.
40 | */
41 | validateWebhookConfig(config: WebhookConfig): string[] {
42 | return validateWebhookConfig(config);
43 | }
44 |
45 | /**
46 | * Queue a webhook event for delivery.
47 | */
48 | async queueWebhook(
49 | config: Partial,
50 | event: WebhookEvent
51 | ): Promise {
52 | const id = crypto.randomUUID();
53 | const fullConfig = {
54 | ...DEFAULT_WEBHOOK_CONFIG,
55 | ...config,
56 | } as WebhookConfig;
57 |
58 | this.queue.set(id, {
59 | id,
60 | config: fullConfig,
61 | event,
62 | retryCount: 0,
63 | });
64 |
65 | // Start processing if not already running
66 | if (!this.processing) {
67 | this.processQueue().catch(console.error);
68 | }
69 |
70 | return id;
71 | }
72 |
73 | /**
74 | * Get delivery results for a webhook.
75 | */
76 | getDeliveryResults(webhookId: string): WebhookDeliveryResult[] {
77 | return this.deliveryResults.get(webhookId) || [];
78 | }
79 |
80 | /**
81 | * Process the webhook queue.
82 | */
83 | private async processQueue(): Promise {
84 | if (this.processing) return;
85 |
86 | this.processing = true;
87 | try {
88 | // Process all webhooks immediately in test environment
89 | if (process.env.NODE_ENV === "test") {
90 | const webhooks = Array.from(this.queue.values());
91 | await Promise.all(
92 | webhooks.map((webhook) => this.deliverWebhook(webhook))
93 | );
94 | return;
95 | }
96 |
97 | // Normal processing for non-test environment
98 | while (this.queue.size > 0) {
99 | const now = new Date();
100 | const readyWebhooks = Array.from(this.queue.values()).filter(
101 | (webhook) => !webhook.nextAttempt || webhook.nextAttempt <= now
102 | );
103 |
104 | if (readyWebhooks.length === 0) {
105 | // No webhooks ready for delivery, wait for the next one
106 | const nextAttempt = Math.min(
107 | ...Array.from(this.queue.values())
108 | .map((w) => w.nextAttempt?.getTime() || Date.now())
109 | .filter((t) => t > Date.now())
110 | );
111 | await new Promise((resolve) =>
112 | setTimeout(resolve, nextAttempt - Date.now())
113 | );
114 | continue;
115 | }
116 |
117 | // Process ready webhooks in parallel
118 | await Promise.all(
119 | readyWebhooks.map((webhook) => this.deliverWebhook(webhook))
120 | );
121 | }
122 | } finally {
123 | this.processing = false;
124 | }
125 | }
126 |
127 | /**
128 | * Attempt to deliver a webhook.
129 | */
130 | private async deliverWebhook(webhook: QueuedWebhook): Promise {
131 | const { id, config, event, retryCount } = webhook;
132 | const payload = JSON.stringify(event);
133 |
134 | // Prepare headers
135 | const headers: Record = {
136 | "Content-Type": "application/json",
137 | "User-Agent": "MCP-Replicate-Webhook/1.0",
138 | "X-Webhook-ID": id,
139 | "X-Event-Type": event.type,
140 | "X-Timestamp": event.timestamp,
141 | };
142 |
143 | // Add signature if secret is provided
144 | if (config.secret) {
145 | headers["X-Signature"] = generateWebhookSignature(payload, config.secret);
146 | }
147 |
148 | try {
149 | // Attempt delivery with timeout
150 | const controller = new AbortController();
151 | const timeout = setTimeout(
152 | () => controller.abort(),
153 | config.timeout || DEFAULT_WEBHOOK_CONFIG.timeout
154 | );
155 |
156 | let response: Response;
157 | try {
158 | response = await fetch(config.url, {
159 | method: "POST",
160 | headers,
161 | body: payload,
162 | signal: controller.signal,
163 | });
164 |
165 | clearTimeout(timeout);
166 |
167 | // Record result
168 | const result: WebhookDeliveryResult = {
169 | success: response.ok,
170 | statusCode: response.status,
171 | retryCount,
172 | timestamp: new Date().toISOString(),
173 | };
174 |
175 | if (!response.ok) {
176 | result.error = `HTTP ${response.status}: ${response.statusText}`;
177 | }
178 |
179 | this.recordDeliveryResult(id, result);
180 |
181 | // Handle failed delivery
182 | const maxRetries =
183 | config.retries ?? DEFAULT_WEBHOOK_CONFIG.retries ?? 3;
184 | if (!response.ok && retryCount < maxRetries) {
185 | // Schedule retry
186 | const delay = calculateRetryDelay(retryCount);
187 | webhook.retryCount++;
188 | webhook.lastAttempt = new Date();
189 | webhook.nextAttempt = new Date(Date.now() + delay);
190 |
191 | // In test environment, immediately retry
192 | if (process.env.NODE_ENV === "test") {
193 | await this.deliverWebhook(webhook);
194 | }
195 | return;
196 | }
197 |
198 | // Delivery succeeded or max retries reached
199 | this.queue.delete(id);
200 | } catch (error) {
201 | clearTimeout(timeout);
202 | throw error;
203 | }
204 | } catch (error) {
205 | // Record error result
206 | const result: WebhookDeliveryResult = {
207 | success: false,
208 | error: error instanceof Error ? error.message : String(error),
209 | retryCount,
210 | timestamp: new Date().toISOString(),
211 | };
212 |
213 | this.recordDeliveryResult(id, result);
214 |
215 | // Handle error
216 | const maxRetries = config.retries ?? DEFAULT_WEBHOOK_CONFIG.retries ?? 3;
217 | if (retryCount < maxRetries) {
218 | // Schedule retry
219 | const delay = calculateRetryDelay(retryCount);
220 | webhook.retryCount++;
221 | webhook.lastAttempt = new Date();
222 | webhook.nextAttempt = new Date(Date.now() + delay);
223 |
224 | // In test environment, immediately retry
225 | if (process.env.NODE_ENV === "test") {
226 | await this.deliverWebhook(webhook);
227 | }
228 | return;
229 | }
230 |
231 | // Max retries reached
232 | this.queue.delete(id);
233 | }
234 | }
235 |
236 | /**
237 | * Record a delivery result.
238 | */
239 | private recordDeliveryResult(
240 | webhookId: string,
241 | result: WebhookDeliveryResult
242 | ): void {
243 | const results = this.deliveryResults.get(webhookId) || [];
244 | results.push(result);
245 | this.deliveryResults.set(webhookId, results);
246 |
247 | // Clean up old results (keep last 10)
248 | if (results.length > 10) {
249 | results.splice(0, results.length - 10);
250 | }
251 | }
252 | }
253 |
--------------------------------------------------------------------------------
/src/templates/manager.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Template manager for handling image generation parameters.
3 | */
4 |
5 | import { qualityPresets, type QualityPreset } from "./parameters/quality.js";
6 | import { stylePresets, type StylePreset } from "./parameters/style.js";
7 | import {
8 | sizePresets,
9 | type SizePreset,
10 | scaleToMaxSize,
11 | } from "./parameters/size.js";
12 |
13 | import type { ModelIO } from "../models/openapi.js";
14 |
15 | export interface ImageGenerationParameters extends ModelIO {
16 | prompt: string;
17 | negative_prompt?: string;
18 | width: number;
19 | height: number;
20 | num_inference_steps?: number;
21 | guidance_scale?: number;
22 | scheduler?: string;
23 | style_strength?: number;
24 | seed?: number;
25 | num_outputs?: number;
26 | }
27 |
28 | export interface TemplateOptions {
29 | quality?: keyof typeof qualityPresets;
30 | style?: keyof typeof stylePresets;
31 | size?: keyof typeof sizePresets;
32 | custom_size?: { width: number; height: number };
33 | seed?: number;
34 | num_outputs?: number;
35 | }
36 |
37 | /**
38 | * Manages templates and parameter generation for image generation.
39 | */
40 | export class TemplateManager {
41 | private maxImageSize: number;
42 |
43 | constructor(maxImageSize = 1024) {
44 | this.maxImageSize = maxImageSize;
45 | }
46 |
47 | /**
48 | * Get all available presets.
49 | */
50 | getAvailablePresets() {
51 | return {
52 | quality: Object.entries(qualityPresets).map(([id, preset]) => ({
53 | id,
54 | ...preset,
55 | })),
56 | style: Object.entries(stylePresets).map(([id, preset]) => ({
57 | id,
58 | ...preset,
59 | })),
60 | size: Object.entries(sizePresets).map(([id, preset]) => ({
61 | id,
62 | ...preset,
63 | })),
64 | };
65 | }
66 |
67 | /**
68 | * Generate parameters by combining presets and options.
69 | */
70 | generateParameters(
71 | prompt: string,
72 | options: TemplateOptions = {}
73 | ): ImageGenerationParameters {
74 | // Get presets
75 | const qualityPreset = options.quality
76 | ? qualityPresets[options.quality]
77 | : qualityPresets.balanced;
78 | const stylePreset = options.style
79 | ? stylePresets[options.style]
80 | : stylePresets.photorealistic;
81 | const sizePreset = options.size
82 | ? sizePresets[options.size]
83 | : sizePresets.square;
84 |
85 | // Handle custom size
86 | let { width, height } = options.custom_size || sizePreset.parameters;
87 | if (width > this.maxImageSize || height > this.maxImageSize) {
88 | ({ width, height } = scaleToMaxSize(width, height, this.maxImageSize));
89 | }
90 |
91 | // Combine prompts
92 | const fullPrompt = [
93 | stylePreset.parameters.prompt_prefix,
94 | prompt.trim(),
95 | stylePreset.parameters.prompt_suffix,
96 | ]
97 | .filter(Boolean)
98 | .join(" ");
99 |
100 | // Combine negative prompts
101 | const negativePrompts = [
102 | qualityPreset.parameters.negative_prompt,
103 | stylePreset.parameters.negative_prompt,
104 | ]
105 | .filter(Boolean)
106 | .join(", ");
107 |
108 | // Combine parameters
109 | return {
110 | prompt: fullPrompt,
111 | negative_prompt: negativePrompts || undefined,
112 | width,
113 | height,
114 | num_inference_steps: qualityPreset.parameters.num_inference_steps,
115 | guidance_scale: qualityPreset.parameters.guidance_scale,
116 | scheduler: qualityPreset.parameters.scheduler,
117 | style_strength: stylePreset.parameters.style_strength,
118 | seed: options.seed,
119 | num_outputs: options.num_outputs || 1,
120 | };
121 | }
122 |
123 | /**
124 | * Validate parameters against model constraints.
125 | */
126 | validateParameters(
127 | parameters: ImageGenerationParameters,
128 | modelConstraints: {
129 | min_width?: number;
130 | max_width?: number;
131 | min_height?: number;
132 | max_height?: number;
133 | step_size?: number;
134 | supported_schedulers?: string[];
135 | } = {}
136 | ): void {
137 | const errors: string[] = [];
138 |
139 | // Basic validation
140 | if (parameters.width <= 0) {
141 | errors.push("Width must be positive");
142 | }
143 | if (parameters.height <= 0) {
144 | errors.push("Height must be positive");
145 | }
146 |
147 | // Model constraints validation
148 | if (
149 | modelConstraints.min_width &&
150 | parameters.width < modelConstraints.min_width
151 | ) {
152 | errors.push(`Width must be at least ${modelConstraints.min_width}`);
153 | }
154 | if (
155 | modelConstraints.max_width &&
156 | parameters.width > modelConstraints.max_width
157 | ) {
158 | errors.push(`Width must be at most ${modelConstraints.max_width}`);
159 | }
160 | if (
161 | modelConstraints.min_height &&
162 | parameters.height < modelConstraints.min_height
163 | ) {
164 | errors.push(`Height must be at least ${modelConstraints.min_height}`);
165 | }
166 | if (
167 | modelConstraints.max_height &&
168 | parameters.height > modelConstraints.max_height
169 | ) {
170 | errors.push(`Height must be at most ${modelConstraints.max_height}`);
171 | }
172 |
173 | // Validate step size
174 | if (modelConstraints.step_size) {
175 | if (parameters.width % modelConstraints.step_size !== 0) {
176 | errors.push(
177 | `Width must be a multiple of ${modelConstraints.step_size}`
178 | );
179 | }
180 | if (parameters.height % modelConstraints.step_size !== 0) {
181 | errors.push(
182 | `Height must be a multiple of ${modelConstraints.step_size}`
183 | );
184 | }
185 | }
186 |
187 | // Validate scheduler
188 | if (
189 | modelConstraints.supported_schedulers &&
190 | parameters.scheduler &&
191 | !modelConstraints.supported_schedulers.includes(parameters.scheduler)
192 | ) {
193 | errors.push(
194 | `Scheduler must be one of: ${modelConstraints.supported_schedulers.join(
195 | ", "
196 | )}`
197 | );
198 | }
199 |
200 | // Validate other parameters
201 | if (parameters.num_inference_steps && parameters.num_inference_steps < 1) {
202 | errors.push("Number of inference steps must be positive");
203 | }
204 | if (parameters.guidance_scale && parameters.guidance_scale < 1) {
205 | errors.push("Guidance scale must be positive");
206 | }
207 | if (
208 | parameters.style_strength &&
209 | (parameters.style_strength < 0 || parameters.style_strength > 1)
210 | ) {
211 | errors.push("Style strength must be between 0 and 1");
212 | }
213 | if (parameters.num_outputs && parameters.num_outputs < 1) {
214 | errors.push("Number of outputs must be positive");
215 | }
216 |
217 | if (errors.length > 0) {
218 | throw new Error(`Parameter validation failed:\n${errors.join("\n")}`);
219 | }
220 | }
221 |
222 | /**
223 | * Suggest parameters based on prompt analysis.
224 | */
225 | suggestParameters(prompt: string): TemplateOptions {
226 | // Simple keyword-based suggestions
227 | const suggestions: TemplateOptions = {};
228 |
229 | // Quality suggestions
230 | if (prompt.includes("quick") || prompt.includes("draft")) {
231 | suggestions.quality = "draft";
232 | } else if (prompt.includes("high quality") || prompt.includes("detailed")) {
233 | suggestions.quality = "quality";
234 | }
235 |
236 | // Style suggestions
237 | if (prompt.includes("photo") || prompt.includes("realistic")) {
238 | suggestions.style = "photorealistic";
239 | } else if (prompt.includes("anime") || prompt.includes("manga")) {
240 | suggestions.style = "anime";
241 | } else if (prompt.includes("painting") || prompt.includes("oil")) {
242 | suggestions.style = "oil_painting";
243 | } else if (prompt.includes("watercolor")) {
244 | suggestions.style = "watercolor";
245 | } else if (prompt.includes("pixel") || prompt.includes("8-bit")) {
246 | suggestions.style = "pixel_art";
247 | } else if (prompt.includes("minimal")) {
248 | suggestions.style = "minimalist";
249 | }
250 |
251 | // Size suggestions
252 | if (prompt.includes("portrait") || prompt.includes("vertical")) {
253 | suggestions.size = "portrait";
254 | } else if (prompt.includes("landscape") || prompt.includes("horizontal")) {
255 | suggestions.size = "landscape";
256 | } else if (prompt.includes("panorama") || prompt.includes("wide")) {
257 | suggestions.size = "panoramic";
258 | } else if (prompt.includes("instagram")) {
259 | suggestions.size = prompt.includes("story")
260 | ? "instagram_story"
261 | : "instagram_post";
262 | } else if (prompt.includes("twitter header")) {
263 | suggestions.size = "twitter_header";
264 | }
265 |
266 | return suggestions;
267 | }
268 | }
269 |
--------------------------------------------------------------------------------
/src/templates/parameters/quality.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Quality presets for image generation.
3 | */
4 |
5 | export interface QualityPreset {
6 | name: string;
7 | description: string;
8 | parameters: {
9 | num_inference_steps?: number;
10 | guidance_scale?: number;
11 | scheduler?: string;
12 | negative_prompt?: string;
13 | };
14 | }
15 |
16 | export const qualityPresets: Record = {
17 | draft: {
18 | name: "Draft",
19 | description: "Quick, low-quality preview with minimal steps",
20 | parameters: {
21 | num_inference_steps: 20,
22 | guidance_scale: 7,
23 | scheduler: "DPMSolverMultistep",
24 | },
25 | },
26 | balanced: {
27 | name: "Balanced",
28 | description: "Good balance between quality and speed",
29 | parameters: {
30 | num_inference_steps: 30,
31 | guidance_scale: 7.5,
32 | scheduler: "DPMSolverMultistep",
33 | negative_prompt: "blurry, low quality, distorted",
34 | },
35 | },
36 | quality: {
37 | name: "Quality",
38 | description: "High-quality output with more steps",
39 | parameters: {
40 | num_inference_steps: 50,
41 | guidance_scale: 8,
42 | scheduler: "DPMSolverMultistep",
43 | negative_prompt: "blurry, low quality, distorted, ugly, deformed",
44 | },
45 | },
46 | extreme: {
47 | name: "Extreme",
48 | description: "Maximum quality with extensive steps",
49 | parameters: {
50 | num_inference_steps: 100,
51 | guidance_scale: 9,
52 | scheduler: "DPMSolverMultistep",
53 | negative_prompt:
54 | "blurry, low quality, distorted, ugly, deformed, noisy, grainy, oversaturated",
55 | },
56 | },
57 | };
58 |
--------------------------------------------------------------------------------
/src/templates/parameters/size.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Size and aspect ratio presets for image generation.
3 | */
4 |
5 | export interface SizePreset {
6 | name: string;
7 | description: string;
8 | parameters: {
9 | width: number;
10 | height: number;
11 | aspect_ratio?: string;
12 | recommended_for?: string[];
13 | };
14 | }
15 |
16 | export const sizePresets: Record = {
17 | square: {
18 | name: "Square",
19 | description: "Perfect square format",
20 | parameters: {
21 | width: 1024,
22 | height: 1024,
23 | aspect_ratio: "1:1",
24 | recommended_for: ["social media", "profile pictures", "album covers"],
25 | },
26 | },
27 | portrait: {
28 | name: "Portrait",
29 | description: "Vertical format for portraits",
30 | parameters: {
31 | width: 768,
32 | height: 1024,
33 | aspect_ratio: "3:4",
34 | recommended_for: ["portraits", "mobile wallpapers", "book covers"],
35 | },
36 | },
37 | landscape: {
38 | name: "Landscape",
39 | description: "Horizontal format for landscapes",
40 | parameters: {
41 | width: 1024,
42 | height: 768,
43 | aspect_ratio: "4:3",
44 | recommended_for: ["landscapes", "desktop wallpapers", "banners"],
45 | },
46 | },
47 | widescreen: {
48 | name: "Widescreen",
49 | description: "16:9 format for modern displays",
50 | parameters: {
51 | width: 1024,
52 | height: 576,
53 | aspect_ratio: "16:9",
54 | recommended_for: [
55 | "desktop backgrounds",
56 | "presentations",
57 | "video thumbnails",
58 | ],
59 | },
60 | },
61 | panoramic: {
62 | name: "Panoramic",
63 | description: "Extra wide format for panoramas",
64 | parameters: {
65 | width: 1024,
66 | height: 384,
67 | aspect_ratio: "21:9",
68 | recommended_for: [
69 | "panoramic landscapes",
70 | "ultra-wide displays",
71 | "banners",
72 | ],
73 | },
74 | },
75 | instagram_post: {
76 | name: "Instagram Post",
77 | description: "Optimized for Instagram posts",
78 | parameters: {
79 | width: 1080,
80 | height: 1080,
81 | aspect_ratio: "1:1",
82 | recommended_for: ["instagram posts", "social media"],
83 | },
84 | },
85 | instagram_story: {
86 | name: "Instagram Story",
87 | description: "Optimized for Instagram stories",
88 | parameters: {
89 | width: 1080,
90 | height: 1920,
91 | aspect_ratio: "9:16",
92 | recommended_for: ["instagram stories", "mobile content"],
93 | },
94 | },
95 | twitter_header: {
96 | name: "Twitter Header",
97 | description: "Optimized for Twitter profile headers",
98 | parameters: {
99 | width: 1500,
100 | height: 500,
101 | aspect_ratio: "3:1",
102 | recommended_for: ["twitter headers", "social media banners"],
103 | },
104 | },
105 | };
106 |
107 | /**
108 | * Get the closest size preset for given dimensions.
109 | */
110 | export function findClosestSizePreset(
111 | width: number,
112 | height: number
113 | ): SizePreset {
114 | const targetRatio = width / height;
115 | let closestPreset = sizePresets.square;
116 | let smallestDiff = Number.POSITIVE_INFINITY;
117 |
118 | for (const preset of Object.values(sizePresets)) {
119 | const presetRatio = preset.parameters.width / preset.parameters.height;
120 | const ratioDiff = Math.abs(presetRatio - targetRatio);
121 |
122 | if (ratioDiff < smallestDiff) {
123 | smallestDiff = ratioDiff;
124 | closestPreset = preset;
125 | }
126 | }
127 |
128 | return closestPreset;
129 | }
130 |
131 | /**
132 | * Scale dimensions to fit within maximum size while maintaining aspect ratio.
133 | */
134 | export function scaleToMaxSize(
135 | width: number,
136 | height: number,
137 | maxSize: number
138 | ): { width: number; height: number } {
139 | if (width <= maxSize && height <= maxSize) {
140 | return { width, height };
141 | }
142 |
143 | const ratio = width / height;
144 | return ratio > 1
145 | ? {
146 | // Width is larger
147 | width: maxSize,
148 | height: Math.round(maxSize / ratio),
149 | }
150 | : {
151 | // Height is larger
152 | width: Math.round(maxSize * ratio),
153 | height: maxSize,
154 | };
155 | }
156 |
--------------------------------------------------------------------------------
/src/templates/parameters/style.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Style presets for image generation.
3 | */
4 |
5 | export interface StylePreset {
6 | name: string;
7 | description: string;
8 | parameters: {
9 | prompt_prefix?: string;
10 | prompt_suffix?: string;
11 | negative_prompt?: string;
12 | style_strength?: number;
13 | };
14 | }
15 |
16 | export const stylePresets: Record = {
17 | photorealistic: {
18 | name: "Photorealistic",
19 | description: "Highly detailed, realistic photography style",
20 | parameters: {
21 | prompt_prefix: "photorealistic, highly detailed photograph,",
22 | prompt_suffix: "8k uhd, high resolution, professional photography",
23 | negative_prompt:
24 | "illustration, painting, drawing, cartoon, anime, rendered, 3d, cgi",
25 | style_strength: 0.8,
26 | },
27 | },
28 | cinematic: {
29 | name: "Cinematic",
30 | description: "Movie-like scenes with dramatic lighting",
31 | parameters: {
32 | prompt_prefix: "cinematic shot, dramatic lighting, movie scene,",
33 | prompt_suffix: "anamorphic lens, film grain, depth of field, bokeh",
34 | negative_prompt: "flat lighting, flash photography, overexposed",
35 | style_strength: 0.85,
36 | },
37 | },
38 | anime: {
39 | name: "Anime",
40 | description: "Japanese anime and manga style",
41 | parameters: {
42 | prompt_prefix: "anime style, manga illustration,",
43 | prompt_suffix: "clean lines, vibrant colors, detailed anime drawing",
44 | negative_prompt: "photorealistic, 3d rendered, western animation",
45 | style_strength: 0.9,
46 | },
47 | },
48 | digital_art: {
49 | name: "Digital Art",
50 | description: "Modern digital art style",
51 | parameters: {
52 | prompt_prefix: "digital art, concept art,",
53 | prompt_suffix: "highly detailed, sharp focus, vibrant colors",
54 | negative_prompt: "traditional media, watercolor, oil painting",
55 | style_strength: 0.85,
56 | },
57 | },
58 | oil_painting: {
59 | name: "Oil Painting",
60 | description: "Classical oil painting style",
61 | parameters: {
62 | prompt_prefix: "oil painting, traditional art, painterly,",
63 | prompt_suffix: "detailed brushwork, canvas texture, rich colors",
64 | negative_prompt: "digital art, photograph, 3d rendered",
65 | style_strength: 0.9,
66 | },
67 | },
68 | watercolor: {
69 | name: "Watercolor",
70 | description: "Soft watercolor painting style",
71 | parameters: {
72 | prompt_prefix: "watercolor painting, soft and dreamy,",
73 | prompt_suffix: "flowing colors, wet on wet technique, artistic",
74 | negative_prompt: "sharp edges, harsh contrast, digital art",
75 | style_strength: 0.85,
76 | },
77 | },
78 | pixel_art: {
79 | name: "Pixel Art",
80 | description: "Retro pixel art style",
81 | parameters: {
82 | prompt_prefix: "pixel art, retro game style,",
83 | prompt_suffix: "8-bit, pixelated, video game art",
84 | negative_prompt: "smooth gradients, photorealistic, high resolution",
85 | style_strength: 0.95,
86 | },
87 | },
88 | minimalist: {
89 | name: "Minimalist",
90 | description: "Clean, simple minimalist style",
91 | parameters: {
92 | prompt_prefix: "minimalist design, simple composition,",
93 | prompt_suffix: "clean lines, negative space, geometric",
94 | negative_prompt: "busy, cluttered, detailed, ornate",
95 | style_strength: 0.8,
96 | },
97 | },
98 | };
99 |
--------------------------------------------------------------------------------
/src/templates/prompts/text_to_image.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Text-to-image generation prompt.
3 | */
4 |
5 | import type { MCPMessage } from "../../types/mcp.js";
6 | import { TemplateManager, type TemplateOptions } from "../manager.js";
7 |
8 | const templateManager = new TemplateManager();
9 |
10 | /**
11 | * Generate a text-to-image prompt with parameter suggestions.
12 | */
13 | export function generateTextToImagePrompt(userPrompt: string): MCPMessage {
14 | // Get parameter suggestions based on prompt
15 | const suggestions = templateManager.suggestParameters(userPrompt);
16 |
17 | // Get all available presets for reference
18 | const presets = templateManager.getAvailablePresets();
19 |
20 | // Generate example parameters
21 | const exampleParams = templateManager.generateParameters(
22 | userPrompt,
23 | suggestions
24 | );
25 |
26 | return {
27 | jsonrpc: "2.0",
28 | method: "prompt/text_to_image",
29 | params: {
30 | messages: [
31 | {
32 | role: "assistant",
33 | content: {
34 | type: "text",
35 | text: "I'll help you generate an image. Here's what I understand from your prompt:",
36 | },
37 | },
38 | {
39 | role: "assistant",
40 | content: {
41 | type: "text",
42 | text: `Based on your description, I suggest:
43 | ${suggestions.quality ? `- Quality: ${suggestions.quality} mode` : ""}
44 | ${suggestions.style ? `- Style: ${suggestions.style} style` : ""}
45 | ${suggestions.size ? `- Size: ${suggestions.size} format` : ""}
46 |
47 | Here are the parameters I'll use:
48 | - Prompt: "${exampleParams.prompt}"
49 | ${
50 | exampleParams.negative_prompt
51 | ? `- Negative prompt: "${exampleParams.negative_prompt}"`
52 | : ""
53 | }
54 | - Size: ${exampleParams.width}x${exampleParams.height}
55 | - Steps: ${exampleParams.num_inference_steps}
56 | - Guidance scale: ${exampleParams.guidance_scale}
57 | ${exampleParams.scheduler ? `- Scheduler: ${exampleParams.scheduler}` : ""}
58 | ${
59 | exampleParams.style_strength
60 | ? `- Style strength: ${exampleParams.style_strength}`
61 | : ""
62 | }
63 |
64 | Would you like to adjust any of these settings? You can choose from:
65 |
66 | Quality presets:
67 | ${presets.quality.map((p) => `- ${p.name}: ${p.description}`).join("\n")}
68 |
69 | Style presets:
70 | ${presets.style.map((p) => `- ${p.name}: ${p.description}`).join("\n")}
71 |
72 | Size presets:
73 | ${presets.size.map((p) => `- ${p.name}: ${p.description}`).join("\n")}
74 |
75 | Or you can specify custom parameters:
76 | - Custom size (e.g., "make it 1024x768")
77 | - Number of images (e.g., "generate 4 variations")
78 | - Seed number for reproducibility
79 |
80 | Let me know if you want to proceed with these settings or make any adjustments.`,
81 | },
82 | },
83 | ],
84 | parameters: exampleParams,
85 | suggestions,
86 | presets,
87 | },
88 | };
89 | }
90 |
91 | /**
92 | * Parse user response to extract parameter adjustments.
93 | */
94 | export function parseParameterAdjustments(
95 | response: string
96 | ): Partial {
97 | const adjustments: Partial = {};
98 |
99 | // Quality adjustments
100 | if (response.match(/\b(draft|quick|fast)\b/i)) {
101 | adjustments.quality = "draft";
102 | } else if (response.match(/\b(balanced|medium|default)\b/i)) {
103 | adjustments.quality = "balanced";
104 | } else if (response.match(/\b(quality|high|detailed)\b/i)) {
105 | adjustments.quality = "quality";
106 | } else if (response.match(/\b(extreme|maximum|best)\b/i)) {
107 | adjustments.quality = "extreme";
108 | }
109 |
110 | // Style adjustments
111 | if (response.match(/\b(photo|realistic)\b/i)) {
112 | adjustments.style = "photorealistic";
113 | } else if (response.match(/\b(anime|manga)\b/i)) {
114 | adjustments.style = "anime";
115 | } else if (response.match(/\b(digital[\s-]?art)\b/i)) {
116 | adjustments.style = "digital_art";
117 | } else if (response.match(/\b(oil[\s-]?painting)\b/i)) {
118 | adjustments.style = "oil_painting";
119 | } else if (response.match(/\b(watercolor)\b/i)) {
120 | adjustments.style = "watercolor";
121 | } else if (response.match(/\b(pixel[\s-]?art|8[\s-]?bit)\b/i)) {
122 | adjustments.style = "pixel_art";
123 | } else if (response.match(/\b(minimal|minimalist)\b/i)) {
124 | adjustments.style = "minimalist";
125 | }
126 |
127 | // Size adjustments
128 | if (response.match(/\b(square|1:1)\b/i)) {
129 | adjustments.size = "square";
130 | } else if (response.match(/\b(portrait|vertical|3:4)\b/i)) {
131 | adjustments.size = "portrait";
132 | } else if (response.match(/\b(landscape|horizontal|4:3)\b/i)) {
133 | adjustments.size = "landscape";
134 | } else if (response.match(/\b(widescreen|16:9)\b/i)) {
135 | adjustments.size = "widescreen";
136 | } else if (response.match(/\b(panoramic|21:9)\b/i)) {
137 | adjustments.size = "panoramic";
138 | } else if (response.match(/\b(instagram[\s-]?story)\b/i)) {
139 | adjustments.size = "instagram_story";
140 | } else if (response.match(/\b(instagram[\s-]?post)\b/i)) {
141 | adjustments.size = "instagram_post";
142 | } else if (response.match(/\b(twitter[\s-]?header)\b/i)) {
143 | adjustments.size = "twitter_header";
144 | }
145 |
146 | // Custom size
147 | const sizeMatch = response.match(/(\d+)\s*x\s*(\d+)/i);
148 | if (sizeMatch) {
149 | adjustments.custom_size = {
150 | width: Number.parseInt(sizeMatch[1], 10),
151 | height: Number.parseInt(sizeMatch[2], 10),
152 | };
153 | }
154 |
155 | // Number of outputs
156 | const numMatch = response.match(
157 | /\b(\d+)\s*(outputs?|variations?|images?)\b/i
158 | );
159 | if (numMatch) {
160 | adjustments.num_outputs = Number.parseInt(numMatch[1], 10);
161 | }
162 |
163 | // Seed
164 | const seedMatch = response.match(/\bseed\s*[=:]?\s*(\d+)\b/i);
165 | if (seedMatch) {
166 | adjustments.seed = Number.parseInt(seedMatch[1], 10);
167 | }
168 |
169 | return adjustments;
170 | }
171 |
--------------------------------------------------------------------------------
/src/tests/integration.test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Integration tests with Replicate API.
3 | */
4 |
5 | import {
6 | describe,
7 | it,
8 | expect,
9 | beforeEach,
10 | afterEach,
11 | vi,
12 | beforeAll,
13 | afterAll,
14 | } from "vitest";
15 | import { ReplicateClient } from "../replicate_client.js";
16 | import { WebhookService } from "../services/webhook.js";
17 | import { TemplateManager } from "../templates/manager.js";
18 | import { ErrorHandler, ReplicateError } from "../services/error.js";
19 | import { Cache } from "../services/cache.js";
20 | import type { Model } from "../models/model.js";
21 | import type { Prediction } from "../models/prediction.js";
22 | import { PredictionStatus } from "../models/prediction.js";
23 | import type { SchemaObject, PropertyObject } from "../models/openapi.js";
24 | import type { WebhookEvent } from "../models/webhook.js";
25 | import { createError } from "../services/error.js";
26 |
27 | // Mock environment variables
28 | process.env.REPLICATE_API_TOKEN = "test_token";
29 |
30 | describe("Replicate API Integration", () => {
31 | let client: ReplicateClient;
32 | let webhookService: WebhookService;
33 | let templateManager: TemplateManager;
34 | let cache: Cache;
35 |
36 | beforeEach(() => {
37 | // Initialize components
38 | client = new ReplicateClient();
39 | webhookService = new WebhookService();
40 | templateManager = new TemplateManager();
41 | cache = new Cache();
42 |
43 | // Mock client methods
44 | vi.spyOn(client, "listModels").mockResolvedValue({
45 | models: [
46 | {
47 | owner: "stability-ai",
48 | name: "sdxl",
49 | description: "Test model",
50 | id: "stability-ai/sdxl",
51 | visibility: "public",
52 | latest_version: {
53 | id: "test-version",
54 | created_at: new Date().toISOString(),
55 | cog_version: "0.3.0",
56 | openapi_schema: {
57 | openapi: "3.0.0",
58 | info: {
59 | title: "Test Model API",
60 | version: "1.0.0",
61 | },
62 | paths: {},
63 | components: {
64 | schemas: {
65 | Input: {
66 | type: "object",
67 | required: ["prompt"],
68 | properties: {
69 | prompt: { type: "string" },
70 | width: { type: "number", minimum: 0 },
71 | height: { type: "number", minimum: 0 },
72 | },
73 | },
74 | },
75 | },
76 | },
77 | },
78 | },
79 | ],
80 | });
81 |
82 | vi.spyOn(client, "searchModels").mockResolvedValue({
83 | models: [
84 | {
85 | id: "stability-ai/sdxl",
86 | owner: "stability-ai",
87 | name: "sdxl",
88 | description: "Text to image model",
89 | visibility: "public",
90 | },
91 | ],
92 | });
93 |
94 | vi.spyOn(client, "getModel").mockResolvedValue({
95 | owner: "stability-ai",
96 | name: "sdxl",
97 | description: "Test model",
98 | id: "stability-ai/sdxl",
99 | visibility: "public",
100 | latest_version: {
101 | id: "test-version",
102 | created_at: new Date().toISOString(),
103 | cog_version: "0.3.0",
104 | openapi_schema: {
105 | openapi: "3.0.0",
106 | info: {
107 | title: "Test Model API",
108 | version: "1.0.0",
109 | },
110 | paths: {},
111 | components: {
112 | schemas: {
113 | Input: {
114 | type: "object",
115 | required: ["prompt"],
116 | properties: {
117 | prompt: { type: "string" },
118 | width: { type: "number", minimum: 0 },
119 | height: { type: "number", minimum: 0 },
120 | },
121 | },
122 | },
123 | },
124 | },
125 | },
126 | });
127 |
128 | vi.spyOn(client, "createPrediction").mockResolvedValue({
129 | id: "test-prediction",
130 | version: "test-version",
131 | status: PredictionStatus.Starting,
132 | input: { prompt: "test" },
133 | created_at: new Date().toISOString(),
134 | urls: {},
135 | });
136 |
137 | vi.spyOn(client, "getPredictionStatus").mockResolvedValue({
138 | id: "test-prediction",
139 | version: "test-version",
140 | status: PredictionStatus.Succeeded,
141 | input: { prompt: "test" },
142 | output: { image: "test.png" },
143 | created_at: new Date().toISOString(),
144 | urls: {},
145 | });
146 |
147 | vi.spyOn(client, "listPredictions").mockResolvedValue([
148 | {
149 | id: "test-prediction",
150 | version: "test-version",
151 | status: PredictionStatus.Succeeded,
152 | input: { prompt: "test" },
153 | output: { image: "test.png" },
154 | created_at: new Date().toISOString(),
155 | urls: {},
156 | },
157 | ]);
158 |
159 | vi.spyOn(client, "listCollections").mockResolvedValue({
160 | collections: [
161 | {
162 | id: "test-collection",
163 | name: "Test Collection",
164 | slug: "test",
165 | description: "Test collection",
166 | models: [],
167 | created_at: new Date().toISOString(),
168 | },
169 | ],
170 | });
171 |
172 | vi.spyOn(client, "getCollection").mockResolvedValue({
173 | id: "test-collection",
174 | name: "Test Collection",
175 | slug: "test",
176 | description: "Test collection",
177 | models: [],
178 | created_at: new Date().toISOString(),
179 | });
180 |
181 | // Mock webhook service
182 | vi.spyOn(webhookService, "queueWebhook").mockResolvedValue("test-webhook");
183 | vi.spyOn(webhookService, "getDeliveryResults").mockReturnValue([
184 | {
185 | success: false,
186 | error: "Mock delivery failure",
187 | retryCount: 0,
188 | timestamp: new Date().toISOString(),
189 | },
190 | ]);
191 | });
192 |
193 | afterEach(() => {
194 | cache.clear();
195 | vi.clearAllMocks();
196 | });
197 |
198 | describe("Error Handling", () => {
199 | it("should handle authentication errors", async () => {
200 | const invalidClient = new ReplicateClient("invalid_token");
201 | vi.spyOn(invalidClient, "listModels").mockRejectedValue(
202 | createError.authentication()
203 | );
204 |
205 | await expect(invalidClient.listModels()).rejects.toThrow(ReplicateError);
206 | });
207 |
208 | it("should handle rate limit errors with retries", async () => {
209 | const rateLimitError = createError.rateLimit(1);
210 |
211 | // Mock client to throw rate limit error once then succeed
212 | let attempts = 0;
213 | vi.spyOn(client, "listModels").mockImplementation(async () => {
214 | if (attempts === 0) {
215 | attempts++;
216 | throw rateLimitError;
217 | }
218 | return { models: [] };
219 | });
220 |
221 | const result = await ErrorHandler.withRetries(
222 | async () => client.listModels(),
223 | {
224 | maxAttempts: 2,
225 | retryIf: (error: Error) => error instanceof ReplicateError,
226 | }
227 | );
228 |
229 | expect(attempts).toBe(1);
230 | expect(result).toEqual({ models: [] });
231 | });
232 |
233 | it("should handle network errors with retries", async () => {
234 | const networkError = createError.api(500, "Connection failed");
235 |
236 | // Mock client to throw network error twice then succeed
237 | let attempts = 0;
238 | vi.spyOn(client, "listModels").mockImplementation(async () => {
239 | if (attempts < 2) {
240 | attempts++;
241 | throw networkError;
242 | }
243 | return { models: [] };
244 | });
245 |
246 | const result = await ErrorHandler.withRetries(
247 | async () => client.listModels(),
248 | {
249 | maxAttempts: 3,
250 | retryIf: (error: Error) => error instanceof ReplicateError,
251 | }
252 | );
253 |
254 | expect(attempts).toBe(2);
255 | expect(result).toEqual({ models: [] });
256 | });
257 |
258 | it("should handle validation errors", async () => {
259 | vi.spyOn(client, "createPrediction").mockRejectedValue(
260 | createError.validation("version", "Invalid version")
261 | );
262 |
263 | await expect(
264 | client.createPrediction({
265 | version: "invalid-version",
266 | input: {},
267 | })
268 | ).rejects.toThrow(ReplicateError);
269 | });
270 |
271 | it("should generate detailed error reports", async () => {
272 | const error = createError.validation("test", "Test error");
273 |
274 | const report = ErrorHandler.createErrorReport(error);
275 | expect(report).toMatchObject({
276 | name: "ReplicateError",
277 | message: "Invalid input parameters",
278 | context: {
279 | field: "test",
280 | message: "Test error",
281 | },
282 | timestamp: expect.any(String),
283 | });
284 | });
285 | });
286 |
287 | describe("Caching Behavior", () => {
288 | it("should cache model listings", async () => {
289 | // First request should hit API
290 | const result1 = await client.listModels();
291 | expect(result1.models.length).toBeGreaterThan(0);
292 |
293 | // Mock cache hit
294 | const result2 = await client.listModels();
295 | expect(result2).toEqual(result1);
296 | });
297 |
298 | it("should cache model details with TTL", async () => {
299 | const owner = "stability-ai";
300 | const name = "sdxl";
301 |
302 | // First request should hit API
303 | const model1 = await client.getModel(owner, name);
304 | expect(model1).toBeDefined();
305 |
306 | // Mock cache hit
307 | const model2 = await client.getModel(owner, name);
308 | expect(model2).toEqual(model1);
309 |
310 | // Advance time past TTL
311 | vi.useFakeTimers();
312 | vi.advanceTimersByTime(25 * 60 * 60 * 1000); // 25 hours
313 |
314 | // Request should hit API again
315 | const model3 = await client.getModel(owner, name);
316 | expect(model3).toBeDefined();
317 | });
318 |
319 | it("should handle cache invalidation for predictions", async () => {
320 | const prediction = await client.createPrediction({
321 | version: "stability-ai/sdxl@latest",
322 | input: {
323 | prompt: "test",
324 | },
325 | });
326 |
327 | // Initial status should be cached
328 | const status1 = await client.getPredictionStatus(prediction.id);
329 | const status2 = await client.getPredictionStatus(prediction.id);
330 | expect(status2).toEqual(status1);
331 |
332 | // Completed predictions should stay cached
333 | if (status2.status === PredictionStatus.Succeeded) {
334 | const status3 = await client.getPredictionStatus(prediction.id);
335 | expect(status3).toEqual(status2);
336 | }
337 | // In-progress predictions should refresh
338 | else {
339 | const status3 = await client.getPredictionStatus(prediction.id);
340 | expect(status3).toBeDefined();
341 | }
342 | });
343 | });
344 |
345 | describe("Model Operations", () => {
346 | it("should list available models", async () => {
347 | const result = await client.listModels();
348 | expect(result.models).toBeInstanceOf(Array);
349 | expect(result.models.length).toBeGreaterThan(0);
350 |
351 | const model = result.models[0];
352 | expect(model).toHaveProperty("owner");
353 | expect(model).toHaveProperty("name");
354 | expect(model).toHaveProperty("description");
355 | });
356 |
357 | it("should search models by query", async () => {
358 | const query = "text to image";
359 | const result = await client.searchModels(query);
360 | expect(result.models).toBeInstanceOf(Array);
361 | expect(result.models.length).toBeGreaterThan(0);
362 |
363 | // Results should be relevant to the query
364 | const relevantModels = result.models.filter(
365 | (model) =>
366 | model.description?.toLowerCase().includes("text") ||
367 | model.description?.toLowerCase().includes("image")
368 | );
369 | expect(relevantModels.length).toBeGreaterThan(0);
370 | });
371 |
372 | it("should get model details", async () => {
373 | // Use a known stable model
374 | const owner = "stability-ai";
375 | const name = "sdxl";
376 | const model = await client.getModel(owner, name);
377 |
378 | expect(model.owner).toBe(owner);
379 | expect(model.name).toBe(name);
380 | expect(model.latest_version).toBeDefined();
381 | expect(model.latest_version?.openapi_schema).toBeDefined();
382 | });
383 | });
384 |
385 | describe("Prediction Operations", () => {
386 | let testModel: Model;
387 |
388 | beforeEach(async () => {
389 | // Get a test model for predictions
390 | const models = await client.listModels();
391 | const foundModel = models.models.find(
392 | (m) => m.owner === "stability-ai" && m.name === "sdxl"
393 | );
394 | if (!foundModel || !foundModel.latest_version) {
395 | throw new Error("Test model not found or missing version");
396 | }
397 | testModel = foundModel;
398 | expect(testModel).toBeDefined();
399 | });
400 |
401 | it("should create prediction with community model version", async () => {
402 | if (!testModel.latest_version) {
403 | throw new Error("Test model missing version");
404 | }
405 | const prediction = await client.createPrediction({
406 | version: testModel.latest_version.id,
407 | input: { prompt: "test" },
408 | });
409 |
410 | expect(prediction.id).toBeDefined();
411 | expect(prediction.status).toBe(PredictionStatus.Starting);
412 | expect(prediction.version).toBe(testModel.latest_version.id);
413 | });
414 |
415 | it("should create prediction with official model", async () => {
416 | const prediction = await client.createPrediction({
417 | model: "stability-ai/sdxl",
418 | input: { prompt: "test" },
419 | });
420 |
421 | expect(prediction.id).toBeDefined();
422 | expect(prediction.status).toBe(PredictionStatus.Starting);
423 | });
424 |
425 | it("should create and track prediction", async () => {
426 | if (!testModel.latest_version) {
427 | throw new Error("Test model missing version");
428 | }
429 | const prompt = "a photo of a mountain landscape at sunset";
430 | const params = templateManager.generateParameters(prompt, {
431 | quality: "quality",
432 | style: "photorealistic",
433 | size: "landscape",
434 | });
435 |
436 | // Create prediction
437 | const prediction = await client.createPrediction({
438 | version: testModel.latest_version.id,
439 | input: params as Record,
440 | });
441 |
442 | expect(prediction.id).toBeDefined();
443 | expect(prediction.status).toBe(PredictionStatus.Starting);
444 |
445 | // Mock the status progression
446 | vi.spyOn(client, "getPredictionStatus")
447 | .mockResolvedValueOnce({
448 | ...prediction,
449 | status: PredictionStatus.Processing,
450 | })
451 | .mockResolvedValueOnce({
452 | ...prediction,
453 | status: PredictionStatus.Succeeded,
454 | output: { image: "test.png" },
455 | });
456 |
457 | // Check processing status
458 | const processingStatus = await client.getPredictionStatus(prediction.id);
459 | expect(processingStatus.status).toBe(PredictionStatus.Processing);
460 |
461 | // Check final status
462 | const finalStatus = await client.getPredictionStatus(prediction.id);
463 | expect(finalStatus.status).toBe(PredictionStatus.Succeeded);
464 | expect(finalStatus.output).toBeDefined();
465 | });
466 |
467 | it("should handle webhook notifications", async () => {
468 | if (!testModel.latest_version) {
469 | throw new Error("Test model missing version");
470 | }
471 | // Setup fake timers
472 | vi.useFakeTimers();
473 |
474 | const prompt = "a photo of a mountain landscape at sunset";
475 | const params = templateManager.generateParameters(prompt, {
476 | quality: "quality",
477 | style: "photorealistic",
478 | size: "landscape",
479 | });
480 |
481 | // Create mock webhook server
482 | const mockWebhook = {
483 | url: "https://example.com/webhook",
484 | };
485 |
486 | // Create prediction with webhook
487 | const prediction = await client.createPrediction({
488 | version: testModel.latest_version.id,
489 | input: params as Record,
490 | webhook: mockWebhook.url,
491 | });
492 |
493 | // Queue webhook delivery
494 | const webhookId = await webhookService.queueWebhook(
495 | { url: mockWebhook.url },
496 | {
497 | type: "prediction.created",
498 | timestamp: new Date().toISOString(),
499 | data: JSON.parse(JSON.stringify(prediction)),
500 | } as WebhookEvent
501 | );
502 |
503 | // Advance timers instead of waiting
504 | await vi.runAllTimersAsync();
505 |
506 | const results = webhookService.getDeliveryResults(webhookId);
507 | expect(results).toHaveLength(1);
508 | // Expect failure since we're using a mock URL
509 | expect(results[0].success).toBe(false);
510 |
511 | // Cleanup
512 | vi.useRealTimers();
513 | });
514 | });
515 |
516 | describe("Collection Operations", () => {
517 | it("should list collections", async () => {
518 | const result = await client.listCollections();
519 | expect(result.collections).toBeInstanceOf(Array);
520 | expect(result.collections.length).toBeGreaterThan(0);
521 |
522 | const collection = result.collections[0];
523 | expect(collection).toHaveProperty("name");
524 | expect(collection).toHaveProperty("slug");
525 | expect(collection).toHaveProperty("models");
526 | });
527 |
528 | it("should get collection details", async () => {
529 | const collections = await client.listCollections();
530 | const testCollection = collections.collections[0];
531 |
532 | const collection = await client.getCollection(testCollection.slug);
533 | expect(collection.name).toBe(testCollection.name);
534 | expect(collection.slug).toBe(testCollection.slug);
535 | expect(collection.models).toBeInstanceOf(Array);
536 | });
537 | });
538 |
539 | describe("Template System Integration", () => {
540 | it("should generate valid parameters for models", async () => {
541 | // Get SDXL model for testing
542 | const model = await client.getModel("stability-ai", "sdxl");
543 | const schema = model.latest_version?.openapi_schema;
544 | expect(schema).toBeDefined();
545 |
546 | // Generate parameters
547 | const prompt = "a detailed portrait in anime style";
548 | const params = templateManager.generateParameters(prompt, {
549 | quality: "quality",
550 | style: "anime",
551 | size: "portrait",
552 | });
553 |
554 | // Validate against model schema
555 | const errors = [];
556 | const inputSchema = schema?.components?.schemas?.Input as SchemaObject;
557 | if (inputSchema) {
558 | // Check required fields
559 | if (inputSchema.required) {
560 | for (const field of inputSchema.required) {
561 | if (!(field in params)) {
562 | errors.push(`Missing required field: ${field}`);
563 | }
564 | }
565 | }
566 |
567 | // Validate property types
568 | if (inputSchema.properties) {
569 | for (const [field, prop] of Object.entries(
570 | inputSchema.properties as Record
571 | )) {
572 | if (field in params) {
573 | const value = params[field as keyof typeof params] as unknown;
574 | switch (prop.type) {
575 | case "number": {
576 | const numValue = value as number;
577 | if (typeof numValue !== "number") {
578 | errors.push(`${field} must be a number`);
579 | } else {
580 | if (prop.minimum !== undefined && numValue < prop.minimum) {
581 | errors.push(`${field} must be >= ${prop.minimum}`);
582 | }
583 | if (prop.maximum !== undefined && numValue > prop.maximum) {
584 | errors.push(`${field} must be <= ${prop.maximum}`);
585 | }
586 | }
587 | break;
588 | }
589 | case "string": {
590 | const strValue = value as string;
591 | if (typeof strValue !== "string") {
592 | errors.push(`${field} must be a string`);
593 | } else if (prop.enum && !prop.enum.includes(strValue)) {
594 | errors.push(
595 | `${field} must be one of: ${prop.enum.join(", ")}`
596 | );
597 | }
598 | break;
599 | }
600 | case "integer": {
601 | const intValue = value as number;
602 | if (!Number.isInteger(intValue)) {
603 | errors.push(`${field} must be an integer`);
604 | }
605 | break;
606 | }
607 | }
608 | }
609 | }
610 | }
611 | }
612 |
613 | expect(errors).toHaveLength(0);
614 | });
615 | });
616 | });
617 |
--------------------------------------------------------------------------------
/src/tests/protocol.test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Protocol compliance tests.
3 | */
4 |
5 | import {
6 | describe,
7 | it,
8 | expect,
9 | beforeEach,
10 | afterEach,
11 | vi,
12 | beforeAll,
13 | afterAll,
14 | } from "vitest";
15 | import { Server } from "@modelcontextprotocol/sdk/server/index.js";
16 | import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
17 | import { WebhookService } from "../services/webhook.js";
18 | import { ReplicateClient } from "../replicate_client.js";
19 | import { TemplateManager } from "../templates/manager.js";
20 | import { ErrorHandler, ReplicateError } from "../services/error.js";
21 | import { PredictionStatus } from "../models/prediction.js";
22 | import { createError } from "../services/error.js";
23 |
24 | // Set test environment
25 | process.env.NODE_ENV = "test";
26 | process.env.REPLICATE_API_TOKEN = "test_token";
27 |
28 | describe("Protocol Compliance", () => {
29 | let server: Server;
30 | let client: ReplicateClient;
31 | let webhookService: WebhookService;
32 | let templateManager: TemplateManager;
33 | let transport: StdioServerTransport;
34 |
35 | beforeAll(() => {
36 | // Enable fake timers
37 | vi.useFakeTimers();
38 | });
39 |
40 | beforeEach(async () => {
41 | // Initialize components
42 | client = new ReplicateClient();
43 | webhookService = new WebhookService();
44 | templateManager = new TemplateManager();
45 | transport = new StdioServerTransport();
46 |
47 | // Create server with test configuration
48 | server = new Server(
49 | {
50 | name: "replicate-test",
51 | version: "0.1.0",
52 | },
53 | {
54 | capabilities: {
55 | tools: {},
56 | prompts: {},
57 | },
58 | }
59 | );
60 |
61 | // Initialize server
62 | await server.connect(transport);
63 | });
64 |
65 | afterEach(async () => {
66 | // Clean up
67 | if (server) {
68 | await server.close();
69 | }
70 | vi.clearAllMocks();
71 | });
72 |
73 | afterAll(() => {
74 | // Restore real timers
75 | vi.useRealTimers();
76 | });
77 |
78 | describe("Message Format", () => {
79 | it("should use JSON-RPC 2.0 format", async () => {
80 | const message = {
81 | jsonrpc: "2.0",
82 | method: "test",
83 | params: {},
84 | id: 1,
85 | };
86 |
87 | expect(message.jsonrpc).toBe("2.0");
88 | expect(message).toHaveProperty("method");
89 | expect(message).toHaveProperty("params");
90 | expect(message).toHaveProperty("id");
91 | });
92 |
93 | it("should handle notifications without id", async () => {
94 | const notification = {
95 | jsonrpc: "2.0",
96 | method: "test",
97 | params: {},
98 | };
99 |
100 | expect(notification.jsonrpc).toBe("2.0");
101 | expect(notification).toHaveProperty("method");
102 | expect(notification).toHaveProperty("params");
103 | expect(notification).not.toHaveProperty("id");
104 | });
105 | });
106 |
107 | describe("Error Handling", () => {
108 | it("should handle rate limit errors with retries", async () => {
109 | const rateLimitError = createError.rateLimit(1);
110 |
111 | // Mock client to throw rate limit error once then succeed
112 | let attempts = 0;
113 | vi.spyOn(client, "listModels").mockImplementation(async () => {
114 | if (attempts === 0) {
115 | attempts++;
116 | throw rateLimitError;
117 | }
118 | return { models: [] };
119 | });
120 |
121 | // Mock setTimeout to advance timers
122 | const setTimeoutSpy = vi.spyOn(global, "setTimeout");
123 |
124 | // Start the retry operation
125 | const resultPromise = ErrorHandler.withRetries(
126 | async () => client.listModels(),
127 | {
128 | maxAttempts: 2,
129 | minDelay: 100, // Use smaller delay for tests
130 | maxDelay: 200,
131 | retryIf: (error: Error) => error instanceof ReplicateError,
132 | }
133 | );
134 |
135 | // Wait for setTimeout to be called and advance timer
136 | await vi.waitFor(() => setTimeoutSpy.mock.calls.length > 0);
137 | await vi.runAllTimersAsync();
138 |
139 | // Wait for the result
140 | const result = await resultPromise;
141 |
142 | expect(attempts).toBe(1);
143 | expect(result).toEqual({ models: [] });
144 | });
145 |
146 | it("should generate detailed error reports", () => {
147 | const error = createError.validation("test", "Test error");
148 |
149 | const report = ErrorHandler.createErrorReport(error);
150 | expect(report).toMatchObject({
151 | name: "ReplicateError",
152 | message: "Invalid input parameters",
153 | context: {
154 | field: "test",
155 | message: "Test error",
156 | },
157 | timestamp: expect.any(String),
158 | });
159 | });
160 |
161 | it("should handle prediction status transitions", async () => {
162 | const prediction = {
163 | id: "test",
164 | status: PredictionStatus.Starting,
165 | version: "test-version",
166 | };
167 |
168 | const statusUpdates: PredictionStatus[] = [];
169 | const mockTransport = {
170 | notify: vi.fn().mockImplementation((notification) => {
171 | if (notification.method === "prediction/status") {
172 | statusUpdates.push(notification.params.status);
173 | }
174 | }),
175 | };
176 |
177 | // Simulate status transitions
178 | prediction.status = PredictionStatus.Processing;
179 | await mockTransport.notify({
180 | method: "prediction/status",
181 | params: { status: prediction.status },
182 | });
183 |
184 | prediction.status = PredictionStatus.Succeeded;
185 | await mockTransport.notify({
186 | method: "prediction/status",
187 | params: { status: prediction.status },
188 | });
189 |
190 | expect(statusUpdates).toEqual([
191 | PredictionStatus.Processing,
192 | PredictionStatus.Succeeded,
193 | ]);
194 | expect(mockTransport.notify).toHaveBeenCalledTimes(2);
195 | });
196 | });
197 |
198 | describe("Webhook Integration", () => {
199 | beforeEach(() => {
200 | // Mock fetch for webhook delivery
201 | vi.spyOn(global, "fetch").mockImplementation(async () => {
202 | return Promise.resolve({
203 | ok: true,
204 | status: 200,
205 | statusText: "OK",
206 | } as Response);
207 | });
208 | });
209 |
210 | it("should validate webhook configuration", () => {
211 | const validConfig = {
212 | url: "https://example.com/webhook",
213 | secret: "1234567890abcdef1234567890abcdef",
214 | retries: 3,
215 | timeout: 5000,
216 | };
217 |
218 | const invalidConfig = {
219 | url: "not-a-url",
220 | secret: "too-short",
221 | retries: -1,
222 | timeout: 500,
223 | };
224 |
225 | expect(webhookService.validateWebhookConfig(validConfig)).toHaveLength(0);
226 | expect(webhookService.validateWebhookConfig(invalidConfig)).toHaveLength(
227 | 4
228 | );
229 | });
230 |
231 | it("should handle webhook delivery with retries", async () => {
232 | const deliverySpy = vi.spyOn(
233 | webhookService,
234 | "deliverWebhook" as keyof WebhookService
235 | );
236 | const fetchSpy = vi.spyOn(global, "fetch");
237 |
238 | // First call fails, second succeeds
239 | fetchSpy
240 | .mockRejectedValueOnce(new Error("Delivery failed"))
241 | .mockResolvedValueOnce({
242 | ok: true,
243 | status: 200,
244 | statusText: "OK",
245 | } as Response);
246 |
247 | const webhookId = await webhookService.queueWebhook(
248 | {
249 | url: "https://example.com/webhook",
250 | retries: 1,
251 | },
252 | {
253 | type: "prediction.created",
254 | timestamp: new Date().toISOString(),
255 | data: { id: "123" },
256 | }
257 | );
258 |
259 | // Wait for delivery attempts
260 | await vi.advanceTimersByTimeAsync(100); // First attempt
261 | await vi.advanceTimersByTimeAsync(200); // Retry attempt
262 | await vi.advanceTimersByTimeAsync(100); // Processing time
263 |
264 | const results = webhookService.getDeliveryResults(webhookId);
265 | expect(results).toHaveLength(2);
266 | expect(results[0].success).toBe(false);
267 | expect(results[1].success).toBe(true);
268 | expect(deliverySpy).toHaveBeenCalledTimes(2);
269 | });
270 | });
271 |
272 | describe("Template System", () => {
273 | it("should generate parameters from templates", () => {
274 | const prompt = "a photo of a mountain landscape at sunset";
275 | const params = templateManager.generateParameters(prompt, {
276 | quality: "quality",
277 | style: "photorealistic",
278 | size: "landscape",
279 | });
280 |
281 | expect(params).toHaveProperty("prompt");
282 | expect(params).toHaveProperty("negative_prompt");
283 | expect(params).toHaveProperty("width");
284 | expect(params).toHaveProperty("height");
285 | expect(params).toHaveProperty("num_inference_steps");
286 | expect(params).toHaveProperty("guidance_scale");
287 | });
288 |
289 | it("should suggest parameters based on prompt", () => {
290 | const prompt = "a detailed anime-style portrait";
291 | const suggestions = templateManager.suggestParameters(prompt);
292 |
293 | expect(suggestions.quality).toBe("quality");
294 | expect(suggestions.style).toBe("anime");
295 | expect(suggestions.size).toBe("portrait");
296 | });
297 |
298 | it("should validate template parameters", () => {
299 | const validTemplate = templateManager.generateParameters("test prompt", {
300 | quality: "quality",
301 | style: "photorealistic",
302 | size: "landscape",
303 | });
304 |
305 | const invalidTemplate = {
306 | ...validTemplate,
307 | width: -100, // Invalid width
308 | };
309 |
310 | expect(() =>
311 | templateManager.validateParameters(validTemplate)
312 | ).not.toThrow();
313 | expect(() =>
314 | templateManager.validateParameters(invalidTemplate)
315 | ).toThrow();
316 | });
317 | });
318 | });
319 |
--------------------------------------------------------------------------------
/src/tools/handlers.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Handlers for MCP tools.
3 | */
4 |
5 | import type { ReplicateClient } from "../replicate_client.js";
6 | import type { Model } from "../models/model.js";
7 | import type { ModelIO, Prediction } from "../models/prediction.js";
8 | import { PredictionStatus } from "../models/prediction.js";
9 | import type { Collection } from "../models/collection.js";
10 |
11 | /**
12 | * Cache for models, predictions, and collections.
13 | */
14 | interface Cache {
15 | modelCache: Map;
16 | predictionCache: Map;
17 | collectionCache: Map;
18 | predictionStatus: Map;
19 | }
20 |
21 | /**
22 | * Get error message from unknown error.
23 | */
24 | function getErrorMessage(error: unknown): string {
25 | if (error instanceof Error) {
26 | return error.message;
27 | }
28 | if (error instanceof Promise) {
29 | return "An asynchronous error occurred. Please try again.";
30 | }
31 | return String(error);
32 | }
33 |
34 | /**
35 | * Handle the search_models tool.
36 | */
37 | export async function handleSearchModels(
38 | client: ReplicateClient,
39 | cache: Cache,
40 | args: { query: string }
41 | ) {
42 | try {
43 | const result = await client.searchModels(args.query);
44 |
45 | // Update cache
46 | for (const model of result.models) {
47 | cache.modelCache.set(`${model.owner}/${model.name}`, model);
48 | }
49 |
50 | return {
51 | content: [
52 | {
53 | type: "text",
54 | text: `Found ${result.models.length} models matching "${args.query}":`,
55 | },
56 | ...result.models.map((model) => ({
57 | type: "text" as const,
58 | text: `- ${model.owner}/${model.name}: ${
59 | model.description || "No description"
60 | }`,
61 | })),
62 | ],
63 | };
64 | } catch (error) {
65 | return {
66 | isError: true,
67 | content: [
68 | {
69 | type: "text",
70 | text: `Error searching models: ${getErrorMessage(error)}`,
71 | },
72 | ],
73 | };
74 | }
75 | }
76 |
77 | /**
78 | * Handle the list_models tool.
79 | */
80 | export async function handleListModels(
81 | client: ReplicateClient,
82 | cache: Cache,
83 | args: { owner?: string; cursor?: string }
84 | ) {
85 | try {
86 | const result = await client.listModels(args);
87 |
88 | // Update cache
89 | for (const model of result.models) {
90 | cache.modelCache.set(`${model.owner}/${model.name}`, model);
91 | }
92 |
93 | return {
94 | content: [
95 | {
96 | type: "text",
97 | text: args.owner ? `Models by ${args.owner}:` : "Available models:",
98 | },
99 | ...result.models.map((model) => ({
100 | type: "text" as const,
101 | text: `- ${model.owner}/${model.name}: ${
102 | model.description || "No description"
103 | }`,
104 | })),
105 | result.next_cursor
106 | ? {
107 | type: "text" as const,
108 | text: `\nUse cursor "${result.next_cursor}" to see more results.`,
109 | }
110 | : null,
111 | ].filter(Boolean),
112 | };
113 | } catch (error) {
114 | return {
115 | isError: true,
116 | content: [
117 | {
118 | type: "text",
119 | text: `Error listing models: ${getErrorMessage(error)}`,
120 | },
121 | ],
122 | };
123 | }
124 | }
125 |
126 | /**
127 | * Handle the list_collections tool.
128 | */
129 | export async function handleListCollections(
130 | client: ReplicateClient,
131 | cache: Cache,
132 | args: { cursor?: string }
133 | ) {
134 | try {
135 | const result = await client.listCollections(args);
136 |
137 | // Update cache
138 | for (const collection of result.collections) {
139 | cache.collectionCache.set(collection.slug, collection);
140 | }
141 |
142 | return {
143 | content: [
144 | {
145 | type: "text",
146 | text: "Available collections:",
147 | },
148 | ...result.collections.map((collection: Collection) => ({
149 | type: "text" as const,
150 | text: `- ${collection.name} (slug: ${collection.slug}): ${
151 | collection.description ||
152 | `A collection of ${collection.models.length} models`
153 | }`,
154 | })),
155 | result.next_cursor
156 | ? {
157 | type: "text" as const,
158 | text: `\nUse cursor "${result.next_cursor}" to see more results.`,
159 | }
160 | : null,
161 | ].filter(Boolean),
162 | };
163 | } catch (error) {
164 | return {
165 | isError: true,
166 | content: [
167 | {
168 | type: "text",
169 | text: `Error listing collections: ${getErrorMessage(error)}`,
170 | },
171 | ],
172 | };
173 | }
174 | }
175 |
176 | /**
177 | * Handle the get_collection tool.
178 | */
179 | export async function handleGetCollection(
180 | client: ReplicateClient,
181 | cache: Cache,
182 | args: { slug: string }
183 | ) {
184 | try {
185 | const collection = await client.getCollection(args.slug);
186 |
187 | // Update cache
188 | cache.collectionCache.set(collection.slug, collection);
189 |
190 | return {
191 | content: [
192 | {
193 | type: "text",
194 | text: `Collection: ${collection.name}`,
195 | },
196 | collection.description
197 | ? {
198 | type: "text" as const,
199 | text: collection.description,
200 | }
201 | : null,
202 | {
203 | type: "text",
204 | text: "\nModels in this collection:",
205 | },
206 | ...collection.models.map((model: Model) => ({
207 | type: "text" as const,
208 | text: `- ${model.owner}/${model.name}: ${
209 | model.description || "No description"
210 | }`,
211 | })),
212 | ].filter(Boolean),
213 | };
214 | } catch (error) {
215 | return {
216 | isError: true,
217 | content: [
218 | {
219 | type: "text",
220 | text: `Error getting collection: ${getErrorMessage(error)}`,
221 | },
222 | ],
223 | };
224 | }
225 | }
226 |
227 | /**
228 | * Handle the create_prediction tool.
229 | */
230 | export async function handleCreatePrediction(
231 | client: ReplicateClient,
232 | cache: Cache,
233 | args: {
234 | version: string | undefined;
235 | model: string | undefined;
236 | input: ModelIO | string;
237 | webhook?: string;
238 | }
239 | ) {
240 | try {
241 | // If input is a string, wrap it in an object with 'prompt' property
242 | const input =
243 | typeof args.input === "string" ? { prompt: args.input } : args.input;
244 |
245 | const prediction = await client.createPrediction({
246 | ...args,
247 | input,
248 | });
249 |
250 | // Cache the prediction and its initial status
251 | cache.predictionCache.set(prediction.id, prediction);
252 | cache.predictionStatus.set(
253 | prediction.id,
254 | prediction.status as PredictionStatus
255 | );
256 |
257 | return {
258 | content: [
259 | {
260 | type: "text",
261 | text: `Created prediction ${prediction.id}`,
262 | },
263 | ],
264 | };
265 | } catch (error) {
266 | return {
267 | isError: true,
268 | content: [
269 | {
270 | type: "text",
271 | text: `Error creating prediction: ${getErrorMessage(error)}`,
272 | },
273 | ],
274 | };
275 | }
276 | }
277 |
278 | /**
279 | * Creates a prediction and handles the full lifecycle:
280 | * sends the request, polls until completion, and returns the final result URL.
281 | */
282 | export async function handleCreateAndPollPrediction(
283 | client: ReplicateClient,
284 | cache: Cache,
285 | args: {
286 | version: string | undefined;
287 | model: string | undefined;
288 | input: ModelIO | string;
289 | webhook?: string;
290 | pollInterval?: number;
291 | timeout?: number;
292 | }
293 | ) {
294 | // If input is a string, wrap it in an object with 'prompt' property
295 | const input =
296 | typeof args.input === "string" ? { prompt: args.input } : args.input;
297 |
298 | let prediction;
299 | try {
300 | prediction = await client.createPrediction({
301 | ...args,
302 | input,
303 | });
304 | } catch (error) {
305 | return {
306 | isError: true,
307 | content: [
308 | {
309 | type: "text",
310 | text: `Error creating prediction: ${getErrorMessage(error)}`,
311 | },
312 | ],
313 | };
314 | }
315 |
316 | // Cache the prediction and its initial status
317 | cache.predictionCache.set(prediction.id, prediction);
318 | cache.predictionStatus.set(
319 | prediction.id,
320 | prediction.status as PredictionStatus
321 | );
322 |
323 | const shouldContinuePolling = (
324 | prediction: Prediction | null,
325 | timeoutAt: number
326 | ) => {
327 | if (!prediction) return true;
328 | if (performance.now() > timeoutAt) return false;
329 |
330 | if (
331 | prediction.status === "succeeded" ||
332 | prediction.status === "failed" ||
333 | prediction.status === "canceled"
334 | ) {
335 | return false;
336 | }
337 | return true;
338 | };
339 |
340 | const { pollInterval = 1, timeout = 60 } = args;
341 | const predictionId = prediction.id;
342 | let timeoutAt = performance.now() + timeout * 1000;
343 |
344 | do {
345 | await new Promise((resolve) => setTimeout(resolve, pollInterval * 1000));
346 |
347 | try {
348 | prediction = await client.getPredictionStatus(predictionId);
349 | } catch (error) {
350 | console.error(error);
351 | }
352 | } while (shouldContinuePolling(prediction, timeoutAt));
353 |
354 | if (timeoutAt < performance.now()) {
355 | console.warn(
356 | `Timeout reached while polling prediction by id: ${predictionId}`
357 | );
358 | return {
359 | isError: true,
360 | content: [
361 | {
362 | type: "text",
363 | text: `Timeout reached while polling prediction by id: ${predictionId}`,
364 | },
365 | ],
366 | };
367 | }
368 |
369 | if (prediction.status === "canceled" || prediction.status === "failed") {
370 | console.warn(
371 | `Prediction with id: ${predictionId} ${prediction.status} with error ${prediction.error}`
372 | );
373 | return {
374 | isError: true,
375 | content: [
376 | {
377 | type: "text",
378 | text: `Prediction with id: ${predictionId} ${prediction.status} with error ${prediction.error}`,
379 | },
380 | ],
381 | };
382 | }
383 |
384 | return {
385 | content: [
386 | {
387 | type: "text",
388 | text: `Created prediction ${prediction.id}, Output: ${prediction.output}`,
389 | },
390 | ],
391 | };
392 | }
393 |
394 | /**
395 | * Handle the cancel_prediction tool.
396 | */
397 | export async function handleCancelPrediction(
398 | client: ReplicateClient,
399 | cache: Cache,
400 | args: { prediction_id: string }
401 | ) {
402 | try {
403 | const prediction = await client.cancelPrediction(args.prediction_id);
404 | // Update cache
405 | cache.predictionCache.set(prediction.id, prediction);
406 | cache.predictionStatus.set(
407 | prediction.id,
408 | prediction.status as PredictionStatus
409 | );
410 |
411 | return {
412 | content: [
413 | {
414 | type: "text",
415 | text: `Cancelled prediction ${prediction.id}`,
416 | },
417 | ],
418 | };
419 | } catch (error) {
420 | return {
421 | isError: true,
422 | content: [
423 | {
424 | type: "text",
425 | text: `Error cancelling prediction: ${getErrorMessage(error)}`,
426 | },
427 | ],
428 | };
429 | }
430 | }
431 |
432 | /**
433 | * Handle the get_model tool.
434 | */
435 | export async function handleGetModel(
436 | client: ReplicateClient,
437 | cache: Cache,
438 | args: { owner: string; name: string }
439 | ) {
440 | try {
441 | const model = await client.getModel(args.owner, args.name);
442 |
443 | // Update cache
444 | cache.modelCache.set(`${model.owner}/${model.name}`, model);
445 |
446 | return {
447 | content: [
448 | {
449 | type: "text",
450 | text: `Model: ${model.owner}/${model.name}`,
451 | },
452 | model.description
453 | ? {
454 | type: "text" as const,
455 | text: `\nDescription: ${model.description}`,
456 | }
457 | : null,
458 | {
459 | type: "text",
460 | text: "\nLatest version:",
461 | },
462 | model.latest_version
463 | ? {
464 | type: "text" as const,
465 | text: `ID: ${model.latest_version.id}\nCreated: ${model.latest_version.created_at}`,
466 | }
467 | : {
468 | type: "text" as const,
469 | text: "No versions available",
470 | },
471 | ].filter(Boolean),
472 | };
473 | } catch (error) {
474 | return {
475 | isError: true,
476 | content: [
477 | {
478 | type: "text",
479 | text: `Error getting model: ${getErrorMessage(error)}`,
480 | },
481 | ],
482 | };
483 | }
484 | }
485 |
486 | /**
487 | * Handle the get_prediction tool.
488 | */
489 | export async function handleGetPrediction(
490 | client: ReplicateClient,
491 | cache: Cache,
492 | args: { prediction_id: string }
493 | ) {
494 | try {
495 | const prediction = await client.getPredictionStatus(args.prediction_id);
496 |
497 | const previousStatus = cache.predictionStatus.get(prediction.id);
498 |
499 | // Update cache
500 | cache.predictionCache.set(prediction.id, prediction);
501 | cache.predictionStatus.set(
502 | prediction.id,
503 | prediction.status as PredictionStatus
504 | );
505 |
506 | return {
507 | content: [
508 | {
509 | type: "text",
510 | text: `Prediction ${prediction.id}:`,
511 | },
512 | {
513 | type: "text",
514 | text: `Status: ${prediction.status}\nModel version: ${prediction.version}\nCreated: ${prediction.created_at}`,
515 | },
516 | prediction.input
517 | ? {
518 | type: "text" as const,
519 | text: `\nInput:\n${JSON.stringify(prediction.input, null, 2)}`,
520 | }
521 | : null,
522 | prediction.output
523 | ? {
524 | type: "text" as const,
525 | text: `\nOutput:\n${JSON.stringify(prediction.output, null, 2)}`,
526 | }
527 | : null,
528 | prediction.error
529 | ? {
530 | type: "text" as const,
531 | text: `\nError: ${prediction.error}`,
532 | }
533 | : null,
534 | prediction.logs
535 | ? {
536 | type: "text" as const,
537 | text: `\nLogs:\n${prediction.logs}`,
538 | }
539 | : null,
540 | ].filter(Boolean),
541 | };
542 | } catch (error) {
543 | return {
544 | isError: true,
545 | content: [
546 | {
547 | type: "text",
548 | text: `Error getting prediction: ${getErrorMessage(error)}`,
549 | },
550 | ],
551 | };
552 | }
553 | }
554 |
555 | /**
556 | * Handle the list_predictions tool.
557 | */
558 | /**
559 | * Estimate prediction progress based on logs and status.
560 | */
561 | function estimateProgress(prediction: Prediction): number {
562 | if (prediction.status === PredictionStatus.Succeeded) return 100;
563 | if (
564 | prediction.status === PredictionStatus.Failed ||
565 | prediction.status === PredictionStatus.Canceled
566 | )
567 | return 0;
568 | if (prediction.status === PredictionStatus.Starting) return 0;
569 |
570 | // Try to parse progress from logs
571 | if (prediction.logs) {
572 | const match = prediction.logs.match(/progress: (\d+)%/);
573 | if (match) {
574 | return Number.parseInt(match[1], 10);
575 | }
576 | }
577 |
578 | // Default to 50% if processing but no specific progress info
579 | return prediction.status === PredictionStatus.Processing ? 50 : 0;
580 | }
581 |
582 | export async function handleListPredictions(
583 | client: ReplicateClient,
584 | cache: Cache,
585 | args: { limit?: number; cursor?: string }
586 | ) {
587 | try {
588 | const predictions = await client.listPredictions({
589 | limit: args.limit || 10,
590 | });
591 |
592 | // Update cache and status tracking
593 | for (const prediction of predictions) {
594 | const previousStatus = cache.predictionStatus.get(prediction.id);
595 | cache.predictionCache.set(prediction.id, prediction);
596 | cache.predictionStatus.set(
597 | prediction.id,
598 | prediction.status as PredictionStatus
599 | );
600 | }
601 |
602 | // Format predictions as text
603 | const predictionTexts = predictions.map((prediction) => {
604 | const status = prediction.status.toUpperCase();
605 | const model = prediction.version;
606 | const time = new Date(prediction.created_at).toLocaleString();
607 | return `- ID: ${prediction.id}\n Status: ${status}\n Model: ${model}\n Created: ${time}`;
608 | });
609 |
610 | return {
611 | content: [
612 | {
613 | type: "text",
614 | text: `Found ${predictions.length} predictions:`,
615 | },
616 | {
617 | type: "text",
618 | text: predictionTexts.join("\n\n"),
619 | },
620 | ],
621 | };
622 | } catch (error) {
623 | return {
624 | isError: true,
625 | content: [
626 | {
627 | type: "text",
628 | text: `Error listing predictions: ${getErrorMessage(error)}`,
629 | },
630 | ],
631 | };
632 | }
633 | }
634 |
--------------------------------------------------------------------------------
/src/tools/image_viewer.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * MCP tools for image viewing functionality.
3 | */
4 |
5 | import { ImageViewer } from "../services/image_viewer.js";
6 | import type { Tool } from "@modelcontextprotocol/sdk/types.js";
7 | import type { MCPRequest } from "../types/mcp.js";
8 |
9 | /**
10 | * Tool for displaying images in the system's default web browser
11 | */
12 | export const viewImageTool: Tool = {
13 | name: "view_image",
14 | description: "Display an image in the system's default web browser",
15 | inputSchema: {
16 | type: "object",
17 | properties: {
18 | url: {
19 | type: "string",
20 | description: "URL of the image to display",
21 | },
22 | },
23 | required: ["url"],
24 | },
25 | handler: handleViewImage,
26 | };
27 |
28 | /**
29 | * Tool for clearing the image cache
30 | */
31 | export const clearImageCacheTool: Tool = {
32 | name: "clear_image_cache",
33 | description: "Clear the image viewer cache",
34 | inputSchema: {
35 | type: "object",
36 | properties: {},
37 | },
38 | handler: handleClearImageCache,
39 | };
40 |
41 | /**
42 | * Tool for getting image cache statistics
43 | */
44 | export const getImageCacheStatsTool: Tool = {
45 | name: "get_image_cache_stats",
46 | description: "Get statistics about the image cache",
47 | inputSchema: {
48 | type: "object",
49 | properties: {},
50 | },
51 | handler: handleGetImageCacheStats,
52 | };
53 |
54 | /**
55 | * Display an image in the system's default web browser
56 | */
57 | export async function handleViewImage(request: any) {
58 | const url = request.params.arguments?.url;
59 |
60 | if (typeof url !== "string") {
61 | return {
62 | isError: true,
63 | content: [
64 | {
65 | type: "text",
66 | text: "URL parameter is required",
67 | },
68 | ],
69 | };
70 | }
71 |
72 | try {
73 | const viewer = ImageViewer.getInstance();
74 | await viewer.displayImage(url);
75 |
76 | return {
77 | content: [
78 | {
79 | type: "text",
80 | text: "Image displayed successfully",
81 | },
82 | ],
83 | };
84 | } catch (error) {
85 | return {
86 | isError: true,
87 | content: [
88 | {
89 | type: "text",
90 | text: `Failed to display image: ${
91 | error instanceof Error ? error.message : String(error)
92 | }`,
93 | },
94 | ],
95 | };
96 | }
97 | }
98 |
99 | /**
100 | * Clear the image viewer cache
101 | */
102 | export async function handleClearImageCache(request: any) {
103 | try {
104 | const viewer = ImageViewer.getInstance();
105 | viewer.clearCache();
106 |
107 | return {
108 | content: [
109 | {
110 | type: "text",
111 | text: "Image cache cleared successfully",
112 | },
113 | ],
114 | };
115 | } catch (error) {
116 | return {
117 | isError: true,
118 | content: [
119 | {
120 | type: "text",
121 | text: `Failed to clear cache: ${
122 | error instanceof Error ? error.message : String(error)
123 | }`,
124 | },
125 | ],
126 | };
127 | }
128 | }
129 |
130 | /**
131 | * Get image viewer cache statistics
132 | */
133 | export async function handleGetImageCacheStats(request: any) {
134 | try {
135 | const viewer = ImageViewer.getInstance();
136 | const stats = viewer.getCacheStats();
137 |
138 | return {
139 | content: [
140 | {
141 | type: "text",
142 | text: "Cache statistics:",
143 | },
144 | {
145 | type: "text",
146 | text: JSON.stringify(stats, null, 2),
147 | },
148 | ],
149 | };
150 | } catch (error) {
151 | return {
152 | isError: true,
153 | content: [
154 | {
155 | type: "text",
156 | text: `Failed to get cache stats: ${
157 | error instanceof Error ? error.message : String(error)
158 | }`,
159 | },
160 | ],
161 | };
162 | }
163 | }
164 |
--------------------------------------------------------------------------------
/src/tools/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * MCP tools for interacting with Replicate.
3 | */
4 |
5 | export * from "./models.js";
6 | export * from "./predictions.js";
7 | export * from "./image_viewer.js";
8 |
9 | import type { Tool } from "@modelcontextprotocol/sdk/types.js";
10 | import {
11 | searchModelsTool,
12 | listModelsTool,
13 | listCollectionsTool,
14 | getCollectionTool,
15 | getModelTool,
16 | } from "./models.js";
17 | import {
18 | createPredictionTool,
19 | createAndPollPredictionTool,
20 | cancelPredictionTool,
21 | getPredictionTool,
22 | listPredictionsTool,
23 | } from "./predictions.js";
24 | import {
25 | viewImageTool,
26 | clearImageCacheTool,
27 | getImageCacheStatsTool,
28 | } from "./image_viewer.js";
29 |
30 | /**
31 | * All available tools.
32 | */
33 | export const tools: Tool[] = [
34 | searchModelsTool,
35 | listModelsTool,
36 | listCollectionsTool,
37 | getCollectionTool,
38 | createPredictionTool,
39 | createAndPollPredictionTool,
40 | cancelPredictionTool,
41 | getPredictionTool,
42 | listPredictionsTool,
43 | getModelTool,
44 | // Image viewer tools
45 | viewImageTool,
46 | clearImageCacheTool,
47 | getImageCacheStatsTool,
48 | ];
49 |
--------------------------------------------------------------------------------
/src/tools/models.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Tools for interacting with Replicate models.
3 | */
4 |
5 | import type { Tool } from "@modelcontextprotocol/sdk/types.js";
6 | import type { ReplicateClient } from "../replicate_client.js";
7 | import type { Model } from "../models/model.js";
8 |
9 | /**
10 | * Tool for searching models using semantic search.
11 | */
12 | export const searchModelsTool: Tool = {
13 | name: "search_models",
14 | description: "Search for models using semantic search",
15 | inputSchema: {
16 | type: "object",
17 | properties: {
18 | query: {
19 | type: "string",
20 | description: "Search query",
21 | },
22 | },
23 | required: ["query"],
24 | },
25 | };
26 |
27 | /**
28 | * Tool for listing available models.
29 | */
30 | export const listModelsTool: Tool = {
31 | name: "list_models",
32 | description: "List available models with optional filtering",
33 | inputSchema: {
34 | type: "object",
35 | properties: {
36 | owner: {
37 | type: "string",
38 | description: "Filter by model owner",
39 | },
40 | cursor: {
41 | type: "string",
42 | description: "Pagination cursor",
43 | },
44 | },
45 | },
46 | };
47 |
48 | /**
49 | * Tool for listing model collections.
50 | */
51 | export const listCollectionsTool: Tool = {
52 | name: "list_collections",
53 | description: "List available model collections",
54 | inputSchema: {
55 | type: "object",
56 | properties: {
57 | cursor: {
58 | type: "string",
59 | description: "Pagination cursor",
60 | },
61 | },
62 | },
63 | };
64 |
65 | /**
66 | * Tool for getting collection details.
67 | */
68 | export const getCollectionTool: Tool = {
69 | name: "get_collection",
70 | description: "Get details of a specific collection",
71 | inputSchema: {
72 | type: "object",
73 | properties: {
74 | slug: {
75 | type: "string",
76 | description: "Collection slug",
77 | },
78 | },
79 | required: ["slug"],
80 | },
81 | };
82 |
83 | /**
84 | * Tool for getting model details including versions.
85 | */
86 | export const getModelTool: Tool = {
87 | name: "get_model",
88 | description: "Get details of a specific model including available versions",
89 | inputSchema: {
90 | type: "object",
91 | properties: {
92 | owner: {
93 | type: "string",
94 | description: "Model owner",
95 | },
96 | name: {
97 | type: "string",
98 | description: "Model name",
99 | },
100 | },
101 | required: ["owner", "name"],
102 | },
103 | };
104 |
--------------------------------------------------------------------------------
/src/tools/predictions.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Tools for managing Replicate predictions.
3 | */
4 |
5 | import type { Tool } from "@modelcontextprotocol/sdk/types.js";
6 | import type { ReplicateClient } from "../replicate_client.js";
7 | import type { Prediction } from "../models/prediction.js";
8 |
9 | /**
10 | * Tool for creating new predictions.
11 | */
12 | export const createPredictionTool: Tool = {
13 | name: "create_prediction",
14 | description:
15 | "Create a new prediction using either a model version (for community models) or model name (for official models)",
16 | inputSchema: {
17 | type: "object",
18 | properties: {
19 | version: {
20 | type: "string",
21 | description: "Model version ID to use (for community models)",
22 | },
23 | model: {
24 | type: "string",
25 | description: "Model name to use (for official models)",
26 | },
27 | input: {
28 | type: "object",
29 | description: "Input parameters for the model",
30 | additionalProperties: true,
31 | },
32 | webhook_url: {
33 | type: "string",
34 | description: "Optional webhook URL for notifications",
35 | },
36 | },
37 | oneOf: [
38 | { required: ["version", "input"] },
39 | { required: ["model", "input"] },
40 | ],
41 | },
42 | };
43 |
44 | /**
45 | * Tool for creating a new prediction, waiting for it to complete,
46 | * and returning the final output URL.
47 | */
48 | export const createAndPollPredictionTool: Tool = {
49 | name: "create_and_poll_prediction",
50 | description:
51 | "Create a new prediction and wait until it's completed. Accepts either a model version (for community models) or a model name (for official models)",
52 | inputSchema: {
53 | type: "object",
54 | properties: {
55 | version: {
56 | type: "string",
57 | description: "Model version ID to use (for community models)",
58 | },
59 | model: {
60 | type: "string",
61 | description: "Model name to use (for official models)",
62 | },
63 | input: {
64 | type: "object",
65 | description: "Input parameters for the model",
66 | additionalProperties: true,
67 | },
68 | webhook_url: {
69 | type: "string",
70 | description: "Optional webhook URL for notifications",
71 | },
72 | poll_interval: {
73 | type: "number",
74 | description: "Optional interval between polls (default: 1)",
75 | },
76 | timeout: {
77 | type: "number",
78 | description: "Optional timeout for prediction (default: 60)",
79 | },
80 | },
81 | oneOf: [
82 | { required: ["version", "input"] },
83 | { required: ["model", "input"] },
84 | ],
85 | },
86 | };
87 |
88 | /**
89 | * Tool for canceling predictions.
90 | */
91 | export const cancelPredictionTool: Tool = {
92 | name: "cancel_prediction",
93 | description: "Cancel a running prediction",
94 | inputSchema: {
95 | type: "object",
96 | properties: {
97 | prediction_id: {
98 | type: "string",
99 | description: "ID of the prediction to cancel",
100 | },
101 | },
102 | required: ["prediction_id"],
103 | },
104 | };
105 |
106 | /**
107 | * Tool for getting prediction details.
108 | */
109 | export const getPredictionTool: Tool = {
110 | name: "get_prediction",
111 | description: "Get details about a specific prediction",
112 | inputSchema: {
113 | type: "object",
114 | properties: {
115 | prediction_id: {
116 | type: "string",
117 | description: "ID of the prediction to get details for",
118 | },
119 | },
120 | required: ["prediction_id"],
121 | },
122 | };
123 |
124 | /**
125 | * Tool for listing recent predictions.
126 | */
127 | export const listPredictionsTool: Tool = {
128 | name: "list_predictions",
129 | description: "List recent predictions",
130 | inputSchema: {
131 | type: "object",
132 | properties: {
133 | limit: {
134 | type: "number",
135 | description: "Maximum number of predictions to return",
136 | default: 10,
137 | },
138 | cursor: {
139 | type: "string",
140 | description: "Cursor for pagination",
141 | },
142 | },
143 | },
144 | };
145 |
--------------------------------------------------------------------------------
/src/types/mcp.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * Type definitions for MCP protocol.
3 | */
4 |
5 | import { EventEmitter } from "node:events";
6 |
7 | export interface MCPMessage {
8 | jsonrpc: "2.0";
9 | id?: string | number;
10 | method?: string;
11 | params?: Record;
12 | }
13 |
14 | export interface MCPRequest extends MCPMessage {
15 | method: string;
16 | params: Record;
17 | }
18 |
19 | export interface MCPResponse extends MCPMessage {
20 | result?: unknown;
21 | error?: {
22 | code: number;
23 | message: string;
24 | data?: unknown;
25 | };
26 | }
27 |
28 | export interface MCPNotification extends MCPMessage {
29 | method: string;
30 | params: Record;
31 | }
32 |
33 | export interface MCPResource {
34 | uri: string;
35 | mimeType?: string;
36 | text?: string;
37 | }
38 |
39 | export abstract class BaseTransport extends EventEmitter {
40 | abstract connect(): Promise;
41 | abstract disconnect(): Promise;
42 | abstract send(message: MCPMessage): Promise;
43 | }
44 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2022",
4 | "module": "NodeNext",
5 | "moduleResolution": "NodeNext",
6 | "outDir": "./build",
7 | "rootDir": "./src",
8 | "strict": true,
9 | "esModuleInterop": true,
10 | "skipLibCheck": true,
11 | "forceConsistentCasingInFileNames": true,
12 | "declaration": true,
13 | "sourceMap": true
14 | },
15 | "include": ["src/**/*"],
16 | "exclude": ["node_modules", "dist"]
17 | }
18 |
--------------------------------------------------------------------------------