├── .python-version ├── .env.example ├── src └── mcp_server_replicate │ ├── __init__.py │ ├── templates │ ├── parameters │ │ ├── __init__.py │ │ ├── llama.py │ │ ├── controlnet.py │ │ ├── common_configs.py │ │ ├── stable_diffusion.py │ │ └── prompt_templates.py │ └── prompts │ │ ├── text_to_image.py │ │ ├── text_generation.py │ │ ├── image_to_image.py │ │ └── controlnet.py │ ├── models │ ├── hardware.py │ ├── collection.py │ ├── webhook.py │ ├── prediction.py │ └── model.py │ ├── tools │ ├── webhook_tools.py │ ├── hardware_tools.py │ ├── collection_tools.py │ ├── prediction_tools.py │ └── model_tools.py │ ├── __main__.py │ ├── replicate_client.py │ └── server.py ├── .gitignore ├── smithery.yaml ├── test_client.py ├── Dockerfile ├── pyproject.toml ├── tests ├── conftest.py ├── utils │ └── mock_client.py ├── unit │ └── test_parameters │ │ └── test_controlnet.py └── test_server.py ├── docs ├── workflows.md └── templates.md ├── CONTRIBUTING.md ├── PLAN.md ├── .cursorrules └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | REPLICATE_API_TOKEN=your_token_here 2 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/__init__.py: -------------------------------------------------------------------------------- 1 | """MCP Server implementation for the Replicate API.""" 2 | 3 | from .__main__ import main 4 | 5 | __all__ = ["main"] 6 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/__init__.py: -------------------------------------------------------------------------------- 1 | """Parameter templates for Replicate models.""" 2 | 3 | from .common_configs import TEMPLATES as COMMON_TEMPLATES 4 | from .stable_diffusion import TEMPLATES as SD_TEMPLATES 5 | from .controlnet import TEMPLATES as CONTROLNET_TEMPLATES 6 | 7 | # Merge all templates 8 | TEMPLATES = { 9 | **COMMON_TEMPLATES, 10 | **SD_TEMPLATES, 11 | **CONTROLNET_TEMPLATES, 12 | } 13 | 14 | __all__ = ["TEMPLATES"] -------------------------------------------------------------------------------- /src/mcp_server_replicate/models/hardware.py: -------------------------------------------------------------------------------- 1 | """Data models for Replicate hardware options.""" 2 | 3 | from typing import List 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | class Hardware(BaseModel): 8 | """A hardware option for running models on Replicate.""" 9 | name: str = Field(..., description="Human-readable name of the hardware") 10 | sku: str = Field(..., description="SKU identifier for the hardware") 11 | 12 | class HardwareList(BaseModel): 13 | """Response format for listing hardware options.""" 14 | hardware: List[Hardware] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual environments 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | env.bak/ 30 | venv.bak/ 31 | 32 | # Testing 33 | .coverage 34 | htmlcov/ 35 | .pytest_cache/ 36 | .tox/ 37 | .hypothesis/ 38 | 39 | # IDE 40 | .idea/ 41 | .vscode/ 42 | *.swp 43 | *.swo 44 | *~ 45 | 46 | # OS 47 | .DS_Store 48 | Thumbs.db 49 | 50 | # Project specific 51 | *.log 52 | -------------------------------------------------------------------------------- /smithery.yaml: -------------------------------------------------------------------------------- 1 | # Smithery configuration file: https://smithery.ai/docs/deployments 2 | 3 | startCommand: 4 | type: stdio 5 | configSchema: 6 | # JSON Schema defining the configuration options for the MCP. 7 | type: object 8 | required: 9 | - replicateApiToken 10 | properties: 11 | replicateApiToken: 12 | type: string 13 | description: The API key for the Replicate server. 14 | commandFunction: 15 | # A function that produces the CLI command to start the MCP on stdio. 16 | |- 17 | config => ({ command: 'uv', args: ['tool', 'run', 'mcp-server-replicate'], env: { REPLICATE_API_TOKEN: config.replicateApiToken } }) -------------------------------------------------------------------------------- /test_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import replicate 3 | from pprint import pprint 4 | 5 | # Get API token from environment 6 | api_token = os.getenv("REPLICATE_API_TOKEN") 7 | print(f"Using API token: {api_token}") 8 | 9 | # Create client 10 | client = replicate.Client(api_token=api_token) 11 | 12 | # Search for models 13 | print("\nSearching for flux models...") 14 | page = client.models.search("flux") 15 | 16 | # Print page attributes to understand pagination 17 | print("\nPage attributes:") 18 | pprint(vars(page)) 19 | 20 | # Print results 21 | print("\nAll models:") 22 | for model in page.results: 23 | print(f"Model: {model.owner}/{model.name}") 24 | print(f"Created: {getattr(model, 'created_at', 'N/A')}") 25 | print(f"Run count: {getattr(model, 'run_count', 'N/A')}") 26 | print("---") -------------------------------------------------------------------------------- /src/mcp_server_replicate/models/collection.py: -------------------------------------------------------------------------------- 1 | """Data models for Replicate collections.""" 2 | 3 | from typing import List, Optional 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | from .model import Model 8 | 9 | class Collection(BaseModel): 10 | """A collection of related models on Replicate.""" 11 | name: str = Field(..., description="Name of the collection") 12 | slug: str = Field(..., description="URL-friendly identifier for the collection") 13 | description: Optional[str] = Field(None, description="Description of the collection's purpose") 14 | models: List[Model] = Field(default_factory=list, description="Models in this collection") 15 | 16 | class CollectionList(BaseModel): 17 | """Response format for listing collections.""" 18 | collections: List[Collection] 19 | next_cursor: Optional[str] = None -------------------------------------------------------------------------------- /src/mcp_server_replicate/tools/webhook_tools.py: -------------------------------------------------------------------------------- 1 | """FastMCP tools for managing Replicate webhooks.""" 2 | 3 | from typing import Any 4 | 5 | from mcp.server.fastmcp import FastMCP 6 | from ..models.webhook import WebhookEvent, WebhookPayload 7 | 8 | mcp = FastMCP() 9 | 10 | @mcp.tool( 11 | name="get_webhook_secret", 12 | description="Get the signing secret for verifying webhook requests.", 13 | ) 14 | async def get_webhook_secret() -> str: 15 | """Get webhook signing secret.""" 16 | raise NotImplementedError 17 | 18 | @mcp.tool( 19 | name="verify_webhook", 20 | description="Verify that a webhook request came from Replicate.", 21 | ) 22 | async def verify_webhook( 23 | payload: WebhookPayload, 24 | signature: str, 25 | secret: str, 26 | ) -> bool: 27 | """Verify webhook signature.""" 28 | raise NotImplementedError -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use the official Python image with a version that matches our needs 2 | FROM python:3.11-slim-bookworm AS base 3 | 4 | # Set the working directory in the container 5 | WORKDIR /app 6 | 7 | # Copy the pyproject.toml and the lock file to the container 8 | COPY pyproject.toml uv.lock /app/ 9 | 10 | # Install UV for dependency management 11 | RUN pip install uv 12 | 13 | # Install the dependencies using UV 14 | RUN uv sync --frozen --no-install-project --no-dev --no-editable 15 | 16 | # Copy the project files into the container 17 | COPY src/ /app/src/ 18 | 19 | # Ensure the entrypoint is executable 20 | RUN chmod +x /app/src/mcp_server_replicate/__main__.py 21 | 22 | # Specify the command to run the MCP server 23 | CMD ["uv", "tool", "run", "mcp-server-replicate"] 24 | 25 | # Set environment variable for API token (to be overridden when running the container) 26 | ENV REPLICATE_API_TOKEN=your_api_key_here -------------------------------------------------------------------------------- /src/mcp_server_replicate/tools/hardware_tools.py: -------------------------------------------------------------------------------- 1 | """FastMCP tools for managing Replicate hardware options.""" 2 | 3 | from mcp.server.fastmcp import FastMCP 4 | from ..models.hardware import Hardware, HardwareList 5 | from ..replicate_client import ReplicateClient 6 | 7 | mcp = FastMCP() 8 | 9 | @mcp.tool( 10 | name="list_hardware", 11 | description="List available hardware options for running models.", 12 | ) 13 | async def list_hardware() -> HardwareList: 14 | """List available hardware options for running models. 15 | 16 | Returns: 17 | HardwareList containing available hardware options 18 | 19 | Raises: 20 | RuntimeError: If the Replicate client fails to initialize 21 | Exception: If the API request fails 22 | """ 23 | async with ReplicateClient() as client: 24 | result = await client.list_hardware() 25 | return HardwareList(hardware=[Hardware(**hw) for hw in result]) -------------------------------------------------------------------------------- /src/mcp_server_replicate/__main__.py: -------------------------------------------------------------------------------- 1 | """Entry point for the MCP server.""" 2 | 3 | import argparse 4 | import logging 5 | import os 6 | from typing import NoReturn 7 | 8 | from .server import create_server 9 | 10 | 11 | def main() -> NoReturn: 12 | """Run the MCP server with configured log level.""" 13 | parser = argparse.ArgumentParser(description="Replicate MCP Server") 14 | parser.add_argument( 15 | "--log-level", 16 | choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], 17 | default=os.getenv("LOG_LEVEL", "WARNING"), 18 | help="Set the logging level (default: WARNING, env: LOG_LEVEL)", 19 | ) 20 | args = parser.parse_args() 21 | 22 | # Configure log level 23 | log_level = getattr(logging, args.log_level.upper()) 24 | 25 | # Create and run server with configured log level 26 | mcp = create_server(log_level=log_level) 27 | mcp.run() 28 | raise SystemExit(0) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/models/webhook.py: -------------------------------------------------------------------------------- 1 | """Data models for Replicate webhooks.""" 2 | 3 | from datetime import datetime 4 | from enum import Enum 5 | from typing import Any, Dict, Optional 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | class WebhookEventType(str, Enum): 10 | """Types of events that can trigger webhooks.""" 11 | START = "start" 12 | OUTPUT = "output" 13 | LOGS = "logs" 14 | COMPLETED = "completed" 15 | 16 | class WebhookEvent(BaseModel): 17 | """A webhook event from Replicate.""" 18 | type: WebhookEventType 19 | prediction_id: str = Field(..., description="ID of the prediction that triggered this event") 20 | timestamp: datetime = Field(..., description="When this event occurred") 21 | data: Dict[str, Any] = Field(..., description="Event-specific data payload") 22 | 23 | class WebhookPayload(BaseModel): 24 | """The full payload of a webhook request.""" 25 | event: WebhookEvent 26 | prediction: Dict[str, Any] = Field(..., description="Full prediction object at time of event") -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/prompts/text_to_image.py: -------------------------------------------------------------------------------- 1 | """Prompt templates for text-to-image models.""" 2 | 3 | from typing import Dict 4 | 5 | BASIC_PROMPT = { 6 | "id": "basic-text-to-image", 7 | "name": "Basic Text to Image", 8 | "description": "Simple prompt template for text-to-image generation", 9 | "template": "{prompt}", 10 | "variables": { 11 | "prompt": "The main text description of the image to generate" 12 | }, 13 | "model_type": "text-to-image", 14 | "version": "1.0.0", 15 | } 16 | 17 | DETAILED_PROMPT = { 18 | "id": "detailed-text-to-image", 19 | "name": "Detailed Text to Image", 20 | "description": "Detailed prompt template with style and quality modifiers", 21 | "template": "{prompt}, {style}, {quality}", 22 | "variables": { 23 | "prompt": "The main text description of the image to generate", 24 | "style": "The artistic style (e.g., 'oil painting', 'digital art')", 25 | "quality": "Quality modifiers (e.g., 'high quality', '4K, detailed')" 26 | }, 27 | "model_type": "text-to-image", 28 | "version": "1.0.0", 29 | } 30 | 31 | NEGATIVE_PROMPT = { 32 | "id": "negative-text-to-image", 33 | "name": "Text to Image with Negative Prompt", 34 | "description": "Prompt template with both positive and negative prompts", 35 | "template": "{prompt} || {negative_prompt}", 36 | "variables": { 37 | "prompt": "The main text description of what to include", 38 | "negative_prompt": "Description of elements to avoid" 39 | }, 40 | "model_type": "text-to-image", 41 | "version": "1.0.0", 42 | } 43 | 44 | # Export all templates 45 | TEMPLATES: Dict[str, dict] = { 46 | "basic": BASIC_PROMPT, 47 | "detailed": DETAILED_PROMPT, 48 | "negative": NEGATIVE_PROMPT, 49 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/prompts/text_generation.py: -------------------------------------------------------------------------------- 1 | """Prompt templates for text generation models.""" 2 | 3 | from typing import Dict 4 | 5 | BASIC_PROMPT = { 6 | "id": "basic-text-generation", 7 | "name": "Basic Text Generation", 8 | "description": "Simple prompt template for text generation", 9 | "template": "{prompt}", 10 | "variables": { 11 | "prompt": "The main text input for generation" 12 | }, 13 | "model_type": "text-generation", 14 | "version": "1.0.0", 15 | } 16 | 17 | CHAT_PROMPT = { 18 | "id": "chat-text-generation", 19 | "name": "Chat Completion", 20 | "description": "Chat-style prompt template with system and user messages", 21 | "template": "System: {system_message}\nUser: {user_message}", 22 | "variables": { 23 | "system_message": "Instructions or context for the model's behavior", 24 | "user_message": "The user's input message or query" 25 | }, 26 | "model_type": "text-generation", 27 | "version": "1.0.0", 28 | } 29 | 30 | STRUCTURED_PROMPT = { 31 | "id": "structured-text-generation", 32 | "name": "Structured Text Generation", 33 | "description": "Template for generating text with specific format instructions", 34 | "template": "{context}\n\nTask: {task}\n\nFormat: {format_instructions}\n\nInput: {input}", 35 | "variables": { 36 | "context": "Background information or context", 37 | "task": "Specific task or objective", 38 | "format_instructions": "Instructions for output format", 39 | "input": "The specific input to process" 40 | }, 41 | "model_type": "text-generation", 42 | "version": "1.0.0", 43 | } 44 | 45 | # Export all templates 46 | TEMPLATES: Dict[str, dict] = { 47 | "basic": BASIC_PROMPT, 48 | "chat": CHAT_PROMPT, 49 | "structured": STRUCTURED_PROMPT, 50 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/tools/collection_tools.py: -------------------------------------------------------------------------------- 1 | """FastMCP tools for browsing Replicate collections.""" 2 | 3 | from typing import Any 4 | 5 | from mcp.server.fastmcp import FastMCP 6 | from ..models.collection import Collection, CollectionList 7 | from ..replicate_client import ReplicateClient 8 | 9 | mcp = FastMCP() 10 | 11 | @mcp.tool( 12 | name="list_collections", 13 | description="List available model collections on Replicate.", 14 | ) 15 | async def list_collections() -> CollectionList: 16 | """List available model collections on Replicate. 17 | 18 | Returns: 19 | CollectionList containing available collections 20 | 21 | Raises: 22 | RuntimeError: If the Replicate client fails to initialize 23 | Exception: If the API request fails 24 | """ 25 | async with ReplicateClient() as client: 26 | result = await client.list_collections() 27 | return CollectionList(collections=[Collection(**collection) for collection in result]) 28 | 29 | @mcp.tool( 30 | name="get_collection_details", 31 | description="Get detailed information about a specific collection.", 32 | ) 33 | async def get_collection_details(collection_slug: str) -> Collection: 34 | """Get detailed information about a specific collection. 35 | 36 | Args: 37 | collection_slug: The slug identifier of the collection 38 | 39 | Returns: 40 | Collection object containing detailed collection information 41 | 42 | Raises: 43 | RuntimeError: If the Replicate client fails to initialize 44 | ValueError: If the collection is not found 45 | Exception: If the API request fails 46 | """ 47 | async with ReplicateClient() as client: 48 | result = await client.get_collection(collection_slug) 49 | return Collection(**result) -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/prompts/image_to_image.py: -------------------------------------------------------------------------------- 1 | """Prompt templates for image-to-image transformation models.""" 2 | 3 | from typing import Dict 4 | 5 | BASIC_TRANSFORM = { 6 | "id": "basic-image-transform", 7 | "name": "Basic Image Transform", 8 | "description": "Simple template for image transformation with guidance text", 9 | "template": "{prompt}", 10 | "variables": { 11 | "prompt": "Text description of the desired transformation" 12 | }, 13 | "model_type": "image-to-image", 14 | "version": "1.0.0", 15 | } 16 | 17 | STYLE_TRANSFER = { 18 | "id": "style-transfer", 19 | "name": "Style Transfer", 20 | "description": "Template for artistic style transfer with detailed style instructions", 21 | "template": "Transform the image in the style of {style_description}, {additional_details}", 22 | "variables": { 23 | "style_description": "Description of the target artistic style", 24 | "additional_details": "Additional style and quality specifications" 25 | }, 26 | "model_type": "image-to-image", 27 | "version": "1.0.0", 28 | } 29 | 30 | INPAINTING = { 31 | "id": "inpainting", 32 | "name": "Image Inpainting", 33 | "description": "Template for image inpainting with mask and instructions", 34 | "template": "Replace the masked area with {replacement_description}, maintaining {consistency_instructions}", 35 | "variables": { 36 | "replacement_description": "Description of what to generate in the masked area", 37 | "consistency_instructions": "Instructions for maintaining consistency with the rest of the image" 38 | }, 39 | "model_type": "image-to-image", 40 | "version": "1.0.0", 41 | } 42 | 43 | # Export all templates 44 | TEMPLATES: Dict[str, dict] = { 45 | "basic": BASIC_TRANSFORM, 46 | "style": STYLE_TRANSFER, 47 | "inpainting": INPAINTING, 48 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/prompts/controlnet.py: -------------------------------------------------------------------------------- 1 | """Prompt templates for ControlNet-guided image generation.""" 2 | 3 | from typing import Dict 4 | 5 | CANNY_CONTROL = { 6 | "id": "canny-controlnet", 7 | "name": "Canny Edge Control", 8 | "description": "Template for generating images guided by Canny edge detection", 9 | "template": "Generate {prompt}, following the edge guidance, {style_instructions}", 10 | "variables": { 11 | "prompt": "Description of the image to generate", 12 | "style_instructions": "Additional style and artistic instructions" 13 | }, 14 | "model_type": "controlnet", 15 | "version": "1.0.0", 16 | } 17 | 18 | DEPTH_CONTROL = { 19 | "id": "depth-controlnet", 20 | "name": "Depth Map Control", 21 | "description": "Template for generating images guided by depth maps", 22 | "template": "Create {prompt}, maintaining the provided depth structure, {composition_notes}", 23 | "variables": { 24 | "prompt": "Description of the image to generate", 25 | "composition_notes": "Notes about composition and spatial arrangement" 26 | }, 27 | "model_type": "controlnet", 28 | "version": "1.0.0", 29 | } 30 | 31 | POSE_CONTROL = { 32 | "id": "pose-controlnet", 33 | "name": "Pose Control", 34 | "description": "Template for generating images guided by pose estimation", 35 | "template": "Generate {prompt}, following the pose structure, {detail_instructions}", 36 | "variables": { 37 | "prompt": "Description of the image to generate", 38 | "detail_instructions": "Instructions for details and refinements" 39 | }, 40 | "model_type": "controlnet", 41 | "version": "1.0.0", 42 | } 43 | 44 | SEGMENTATION_CONTROL = { 45 | "id": "segmentation-controlnet", 46 | "name": "Segmentation Control", 47 | "description": "Template for generating images guided by segmentation maps", 48 | "template": "Create {prompt}, following the segmentation layout, {region_instructions}", 49 | "variables": { 50 | "prompt": "Description of the image to generate", 51 | "region_instructions": "Specific instructions for different regions" 52 | }, 53 | "model_type": "controlnet", 54 | "version": "1.0.0", 55 | } 56 | 57 | # Export all templates 58 | TEMPLATES: Dict[str, dict] = { 59 | "canny": CANNY_CONTROL, 60 | "depth": DEPTH_CONTROL, 61 | "pose": POSE_CONTROL, 62 | "segmentation": SEGMENTATION_CONTROL, 63 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/models/prediction.py: -------------------------------------------------------------------------------- 1 | """Data models for Replicate predictions.""" 2 | 3 | from datetime import datetime 4 | from enum import Enum 5 | from typing import Any, Dict, List, Optional 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | class PredictionStatus(str, Enum): 10 | """Status of a prediction.""" 11 | STARTING = "starting" 12 | PROCESSING = "processing" 13 | SUCCEEDED = "succeeded" 14 | FAILED = "failed" 15 | CANCELED = "canceled" 16 | 17 | class PredictionInput(BaseModel): 18 | """Input parameters for creating a prediction.""" 19 | model_version: str = Field(..., description="Model version to use for prediction") 20 | input: Dict[str, Any] = Field(..., description="Model-specific input parameters") 21 | template_id: Optional[str] = Field(None, description="Optional template ID to use") 22 | webhook_url: Optional[str] = Field(None, description="URL for webhook notifications") 23 | webhook_events: Optional[List[str]] = Field(None, description="Events to trigger webhooks") 24 | wait: bool = Field(False, description="Whether to wait for prediction completion") 25 | wait_timeout: Optional[int] = Field(None, description="Max seconds to wait if wait=True (1-60)") 26 | stream: bool = Field(False, description="Whether to request streaming output") 27 | 28 | class Prediction(BaseModel): 29 | """A prediction (model run) on Replicate.""" 30 | id: str = Field(..., description="Unique identifier for this prediction") 31 | version: str = Field(..., description="Model version used for this prediction") 32 | status: PredictionStatus = Field(..., description="Current status of the prediction") 33 | input: Dict[str, Any] = Field(..., description="Input parameters used for the prediction") 34 | output: Optional[Any] = Field(None, description="Output from the prediction if completed") 35 | error: Optional[str] = Field(None, description="Error message if prediction failed") 36 | logs: Optional[str] = Field(None, description="Execution logs from the prediction") 37 | created_at: datetime 38 | started_at: Optional[datetime] = None 39 | completed_at: Optional[datetime] = None 40 | urls: Dict[str, str] = Field(..., description="Related API URLs for this prediction") 41 | metrics: Optional[Dict[str, float]] = Field(None, description="Performance metrics if available") 42 | stream_url: Optional[str] = Field(None, description="URL for streaming output if requested") -------------------------------------------------------------------------------- /src/mcp_server_replicate/models/model.py: -------------------------------------------------------------------------------- 1 | """Data models for Replicate models and versions.""" 2 | 3 | from datetime import datetime 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | from pydantic import BaseModel, Field, field_validator 7 | 8 | class ModelVersion(BaseModel): 9 | """A specific version of a model on Replicate.""" 10 | id: str = Field(..., description="Unique identifier for this model version") 11 | created_at: datetime 12 | cog_version: str 13 | openapi_schema: Dict[str, Any] 14 | model: Optional[str] = Field(None, description="Model identifier (owner/name)") 15 | replicate_version: Optional[str] = Field(None, description="Replicate version identifier") 16 | hardware: Optional[str] = Field(None, description="Hardware configuration for this version") 17 | 18 | class Model(BaseModel): 19 | """Model information returned from Replicate.""" 20 | id: str = Field(..., description="Unique identifier in format owner/name") 21 | owner: str = Field(..., description="Owner of the model (user or organization)") 22 | name: str = Field(..., description="Name of the model") 23 | description: Optional[str] = Field(None, description="Description of the model's purpose and usage") 24 | visibility: str = Field("public", description="Model visibility (public/private)") 25 | github_url: Optional[str] = Field(None, description="URL to model's GitHub repository") 26 | paper_url: Optional[str] = Field(None, description="URL to model's research paper") 27 | license_url: Optional[str] = Field(None, description="URL to model's license") 28 | run_count: Optional[int] = Field(None, description="Number of times this model has been run") 29 | cover_image_url: Optional[str] = Field(None, description="URL to model's cover image") 30 | latest_version: Optional[ModelVersion] = Field(None, description="Latest version of the model") 31 | default_example: Optional[Dict[str, Any]] = Field(None, description="Default example inputs") 32 | featured: Optional[bool] = Field(None, description="Whether this model is featured") 33 | tags: Optional[List[str]] = Field(default_factory=list, description="Model tags") 34 | 35 | @field_validator("id", mode="before") 36 | def validate_id(cls, v: Optional[str], values: Dict[str, Any]) -> str: 37 | """Validate and construct ID if not provided.""" 38 | if v: 39 | return v 40 | owner = values.get("owner") 41 | name = values.get("name") 42 | if owner and name: 43 | return f"{owner}/{name}" 44 | raise ValueError("Either id or both owner and name must be provided") 45 | 46 | class ModelList(BaseModel): 47 | """Response format for listing models.""" 48 | models: List[Model] 49 | next_cursor: Optional[str] = None 50 | total_count: Optional[int] = None -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/llama.py: -------------------------------------------------------------------------------- 1 | """Parameter templates for LLaMA models.""" 2 | 3 | from typing import Dict, Any 4 | 5 | LLAMA_70B_PARAMETERS = { 6 | "id": "llama-70b-base", 7 | "name": "LLaMA 70B Base Parameters", 8 | "description": "Default parameters for LLaMA 70B models", 9 | "model_type": "text-generation", 10 | "default_parameters": { 11 | "temperature": 0.7, 12 | "top_p": 0.9, 13 | "max_tokens": 512, 14 | "repetition_penalty": 1.1, 15 | "stop_sequences": [], 16 | "system_prompt": "You are a helpful AI assistant.", 17 | }, 18 | "parameter_schema": { 19 | "type": "object", 20 | "properties": { 21 | "prompt": {"type": "string", "description": "The input text prompt"}, 22 | "system_prompt": {"type": "string", "description": "System prompt for setting assistant behavior"}, 23 | "temperature": {"type": "number", "minimum": 0, "maximum": 2}, 24 | "top_p": {"type": "number", "minimum": 0, "maximum": 1}, 25 | "max_tokens": {"type": "integer", "minimum": 1, "maximum": 4096}, 26 | "repetition_penalty": {"type": "number", "minimum": 0.1, "maximum": 2.0}, 27 | "stop_sequences": {"type": "array", "items": {"type": "string"}}, 28 | }, 29 | "required": ["prompt"] 30 | }, 31 | "version": "1.0.0", 32 | } 33 | 34 | LLAMA_13B_PARAMETERS = { 35 | "id": "llama-13b-base", 36 | "name": "LLaMA 13B Base Parameters", 37 | "description": "Default parameters for LLaMA 13B models", 38 | "model_type": "text-generation", 39 | "default_parameters": { 40 | "temperature": 0.7, 41 | "top_p": 0.9, 42 | "max_tokens": 256, 43 | "repetition_penalty": 1.1, 44 | "stop_sequences": [], 45 | "system_prompt": "You are a helpful AI assistant.", 46 | }, 47 | "parameter_schema": { 48 | "type": "object", 49 | "properties": { 50 | "prompt": {"type": "string", "description": "The input text prompt"}, 51 | "system_prompt": {"type": "string", "description": "System prompt for setting assistant behavior"}, 52 | "temperature": {"type": "number", "minimum": 0, "maximum": 2}, 53 | "top_p": {"type": "number", "minimum": 0, "maximum": 1}, 54 | "max_tokens": {"type": "integer", "minimum": 1, "maximum": 2048}, 55 | "repetition_penalty": {"type": "number", "minimum": 0.1, "maximum": 2.0}, 56 | "stop_sequences": {"type": "array", "items": {"type": "string"}}, 57 | }, 58 | "required": ["prompt"] 59 | }, 60 | "version": "1.0.0", 61 | } 62 | 63 | # Export all templates 64 | TEMPLATES: Dict[str, Dict[str, Any]] = { 65 | "llama70b": LLAMA_70B_PARAMETERS, 66 | "llama13b": LLAMA_13B_PARAMETERS, 67 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build] 6 | packages = ["src/mcp_server_replicate"] 7 | 8 | [project] 9 | name = "mcp-server-replicate" 10 | version = "0.1.9" 11 | description = "FastMCP server implementation for the Replicate API, providing resource-based access to AI model inference" 12 | readme = "README.md" 13 | requires-python = ">=3.11" 14 | license = "MIT" 15 | keywords = [ 16 | "mcp", 17 | "replicate", 18 | "ai", 19 | "inference", 20 | "stable-diffusion", 21 | "image-generation", 22 | ] 23 | authors = [{ name = "Gerred Dillon", email = "hello@gerred.org" }] 24 | classifiers = [ 25 | "Development Status :: 3 - Alpha", 26 | "Intended Audience :: Developers", 27 | "License :: OSI Approved :: MIT License", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.11", 30 | "Topic :: Software Development :: Libraries :: Python Modules", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "Topic :: Multimedia :: Graphics", 33 | ] 34 | dependencies = [ 35 | "mcp[cli]>=1.2.0", 36 | "replicate>=1.0.4", 37 | "python-dotenv>=1.0.0", 38 | "jsonschema>=4.21.1", 39 | "httpx>=0.26.0", 40 | "pydantic>=2.5.0", 41 | ] 42 | 43 | [project.optional-dependencies] 44 | dev = [ 45 | "pytest>=7.4.0", 46 | "pytest-cov>=4.1.0", 47 | "pytest-asyncio>=0.23.2", 48 | "black>=23.12.0", 49 | "ruff>=0.1.9", 50 | "mypy>=1.8.0", 51 | "build>=1.0.3", 52 | "twine>=4.0.2", 53 | ] 54 | 55 | [project.scripts] 56 | mcp-server-replicate = "mcp_server_replicate.__main__:main" 57 | 58 | [project.urls] 59 | Homepage = "https://github.com/gerred/mcp-server-replicate" 60 | Documentation = "https://github.com/gerred/mcp-server-replicate#readme" 61 | Issues = "https://github.com/gerred/mcp-server-replicate/issues" 62 | 63 | [tool.pytest.ini_options] 64 | minversion = "7.0" 65 | addopts = "-ra -q --cov" 66 | testpaths = ["tests"] 67 | asyncio_mode = "auto" 68 | asyncio_default_fixture_loop_scope = "function" 69 | 70 | [tool.black] 71 | line-length = 88 72 | target-version = ["py311"] 73 | 74 | [tool.ruff] 75 | line-length = 120 76 | target-version = "py311" 77 | 78 | [tool.ruff.lint] 79 | select = ["E", "F", "B", "I", "UP"] 80 | ignore = [] 81 | 82 | [tool.mypy] 83 | python_version = "3.11" 84 | strict = true 85 | warn_return_any = true 86 | warn_unused_configs = true 87 | disallow_untyped_defs = true 88 | [[tool.mypy.overrides]] 89 | module = "replicate.*" 90 | ignore_missing_imports = true 91 | 92 | [tool.coverage.run] 93 | source = ["mcp-server-replicate"] 94 | branch = true 95 | 96 | [tool.coverage.report] 97 | exclude_lines = [ 98 | "pragma: no cover", 99 | "def __repr__", 100 | "if self.debug:", 101 | "raise NotImplementedError", 102 | "if __name__ == .__main__.:", 103 | "pass", 104 | "raise ImportError", 105 | ] 106 | ignore_errors = true 107 | omit = ["tests/*", "setup.py"] 108 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Global test configuration and fixtures.""" 2 | 3 | import pytest 4 | from typing import AsyncGenerator, Dict, Any 5 | from pathlib import Path 6 | 7 | # Add project root to Python path 8 | import sys 9 | sys.path.insert(0, str(Path(__file__).parent.parent)) 10 | 11 | @pytest.fixture 12 | def test_data_dir() -> Path: 13 | """Return the path to the test data directory.""" 14 | return Path(__file__).parent / "fixtures" / "data" 15 | 16 | @pytest.fixture 17 | def mock_api_responses() -> Dict[str, Any]: 18 | """Return mock API responses for testing.""" 19 | return { 20 | "models": { 21 | "list": { 22 | "previous": None, 23 | "next": None, 24 | "results": [ 25 | { 26 | "url": "https://replicate.com/stability-ai/sdxl", 27 | "owner": "stability-ai", 28 | "name": "sdxl", 29 | "description": "A text-to-image generative AI model", 30 | "visibility": "public", 31 | "github_url": None, 32 | "paper_url": None, 33 | "license_url": None, 34 | "latest_version": { 35 | "id": "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", 36 | "created_at": "2023-09-22T21:00:00.000Z", 37 | } 38 | } 39 | ] 40 | } 41 | }, 42 | "predictions": { 43 | "create": { 44 | "id": "test-prediction-id", 45 | "version": "test-version-id", 46 | "status": "starting", 47 | "created_at": "2024-01-06T00:00:00.000Z", 48 | "started_at": None, 49 | "completed_at": None, 50 | "urls": { 51 | "get": "https://api.replicate.com/v1/predictions/test-prediction-id", 52 | "cancel": "https://api.replicate.com/v1/predictions/test-prediction-id/cancel", 53 | } 54 | }, 55 | "get": { 56 | "id": "test-prediction-id", 57 | "version": "test-version-id", 58 | "status": "succeeded", 59 | "created_at": "2024-01-06T00:00:00.000Z", 60 | "started_at": "2024-01-06T00:00:01.000Z", 61 | "completed_at": "2024-01-06T00:00:10.000Z", 62 | "output": ["https://replicate.delivery/test-output.png"], 63 | "urls": { 64 | "get": "https://api.replicate.com/v1/predictions/test-prediction-id", 65 | "cancel": "https://api.replicate.com/v1/predictions/test-prediction-id/cancel", 66 | } 67 | } 68 | } 69 | } 70 | 71 | @pytest.fixture 72 | async def mock_client() -> AsyncGenerator[Dict[str, Any], None]: 73 | """Return a mock client for testing.""" 74 | client = { 75 | "api_token": "test-token", 76 | "base_url": "https://api.replicate.com/v1", 77 | "responses": mock_api_responses(), 78 | } 79 | yield client -------------------------------------------------------------------------------- /tests/utils/mock_client.py: -------------------------------------------------------------------------------- 1 | """Mock client for testing Replicate API interactions.""" 2 | 3 | from typing import Dict, Any, Optional, List 4 | from dataclasses import dataclass 5 | from datetime import datetime, timezone 6 | 7 | @dataclass 8 | class MockResponse: 9 | """Mock HTTP response.""" 10 | status_code: int 11 | json_data: Dict[str, Any] 12 | 13 | async def json(self) -> Dict[str, Any]: 14 | """Return JSON response data.""" 15 | return self.json_data 16 | 17 | class MockReplicateClient: 18 | """Mock Replicate API client for testing.""" 19 | 20 | def __init__(self, responses: Dict[str, Any]) -> None: 21 | """Initialize mock client with predefined responses.""" 22 | self.responses = responses 23 | self.calls: List[Dict[str, Any]] = [] 24 | 25 | async def get_model(self, owner: str, name: str) -> MockResponse: 26 | """Mock get model endpoint.""" 27 | self.calls.append({ 28 | "method": "GET", 29 | "endpoint": f"models/{owner}/{name}", 30 | "timestamp": datetime.now(timezone.utc), 31 | }) 32 | 33 | # Return first model from list response as default 34 | model = self.responses["models"]["list"]["results"][0] 35 | return MockResponse(200, model) 36 | 37 | async def create_prediction( 38 | self, 39 | version: str, 40 | input: Dict[str, Any], 41 | webhook: Optional[str] = None, 42 | ) -> MockResponse: 43 | """Mock create prediction endpoint.""" 44 | self.calls.append({ 45 | "method": "POST", 46 | "endpoint": "predictions", 47 | "data": { 48 | "version": version, 49 | "input": input, 50 | "webhook": webhook, 51 | }, 52 | "timestamp": datetime.now(timezone.utc), 53 | }) 54 | return MockResponse(201, self.responses["predictions"]["create"]) 55 | 56 | async def get_prediction(self, id: str) -> MockResponse: 57 | """Mock get prediction endpoint.""" 58 | self.calls.append({ 59 | "method": "GET", 60 | "endpoint": f"predictions/{id}", 61 | "timestamp": datetime.now(timezone.utc), 62 | }) 63 | return MockResponse(200, self.responses["predictions"]["get"]) 64 | 65 | async def cancel_prediction(self, id: str) -> MockResponse: 66 | """Mock cancel prediction endpoint.""" 67 | self.calls.append({ 68 | "method": "POST", 69 | "endpoint": f"predictions/{id}/cancel", 70 | "timestamp": datetime.now(timezone.utc), 71 | }) 72 | 73 | response = self.responses["predictions"]["get"].copy() 74 | response["status"] = "canceled" 75 | return MockResponse(200, response) 76 | 77 | def get_calls(self, method: Optional[str] = None, endpoint: Optional[str] = None) -> List[Dict[str, Any]]: 78 | """Get filtered API calls.""" 79 | filtered = self.calls 80 | if method: 81 | filtered = [c for c in filtered if c["method"] == method] 82 | if endpoint: 83 | filtered = [c for c in filtered if c["endpoint"] == endpoint] 84 | return filtered -------------------------------------------------------------------------------- /src/mcp_server_replicate/tools/prediction_tools.py: -------------------------------------------------------------------------------- 1 | """FastMCP tools for managing Replicate predictions.""" 2 | 3 | from typing import Any, Optional 4 | 5 | from mcp.server.fastmcp import FastMCP 6 | from ..models.prediction import Prediction, PredictionInput, PredictionStatus 7 | from ..replicate_client import ReplicateClient 8 | 9 | mcp = FastMCP() 10 | 11 | @mcp.tool( 12 | name="create_prediction", 13 | description="Create a new prediction using a Replicate model.", 14 | ) 15 | async def create_prediction(input: PredictionInput) -> Prediction: 16 | """Create a new prediction using a Replicate model. 17 | 18 | Args: 19 | input: PredictionInput containing model, version, and input parameters 20 | 21 | Returns: 22 | Prediction object containing the prediction details and status 23 | 24 | Raises: 25 | RuntimeError: If the Replicate client fails to initialize 26 | ValueError: If the model or version is not found 27 | Exception: If the prediction creation fails 28 | """ 29 | async with ReplicateClient() as client: 30 | result = await client.predict( 31 | model=input.model_id, 32 | version=input.version_id, 33 | input_data=input.input_data, 34 | wait=input.wait, 35 | wait_timeout=input.wait_timeout, 36 | stream=input.stream 37 | ) 38 | return Prediction(**result) 39 | 40 | @mcp.tool( 41 | name="get_prediction", 42 | description="Get the current status and results of a prediction.", 43 | ) 44 | async def get_prediction(prediction_id: str) -> Prediction: 45 | """Get the current status and results of a prediction. 46 | 47 | Args: 48 | prediction_id: The ID of the prediction to retrieve 49 | 50 | Returns: 51 | Prediction object containing the current status and results 52 | 53 | Raises: 54 | RuntimeError: If the Replicate client fails to initialize 55 | ValueError: If the prediction is not found 56 | Exception: If the status check fails 57 | """ 58 | async with ReplicateClient() as client: 59 | result = client.get_prediction_status(prediction_id) 60 | return Prediction(**result) 61 | 62 | @mcp.tool( 63 | name="cancel_prediction", 64 | description="Cancel a running prediction.", 65 | ) 66 | async def cancel_prediction(prediction_id: str) -> Prediction: 67 | """Cancel a running prediction. 68 | 69 | Args: 70 | prediction_id: The ID of the prediction to cancel 71 | 72 | Returns: 73 | Prediction object containing the updated status 74 | 75 | Raises: 76 | RuntimeError: If the Replicate client fails to initialize 77 | ValueError: If the prediction is not found 78 | Exception: If the cancellation fails 79 | """ 80 | async with ReplicateClient() as client: 81 | result = await client.cancel_prediction(prediction_id) 82 | return Prediction(**result) 83 | 84 | @mcp.tool( 85 | name="list_predictions", 86 | description="List recent predictions with optional filtering.", 87 | ) 88 | async def list_predictions( 89 | status: Optional[PredictionStatus] = None, 90 | limit: int = 10 91 | ) -> list[Prediction]: 92 | """List recent predictions with optional filtering. 93 | 94 | Args: 95 | status: Optional status to filter predictions by 96 | limit: Maximum number of predictions to return (1-100) 97 | 98 | Returns: 99 | List of Prediction objects 100 | 101 | Raises: 102 | RuntimeError: If the Replicate client fails to initialize 103 | ValueError: If limit is out of range 104 | Exception: If the API request fails 105 | """ 106 | async with ReplicateClient() as client: 107 | result = await client.list_predictions( 108 | status=status.value if status else None, 109 | limit=limit 110 | ) 111 | return [Prediction(**prediction) for prediction in result] -------------------------------------------------------------------------------- /src/mcp_server_replicate/tools/model_tools.py: -------------------------------------------------------------------------------- 1 | """FastMCP tools for interacting with Replicate models.""" 2 | 3 | from typing import Any, Optional 4 | 5 | from mcp.server.fastmcp import FastMCP 6 | from ..models.model import Model, ModelVersion, ModelList 7 | from ..replicate_client import ReplicateClient 8 | 9 | mcp = FastMCP() 10 | 11 | @mcp.tool( 12 | name="list_models", 13 | description="List available models on Replicate with optional filtering by owner.", 14 | ) 15 | async def list_models(owner: Optional[str] = None) -> ModelList: 16 | """List available models on Replicate. 17 | 18 | Args: 19 | owner: Optional owner username to filter models by 20 | 21 | Returns: 22 | ModelList containing the available models and pagination info 23 | 24 | Raises: 25 | RuntimeError: If the Replicate client fails to initialize 26 | Exception: If the API request fails 27 | """ 28 | async with ReplicateClient() as client: 29 | result = client.list_models(owner=owner) 30 | return ModelList( 31 | models=[Model(**model) for model in result["models"]], 32 | next_cursor=result.get("next_cursor"), 33 | total_count=result.get("total_models") 34 | ) 35 | 36 | @mcp.tool( 37 | name="search_models", 38 | description="Search for models using semantic search.", 39 | ) 40 | async def search_models(query: str) -> ModelList: 41 | """Search for models using semantic search. 42 | 43 | Args: 44 | query: Search query string 45 | 46 | Returns: 47 | ModelList containing the matching models and pagination info 48 | 49 | Raises: 50 | RuntimeError: If the Replicate client fails to initialize 51 | Exception: If the API request fails 52 | """ 53 | async with ReplicateClient() as client: 54 | result = await client.search_models(query) 55 | return ModelList( 56 | models=[Model(**model) for model in result["models"]], 57 | next_cursor=result.get("next_cursor"), 58 | total_count=result.get("total_models") 59 | ) 60 | 61 | @mcp.tool( 62 | name="get_model_details", 63 | description="Get detailed information about a specific model.", 64 | ) 65 | async def get_model_details(model_id: str) -> Model: 66 | """Get detailed information about a specific model. 67 | 68 | Args: 69 | model_id: Model identifier in format 'owner/model' 70 | 71 | Returns: 72 | Model object containing detailed model information 73 | 74 | Raises: 75 | RuntimeError: If the Replicate client fails to initialize 76 | ValueError: If the model is not found 77 | Exception: If the API request fails 78 | """ 79 | owner, name = model_id.split("/") 80 | async with ReplicateClient() as client: 81 | # First try to find the model in the owner's models 82 | result = client.list_models(owner=owner) 83 | for model in result["models"]: 84 | if model["name"] == name: 85 | return Model(**model) 86 | 87 | # If not found, try searching for it 88 | search_result = await client.search_models(model_id) 89 | for model in search_result["models"]: 90 | if f"{model['owner']}/{model['name']}" == model_id: 91 | return Model(**model) 92 | 93 | raise ValueError(f"Model not found: {model_id}") 94 | 95 | @mcp.tool( 96 | name="get_model_versions", 97 | description="Get available versions for a model.", 98 | ) 99 | async def get_model_versions(model_id: str) -> list[ModelVersion]: 100 | """Get available versions for a model. 101 | 102 | Args: 103 | model_id: Model identifier in format 'owner/model' 104 | 105 | Returns: 106 | List of ModelVersion objects containing version metadata 107 | 108 | Raises: 109 | RuntimeError: If the Replicate client fails to initialize 110 | ValueError: If the model is not found 111 | Exception: If the API request fails 112 | """ 113 | async with ReplicateClient() as client: 114 | versions = client.get_model_versions(model_id) 115 | return [ModelVersion(**version) for version in versions] -------------------------------------------------------------------------------- /docs/workflows.md: -------------------------------------------------------------------------------- 1 | # MCP Server Workflow Patterns 2 | 3 | ## Text-to-Image Generation 4 | 5 | The text-to-image workflow is designed to help users find and use the right model for their specific needs. The workflow follows these steps: 6 | 7 | 1. **Initial User Input** 8 | 9 | - User provides their requirements through the `text_to_image` prompt 10 | - System guides them to specify: 11 | - Subject matter 12 | - Style preferences 13 | - Quality requirements 14 | - Technical requirements (size, etc.) 15 | 16 | 2. **Model Discovery** 17 | 18 | - System uses `search_available_models` to find suitable models 19 | - Models are scored based on: 20 | - Popularity (run count) 21 | - Featured status 22 | - Version stability 23 | - Tag matching 24 | - Task relevance 25 | - Results are sorted but presented to user for selection 26 | 27 | 3. **Model Selection** 28 | 29 | - User can view model details using `get_model_details` 30 | - System provides guidance on model capabilities 31 | - User makes final model selection 32 | - System confirms version to use 33 | 34 | 4. **Parameter Configuration** 35 | 36 | - System applies quality presets based on user requirements 37 | - Style presets are applied if specified 38 | - Size and other parameters are configured 39 | - User can override any parameters 40 | 41 | 5. **Image Generation** 42 | - System uses `generate_image` with selected model and parameters 43 | - Progress is tracked 44 | - Results are validated 45 | - Retry logic handles failures 46 | 47 | ### Example Flow 48 | 49 | ```python 50 | # 1. User provides requirements through text_to_image prompt 51 | response = await text_to_image() 52 | 53 | # 2. Search for models based on requirements 54 | models = await search_available_models( 55 | query="cat portrait", 56 | style="photorealistic" 57 | ) 58 | 59 | # 3. Get details for user's chosen model 60 | model = await get_model_details("stability-ai/sdxl") 61 | 62 | # 4. Generate image with chosen model 63 | result = await generate_image( 64 | model_version=model.latest_version.id, 65 | prompt="a photorealistic cat portrait", 66 | style="photorealistic", 67 | quality="balanced" 68 | ) 69 | ``` 70 | 71 | ### Key Principles 72 | 73 | 1. **User Agency** 74 | 75 | - System suggests but doesn't decide 76 | - User makes final model selection 77 | - Parameters can be overridden 78 | 79 | 2. **Transparency** 80 | 81 | - Model scoring is visible 82 | - Capabilities are clearly communicated 83 | - Limitations are disclosed 84 | 85 | 3. **Flexibility** 86 | 87 | - Multiple models can be used 88 | - Parameters are customizable 89 | - Quality/style presets are optional 90 | 91 | 4. **Reliability** 92 | - Version stability is considered 93 | - Error handling is robust 94 | - Progress is tracked 95 | 96 | ## Template Usage 97 | 98 | Templates provide consistent parameter sets but should be used flexibly: 99 | 100 | 1. **Quality Presets** 101 | 102 | - Provide baseline parameters 103 | - Can be overridden 104 | - Match user's speed/quality needs 105 | 106 | 2. **Style Presets** 107 | 108 | - Enhance prompts 109 | - Add style-specific parameters 110 | - Are optional and customizable 111 | 112 | 3. **Aspect Ratio Presets** 113 | - Match common use cases 114 | - Ensure valid dimensions 115 | - Can be customized 116 | 117 | ## Error Handling 118 | 119 | 1. **Model Selection** 120 | 121 | - Handle no matches gracefully 122 | - Provide alternatives 123 | - Explain limitations 124 | 125 | 2. **Parameter Validation** 126 | 127 | - Validate before submission 128 | - Provide clear error messages 129 | - Suggest corrections 130 | 131 | 3. **Generation Failures** 132 | - Implement retry logic 133 | - Track progress 134 | - Provide status updates 135 | 136 | ## Best Practices 137 | 138 | 1. **Model Selection** 139 | 140 | - Always let user choose model 141 | - Provide scoring context 142 | - Explain trade-offs 143 | 144 | 2. **Parameter Configuration** 145 | 146 | - Use presets as starting points 147 | - Allow customization 148 | - Validate combinations 149 | 150 | 3. **Error Handling** 151 | 152 | - Be proactive about potential issues 153 | - Provide clear error messages 154 | - Implement proper retry logic 155 | 156 | 4. **User Interaction** 157 | - Guide don't decide 158 | - Explain options 159 | - Respect user choices 160 | 161 | ``` 162 | 163 | ``` 164 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MCP Server Replicate 2 | 3 | First off, thank you for considering contributing to MCP Server Replicate! It's people like you that make it a great tool for everyone. 4 | 5 | ## Code of Conduct 6 | 7 | This project and everyone participating in it is governed by our Code of Conduct. By participating, you are expected to uphold this code. 8 | 9 | ## How Can I Contribute? 10 | 11 | ### Reporting Bugs 12 | 13 | Before creating bug reports, please check the issue list as you might find out that you don't need to create one. When you are creating a bug report, please include as many details as possible: 14 | 15 | - Use a clear and descriptive title 16 | - Describe the exact steps which reproduce the problem 17 | - Provide specific examples to demonstrate the steps 18 | - Describe the behavior you observed after following the steps 19 | - Explain which behavior you expected to see instead and why 20 | - Include any error messages or logs 21 | 22 | ### Suggesting Enhancements 23 | 24 | Enhancement suggestions are tracked as GitHub issues. When creating an enhancement suggestion, please include: 25 | 26 | - A clear and descriptive title 27 | - A detailed description of the proposed functionality 28 | - Any possible drawbacks or alternatives you've considered 29 | - If possible, a rough implementation plan 30 | 31 | ### Pull Requests 32 | 33 | 1. Fork the repo and create your branch from `main` 34 | 2. If you've added code that should be tested, add tests 35 | 3. If you've changed APIs, update the documentation 36 | 4. Ensure the test suite passes 37 | 5. Make sure your code follows the existing style 38 | 6. Issue that pull request! 39 | 40 | ## Development Process 41 | 42 | 1. Set up your development environment: 43 | 44 | ```bash 45 | # Clone your fork 46 | git clone https://github.com/your-username/mcp-server-replicate.git 47 | cd mcp-server-replicate 48 | 49 | # Create virtual environment 50 | python -m venv .venv 51 | source .venv/bin/activate # Linux/macOS 52 | # or 53 | .venv\Scripts\activate # Windows 54 | 55 | # Install development dependencies 56 | pip install -e ".[dev]" 57 | 58 | # Install pre-commit hooks 59 | pre-commit install 60 | ``` 61 | 62 | 2. Make your changes: 63 | 64 | - Write your code 65 | - Add or update tests 66 | - Update documentation 67 | - Run the test suite 68 | 69 | 3. Commit your changes: 70 | 71 | ```bash 72 | # Stage your changes 73 | git add . 74 | 75 | # Commit using conventional commits 76 | git commit -m "feat: add amazing feature" 77 | # or 78 | git commit -m "fix: resolve issue with something" 79 | ``` 80 | 81 | 4. Push and create a PR: 82 | 83 | ```bash 84 | git push origin your-branch-name 85 | ``` 86 | 87 | ## Style Guide 88 | 89 | ### Python Code Style 90 | 91 | - Follow [PEP 8](https://www.python.org/dev/peps/pep-0008/) 92 | - Use [Black](https://github.com/psf/black) for formatting 93 | - Use [Ruff](https://github.com/astral-sh/ruff) for linting 94 | - Use [mypy](https://github.com/python/mypy) for type checking 95 | - Write [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) 96 | 97 | ### Commit Messages 98 | 99 | Follow [Conventional Commits](https://www.conventionalcommits.org/): 100 | 101 | - `feat:` for new features 102 | - `fix:` for bug fixes 103 | - `docs:` for documentation changes 104 | - `style:` for formatting changes 105 | - `refactor:` for code refactoring 106 | - `test:` for adding tests 107 | - `chore:` for maintenance tasks 108 | 109 | ### Testing 110 | 111 | - Write unit tests for new functionality 112 | - Ensure all tests pass before submitting PR 113 | - Maintain or improve code coverage 114 | - Test edge cases and error conditions 115 | 116 | ## Documentation 117 | 118 | - Update README.md if needed 119 | - Add docstrings to new functions/classes 120 | - Update API documentation 121 | - Include examples for new features 122 | 123 | ## Review Process 124 | 125 | 1. Automated checks must pass: 126 | 127 | - Tests 128 | - Linting 129 | - Type checking 130 | - Code coverage 131 | 132 | 2. Code review requirements: 133 | 134 | - At least one approval 135 | - No unresolved comments 136 | - All automated checks pass 137 | 138 | 3. Merge requirements: 139 | - Up-to-date with main branch 140 | - No conflicts 141 | - All review requirements met 142 | 143 | ## Getting Help 144 | 145 | - Check the [documentation](docs/) 146 | - Join our [Discord community](https://discord.gg/cursor) 147 | - Ask in GitHub Discussions 148 | - Tag maintainers in issues/PRs 149 | 150 | ## Recognition 151 | 152 | Contributors will be recognized in: 153 | 154 | - The project's README 155 | - Release notes 156 | - GitHub's contributors page 157 | 158 | Thank you for contributing to MCP Server Replicate! 🎉 159 | -------------------------------------------------------------------------------- /PLAN.md: -------------------------------------------------------------------------------- 1 | # MCP Server Replicate - Implementation Plan 2 | 3 | ## Current Status 4 | 5 | The MCP Server Replicate project implements a FastMCP server for the Replicate API, providing: 6 | 7 | - Resource-based image generation and management 8 | - Subscription-based updates for generation progress 9 | - Template-driven parameter configuration 10 | - Comprehensive model discovery and selection 11 | - Webhook integration for external notifications 12 | 13 | ## Core Components 14 | 15 | ### 1. Resource Management 16 | 17 | - ✅ Generation resource templates 18 | - ✅ Resource subscription system 19 | - ✅ Resource listing and filtering 20 | - ✅ Resource search capabilities 21 | - ✅ Status-based filtering 22 | 23 | ### 2. Image Generation 24 | 25 | - ✅ Text-to-image generation 26 | - ✅ Quality presets 27 | - ✅ Style presets 28 | - ✅ Progress tracking 29 | - ✅ Error handling 30 | - ✅ Resource-based results 31 | 32 | ### 3. Model Management 33 | 34 | - ✅ Model discovery 35 | - ✅ Model search 36 | - ✅ Collection management 37 | - ✅ Hardware options 38 | - ✅ Version tracking 39 | 40 | ### 4. Template System 41 | 42 | - ✅ Parameter validation 43 | - ✅ Quality presets 44 | - ✅ Style presets 45 | - ✅ Version tracking 46 | - ✅ Schema validation 47 | 48 | ### 5. Client Integration 49 | 50 | - ✅ Async HTTP client 51 | - ✅ Rate limiting 52 | - ✅ Error handling 53 | - ✅ Webhook support 54 | - ✅ Resource streaming 55 | 56 | ## Upcoming Features 57 | 58 | ### Short Term (1-2 months) 59 | 60 | 1. Image-to-Image Generation 61 | 62 | - Support for image transformation 63 | - Inpainting capabilities 64 | - Style transfer 65 | - Upscaling 66 | 67 | 2. Enhanced Resource Management 68 | 69 | - Resource caching 70 | - Batch operations 71 | - Resource metadata 72 | - Resource tagging 73 | 74 | 3. Advanced Templates 75 | - Custom template creation 76 | - Template inheritance 77 | - Dynamic parameter validation 78 | - Template versioning 79 | 80 | ### Medium Term (3-6 months) 81 | 82 | 1. Advanced Model Features 83 | 84 | - Model fine-tuning support 85 | - Custom model deployment 86 | - Model performance metrics 87 | - A/B testing capabilities 88 | 89 | 2. Enhanced Monitoring 90 | 91 | - Usage analytics 92 | - Cost tracking 93 | - Performance monitoring 94 | - Error reporting 95 | 96 | 3. Integration Features 97 | - OAuth support 98 | - API key rotation 99 | - Rate limit optimization 100 | - Webhook enhancements 101 | 102 | ### Long Term (6+ months) 103 | 104 | 1. Enterprise Features 105 | 106 | - Multi-tenant support 107 | - Resource quotas 108 | - Audit logging 109 | - Role-based access 110 | 111 | 2. Advanced Workflows 112 | 113 | - Pipeline creation 114 | - Workflow templates 115 | - Custom scheduling 116 | - Result post-processing 117 | 118 | 3. Developer Tools 119 | - CLI improvements 120 | - SDK generation 121 | - Documentation tooling 122 | - Testing utilities 123 | 124 | ## Contributing 125 | 126 | See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on contributing to this project. 127 | 128 | ## Implementation Notes 129 | 130 | ### Resource System 131 | 132 | The resource system is implemented using FastMCP's resource capabilities: 133 | 134 | - Resources are identified by URIs (e.g., `generations://123`) 135 | - Resources support subscription for updates 136 | - Resources can be listed, filtered, and searched 137 | - Resources maintain proper state management 138 | 139 | ### Template System 140 | 141 | Templates provide structured parameter handling: 142 | 143 | - JSON Schema validation 144 | - Version tracking 145 | - Parameter inheritance 146 | - Default values 147 | - Validation rules 148 | 149 | ### Client Integration 150 | 151 | The client system provides robust API interaction: 152 | 153 | - Async operations 154 | - Proper error handling 155 | - Rate limiting 156 | - Resource streaming 157 | - Webhook support 158 | 159 | ### Installation and Usage 160 | 161 | The package provides two main ways to run the server: 162 | 163 | 1. Using UVX (recommended): 164 | ```bash 165 | uvx mcp-server-replicate 166 | ``` 167 | 168 | 2. Using UV directly: 169 | ```bash 170 | uv run mcp-server-replicate 171 | ``` 172 | 173 | The server can be integrated with Claude Desktop by configuring the appropriate command in `claude_desktop_config.json`. 174 | 175 | ## Testing Strategy 176 | 177 | 1. Unit Tests 178 | 179 | - Component isolation 180 | - Mocked dependencies 181 | - Edge case coverage 182 | - Error scenarios 183 | 184 | 2. Integration Tests 185 | 186 | - API interaction 187 | - Resource management 188 | - Template validation 189 | - Client operations 190 | 191 | 3. End-to-End Tests 192 | - Complete workflows 193 | - Real API interaction 194 | - Performance testing 195 | - Load testing 196 | -------------------------------------------------------------------------------- /tests/unit/test_parameters/test_controlnet.py: -------------------------------------------------------------------------------- 1 | """Tests for ControlNet parameter templates.""" 2 | 3 | import pytest 4 | from typing import Dict, Any 5 | 6 | from mcp-server-replicate.templates.parameters.controlnet import ( 7 | TEMPLATES, 8 | CANNY_PARAMETERS, 9 | DEPTH_PARAMETERS, 10 | POSE_PARAMETERS, 11 | SEGMENTATION_PARAMETERS, 12 | ) 13 | 14 | def test_templates_export(): 15 | """Test that all templates are properly exported.""" 16 | assert "canny" in TEMPLATES 17 | assert "depth" in TEMPLATES 18 | assert "pose" in TEMPLATES 19 | assert "segmentation" in TEMPLATES 20 | 21 | @pytest.mark.parametrize("template", [ 22 | CANNY_PARAMETERS, 23 | DEPTH_PARAMETERS, 24 | POSE_PARAMETERS, 25 | SEGMENTATION_PARAMETERS, 26 | ]) 27 | def test_template_structure(template: Dict[str, Any]): 28 | """Test that each template has the required structure.""" 29 | assert "id" in template 30 | assert "name" in template 31 | assert "description" in template 32 | assert "model_type" in template 33 | assert "control_type" in template 34 | assert "default_parameters" in template 35 | assert "parameter_schema" in template 36 | assert "version" in template 37 | 38 | # Check parameter schema structure 39 | schema = template["parameter_schema"] 40 | assert schema["type"] == "object" 41 | assert "properties" in schema 42 | assert "required" in schema 43 | 44 | # Check required base parameters 45 | properties = schema["properties"] 46 | assert "prompt" in properties 47 | assert "image" in properties 48 | assert "num_inference_steps" in properties 49 | assert "guidance_scale" in properties 50 | assert "controlnet_conditioning_scale" in properties 51 | 52 | # Check required fields are listed 53 | assert "prompt" in schema["required"] 54 | assert "image" in schema["required"] 55 | 56 | def test_canny_parameters(): 57 | """Test Canny edge detection specific parameters.""" 58 | template = CANNY_PARAMETERS 59 | 60 | # Check Canny-specific parameters 61 | assert "low_threshold" in template["default_parameters"] 62 | assert "high_threshold" in template["default_parameters"] 63 | 64 | properties = template["parameter_schema"]["properties"] 65 | assert "low_threshold" in properties 66 | assert properties["low_threshold"]["type"] == "integer" 67 | assert properties["low_threshold"]["minimum"] == 1 68 | assert properties["low_threshold"]["maximum"] == 255 69 | 70 | assert "high_threshold" in properties 71 | assert properties["high_threshold"]["type"] == "integer" 72 | assert properties["high_threshold"]["minimum"] == 1 73 | assert properties["high_threshold"]["maximum"] == 255 74 | 75 | def test_depth_parameters(): 76 | """Test depth estimation specific parameters.""" 77 | template = DEPTH_PARAMETERS 78 | 79 | # Check depth-specific parameters 80 | assert "detect_resolution" in template["default_parameters"] 81 | assert "boost" in template["default_parameters"] 82 | 83 | properties = template["parameter_schema"]["properties"] 84 | assert "detect_resolution" in properties 85 | assert properties["detect_resolution"]["type"] == "integer" 86 | assert properties["detect_resolution"]["minimum"] == 128 87 | assert properties["detect_resolution"]["maximum"] == 1024 88 | 89 | assert "boost" in properties 90 | assert properties["boost"]["type"] == "number" 91 | assert properties["boost"]["minimum"] == 0.0 92 | assert properties["boost"]["maximum"] == 2.0 93 | 94 | def test_pose_parameters(): 95 | """Test pose detection specific parameters.""" 96 | template = POSE_PARAMETERS 97 | 98 | # Check pose-specific parameters 99 | assert "detect_resolution" in template["default_parameters"] 100 | assert "include_hand_pose" in template["default_parameters"] 101 | assert "include_face_landmarks" in template["default_parameters"] 102 | 103 | properties = template["parameter_schema"]["properties"] 104 | assert "detect_resolution" in properties 105 | assert "include_hand_pose" in properties 106 | assert properties["include_hand_pose"]["type"] == "boolean" 107 | assert "include_face_landmarks" in properties 108 | assert properties["include_face_landmarks"]["type"] == "boolean" 109 | 110 | def test_segmentation_parameters(): 111 | """Test segmentation specific parameters.""" 112 | template = SEGMENTATION_PARAMETERS 113 | 114 | # Check segmentation-specific parameters 115 | assert "detect_resolution" in template["default_parameters"] 116 | assert "output_type" in template["default_parameters"] 117 | 118 | properties = template["parameter_schema"]["properties"] 119 | assert "detect_resolution" in properties 120 | assert "output_type" in properties 121 | assert properties["output_type"]["type"] == "string" 122 | assert "ade20k" in properties["output_type"]["enum"] 123 | assert "coco" in properties["output_type"]["enum"] -------------------------------------------------------------------------------- /.cursorrules: -------------------------------------------------------------------------------- 1 | // You are a Staff/Principal Python Engineer with deep expertise in: 2 | // - Modern Python development practices (Python 3.11+) 3 | // - FastMCP and Model Context Protocol implementations 4 | // - Replicate API integration and model management 5 | // - Async Python and high-performance patterns 6 | // Your role is to ensure this codebase maintains the highest standards of Python engineering. 7 | 8 | "pyproject.toml": 9 | - "Target Python 3.11+ with type safety and modern features" 10 | - "Use hatchling for modern, standardized builds" 11 | - "Maintain strict dependency versioning (fastmcp>=0.1.0, replicate>=0.15.0)" 12 | - "Configure comprehensive test coverage with pytest-cov" 13 | - "Enforce strict type checking with mypy (disallow_untyped_defs=true)" 14 | - "Use ruff for fast, comprehensive linting (E, F, B, I, UP rules)" 15 | - "Maintain 88-char line length (black) with 120-char linting tolerance" 16 | - "Use uv for dependency management and virtual environments" 17 | - "Keep dev dependencies in optional-dependencies section" 18 | - "Define console scripts for CLI entry points" 19 | 20 | "*.py": 21 | - "Use strict type hints with Python 3.11+ features" 22 | - "Write comprehensive docstrings following Google style" 23 | - "Follow PEP 8 with black formatting" 24 | - "Use Pydantic v2 models for all data structures" 25 | - "Implement structured logging with proper levels" 26 | - "Use pathlib exclusively for file operations" 27 | - "Handle errors with custom exception hierarchies" 28 | - "Implement proper async context managers" 29 | - "Use f-strings for all string formatting" 30 | - "Avoid mutable default arguments" 31 | - "Use Pydantic v2 field_validator instead of validator" 32 | - "Implement proper cleanup in context managers" 33 | - "Use dataclasses for internal data structures" 34 | 35 | "src/mcp_server_replicate/server.py": 36 | - "Create FastMCP instance with descriptive name and version" 37 | - "Organize tools by model type and functionality" 38 | - "Use async handlers for all I/O operations" 39 | - "Implement proper validation with Pydantic models" 40 | - "Return structured responses from tools" 41 | - "Handle errors with appropriate status codes" 42 | - "Use descriptive tool docstrings" 43 | - "Configure logging with proper context" 44 | - "Use proper typing for tool parameters" 45 | - "Implement graceful shutdown handling" 46 | - "Track tool usage metrics" 47 | - "Handle concurrent requests properly" 48 | 49 | "src/mcp_server_replicate/templates/parameters/**/*.py": 50 | - "Use Pydantic models for parameter validation" 51 | - "Implement version tracking for templates" 52 | - "Share common parameters via base classes" 53 | - "Document all parameters comprehensively" 54 | - "Validate parameter constraints" 55 | - "Use proper type hints" 56 | - "Export templates via __all__" 57 | - "Maintain backward compatibility" 58 | - "Include parameter examples" 59 | - "Document parameter interactions" 60 | - "Handle model-specific requirements" 61 | - "Implement parameter validation logic" 62 | 63 | "src/mcp_server_replicate/replicate_client.py": 64 | - "Use httpx with proper timeout handling" 65 | - "Implement exponential backoff with rate limiting" 66 | - "Use proper API versioning" 67 | - "Implement comprehensive error mapping" 68 | - "Use structured logging for API operations" 69 | - "Implement proper connection pooling" 70 | - "Handle API authentication securely" 71 | - "Use async context managers" 72 | - "Implement proper request retries" 73 | - "Handle rate limiting headers" 74 | - "Track API usage metrics" 75 | - "Cache responses appropriately" 76 | 77 | "src/mcp_server_replicate/models/**/*.py": 78 | - "Use Pydantic models for all data structures" 79 | - "Implement proper validation logic" 80 | - "Handle model versioning" 81 | - "Track model usage metrics" 82 | - "Implement proper caching" 83 | - "Handle model-specific configurations" 84 | - "Validate model compatibility" 85 | - "Track model performance metrics" 86 | - "Handle model updates gracefully" 87 | - "Implement proper cleanup" 88 | 89 | "tests/**/*.py": 90 | - "Maintain minimum 90% coverage" 91 | - "Use pytest fixtures for common scenarios" 92 | - "Implement proper async test patterns" 93 | - "Use pytest-randomly with fixed seeds" 94 | - "Mock external services comprehensively" 95 | - "Test all error conditions" 96 | - "Implement proper cleanup in fixtures" 97 | - "Use coverage exclusions appropriately" 98 | - "Test parameter validation thoroughly" 99 | - "Use parametrized tests for variations" 100 | - "Test async timeouts and cancellation" 101 | - "Validate error responses" 102 | - "Test rate limiting handling" 103 | - "Test concurrent operations" 104 | 105 | ".pre-commit-config.yaml": 106 | - "Configure ruff with specified rule sets" 107 | - "Enable strict mypy type checking" 108 | - "Run pytest with coverage enforcement" 109 | - "Enforce black formatting" 110 | - "Check for security issues" 111 | - "Validate pyproject.toml format" 112 | - "Check for large files" 113 | - "Validate JSON/YAML syntax" 114 | - "Check for merge conflicts" 115 | - "Enforce commit message format" 116 | - "Check for debug statements" 117 | 118 | "docs/**/*.md": 119 | - "Follow Google documentation style" 120 | - "Include code examples" 121 | - "Document all parameters" 122 | - "Explain error scenarios" 123 | - "Provide troubleshooting guides" 124 | - "Include version compatibility" 125 | - "Document breaking changes" 126 | - "Add API reference" 127 | - "Include architecture diagrams" 128 | - "Document performance considerations" 129 | - "Provide upgrade guides" 130 | - "Include security considerations" 131 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/controlnet.py: -------------------------------------------------------------------------------- 1 | """Parameter templates for ControlNet models.""" 2 | 3 | from typing import Dict, Any 4 | from enum import Enum 5 | 6 | class ControlMode(str, Enum): 7 | """Control modes for ControlNet.""" 8 | BALANCED = "balanced" 9 | PROMPT = "prompt" 10 | CONTROL = "control" 11 | 12 | CONTROLNET_PARAMETERS = { 13 | "id": "controlnet-base", 14 | "name": "ControlNet Base Parameters", 15 | "description": "Parameters for ControlNet-enabled Stable Diffusion models", 16 | "model_type": "controlnet", 17 | "default_parameters": { 18 | "control_mode": "balanced", 19 | "control_scale": 0.9, 20 | "begin_control_step": 0.0, 21 | "end_control_step": 1.0, 22 | "detection_resolution": 512, 23 | "image_resolution": 512, 24 | "guess_mode": False, 25 | }, 26 | "parameter_schema": { 27 | "type": "object", 28 | "properties": { 29 | "control_image": { 30 | "type": "string", 31 | "format": "uri", 32 | "description": "URL or base64 of the control image (edge map, depth map, pose, etc.)" 33 | }, 34 | "control_mode": { 35 | "type": "string", 36 | "enum": ["balanced", "prompt", "control"], 37 | "description": "How to balance between prompt and control. balanced=0.5/0.5, prompt=0.25/0.75, control=0.75/0.25" 38 | }, 39 | "control_scale": { 40 | "type": "number", 41 | "minimum": 0.0, 42 | "maximum": 2.0, 43 | "description": "Overall influence of the control signal. Higher values = stronger control." 44 | }, 45 | "begin_control_step": { 46 | "type": "number", 47 | "minimum": 0.0, 48 | "maximum": 1.0, 49 | "description": "When to start applying control (0.0 = start, 1.0 = end)" 50 | }, 51 | "end_control_step": { 52 | "type": "number", 53 | "minimum": 0.0, 54 | "maximum": 1.0, 55 | "description": "When to stop applying control (0.0 = start, 1.0 = end)" 56 | }, 57 | "detection_resolution": { 58 | "type": "integer", 59 | "minimum": 256, 60 | "maximum": 1024, 61 | "multipleOf": 8, 62 | "description": "Resolution for control signal detection. Higher = more detail but slower." 63 | }, 64 | "image_resolution": { 65 | "type": "integer", 66 | "minimum": 256, 67 | "maximum": 1024, 68 | "multipleOf": 8, 69 | "description": "Output image resolution. Should match detection_resolution for best results." 70 | }, 71 | "guess_mode": { 72 | "type": "boolean", 73 | "description": "Enable 'guess mode' for reference-only control (no exact matching)" 74 | }, 75 | "preprocessor": { 76 | "type": "string", 77 | "enum": [ 78 | "canny", 79 | "depth", 80 | "mlsd", 81 | "normal", 82 | "openpose", 83 | "scribble", 84 | "seg", 85 | "shuffle", 86 | "softedge", 87 | "tile" 88 | ], 89 | "description": "Type of preprocessing to apply to control image" 90 | } 91 | }, 92 | "required": ["control_image", "preprocessor"], 93 | "dependencies": { 94 | "preprocessor": { 95 | "oneOf": [ 96 | { 97 | "properties": { 98 | "preprocessor": {"enum": ["canny"]}, 99 | "low_threshold": { 100 | "type": "integer", 101 | "minimum": 1, 102 | "maximum": 255, 103 | "description": "Lower threshold for Canny edge detection" 104 | }, 105 | "high_threshold": { 106 | "type": "integer", 107 | "minimum": 1, 108 | "maximum": 255, 109 | "description": "Upper threshold for Canny edge detection" 110 | } 111 | } 112 | }, 113 | { 114 | "properties": { 115 | "preprocessor": {"enum": ["mlsd"]}, 116 | "score_threshold": { 117 | "type": "number", 118 | "minimum": 0.1, 119 | "maximum": 0.9, 120 | "description": "Confidence threshold for line detection" 121 | }, 122 | "distance_threshold": { 123 | "type": "number", 124 | "minimum": 0.1, 125 | "maximum": 20.0, 126 | "description": "Distance threshold for line merging" 127 | } 128 | } 129 | } 130 | ] 131 | } 132 | } 133 | }, 134 | "version": "1.0.0" 135 | } 136 | 137 | # Export all templates 138 | TEMPLATES: Dict[str, Dict[str, Any]] = { 139 | "controlnet": CONTROLNET_PARAMETERS, 140 | } -------------------------------------------------------------------------------- /docs/templates.md: -------------------------------------------------------------------------------- 1 | # Template Documentation 2 | 3 | ## Overview 4 | 5 | This document provides comprehensive documentation for all templates available in the MCP Server for Replicate. Templates are organized into several categories: 6 | 7 | 1. Model Parameters 8 | 2. Common Configurations 9 | 3. Prompt Templates 10 | 11 | ## Model Parameters 12 | 13 | ### SDXL Parameters 14 | 15 | The SDXL template provides parameters optimized for Stable Diffusion XL models. 16 | 17 | ```python 18 | { 19 | "prompt": "your detailed prompt", 20 | "negative_prompt": "elements to avoid", 21 | "width": 1024, # 512-2048, multiple of 8 22 | "height": 1024, # 512-2048, multiple of 8 23 | "num_inference_steps": 50, # 1-150 24 | "guidance_scale": 7.5, # 1-20 25 | "prompt_strength": 1.0, # 0-1 26 | "refine": "expert_ensemble_refiner", # or "no_refiner", "base_image_refiner" 27 | "scheduler": "K_EULER", # or "DDIM", "DPM_MULTISTEP", "PNDM", "KLMS" 28 | "num_outputs": 1, # 1-4 29 | "high_noise_frac": 0.8, # 0-1 30 | "seed": null, # null or integer 31 | "apply_watermark": true 32 | } 33 | ``` 34 | 35 | ### SD 1.5 Parameters 36 | 37 | The SD 1.5 template provides parameters optimized for Stable Diffusion 1.5 models. 38 | 39 | ```python 40 | { 41 | "prompt": "your detailed prompt", 42 | "negative_prompt": "elements to avoid", 43 | "width": 512, # 256-1024, multiple of 8 44 | "height": 512, # 256-1024, multiple of 8 45 | "num_inference_steps": 50, # 1-150 46 | "guidance_scale": 7.5, # 1-20 47 | "scheduler": "K_EULER", # or "DDIM", "DPM_MULTISTEP", "PNDM", "KLMS" 48 | "num_outputs": 1, # 1-4 49 | "seed": null, # null or integer 50 | "apply_watermark": true 51 | } 52 | ``` 53 | 54 | ### ControlNet Parameters 55 | 56 | The ControlNet template provides parameters for controlled image generation. 57 | 58 | ```python 59 | { 60 | "control_image": "image_url_or_base64", 61 | "control_mode": "balanced", # or "prompt", "control" 62 | "control_scale": 0.9, # 0-2 63 | "begin_control_step": 0.0, # 0-1 64 | "end_control_step": 1.0, # 0-1 65 | "detection_resolution": 512, # 256-1024, multiple of 8 66 | "image_resolution": 512, # 256-1024, multiple of 8 67 | "guess_mode": false, 68 | "preprocessor": "canny" # or other preprocessors 69 | } 70 | ``` 71 | 72 | ## Common Configurations 73 | 74 | ### Quality Presets 75 | 76 | Pre-configured quality settings for different use cases: 77 | 78 | - `draft`: Fast iterations (20 steps) 79 | - `balanced`: General use (30 steps) 80 | - `quality`: High quality (50 steps) 81 | - `extreme`: Maximum quality (150 steps) 82 | 83 | ### Style Presets 84 | 85 | Pre-configured style settings: 86 | 87 | - `photorealistic`: Highly detailed photo style 88 | - `cinematic`: Movie-like dramatic style 89 | - `anime`: Anime/manga style 90 | - `digital_art`: Modern digital art style 91 | - `oil_painting`: Classical painting style 92 | 93 | ### Aspect Ratio Presets 94 | 95 | Common aspect ratios with optimal resolutions: 96 | 97 | - `square`: 1:1 (1024x1024) 98 | - `portrait`: 2:3 (832x1216) 99 | - `landscape`: 3:2 (1216x832) 100 | - `wide`: 16:9 (1344x768) 101 | - `mobile`: 9:16 (768x1344) 102 | 103 | ### Negative Prompt Presets 104 | 105 | Quality control negative prompts: 106 | 107 | - `quality_control`: Basic quality control 108 | - `strict_quality`: Comprehensive quality control 109 | - `photo_quality`: Photo-specific quality control 110 | - `artistic_quality`: Art-specific quality control 111 | 112 | ## Prompt Templates 113 | 114 | ### Text-to-Image 115 | 116 | #### Detailed Scene Template 117 | 118 | ``` 119 | {subject} in {setting}, {lighting} lighting, {mood} atmosphere, {style} style, {details} 120 | ``` 121 | 122 | Example: 123 | 124 | ``` 125 | "a young explorer in ancient temple ruins, dramatic golden hour lighting, mysterious atmosphere, cinematic style, vines growing on weathered stone, dust particles in light beams" 126 | ``` 127 | 128 | #### Character Portrait Template 129 | 130 | ``` 131 | {gender} {character_type}, {appearance}, {clothing}, {expression}, {pose}, {style} style, {background} 132 | ``` 133 | 134 | #### Landscape Template 135 | 136 | ``` 137 | {environment} landscape, {time_of_day}, {weather}, {features}, {style} style, {mood} mood 138 | ``` 139 | 140 | ### Image-to-Image 141 | 142 | #### Style Transfer Template 143 | 144 | ``` 145 | Transform into {style} style, {quality} quality, maintain {preserve} from original 146 | ``` 147 | 148 | #### Variation Template 149 | 150 | ``` 151 | Similar to original but with {changes}, {style} style, {quality} quality 152 | ``` 153 | 154 | ### ControlNet 155 | 156 | #### Pose-Guided Template 157 | 158 | ``` 159 | {subject} in {pose_description}, {clothing}, {style} style, {background} 160 | ``` 161 | 162 | #### Depth-Guided Template 163 | 164 | ``` 165 | {subject} with {depth_elements}, {perspective}, {style} style 166 | ``` 167 | 168 | ## Best Practices 169 | 170 | 1. **Parameter Selection** 171 | 172 | - Start with preset configurations 173 | - Adjust parameters gradually 174 | - Use appropriate aspect ratios for your use case 175 | 176 | 2. **Prompt Engineering** 177 | 178 | - Use detailed, specific descriptions 179 | - Include style and quality indicators 180 | - Use negative prompts for quality control 181 | 182 | 3. **ControlNet Usage** 183 | 184 | - Match detection and output resolutions 185 | - Use appropriate preprocessors for your use case 186 | - Adjust control scale based on desired influence 187 | 188 | 4. **Quality Optimization** 189 | - Use higher step counts for final outputs 190 | - Adjust guidance scale for creativity vs. accuracy 191 | - Use refiners for enhanced quality 192 | 193 | ## Version History 194 | 195 | - v1.1.0: Added comprehensive parameter descriptions and validation 196 | - v1.0.0: Initial release with basic parameters 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MCP Server Replicate 2 | 3 | [![Python Version](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/downloads/) 4 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) 5 | [![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | [![Type Checker](https://img.shields.io/badge/type%20checker-mypy-blue.svg)](https://github.com/python/mypy) 7 | [![Ruff](https://img.shields.io/badge/linter-ruff-red.svg)](https://github.com/astral-sh/ruff) 8 | [![PyPI version](https://badge.fury.io/py/mcp-server-replicate.svg)](https://pypi.org/project/mcp-server-replicate/) 9 | [![smithery badge](https://smithery.ai/badge/@gerred/mcp-server-replicate)](https://smithery.ai/server/@gerred/mcp-server-replicate) 10 | 11 | A FastMCP server implementation for the Replicate API, providing resource-based access to AI model inference with a focus on image generation. 12 | 13 | Server Replicate MCP server 14 | 15 | ## Features 16 | 17 | - 🖼️ Resource-based image generation and management 18 | - 🔄 Real-time updates through subscriptions 19 | - 📝 Template-driven parameter configuration 20 | - 🔍 Comprehensive model discovery and selection 21 | - 🪝 Webhook integration for external notifications 22 | - 🎨 Quality and style presets for optimal results 23 | - 📊 Progress tracking and status monitoring 24 | - 🔒 Secure API key management 25 | 26 | ## Available Prompts 27 | 28 | The server provides several specialized prompts for different tasks: 29 | 30 | ### Text to Image (Primary) 31 | 32 | Our most thoroughly tested and robust prompt. Optimized for generating high-quality images from text descriptions with: 33 | 34 | - Detailed style control 35 | - Quality presets (draft, balanced, quality, extreme) 36 | - Size and aspect ratio customization 37 | - Progress tracking and real-time updates 38 | 39 | Example: 40 | 41 | ``` 42 | Create a photorealistic mountain landscape at sunset with snow-capped peaks, quality level: quality, style: photorealistic 43 | ``` 44 | 45 | ### Other Prompts 46 | 47 | - **Image to Image**: Transform existing images (coming soon) 48 | - **Model Selection**: Get help choosing the right model for your task 49 | - **Parameter Help**: Understand and configure model parameters 50 | 51 | ## Prerequisites 52 | 53 | - Python 3.11 or higher 54 | - A Replicate API key (get one at https://replicate.com/account) 55 | - [UV](https://github.com/astral-sh/uv) for dependency management 56 | 57 | ## Installation 58 | 59 | ### Installing via Smithery 60 | 61 | To install MCP Server Replicate for Claude Desktop automatically via [Smithery](https://smithery.ai/server/@gerred/mcp-server-replicate): 62 | 63 | ```bash 64 | npx -y @smithery/cli install @gerred/mcp-server-replicate --client claude 65 | ``` 66 | 67 | ### Installing Manually 68 | You can install the package directly from PyPI: 69 | 70 | ```bash 71 | # Using UV (recommended) 72 | uv pip install mcp-server-replicate 73 | 74 | # Using UVX for isolated environments 75 | uvx install mcp-server-replicate 76 | 77 | # Using pip 78 | pip install mcp-server-replicate 79 | ``` 80 | 81 | ## Claude Desktop Integration 82 | 83 | 1. Make sure you have the latest version of Claude Desktop installed 84 | 2. Open your Claude Desktop configuration: 85 | 86 | ```bash 87 | # macOS 88 | code ~/Library/Application\ Support/Claude/claude_desktop_config.json 89 | 90 | # Windows 91 | code %APPDATA%\Claude\claude_desktop_config.json 92 | ``` 93 | 94 | 3. Add the server configuration using one of these options: 95 | 96 | ```json 97 | { 98 | "globalShortcut": "Shift+Alt+A", 99 | "mcpServers": { 100 | "replicate": { 101 | "command": "uv", 102 | "args": ["tool", "run", "mcp-server-replicate"], 103 | "env": { 104 | "REPLICATE_API_TOKEN": "APITOKEN" 105 | }, 106 | "cwd": "$PATH_TO_REPO" 107 | } 108 | } 109 | } 110 | ``` 111 | 112 | 4. Set your Replicate API key: 113 | 114 | ```bash 115 | # Option 1: Set in your environment 116 | export REPLICATE_API_TOKEN=your_api_key_here 117 | 118 | # Option 2: Create a .env file in your home directory 119 | echo "REPLICATE_API_TOKEN=your_api_key_here" > ~/.env 120 | ``` 121 | 122 | 5. Restart Claude Desktop completely 123 | 124 | You should now see the 🔨 icon in Claude Desktop, indicating that the MCP server is available. 125 | 126 | ## Usage 127 | 128 | Once connected to Claude Desktop, you can: 129 | 130 | 1. Generate images with natural language: 131 | 132 | ``` 133 | Create a photorealistic mountain landscape at sunset with snow-capped peaks 134 | ``` 135 | 136 | 2. Browse your generations: 137 | 138 | ``` 139 | Show me my recent image generations 140 | ``` 141 | 142 | 3. Search through generations: 143 | 144 | ``` 145 | Find my landscape generations 146 | ``` 147 | 148 | 4. Check generation status: 149 | ``` 150 | What's the status of my last generation? 151 | ``` 152 | 153 | ## Troubleshooting 154 | 155 | ### Server not showing up in Claude Desktop 156 | 157 | 1. Check the Claude Desktop logs: 158 | 159 | ```bash 160 | tail -n 20 -f ~/Library/Logs/Claude/mcp*.log 161 | ``` 162 | 163 | 2. Verify your configuration: 164 | 165 | - Make sure the path in `claude_desktop_config.json` is absolute 166 | - Ensure UV is installed and in your PATH 167 | - Check that your Replicate API key is set 168 | 169 | 3. Try restarting Claude Desktop 170 | 171 | For more detailed troubleshooting, see our [Debugging Guide](docs/debugging.md). 172 | 173 | ## Documentation 174 | 175 | - [Implementation Plan](PLAN.md) 176 | - [Contributing Guide](CONTRIBUTING.md) 177 | - [API Reference](docs/api.md) 178 | - [Resource System](docs/resources.md) 179 | - [Template System](docs/templates.md) 180 | 181 | ## Development 182 | 183 | 1. Clone the repository: 184 | 185 | ```bash 186 | git clone https://github.com/gerred/mcp-server-replicate.git 187 | cd mcp-server-replicate 188 | ``` 189 | 190 | 2. Install development dependencies: 191 | 192 | ```bash 193 | uv pip install --system ".[dev]" 194 | ``` 195 | 196 | 3. Install pre-commit hooks: 197 | 198 | ```bash 199 | pre-commit install 200 | ``` 201 | 202 | 4. Run tests: 203 | 204 | ```bash 205 | pytest 206 | ``` 207 | 208 | ## Contributing 209 | 210 | We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details. 211 | 212 | ## License 213 | 214 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 215 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/common_configs.py: -------------------------------------------------------------------------------- 1 | """Common model configuration templates that can be reused across different models.""" 2 | 3 | from typing import Any 4 | 5 | QUALITY_PRESETS = { 6 | "id": "quality-presets", 7 | "name": "Quality Presets", 8 | "description": "Common quality presets for different generation scenarios", 9 | "model_type": "any", 10 | "presets": { 11 | "draft": { 12 | "description": "Fast draft quality for quick iterations", 13 | "parameters": { 14 | "num_inference_steps": 20, 15 | "guidance_scale": 5.0, 16 | "width": 512, 17 | "height": 512, 18 | }, 19 | }, 20 | "balanced": { 21 | "description": "Balanced quality and speed for most use cases", 22 | "parameters": { 23 | "num_inference_steps": 30, 24 | "guidance_scale": 7.5, 25 | "width": 768, 26 | "height": 768, 27 | }, 28 | }, 29 | "quality": { 30 | "description": "High quality for final outputs", 31 | "parameters": { 32 | "num_inference_steps": 50, 33 | "guidance_scale": 7.5, 34 | "width": 1024, 35 | "height": 1024, 36 | }, 37 | }, 38 | "extreme": { 39 | "description": "Maximum quality, very slow", 40 | "parameters": { 41 | "num_inference_steps": 150, 42 | "guidance_scale": 8.0, 43 | "width": 1536, 44 | "height": 1536, 45 | }, 46 | }, 47 | }, 48 | "version": "1.0.0", 49 | } 50 | 51 | STYLE_PRESETS = { 52 | "id": "style-presets", 53 | "name": "Style Presets", 54 | "description": "Common style presets for different artistic looks", 55 | "model_type": "any", 56 | "presets": { 57 | "photorealistic": { 58 | "description": "Highly detailed photorealistic style", 59 | "parameters": { 60 | "prompt_prefix": "professional photograph, photorealistic, highly detailed, 8k uhd", 61 | "negative_prompt": "painting, drawing, illustration, anime, cartoon, artistic, unrealistic", 62 | "guidance_scale": 8.0, 63 | }, 64 | }, 65 | "cinematic": { 66 | "description": "Dramatic cinematic style", 67 | "parameters": { 68 | "prompt_prefix": "cinematic shot, dramatic lighting, movie scene, high budget film", 69 | "negative_prompt": "low quality, amateur, poorly lit", 70 | "guidance_scale": 7.5, 71 | }, 72 | }, 73 | "anime": { 74 | "description": "Anime/manga style", 75 | "parameters": { 76 | "prompt_prefix": "anime style, manga art, clean lines, vibrant colors", 77 | "negative_prompt": "photorealistic, 3d render, photograph, western art style", 78 | "guidance_scale": 7.0, 79 | }, 80 | }, 81 | "digital_art": { 82 | "description": "Digital art style", 83 | "parameters": { 84 | "prompt_prefix": "digital art, vibrant colors, detailed illustration", 85 | "negative_prompt": "photograph, realistic, grainy, noisy", 86 | "guidance_scale": 7.0, 87 | }, 88 | }, 89 | "oil_painting": { 90 | "description": "Oil painting style", 91 | "parameters": { 92 | "prompt_prefix": "oil painting, textured brushstrokes, artistic, rich colors", 93 | "negative_prompt": "photograph, digital art, 3d render, smooth", 94 | "guidance_scale": 7.0, 95 | }, 96 | }, 97 | }, 98 | "version": "1.0.0", 99 | } 100 | 101 | ASPECT_RATIO_PRESETS = { 102 | "id": "aspect-ratio-presets", 103 | "name": "Aspect Ratio Presets", 104 | "description": "Common aspect ratio presets for different use cases", 105 | "model_type": "any", 106 | "presets": { 107 | "square": { 108 | "description": "1:1 square format", 109 | "parameters": { 110 | "width": 1024, 111 | "height": 1024, 112 | }, 113 | }, 114 | "portrait": { 115 | "description": "2:3 portrait format", 116 | "parameters": { 117 | "width": 832, 118 | "height": 1216, 119 | }, 120 | }, 121 | "landscape": { 122 | "description": "3:2 landscape format", 123 | "parameters": { 124 | "width": 1216, 125 | "height": 832, 126 | }, 127 | }, 128 | "wide": { 129 | "description": "16:9 widescreen format", 130 | "parameters": { 131 | "width": 1344, 132 | "height": 768, 133 | }, 134 | }, 135 | "mobile": { 136 | "description": "9:16 mobile format", 137 | "parameters": { 138 | "width": 768, 139 | "height": 1344, 140 | }, 141 | }, 142 | }, 143 | "version": "1.0.0", 144 | } 145 | 146 | NEGATIVE_PROMPT_PRESETS = { 147 | "id": "negative-prompt-presets", 148 | "name": "Negative Prompt Presets", 149 | "description": "Common negative prompts for quality control", 150 | "model_type": "any", 151 | "presets": { 152 | "quality_control": { 153 | "description": "Basic quality control negative prompt", 154 | "parameters": {"negative_prompt": "ugly, blurry, low quality, distorted, disfigured, bad anatomy"}, 155 | }, 156 | "strict_quality": { 157 | "description": "Strict quality control negative prompt", 158 | "parameters": { 159 | "negative_prompt": "ugly, blurry, low quality, distorted, disfigured, bad anatomy, bad proportions, duplicate, extra limbs, missing limbs, poorly drawn face, poorly drawn hands, mutation, mutated, extra fingers, missing fingers, floating limbs, disconnected limbs, malformed limbs, oversaturated, undersaturated" 160 | }, 161 | }, 162 | "photo_quality": { 163 | "description": "Photo-specific quality control", 164 | "parameters": { 165 | "negative_prompt": "blurry, low quality, noise, grain, chromatic aberration, lens flare, overexposed, underexposed, bad composition, amateur, poorly lit" 166 | }, 167 | }, 168 | "artistic_quality": { 169 | "description": "Art-specific quality control", 170 | "parameters": { 171 | "negative_prompt": "amateur, poorly drawn, bad art, poorly drawn hands, poorly drawn face, poorly drawn eyes, poorly drawn nose, poorly drawn mouth, poorly drawn ears, poorly drawn body, poorly drawn legs, poorly drawn feet" 172 | }, 173 | }, 174 | }, 175 | "version": "1.0.0", 176 | } 177 | 178 | # Export all templates 179 | TEMPLATES: dict[str, dict[str, Any]] = { 180 | "quality": QUALITY_PRESETS, 181 | "style": STYLE_PRESETS, 182 | "aspect_ratio": ASPECT_RATIO_PRESETS, 183 | "negative_prompt": NEGATIVE_PROMPT_PRESETS, 184 | } 185 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | """Tests for FastMCP server implementation.""" 2 | 3 | import json 4 | import pytest 5 | from unittest.mock import AsyncMock, patch 6 | from datetime import datetime 7 | 8 | from mcp-server-replicate.server import create_server 9 | from mcp-server-replicate.models.model import Model, ModelList 10 | from mcp-server-replicate.models.collection import Collection, CollectionList 11 | from mcp-server-replicate.models.hardware import Hardware, HardwareList 12 | from mcp-server-replicate.models.webhook import WebhookPayload 13 | 14 | # Test data 15 | MOCK_MODEL = { 16 | "owner": "stability-ai", 17 | "name": "sdxl", 18 | "description": "Stable Diffusion XL", 19 | "visibility": "public", 20 | "latest_version_id": "v1.0.0", 21 | "latest_version_created_at": "2024-01-01T00:00:00Z" 22 | } 23 | 24 | MOCK_COLLECTION = { 25 | "name": "Text to Image", 26 | "slug": "text-to-image", 27 | "description": "Models for generating images from text" 28 | } 29 | 30 | MOCK_HARDWARE = { 31 | "name": "GPU T4", 32 | "sku": "gpu-t4" 33 | } 34 | 35 | MOCK_PREDICTION = { 36 | "id": "pred_123", 37 | "status": "succeeded", 38 | "input": {"prompt": "test"}, 39 | "output": "result", 40 | "created_at": "2024-01-01T00:00:00Z" 41 | } 42 | 43 | @pytest.fixture 44 | async def server(): 45 | """Create server instance for testing.""" 46 | return create_server(log_level=0) 47 | 48 | @pytest.fixture 49 | def mock_client(): 50 | """Create mock ReplicateClient.""" 51 | with patch("mcp-server-replicate.server.ReplicateClient") as mock: 52 | client = AsyncMock() 53 | mock.return_value.__aenter__.return_value = client 54 | yield client 55 | 56 | # Model Tools Tests 57 | async def test_list_models(server, mock_client): 58 | """Test list_models tool.""" 59 | mock_client.list_models.return_value = { 60 | "models": [MOCK_MODEL], 61 | "next_cursor": "next", 62 | "total_models": 1 63 | } 64 | 65 | result = await server.tools["list_models"].func() 66 | assert isinstance(result, ModelList) 67 | assert len(result.models) == 1 68 | assert result.models[0].owner == MOCK_MODEL["owner"] 69 | assert result.next_cursor == "next" 70 | assert result.total_count == 1 71 | 72 | async def test_search_models(server, mock_client): 73 | """Test search_models tool.""" 74 | mock_client.search_models.return_value = { 75 | "models": [MOCK_MODEL], 76 | "next_cursor": None, 77 | "total_models": 1 78 | } 79 | 80 | result = await server.tools["search_models"].func(query="stable diffusion") 81 | assert isinstance(result, ModelList) 82 | assert len(result.models) == 1 83 | assert result.models[0].name == MOCK_MODEL["name"] 84 | 85 | # Collection Tools Tests 86 | async def test_list_collections(server, mock_client): 87 | """Test list_collections tool.""" 88 | mock_client.list_collections.return_value = [MOCK_COLLECTION] 89 | 90 | result = await server.tools["list_collections"].func() 91 | assert isinstance(result, CollectionList) 92 | assert len(result.collections) == 1 93 | assert result.collections[0].name == MOCK_COLLECTION["name"] 94 | 95 | async def test_get_collection_details(server, mock_client): 96 | """Test get_collection_details tool.""" 97 | mock_client.get_collection.return_value = MOCK_COLLECTION 98 | 99 | result = await server.tools["get_collection_details"].func(collection_slug="text-to-image") 100 | assert isinstance(result, Collection) 101 | assert result.name == MOCK_COLLECTION["name"] 102 | assert result.slug == MOCK_COLLECTION["slug"] 103 | 104 | # Hardware Tools Tests 105 | async def test_list_hardware(server, mock_client): 106 | """Test list_hardware tool.""" 107 | mock_client.list_hardware.return_value = [MOCK_HARDWARE] 108 | 109 | result = await server.tools["list_hardware"].func() 110 | assert isinstance(result, HardwareList) 111 | assert len(result.hardware) == 1 112 | assert result.hardware[0].name == MOCK_HARDWARE["name"] 113 | 114 | # Template Tools Tests 115 | async def test_list_templates(server): 116 | """Test list_templates tool.""" 117 | result = await server.tools["list_templates"].func() 118 | assert isinstance(result, dict) 119 | for template in result.values(): 120 | assert "schema" in template 121 | assert "description" in template 122 | assert "version" in template 123 | 124 | async def test_validate_template_parameters(server): 125 | """Test validate_template_parameters tool.""" 126 | # This requires actual template data from TEMPLATES 127 | with pytest.raises(ValueError): 128 | await server.tools["validate_template_parameters"].func({"template": "invalid"}) 129 | 130 | # Prediction Tools Tests 131 | async def test_create_prediction(server, mock_client): 132 | """Test create_prediction tool.""" 133 | mock_client.create_prediction.return_value.json.return_value = MOCK_PREDICTION 134 | 135 | result = await server.tools["create_prediction"].func({ 136 | "version": "v1", 137 | "input": {"prompt": "test"} 138 | }) 139 | assert result["id"] == MOCK_PREDICTION["id"] 140 | assert result["status"] == MOCK_PREDICTION["status"] 141 | 142 | async def test_get_prediction(server, mock_client): 143 | """Test get_prediction tool.""" 144 | mock_client.get_prediction.return_value.json.return_value = MOCK_PREDICTION 145 | 146 | result = await server.tools["get_prediction"].func("pred_123") 147 | assert result["id"] == MOCK_PREDICTION["id"] 148 | assert result["status"] == MOCK_PREDICTION["status"] 149 | 150 | async def test_cancel_prediction(server, mock_client): 151 | """Test cancel_prediction tool.""" 152 | mock_client.cancel_prediction.return_value.json.return_value = { 153 | **MOCK_PREDICTION, 154 | "status": "canceled" 155 | } 156 | 157 | result = await server.tools["cancel_prediction"].func("pred_123") 158 | assert result["id"] == MOCK_PREDICTION["id"] 159 | assert result["status"] == "canceled" 160 | 161 | # Webhook Tools Tests 162 | async def test_get_webhook_secret(server, mock_client): 163 | """Test get_webhook_secret tool.""" 164 | mock_client.get_webhook_secret.return_value = "secret123" 165 | 166 | result = await server.tools["get_webhook_secret"].func() 167 | assert result == "secret123" 168 | 169 | async def test_verify_webhook(server): 170 | """Test verify_webhook tool.""" 171 | payload = WebhookPayload( 172 | id="evt_123", 173 | created_at=datetime.now(), 174 | type="prediction.completed", 175 | data={"prediction": MOCK_PREDICTION} 176 | ) 177 | secret = "test_secret" 178 | 179 | # Calculate valid signature 180 | payload_str = json.dumps(payload.model_dump(), sort_keys=True) 181 | import hmac, hashlib 182 | signature = hmac.new( 183 | secret.encode(), 184 | payload_str.encode(), 185 | hashlib.sha256 186 | ).hexdigest() 187 | 188 | # Test valid signature 189 | result = await server.tools["verify_webhook"].func(payload, signature, secret) 190 | assert result is True 191 | 192 | # Test invalid signature 193 | result = await server.tools["verify_webhook"].func(payload, "invalid", secret) 194 | assert result is False 195 | 196 | # Test empty signature 197 | result = await server.tools["verify_webhook"].func(payload, "", secret) 198 | assert result is False -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | """Parameter templates for Stable Diffusion models.""" 2 | 3 | from typing import Dict, Any 4 | 5 | SDXL_PARAMETERS = { 6 | "id": "sdxl-base", 7 | "name": "SDXL Base Parameters", 8 | "description": "Default parameters for SDXL models with comprehensive options for high-quality image generation", 9 | "model_type": "stable-diffusion", 10 | "default_parameters": { 11 | "width": 1024, 12 | "height": 1024, 13 | "num_inference_steps": 50, 14 | "guidance_scale": 7.5, 15 | "prompt_strength": 1.0, 16 | "refine": "expert_ensemble_refiner", 17 | "scheduler": "K_EULER", 18 | "num_outputs": 1, 19 | "high_noise_frac": 0.8, 20 | "seed": None, 21 | "apply_watermark": True, 22 | }, 23 | "parameter_schema": { 24 | "type": "object", 25 | "properties": { 26 | "prompt": { 27 | "type": "string", 28 | "description": "Text prompt for image generation. Use descriptive language and artistic terms for better results.", 29 | "minLength": 1, 30 | "maxLength": 2000 31 | }, 32 | "negative_prompt": { 33 | "type": "string", 34 | "description": "Text prompt for elements to avoid. Common defaults: 'ugly, blurry, low quality, distorted'", 35 | "maxLength": 2000 36 | }, 37 | "width": { 38 | "type": "integer", 39 | "minimum": 512, 40 | "maximum": 2048, 41 | "multipleOf": 8, 42 | "description": "Image width in pixels. Must be multiple of 8. Larger sizes need more memory." 43 | }, 44 | "height": { 45 | "type": "integer", 46 | "minimum": 512, 47 | "maximum": 2048, 48 | "multipleOf": 8, 49 | "description": "Image height in pixels. Must be multiple of 8. Larger sizes need more memory." 50 | }, 51 | "num_inference_steps": { 52 | "type": "integer", 53 | "minimum": 1, 54 | "maximum": 150, 55 | "description": "Number of denoising steps. Higher values = better quality but slower generation." 56 | }, 57 | "guidance_scale": { 58 | "type": "number", 59 | "minimum": 1, 60 | "maximum": 20, 61 | "description": "How closely to follow the prompt. Higher values = more literal but may be less creative." 62 | }, 63 | "prompt_strength": { 64 | "type": "number", 65 | "minimum": 0, 66 | "maximum": 1, 67 | "description": "Strength of the prompt in image-to-image tasks. 1.0 = full prompt strength." 68 | }, 69 | "refine": { 70 | "type": "string", 71 | "enum": ["no_refiner", "expert_ensemble_refiner", "base_image_refiner"], 72 | "description": "Type of refinement to apply. expert_ensemble_refiner provides best quality." 73 | }, 74 | "scheduler": { 75 | "type": "string", 76 | "enum": ["DDIM", "DPM_MULTISTEP", "K_EULER", "PNDM", "KLMS"], 77 | "description": "Sampling method. K_EULER is a good default, DDIM for more deterministic results." 78 | }, 79 | "num_outputs": { 80 | "type": "integer", 81 | "minimum": 1, 82 | "maximum": 4, 83 | "description": "Number of images to generate in parallel. More outputs = longer generation time." 84 | }, 85 | "high_noise_frac": { 86 | "type": "number", 87 | "minimum": 0.0, 88 | "maximum": 1.0, 89 | "description": "Fraction of inference steps to use for high noise. Higher = more variation." 90 | }, 91 | "seed": { 92 | "type": ["integer", "null"], 93 | "minimum": 0, 94 | "maximum": 2147483647, 95 | "description": "Random seed for reproducible generation. null for random seed." 96 | }, 97 | "apply_watermark": { 98 | "type": "boolean", 99 | "description": "Whether to apply invisible watermarking to detect AI-generated images." 100 | } 101 | }, 102 | "required": ["prompt"] 103 | }, 104 | "version": "1.1.0", 105 | } 106 | 107 | SD_15_PARAMETERS = { 108 | "id": "sd-1.5-base", 109 | "name": "Stable Diffusion 1.5 Parameters", 110 | "description": "Default parameters for SD 1.5 models with comprehensive options for stable image generation", 111 | "model_type": "stable-diffusion", 112 | "default_parameters": { 113 | "width": 512, 114 | "height": 512, 115 | "num_inference_steps": 50, 116 | "guidance_scale": 7.5, 117 | "scheduler": "K_EULER", 118 | "num_outputs": 1, 119 | "seed": None, 120 | "apply_watermark": True, 121 | }, 122 | "parameter_schema": { 123 | "type": "object", 124 | "properties": { 125 | "prompt": { 126 | "type": "string", 127 | "description": "Text prompt for image generation. Use descriptive language and artistic terms.", 128 | "minLength": 1, 129 | "maxLength": 2000 130 | }, 131 | "negative_prompt": { 132 | "type": "string", 133 | "description": "Text prompt for elements to avoid. Common defaults: 'ugly, blurry, low quality'", 134 | "maxLength": 2000 135 | }, 136 | "width": { 137 | "type": "integer", 138 | "minimum": 256, 139 | "maximum": 1024, 140 | "multipleOf": 8, 141 | "description": "Image width in pixels. Must be multiple of 8. SD 1.5 works best at 512x512." 142 | }, 143 | "height": { 144 | "type": "integer", 145 | "minimum": 256, 146 | "maximum": 1024, 147 | "multipleOf": 8, 148 | "description": "Image height in pixels. Must be multiple of 8. SD 1.5 works best at 512x512." 149 | }, 150 | "num_inference_steps": { 151 | "type": "integer", 152 | "minimum": 1, 153 | "maximum": 150, 154 | "description": "Number of denoising steps. Higher values = better quality but slower." 155 | }, 156 | "guidance_scale": { 157 | "type": "number", 158 | "minimum": 1, 159 | "maximum": 20, 160 | "description": "How closely to follow the prompt. 7.5 is a good default." 161 | }, 162 | "scheduler": { 163 | "type": "string", 164 | "enum": ["DDIM", "DPM_MULTISTEP", "K_EULER", "PNDM", "KLMS"], 165 | "description": "Sampling method. K_EULER is a good default for quality/speed balance." 166 | }, 167 | "num_outputs": { 168 | "type": "integer", 169 | "minimum": 1, 170 | "maximum": 4, 171 | "description": "Number of images to generate in parallel. More outputs = longer time." 172 | }, 173 | "seed": { 174 | "type": ["integer", "null"], 175 | "minimum": 0, 176 | "maximum": 2147483647, 177 | "description": "Random seed for reproducible generation. null for random seed." 178 | }, 179 | "apply_watermark": { 180 | "type": "boolean", 181 | "description": "Whether to apply invisible watermarking to detect AI-generated images." 182 | } 183 | }, 184 | "required": ["prompt"] 185 | }, 186 | "version": "1.1.0", 187 | } 188 | 189 | # Export all templates 190 | TEMPLATES: Dict[str, Dict[str, Any]] = { 191 | "sdxl": SDXL_PARAMETERS, 192 | "sd15": SD_15_PARAMETERS, 193 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/templates/parameters/prompt_templates.py: -------------------------------------------------------------------------------- 1 | """Prompt templates for different generation tasks and styles.""" 2 | 3 | from typing import Dict, Any 4 | 5 | TEXT_TO_IMAGE_TEMPLATES = { 6 | "id": "text-to-image-prompts", 7 | "name": "Text to Image Prompt Templates", 8 | "description": "Templates for generating effective text-to-image prompts", 9 | "model_type": "any", 10 | "templates": { 11 | "detailed_scene": { 12 | "description": "Template for detailed scene descriptions", 13 | "format": "{subject} in {setting}, {lighting} lighting, {mood} atmosphere, {style} style, {details}", 14 | "examples": [ 15 | { 16 | "parameters": { 17 | "subject": "a young explorer", 18 | "setting": "ancient temple ruins", 19 | "lighting": "dramatic golden hour", 20 | "mood": "mysterious", 21 | "style": "cinematic", 22 | "details": "vines growing on weathered stone, dust particles in light beams" 23 | }, 24 | "result": "a young explorer in ancient temple ruins, dramatic golden hour lighting, mysterious atmosphere, cinematic style, vines growing on weathered stone, dust particles in light beams" 25 | } 26 | ], 27 | "parameter_descriptions": { 28 | "subject": "Main subject or focus of the image", 29 | "setting": "Location or environment", 30 | "lighting": "Type and quality of lighting", 31 | "mood": "Overall emotional tone", 32 | "style": "Visual or artistic style", 33 | "details": "Additional specific details" 34 | } 35 | }, 36 | "character_portrait": { 37 | "description": "Template for character portraits", 38 | "format": "{gender} {character_type}, {appearance}, {clothing}, {expression}, {pose}, {style} style, {background}", 39 | "examples": [ 40 | { 41 | "parameters": { 42 | "gender": "female", 43 | "character_type": "warrior", 44 | "appearance": "long red hair, battle-scarred", 45 | "clothing": "ornate plate armor", 46 | "expression": "determined look", 47 | "pose": "heroic stance", 48 | "style": "digital art", 49 | "background": "stormy sky" 50 | }, 51 | "result": "female warrior, long red hair, battle-scarred, ornate plate armor, determined look, heroic stance, digital art style, stormy sky" 52 | } 53 | ], 54 | "parameter_descriptions": { 55 | "gender": "Character's gender", 56 | "character_type": "Role or profession", 57 | "appearance": "Physical characteristics", 58 | "clothing": "Outfit description", 59 | "expression": "Facial expression", 60 | "pose": "Body position", 61 | "style": "Visual style", 62 | "background": "Background setting" 63 | } 64 | }, 65 | "landscape": { 66 | "description": "Template for landscape scenes", 67 | "format": "{environment} landscape, {time_of_day}, {weather}, {features}, {style} style, {mood} mood", 68 | "examples": [ 69 | { 70 | "parameters": { 71 | "environment": "mountain", 72 | "time_of_day": "sunset", 73 | "weather": "partly cloudy", 74 | "features": "snow-capped peaks, alpine lake, pine forest", 75 | "style": "oil painting", 76 | "mood": "peaceful" 77 | }, 78 | "result": "mountain landscape, sunset, partly cloudy, snow-capped peaks, alpine lake, pine forest, oil painting style, peaceful mood" 79 | } 80 | ], 81 | "parameter_descriptions": { 82 | "environment": "Type of landscape", 83 | "time_of_day": "Time of day", 84 | "weather": "Weather conditions", 85 | "features": "Notable landscape features", 86 | "style": "Visual style", 87 | "mood": "Emotional atmosphere" 88 | } 89 | } 90 | }, 91 | "version": "1.0.0" 92 | } 93 | 94 | IMAGE_TO_IMAGE_TEMPLATES = { 95 | "id": "image-to-image-prompts", 96 | "name": "Image to Image Prompt Templates", 97 | "description": "Templates for effective image-to-image modification prompts", 98 | "model_type": "any", 99 | "templates": { 100 | "style_transfer": { 101 | "description": "Template for transferring style to an image", 102 | "format": "Transform into {style} style, {quality} quality, maintain {preserve} from original", 103 | "examples": [ 104 | { 105 | "parameters": { 106 | "style": "oil painting", 107 | "quality": "masterpiece", 108 | "preserve": "composition and lighting" 109 | }, 110 | "result": "Transform into oil painting style, masterpiece quality, maintain composition and lighting from original" 111 | } 112 | ], 113 | "parameter_descriptions": { 114 | "style": "Target artistic style", 115 | "quality": "Quality level", 116 | "preserve": "Elements to preserve" 117 | } 118 | }, 119 | "variation": { 120 | "description": "Template for creating variations", 121 | "format": "Similar to original but with {changes}, {style} style, {quality} quality", 122 | "examples": [ 123 | { 124 | "parameters": { 125 | "changes": "different color scheme", 126 | "style": "same", 127 | "quality": "high quality" 128 | }, 129 | "result": "Similar to original but with different color scheme, same style, high quality" 130 | } 131 | ], 132 | "parameter_descriptions": { 133 | "changes": "Desired changes", 134 | "style": "Style modification", 135 | "quality": "Quality level" 136 | } 137 | } 138 | }, 139 | "version": "1.0.0" 140 | } 141 | 142 | CONTROLNET_TEMPLATES = { 143 | "id": "controlnet-prompts", 144 | "name": "ControlNet Prompt Templates", 145 | "description": "Templates for ControlNet-guided image generation", 146 | "model_type": "controlnet", 147 | "templates": { 148 | "pose_guided": { 149 | "description": "Template for pose-guided generation", 150 | "format": "{subject} in {pose_description}, {clothing}, {style} style, {background}", 151 | "examples": [ 152 | { 153 | "parameters": { 154 | "subject": "young athlete", 155 | "pose_description": "dynamic running pose", 156 | "clothing": "sports attire", 157 | "style": "photorealistic", 158 | "background": "track field" 159 | }, 160 | "result": "young athlete in dynamic running pose, sports attire, photorealistic style, track field" 161 | } 162 | ], 163 | "parameter_descriptions": { 164 | "subject": "Main subject", 165 | "pose_description": "Description of the pose", 166 | "clothing": "Outfit description", 167 | "style": "Visual style", 168 | "background": "Background setting" 169 | } 170 | }, 171 | "depth_guided": { 172 | "description": "Template for depth-guided generation", 173 | "format": "{subject} with {depth_elements}, {perspective}, {style} style", 174 | "examples": [ 175 | { 176 | "parameters": { 177 | "subject": "forest path", 178 | "depth_elements": "trees fading into distance", 179 | "perspective": "one-point perspective", 180 | "style": "photorealistic" 181 | }, 182 | "result": "forest path with trees fading into distance, one-point perspective, photorealistic style" 183 | } 184 | ], 185 | "parameter_descriptions": { 186 | "subject": "Main subject", 187 | "depth_elements": "Elements showing depth", 188 | "perspective": "Type of perspective", 189 | "style": "Visual style" 190 | } 191 | } 192 | }, 193 | "version": "1.0.0" 194 | } 195 | 196 | # Export all templates 197 | TEMPLATES: Dict[str, Dict[str, Any]] = { 198 | "text_to_image": TEXT_TO_IMAGE_TEMPLATES, 199 | "image_to_image": IMAGE_TO_IMAGE_TEMPLATES, 200 | "controlnet": CONTROLNET_TEMPLATES, 201 | } -------------------------------------------------------------------------------- /src/mcp_server_replicate/replicate_client.py: -------------------------------------------------------------------------------- 1 | """Replicate API client implementation.""" 2 | 3 | import logging 4 | import os 5 | import time 6 | from typing import Any, Optional, Dict, AsyncGenerator 7 | import asyncio 8 | import random 9 | 10 | import httpx 11 | import replicate 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | # Constants 16 | REPLICATE_API_BASE = "https://api.replicate.com/v1" 17 | DEFAULT_TIMEOUT = 60.0 18 | MAX_RETRIES = 3 19 | MIN_RETRY_DELAY = 1.0 20 | MAX_RETRY_DELAY = 10.0 21 | DEFAULT_RATE_LIMIT = 100 # requests per minute 22 | 23 | class RateLimitExceeded(Exception): 24 | """Raised when rate limit is exceeded.""" 25 | 26 | def __init__(self, retry_after: float): 27 | """Initialize with retry after duration.""" 28 | self.retry_after = retry_after 29 | super().__init__(f"Rate limit exceeded. Retry after {retry_after} seconds.") 30 | 31 | class ReplicateClient: 32 | """Client for interacting with the Replicate API.""" 33 | 34 | def __init__(self, api_token: str | None = None) -> None: 35 | """Initialize the Replicate client. 36 | 37 | Args: 38 | api_token: Replicate API token for authentication 39 | """ 40 | self.client = None 41 | self.error = None 42 | self.api_token = api_token 43 | self._rate_limit = DEFAULT_RATE_LIMIT 44 | self._request_times: list[float] = [] 45 | self._retry_count = 0 46 | self.http_client = None # Initialize to None, will be set up in __aenter__ 47 | 48 | if not api_token or not api_token.strip(): 49 | self.error = "Replicate API token is required" 50 | return 51 | 52 | os.environ["REPLICATE_API_TOKEN"] = api_token 53 | self.client = replicate.Client() 54 | 55 | async def __aenter__(self): 56 | """Async context manager entry.""" 57 | # Initialize httpx client for direct API calls 58 | self.http_client = httpx.AsyncClient( 59 | base_url=REPLICATE_API_BASE, 60 | headers={ 61 | "Authorization": f"Bearer {self.api_token}", 62 | "Content-Type": "application/json", 63 | }, 64 | timeout=DEFAULT_TIMEOUT 65 | ) 66 | return self 67 | 68 | async def __aexit__(self, exc_type, exc_val, exc_tb): 69 | """Async context manager exit.""" 70 | if self.http_client: 71 | await self.http_client.aclose() 72 | self.http_client = None 73 | 74 | async def _ensure_http_client(self): 75 | """Ensure http_client is initialized.""" 76 | if not self.http_client: 77 | self.http_client = httpx.AsyncClient( 78 | base_url=REPLICATE_API_BASE, 79 | headers={ 80 | "Authorization": f"Bearer {self.api_token}", 81 | "Content-Type": "application/json", 82 | }, 83 | timeout=DEFAULT_TIMEOUT 84 | ) 85 | 86 | async def _wait_for_rate_limit(self) -> None: 87 | """Wait if necessary to comply with rate limiting.""" 88 | now = time.time() 89 | 90 | # Remove request times older than 1 minute 91 | self._request_times = [t for t in self._request_times if now - t <= 60] 92 | 93 | if len(self._request_times) >= self._rate_limit: 94 | # Calculate wait time based on oldest request 95 | wait_time = 60 - (now - self._request_times[0]) 96 | if wait_time > 0: 97 | logger.debug(f"Rate limit reached. Waiting {wait_time:.2f} seconds") 98 | await asyncio.sleep(wait_time) 99 | 100 | self._request_times.append(now) 101 | 102 | async def _handle_response(self, response: httpx.Response) -> None: 103 | """Handle rate limits and other response headers. 104 | 105 | Args: 106 | response: The HTTP response to handle 107 | 108 | Raises: 109 | RateLimitExceeded: If rate limit is exceeded 110 | """ 111 | # Update rate limit from headers if available 112 | if "X-RateLimit-Limit" in response.headers: 113 | self._rate_limit = int(response.headers["X-RateLimit-Limit"]) 114 | 115 | # Handle rate limit exceeded 116 | if response.status_code == 429: 117 | retry_after = float(response.headers.get("Retry-After", 60)) 118 | raise RateLimitExceeded(retry_after) 119 | 120 | async def _make_request( 121 | self, 122 | method: str, 123 | endpoint: str, 124 | **kwargs: Any 125 | ) -> httpx.Response: 126 | """Make an HTTP request with retries and rate limiting. 127 | 128 | Args: 129 | method: HTTP method to use 130 | endpoint: API endpoint to call 131 | **kwargs: Additional arguments to pass to httpx 132 | 133 | Returns: 134 | HTTP response 135 | 136 | Raises: 137 | Exception: If the request fails after retries 138 | """ 139 | await self._wait_for_rate_limit() 140 | 141 | for attempt in range(MAX_RETRIES): 142 | try: 143 | response = await self.http_client.request(method, endpoint, **kwargs) 144 | await self._handle_response(response) 145 | response.raise_for_status() 146 | self._retry_count = 0 # Reset on success 147 | return response 148 | 149 | except RateLimitExceeded as e: 150 | logger.warning(f"Rate limit exceeded. Waiting {e.retry_after} seconds") 151 | await asyncio.sleep(e.retry_after) 152 | continue 153 | 154 | except httpx.HTTPError as e: 155 | self._retry_count += 1 156 | if attempt == MAX_RETRIES - 1: 157 | raise 158 | 159 | # Calculate exponential backoff with jitter 160 | delay = min( 161 | MAX_RETRY_DELAY, 162 | MIN_RETRY_DELAY * (2 ** attempt) + random.uniform(0, 1) 163 | ) 164 | logger.warning( 165 | f"Request failed: {str(e)}. " 166 | f"Retrying in {delay:.2f} seconds " 167 | f"(attempt {attempt + 1}/{MAX_RETRIES})" 168 | ) 169 | await asyncio.sleep(delay) 170 | continue 171 | 172 | def list_models(self, owner: str | None = None, cursor: str | None = None) -> dict[str, Any]: 173 | """List available models on Replicate with pagination. 174 | 175 | Args: 176 | owner: Optional owner username to filter models 177 | cursor: Pagination cursor from previous response 178 | 179 | Returns: 180 | Dict containing models list, next cursor, and total count 181 | 182 | Raises: 183 | Exception: If the API request fails 184 | """ 185 | if not self.client: 186 | raise RuntimeError("Client not initialized. Check error property for details.") 187 | 188 | try: 189 | # Build params dict only including cursor if provided 190 | params = {} 191 | if cursor: 192 | params["cursor"] = cursor 193 | 194 | # Get models collection with pagination 195 | models = self.client.models.list(**params) 196 | 197 | # Get pagination info 198 | next_cursor = models.next_cursor if hasattr(models, "next_cursor") else None 199 | total_models = models.total if hasattr(models, "total") else None 200 | 201 | # Filter by owner if specified 202 | if owner: 203 | models = [m for m in models if m.owner == owner] 204 | 205 | # Format models with complete structure 206 | formatted_models = [] 207 | for model in models: 208 | model_data = { 209 | "id": f"{model.owner}/{model.name}", 210 | "owner": model.owner, 211 | "name": model.name, 212 | "description": model.description, 213 | "visibility": model.visibility, 214 | "github_url": getattr(model, "github_url", None), 215 | "paper_url": getattr(model, "paper_url", None), 216 | "license_url": getattr(model, "license_url", None), 217 | "run_count": getattr(model, "run_count", None), 218 | "cover_image_url": getattr(model, "cover_image_url", None), 219 | "default_example": getattr(model, "default_example", None), 220 | "featured": getattr(model, "featured", None), 221 | "tags": getattr(model, "tags", []), 222 | } 223 | 224 | # Add latest version info if available 225 | if model.latest_version: 226 | model_data["latest_version"] = { 227 | "id": model.latest_version.id, 228 | "created_at": model.latest_version.created_at, 229 | "cog_version": model.latest_version.cog_version, 230 | "openapi_schema": model.latest_version.openapi_schema, 231 | "model": f"{model.owner}/{model.name}", 232 | "replicate_version": getattr(model.latest_version, "replicate_version", None), 233 | "hardware": getattr(model.latest_version, "hardware", None), 234 | } 235 | 236 | formatted_models.append(model_data) 237 | 238 | return { 239 | "models": formatted_models, 240 | "next_cursor": next_cursor, 241 | "total_count": total_models, 242 | } 243 | 244 | except Exception as err: 245 | logger.error(f"Failed to list models: {str(err)}") 246 | raise Exception(f"Failed to list models: {str(err)}") from err 247 | 248 | def get_model_versions(self, model: str) -> list[dict[str, Any]]: 249 | """Get available versions for a model. 250 | 251 | Args: 252 | model: Model identifier in format 'owner/model' 253 | 254 | Returns: 255 | List of model versions with their metadata 256 | 257 | Raises: 258 | ValueError: If the model is not found 259 | Exception: If the API request fails 260 | """ 261 | if not self.client: 262 | raise RuntimeError("Client not initialized. Check error property for details.") 263 | 264 | try: 265 | # Get model 266 | model_obj = self.client.models.get(model) 267 | if not model_obj: 268 | raise ValueError(f"Model not found: {model}") 269 | 270 | # Get versions 271 | versions = model_obj.versions.list() 272 | 273 | # Return minimal version metadata 274 | return [ 275 | { 276 | "id": version.id, 277 | "created_at": version.created_at.isoformat() if version.created_at else None, 278 | "cog_version": version.cog_version, 279 | "openapi_schema": version.openapi_schema, 280 | } 281 | for version in versions 282 | ] 283 | 284 | except ValueError as err: 285 | logger.error(f"Validation error: {str(err)}") 286 | raise 287 | except Exception as err: 288 | logger.error(f"Failed to get model versions: {str(err)}") 289 | raise Exception(f"Failed to get model versions: {str(err)}") from err 290 | 291 | async def predict( 292 | self, 293 | model: str, 294 | input_data: dict[str, Any], 295 | version: str | None = None, 296 | wait: bool = False, 297 | wait_timeout: int | None = None, 298 | stream: bool = False, 299 | ) -> dict[str, Any]: 300 | """Run a prediction using a Replicate model. 301 | 302 | Args: 303 | model: Model identifier in format 'owner/model' 304 | input_data: Model-specific input parameters 305 | version: Optional model version hash 306 | wait: Whether to wait for prediction completion 307 | wait_timeout: Max seconds to wait if wait=True (1-60) 308 | stream: Whether to request streaming output 309 | 310 | Returns: 311 | Dict containing prediction details and optional stream URL 312 | 313 | Raises: 314 | ValueError: If the model or version is not found 315 | Exception: If the prediction fails 316 | """ 317 | if not self.client: 318 | raise RuntimeError("Client not initialized. Check error property for details.") 319 | 320 | try: 321 | # Validate wait_timeout 322 | if wait and wait_timeout: 323 | if not 1 <= wait_timeout <= 60: 324 | raise ValueError("wait_timeout must be between 1 and 60 seconds") 325 | 326 | # Get model 327 | model_obj = self.client.models.get(model) 328 | if not model_obj: 329 | raise ValueError(f"Model not found: {model}") 330 | 331 | # Get specific version or latest 332 | if version: 333 | model_version = model_obj.versions.get(version) 334 | if not model_version: 335 | raise ValueError(f"Version not found: {version}") 336 | else: 337 | model_version = model_obj.latest_version 338 | 339 | # Prepare headers 340 | headers = { 341 | "Authorization": f"Bearer {self.api_token}", 342 | "Content-Type": "application/json", 343 | } 344 | if wait: 345 | if wait_timeout: 346 | headers["Prefer"] = f"wait={wait_timeout}" 347 | else: 348 | headers["Prefer"] = "wait" 349 | 350 | # Prepare request body 351 | body = { 352 | "input": input_data, 353 | "stream": stream, 354 | } 355 | if version: 356 | body["version"] = version 357 | 358 | # Create prediction using rate-limited request 359 | response = await self._make_request( 360 | "POST", 361 | "/predictions", 362 | headers=headers, 363 | json=body 364 | ) 365 | data = response.json() 366 | 367 | # Format response 368 | result = { 369 | "id": data["id"], 370 | "status": data["status"], 371 | "input": data["input"], 372 | "output": data.get("output"), 373 | "error": data.get("error"), 374 | "logs": data.get("logs"), 375 | "created_at": data.get("created_at"), 376 | "started_at": data.get("started_at"), 377 | "completed_at": data.get("completed_at"), 378 | "urls": data.get("urls", {}), 379 | } 380 | 381 | # Add metrics if available 382 | if "metrics" in data: 383 | result["metrics"] = data["metrics"] 384 | 385 | # Add stream URL if requested and available 386 | if stream and "urls" in data and "stream" in data["urls"]: 387 | result["stream_url"] = data["urls"]["stream"] 388 | 389 | return result 390 | 391 | except ValueError as err: 392 | logger.error(f"Validation error: {str(err)}") 393 | raise 394 | except httpx.HTTPError as err: 395 | logger.error(f"HTTP error during prediction: {str(err)}") 396 | raise Exception(f"Prediction failed: {str(err)}") from err 397 | except Exception as err: 398 | logger.error(f"Prediction failed: {str(err)}") 399 | raise Exception(f"Prediction failed: {str(err)}") from err 400 | 401 | def get_prediction_status(self, prediction_id: str) -> dict[str, Any]: 402 | """Get the status of a prediction. 403 | 404 | Args: 405 | prediction_id: ID of the prediction to check 406 | 407 | Returns: 408 | Dict containing current status and output of the prediction 409 | 410 | Raises: 411 | ValueError: If the prediction is not found 412 | Exception: If the API request fails 413 | """ 414 | if not self.client: 415 | raise RuntimeError("Client not initialized. Check error property for details.") 416 | 417 | try: 418 | # Get prediction 419 | prediction = self.client.predictions.get(prediction_id) 420 | if not prediction: 421 | raise ValueError(f"Prediction not found: {prediction_id}") 422 | 423 | # Return prediction status and output 424 | return { 425 | "id": prediction.id, 426 | "status": prediction.status, 427 | "output": prediction.output, 428 | "error": prediction.error, 429 | "created_at": prediction.created_at.isoformat() if prediction.created_at else None, 430 | "started_at": prediction.started_at.isoformat() if prediction.started_at else None, 431 | "completed_at": prediction.completed_at.isoformat() if prediction.completed_at else None, 432 | "urls": prediction.urls, 433 | "metrics": prediction.metrics, 434 | } 435 | 436 | except ValueError as err: 437 | logger.error(f"Validation error: {str(err)}") 438 | raise 439 | except Exception as err: 440 | logger.error(f"Failed to get prediction status: {str(err)}") 441 | raise Exception(f"Failed to get prediction status: {str(err)}") from err 442 | 443 | async def search_models( 444 | self, 445 | query: str, 446 | cursor: Optional[str] = None, 447 | ) -> dict[str, Any]: 448 | """Search for models using the QUERY endpoint. 449 | 450 | Args: 451 | query: Search query string 452 | cursor: Optional pagination cursor 453 | 454 | Returns: 455 | Dict containing search results with pagination info 456 | 457 | Raises: 458 | Exception: If the API request fails 459 | """ 460 | if not self.client: 461 | raise RuntimeError("Client not initialized. Check error property for details.") 462 | 463 | try: 464 | # Build URL with cursor if provided 465 | url = "/models" 466 | if cursor: 467 | url = f"{url}?cursor={cursor}" 468 | 469 | # Make QUERY request 470 | response = await self.http_client.request( 471 | "QUERY", 472 | url, 473 | content=query, 474 | headers={"Content-Type": "text/plain"} 475 | ) 476 | response.raise_for_status() 477 | data = response.json() 478 | 479 | # Format response with complete model structure 480 | return { 481 | "models": [ 482 | { 483 | "id": f"{model['owner']}/{model['name']}", 484 | "owner": model["owner"], 485 | "name": model["name"], 486 | "description": model.get("description"), 487 | "visibility": model.get("visibility", "public"), 488 | "github_url": model.get("github_url"), 489 | "paper_url": model.get("paper_url"), 490 | "license_url": model.get("license_url"), 491 | "run_count": model.get("run_count"), 492 | "cover_image_url": model.get("cover_image_url"), 493 | "default_example": model.get("default_example"), 494 | "featured": model.get("featured", False), 495 | "tags": model.get("tags", []), 496 | "latest_version": model.get("latest_version", { 497 | "id": model.get("latest_version", {}).get("id"), 498 | "created_at": model.get("latest_version", {}).get("created_at"), 499 | "cog_version": model.get("latest_version", {}).get("cog_version"), 500 | "openapi_schema": model.get("latest_version", {}).get("openapi_schema"), 501 | "model": f"{model['owner']}/{model['name']}", 502 | "replicate_version": model.get("latest_version", {}).get("replicate_version"), 503 | "hardware": model.get("latest_version", {}).get("hardware"), 504 | } if model.get("latest_version") else None), 505 | } 506 | for model in data.get("results", []) 507 | ], 508 | "next_cursor": data.get("next"), 509 | "total_count": data.get("total"), 510 | } 511 | 512 | except httpx.HTTPError as err: 513 | logger.error(f"HTTP error during model search: {str(err)}") 514 | raise Exception(f"Failed to search models: {str(err)}") from err 515 | except Exception as err: 516 | logger.error(f"Failed to search models: {str(err)}") 517 | raise Exception(f"Failed to search models: {str(err)}") from err 518 | 519 | async def list_hardware(self) -> list[dict[str, str]]: 520 | """Get list of available hardware options for running models. 521 | 522 | Returns: 523 | List of hardware options with name and SKU 524 | 525 | Raises: 526 | Exception: If the API request fails 527 | """ 528 | if not self.client: 529 | raise RuntimeError("Client not initialized. Check error property for details.") 530 | 531 | try: 532 | response = await self.http_client.get("/hardware") 533 | response.raise_for_status() 534 | 535 | return [ 536 | { 537 | "name": hw["name"], 538 | "sku": hw["sku"], 539 | } 540 | for hw in response.json() 541 | ] 542 | 543 | except httpx.HTTPError as err: 544 | logger.error(f"HTTP error getting hardware options: {str(err)}") 545 | raise Exception(f"Failed to get hardware options: {str(err)}") from err 546 | except Exception as err: 547 | logger.error(f"Failed to get hardware options: {str(err)}") 548 | raise Exception(f"Failed to get hardware options: {str(err)}") from err 549 | 550 | async def list_collections(self) -> list[dict[str, Any]]: 551 | """Get list of available model collections. 552 | 553 | Returns: 554 | List of collections with their metadata 555 | 556 | Raises: 557 | Exception: If the API request fails 558 | """ 559 | if not self.client: 560 | raise RuntimeError("Client not initialized. Check error property for details.") 561 | 562 | try: 563 | response = await self.http_client.get("/collections") 564 | response.raise_for_status() 565 | data = response.json() 566 | 567 | return [ 568 | { 569 | "name": collection["name"], 570 | "slug": collection["slug"], 571 | "description": collection.get("description"), 572 | } 573 | for collection in data.get("results", []) 574 | ] 575 | 576 | except httpx.HTTPError as err: 577 | logger.error(f"HTTP error listing collections: {str(err)}") 578 | raise Exception(f"Failed to list collections: {str(err)}") from err 579 | except Exception as err: 580 | logger.error(f"Failed to list collections: {str(err)}") 581 | raise Exception(f"Failed to list collections: {str(err)}") from err 582 | 583 | async def get_collection(self, collection_slug: str) -> dict[str, Any]: 584 | """Get details of a specific collection including its models. 585 | 586 | Args: 587 | collection_slug: The slug identifier of the collection 588 | 589 | Returns: 590 | Collection details including contained models 591 | 592 | Raises: 593 | ValueError: If the collection is not found 594 | Exception: If the API request fails 595 | """ 596 | if not self.client: 597 | raise RuntimeError("Client not initialized. Check error property for details.") 598 | 599 | try: 600 | response = await self.http_client.get(f"/collections/{collection_slug}") 601 | response.raise_for_status() 602 | data = response.json() 603 | 604 | return { 605 | "name": data["name"], 606 | "slug": data["slug"], 607 | "description": data.get("description"), 608 | "models": [ 609 | { 610 | "id": f"{model['owner']}/{model['name']}", 611 | "owner": model["owner"], 612 | "name": model["name"], 613 | "description": model.get("description"), 614 | "visibility": model.get("visibility", "public"), 615 | "latest_version": model.get("latest_version"), 616 | } 617 | for model in data.get("models", []) 618 | ] 619 | } 620 | 621 | except httpx.HTTPStatusError as err: 622 | if err.response.status_code == 404: 623 | raise ValueError(f"Collection not found: {collection_slug}") 624 | logger.error(f"HTTP error getting collection: {str(err)}") 625 | raise Exception(f"Failed to get collection: {str(err)}") from err 626 | except Exception as err: 627 | logger.error(f"Failed to get collection: {str(err)}") 628 | raise Exception(f"Failed to get collection: {str(err)}") from err 629 | 630 | async def get_webhook_secret(self) -> str: 631 | """Get the signing secret for the default webhook endpoint. 632 | 633 | This secret is used to verify that webhook requests are coming from Replicate. 634 | 635 | Returns: 636 | The webhook signing secret 637 | 638 | Raises: 639 | Exception: If the API request fails 640 | """ 641 | if not self.client: 642 | raise RuntimeError("Client not initialized. Check error property for details.") 643 | 644 | try: 645 | response = await self.http_client.get("/webhooks/default/secret") 646 | response.raise_for_status() 647 | data = response.json() 648 | 649 | return data["key"] 650 | 651 | except httpx.HTTPError as err: 652 | logger.error(f"HTTP error getting webhook secret: {str(err)}") 653 | raise Exception(f"Failed to get webhook secret: {str(err)}") from err 654 | except Exception as err: 655 | logger.error(f"Failed to get webhook secret: {str(err)}") 656 | raise Exception(f"Failed to get webhook secret: {str(err)}") from err 657 | 658 | async def cancel_prediction(self, prediction_id: str) -> dict[str, Any]: 659 | """Cancel a running prediction. 660 | 661 | Args: 662 | prediction_id: The ID of the prediction to cancel 663 | 664 | Returns: 665 | Dict containing the updated prediction status 666 | 667 | Raises: 668 | ValueError: If the prediction is not found 669 | Exception: If the cancellation fails 670 | """ 671 | if not self.client: 672 | raise RuntimeError("Client not initialized. Check error property for details.") 673 | 674 | try: 675 | response = await self.http_client.post( 676 | f"/predictions/{prediction_id}/cancel", 677 | headers={ 678 | "Authorization": f"Bearer {self.api_token}", 679 | "Content-Type": "application/json", 680 | } 681 | ) 682 | response.raise_for_status() 683 | return response.json() 684 | 685 | except httpx.HTTPStatusError as err: 686 | if err.response.status_code == 404: 687 | raise ValueError(f"Prediction not found: {prediction_id}") 688 | logger.error(f"Failed to cancel prediction: {str(err)}") 689 | raise Exception(f"Failed to cancel prediction: {str(err)}") from err 690 | except Exception as err: 691 | logger.error(f"Failed to cancel prediction: {str(err)}") 692 | raise Exception(f"Failed to cancel prediction: {str(err)}") from err 693 | 694 | async def list_predictions( 695 | self, 696 | status: str | None = None, 697 | limit: int = 10 698 | ) -> list[dict[str, Any]]: 699 | """List recent predictions with optional filtering. 700 | 701 | Args: 702 | status: Optional status to filter by (starting|processing|succeeded|failed|canceled) 703 | limit: Maximum number of predictions to return (1-100) 704 | 705 | Returns: 706 | List of prediction objects 707 | 708 | Raises: 709 | ValueError: If limit is out of range 710 | Exception: If the API request fails 711 | """ 712 | if not self.client: 713 | raise RuntimeError("Client not initialized. Check error property for details.") 714 | 715 | if not 1 <= limit <= 100: 716 | raise ValueError("limit must be between 1 and 100") 717 | 718 | try: 719 | params = {"limit": limit} 720 | if status: 721 | params["status"] = status 722 | 723 | response = await self.http_client.get( 724 | "/predictions", 725 | params=params, 726 | headers={ 727 | "Authorization": f"Bearer {self.api_token}", 728 | "Content-Type": "application/json", 729 | } 730 | ) 731 | response.raise_for_status() 732 | return response.json() 733 | 734 | except Exception as err: 735 | logger.error(f"Failed to list predictions: {str(err)}") 736 | raise Exception(f"Failed to list predictions: {str(err)}") from err 737 | 738 | async def create_prediction( 739 | self, 740 | version: str, 741 | input: Dict[str, Any], 742 | webhook: Optional[str] = None, 743 | ) -> Dict[str, Any]: 744 | """Create a new prediction using a model version. 745 | 746 | Args: 747 | version: Model version ID 748 | input: Model input parameters 749 | webhook: Optional webhook URL for prediction updates 750 | 751 | Returns: 752 | Dict containing prediction details 753 | 754 | Raises: 755 | Exception: If the prediction creation fails 756 | """ 757 | if not self.client: 758 | raise RuntimeError("Client not initialized. Check error property for details.") 759 | 760 | try: 761 | await self._ensure_http_client() 762 | 763 | # Prepare request body 764 | body = { 765 | "version": version, 766 | "input": input, 767 | } 768 | if webhook: 769 | body["webhook"] = webhook 770 | 771 | # Create prediction using rate-limited request 772 | response = await self._make_request( 773 | "POST", 774 | "/predictions", 775 | json=body 776 | ) 777 | data = response.json() 778 | 779 | # Format response 780 | result = { 781 | "id": data["id"], 782 | "status": data["status"], 783 | "input": data["input"], 784 | "output": data.get("output"), 785 | "error": data.get("error"), 786 | "logs": data.get("logs"), 787 | "created_at": data.get("created_at"), 788 | "started_at": data.get("started_at"), 789 | "completed_at": data.get("completed_at"), 790 | "urls": data.get("urls", {}), 791 | } 792 | 793 | # Add metrics if available 794 | if "metrics" in data: 795 | result["metrics"] = data["metrics"] 796 | 797 | return result 798 | 799 | except Exception as err: 800 | logger.error(f"Failed to create prediction: {str(err)}") 801 | raise Exception(f"Failed to create prediction: {str(err)}") from err 802 | -------------------------------------------------------------------------------- /src/mcp_server_replicate/server.py: -------------------------------------------------------------------------------- 1 | """FastMCP server implementation for Replicate API.""" 2 | 3 | import asyncio 4 | import base64 5 | import hashlib 6 | import hmac 7 | import json 8 | import logging 9 | import os 10 | import webbrowser 11 | from collections.abc import Sequence 12 | from typing import Any 13 | 14 | import httpx 15 | import jsonschema 16 | from mcp.server.fastmcp import FastMCP 17 | from mcp.server.fastmcp.prompts.base import Message, TextContent, UserMessage 18 | from mcp.server.session import ServerSession 19 | from mcp.types import ( 20 | AnyUrl, 21 | BlobResourceContents, 22 | EmptyResult, 23 | ResourceUpdatedNotification, 24 | TextResourceContents, 25 | ) 26 | from pydantic import BaseModel, Field, field_validator 27 | 28 | from .models.collection import Collection, CollectionList 29 | from .models.hardware import Hardware, HardwareList 30 | from .models.model import Model, ModelList 31 | from .models.webhook import WebhookPayload 32 | from .replicate_client import ReplicateClient 33 | from .templates.parameters.common_configs import QUALITY_PRESETS, STYLE_PRESETS, TEMPLATES 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class SubscriptionRequest(BaseModel): 39 | """Request model for subscription operations.""" 40 | 41 | uri: str = Field(..., description="Resource URI to subscribe to") 42 | session_id: str = Field(..., description="ID of the session making the request") 43 | 44 | 45 | class GenerationSubscriptionManager: 46 | """Manages subscriptions to generation resources.""" 47 | 48 | def __init__(self): 49 | self._subscriptions: dict[str, set[ServerSession]] = {} 50 | self._check_task: asyncio.Task | None = None 51 | 52 | async def subscribe(self, uri: str, session: ServerSession): 53 | """Subscribe a session to generation updates.""" 54 | prediction_id = uri.replace("generations://", "") 55 | if prediction_id not in self._subscriptions: 56 | self._subscriptions[prediction_id] = set() 57 | self._subscriptions[prediction_id].add(session) 58 | 59 | # Start checking if not already running 60 | if not self._check_task: 61 | self._check_task = asyncio.create_task(self._check_generations()) 62 | 63 | async def unsubscribe(self, uri: str, session: ServerSession): 64 | """Unsubscribe a session from generation updates.""" 65 | prediction_id = uri.replace("generations://", "") 66 | if prediction_id in self._subscriptions: 67 | self._subscriptions[prediction_id].discard(session) 68 | if not self._subscriptions[prediction_id]: 69 | del self._subscriptions[prediction_id] 70 | 71 | # Stop checking if no more subscriptions 72 | if not self._subscriptions and self._check_task: 73 | self._check_task.cancel() 74 | self._check_task = None 75 | 76 | async def _check_generations(self): 77 | """Periodically check subscribed generations and notify of updates.""" 78 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 79 | while True: 80 | try: 81 | for prediction_id, sessions in list(self._subscriptions.items()): 82 | try: 83 | result = await client.get_prediction_status(prediction_id) 84 | # Notify on completion or failure 85 | if result["status"] in ["succeeded", "failed", "canceled"]: 86 | # For succeeded generations with image output 87 | if result["status"] == "succeeded" and result.get("output"): 88 | # For image generation models, output is typically a list with the image URL as first item 89 | image_url = ( 90 | result["output"][0] if isinstance(result["output"], list) else result["output"] 91 | ) 92 | 93 | # First send a notification with just the URL and metadata 94 | notification = ResourceUpdatedNotification( 95 | method="notifications/resources/updated", 96 | params={"uri": f"generations://{prediction_id}"}, 97 | ) 98 | 99 | # Create text resource with metadata and URL 100 | text_resource = TextResourceContents( 101 | type="text", 102 | uri=f"generations://{prediction_id}", 103 | mimeType="application/json", 104 | text=json.dumps( 105 | { 106 | "status": "succeeded", 107 | "image_url": image_url, 108 | "created_at": result.get("created_at"), 109 | "completed_at": result.get("completed_at"), 110 | "metrics": result.get("metrics", {}), 111 | "urls": result.get("urls", {}), 112 | "input": result.get("input", {}), 113 | }, 114 | indent=2, 115 | ), 116 | ) 117 | 118 | # Send notification and text resource to all sessions 119 | for session in sessions: 120 | await session.send_notification(notification) 121 | await session.send_resource(text_resource) 122 | 123 | # Remove from subscriptions since we've notified 124 | del self._subscriptions[prediction_id] 125 | else: 126 | # For failed or canceled generations, create text resource 127 | resource = TextResourceContents( 128 | uri=f"generations://{prediction_id}", 129 | mimeType="application/json", 130 | text=json.dumps( 131 | { 132 | "status": result["status"], 133 | "error": result.get("error"), 134 | "created_at": result.get("created_at"), 135 | "completed_at": result.get("completed_at"), 136 | "metrics": result.get("metrics", {}), 137 | "urls": result.get("urls", {}), 138 | }, 139 | indent=2, 140 | ), 141 | ) 142 | 143 | # Send notification with the resource 144 | notification = ResourceUpdatedNotification( 145 | method="notifications/resources/updated", 146 | params={"uri": AnyUrl(f"generations://{prediction_id}")}, 147 | ) 148 | for session in sessions: 149 | await session.send_notification(notification) 150 | # Also send the resource directly 151 | await session.send_resource(resource) 152 | 153 | # Remove completed/failed generation from subscriptions 154 | del self._subscriptions[prediction_id] 155 | except Exception as e: 156 | logger.error(f"Error checking generation {prediction_id}: {e}") 157 | 158 | if not self._subscriptions: 159 | break 160 | 161 | await asyncio.sleep(2.0) # Poll every 2 seconds 162 | except asyncio.CancelledError: 163 | break 164 | except Exception as e: 165 | logger.error(f"Error in generation check loop: {e}") 166 | await asyncio.sleep(5.0) # Back off on errors 167 | 168 | 169 | async def select_model_for_task( 170 | task: str, 171 | style: str | None = None, 172 | quality: str = "balanced", 173 | ) -> tuple[Model, dict[str, Any]]: 174 | """Select the best model for a given task and get optimal parameters. 175 | 176 | Args: 177 | task: Task description/prompt 178 | style: Optional style preference 179 | quality: Quality preset (draft, balanced, quality, extreme) 180 | 181 | Returns: 182 | Tuple of (selected model, optimized parameters) 183 | """ 184 | # Build search query 185 | search_query = task 186 | if style: 187 | search_query = f"{style} style {search_query}" 188 | 189 | # Search for models 190 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 191 | result = await client.search_models(search_query) 192 | 193 | if not result["models"]: 194 | raise ValueError("No suitable models found for the task") 195 | 196 | # Score and rank models 197 | scored_models = [] 198 | for model in result["models"]: 199 | score = 0 200 | 201 | # Popularity score (0-50) 202 | run_count = model.get("run_count", 0) 203 | score += min(50, (run_count / 1000) * 50) 204 | 205 | # Featured bonus 206 | if model.get("featured"): 207 | score += 20 208 | 209 | # Version stability 210 | if model.get("latest_version"): 211 | score += 10 212 | 213 | # Tag matching 214 | tags = model.get("tags", []) 215 | if style and any(style.lower() in tag.lower() for tag in tags): 216 | score += 15 217 | if "image" in tags or "text-to-image" in tags: 218 | score += 15 219 | 220 | scored_models.append((model, score)) 221 | 222 | # Sort by score 223 | scored_models.sort(key=lambda x: x[1], reverse=True) 224 | selected_model = scored_models[0][0] 225 | 226 | # Get quality preset 227 | quality_preset = TEMPLATES["quality-presets"]["presets"].get( 228 | quality, TEMPLATES["quality-presets"]["presets"]["balanced"] 229 | ) 230 | 231 | # Get style preset if specified 232 | parameters = quality_preset["parameters"].copy() 233 | if style: 234 | style_preset = TEMPLATES["style-presets"]["presets"].get( 235 | style.lower(), TEMPLATES["style-presets"]["presets"].get("photorealistic") 236 | ) 237 | if style_preset: 238 | parameters.update(style_preset["parameters"]) 239 | 240 | return Model(**selected_model), parameters 241 | 242 | 243 | class TemplateInput(BaseModel): 244 | """Input for template-based operations.""" 245 | 246 | template: str = Field(..., description="Template identifier") 247 | parameters: dict[str, Any] = Field(default_factory=dict, description="Template parameters") 248 | 249 | @field_validator("template") 250 | def validate_template(cls, v: str) -> str: 251 | """Validate template identifier.""" 252 | if v not in TEMPLATES: 253 | raise ValueError(f"Unknown template: {v}") 254 | return v 255 | 256 | @field_validator("parameters") 257 | def validate_parameters(cls, v: dict[str, Any], values: dict[str, Any]) -> dict[str, Any]: 258 | """Validate template parameters.""" 259 | if "template" not in values: 260 | return v 261 | 262 | template = TEMPLATES[values["template"]] 263 | try: 264 | jsonschema.validate(v, template["parameter_schema"]) 265 | except jsonschema.exceptions.ValidationError as e: 266 | raise ValueError(f"Invalid parameters: {e.message}") from e 267 | return v 268 | 269 | 270 | class PredictionInput(BaseModel): 271 | """Input for prediction operations.""" 272 | 273 | version: str = Field(..., description="Model version ID") 274 | input: dict[str, Any] = Field(..., description="Model input parameters") 275 | webhook: str | None = Field(None, description="Webhook URL for prediction updates") 276 | 277 | @field_validator("input") 278 | def validate_input(cls, v: dict[str, Any]) -> dict[str, Any]: 279 | """Validate prediction input.""" 280 | if not isinstance(v, dict): 281 | raise ValueError("Input must be a dictionary") 282 | return v 283 | 284 | 285 | def create_server(*, log_level: int = logging.WARNING) -> FastMCP: 286 | """Create and configure the FastMCP server. 287 | 288 | Args: 289 | log_level: The logging level to use. Defaults to WARNING. 290 | 291 | Returns: 292 | Configured FastMCP server instance. 293 | """ 294 | # Configure logging 295 | logging.basicConfig(level=log_level) 296 | logger.setLevel(log_level) 297 | 298 | # Verify API token is available 299 | api_token = os.getenv("REPLICATE_API_TOKEN") 300 | if not api_token: 301 | raise ValueError( 302 | "REPLICATE_API_TOKEN environment variable is required. " "Get your token from https://replicate.com/account" 303 | ) 304 | 305 | # Create server instance 306 | mcp = FastMCP("Replicate Server") 307 | 308 | # Add resources 309 | @mcp.resource("templates://list") 310 | def list_available_templates() -> str: 311 | """List all available templates with descriptions.""" 312 | template_info = [] 313 | for name, template in TEMPLATES.items(): 314 | template_info.append( 315 | f"Template: {name}\n" 316 | f"Description: {template.get('description', 'No description')}\n" 317 | f"Version: {template.get('version', '1.0.0')}\n" 318 | "---" 319 | ) 320 | return "\n".join(template_info) 321 | 322 | @mcp.resource("templates://{name}") 323 | def get_template_details(name: str) -> str: 324 | """Get detailed information about a specific template.""" 325 | if name not in TEMPLATES: 326 | raise ValueError(f"Template not found: {name}") 327 | 328 | template = TEMPLATES[name] 329 | return json.dumps( 330 | { 331 | "name": name, 332 | "description": template.get("description", ""), 333 | "version": template.get("version", "1.0.0"), 334 | "parameter_schema": template["parameter_schema"], 335 | "examples": template.get("examples", []), 336 | }, 337 | indent=2, 338 | ) 339 | 340 | @mcp.resource("generations://{prediction_id}") 341 | async def get_generation(prediction_id: str) -> TextResourceContents | BlobResourceContents: 342 | """Get a specific image generation result.""" 343 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 344 | result = await client.get_prediction_status(prediction_id) 345 | 346 | # If not succeeded, return status info 347 | if result["status"] != "succeeded": 348 | return TextResourceContents( 349 | uri=f"generations://{prediction_id}", 350 | mimeType="application/json", 351 | text=json.dumps( 352 | { 353 | "status": result["status"], 354 | "created_at": result.get("created_at"), 355 | "started_at": result.get("started_at"), 356 | "completed_at": result.get("completed_at"), 357 | "error": result.get("error"), 358 | "logs": result.get("logs"), 359 | "urls": result.get("urls", {}), 360 | "metrics": result.get("metrics", {}), 361 | } 362 | ), 363 | ) 364 | 365 | # For succeeded generations, return image URL and metadata 366 | image_url = result["output"][0] if isinstance(result["output"], list) else result["output"] 367 | return TextResourceContents( 368 | uri=f"generations://{prediction_id}", 369 | mimeType="application/json", 370 | text=json.dumps( 371 | { 372 | "status": "succeeded", 373 | "image_url": image_url, 374 | "created_at": result.get("created_at"), 375 | "completed_at": result.get("completed_at"), 376 | "metrics": result.get("metrics", {}), 377 | "urls": result.get("urls", {}), 378 | "input": result.get("input", {}), 379 | } 380 | ), 381 | ) 382 | 383 | @mcp.resource("generations://list") 384 | async def list_generations() -> TextResourceContents: 385 | """List all available generations with their details and resource URIs.""" 386 | async with ReplicateClient() as client: 387 | predictions = await client.list_predictions(limit=100) 388 | return TextResourceContents( 389 | uri="generations://list", 390 | mimeType="application/json", 391 | text=json.dumps( 392 | { 393 | "total_count": len(predictions), 394 | "generations": [ 395 | { 396 | "id": p["id"], 397 | "status": p["status"], 398 | "created_at": p.get("created_at"), 399 | "completed_at": p.get("completed_at"), 400 | "prompt": p.get("input", {}).get("prompt"), # Extract prompt for easy reference 401 | "style": p.get("input", {}).get("style"), # Extract style for filtering 402 | "quality": p.get("input", {}).get("quality", "balanced"), 403 | "error": p.get("error"), 404 | "resource_uri": f"generations://{p['id']}", 405 | "metrics": p.get("metrics", {}), # Include performance metrics 406 | "urls": p.get("urls", {}), # Include direct URLs 407 | } 408 | for p in predictions 409 | ], 410 | }, 411 | indent=2, 412 | ), 413 | ) 414 | 415 | @mcp.resource("generations://search/{query}") 416 | async def search_generations(query: str) -> TextResourceContents: 417 | """Search through available generations by prompt text or metadata.""" 418 | async with ReplicateClient() as client: 419 | predictions = await client.list_predictions(limit=100) 420 | # Improved search - check prompt, style, and quality 421 | filtered = [] 422 | query_lower = query.lower() 423 | for p in predictions: 424 | input_data = p.get("input", {}) 425 | searchable_text = ( 426 | f"{input_data.get('prompt', '')} {input_data.get('style', '')} {input_data.get('quality', '')}" 427 | ) 428 | if query_lower in searchable_text.lower(): 429 | filtered.append(p) 430 | 431 | return TextResourceContents( 432 | uri=f"generations://search/{query}", 433 | mimeType="application/json", 434 | text=json.dumps( 435 | { 436 | "query": query, 437 | "total_count": len(filtered), 438 | "generations": [ 439 | { 440 | "id": p["id"], 441 | "status": p["status"], 442 | "created_at": p.get("created_at"), 443 | "completed_at": p.get("completed_at"), 444 | "prompt": p.get("input", {}).get("prompt"), 445 | "style": p.get("input", {}).get("style"), 446 | "quality": p.get("input", {}).get("quality", "balanced"), 447 | "error": p.get("error"), 448 | "resource_uri": f"generations://{p['id']}", 449 | "metrics": p.get("metrics", {}), 450 | "urls": p.get("urls", {}), 451 | } 452 | for p in filtered 453 | ], 454 | }, 455 | indent=2, 456 | ), 457 | ) 458 | 459 | @mcp.resource("generations://status/{status}") 460 | async def filter_generations_by_status(status: str) -> TextResourceContents: 461 | """Get generations filtered by status (starting, processing, succeeded, failed, canceled).""" 462 | async with ReplicateClient() as client: 463 | predictions = await client.list_predictions(status=status, limit=100) 464 | return TextResourceContents( 465 | uri=f"generations://status/{status}", 466 | mimeType="application/json", 467 | text=json.dumps( 468 | { 469 | "status": status, 470 | "total_count": len(predictions), 471 | "generations": [ 472 | { 473 | "id": p["id"], 474 | "created_at": p.get("created_at"), 475 | "completed_at": p.get("completed_at"), 476 | "prompt": p.get("input", {}).get("prompt"), 477 | "style": p.get("input", {}).get("style"), 478 | "quality": p.get("input", {}).get("quality", "balanced"), 479 | "error": p.get("error"), 480 | "resource_uri": f"generations://{p['id']}", 481 | "metrics": p.get("metrics", {}), 482 | "urls": p.get("urls", {}), 483 | } 484 | for p in predictions 485 | ], 486 | }, 487 | indent=2, 488 | ), 489 | ) 490 | 491 | @mcp.resource("models://popular") 492 | async def get_popular_models() -> str: 493 | """Get a list of popular models on Replicate.""" 494 | async with ReplicateClient() as client: 495 | models = await client.list_models() 496 | return json.dumps( 497 | { 498 | "models": models["models"], 499 | "total": models.get("total_models", 0), 500 | }, 501 | indent=2, 502 | ) 503 | 504 | # Add prompts 505 | @mcp.prompt() 506 | def text_to_image() -> Sequence[Message]: 507 | """Generate an image from text using available models.""" 508 | return [ 509 | UserMessage( 510 | content=TextContent( 511 | type="text", 512 | text=( 513 | "I'll help you create an image using Replicate's SDXL model. " 514 | "To get the best results, please tell me:\n\n" 515 | "1. What you want to see in the image (be specific)\n" 516 | "2. Style preferences (e.g., photorealistic, anime, oil painting)\n" 517 | "3. Quality level (draft=fast, balanced=default, quality=better, extreme=best)\n" 518 | "4. Any specific requirements (size, aspect ratio, etc.)\n\n" 519 | "For example: 'Create a photorealistic mountain landscape at sunset with snow-capped peaks, " 520 | "quality level, 16:9 aspect ratio'\n\n" 521 | "Once you provide these details:\n" 522 | "- I'll start the generation and provide real-time updates\n" 523 | "- You can wait for completion or start another generation\n" 524 | "- When ready, I can show you the image or open it on your system\n" 525 | "- You can also browse, search, or manage your generations" 526 | ), 527 | ) 528 | ) 529 | ] 530 | 531 | @mcp.prompt() 532 | def image_to_image() -> Sequence[Message]: 533 | """Transform an existing image using various models.""" 534 | return [ 535 | UserMessage( 536 | content=TextContent( 537 | type="text", 538 | text=( 539 | "I'll help you transform an existing image using Replicate's models. " 540 | "Please provide:\n\n" 541 | "1. The URL or path to your source image\n" 542 | "2. The type of transformation you want (e.g., style transfer, upscaling, inpainting)\n" 543 | "3. Any specific settings or parameters\n\n" 544 | "I'll help you choose the right model and format your request." 545 | ), 546 | ) 547 | ) 548 | ] 549 | 550 | @mcp.prompt() 551 | def model_selection(task: str | None = None) -> Sequence[Message]: 552 | """Help choose the right model for a specific task.""" 553 | base_prompt = ( 554 | "I'll help you select the best Replicate model for your needs. " 555 | "Please tell me about your task and requirements, including:\n\n" 556 | "1. The type of input you have\n" 557 | "2. Your desired output\n" 558 | "3. Any specific quality or performance requirements\n" 559 | "4. Any budget or hardware constraints\n\n" 560 | ) 561 | 562 | if task: 563 | base_prompt += f"\nFor {task}, I recommend considering these aspects:\n" 564 | if "image" in task.lower(): 565 | base_prompt += ( 566 | "- Input/output image dimensions\n" 567 | "- Style and quality requirements\n" 568 | "- Processing speed needs\n" 569 | ) 570 | elif "text" in task.lower(): 571 | base_prompt += ( 572 | "- Input length considerations\n" "- Output format requirements\n" "- Specific language needs\n" 573 | ) 574 | 575 | return [UserMessage(content=TextContent(type="text", text=base_prompt))] 576 | 577 | @mcp.prompt() 578 | def parameter_help(template: str | None = None) -> Sequence[Message]: 579 | """Get help with model parameters and templates.""" 580 | if template and template in TEMPLATES: 581 | tmpl = TEMPLATES[template] 582 | text = ( 583 | f"I'll help you with the {template} template.\n\n" 584 | f"Description: {tmpl.get('description', 'No description')}\n" 585 | "Required Parameters:\n" 586 | + "\n".join( 587 | f"- {param}: {schema.get('description', 'No description')}" 588 | for param, schema in tmpl["parameter_schema"]["properties"].items() 589 | if param in tmpl["parameter_schema"].get("required", []) 590 | ) 591 | + "\n\nOptional Parameters:\n" 592 | + "\n".join( 593 | f"- {param}: {schema.get('description', 'No description')}" 594 | for param, schema in tmpl["parameter_schema"]["properties"].items() 595 | if param not in tmpl["parameter_schema"].get("required", []) 596 | ) 597 | ) 598 | else: 599 | text = ( 600 | "I'll help you understand and configure model parameters. " 601 | "Please provide:\n\n" 602 | "1. The model or template you're using\n" 603 | "2. Any specific parameters you need help with\n" 604 | "3. Your use case or requirements\n\n" 605 | "I'll explain the parameters and suggest appropriate values." 606 | ) 607 | 608 | return [UserMessage(content=TextContent(type="text", text=text))] 609 | 610 | @mcp.prompt() 611 | def after_generation() -> Sequence[Message]: 612 | """Prompt shown after starting an image generation.""" 613 | return [ 614 | UserMessage( 615 | content=TextContent( 616 | type="text", 617 | text=( 618 | "Your image generation has started! You have several options:\n\n" 619 | "1. Wait here - I'll check the progress and show you the image when it's ready\n" 620 | "2. Browse your generations - I can show you a list of all your generations\n" 621 | "3. Start another generation - We can create more images while this one processes\n\n" 622 | "When the image is ready, I can:\n" 623 | "- Show you the image directly\n" 624 | "- Open it with your system's default image viewer\n" 625 | "- Save it or share it\n" 626 | "- Create variations or apply transformations\n\n" 627 | "What would you like to do?" 628 | ), 629 | ) 630 | ) 631 | ] 632 | 633 | # Model Discovery Tools 634 | @mcp.tool() 635 | async def list_models(owner: str | None = None) -> ModelList: 636 | """List available models on Replicate with optional filtering by owner.""" 637 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 638 | result = await client.list_models(owner=owner) 639 | return ModelList( 640 | models=[Model(**model) for model in result["models"]], 641 | next_cursor=result.get("next_cursor"), 642 | total_count=result.get("total_models"), 643 | ) 644 | 645 | @mcp.tool() 646 | async def search_models(query: str) -> ModelList: 647 | """Search for models using semantic search.""" 648 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 649 | result = await client.search_models(query) 650 | models = [Model(**model) for model in result["models"]] 651 | 652 | # Sort by run count as a proxy for popularity/reliability 653 | models.sort(key=lambda m: m.run_count if m.run_count else 0, reverse=True) 654 | 655 | return ModelList( 656 | models=models, next_cursor=result.get("next_cursor"), total_count=result.get("total_models") 657 | ) 658 | 659 | # Collection Management Tools 660 | @mcp.tool() 661 | async def list_collections() -> CollectionList: 662 | """List available model collections on Replicate.""" 663 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 664 | result = await client.list_collections() 665 | return CollectionList(collections=[Collection(**collection) for collection in result]) 666 | 667 | @mcp.tool() 668 | async def get_collection_details(collection_slug: str) -> Collection: 669 | """Get detailed information about a specific collection.""" 670 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 671 | result = await client.get_collection(collection_slug) 672 | return Collection(**result) 673 | 674 | # Hardware Tools 675 | @mcp.tool() 676 | async def list_hardware() -> HardwareList: 677 | """List available hardware options for running models.""" 678 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 679 | result = await client.list_hardware() 680 | return HardwareList(hardware=[Hardware(**hw) for hw in result]) 681 | 682 | # Template Tools 683 | @mcp.tool() 684 | def list_templates() -> dict[str, Any]: 685 | """List all available templates with their schemas.""" 686 | return { 687 | name: { 688 | "schema": template["parameter_schema"], 689 | "description": template.get("description", ""), 690 | "version": template.get("version", "1.0.0"), 691 | } 692 | for name, template in TEMPLATES.items() 693 | } 694 | 695 | @mcp.tool() 696 | def validate_template_parameters(input: dict[str, Any]) -> bool: 697 | """Validate parameters against a template schema.""" 698 | template_input = TemplateInput(**input) 699 | return True # If we get here, validation passed 700 | 701 | # Prediction Tools 702 | @mcp.tool() 703 | async def create_prediction(input: dict[str, Any], confirmed: bool = False) -> dict[str, Any]: 704 | """Create a new prediction using a specific model version on Replicate. 705 | 706 | Args: 707 | input: Model input parameters including version or model details 708 | confirmed: Whether the user has explicitly confirmed the generation 709 | 710 | Returns: 711 | Prediction details if confirmed, or a confirmation request if not 712 | """ 713 | # If not confirmed, return info about what will be generated 714 | if not confirmed: 715 | # Extract model info for display 716 | model_info = "" 717 | if "version" in input: 718 | model_info = f"version: {input['version']}" 719 | elif "model_owner" in input and "model_name" in input: 720 | model_info = f"model: {input['model_owner']}/{input['model_name']}" 721 | 722 | return { 723 | "requires_confirmation": True, 724 | "message": ( 725 | "⚠️ This will use Replicate credits to generate an image with these parameters:\n\n" 726 | f"Model: {model_info}\n" 727 | f"Prompt: {input.get('prompt', 'Not specified')}\n" 728 | f"Quality: {input.get('quality', 'balanced')}\n\n" 729 | "Please confirm if you want to proceed with the generation." 730 | ), 731 | } 732 | 733 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 734 | # If version is provided directly, use it 735 | if "version" in input: 736 | version = input.pop("version") 737 | # Otherwise, try to find the model and get its latest version 738 | elif "model_owner" in input and "model_name" in input: 739 | model_id = f"{input.pop('model_owner')}/{input.pop('model_name')}" 740 | search_result = await client.search_models(model_id) 741 | if not search_result["models"]: 742 | raise ValueError(f"Model not found: {model_id}") 743 | model = search_result["models"][0] 744 | if not model.get("latest_version"): 745 | raise ValueError(f"No versions found for model: {model_id}") 746 | version = model["latest_version"]["id"] 747 | else: 748 | raise ValueError("Must provide either 'version' or both 'model_owner' and 'model_name'") 749 | 750 | # Create prediction with remaining parameters as input 751 | result = await client.create_prediction(version=version, input=input, webhook=input.pop("webhook", None)) 752 | 753 | # Return result with prompt about waiting 754 | return { 755 | **result, 756 | "_next_prompt": "after_generation", # Signal to show the waiting prompt 757 | } 758 | 759 | @mcp.tool() 760 | async def get_prediction(prediction_id: str, wait: bool = False, max_retries: int | None = None) -> dict[str, Any]: 761 | """Get the status and results of a prediction.""" 762 | consecutive_errors = 0 763 | 764 | while True: 765 | try: 766 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 767 | response = await client.http_client.get( 768 | f"/predictions/{prediction_id}", 769 | headers={ 770 | "Authorization": f"Bearer {client.api_token}", 771 | "Content-Type": "application/json", 772 | }, 773 | ) 774 | response.raise_for_status() 775 | data = response.json() 776 | 777 | # Build URL message with all available links 778 | urls_msg = [] 779 | 780 | # Add streaming URL if available 781 | if data.get("urls", {}).get("stream"): 782 | urls_msg.append(f"🔄 Stream URL: {data['urls']['stream']}") 783 | 784 | # Add get URL if available 785 | if data.get("urls", {}).get("get"): 786 | urls_msg.append(f"📡 Status URL: {data['urls']['get']}") 787 | 788 | # Add cancel URL if available and still processing 789 | if data.get("urls", {}).get("cancel") and data["status"] in ["starting", "processing"]: 790 | urls_msg.append(f"🛑 Cancel URL: {data['urls']['cancel']}") 791 | 792 | # If prediction is complete and has an image output 793 | if data["status"] == "succeeded" and data.get("output"): 794 | # For image generation models, output is typically a list with the image URL as first item 795 | image_url = data["output"][0] if isinstance(data["output"], list) else data["output"] 796 | 797 | return { 798 | "id": data["id"], 799 | "status": "succeeded", 800 | "image_url": image_url, 801 | "resource_uri": f"generations://{prediction_id}", 802 | "image_resource_uri": f"generations://{prediction_id}/image", 803 | "created_at": data.get("created_at"), 804 | "completed_at": data.get("completed_at"), 805 | "metrics": data.get("metrics", {}), 806 | "urls": data.get("urls", {}), 807 | "message": ( 808 | "🎨 Generation completed successfully!\n\n" 809 | f"You can:\n" 810 | f"1. View the image at: {image_url}\n" 811 | f"2. Open it with your system viewer (just ask me to open it)\n" 812 | f"3. Access image data at: generations://{prediction_id}/image\n" 813 | f"4. Get metadata at: generations://{prediction_id}\n" 814 | f"5. Create variations or apply transformations\n\n" 815 | "Would you like me to open the image for you?\n\n" 816 | "Available URLs:\n" + "\n".join(urls_msg) 817 | ), 818 | } 819 | 820 | # If prediction failed or was cancelled 821 | if data["status"] in ["failed", "canceled"]: 822 | error_msg = data.get("error", "No error details available") 823 | return { 824 | "id": data["id"], 825 | "status": data["status"], 826 | "error": error_msg, 827 | "resource_uri": f"generations://{prediction_id}", 828 | "created_at": data.get("created_at"), 829 | "completed_at": data.get("completed_at"), 830 | "metrics": data.get("metrics", {}), 831 | "urls": data.get("urls", {}), 832 | "message": ( 833 | f"❌ Generation {data['status']}: {error_msg}\n\n" 834 | "Available URLs:\n" + "\n".join(urls_msg) 835 | ), 836 | } 837 | 838 | # If we're still processing and not waiting, return status 839 | if not wait: 840 | return { 841 | "id": data["id"], 842 | "status": data["status"], 843 | "resource_uri": f"generations://{prediction_id}", 844 | "created_at": data.get("created_at"), 845 | "started_at": data.get("started_at"), 846 | "metrics": data.get("metrics", {}), 847 | "urls": data.get("urls", {}), 848 | "message": ( 849 | f"⏳ Generation is {data['status']}...\n" 850 | "You can:\n" 851 | "1. Keep waiting (I'll check again)\n" 852 | "2. Use the URLs above to check progress yourself\n" 853 | "3. Cancel the generation if needed" 854 | ), 855 | } 856 | 857 | # Reset error count on successful request 858 | consecutive_errors = 0 859 | 860 | # Wait before polling again 861 | await asyncio.sleep(2.0) # Increased poll interval to reduce API load 862 | 863 | except Exception as e: 864 | logger.error(f"Error checking prediction status: {str(e)}") 865 | consecutive_errors += 1 866 | 867 | # Only stop if we have a max_retries set and exceeded it 868 | if max_retries is not None and consecutive_errors >= max_retries: 869 | return { 870 | "id": prediction_id, 871 | "status": "error", 872 | "error": str(e), 873 | "resource_uri": f"generations://{prediction_id}", 874 | "message": ( 875 | f"⚠️ Having trouble checking the prediction status (tried {consecutive_errors} times).\n\n" 876 | f"The prediction might still be running! You can:\n" 877 | "1. Try checking again in a few minutes\n" 878 | "2. Visit the status URL directly: " 879 | f"https://replicate.com/p/{prediction_id}\n" 880 | "3. Start a new check with a higher retry limit\n\n" 881 | f"Last error: {str(e)}" 882 | ), 883 | } 884 | 885 | # Wait with exponential backoff before retrying 886 | await asyncio.sleep(min(30, 2.0**consecutive_errors)) # Cap at 30 seconds 887 | 888 | @mcp.tool() 889 | async def cancel_prediction(prediction_id: str) -> dict[str, Any]: 890 | """Cancel a running prediction.""" 891 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 892 | response = await client.cancel_prediction(prediction_id) 893 | return await response.json() 894 | 895 | # Webhook Tools 896 | @mcp.tool() 897 | async def get_webhook_secret() -> str: 898 | """Get the signing secret for verifying webhook requests.""" 899 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 900 | return await client.get_webhook_secret() 901 | 902 | @mcp.tool() 903 | def verify_webhook(payload: WebhookPayload, signature: str, secret: str) -> bool: 904 | """Verify that a webhook request came from Replicate using HMAC-SHA256. 905 | 906 | Args: 907 | payload: The webhook payload to verify 908 | signature: The signature from the X-Replicate-Signature header 909 | secret: The webhook signing secret from get_webhook_secret 910 | 911 | Returns: 912 | True if signature is valid, False otherwise 913 | """ 914 | if not signature or not secret: 915 | return False 916 | 917 | # Convert payload to canonical JSON string 918 | payload_str = json.dumps(payload.model_dump(), sort_keys=True) 919 | 920 | # Calculate expected signature 921 | expected = hmac.new(secret.encode(), payload_str.encode(), hashlib.sha256).hexdigest() 922 | 923 | # Compare signatures using constant-time comparison 924 | return hmac.compare_digest(signature, expected) 925 | 926 | @mcp.tool() 927 | async def search_available_models( 928 | query: str, 929 | style: str | None = None, 930 | ) -> ModelList: 931 | """Search for available models matching the query. 932 | 933 | Args: 934 | query: Search query describing the desired model 935 | style: Optional style to filter by 936 | 937 | Returns: 938 | List of matching models with scores 939 | """ 940 | search_query = query 941 | if style: 942 | search_query = f"{style} style {search_query}" 943 | 944 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 945 | result = await client.search_models(search_query) 946 | models = [Model(**model) for model in result["models"]] 947 | 948 | # Score models but don't auto-select 949 | scored_models = [] 950 | for model in models: 951 | score = 0 952 | run_count = getattr(model, "run_count", 0) or 0 953 | score += min(50, (run_count / 1000) * 50) 954 | if getattr(model, "featured", False): 955 | score += 20 956 | if model.latest_version: 957 | score += 10 958 | tags = getattr(model, "tags", []) 959 | if style and any(style.lower() in tag.lower() for tag in tags): 960 | score += 15 961 | if "image" in tags or "text-to-image" in tags: 962 | score += 15 963 | scored_models.append((model, score)) 964 | 965 | # Sort by score but return all for user selection 966 | scored_models.sort(key=lambda x: x[1], reverse=True) 967 | return ModelList( 968 | models=[m[0] for m in scored_models], 969 | next_cursor=result.get("next_cursor"), 970 | total_count=result.get("total_count"), 971 | ) 972 | 973 | @mcp.tool() 974 | async def get_model_details(model_id: str) -> Model: 975 | """Get detailed information about a specific model. 976 | 977 | Args: 978 | model_id: Model identifier in format owner/name 979 | 980 | Returns: 981 | Detailed model information 982 | """ 983 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 984 | result = await client.get_model_details(model_id) 985 | return Model(**result) 986 | 987 | @mcp.tool() 988 | async def generate_image( 989 | prompt: str, 990 | style: str | None = None, 991 | quality: str = "balanced", 992 | width: int | None = None, 993 | height: int | None = None, 994 | num_outputs: int = 1, 995 | seed: int | None = None, 996 | ) -> dict[str, Any]: 997 | """Generate an image using the specified parameters.""" 998 | # Get quality preset parameters 999 | if quality not in QUALITY_PRESETS["presets"]: 1000 | quality = "balanced" 1001 | parameters = QUALITY_PRESETS["presets"][quality]["parameters"].copy() 1002 | 1003 | # Apply style preset if specified 1004 | if style: 1005 | if style in STYLE_PRESETS["presets"]: 1006 | style_params = STYLE_PRESETS["presets"][style]["parameters"] 1007 | # Merge prompt prefixes 1008 | if "prompt_prefix" in style_params: 1009 | prompt = f"{style_params['prompt_prefix']}, {prompt}" 1010 | # Copy other parameters 1011 | for k, v in style_params.items(): 1012 | if k != "prompt_prefix": 1013 | parameters[k] = v 1014 | 1015 | # Override size if specified 1016 | if width: 1017 | parameters["width"] = width 1018 | if height: 1019 | parameters["height"] = height 1020 | 1021 | # Add other parameters 1022 | parameters.update( 1023 | { 1024 | "prompt": prompt, 1025 | "num_outputs": num_outputs, 1026 | } 1027 | ) 1028 | if seed is not None: 1029 | parameters["seed"] = seed 1030 | 1031 | # Create prediction with SDXL model 1032 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 1033 | result = await client.create_prediction( 1034 | version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", # SDXL v1.0 1035 | input=parameters, 1036 | ) 1037 | 1038 | # Return resource information 1039 | return { 1040 | "resource_uri": f"generations://{result['id']}", 1041 | "status": result["status"], 1042 | "message": ( 1043 | "🎨 Starting your image generation...\n\n" 1044 | f"Prompt: {prompt}\n" 1045 | f"Style: {style or 'Default'}\n" 1046 | f"Quality: {quality}\n\n" 1047 | "Let me check the status of your generation. I'll use:\n" 1048 | f"`get_prediction(\"{result['id']}\", wait=true)`\n\n" 1049 | "This will let me monitor the progress and show you the image as soon as it's ready." 1050 | ), 1051 | "next_prompt": "after_generation", 1052 | "metadata": { # Add metadata for client use 1053 | "prompt": prompt, 1054 | "style": style, 1055 | "quality": quality, 1056 | "width": width, 1057 | "height": height, 1058 | "seed": seed, 1059 | "model": "SDXL v1.0", 1060 | "created_at": result.get("created_at"), 1061 | }, 1062 | } 1063 | 1064 | # Initialize subscription manager 1065 | subscription_manager = GenerationSubscriptionManager() 1066 | 1067 | @mcp.tool() 1068 | async def subscribe_to_generation(request: SubscriptionRequest) -> EmptyResult: 1069 | """Handle resource subscription requests.""" 1070 | if request.uri.startswith("generations://"): 1071 | session = ServerSession(request.session_id) 1072 | await subscription_manager.subscribe(request.uri, session) 1073 | return EmptyResult() 1074 | 1075 | @mcp.tool() 1076 | async def unsubscribe_from_generation(request: SubscriptionRequest) -> EmptyResult: 1077 | """Handle resource unsubscribe requests.""" 1078 | if request.uri.startswith("generations://"): 1079 | session = ServerSession(request.session_id) 1080 | await subscription_manager.unsubscribe(request.uri, session) 1081 | return EmptyResult() 1082 | 1083 | @mcp.resource("generations://{prediction_id}/image") 1084 | async def get_generation_image(prediction_id: str) -> BlobResourceContents: 1085 | """Get the image data for a completed generation.""" 1086 | async with ReplicateClient(api_token=os.getenv("REPLICATE_API_TOKEN")) as client: 1087 | result = await client.get_prediction_status(prediction_id) 1088 | 1089 | if result["status"] != "succeeded": 1090 | raise ValueError(f"Generation not completed: {result['status']}") 1091 | 1092 | if not result.get("output"): 1093 | raise ValueError("No image output available") 1094 | 1095 | # Get image URL 1096 | image_url = result["output"][0] if isinstance(result["output"], list) else result["output"] 1097 | 1098 | # Download image 1099 | async with httpx.AsyncClient() as http_client: 1100 | img_response = await http_client.get(image_url) 1101 | img_response.raise_for_status() 1102 | 1103 | # Determine mime type from URL extension 1104 | ext = image_url.split(".")[-1].lower() 1105 | mime_type = { 1106 | "png": "image/png", 1107 | "jpg": "image/jpeg", 1108 | "jpeg": "image/jpeg", 1109 | "gif": "image/gif", 1110 | "webp": "image/webp", 1111 | }.get(ext, "image/png") 1112 | 1113 | # Return blob contents 1114 | return BlobResourceContents( 1115 | type="blob", 1116 | mimeType=mime_type, 1117 | uri=image_url, 1118 | blob=base64.b64encode(img_response.content).decode("ascii"), 1119 | description="Generated image data", 1120 | ) 1121 | 1122 | @mcp.tool() 1123 | async def open_image_with_system(image_url: str) -> dict[str, Any]: 1124 | """Open an image URL with the system's default application. 1125 | 1126 | Args: 1127 | image_url: URL of the image to open 1128 | 1129 | Returns: 1130 | Dict containing status of the operation 1131 | """ 1132 | try: 1133 | # Open URL directly with system default 1134 | webbrowser.open(image_url) 1135 | 1136 | return {"status": "success", "message": "Image opened with system default application", "url": image_url} 1137 | except Exception as e: 1138 | logger.error(f"Failed to open image: {str(e)}") 1139 | return {"status": "error", "message": f"Failed to open image: {str(e)}", "url": image_url} 1140 | 1141 | return mcp 1142 | --------------------------------------------------------------------------------