├── .env.example ├── AIEDITS.md ├── AINOTES.md ├── API.md ├── DESIGN.md ├── Dockerfile ├── EDITS.md ├── MCP_INTEGRATION.md ├── MEMORY_MANAGEMENT.md ├── NOTES.md ├── README.md ├── app ├── __init__.py ├── admin │ ├── __init__.py │ └── routes.py ├── api │ ├── __init__.py │ ├── mcp.py │ └── routes.py ├── main │ ├── __init__.py │ └── routes.py ├── models │ ├── __init__.py │ ├── generator.py │ └── manager.py ├── static │ ├── css │ │ ├── gallery.css │ │ └── style.css │ ├── images │ │ └── favicon.png │ └── js │ │ ├── gallery.js │ │ ├── main.js │ │ └── modules │ │ ├── imageLoader.js │ │ ├── modalManager.js │ │ ├── selectionManager.js │ │ ├── shortcuts.js │ │ ├── uiUtils.js │ │ └── viewManager.js ├── templates │ ├── admin.html │ ├── gallery.html │ └── index.html └── utils │ ├── cleanup_models.py │ ├── config.py │ ├── db.py │ ├── download_models.py │ ├── image.py │ ├── logging_config.py │ ├── model_config_test.py │ ├── queue.py │ ├── schema.sql │ └── watchdog.py ├── docker-compose.yml ├── examples ├── ai_assistant_mcp_example.py └── mcp_client_example.py ├── instance └── config.py ├── media ├── basic.png ├── enrich.png ├── gallery.png ├── generate.png ├── queue.png └── single.png ├── requirements.txt ├── run.py ├── run.sh ├── test_generate.py ├── test_watchdog.py └── tests └── test_mcp.py /.env.example: -------------------------------------------------------------------------------- 1 | MODEL_FOLDER=./models 2 | IMAGES_FOLDER=./images 3 | EXTERNAL_MODEL_FOLDER= 4 | EXTERNAL_IMAGES_FOLDER= 5 | HF_TOKEN= 6 | FAL_KEY= 7 | OPENAI_ENDPOINT= 8 | OPENAI_API_KEY= 9 | OPENAI_MODEL= 10 | CIVITAI_API_KEY= 11 | REPLICATE_API_KEY= 12 | 13 | # Model Configuration 14 | # Format: MODEL_NAME=;;;;[;] 15 | # The is an optional JSON string for extra model settings (e.g., step config). 16 | # Example Fixed Steps: MODEL_X="fixed-step-model;some/repo;...;{\\"fixed_steps\\": 25}" 17 | # Example Step Range: MODEL_Y="range-step-model;another/repo;...;{\\"steps\\":{\\"min\\":10,\\"max\\":40,\\"default\\":20}}" 18 | # Note: When sourcing this file in a shell, quotes are required around values with semicolons 19 | # When used as .env file directly, quotes are optional but must be consistent 20 | # Example: MODEL_1=flux-1;black-forest-labs/FLUX.1-dev;FLUX base model;huggingface;true 21 | MODEL_1="flux-1;black-forest-labs/FLUX.1-dev;FLUX base model;huggingface;true;{\"steps\": {\"min\": 20, \"max\": 50, \"default\": 30}}" 22 | MODEL_2="sd-3.5;stabilityai/stable-diffusion-3.5-large;Stable Diffusion 3.5;huggingface;true;{\"steps\": {\"min\": 20, \"max\": 50, \"default\": 30}}" 23 | MODEL_3="flux-schnell;black-forest-labs/FLUX.1-schnell;FLUX Schnell;huggingface;true;{\"steps\": {\"min\": 4, \"max\": 20, \"default\": 10}}" 24 | MODEL_4="animagine-xl;cagliostrolab/animagine-xl-4.0;Animagine XL;huggingface;true;{\"steps\": {\"min\": 15, \"max\": 40, \"default\": 25}}" 25 | MODEL_5="sana-sprint;Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers;Sana Sprint 1.6B;huggingface;true;{\"fixed_steps\": 2}" 26 | 27 | # Remove the MODEL_CONFIG_FluxHF line and add this: 28 | MODEL_6="flux-hf-api;black-forest-labs/FLUX.1-dev;FLUX.1 via HF API;huggingface_api;true;{\"provider\": \"huggingface-inference-api\"}" 29 | MODEL_7="ltx-video-t2v-api;Lightricks/LTX-Video;LTX Text-to-Video (Fal.ai via Fal API Key);fal_api;true;{\"type\": \"t2v\", \"provider\": \"fal-ai\"}" 30 | MODEL_8="ltx-video-i2v-api;Lightricks/LTX-Video;LTX Image-to-Video (Fal.ai via Fal API Key);fal_api;true;{\"type\": \"i2v\", \"provider\": \"fal-ai\"}" 31 | MODEL_9="sd-3.5-hf-api;stabilityai/stable-diffusion-3.5-large;Stable Diffusion 3.5 via HF API;huggingface_api;true;{\"provider\": \"huggingface-inference-api\"}" 32 | 33 | # Enable/disable downloading specific models (values: true/false) 34 | DOWNLOAD_MODEL_1=true 35 | DOWNLOAD_MODEL_2=true 36 | DOWNLOAD_MODEL_3=true 37 | DOWNLOAD_MODEL_4=true 38 | DOWNLOAD_MODEL_5=true 39 | 40 | # --- Example for Replicate API based model --- 41 | # MODEL_CONFIG_Replicate_Example='{ 42 | # "repo": "owner/model-name:versionhash", # IMPORTANT: Use Replicate's specific model identifier here 43 | # "type": "sdxl", # Or other appropriate type 44 | # "description": "Example Model via Replicate", 45 | # "source": "huggingface_api", 46 | # "files": [], 47 | # "options_json": "{\"provider\": \"replicate\"}", 48 | # "download_enabled": true, 49 | # "notes": "Ensure REPLICATE_API_KEY is set in .env if using this." 50 | # }' 51 | 52 | # Enable/disable downloading specific models (values: true/false) 53 | # For MODEL_CONFIG_ lines, use the name after MODEL_CONFIG_ e.g., DOWNLOAD_REPLICATE_EXAMPLE=true -------------------------------------------------------------------------------- /AIEDITS.md: -------------------------------------------------------------------------------- 1 | # AI Edits Log 2 | 3 | ## Edit 2024-07-27_05 (Fix Missing Import) 4 | 5 | - **File:** `app/api/routes.py` 6 | - **Change:** Removed the import statement `from app.utils.rate_limit import rate_limit`. 7 | - **Reason:** This import was causing a `ModuleNotFoundError` during startup because the `app/utils/rate_limit.py` file was deleted in the previous step (Edit 2024-07-27_04) as part of removing the rate limiting feature. 8 | - **Aligned with AINOTES:** Yes, notes updated. 9 | 10 | ## Edit 2024-07-27_04 (Rate Limit Removal) 11 | 12 | - **Files:** 13 | - `app/api/routes.py`: Removed `@rate_limit` decorator from `generate_image`. 14 | - `app/utils/rate_limit.py`: Deleted the file. 15 | - `instance/config.py`: Removed rate limit config lines. 16 | - `app/__init__.py`: Removed default rate limit config from `from_mapping`. 17 | - **Change:** Completely removed the IP-based rate limiting functionality. 18 | - **Reason:** User requested removal as the feature was not desired. 19 | - **Aligned with AINOTES:** Yes, notes updated to reflect removal. 20 | 21 | ## Edit 2024-07-27_03 22 | 23 | - **File:** `instance/config.py` 24 | - **Change:** Created/modified file to set `ENABLE_RATE_LIMIT = False`. 25 | - **Reason:** To correctly disable IP-based rate limiting (now superseded by removal). 26 | - **Aligned with AINOTES:** Yes, updated notes reflect the correct config location. 27 | 28 | ## Edit 2024-07-27_02 29 | 30 | - **File:** `app/utils/config.py` 31 | - **Change:** Added `ENABLE_RATE_LIMIT = False` (Ineffective - wrong file). 32 | - **Reason:** Attempted to disable rate limiting. 33 | - **Aligned with AINOTES:** Noted as ineffective. 34 | 35 | ## Edit 2024-07-27_01 36 | 37 | - **File:** `app/models/generator.py` 38 | - **Change:** Removed the call to `_force_memory_cleanup()` within the `finally` block of the `_generation_lock` in the `process_job` method. 39 | - **Reason:** To improve performance by enabling model caching between jobs. 40 | - **Aligned with AINOTES:** Yes, updated notes reflect this performance optimization. 41 | 42 | ## Edit 2024-07-26_01 (Example Previous Edit) 43 | ... (other previous edits) ... -------------------------------------------------------------------------------- /AINOTES.md: -------------------------------------------------------------------------------- 1 | # Project Notes 2 | 3 | ## Current State 4 | - **Fixed `ModuleNotFoundError`** by removing the unused import for `rate_limit` in `app/api/routes.py`. 5 | - Removed IP-based rate limiting functionality entirely: 6 | - Removed `@rate_limit` decorator from `app/api/routes.py`. 7 | - Deleted `app/utils/rate_limit.py`. 8 | - Removed related config from `instance/config.py` and `app/__init__.py`. 9 | - Optimized image generation pipeline by removing forced model unloading after each job in `generator.py`. 10 | - This allows model caching between jobs using the same model, significantly reducing processing time and estimated wait times. 11 | - Investigated potential causes for long user wait times (identified aggressive memory cleanup as primary cause). 12 | 13 | ## Functions & APIs 14 | - Refer to `API.md` for API definitions. 15 | - `app/__init__.py`: Creates Flask app, loads config (from mapping and `instance/config.py`). *Rate limit config removed.* 16 | - `instance/config.py`: Instance-specific configuration overrides. *Rate limit config removed.* 17 | - `app/api/routes.py`: Defines API endpoints. *Removed rate_limit import.* 18 | - `app/models/generator.GenerationPipeline`: Manages job queue and image generation. 19 | - `process_job`: Processes a single job (modified to keep model loaded). 20 | - `_force_memory_cleanup`: (Now less frequently used) Unloads models and clears GPU memory. 21 | - `app/models.manager.ModelManager`: Handles loading, unloading, and generation with specific models. 22 | - `app/utils.queue.QueueManager`: Manages job persistence and status in the database. 23 | - `app/utils.image.ImageManager`: Handles image saving. 24 | - `app/utils/config.py`: Contains utility functions. 25 | - `app/static/js/main.js`: Handles frontend logic, including status polling and wait time estimation (`pollGenerationStatus`). 26 | 27 | ## Design Alignment (DESIGN.md) 28 | - Current changes align with minimizing complexity and improving performance/user experience. 29 | 30 | ## Known Issues / TODOs 31 | - Monitor performance after the optimization. 32 | - Ensure memory management remains stable without forced cleanup after *every* job (Watchdog should help). -------------------------------------------------------------------------------- /DESIGN.md: -------------------------------------------------------------------------------- 1 | - **Project Name:** CyberImage 2 | - **Overview:** Design a dynamic cyberpunk flux image generator that allows users to create, view, and explore images with advanced customization and user experience. 3 | - **Core Features:** 4 | - **Model Selection:** Users can choose from various generation models for different styles and effects. 5 | - **Queue System:** Implement a queuing system to manage multiple image generation requests efficiently. 6 | - **Dynamic Gallery View:** A gallery displaying all generated images in PNG or JPG format, with a grid view for effective browsing. 7 | - **Dynamic Generation View:** Real-time display of image generation progress and result preview. 8 | - **Enrich Prompt Button:** A feature that sends the current prompt to an OpenAI-compatible endpoint to expand the prompt into a more detailed and epic version. 9 | - **Image Display and Format:** Generated images should be displayed in high-quality PNG or JPG format. 10 | - **Prompt and Audit Logging:** Save all prompts along with generation data (model, time, and settings) in a SQLite table for auditing and reference. 11 | - **Image Details View:** Clicking on an image shows: 12 | - Image preview 13 | - Date and time of generation 14 | - Model used 15 | - Generation prompt 16 | - All step settings and seed values 17 | - **Unique Image Naming:** Each image is named using a unique ID for easy reference and retrieval. 18 | - **Technical Requirements: 19 | - **Technologies to be Used:** 20 | - **Python 3.12** for backend logic and system integrations. 21 | - **Flask** for building the web application with dynamic routing and responsive views. 22 | - **Hugging Face** for model integration and prompt enhancement. 23 | - **Docker** for containerization, ensuring consistent deployment across environments. 24 | - **Storage Design Consideration:** 25 | - External volumes or folders should be used for image storage to separate application data from container storage. 26 | - Image storage paths should be configurable for flexible deployment and scalability. 27 | - **Event Logging and Auditing:** 28 | - All user actions and system events are logged in SQLite for complete auditing. 29 | - An audit view is available to review SQLite audit data, including prompts, model selections, generation settings, and image views. 30 | - All system events are logged to the console using Python logging with detailed datetime stamps and event context for debugging and monitoring.** 31 | - **Frontend:** Dynamic and responsive UI for gallery and generation views. 32 | - **Backend:** Efficient handling of image generation requests and prompt enrichment via OpenAI-compatible API. 33 | - **Database:** SQLite database to store prompts, audit data, and image metadata. 34 | - **User Experience Flow:** 35 | 1. User enters a prompt and selects a generation model. 36 | 2. User can enhance the prompt using the Enrich Prompt Button. 37 | - cinematic 38 | - asymmetrical 39 | - unique 40 | - rule of thirds 41 | - strong negative space 42 | - production photography 43 | 3. Generation request is added to the queue and processed. 44 | 4. Image generation progress is shown in real-time. 45 | 5. Result is displayed in the dynamic generation view. 46 | 6. Image is saved in the gallery with the unique ID. 47 | 7. User can click on an image to view details and audit information. 48 | - **- **UI Design:** 49 | - **Color Scheme:** 50 | - Primary Color: Neon Green 51 | - Background: Black 52 | - Data Text: White 53 | - **Theme and Style:** 54 | - Cyberpunk dynamic, vivid, and beautiful aesthetics. 55 | - High contrast with vibrant neon elements for a futuristic feel. 56 | - Sleek animations and fluid transitions to enhance user experience. 57 | - Dynamic layouts for a visually engaging experience. 58 | 59 | - **Additional Notes:**** 60 | - Ensure the UI/UX is sleek and intuitive for seamless navigation. 61 | - Optimize image loading and caching for faster gallery performance. 62 | - Secure API endpoints and data storage for user-generated content. 63 | 64 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Python 3.12 as base image 2 | FROM python:3.12-slim 3 | 4 | # Set working directory 5 | WORKDIR /app 6 | 7 | # Install system dependencies and newer libstdc++ 8 | RUN apt-get update && apt-get install -y \ 9 | build-essential \ 10 | curl \ 11 | git \ 12 | software-properties-common \ 13 | && rm -rf /var/lib/apt/lists/* \ 14 | && apt-get update \ 15 | && apt-get install -y gcc-12 g++-12 \ 16 | && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 100 \ 17 | && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 100 18 | 19 | # Create non-root user 20 | RUN useradd -u 1000 -m -s /bin/bash appuser \ 21 | && chown -R appuser:appuser /app 22 | 23 | RUN pip install --no-cache-dir uv 24 | 25 | # Install Python dependencies 26 | COPY --chown=appuser:appuser requirements.txt . 27 | RUN uv pip install --no-cache-dir -r requirements.txt --system 28 | 29 | # Copy application code 30 | COPY --chown=appuser:appuser . . 31 | ENV FLASK_APP=run.py 32 | 33 | # Create mount points and set permissions 34 | RUN mkdir -p /app/images /app/models \ 35 | && chown -R appuser:appuser /app 36 | 37 | # Switch to non-root user 38 | USER appuser 39 | 40 | # Expose port 41 | EXPOSE 5050 42 | 43 | # Run with Gunicorn in production mode - ensure single thread processing 44 | CMD ["gunicorn", "--workers=1", "--threads=1", "--bind=0.0.0.0:5050", "--timeout=120", "run:app"] -------------------------------------------------------------------------------- /MCP_INTEGRATION.md: -------------------------------------------------------------------------------- 1 | # Integrating CyberImage with AI Systems using MCP 2 | 3 | This guide explains how to integrate CyberImage with AI assistants and other systems using the Model Context Protocol (MCP). 4 | 5 | ## What is Model Context Protocol? 6 | 7 | The [Model Context Protocol (MCP)](https://spec.modelcontextprotocol.io/specification/2024-11-05/) is an open protocol that standardizes how AI systems interact with external tools and capabilities. It allows AI models to access context from various sources and use external tools through a standardized JSON-RPC interface. 8 | 9 | ## CyberImage MCP Integration 10 | 11 | CyberImage implements MCP to enable AI assistants to generate images directly using our image generation capabilities. The MCP integration uses the same default model as the web interface, ensuring consistency across all application interfaces. This default model is determined dynamically, prioritizing "flux-1" if available, or otherwise using the first available model in the configuration. 12 | 13 | This provides a standardized way for AI systems to: 14 | 15 | 1. Discover available image generation models 16 | 2. Generate images based on text prompts 17 | 3. Track the progress of image generation 18 | 4. Retrieve the resulting images 19 | 20 | ## MCP Endpoint 21 | 22 | The MCP endpoint is accessible at: 23 | 24 | ``` 25 | http://localhost:5050/api/mcp 26 | ``` 27 | 28 | For production deployments, replace `localhost:5050` with your actual server address. 29 | 30 | ## Supported Methods 31 | 32 | CyberImage supports the following MCP methods: 33 | 34 | ### 1. context.image_generation.models 35 | 36 | Lists all available image generation models. 37 | 38 | **Request:** 39 | ```json 40 | { 41 | "jsonrpc": "2.0", 42 | "method": "context.image_generation.models", 43 | "id": "request-123" 44 | } 45 | ``` 46 | 47 | **Response:** 48 | ```json 49 | { 50 | "jsonrpc": "2.0", 51 | "result": { 52 | "models": { 53 | "flux-2": { 54 | "id": "black-forest-labs/FLUX.1-dev", 55 | "description": "FP8 quantized version of FLUX-1 for memory efficiency" 56 | }, 57 | "flux-1": { 58 | "id": "black-forest-labs/FLUX.1-dev", 59 | "description": "High-quality image generation model optimized for detailed outputs" 60 | }, 61 | "sd-3.5": { 62 | "id": "stabilityai/stable-diffusion-3.5-large", 63 | "description": "Latest Stable Diffusion model with improved quality and speed" 64 | }, 65 | "flux-abliterated": { 66 | "id": "aoxo/flux.1dev-abliteratedv2", 67 | "description": "Modified FLUX model with enhanced capabilities" 68 | } 69 | }, 70 | "default": "flux-1" 71 | }, 72 | "id": "request-123" 73 | } 74 | ``` 75 | 76 | The `default` field in the response indicates the system's current default model, which is dynamically determined based on available models (prioritizing "flux-1" if available, otherwise using the first available model). 77 | 78 | ### 2. context.image_generation.generate 79 | 80 | Submits an image generation request. 81 | 82 | **Request:** 83 | ```json 84 | { 85 | "jsonrpc": "2.0", 86 | "method": "context.image_generation.generate", 87 | "params": { 88 | "prompt": "A detailed description of the image you want to generate", 89 | "negative_prompt": "What to exclude from the image (optional)", 90 | "model": "flux-1", 91 | "settings": { 92 | "num_images": 1, 93 | "num_inference_steps": 30, 94 | "guidance_scale": 7.5, 95 | "height": 1024, 96 | "width": 1024 97 | } 98 | }, 99 | "id": "request-123" 100 | } 101 | ``` 102 | 103 | If the `model` parameter is omitted, the system will use the default model as returned by the `context.image_generation.models` method. 104 | 105 | **Response:** 106 | ```json 107 | { 108 | "jsonrpc": "2.0", 109 | "result": { 110 | "job_id": "0c907066-985e-467b-8d11-07419060ef91", 111 | "status": "pending", 112 | "num_images": 1 113 | }, 114 | "id": "request-123" 115 | } 116 | ``` 117 | 118 | ### 3. context.image_generation.status 119 | 120 | Checks the status of a generation job. 121 | 122 | **Request:** 123 | ```json 124 | { 125 | "jsonrpc": "2.0", 126 | "method": "context.image_generation.status", 127 | "params": { 128 | "job_id": "0c907066-985e-467b-8d11-07419060ef91" 129 | }, 130 | "id": "request-123" 131 | } 132 | ``` 133 | 134 | **Response (when pending/processing):** 135 | ```json 136 | { 137 | "jsonrpc": "2.0", 138 | "result": { 139 | "job_id": "0c907066-985e-467b-8d11-07419060ef91", 140 | "status": "processing", 141 | "model": "flux-2", 142 | "prompt": "A detailed description...", 143 | "created_at": "2024-03-20T10:30:00Z", 144 | "started_at": "2024-03-20T10:30:01Z", 145 | "completed_at": null, 146 | "progress": { 147 | "preparing": false, 148 | "loading_model": false, 149 | "generating": true, 150 | "saving": false, 151 | "completed": false, 152 | "failed": false, 153 | "step": 15, 154 | "total_steps": 30 155 | } 156 | }, 157 | "id": "request-123" 158 | } 159 | ``` 160 | 161 | **Response (when completed):** 162 | ```json 163 | { 164 | "jsonrpc": "2.0", 165 | "result": { 166 | "job_id": "0c907066-985e-467b-8d11-07419060ef91", 167 | "status": "completed", 168 | "model": "flux-2", 169 | "prompt": "A detailed description...", 170 | "created_at": "2024-03-20T10:30:00Z", 171 | "started_at": "2024-03-20T10:30:01Z", 172 | "completed_at": "2024-03-20T10:31:00Z", 173 | "progress": { 174 | "preparing": false, 175 | "loading_model": false, 176 | "generating": false, 177 | "saving": false, 178 | "completed": true, 179 | "failed": false, 180 | "step": 30, 181 | "total_steps": 30 182 | }, 183 | "images": [ 184 | { 185 | "id": "img_123456", 186 | "url": "/api/get_image/img_123456", 187 | "metadata": { 188 | "model_id": "flux-2", 189 | "prompt": "A detailed description...", 190 | "settings": { 191 | "num_inference_steps": 30, 192 | "guidance_scale": 7.5, 193 | "height": 1024, 194 | "width": 1024 195 | } 196 | } 197 | } 198 | ] 199 | }, 200 | "id": "request-123" 201 | } 202 | ``` 203 | 204 | ## Typical Integration Flow 205 | 206 | 1. **Model Discovery**: Call `context.image_generation.models` to discover available models. 207 | 2. **Image Generation**: Call `context.image_generation.generate` with a prompt and optionally a model selection. 208 | 3. **Status Polling**: Call `context.image_generation.status` periodically to check job progress. 209 | 4. **Image Retrieval**: Once the job is completed, use the URL from the response to retrieve the image. 210 | 211 | ## Integration Examples 212 | 213 | ### Python Example (Basic) 214 | 215 | ```python 216 | import requests 217 | import json 218 | import time 219 | 220 | # Configuration 221 | MCP_ENDPOINT = "http://localhost:5050/api/mcp" 222 | 223 | def mcp_request(method, params=None): 224 | """Make a request to the MCP endpoint""" 225 | payload = { 226 | "jsonrpc": "2.0", 227 | "method": method, 228 | "params": params or {}, 229 | "id": "request-" + str(int(time.time())) 230 | } 231 | 232 | response = requests.post(MCP_ENDPOINT, json=payload) 233 | if response.status_code == 200: 234 | return response.json() 235 | else: 236 | print(f"Error: {response.status_code}") 237 | return None 238 | 239 | def get_default_model(): 240 | """Get the default model from the system""" 241 | response = mcp_request("context.image_generation.models") 242 | if response and "result" in response: 243 | return response["result"]["default"] 244 | return None 245 | 246 | def generate_image(prompt, model=None, negative_prompt=None): 247 | """Generate an image using the MCP endpoint""" 248 | # Prepare parameters 249 | params = {"prompt": prompt} 250 | if model: 251 | params["model"] = model 252 | if negative_prompt: 253 | params["negative_prompt"] = negative_prompt 254 | 255 | # Submit job 256 | response = mcp_request("context.image_generation.generate", params) 257 | if not response or "error" in response: 258 | print("Error generating image") 259 | return None 260 | 261 | job_id = response["result"]["job_id"] 262 | print(f"Job submitted with ID: {job_id}") 263 | 264 | # Poll for completion 265 | while True: 266 | time.sleep(3) # Wait between polls 267 | status_response = mcp_request("context.image_generation.status", {"job_id": job_id}) 268 | 269 | if not status_response or "error" in status_response: 270 | print("Error checking status") 271 | break 272 | 273 | status = status_response["result"]["status"] 274 | print(f"Status: {status}") 275 | 276 | if status == "completed": 277 | # Get the image URL 278 | image_url = status_response["result"]["images"][0]["url"] 279 | full_url = f"http://localhost:5050{image_url}" if image_url.startswith("/") else image_url 280 | print(f"Image available at: {full_url}") 281 | return full_url 282 | 283 | elif status == "failed": 284 | print("Generation failed") 285 | break 286 | 287 | return None 288 | 289 | # Example usage 290 | if __name__ == "__main__": 291 | # First example: Using the system's default model 292 | default_model = get_default_model() 293 | print(f"Using system default model: {default_model}") 294 | 295 | image_url = generate_image("A beautiful sunset over mountains") 296 | if image_url: 297 | print(f"Success! Image available at: {image_url}") 298 | 299 | # Second example: Explicitly specifying a model 300 | image_url = generate_image("A futuristic cityscape", model="sd-3.5") 301 | if image_url: 302 | print(f"Success! Image available at: {image_url}") 303 | ``` 304 | 305 | ### Complete AI Assistant Integration 306 | 307 | For a complete AI assistant integration example, see the `examples/ai_assistant_mcp_example.py` file in the repository. 308 | 309 | ## Best Practices 310 | 311 | 1. **Handle Rate Limiting**: Implement exponential backoff when making repeated requests. 312 | 2. **Provide Detailed Prompts**: The quality of the generated image depends on the prompt detail. 313 | 3. **Use Model Discovery**: Call the `context.image_generation.models` method to discover available models and the current default model. 314 | 4. **Error Handling**: Be prepared to handle error responses from the API. 315 | 5. **Timeouts**: Set appropriate timeouts for both the initial request and status polling. 316 | 6. **User Consent**: When using in AI assistants, ensure users are aware of and consent to image generation. 317 | 318 | ## Troubleshooting 319 | 320 | ### Common Issues 321 | 322 | 1. **Connection Refused**: Ensure the CyberImage server is running and accessible. 323 | 2. **Invalid Method**: Check that you're using the correct method names as listed above. 324 | 3. **Job Not Found**: The job ID might be invalid or the job may have been cleaned up. 325 | 4. **Rate Limiting**: You may be making too many requests in a short period. 326 | 327 | ### Debugging 328 | 329 | For debugging, you can set the `DEBUG` environment variable to see more detailed logs: 330 | 331 | ```bash 332 | DEBUG=1 python your_integration_script.py 333 | ``` 334 | 335 | ## Getting Help 336 | 337 | If you encounter issues with MCP integration, please: 338 | 339 | 1. Check the examples in the `examples/` directory 340 | 2. Review the API documentation in `API.md` 341 | 3. Open an issue on the GitHub repository -------------------------------------------------------------------------------- /MEMORY_MANAGEMENT.md: -------------------------------------------------------------------------------- 1 | # CyberImage Memory Management 2 | 3 | This document describes the job monitoring and recovery systems in the CyberImage application. 4 | 5 | ## Overview 6 | 7 | To address stalled jobs and ensure reliable image generation, we've implemented several robust mechanisms: 8 | 9 | 1. **Job Watchdog**: Continuously monitors job execution, automatically recovering from stalled jobs. 10 | 2. **Enhanced Error Handling**: Improved cleanup procedures during normal and error conditions. 11 | 3. **Process Management**: Better handling of process termination to ensure all resources are released. 12 | 4. **Job Recovery**: Mechanisms to reset stalled jobs and retry failed ones. 13 | 14 | ## Important Note on GPU Memory Usage 15 | 16 | **100% GPU memory usage is normal and expected** when running Flux models. The models are designed to utilize all available GPU memory for optimal performance. Our monitoring system focuses on detecting stalled jobs rather than memory usage thresholds. 17 | 18 | ## Job Watchdog (`app/utils/watchdog.py`) 19 | 20 | The Job Watchdog provides automatic monitoring and recovery for stalled jobs: 21 | 22 | - **Job Monitoring**: Tracks job execution time and detects stalled jobs. 23 | - **Stalled Job Detection**: Identifies jobs that have been processing for too long (default: 15 minutes). 24 | - **Emergency Recovery**: Resets stalled jobs and cleans up resources when necessary. 25 | 26 | Configuration parameters: 27 | - `max_stalled_time`: Maximum time a job can be processing (default: 900 seconds / 15 minutes) 28 | - `check_interval`: Time between health checks (default: 30 seconds) 29 | - `recovery_cooldown`: Minimum time between emergency recoveries (default: 300 seconds) 30 | 31 | ## Enhanced Error Handling (`app/models/manager.py`) 32 | 33 | The ModelManager now includes: 34 | 35 | - **Improved Error Handling**: Better cleanup in both normal and error paths. 36 | - **Forced Cleanup**: Resource reclamation when needed. 37 | 38 | Key methods: 39 | - `_force_memory_cleanup()`: Frees GPU resources when necessary. 40 | 41 | ## Process Management (`run.py`) 42 | 43 | The application startup has been enhanced to: 44 | 45 | - **Set Memory Environment Variables**: Optimizes PyTorch memory allocation. 46 | - **Register Cleanup Handlers**: Ensures resources are released on shutdown. 47 | - **Handle Signals**: Properly manages termination signals. 48 | 49 | ## Integration with GenerationPipeline 50 | 51 | The GenerationPipeline now: 52 | 53 | - **Initializes the Watchdog**: Starts monitoring when the application launches. 54 | - **Resets Stalled Jobs**: Recovers jobs that were processing during previous crashes. 55 | - **Manages Clean Shutdown**: Ensures the watchdog is properly stopped on application exit. 56 | 57 | ## Testing 58 | 59 | A test script is provided to verify the new features: 60 | 61 | ```bash 62 | python test_watchdog.py [--test {stall,load,all}] [--jobs NUM_JOBS] 63 | ``` 64 | 65 | This script can: 66 | - Simulate stalled jobs (note: this test takes at least 15 minutes to complete) 67 | - Run load tests with multiple jobs 68 | 69 | ## Best Practices 70 | 71 | To ensure stable operation: 72 | 73 | 1. Run the application with a single worker and single thread in production: 74 | ``` 75 | gunicorn -w 1 --threads=1 -b 0.0.0.0:5050 --timeout=120 run:app 76 | ``` 77 | 78 | This ensures that only one request is processed at a time, preventing memory issues from concurrent processing. 79 | 80 | 2. Monitor job status regularly: 81 | ``` 82 | curl http://localhost:5050/api/queue 83 | ``` 84 | 85 | 3. If issues persist, you can manually reset the application: 86 | ``` 87 | sudo systemctl restart cyberimage 88 | ``` 89 | 90 | ## Debugging Job Issues 91 | 92 | If you encounter stalled jobs: 93 | 94 | 1. Check the application logs for watchdog activity. 95 | 2. Verify that the watchdog is detecting and recovering stalled jobs. 96 | 3. Consider adjusting the `max_stalled_time` parameter if your jobs legitimately need more time to complete. 97 | 98 | ## Future Improvements 99 | 100 | Potential future enhancements: 101 | 102 | 1. Job retry with exponential backoff 103 | 2. Process-specific resource isolation 104 | 3. Job execution analytics and reporting 105 | 4. Dynamic adjustment of model settings based on job complexity -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CyberImage 2 | 3 | **AI Image Generation Platform** 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 |
17 | 18 | 🎨 Stable Diffusion | 🌍 Web Interface | ⚡ Fast | 🎯 Precise | 🖼️ Gallery | 🔄 Queue System | 📡 API | 🤖 MCP Support 19 | 20 |
21 | 22 | ## 🌟 Features 23 | 24 |
25 | 26 | | 🎨 **Image Generation** | 🖥️ **Web Interface** | 🖼️ **Gallery Features** | ⚡ **Performance** | 27 | |------------------------|---------------------|------------------------|-------------------| 28 | | 🤖 State-of-the-art AI models | 🌃 Cyberpunk-themed UI | 📊 Multiple view options | 💾 Smart model caching | 29 | | 🎛️ Customizable parameters | ⏱️ Real-time status updates | 🔍 Detailed image modal | 🧠 Memory optimization | 30 | | 🚫 Negative prompt support | 🖱️ Interactive gallery | ⬇️ Quick download actions | 🔄 Efficient queue system | 31 | | 📦 Batch image generation | 📱 Mobile-responsive design | 📋 Copy prompt feature | 🏃‍♂️ Background processing | 32 | | 📈 Real-time progress | 🌈 Beautiful UI | 🔍 Search and filtering | 🔒 No data collection | 33 | | 🎯 Precise control | 🎮 Intuitive controls | 🏷️ Tagging system | 🏠 Local image storage | 34 | | 🧩 Model compatibility | 🌙 Dark mode support | | | 35 | | 🤖 **MCP Integration** | 🔌 **AI Accessibility** | | | 36 | | 🔗 AI assistant support | 🔄 JSON-RPC interface | | | 37 | 38 |
39 | 40 | ## 🤖 Model Context Protocol (MCP) Support 41 | 42 | CyberImage now implements the [Model Context Protocol (MCP)](https://spec.modelcontextprotocol.io/specification/2024-11-05/), enabling AI assistants and other tools to seamlessly generate images through a standardized interface. 43 | 44 | ### What is MCP? 45 | 46 | MCP is an open protocol that enables AI systems to interact with external tools and capabilities in a standardized way. With MCP support, AI assistants can generate images directly through CyberImage using JSON-RPC calls. 47 | 48 | ### Key MCP Features 49 | 50 | - **JSON-RPC 2.0 Interface**: Simple, standardized format for all requests 51 | - **Dynamic Default Model**: Uses the system's default model (same as the web UI), prioritizing "flux-1" if available 52 | - **Seamless Queue Integration**: Jobs from AI assistants are integrated into the same queue as web UI requests 53 | - **Progress Tracking**: AI systems can track generation progress in real-time 54 | - **Standard Format**: Follows the MCP specification for interoperability with any MCP-compatible AI system 55 | 56 | ### Supported MCP Methods 57 | 58 | | Method | Description | 59 | |--------|-------------| 60 | | `context.image_generation.models` | List all available models | 61 | | `context.image_generation.generate` | Generate images based on a prompt | 62 | | `context.image_generation.status` | Check the status of a generation job | 63 | 64 | ### Using the MCP Endpoint 65 | 66 | AI assistants can connect to the MCP endpoint at: 67 | 68 | ``` 69 | http://localhost:5050/api/mcp 70 | ``` 71 | 72 | For implementation examples, see the `examples/` directory: 73 | - `mcp_client_example.py`: General MCP client implementation 74 | - `ai_assistant_mcp_example.py`: Specialized client for AI assistants 75 | 76 | ## The Enhance/Enrich Button 77 | 78 | The Enrich button uses the openai api to enhance the image. It uses the openai api key and the openai model to enhance the image. The openai model is the same as the model used for generation. 79 | 80 | This is my favorite feature, it allows you to take a basic image prompt and enhance it to make it better using a number of different techniques shared with myself by an AI expert. 81 | 82 | 83 | 84 | 85 | 86 | 87 |
88 | 89 | 90 | ## ⚡ Installation 91 | 92 | ### Requirements 93 | 94 | | Requirement | Specification | 95 | |-------------|---------------| 96 | | Python | 3.12 (for local installation) | 97 | | GPU | CUDA-capable | 98 | | RAM | 16GB+ recommended | 99 | | VRAM | 24GB+ recommended | 100 | | Disk Space | 250GB+ for models | 101 | | Container | Docker & Docker Compose (for containerized installation) | 102 | | API | Huggingface API Key (free! for downloading models) | 103 | 104 | ### Environment Setup 105 | 1. Copy the example environment file: 106 | ```bash 107 | cp .env.example .env 108 | ``` 109 | 110 | 2. Configure the environment variables in `.env`: 111 | 112 | - For a local install, you can use the MODEL_FOLDER and IMAGES_FOLDER to store the models and images locally in different directories. 113 | 114 | - For a docker install, you can use the EXTERNAL_MODEL_FOLDER and EXTERNAL_IMAGES_FOLDER to store the models and images externally in different directories. 115 | 116 | - For a docker install, you will need a Huggingface API Key to download the models. 117 | 118 | - The openai endpoints works fine with Ollama with 127.0.0.1:11434/v1 as the endpoint and the openai api key as the key, or host.docker.internal:11434/v1 as the endpoint and the openai api key as the key if ollama is running on the host machine. The model needs to be something in 127.0.0.1:11434/v1/models on your system. 119 | 120 | > ***If you don't configure the openai endpoint, the enrich prompt will not work.*** 121 | 122 | - The civitai api key is optional, it is used to download models from civitai (not configured currently). 123 | 124 | ```env 125 | MODEL_FOLDER=./models 126 | IMAGES_FOLDER=./images 127 | EXTERNAL_MODEL_FOLDER= 128 | EXTERNAL_IMAGES_FOLDER= 129 | HF_TOKEN= 130 | OPENAI_ENDPOINT= 131 | OPENAI_API_KEY= 132 | OPENAI_MODEL= 133 | CIVITAI_API_KEY= 134 | # Add provider-specific API keys here if needed, e.g.: 135 | # FAL_AI_API_KEY= 136 | ``` 137 | 138 | ### Hugging Face and Provider API Keys Workflow 139 | 140 | CyberImage utilizes API keys for various functionalities, primarily interacting with Hugging Face and specific model providers through their APIs. 141 | 142 | - **`HF_TOKEN`**: This is your general Hugging Face API token. It is primarily used for: 143 | - Downloading models from the Hugging Face Hub (both for local use and as a fallback for some API models). 144 | - Authenticating with the Hugging Face Inference API for models that don't have a specific third-party provider or when a provider-specific key isn't set. 145 | Ensure this token has the necessary permissions (usually 'read' access is sufficient for downloads and basic inference). 146 | 147 | - **Provider-Specific API Keys (e.g., `FAL_AI_API_KEY`)**: Some models available through the Hugging Face API are hosted or served by third-party providers (e.g., `fal-ai`, `replicate`). These providers often require their own API keys for access. 148 | - When you configure a model in `.env` that specifies a `provider` in its `options_json` (e.g., `{"type": "t2v", "provider": "fal-ai"}`), the application will look for a corresponding environment variable for that provider's API key. 149 | - The naming convention for these keys is typically `PROVIDER_NAME_API_KEY` (e.g., `FAL_AI_API_KEY`, `REPLICATE_API_KEY`). 150 | - If a provider-specific key is found and required by the model's configuration, it will be used for authentication with that provider's service. Otherwise, `HF_TOKEN` might be used as a fallback if the provider accepts it, or the API call might fail if a dedicated key is mandatory. 151 | 152 | **Configuration:** 153 | 154 | 1. Add your `HF_TOKEN` to the `.env` file. 155 | 2. If you plan to use models that rely on a specific provider (like `fal-ai` for LTX-Video), obtain an API key from that provider and add it to your `.env` file using the appropriate variable name (e.g., `FAL_AI_API_KEY=your_fal_ai_key_here`). 156 | 157 | The application's `ModelManager` is responsible for selecting the correct API token based on the model's configuration (`source` and `provider` fields). 158 | 159 | ### Docker Installation (Recommended) 160 | 161 | 1. Clone the repository: 162 | ```bash 163 | git clone https://github.com/ramborogers/cyberimage.git 164 | cd cyberimage 165 | ``` 166 | 2. Use the run.sh script to start the application (easiest): 167 | ```bash 168 | # This will start the application in a container 169 | ./run.sh start 170 | ``` 171 | 172 | Alternatively, use Docker or Docker Compose directly: 173 | 174 | 3. Using docker-compose: 175 | ```bash 176 | # Start the application 177 | docker-compose up -d 178 | 179 | # View logs 180 | docker-compose logs -f 181 | 182 | # Stop the application 183 | docker-compose down 184 | ``` 185 | 186 | 4. Using Docker CLI: 187 | ```bash 188 | # Build the image 189 | docker build -t cyberimage . 190 | 191 | # Run the container 192 | docker run -d \ 193 | --name cyberimage \ 194 | --gpus all \ 195 | -p 7860:5050 \ 196 | -v $(pwd)/models:/app/models \ 197 | -v $(pwd)/images:/app/images \ 198 | --env-file .env \ 199 | cyberimage 200 | ``` 201 | 202 | 5. Open in browser: 203 | ``` 204 | http://localhost:7860 205 | ``` 206 | 207 | ![CyberImage](media/generate.png) 208 | 209 | ![CyberImage](media/gallery.png) 210 | 211 | ![CyberImage](media/queue.png) 212 | 213 | ![CyberImage](media/single.png) 214 | 215 | 216 | ### Local Installation 217 | 218 | 1. Clone the repository: 219 | ```bash 220 | git clone https://github.com/ramborogers/cyberimage.git 221 | cd cyberimage 222 | ``` 223 | 224 | 2. Create a virtual environment: 225 | ```bash 226 | python -m venv venv 227 | source venv/bin/activate # Linux/Mac 228 | # or 229 | .\venv\Scripts\activate # Windows 230 | ``` 231 | 232 | 3. Install dependencies: 233 | ```bash 234 | pip install -r requirements.txt 235 | ``` 236 | 237 | 4. Download models: 238 | ```bash 239 | python download_models.py 240 | ``` 241 | 242 | ### Generation Parameters 243 | - **Model**: Choose from multiple AI models 244 | - **Prompt**: Describe your desired image 245 | - **Negative Prompt**: Specify elements to exclude 246 | - **Size**: Select output dimensions 247 | - **Steps**: Control generation quality 248 | - **Guidance**: Adjust prompt adherence 249 | - **Batch Size**: Generate multiple images 250 | 251 | ## 💡 Use Cases 252 | - **AI Assistant Integration**: Allow AI assistants to generate images based on user conversations 253 | - **Family Images**: My children love to use this 254 | - **Digital Art**: Create unique artwork 255 | - **Concept Design**: Generate design ideas 256 | - **Visual Inspiration**: Explore creative concepts 257 | - **Content Creation**: Generate visual content 258 | 259 | 260 | ## 🔧 Configuration 261 | 262 | Copy the .env.example file to .env and edit the .env file to configure the models you want to use. 263 | 264 | ## 🖼️ Managing Models 265 | 266 | CyberImage uses environment variables in the `.env` file to configure models. You can easily add, remove, or modify models by editing this file. 267 | 268 | ### Model Configuration Format 269 | 270 | Models are defined using the following format: 271 | 272 | ``` 273 | MODEL_=;;;;[;] 274 | ``` 275 | 276 | Where: 277 | - ``: Numerical index (1, 2, 3, etc.). Must be unique. 278 | - ``: Unique identifier for the model (used as directory name). 279 | - ``: HuggingFace repository path, GGUF file URL, or other model identifier. 280 | - ``: Human-readable description shown in the UI. 281 | - ``: Source platform (`huggingface`, `gguf_url`, `civitai`, etc.). Used by `download_models.py`. 282 | - ``: Whether authentication (e.g., HuggingFace token) is required (true/false). 283 | - `` (Optional): A JSON string containing additional model-specific parameters. 284 | - `{"fixed_steps": N}`: Forces the model to use a specific number of steps (e.g., for `sana-sprint`). 285 | - `{"type": "t2v"}`: Explicitly marks the model as Text-to-Video. 286 | - `{"type": "i2v"}`: Explicitly marks the model as Image-to-Video. 287 | - *Note:* The application also attempts auto-detection based on model names (e.g., "t2v", "i2v"). Explicit types take precedence. 288 | 289 | ### Adding/Disabling Models 290 | 291 | - **To add:** Add a new `MODEL_` line with a unique index and set `DOWNLOAD_MODEL_=true`. 292 | - **To disable download:** Set `DOWNLOAD_MODEL_=false`. The model remains in the UI if already downloaded. 293 | - **To remove:** Delete or comment out both the `MODEL_` and `DOWNLOAD_MODEL_` lines. 294 | 295 | ### Example Configuration 296 | 297 | Here's an example `.env` configuration showing various model types: 298 | 299 | ```env 300 | # --- Image Models --- 301 | MODEL_1="flux-1;black-forest-labs/FLUX.1-dev;FLUX Dev;huggingface;true" 302 | DOWNLOAD_MODEL_1=true 303 | 304 | MODEL_2="sd-3.5;stabilityai/stable-diffusion-3.5-large;Stable Diffusion 3.5;huggingface;true" 305 | DOWNLOAD_MODEL_2=true 306 | 307 | MODEL_3="flux-schnell;black-forest-labs/FLUX.1-schnell;FLUX Schnell;huggingface;true" 308 | DOWNLOAD_MODEL_3=true 309 | 310 | MODEL_4="sana-sprint;Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers;Sana Sprint 1.6B (Fixed Steps);huggingface;false;{\\"fixed_steps\\": 2}" 311 | DOWNLOAD_MODEL_4=true 312 | 313 | # --- Video Models --- 314 | # Text-to-Video (Wan) 315 | MODEL_5="wan-t2v-1.3b;Wan-AI/Wan2.1-T2V-1.3B-Diffusers;Wan Text2Video 1.3B;huggingface;false;{\\"type\\": \\"t2v\\"}" 316 | DOWNLOAD_MODEL_5=true 317 | 318 | # Image-to-Video (Wan) 319 | MODEL_6="wan-i2v-14b;Wan-AI/Wan2.1-I2V-14B-480P;Wan Image-to-Video (14B, 480p);huggingface;false;{\\"type\\": \\"i2v\\"}" 320 | DOWNLOAD_MODEL_6=true 321 | 322 | # Text-to-Video (LTX GGUF - requires download_models.py support for gguf_url) 323 | MODEL_7='LTX-Video-t2v;https://huggingface.co/city96/LTX-Video-gguf/resolve/main/ltx-video-2b-v0.9-Q3_K_S.gguf;LTX Video GGUF T2V;gguf_url;false;{"type": "t2v"}' 324 | DOWNLOAD_MODEL_7=true 325 | 326 | # Image-to-Video (LTX GGUF - uses same file as above) 327 | MODEL_8='LTX-Video-i2v;https://huggingface.co/city96/LTX-Video-gguf/resolve/main/ltx-video-2b-v0.9-Q3_K_S.gguf;LTX Video GGUF I2V;gguf_url;false;{"type": "i2v"}' 328 | DOWNLOAD_MODEL_8=false # Set to false if MODEL_7 already downloads the file 329 | 330 | # --- Other Examples --- 331 | # Disabled Model 332 | MODEL_9="animagine-xl;cagliostrolab/animagine-xl-4.0;Animagine XL;huggingface;true" 333 | DOWNLOAD_MODEL_9=false 334 | 335 | # Commented out/Removed Model (example for Civitai - requires API key and download_models.py support) 336 | # MODEL_10="my-custom-model;civitai:12345;My Custom Model;civitai;true" 337 | # DOWNLOAD_MODEL_10=true 338 | ``` 339 | 340 | *(Remember to restart the application after changing `.env`)* 341 | 342 | ## 🎬 Video Generation Support 343 | 344 | CyberImage extends beyond static images, offering powerful **Text-to-Video (T2V)** and **Image-to-Video (I2V)** generation capabilities! 345 | 346 | ### Features 347 | 348 | - **Text-to-Video (T2V):** Generate videos directly from text prompts using the main generation interface. 349 | - **Image-to-Video (I2V):** Bring your existing generated images to life by creating videos from them, initiated directly from the gallery. 350 | - **Supported Models:** Leverages advanced video models like `WanPipeline`, `WanImageToVideoPipeline`, and experimentally supports `LTXPipeline` (including GGUF variants). 351 | - **Seamless Integration:** Video jobs are handled by the same robust queue system as image jobs. 352 | - **Configuration:** Add and manage video models via the `.env` file as described in the **Managing Models** section above. The application identifies video models using the optional `type` parameter in the model configuration or by keywords in the model name (e.g., "t2v", "i2v"). 353 | 354 | ### Using Video Generation 355 | 356 | 1. **Text-to-Video:** 357 | * Select a T2V model (identified by name or `[Video]` tag) in the main model dropdown. 358 | * Enter your prompt. 359 | * Click "🎬 Generate Video". 360 | 2. **Image-to-Video:** 361 | * Go to the Gallery or view a single image. 362 | * Click the 🎥 (Generate Video) icon on the desired image. 363 | * A modal will appear; enter a video prompt and select an available I2V model. 364 | * Submit the job. 365 | 366 | Generated videos will appear in the gallery alongside images, with appropriate video player controls. 367 | 368 | ## 🤝 Contributing 369 | 370 | 1. Fork the repository 371 | 2. Create a feature branch 372 | 3. Commit your changes 373 | 4. Push to the branch 374 | 5. Create a Pull Request 375 | 376 |
377 | 378 | ## ⚖️ License 379 | 380 |

381 | CyberImage is licensed under the GNU General Public License v3.0 (GPLv3).
382 | Free Software 383 |

384 | 385 | [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg?style=for-the-badge)](https://www.gnu.org/licenses/gpl-3.0) 386 | 387 | ### Connect With Me 🤝 388 | 389 | [![GitHub](https://img.shields.io/badge/GitHub-RamboRogers-181717?style=for-the-badge&logo=github)](https://github.com/RamboRogers) 390 | [![Twitter](https://img.shields.io/badge/Twitter-@rogerscissp-1DA1F2?style=for-the-badge&logo=twitter)](https://x.com/rogerscissp) 391 | [![Website](https://img.shields.io/badge/Web-matthewrogers.org-00ADD8?style=for-the-badge&logo=google-chrome)](https://matthewrogers.org) 392 | 393 | ![RamboRogers](https://github.com/RamboRogers/netventory/raw/master/media/ramborogers.png) 394 | 395 |
396 | 397 | --- 398 | 399 | Made with 💚 by [Matthew Rogers] -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CyberImage - Epic AI Image Generation Service 3 | """ 4 | import os 5 | from flask import Flask, g, current_app, request 6 | from flask_cors import CORS 7 | from pathlib import Path 8 | from app.utils.logging_config import setup_logging 9 | from app.utils.config import get_available_models 10 | from app.models.generator import GenerationPipeline 11 | from contextlib import suppress 12 | import sys 13 | import signal 14 | 15 | def create_app(test_config=None): 16 | """Create and configure the Flask application""" 17 | app = Flask(__name__) 18 | CORS(app) 19 | 20 | # Get directory configurations from environment or use defaults 21 | images_dir = os.getenv("IMAGES_FOLDER", "./images") 22 | models_dir = os.getenv("MODEL_FOLDER", "./models") 23 | 24 | # Ensure required directories exist 25 | images_path = Path(images_dir) 26 | images_path.mkdir(exist_ok=True) 27 | (images_path / "db").mkdir(exist_ok=True) # Create db directory under images 28 | 29 | models_path = Path(models_dir) 30 | models_path.mkdir(exist_ok=True) 31 | 32 | # Default configuration 33 | app.config.from_mapping( 34 | DATABASE=os.path.join(os.path.abspath(images_dir), "db", "cyberimage.sqlite"), 35 | IMAGES_PATH=os.path.abspath(images_dir), 36 | MODELS_PATH=os.path.abspath(models_dir), 37 | MAX_QUEUE_SIZE=10, 38 | MODEL_CACHE_SIZE=1, # Only keep one model in memory at a time (memory constrained) 39 | ENABLE_XFORMERS=True, # Enable memory efficient attention 40 | USE_HALF_PRECISION=True, # Use FP16 for models when possible 41 | CLEANUP_INTERVAL=3600, # Cleanup stalled jobs every hour 42 | MAX_PROMPT_LENGTH=800, 43 | DEFAULT_STEPS=30, 44 | DEFAULT_GUIDANCE_SCALE=7.5, 45 | DEFAULT_MODEL="flux-1", # Updated default model 46 | AVAILABLE_MODELS=get_available_models() 47 | ) 48 | 49 | # Load API tokens from environment variables 50 | app.config['HF_TOKEN'] = os.getenv('HF_TOKEN') 51 | app.config['FAL_AI_API_KEY'] = os.getenv('FAL_AI_API_KEY') 52 | app.config['REPLICATE_API_KEY'] = os.getenv('REPLICATE_API_KEY') 53 | 54 | # Force single process mode in production 55 | if not app.debug: 56 | os.environ["GUNICORN_CMD_ARGS"] = "--workers=1" 57 | 58 | if test_config is None: 59 | # Load the instance config, if it exists, when not testing 60 | app.config.from_pyfile("config.py", silent=True) 61 | else: 62 | # Load the test config if passed in 63 | app.config.update(test_config) 64 | 65 | # Setup logging 66 | setup_logging(app) 67 | 68 | # Initialize database 69 | from app.utils import db 70 | db.init_app(app) 71 | 72 | # Download models if needed 73 | with app.app_context(): 74 | from app.utils.download_models import download_all_models 75 | print("\n📥 Checking/Downloading required models...") 76 | sys.stdout.flush() 77 | download_all_models() 78 | 79 | def get_or_create_generator(): 80 | """Get existing generator or create a new one if needed""" 81 | if not hasattr(current_app, '_generator'): 82 | current_app._generator = GenerationPipeline() 83 | return current_app._generator 84 | 85 | @app.before_request 86 | def before_request(): 87 | """Ensure generator is available before API requests""" 88 | if request.endpoint and 'api.' in request.endpoint: 89 | g.generator = get_or_create_generator() 90 | 91 | # Replace the existing cleanup with a more targeted version 92 | def cleanup_handler(signum=None, frame=None): 93 | """Handle cleanup only on actual process termination""" 94 | try: 95 | if hasattr(current_app, '_generator'): 96 | generator = current_app._generator 97 | if generator and generator.is_running: 98 | print("\n🛑 Performing final cleanup on process termination...") 99 | sys.stdout.flush() 100 | generator.stop() 101 | print("✅ Cleanup completed") 102 | sys.stdout.flush() 103 | except Exception as e: 104 | print(f"\n⚠️ Final cleanup warning: {str(e)}") 105 | sys.stdout.flush() 106 | 107 | # Register the cleanup handler for process termination 108 | signal.signal(signal.SIGTERM, cleanup_handler) 109 | signal.signal(signal.SIGINT, cleanup_handler) 110 | 111 | # Register blueprints 112 | from app.api import bp as api_bp 113 | app.register_blueprint(api_bp, url_prefix="/api") 114 | 115 | # Register main blueprint for frontend routes 116 | from app.main import bp as main_bp 117 | app.register_blueprint(main_bp) 118 | 119 | # Register admin blueprint 120 | from app.admin import bp as admin_bp 121 | app.register_blueprint(admin_bp) 122 | 123 | # Register MCP blueprint 124 | from app.api.mcp import mcp_bp 125 | app.register_blueprint(mcp_bp, url_prefix="/api") 126 | 127 | # Health check endpoint 128 | @app.route("/health") 129 | def health_check(): 130 | generator = get_or_create_generator() 131 | try: 132 | from app.utils.queue import QueueManager 133 | queue_status = QueueManager.get_queue_status() 134 | return { 135 | "status": "healthy", 136 | "version": "1.0.0", 137 | "device": generator.model_manager._device if generator.model_manager else "not_initialized", 138 | "queue": { 139 | "in_memory_size": generator.generation_queue.qsize() if generator.generation_queue else 0, 140 | "pending": queue_status["pending"], 141 | "processing": queue_status["processing"], 142 | "completed": queue_status["completed"], 143 | "failed": queue_status["failed"], 144 | "total": queue_status["total"] 145 | }, 146 | "is_main_process": generator.is_main_process, 147 | "is_running": generator.is_running 148 | } 149 | except Exception as e: 150 | app.logger.error(f"Error in health check: {str(e)}") 151 | return { 152 | "status": "degraded", 153 | "version": "1.0.0", 154 | "error": str(e), 155 | "queue": { 156 | "in_memory_size": generator.generation_queue.qsize() if generator.generation_queue else 0 157 | }, 158 | "is_main_process": generator.is_main_process, 159 | "is_running": generator.is_running 160 | } 161 | 162 | app.logger.info("Application initialized successfully") 163 | return app -------------------------------------------------------------------------------- /app/admin/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | 3 | bp = Blueprint('admin', __name__) 4 | 5 | from app.admin import routes -------------------------------------------------------------------------------- /app/admin/routes.py: -------------------------------------------------------------------------------- 1 | from flask import render_template, jsonify, request, redirect, url_for, current_app 2 | from app.admin import bp 3 | from app.utils.queue import QueueManager 4 | from app.utils.image import ImageManager 5 | from app.models import AVAILABLE_MODELS 6 | import json 7 | from datetime import datetime 8 | from app.utils.db import get_db 9 | 10 | @bp.route('/admin') 11 | def admin_panel(): 12 | """Render the admin panel page""" 13 | # Get all jobs from the queue 14 | jobs = QueueManager.get_all_jobs() 15 | queue_status = QueueManager.get_queue_status() 16 | 17 | # Get generator status 18 | try: 19 | from app.models.generator import GenerationPipeline 20 | generator = GenerationPipeline() 21 | generator_status = { 22 | "is_running": generator.is_running, 23 | "is_main_process": generator.is_main_process 24 | } 25 | except Exception: 26 | generator_status = { 27 | "is_running": False, 28 | "is_main_process": False 29 | } 30 | 31 | return render_template('admin.html', 32 | jobs=jobs, 33 | queue_status=queue_status, 34 | generator_status=generator_status, 35 | available_models=AVAILABLE_MODELS) 36 | 37 | @bp.route('/admin/queue/clear', methods=['POST']) 38 | def clear_queue(): 39 | """Clear jobs from the queue with the specified status""" 40 | status = request.form.get('status', None) 41 | job_id = request.form.get('job_id', None) 42 | 43 | if job_id: 44 | # Delete a specific job 45 | QueueManager.delete_job(job_id) 46 | return jsonify({"success": True, "message": f"Job {job_id} deleted successfully"}) 47 | elif status: 48 | # Delete all jobs with the specified status 49 | count = QueueManager.clear_queue_by_status(status) 50 | return jsonify({"success": True, "message": f"Cleared {count} jobs with status '{status}'"}) 51 | else: 52 | # Delete all jobs (careful with this!) 53 | count = QueueManager.clear_all_jobs() 54 | return jsonify({"success": True, "message": f"Cleared {count} jobs from the queue"}) 55 | 56 | @bp.route('/admin/job/') 57 | def job_details(job_id): 58 | """Get details for a specific job""" 59 | job = QueueManager.get_job(job_id) 60 | images = ImageManager.get_images_for_job(job_id) if job else [] 61 | 62 | # For AJAX requests return JSON 63 | if request.headers.get('X-Requested-With') == 'XMLHttpRequest': 64 | return jsonify({ 65 | "job": job, 66 | "images": images 67 | }) 68 | 69 | # For regular requests, render the template 70 | return render_template('admin.html', 71 | job=job, 72 | images=images, 73 | available_models=AVAILABLE_MODELS) 74 | 75 | @bp.route('/admin/queue/retry', methods=['POST']) 76 | def retry_job(): 77 | """Retry a failed job (admin-initiated retries always force retry)""" 78 | job_id = request.form.get('job_id') 79 | 80 | if not job_id: 81 | return jsonify({ 82 | "success": False, 83 | "message": "No job ID provided" 84 | }) 85 | 86 | # Get job details before retry for logging 87 | db = get_db() 88 | try: 89 | # Get original job details 90 | job = db.execute( 91 | """ 92 | SELECT id, model_id, prompt, negative_prompt, settings, status, error_message 93 | FROM jobs 94 | WHERE id = ? 95 | """, 96 | (job_id,) 97 | ).fetchone() 98 | 99 | if not job: 100 | return jsonify({ 101 | "success": False, 102 | "message": f"Job {job_id} not found" 103 | }) 104 | 105 | # Only retry failed jobs 106 | if job["status"] != "failed": 107 | return jsonify({ 108 | "success": False, 109 | "message": f"Job {job_id} is not in failed state (current state: {job['status']})" 110 | }) 111 | 112 | # Extract settings 113 | settings = json.loads(job["settings"]) 114 | 115 | # Reset retry count to 0 - all admin retries are force retries 116 | settings["retry_count"] = 0 117 | 118 | # Log this admin retry 119 | settings["admin_retried"] = True 120 | settings["admin_retry_timestamp"] = datetime.now().isoformat() 121 | 122 | # Update the job to pending 123 | db.execute( 124 | """ 125 | UPDATE jobs 126 | SET status = 'pending', 127 | started_at = NULL, 128 | completed_at = NULL, 129 | settings = ?, 130 | error_message = ? 131 | WHERE id = ? 132 | """, 133 | ( 134 | json.dumps(settings), 135 | "Job retried by admin", 136 | job_id 137 | ) 138 | ) 139 | 140 | db.commit() 141 | current_app.logger.info(f"Admin panel: Retried job {job_id} (retry count reset)") 142 | 143 | return jsonify({ 144 | "success": True, 145 | "message": f"Job {job_id} has been retried with retry count reset" 146 | }) 147 | 148 | except Exception as e: 149 | db.rollback() 150 | current_app.logger.error(f"Error retrying job {job_id}: {str(e)}") 151 | return jsonify({ 152 | "success": False, 153 | "message": f"Error retrying job: {str(e)}" 154 | }) 155 | 156 | @bp.route('/admin/queue/retry-all', methods=['POST']) 157 | def retry_all_failed_jobs(): 158 | """Retry all failed jobs (admin-initiated retries always force retry)""" 159 | db = get_db() 160 | 161 | try: 162 | # Get all failed jobs 163 | jobs = db.execute( 164 | """ 165 | SELECT id, model_id, prompt, negative_prompt, settings, error_message 166 | FROM jobs 167 | WHERE status = 'failed' 168 | ORDER BY created_at DESC 169 | """ 170 | ).fetchall() 171 | 172 | if not jobs: 173 | return jsonify({ 174 | "success": True, 175 | "message": "No failed jobs to retry" 176 | }) 177 | 178 | # Log the attempt 179 | current_app.logger.info(f"Admin panel: Retrying all {len(jobs)} failed jobs") 180 | 181 | # Try to retry each job 182 | retry_count = 0 183 | retry_results = [] 184 | 185 | for job in jobs: 186 | job_id = job["id"] 187 | retry_result = { 188 | "job_id": job_id, 189 | "prompt": job["prompt"][:50] + "..." if len(job["prompt"]) > 50 else job["prompt"], 190 | "model": job["model_id"], 191 | "success": False 192 | } 193 | 194 | try: 195 | # Extract settings 196 | settings = json.loads(job["settings"]) 197 | 198 | # Reset retry count to 0 - all admin retries are force retries 199 | settings["retry_count"] = 0 200 | 201 | # Log this admin retry 202 | settings["admin_retried"] = True 203 | settings["admin_retry_timestamp"] = datetime.now().isoformat() 204 | 205 | # Update the job to pending 206 | db.execute( 207 | """ 208 | UPDATE jobs 209 | SET status = 'pending', 210 | started_at = NULL, 211 | completed_at = NULL, 212 | settings = ?, 213 | error_message = ? 214 | WHERE id = ? 215 | """, 216 | ( 217 | json.dumps(settings), 218 | "Job retried by admin", 219 | job_id 220 | ) 221 | ) 222 | 223 | retry_count += 1 224 | retry_result["success"] = True 225 | current_app.logger.info(f"Successfully retried job {job_id}") 226 | 227 | except Exception as e: 228 | retry_result["error"] = str(e) 229 | current_app.logger.error(f"Error retrying job {job_id}: {str(e)}") 230 | 231 | retry_results.append(retry_result) 232 | 233 | # Commit all changes 234 | db.commit() 235 | 236 | # Log summary 237 | current_app.logger.info(f"Bulk retry complete: {retry_count}/{len(jobs)} jobs successfully queued") 238 | 239 | return jsonify({ 240 | "success": True, 241 | "message": f"Retried {retry_count} failed jobs", 242 | "details": { 243 | "total_failed": len(jobs), 244 | "total_retried": retry_count, 245 | "results": retry_results 246 | } 247 | }) 248 | 249 | except Exception as e: 250 | db.rollback() 251 | current_app.logger.error(f"Error during bulk retry: {str(e)}") 252 | return jsonify({ 253 | "success": False, 254 | "message": f"Error retrying all jobs: {str(e)}" 255 | }) 256 | 257 | @bp.route('/admin/queue/force-retry', methods=['POST']) 258 | def force_retry_job(): 259 | """Force retry a failed job regardless of retry count""" 260 | job_id = request.form.get('job_id') 261 | 262 | if not job_id: 263 | return jsonify({ 264 | "success": False, 265 | "message": "No job ID provided" 266 | }) 267 | 268 | # Get job details before retry for logging 269 | db = get_db() 270 | try: 271 | # Get original job details 272 | job = db.execute( 273 | """ 274 | SELECT id, model_id, prompt, negative_prompt, settings, status, error_message 275 | FROM jobs 276 | WHERE id = ? 277 | """, 278 | (job_id,) 279 | ).fetchone() 280 | 281 | if not job: 282 | return jsonify({ 283 | "success": False, 284 | "message": f"Job {job_id} not found" 285 | }) 286 | 287 | # Only retry failed jobs 288 | if job["status"] != "failed": 289 | return jsonify({ 290 | "success": False, 291 | "message": f"Job {job_id} is not in failed state (current state: {job['status']})" 292 | }) 293 | 294 | # Extract settings 295 | settings = json.loads(job["settings"]) 296 | 297 | # Reset retry count to 0 298 | settings["retry_count"] = 0 299 | 300 | # Log this force retry 301 | settings["force_retried"] = True 302 | settings["force_retry_timestamp"] = datetime.now().isoformat() 303 | 304 | # Update the job to pending 305 | db.execute( 306 | """ 307 | UPDATE jobs 308 | SET status = 'pending', 309 | started_at = NULL, 310 | completed_at = NULL, 311 | settings = ?, 312 | error_message = ? 313 | WHERE id = ? 314 | """, 315 | ( 316 | json.dumps(settings), 317 | "Job force-retried by admin", 318 | job_id 319 | ) 320 | ) 321 | 322 | db.commit() 323 | current_app.logger.info(f"Admin panel: Force-retried job {job_id} (retry count reset)") 324 | 325 | return jsonify({ 326 | "success": True, 327 | "message": f"Job {job_id} has been force-retried with retry count reset" 328 | }) 329 | 330 | except Exception as e: 331 | db.rollback() 332 | current_app.logger.error(f"Error force-retrying job {job_id}: {str(e)}") 333 | return jsonify({ 334 | "success": False, 335 | "message": f"Error force-retrying job: {str(e)}" 336 | }) 337 | 338 | @bp.route('/admin/queue/force-retry-all', methods=['POST']) 339 | def force_retry_all_failed_jobs(): 340 | """Force retry all failed jobs regardless of retry count""" 341 | db = get_db() 342 | 343 | try: 344 | # Get all failed jobs 345 | jobs = db.execute( 346 | """ 347 | SELECT id, model_id, prompt, negative_prompt, settings, error_message 348 | FROM jobs 349 | WHERE status = 'failed' 350 | ORDER BY created_at DESC 351 | """ 352 | ).fetchall() 353 | 354 | if not jobs: 355 | return jsonify({ 356 | "success": True, 357 | "message": "No failed jobs to retry" 358 | }) 359 | 360 | # Log the attempt 361 | current_app.logger.info(f"Admin panel: Force-retrying all {len(jobs)} failed jobs") 362 | 363 | # Try to retry each job 364 | retry_count = 0 365 | retry_results = [] 366 | 367 | for job in jobs: 368 | job_id = job["id"] 369 | retry_result = { 370 | "job_id": job_id, 371 | "prompt": job["prompt"][:50] + "..." if len(job["prompt"]) > 50 else job["prompt"], 372 | "model": job["model_id"], 373 | "success": False 374 | } 375 | 376 | try: 377 | # Extract settings 378 | settings = json.loads(job["settings"]) 379 | 380 | # Reset retry count to 0 381 | settings["retry_count"] = 0 382 | 383 | # Log this force retry 384 | settings["force_retried"] = True 385 | settings["force_retry_timestamp"] = datetime.now().isoformat() 386 | 387 | # Update the job to pending 388 | db.execute( 389 | """ 390 | UPDATE jobs 391 | SET status = 'pending', 392 | started_at = NULL, 393 | completed_at = NULL, 394 | settings = ?, 395 | error_message = ? 396 | WHERE id = ? 397 | """, 398 | ( 399 | json.dumps(settings), 400 | "Job force-retried by admin", 401 | job_id 402 | ) 403 | ) 404 | 405 | retry_count += 1 406 | retry_result["success"] = True 407 | current_app.logger.info(f"Successfully force-retried job {job_id}") 408 | 409 | except Exception as e: 410 | retry_result["error"] = str(e) 411 | current_app.logger.error(f"Error force-retrying job {job_id}: {str(e)}") 412 | 413 | retry_results.append(retry_result) 414 | 415 | # Commit all changes 416 | db.commit() 417 | 418 | # Log summary 419 | current_app.logger.info(f"Bulk force-retry complete: {retry_count}/{len(jobs)} jobs successfully queued") 420 | 421 | return jsonify({ 422 | "success": True, 423 | "message": f"Force-retried {retry_count} failed jobs", 424 | "details": { 425 | "total_failed": len(jobs), 426 | "total_retried": retry_count, 427 | "results": retry_results 428 | } 429 | }) 430 | 431 | except Exception as e: 432 | db.rollback() 433 | current_app.logger.error(f"Error during bulk force-retry: {str(e)}") 434 | return jsonify({ 435 | "success": False, 436 | "message": f"Error force-retrying all jobs: {str(e)}" 437 | }) 438 | 439 | @bp.route('/admin/queue/restart-generator', methods=['POST']) 440 | def restart_generator(): 441 | """Force restart the generation pipeline to process pending jobs""" 442 | try: 443 | # Get the generator instance 444 | from app.models.generator import GenerationPipeline 445 | generator = GenerationPipeline() 446 | 447 | # Force re-initialization 448 | generator.is_running = False 449 | generator._initialize() 450 | 451 | # Trigger queue processing 452 | generator._queue_event.set() 453 | 454 | # Get current queue status 455 | queue_status = QueueManager.get_queue_status() 456 | pending_count = queue_status.get("pending", 0) 457 | 458 | current_app.logger.info(f"Admin panel: Generator pipeline restarted. {pending_count} pending jobs in queue.") 459 | 460 | return jsonify({ 461 | "success": True, 462 | "message": f"Generator pipeline restarted. {pending_count} pending jobs in queue.", 463 | "is_running": generator.is_running, 464 | "is_main_process": generator.is_main_process 465 | }) 466 | except Exception as e: 467 | current_app.logger.error(f"Error restarting generator: {str(e)}") 468 | return jsonify({ 469 | "success": False, 470 | "message": f"Error restarting generator: {str(e)}" 471 | }) -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | API Blueprint for CyberImage 3 | """ 4 | from flask import Blueprint 5 | 6 | bp = Blueprint("api", __name__) 7 | 8 | # Import routes after blueprint creation to avoid circular imports 9 | from app.api import routes 10 | from app.api.mcp import mcp_bp -------------------------------------------------------------------------------- /app/api/mcp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model Context Protocol (MCP) implementation for CyberImage API 3 | 4 | This module implements the MCP specification for AI image generation. 5 | The default model used for generation matches the system's DEFAULT_MODEL, 6 | which is dynamically determined based on available models 7 | (prioritizing flux-1 if available, otherwise using the first available model). 8 | """ 9 | import json 10 | import logging 11 | import uuid 12 | import time 13 | from flask import jsonify, request, current_app, Blueprint, g 14 | from app.models import AVAILABLE_MODELS, DEFAULT_MODEL 15 | from app.utils.queue import QueueManager 16 | from app.api.routes import APIError 17 | 18 | mcp_bp = Blueprint('mcp', __name__) 19 | 20 | # Configure logger 21 | logger = logging.getLogger(__name__) 22 | 23 | @mcp_bp.route("/mcp", methods=["POST"]) 24 | def handle_mcp(): 25 | """ 26 | Handle MCP requests according to the protocol specification 27 | https://spec.modelcontextprotocol.io/specification/2024-11-05/ 28 | """ 29 | try: 30 | # Parse the MCP request 31 | mcp_request = request.get_json() 32 | 33 | # Validate basic JSON-RPC 2.0 structure 34 | if not mcp_request or "jsonrpc" not in mcp_request or mcp_request["jsonrpc"] != "2.0": 35 | return jsonify({ 36 | "jsonrpc": "2.0", 37 | "error": { 38 | "code": -32600, 39 | "message": "Invalid Request: Not a valid JSON-RPC 2.0 request" 40 | }, 41 | "id": mcp_request.get("id", None) 42 | }), 400 43 | 44 | # Get method, params, and request ID 45 | method = mcp_request.get("method", "") 46 | params = mcp_request.get("params", {}) 47 | req_id = mcp_request.get("id", str(uuid.uuid4())) 48 | 49 | # Route to appropriate handler based on method 50 | if method == "context.image_generation.generate": 51 | result = handle_generate(params) 52 | return jsonify({ 53 | "jsonrpc": "2.0", 54 | "result": result, 55 | "id": req_id 56 | }) 57 | elif method == "context.image_generation.status": 58 | result = handle_status(params) 59 | return jsonify({ 60 | "jsonrpc": "2.0", 61 | "result": result, 62 | "id": req_id 63 | }) 64 | elif method == "context.image_generation.models": 65 | result = handle_models(params) 66 | return jsonify({ 67 | "jsonrpc": "2.0", 68 | "result": result, 69 | "id": req_id 70 | }) 71 | else: 72 | # Method not found 73 | return jsonify({ 74 | "jsonrpc": "2.0", 75 | "error": { 76 | "code": -32601, 77 | "message": f"Method not found: {method}" 78 | }, 79 | "id": req_id 80 | }), 404 81 | 82 | except Exception as e: 83 | logger.error(f"MCP error: {str(e)}") 84 | return jsonify({ 85 | "jsonrpc": "2.0", 86 | "error": { 87 | "code": -32000, 88 | "message": f"Server error: {str(e)}" 89 | }, 90 | "id": request.get_json().get("id", None) if request.get_json() else None 91 | }), 500 92 | 93 | def handle_generate(params): 94 | """ 95 | Handle image generation requests 96 | 97 | Expected params: 98 | { 99 | "prompt": "A detailed description of the image", 100 | "negative_prompt": "What to avoid in the image (optional)", 101 | "model": "model-id (optional, defaults to system's DEFAULT_MODEL)", 102 | "settings": { 103 | "height": 1024 (optional), 104 | "width": 1024 (optional), 105 | "num_inference_steps": 30 (optional), 106 | "guidance_scale": 7.5 (optional), 107 | "num_images": 1 (optional) 108 | } 109 | } 110 | """ 111 | # Validate required parameters 112 | if "prompt" not in params: 113 | raise APIError("Missing required parameter: prompt", 400) 114 | 115 | # Extract parameters 116 | prompt = params["prompt"] 117 | negative_prompt = params.get("negative_prompt", "") 118 | model_id = params.get("model", DEFAULT_MODEL) 119 | 120 | # Log which model is being used 121 | if "model" not in params: 122 | logger.info(f"MCP: Using default model: {model_id}") 123 | else: 124 | logger.info(f"MCP: Using specified model: {model_id}") 125 | 126 | # Validate model 127 | if model_id not in AVAILABLE_MODELS: 128 | available_models = ", ".join(AVAILABLE_MODELS.keys()) 129 | raise APIError(f"Invalid model. Available models: {available_models}", 400) 130 | 131 | # Extract settings 132 | settings = params.get("settings", {}) 133 | 134 | # Add negative prompt to settings if provided 135 | if negative_prompt: 136 | settings["negative_prompt"] = negative_prompt 137 | 138 | # Get number of images to generate (default to 1) 139 | num_images = min(int(settings.get("num_images", 1)), 8) 140 | settings["num_images"] = num_images 141 | 142 | # Create job in the database 143 | job_id = QueueManager.add_job(model_id, prompt, settings) 144 | 145 | # Add to generation queue 146 | g.generator.add_job({ 147 | "id": job_id, 148 | "model_id": model_id, 149 | "prompt": prompt, 150 | "negative_prompt": negative_prompt, 151 | "settings": settings 152 | }) 153 | 154 | logger.info(f"MCP: Added job to queue: {job_id} for {num_images} images") 155 | 156 | return { 157 | "job_id": job_id, 158 | "status": "pending", 159 | "num_images": num_images 160 | } 161 | 162 | def handle_status(params): 163 | """ 164 | Handle job status requests 165 | 166 | Expected params: 167 | { 168 | "job_id": "UUID of the job" 169 | } 170 | """ 171 | # Validate required parameters 172 | if "job_id" not in params: 173 | raise APIError("Missing required parameter: job_id", 400) 174 | 175 | job_id = params["job_id"] 176 | job_status = QueueManager.get_job(job_id) 177 | 178 | if not job_status: 179 | raise APIError(f"Job not found: {job_id}", 404) 180 | 181 | # Format response according to MCP standards 182 | response = { 183 | "job_id": job_id, 184 | "status": job_status["status"], 185 | "model": job_status["model_id"], 186 | "prompt": job_status["prompt"], 187 | "created_at": job_status.get("created_at", ""), 188 | "started_at": job_status.get("started_at", ""), 189 | "completed_at": job_status.get("completed_at", ""), 190 | "progress": job_status.get("progress", {}), 191 | } 192 | 193 | # Include images if job is completed 194 | if job_status["status"] == "completed" and "images" in job_status: 195 | response["images"] = [{ 196 | "id": img["id"], 197 | "url": f"/api/get_image/{img['id']}", 198 | "metadata": img.get("metadata", {}) 199 | } for img in job_status["images"]] 200 | 201 | return response 202 | 203 | def handle_models(params): 204 | """ 205 | Handle models listing requests 206 | 207 | Returns a dictionary containing: 208 | - All available models with their IDs and descriptions 209 | - The system's default model, which is determined dynamically 210 | based on the application configuration (prioritizing flux-1 211 | if available, otherwise the first available model) 212 | 213 | This ensures consistency between the MCP API and the web interface. 214 | """ 215 | # Return available models with the system default model 216 | return { 217 | "models": {name: {"id": info["id"], "description": info["description"]} 218 | for name, info in AVAILABLE_MODELS.items()}, 219 | "default": DEFAULT_MODEL 220 | } -------------------------------------------------------------------------------- /app/main/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | 3 | bp = Blueprint('main', __name__) 4 | 5 | from app.main import routes -------------------------------------------------------------------------------- /app/main/routes.py: -------------------------------------------------------------------------------- 1 | from flask import render_template, jsonify, current_app, request 2 | from app.main import bp 3 | from app.utils.queue import QueueManager 4 | from app.utils.image import ImageManager 5 | from app.models import AVAILABLE_MODELS 6 | import json 7 | 8 | @bp.route('/') 9 | @bp.route('/index') 10 | def index(): 11 | """Render the main page""" 12 | recent_images_raw = ImageManager.get_recent_images(limit=12) 13 | # Parse metadata JSON 14 | recent_images = [] 15 | for img in recent_images_raw: 16 | try: 17 | img_dict = dict(img) # Convert Row object to dict 18 | if isinstance(img_dict.get('metadata'), str): 19 | img_dict['metadata'] = json.loads(img_dict['metadata']) 20 | else: 21 | # Assume it's already a dict or handle appropriately 22 | img_dict['metadata'] = img_dict.get('metadata', {}) 23 | recent_images.append(img_dict) 24 | except (json.JSONDecodeError, TypeError) as e: 25 | current_app.logger.error(f"Error parsing metadata for recent image {img.get('id', '')}: {e}") 26 | img_dict = dict(img) 27 | img_dict['metadata'] = {} # Add empty dict on error 28 | recent_images.append(img_dict) 29 | 30 | return render_template('index.html', 31 | recent_images=recent_images, 32 | available_models=AVAILABLE_MODELS) 33 | 34 | @bp.route('/gallery') 35 | def gallery(): 36 | """Render the gallery page""" 37 | page = request.args.get('page', 1, type=int) 38 | model_id = request.args.get('model', None) 39 | sort_by = request.args.get('sort', 'newest') 40 | 41 | # Get images with pagination 42 | result_raw = ImageManager.get_all_images( 43 | page=page, 44 | per_page=24, 45 | model_id=model_id 46 | ) 47 | 48 | # Parse metadata JSON for main page load 49 | images_parsed = [] 50 | for img in result_raw['images']: 51 | try: 52 | img_dict = dict(img) # Convert Row object to dict 53 | if isinstance(img_dict.get('metadata'), str): 54 | img_dict['metadata'] = json.loads(img_dict['metadata']) 55 | else: 56 | img_dict['metadata'] = img_dict.get('metadata', {}) 57 | images_parsed.append(img_dict) 58 | except (json.JSONDecodeError, TypeError) as e: 59 | current_app.logger.error(f"Error parsing metadata for gallery image {img.get('id', '')}: {e}") 60 | img_dict = dict(img) 61 | img_dict['metadata'] = {} # Add empty dict on error 62 | images_parsed.append(img_dict) 63 | 64 | # Handle AJAX requests for infinite scroll 65 | if request.headers.get('X-Requested-With') == 'XMLHttpRequest': 66 | # For AJAX, we might only need basic info, or parse here too if needed by JS 67 | # For simplicity, returning basic info as before. 68 | # If JS needs parsed metadata, parse it here like above. 69 | ajax_images = [] 70 | for img in result_raw['images']: 71 | img_dict = dict(img) 72 | # Parse metadata ONLY if needed by infinite scroll JS 73 | # metadata = json.loads(img_dict['metadata']) if isinstance(img_dict.get('metadata'), str) else img_dict.get('metadata', {}) 74 | ajax_images.append({ 75 | 'id': img_dict['id'], 76 | # Use parsed metadata values if needed by JS, otherwise keep raw strings 77 | 'prompt': img_dict.get('prompt', ''), # Get prompt safely 78 | 'model_id': img_dict.get('model_id', 'Unknown'), # Get model_id safely 79 | 'created_at': img_dict['created_at'].isoformat() if img_dict['created_at'] else None, 80 | # Pass parsed type if JS needs it 81 | # 'media_type': metadata.get('type', 'image') 82 | }) 83 | return jsonify({'images': ajax_images}) 84 | 85 | return render_template('gallery.html', 86 | images=images_parsed, # Pass parsed images 87 | total_pages=result_raw['pages'], 88 | current_page=result_raw['current_page'], 89 | selected_model=model_id, 90 | sort_by=sort_by, 91 | available_models=AVAILABLE_MODELS) -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models initialization for CyberImage 3 | """ 4 | from app.utils.config import get_available_models 5 | 6 | # Get available models from configuration 7 | AVAILABLE_MODELS = get_available_models() 8 | 9 | # Set the default model (use the first model or flux-1 if available) 10 | DEFAULT_MODEL = next( 11 | (model_id for model_id in AVAILABLE_MODELS if model_id == "flux-1"), 12 | next(iter(AVAILABLE_MODELS.keys())) if AVAILABLE_MODELS else None 13 | ) 14 | 15 | # Make models accessible from this module 16 | __all__ = ["AVAILABLE_MODELS", "DEFAULT_MODEL"] -------------------------------------------------------------------------------- /app/static/images/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/app/static/images/favicon.png -------------------------------------------------------------------------------- /app/static/js/modules/imageLoader.js: -------------------------------------------------------------------------------- 1 | export class ImageLoader { 2 | constructor() { 3 | this.observer = null; 4 | this.initializeObserver(); 5 | } 6 | 7 | initializeObserver() { 8 | this.observer = new IntersectionObserver( 9 | (entries) => { 10 | entries.forEach(entry => { 11 | if (entry.isIntersecting) { 12 | const img = entry.target; 13 | this.loadImage(img); 14 | this.observer.unobserve(img); 15 | } 16 | }); 17 | }, 18 | { 19 | rootMargin: '50px 0px', 20 | threshold: 0.1 21 | } 22 | ); 23 | } 24 | 25 | loadImage(img) { 26 | if (!img.dataset.src) return; 27 | 28 | // Start loading animation 29 | img.classList.add('loading'); 30 | 31 | // Create a temporary image to load in background 32 | const tempImg = new Image(); 33 | 34 | tempImg.onload = () => { 35 | img.src = tempImg.src; 36 | img.classList.remove('loading'); 37 | img.classList.add('loaded'); 38 | }; 39 | 40 | tempImg.onerror = () => { 41 | img.classList.remove('loading'); 42 | img.classList.add('error'); 43 | // Add error placeholder 44 | img.src = '/static/images/error-placeholder.png'; 45 | }; 46 | 47 | tempImg.src = img.dataset.src; 48 | } 49 | 50 | observe(img) { 51 | if (!img.dataset.src) { 52 | img.dataset.src = img.src; 53 | img.src = '/static/images/placeholder.png'; 54 | } 55 | this.observer.observe(img); 56 | } 57 | 58 | observeAll() { 59 | document.querySelectorAll('.gallery-item img').forEach(img => { 60 | this.observe(img); 61 | }); 62 | } 63 | 64 | // Handle newly added images (e.g., from infinite scroll) 65 | handleNewImages(container) { 66 | container.querySelectorAll('img:not(.loaded)').forEach(img => { 67 | this.observe(img); 68 | }); 69 | } 70 | } 71 | 72 | export const imageLoader = new ImageLoader(); -------------------------------------------------------------------------------- /app/static/js/modules/selectionManager.js: -------------------------------------------------------------------------------- 1 | export class SelectionManager { 2 | constructor() { 3 | this.selectedItems = new Set(); 4 | this.batchOperationsEl = document.querySelector('.batch-operations'); 5 | this.selectedCountEl = document.querySelector('.selected-count'); 6 | this.initializeListeners(); 7 | } 8 | 9 | initializeListeners() { 10 | // Handle item selection 11 | document.addEventListener('click', (e) => { 12 | const item = e.target.closest('.gallery-item'); 13 | if (!item) return; 14 | 15 | if (e.shiftKey) { 16 | this.toggleSelection(item); 17 | } 18 | }); 19 | 20 | // Handle batch operations 21 | const batchButtons = document.querySelectorAll('.batch-actions button'); 22 | batchButtons.forEach(button => { 23 | button.addEventListener('click', () => { 24 | const action = button.classList[0].replace('batch-', ''); 25 | this.executeBatchAction(action); 26 | }); 27 | }); 28 | } 29 | 30 | toggleSelection(item) { 31 | const itemId = item.dataset.imageId; 32 | 33 | if (this.selectedItems.has(itemId)) { 34 | this.selectedItems.delete(itemId); 35 | item.classList.remove('selected'); 36 | } else { 37 | this.selectedItems.add(itemId); 38 | item.classList.add('selected'); 39 | } 40 | 41 | this.updateUI(); 42 | } 43 | 44 | selectAll() { 45 | document.querySelectorAll('.gallery-item').forEach(item => { 46 | const itemId = item.dataset.imageId; 47 | this.selectedItems.add(itemId); 48 | item.classList.add('selected'); 49 | }); 50 | this.updateUI(); 51 | } 52 | 53 | deselectAll() { 54 | this.selectedItems.clear(); 55 | document.querySelectorAll('.gallery-item').forEach(item => { 56 | item.classList.remove('selected'); 57 | }); 58 | this.updateUI(); 59 | } 60 | 61 | updateUI() { 62 | const count = this.selectedItems.size; 63 | 64 | if (count > 0) { 65 | this.batchOperationsEl.classList.add('visible'); 66 | this.selectedCountEl.textContent = `${count} selected`; 67 | } else { 68 | this.batchOperationsEl.classList.remove('visible'); 69 | } 70 | } 71 | 72 | async executeBatchAction(action) { 73 | const selectedIds = Array.from(this.selectedItems); 74 | 75 | switch (action) { 76 | case 'download': 77 | this.downloadSelected(selectedIds); 78 | break; 79 | case 'delete': 80 | await this.deleteSelected(selectedIds); 81 | break; 82 | case 'tag': 83 | this.showTagDialog(selectedIds); 84 | break; 85 | } 86 | } 87 | 88 | async downloadSelected(ids) { 89 | for (const id of ids) { 90 | const response = await fetch(`/api/get_image/${id}`); 91 | const blob = await response.blob(); 92 | const url = window.URL.createObjectURL(blob); 93 | const a = document.createElement('a'); 94 | a.href = url; 95 | a.download = `image-${id}.png`; 96 | document.body.appendChild(a); 97 | a.click(); 98 | document.body.removeChild(a); 99 | window.URL.revokeObjectURL(url); 100 | } 101 | } 102 | 103 | async deleteSelected(ids) { 104 | if (!confirm(`Are you sure you want to delete ${ids.length} images?`)) { 105 | return; 106 | } 107 | 108 | try { 109 | await Promise.all(ids.map(id => 110 | fetch(`/api/delete_image/${id}`, { method: 'DELETE' }) 111 | )); 112 | 113 | // Remove deleted items from DOM 114 | ids.forEach(id => { 115 | const item = document.querySelector(`[data-image-id="${id}"]`); 116 | if (item) item.remove(); 117 | }); 118 | 119 | this.selectedItems.clear(); 120 | this.updateUI(); 121 | } catch (error) { 122 | console.error('Error deleting images:', error); 123 | alert('Failed to delete some images. Please try again.'); 124 | } 125 | } 126 | 127 | showTagDialog(ids) { 128 | // Implementation for tag dialog 129 | console.log('Show tag dialog for:', ids); 130 | } 131 | } 132 | 133 | export const selectionManager = new SelectionManager(); -------------------------------------------------------------------------------- /app/static/js/modules/shortcuts.js: -------------------------------------------------------------------------------- 1 | export const SHORTCUTS = { 2 | 'j': 'Next image', 3 | 'k': 'Previous image', 4 | 'f': 'Toggle favorite', 5 | 'c': 'Copy prompt', 6 | 'd': 'Download image', 7 | '/': 'Focus search', 8 | 'esc': 'Close modal/clear selection', 9 | 'space': 'Toggle selection', 10 | 'a': 'Select all', 11 | 'shift+a': 'Deselect all' 12 | }; 13 | 14 | class ShortcutManager { 15 | constructor() { 16 | this.handlers = new Map(); 17 | this.isEnabled = true; 18 | this.initializeShortcuts(); 19 | } 20 | 21 | initializeShortcuts() { 22 | document.addEventListener('keydown', (e) => { 23 | if (!this.isEnabled) return; 24 | 25 | // Don't trigger shortcuts when typing in input fields 26 | if (e.target.matches('input, textarea')) return; 27 | 28 | const key = this.getKeyCombo(e); 29 | const handler = this.handlers.get(key); 30 | 31 | if (handler) { 32 | e.preventDefault(); 33 | handler(); 34 | } 35 | }); 36 | } 37 | 38 | getKeyCombo(e) { 39 | const parts = []; 40 | if (e.shiftKey) parts.push('shift'); 41 | if (e.ctrlKey) parts.push('ctrl'); 42 | if (e.altKey) parts.push('alt'); 43 | parts.push(e.key.toLowerCase()); 44 | return parts.join('+'); 45 | } 46 | 47 | register(key, handler) { 48 | this.handlers.set(key.toLowerCase(), handler); 49 | } 50 | 51 | unregister(key) { 52 | this.handlers.delete(key.toLowerCase()); 53 | } 54 | 55 | disable() { 56 | this.isEnabled = false; 57 | } 58 | 59 | enable() { 60 | this.isEnabled = true; 61 | } 62 | } 63 | 64 | export const shortcutManager = new ShortcutManager(); -------------------------------------------------------------------------------- /app/static/js/modules/uiUtils.js: -------------------------------------------------------------------------------- 1 | // UI Utility Functions 2 | 3 | export function formatDate(date) { 4 | return new Intl.DateTimeFormat('default', { 5 | month: 'short', 6 | day: 'numeric', 7 | hour: 'numeric', 8 | minute: 'numeric', 9 | hour12: true 10 | }).format(date); 11 | } 12 | 13 | export function formatDateLong(date) { 14 | return new Intl.DateTimeFormat('default', { 15 | year: 'numeric', 16 | month: 'short', 17 | day: 'numeric', 18 | hour: 'numeric', 19 | minute: 'numeric', 20 | second: 'numeric', 21 | timeZoneName: 'short' 22 | }).format(date); 23 | } 24 | 25 | export function addNeonFlash(element) { 26 | element.style.boxShadow = 'var(--neon-green-glow)'; 27 | setTimeout(() => { 28 | element.style.boxShadow = ''; 29 | }, 1000); 30 | } 31 | 32 | export function showMainFeedback(message, type = 'info') { 33 | let feedbackDiv = document.getElementById('main-feedback-message'); 34 | if (!feedbackDiv) { 35 | feedbackDiv = document.createElement('div'); 36 | feedbackDiv.id = 'main-feedback-message'; 37 | feedbackDiv.style.position = 'fixed'; 38 | feedbackDiv.style.bottom = '20px'; 39 | feedbackDiv.style.left = '50%'; 40 | feedbackDiv.style.transform = 'translateX(-50%)'; 41 | feedbackDiv.style.padding = '10px 20px'; 42 | feedbackDiv.style.borderRadius = 'var(--border-radius)'; 43 | feedbackDiv.style.zIndex = '3000'; 44 | feedbackDiv.style.opacity = '0'; 45 | feedbackDiv.style.transition = 'opacity 0.3s ease'; 46 | document.body.appendChild(feedbackDiv); 47 | } 48 | 49 | feedbackDiv.textContent = message; 50 | if (type === 'success') { 51 | feedbackDiv.style.backgroundColor = 'rgba(57, 255, 20, 0.8)'; // Neon green bg 52 | feedbackDiv.style.color = 'black'; 53 | feedbackDiv.style.border = '1px solid var(--neon-green)'; 54 | } else if (type === 'error') { 55 | feedbackDiv.style.backgroundColor = 'rgba(255, 68, 68, 0.8)'; // Red bg 56 | feedbackDiv.style.color = 'white'; 57 | feedbackDiv.style.border = '1px solid #ff4444'; 58 | } else { 59 | feedbackDiv.style.backgroundColor = 'rgba(30, 30, 30, 0.9)'; // Dark bg for info 60 | feedbackDiv.style.color = 'white'; 61 | feedbackDiv.style.border = '1px solid var(--neon-green-dim)'; 62 | } 63 | 64 | // Fade in 65 | setTimeout(() => { feedbackDiv.style.opacity = '1'; }, 10); 66 | 67 | // Clear previous timeout if exists 68 | if (feedbackDiv.dataset.timeoutId) { 69 | clearTimeout(parseInt(feedbackDiv.dataset.timeoutId)); 70 | } 71 | 72 | // Set timeout to fade out 73 | const timeoutId = setTimeout(() => { 74 | feedbackDiv.style.opacity = '0'; 75 | // Optional: remove element after fade out 76 | // setTimeout(() => { feedbackDiv.remove(); }, 300); 77 | }, 3000); // Display for 3 seconds 78 | feedbackDiv.dataset.timeoutId = timeoutId.toString(); 79 | } 80 | 81 | // Helper function to format time in seconds to a readable format 82 | // This was already in main.js, but it fits well with other UI utils. 83 | export function formatTime(seconds) { 84 | if (seconds < 60) { 85 | return `${Math.round(seconds)} sec`; 86 | } else if (seconds < 3600) { 87 | return `${Math.round(seconds / 60)} min`; 88 | } else { 89 | const hours = Math.floor(seconds / 3600); 90 | const minutes = Math.round((seconds % 3600) / 60); 91 | return `${hours}h ${minutes}m`; 92 | } 93 | } -------------------------------------------------------------------------------- /app/static/js/modules/viewManager.js: -------------------------------------------------------------------------------- 1 | export class ViewManager { 2 | constructor() { 3 | this.currentView = 'grid'; 4 | this.container = document.querySelector('.gallery-grid'); 5 | this.initializeViewToggles(); 6 | } 7 | 8 | initializeViewToggles() { 9 | const toggles = document.querySelectorAll('.view-toggle'); 10 | toggles.forEach(toggle => { 11 | toggle.addEventListener('click', () => { 12 | const view = toggle.dataset.view; 13 | this.setView(view); 14 | 15 | // Update active state 16 | toggles.forEach(t => t.classList.remove('active')); 17 | toggle.classList.add('active'); 18 | }); 19 | }); 20 | } 21 | 22 | setView(view) { 23 | if (this.currentView === view) return; 24 | 25 | // Remove old view 26 | this.container.classList.remove(`view-${this.currentView}`); 27 | 28 | // Add new view 29 | this.container.classList.add(`view-${view}`); 30 | this.container.dataset.view = view; 31 | 32 | // Store current view 33 | this.currentView = view; 34 | 35 | // Save preference 36 | localStorage.setItem('preferred-view', view); 37 | 38 | // Dispatch event for other modules 39 | window.dispatchEvent(new CustomEvent('viewchange', { detail: { view } })); 40 | } 41 | 42 | getStoredView() { 43 | return localStorage.getItem('preferred-view') || 'grid'; 44 | } 45 | 46 | initialize() { 47 | const storedView = this.getStoredView(); 48 | this.setView(storedView); 49 | 50 | // Set active state on the correct toggle 51 | const activeToggle = document.querySelector(`[data-view="${storedView}"]`); 52 | if (activeToggle) { 53 | activeToggle.classList.add('active'); 54 | } 55 | } 56 | } 57 | 58 | export const viewManager = new ViewManager(); -------------------------------------------------------------------------------- /app/templates/gallery.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | CyberImage Gallery - AI Generated Images 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 35 | 36 |
37 | 62 | 63 | 124 |
125 | 126 | 127 | 142 | 143 | 185 | 186 | 196 | 197 | 198 | 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /app/utils/cleanup_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility to clean up model directories with issues 3 | """ 4 | import os 5 | from pathlib import Path 6 | import shutil 7 | import sys 8 | 9 | def print_status(message, status="info"): 10 | """Print a status message""" 11 | status_icons = { 12 | "info": "ℹ️", 13 | "success": "✅", 14 | "error": "❌", 15 | "warning": "⚠️", 16 | } 17 | icon = status_icons.get(status, "ℹ️") 18 | print(f"\n{icon} {message}") 19 | sys.stdout.flush() 20 | 21 | def cleanup_models_directory(models_dir): 22 | """Clean up the models directory""" 23 | print_status(f"Cleaning up models directory: {models_dir}") 24 | 25 | # Ensure the directory exists 26 | if not models_dir.exists(): 27 | print_status("Models directory does not exist", "error") 28 | return False 29 | 30 | # Find all problematic directories 31 | problem_dirs = [] 32 | 33 | for path in models_dir.iterdir(): 34 | if not path.is_dir(): 35 | continue 36 | 37 | # Check if the directory name has quotes or ends with _temp 38 | if '"' in path.name or "'" in path.name or path.name.endswith("_temp"): 39 | problem_dirs.append(path) 40 | 41 | if not problem_dirs: 42 | print_status("No problematic directories found", "success") 43 | return True 44 | 45 | # Report the directories found 46 | print_status(f"Found {len(problem_dirs)} problematic directories:", "warning") 47 | for path in problem_dirs: 48 | print(f" - {path.name}") 49 | 50 | # Ask for confirmation 51 | confirm = input("\nDelete these directories? (y/n): ").strip().lower() 52 | if confirm != 'y': 53 | print_status("Cleanup aborted", "info") 54 | return False 55 | 56 | # Delete the directories 57 | deleted = 0 58 | for path in problem_dirs: 59 | try: 60 | print_status(f"Removing {path.name}...", "info") 61 | shutil.rmtree(path) 62 | deleted += 1 63 | except Exception as e: 64 | print_status(f"Failed to remove {path.name}: {str(e)}", "error") 65 | 66 | print_status(f"Removed {deleted} of {len(problem_dirs)} directories", "success") 67 | return True 68 | 69 | def main(): 70 | """Main entry point""" 71 | # Get models directory from environment or use default 72 | models_dir = Path(os.getenv("MODEL_FOLDER", "./models")) 73 | 74 | print_status(f"Models Cleanup Utility\n\nModels directory: {models_dir}") 75 | cleanup_models_directory(models_dir) 76 | 77 | if __name__ == "__main__": 78 | main() -------------------------------------------------------------------------------- /app/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration utilities for CyberImage 3 | """ 4 | import os 5 | import logging 6 | import json 7 | from typing import Dict, List, Any 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def parse_model_config() -> Dict[str, Dict[str, Any]]: 12 | """ 13 | Parse model configuration from environment variables. 14 | 15 | Format: MODEL_=;;;;[;] 16 | Source options: huggingface (local), huggingface_api, fal_api 17 | Example Local: MODEL_1=flux-1;black-forest-labs/FLUX.1-dev;FLUX base model;huggingface;true 18 | Example HF API: MODEL_10=llava-hf/llava-1.5-7b-hf;llava-hf/llava-1.5-7b-hf;LLaVA 1.5 7B HF;huggingface_api;true;{\"provider\": \"huggingface-inference-api\", \"type\": \"vqa\"} 19 | Example Fal API: MODEL_11=ltx-video-i2v-api;Lightricks/LTX-Video;LTX Image-to-Video;fal_api;true;{\"provider\": \"fal-ai\", \"type\": \"i2v\", \"fal_function_id\": \"fal-ai/ltx-video-13b-dev/image-to-video\"} 20 | 21 | Returns: 22 | Dict mapping model names to their configurations 23 | """ 24 | models_config = {} 25 | enabled_models = set() 26 | 27 | # Common file lists for model completeness checks 28 | common_file_lists = { 29 | "flux": [ 30 | "model_index.json", 31 | "ae.safetensors" 32 | # Removed flux1-dev.safetensors to be more forgiving 33 | ], 34 | "flux-schnell": [ 35 | "model_index.json", 36 | "ae.safetensors" 37 | # No specific safetensors file check since it may vary 38 | ], 39 | "sd3": [ 40 | "model_index.json", 41 | "sd3.5_large.safetensors" 42 | ], 43 | "sdxl": [ 44 | "model_index.json", 45 | "vae/diffusion_pytorch_model.safetensors" 46 | ], 47 | "animagine-xl": [ 48 | "model_index.json", 49 | "unet/diffusion_pytorch_model.safetensors", 50 | "vae/diffusion_pytorch_model.safetensors", 51 | "text_encoder/model.safetensors", 52 | "text_encoder_2/model.safetensors" 53 | ], 54 | "flux-abliterated": [ 55 | "model_index.json", 56 | "transformer/config.json", 57 | "transformer/diffusion_pytorch_model-00001-of-00003.safetensors", 58 | "transformer/diffusion_pytorch_model-00002-of-00003.safetensors", 59 | "transformer/diffusion_pytorch_model-00003-of-00003.safetensors", 60 | "transformer/diffusion_pytorch_model.safetensors.index.json", 61 | "vae/config.json", 62 | "vae/diffusion_pytorch_model.safetensors" 63 | ] 64 | } 65 | 66 | # Find all model definitions in environment variables 67 | for key, value in os.environ.items(): 68 | # Skip MODEL_FOLDER which is a path setting, not a model definition 69 | if key == 'MODEL_FOLDER': 70 | continue 71 | 72 | if key.startswith('MODEL_') and value and '_' not in key[6:]: 73 | try: 74 | model_num = key[6:] # Extract number from MODEL_N 75 | 76 | # Strip quotes from the value 77 | value = value.strip() 78 | if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): 79 | value = value[1:-1] 80 | 81 | parts = value.split(';') 82 | 83 | # Updated check for minimum parts (5 required, 6th optional) 84 | if len(parts) < 5: 85 | logger.warning(f"Invalid model configuration format for {key}: {value} (needs at least 5 parts)") 86 | continue 87 | 88 | name, repo, description, source, requires_auth = parts[:5] 89 | options_json = parts[5] if len(parts) > 5 else None 90 | name = name.strip() 91 | 92 | # Parse the optional JSON configuration 93 | step_config = {} 94 | if options_json: 95 | options_json_stripped = options_json.strip('"\' ') # Clean up outer quotes/spaces 96 | options_json_cleaned = options_json_stripped.replace('\\"', '"') 97 | try: 98 | parsed_options = json.loads(options_json_cleaned) 99 | if isinstance(parsed_options, dict): 100 | step_config = parsed_options 101 | else: 102 | logger.warning(f"Optional config for {key} is not a JSON object: {options_json_cleaned}") 103 | except json.JSONDecodeError as json_err: 104 | logger.warning(f"Invalid JSON in optional config for {key}: {options_json_cleaned} - Error: {json_err}") 105 | # else: 106 | # print(f"DEBUG: No options_json found for {key}") 107 | 108 | # Further strip any quotes from individual parts 109 | name = name.strip('"\'') 110 | repo = repo.strip('"\'') 111 | description = description.strip('"\'') 112 | source_val = source.strip('"\'').lower() # Renamed to avoid conflict with 'source' module, and pre-process 113 | requires_auth = requires_auth.strip('"\'') 114 | 115 | # Determine if model is enabled based on its type 116 | # Models from API sources (huggingface_api, fal_api) are enabled by default unless explicitly disabled in options 117 | if source_val in ["huggingface_api", "fal_api"]: 118 | download_enabled = step_config.get('download_enabled', True) 119 | api_source_type = source_val.replace("_api", "").capitalize() 120 | logger.info(f"{api_source_type} API model {name} is {'enabled' if download_enabled else 'disabled'} for use") 121 | else: # Local models 122 | # For local models, check DOWNLOAD_MODEL_N environment variable 123 | download_key = f"DOWNLOAD_MODEL_{model_num}" 124 | download_value = os.environ.get(download_key, "true") 125 | # Strip quotes from download value too 126 | if (download_value.startswith('"') and download_value.endswith('"')) or (download_value.startswith("'") and download_value.endswith("'")): 127 | download_value = download_value[1:-1] 128 | download_enabled = download_value.lower() == "true" 129 | 130 | if not download_enabled: 131 | logger.info(f"Model {name} is defined but disabled for use") 132 | 133 | # Determine model type from name (heuristic, can be overridden by JSON) 134 | heuristic_model_type = "generic" # Start with generic 135 | if "flux" in name.lower(): 136 | heuristic_model_type = "flux" 137 | elif "sd-3" in name.lower(): 138 | heuristic_model_type = "sd3" 139 | elif "animagine" in name.lower(): 140 | heuristic_model_type = "sdxl" # Animagine XL uses SDXL architecture 141 | elif "xl" in name.lower() or "sdxl" in name.lower(): 142 | heuristic_model_type = "sdxl" 143 | 144 | # Get appropriate file list based on model name or heuristic type 145 | files = [] 146 | if source_val != "huggingface_api": # Only need files for local models 147 | if name in common_file_lists: 148 | files = common_file_lists[name] 149 | elif heuristic_model_type in common_file_lists: 150 | files = common_file_lists[heuristic_model_type] 151 | 152 | # If source is an API source (huggingface_api, fal_api), files are not applicable locally 153 | if source_val in ["huggingface_api", "fal_api"]: 154 | files = [] # API models don't have local files for completeness check 155 | logger.debug(f"Model {name} is from {source_val}, setting files to empty list.") 156 | 157 | # Validate provider for API sources 158 | if source_val in ["huggingface_api", "fal_api"]: 159 | if not isinstance(step_config, dict) or "provider" not in step_config: 160 | logger.error(f"Model {name} with source '{source_val}' is missing 'provider' in its JSON options. Skipping this model.") 161 | continue 162 | if not isinstance(step_config.get("provider"), str) or not step_config.get("provider").strip(): 163 | logger.error(f"Model {name} with source '{source_val}' has an invalid or empty 'provider' in its JSON options: '{step_config.get('provider')}'. Skipping this model.") 164 | continue 165 | logger.info(f"Configured {source_val} model {name} with provider: {step_config['provider']}") 166 | 167 | # Determine final model type: Explicit JSON type > Heuristic > Default ('image') 168 | final_model_type = 'image' # Default to image 169 | if heuristic_model_type != 'generic': # If heuristic found something specific 170 | final_model_type = heuristic_model_type 171 | 172 | # --- Override model_type if specified in step_config --- # 173 | if isinstance(step_config, dict) and 'type' in step_config: 174 | explicit_type = step_config['type'] 175 | if isinstance(explicit_type, str) and explicit_type: 176 | logger.debug(f"Overriding model type with explicit type '{explicit_type}' from JSON options for {name}") 177 | final_model_type = explicit_type.lower() # Use explicit type, lowercased 178 | else: 179 | logger.warning(f"'type' found in JSON options for {name}, but it's not a valid string: {explicit_type}") 180 | # --- End Override --- # 181 | 182 | # --- Standardize non-video types to 'image' --- # 183 | if final_model_type not in ['t2v', 'i2v']: 184 | if final_model_type != 'image': # Log if we are changing it 185 | logger.debug(f"Standardizing model type '{final_model_type}' to 'image' for {name}") 186 | final_model_type = 'image' 187 | # --- End Standardization --- # 188 | 189 | models_config[name] = { 190 | "repo": repo.strip(), 191 | "description": description.strip(), 192 | "source": source_val, # Use the processed source_val 193 | "requires_auth": requires_auth.lower() == "true", 194 | "download_enabled": download_enabled, # Represents if the model (local or API) is active 195 | "type": final_model_type, # Use the determined final type 196 | "files": files, # Will be empty for API sources like huggingface_api or fal_api 197 | "step_config": step_config # Add parsed step config (contains provider for API models) 198 | } 199 | 200 | # Store enabled models 201 | if download_enabled: 202 | enabled_models.add(name) 203 | 204 | logger.debug(f"Loaded model configuration for {name}") 205 | 206 | except Exception as e: 207 | logger.warning(f"Error parsing model configuration {key}: {str(e)}") 208 | 209 | # If no models are defined in environment, use default configuration 210 | if not models_config: 211 | logger.warning("No models found in environment variables, using defaults") 212 | models_config = { 213 | "flux-1": { 214 | "repo": "black-forest-labs/FLUX.1-dev", 215 | "description": "FLUX base model", 216 | "requires_auth": True, 217 | "source": "huggingface", 218 | "download_enabled": True, 219 | "type": "flux", 220 | "files": common_file_lists["flux"] 221 | }, 222 | "sd-3.5": { 223 | "repo": "stabilityai/stable-diffusion-3.5-large", 224 | "description": "Stable Diffusion 3.5", 225 | "requires_auth": True, 226 | "source": "huggingface", 227 | "download_enabled": True, 228 | "type": "sd3", 229 | "files": common_file_lists["sd3"] 230 | }, 231 | "flux-abliterated": { 232 | "repo": "aoxo/flux.1dev-abliteratedv2", 233 | "description": "FLUX Abliterated variant", 234 | "requires_auth": True, 235 | "source": "huggingface", 236 | "download_enabled": True, 237 | "type": "flux", 238 | "files": common_file_lists["flux-abliterated"] 239 | } 240 | } 241 | enabled_models = set(models_config.keys()) 242 | 243 | # Log summary of configured models 244 | logger.info(f"Loaded {len(models_config)} model configurations") 245 | logger.info(f"Models enabled for download: {', '.join(enabled_models)}") 246 | 247 | return models_config 248 | 249 | def get_downloadable_models() -> Dict[str, Dict[str, Any]]: 250 | """ 251 | Get only the models that are enabled for download 252 | 253 | Returns: 254 | Dict mapping model names to their configurations for enabled models 255 | """ 256 | all_models = parse_model_config() 257 | return {name: config for name, config in all_models.items() 258 | if config.get("download_enabled", True)} 259 | 260 | def get_available_models() -> Dict[str, Dict[str, Any]]: 261 | """ 262 | Get all available models for the UI (includes display metadata) 263 | 264 | Returns: 265 | Dict mapping model names to their UI configurations 266 | """ 267 | parsed_models = parse_model_config() 268 | ui_models = {} 269 | 270 | # Add display information and standardize format for UI 271 | for name, config in parsed_models.items(): 272 | # Start with a copy of the full original config to preserve all fields 273 | model_entry = config.copy() 274 | 275 | # Set/Override UI-specific fields 276 | model_entry["id"] = config.get("repo", name) # Use repo as ID, fallback to name 277 | model_entry["name"] = config.get("display_name", name) # Use display_name, fallback to name (original key) 278 | 279 | # Ensure 'source' is definitely there (it should be from parse_model_config) 280 | if 'source' not in model_entry: 281 | logger.warning(f"Source field was missing from parsed_config for {name} in get_available_models. This is unexpected.") 282 | model_entry['source'] = config.get('source', 'unknown') # Should already be there 283 | 284 | ui_models[name] = model_entry 285 | 286 | return ui_models 287 | 288 | # --- Rate Limit Configuration --- 289 | ENABLE_RATE_LIMIT = False # Set to False to disable IP-based hourly rate limiting 290 | RATE_LIMIT_HOURLY = 10 # Default: Max 10 requests per hour per IP (only applies if ENABLE_RATE_LIMIT is True) 291 | # --- End Rate Limit Configuration --- -------------------------------------------------------------------------------- /app/utils/db.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database utilities for CyberImage 3 | """ 4 | import sqlite3 5 | import click 6 | import os 7 | from flask import current_app, g 8 | from flask.cli import with_appcontext 9 | import json 10 | from datetime import datetime 11 | 12 | def get_db(): 13 | """Get database connection, initializing if necessary""" 14 | if "db" not in g: 15 | # Ensure the database directory exists 16 | db_path = current_app.config["DATABASE"] 17 | os.makedirs(os.path.dirname(db_path), exist_ok=True) 18 | 19 | # Check if we need to initialize the database 20 | needs_init = not os.path.exists(db_path) 21 | 22 | # Connect to the database 23 | g.db = sqlite3.connect( 24 | db_path, 25 | detect_types=sqlite3.PARSE_DECLTYPES 26 | ) 27 | g.db.row_factory = sqlite3.Row 28 | 29 | # Initialize if needed 30 | if needs_init: 31 | current_app.logger.info("Database does not exist, initializing...") 32 | init_db() 33 | current_app.logger.info("Database initialized successfully") 34 | 35 | return g.db 36 | 37 | def close_db(e=None): 38 | """Close database connection""" 39 | db = g.pop("db", None) 40 | if db is not None: 41 | db.close() 42 | 43 | def init_db(): 44 | """Initialize database schema""" 45 | db = g.db # Use existing connection from get_db() 46 | 47 | # Read and execute schema 48 | schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") 49 | with open(schema_path, "r") as f: 50 | db.executescript(f.read()) 51 | 52 | def init_app(app): 53 | """Register database functions with the Flask app""" 54 | app.teardown_appcontext(close_db) 55 | # We don't need the init-db command anymore since it's automatic 56 | # But we'll keep it for manual reinitialization if needed 57 | app.cli.add_command(init_db_command) 58 | 59 | @click.command("init-db") 60 | @with_appcontext 61 | def init_db_command(): 62 | """Clear the existing data and create new tables.""" 63 | init_db() 64 | click.echo("Initialized the database.") -------------------------------------------------------------------------------- /app/utils/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image utilities for CyberImage 3 | """ 4 | import os 5 | import uuid 6 | import json 7 | from datetime import datetime 8 | import pytz 9 | from typing import Dict, Optional, List, Union 10 | from PIL import Image 11 | from flask import current_app 12 | from app.utils.db import get_db 13 | 14 | class ImageManager: 15 | """Manages image storage and retrieval""" 16 | 17 | @staticmethod 18 | def _convert_to_local(utc_dt) -> datetime: 19 | """Convert UTC datetime to local time""" 20 | if not utc_dt: 21 | return None 22 | 23 | # Ensure UTC timezone is set 24 | if utc_dt.tzinfo is None: 25 | utc_dt = pytz.UTC.localize(utc_dt) 26 | 27 | # Convert to local timezone 28 | local_tz = datetime.now().astimezone().tzinfo 29 | return utc_dt.astimezone(local_tz) 30 | 31 | @staticmethod 32 | def save_image(image: Optional[Image.Image], job_id: str, metadata: Dict, image_id: Optional[str] = None, file_path: Optional[str] = None) -> str: 33 | """Save an image/video record in the database. If image is provided, save it. 34 | 35 | Args: 36 | image: PIL Image object (optional, provide None if file already saved, e.g., video) 37 | job_id: The ID of the job that generated the media. 38 | metadata: Dictionary containing metadata about the media. 39 | image_id: Optional pre-generated ID for the media record. 40 | file_path: Optional pre-generated relative file path for the media record. 41 | 42 | Returns: 43 | The ID of the saved media record. 44 | """ 45 | db = get_db() 46 | # Use provided ID or generate a new one 47 | media_id = image_id if image_id else str(uuid.uuid4()) 48 | 49 | # Get job information for metadata (optional, enhance metadata if job exists) 50 | try: 51 | job = db.execute( 52 | """ 53 | SELECT model_id, prompt, settings 54 | FROM jobs WHERE id = ? 55 | """, 56 | (job_id,) 57 | ).fetchone() 58 | 59 | if job: 60 | job_settings = json.loads(job["settings"]) 61 | # Update metadata with job info, ensure settings are merged/overwritten correctly 62 | metadata.update({ 63 | "model_id": metadata.get("model_id", job["model_id"]), # Prioritize metadata's model_id 64 | "prompt": metadata.get("prompt", job["prompt"]), # Prioritize metadata's prompt 65 | # Merge settings, prioritizing specific metadata settings over job settings 66 | "settings": {**job_settings, **metadata.get("settings", {})} 67 | }) 68 | else: 69 | current_app.logger.warning(f"Job {job_id} not found when saving media {media_id}") 70 | except Exception as e: 71 | current_app.logger.error(f"Error fetching job info {job_id} for media {media_id}: {e}") 72 | 73 | relative_path = None 74 | full_save_path = None 75 | 76 | if file_path: 77 | # Use pre-generated relative path 78 | relative_path = file_path 79 | full_save_path = os.path.join(current_app.config["IMAGES_PATH"], relative_path) 80 | current_app.logger.debug(f"Using pre-defined path for media {media_id}: {relative_path}") 81 | elif image: 82 | # Generate path and save the image file 83 | today = datetime.utcnow().strftime("%Y/%m/%d") 84 | image_dir = os.path.join(current_app.config["IMAGES_PATH"], today) 85 | os.makedirs(image_dir, exist_ok=True) 86 | 87 | file_name = f"{media_id}.png" 88 | relative_path = os.path.join(today, file_name) 89 | full_save_path = os.path.join(image_dir, file_name) 90 | current_app.logger.debug(f"Saving new image for media {media_id} to: {relative_path}") 91 | try: 92 | image.save(full_save_path, "PNG") 93 | except Exception as save_err: 94 | raise Exception(f"Failed to save image file to {full_save_path}: {save_err}") 95 | else: 96 | # Error: No image provided and no file_path provided 97 | raise ValueError("ImageManager.save_image requires either an image object or a file_path.") 98 | 99 | # Record in database 100 | try: 101 | db.execute( 102 | """ 103 | INSERT INTO images (id, job_id, file_path, metadata) 104 | VALUES (?, ?, ?, ?) 105 | """, 106 | (media_id, job_id, relative_path, json.dumps(metadata)) 107 | ) 108 | db.commit() 109 | current_app.logger.info(f"Saved media record {media_id} to database.") 110 | return media_id 111 | except Exception as e: 112 | db.rollback() 113 | # Clean up the image file only if we created it in this function call 114 | if full_save_path and not file_path and os.path.exists(full_save_path): 115 | current_app.logger.warning(f"Rolling back DB commit, deleting created file: {full_save_path}") 116 | os.remove(full_save_path) 117 | raise Exception(f"Failed to save media record {media_id} to database: {str(e)}") 118 | 119 | @staticmethod 120 | def get_image_path(image_id: str) -> Optional[str]: 121 | """Get the full path to an image""" 122 | db = get_db() 123 | result = db.execute( 124 | "SELECT file_path FROM images WHERE id = ?", 125 | (image_id,) 126 | ).fetchone() 127 | 128 | if result is None: 129 | return None 130 | 131 | return os.path.join(current_app.config["IMAGES_PATH"], result["file_path"]) 132 | 133 | @staticmethod 134 | def get_image_metadata(image_id: str) -> Optional[Dict]: 135 | """Get metadata for an image""" 136 | db = get_db() 137 | result = db.execute( 138 | "SELECT metadata FROM images WHERE id = ?", 139 | (image_id,) 140 | ).fetchone() 141 | 142 | if result is None: 143 | return None 144 | 145 | return json.loads(result["metadata"]) 146 | 147 | @staticmethod 148 | def get_image_info(image_id: str) -> Optional[Dict]: 149 | """Get complete information for a single image by ID""" 150 | db = get_db() 151 | img = db.execute( 152 | """ 153 | SELECT id, file_path, job_id, created_at, metadata 154 | FROM images 155 | WHERE id = ? 156 | """, 157 | (image_id,) 158 | ).fetchone() 159 | 160 | if img is None: 161 | return None 162 | 163 | return { 164 | "id": img["id"], 165 | "file_path": img["file_path"], 166 | "job_id": img["job_id"], 167 | "created_at": ImageManager._convert_to_local(img["created_at"]), 168 | "metadata": json.loads(img["metadata"]) 169 | } 170 | 171 | @staticmethod 172 | def get_job_images(job_id: str) -> list: 173 | """Get all images associated with a job""" 174 | db = get_db() 175 | images = db.execute( 176 | """ 177 | SELECT id, file_path, created_at, metadata 178 | FROM images 179 | WHERE job_id = ? 180 | ORDER BY created_at DESC 181 | """, 182 | (job_id,) 183 | ).fetchall() 184 | 185 | return [{ 186 | "id": img["id"], 187 | "file_path": img["file_path"], 188 | "created_at": ImageManager._convert_to_local(img["created_at"]), 189 | "metadata": json.loads(img["metadata"]) 190 | } for img in images] 191 | 192 | @staticmethod 193 | def get_recent_images(limit: int = 12) -> List[Dict]: 194 | """Get recent images with metadata""" 195 | db = get_db() 196 | images = db.execute( 197 | """ 198 | SELECT id, file_path, created_at, metadata 199 | FROM images 200 | ORDER BY created_at DESC 201 | LIMIT ? 202 | """, 203 | (limit,) 204 | ).fetchall() 205 | 206 | return [{ 207 | "id": img["id"], 208 | "created_at": img["created_at"], # SQLite returns datetime object directly 209 | "metadata": json.loads(img["metadata"]), 210 | "model_id": json.loads(img["metadata"])["model_id"], 211 | "prompt": json.loads(img["metadata"])["prompt"] 212 | } for img in images] 213 | 214 | @staticmethod 215 | def get_all_images(page: int = 1, per_page: int = 24, model_id: Optional[str] = None) -> Dict: 216 | """Get all images with optional filtering and pagination""" 217 | db = get_db() 218 | offset = (page - 1) * per_page 219 | 220 | # Build query based on filters 221 | query = """ 222 | SELECT id, file_path, created_at, metadata 223 | FROM images 224 | """ 225 | params = [] 226 | 227 | if model_id: 228 | query += " WHERE json_extract(metadata, '$.model_id') = ?" 229 | params.append(model_id) 230 | 231 | # Get total count 232 | count_query = f"SELECT COUNT(*) as total FROM ({query})" 233 | total = db.execute(count_query, params).fetchone()["total"] 234 | 235 | # Add pagination 236 | query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" 237 | params.extend([per_page, offset]) 238 | 239 | # Get images 240 | images = db.execute(query, params).fetchall() 241 | 242 | return { 243 | "images": [{ 244 | "id": img["id"], 245 | "created_at": img["created_at"], # SQLite returns datetime object directly 246 | "metadata": json.loads(img["metadata"]), 247 | "model_id": json.loads(img["metadata"])["model_id"], 248 | "prompt": json.loads(img["metadata"])["prompt"] 249 | } for img in images], 250 | "total": total, 251 | "pages": (total + per_page - 1) // per_page, 252 | "current_page": page 253 | } 254 | 255 | @staticmethod 256 | def delete_image(image_id: str) -> bool: 257 | """Delete an image and its associated database record""" 258 | db = get_db() 259 | 260 | try: 261 | # Get image path 262 | result = db.execute( 263 | "SELECT file_path FROM images WHERE id = ?", 264 | (image_id,) 265 | ).fetchone() 266 | 267 | if result is None: 268 | return False 269 | 270 | # Delete the file 271 | file_path = os.path.join(current_app.config["IMAGES_PATH"], result["file_path"]) 272 | if os.path.exists(file_path): 273 | os.remove(file_path) 274 | 275 | # Delete database record 276 | db.execute("DELETE FROM images WHERE id = ?", (image_id,)) 277 | db.commit() 278 | 279 | return True 280 | 281 | except Exception as e: 282 | db.rollback() 283 | current_app.logger.error(f"Failed to delete image {image_id}: {str(e)}") 284 | return False 285 | 286 | @staticmethod 287 | def get_images_for_job(job_id: str) -> List[Dict]: 288 | """Get all images associated with a specific job""" 289 | db = get_db() 290 | images = db.execute( 291 | """ 292 | SELECT id, file_path, created_at, metadata 293 | FROM images 294 | WHERE job_id = ? 295 | ORDER BY created_at DESC 296 | """, 297 | (job_id,) 298 | ).fetchall() 299 | 300 | return [{ 301 | "id": img["id"], 302 | "file_path": img["file_path"], 303 | "created_at": ImageManager._convert_to_local(img["created_at"]), 304 | "metadata": json.loads(img["metadata"]), 305 | "model_id": json.loads(img["metadata"])["model_id"], 306 | "prompt": json.loads(img["metadata"])["prompt"] 307 | } for img in images] -------------------------------------------------------------------------------- /app/utils/logging_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging configuration for CyberImage 3 | """ 4 | import os 5 | import logging 6 | import logging.handlers 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | def setup_logging(app): 11 | """Configure application-wide logging""" 12 | # Create logs directory 13 | log_dir = Path("logs") 14 | log_dir.mkdir(exist_ok=True) 15 | 16 | # Create formatters 17 | verbose_formatter = logging.Formatter( 18 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 19 | ) 20 | simple_formatter = logging.Formatter( 21 | '%(levelname)s - %(message)s' 22 | ) 23 | 24 | # Create handlers 25 | # Console handler 26 | console_handler = logging.StreamHandler() 27 | console_handler.setLevel(logging.INFO) 28 | console_handler.setFormatter(simple_formatter) 29 | 30 | # File handlers 31 | today = datetime.now().strftime("%Y-%m-%d") 32 | 33 | # Application log 34 | app_handler = logging.handlers.RotatingFileHandler( 35 | log_dir / f"app-{today}.log", 36 | maxBytes=10485760, # 10MB 37 | backupCount=10 38 | ) 39 | app_handler.setLevel(logging.DEBUG) 40 | app_handler.setFormatter(verbose_formatter) 41 | 42 | # Error log 43 | error_handler = logging.handlers.RotatingFileHandler( 44 | log_dir / f"error-{today}.log", 45 | maxBytes=10485760, # 10MB 46 | backupCount=10 47 | ) 48 | error_handler.setLevel(logging.ERROR) 49 | error_handler.setFormatter(verbose_formatter) 50 | 51 | # Model log 52 | model_handler = logging.handlers.RotatingFileHandler( 53 | log_dir / f"model-{today}.log", 54 | maxBytes=10485760, # 10MB 55 | backupCount=10 56 | ) 57 | model_handler.setLevel(logging.DEBUG) 58 | model_handler.setFormatter(verbose_formatter) 59 | 60 | # Configure root logger 61 | root_logger = logging.getLogger() 62 | root_logger.setLevel(logging.DEBUG) 63 | root_logger.addHandler(console_handler) 64 | root_logger.addHandler(app_handler) 65 | root_logger.addHandler(error_handler) 66 | 67 | # Configure specific loggers 68 | model_logger = logging.getLogger("app.models") 69 | model_logger.addHandler(model_handler) 70 | 71 | # Disable propagation for model logger to avoid duplicate logs 72 | model_logger.propagate = False 73 | 74 | # Log startup 75 | app.logger.info("Logging system initialized") -------------------------------------------------------------------------------- /app/utils/model_config_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for model configuration 3 | """ 4 | import os 5 | import sys 6 | from config import parse_model_config, get_downloadable_models, get_available_models 7 | 8 | def main(): 9 | print("\n===== Testing Model Configuration =====\n") 10 | 11 | # Test environment variables 12 | os.environ["MODEL_1"] = "test-model-1;repo/test;Test Model;huggingface;true" 13 | os.environ["MODEL_2"] = "test-model-2;org/repo;Another Model;huggingface;true" 14 | os.environ["DOWNLOAD_MODEL_2"] = "false" 15 | 16 | # Test parsing model configuration 17 | print("Parsing model configuration...") 18 | models = parse_model_config() 19 | 20 | print(f"\nFound {len(models)} models:") 21 | for name, config in models.items(): 22 | print(f"- {name}: {config['description']} ({config['repo']})") 23 | print(f" Source: {config['source']}, Download enabled: {config['download_enabled']}") 24 | print(f" Type: {config['type']}") 25 | print(f" Files to check: {len(config.get('files', []))}") 26 | 27 | # Test getting downloadable models 28 | print("\nGetting downloadable models...") 29 | download_models = get_downloadable_models() 30 | print(f"Found {len(download_models)} downloadable models:") 31 | for name in download_models: 32 | print(f"- {name}") 33 | 34 | # Test getting available models for UI 35 | print("\nGetting available models for UI...") 36 | available_models = get_available_models() 37 | print(f"Found {len(available_models)} available models:") 38 | for name, config in available_models.items(): 39 | print(f"- {name} (id: {config['id']})") 40 | 41 | print("\nTest completed successfully!") 42 | 43 | if __name__ == "__main__": 44 | main() -------------------------------------------------------------------------------- /app/utils/schema.sql: -------------------------------------------------------------------------------- 1 | -- Initialize database schema for CyberImage 2 | 3 | -- Drop existing tables if they exist 4 | DROP TABLE IF EXISTS jobs; 5 | DROP TABLE IF EXISTS images; 6 | 7 | -- Create jobs table for queue management 8 | CREATE TABLE jobs ( 9 | id TEXT PRIMARY KEY, 10 | status TEXT NOT NULL CHECK(status IN ('pending', 'processing', 'completed', 'failed')), 11 | model_id TEXT NOT NULL, 12 | prompt TEXT NOT NULL, 13 | negative_prompt TEXT, 14 | settings TEXT NOT NULL, -- JSON string of generation settings 15 | created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 16 | started_at TIMESTAMP, 17 | completed_at TIMESTAMP, 18 | error_message TEXT 19 | ); 20 | 21 | -- Create images table for generated images 22 | CREATE TABLE images ( 23 | id TEXT PRIMARY KEY, 24 | job_id TEXT NOT NULL, 25 | file_path TEXT NOT NULL, 26 | created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 27 | metadata TEXT NOT NULL, -- JSON string of image metadata 28 | FOREIGN KEY (job_id) REFERENCES jobs (id) 29 | ); 30 | 31 | -- Create indexes for common queries 32 | CREATE INDEX idx_jobs_status ON jobs(status); 33 | CREATE INDEX idx_jobs_created_at ON jobs(created_at); 34 | CREATE INDEX idx_images_job_id ON images(job_id); -------------------------------------------------------------------------------- /app/utils/watchdog.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPU watchdog for monitoring job execution and recovery 3 | """ 4 | import threading 5 | import time 6 | import torch 7 | import gc 8 | import sys 9 | import logging 10 | from flask import current_app 11 | from app.utils.queue import QueueManager 12 | from datetime import datetime 13 | 14 | # Configure logging 15 | logger = logging.getLogger(__name__) 16 | 17 | class GPUWatchdog: 18 | """ 19 | Monitors job execution and provides recovery mechanisms for stalled jobs. 20 | Note: High GPU memory usage (up to 100%) is normal and expected for Flux models. 21 | """ 22 | 23 | def __init__(self, model_manager, app=None, max_stalled_time=400, check_interval=30, 24 | recovery_cooldown=300): 25 | """ 26 | Initialize the GPU watchdog 27 | 28 | Args: 29 | model_manager: The ModelManager instance to monitor 30 | app: Flask application instance for context 31 | max_stalled_time: Maximum time (in seconds) a job can be "processing" before recovery 32 | check_interval: Time (in seconds) between health checks 33 | recovery_cooldown: Minimum time (in seconds) between emergency recoveries 34 | """ 35 | self.model_manager = model_manager 36 | self.app = app 37 | self.max_stalled_time = max_stalled_time 38 | self.check_interval = check_interval 39 | self.recovery_cooldown = recovery_cooldown 40 | 41 | self.running = False 42 | self.thread = None 43 | self.last_recovery_time = 0 44 | self.recovery_in_progress = False 45 | 46 | logger.info(f"GPU Watchdog initialized (max_stalled_time={max_stalled_time}s, check_interval={check_interval}s)") 47 | 48 | def _get_app(self): 49 | """Get the Flask application instance or the stored one""" 50 | return self.app or current_app._get_current_object() 51 | 52 | def start(self): 53 | """Start the watchdog monitoring thread""" 54 | if self.running: 55 | return 56 | 57 | self.running = True 58 | self.thread = threading.Thread(target=self._monitor_loop, daemon=True) 59 | self.thread.start() 60 | logger.info("GPU Watchdog started") 61 | 62 | def stop(self): 63 | """Stop the watchdog monitoring thread""" 64 | self.running = False 65 | if self.thread and self.thread.is_alive(): 66 | self.thread.join(timeout=5) 67 | logger.info("GPU Watchdog stopped") 68 | 69 | def _monitor_loop(self): 70 | """Main monitoring loop""" 71 | while self.running: 72 | try: 73 | # Log GPU memory status (for information only) 74 | self._log_gpu_memory() 75 | 76 | # Check for stalled jobs - this is our primary focus 77 | with self._get_app().app_context(): 78 | self._check_stalled_jobs() 79 | 80 | # Sleep until next check 81 | time.sleep(self.check_interval) 82 | except Exception as e: 83 | logger.error(f"Watchdog error: {str(e)}") 84 | # Sleep a bit longer on error 85 | time.sleep(self.check_interval * 2) 86 | 87 | def _log_gpu_memory(self): 88 | """Log GPU memory usage for monitoring purposes (no recovery action)""" 89 | if not torch.cuda.is_available(): 90 | return # Skip if not using CUDA 91 | 92 | try: 93 | # Get current memory usage 94 | memory_allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB 95 | memory_reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB 96 | total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) 97 | 98 | # Calculate usage percentage 99 | usage_percentage = memory_allocated / total_memory 100 | 101 | # Only log if usage is above 90% or below 10% to reduce chattiness 102 | if usage_percentage > 0.9 or usage_percentage < 0.1: 103 | logger.info(f"GPU Memory: {memory_allocated:.2f}GB/{total_memory:.2f}GB ({usage_percentage:.2%})") 104 | 105 | # Don't print to stdout for regular checks 106 | except Exception as e: 107 | logger.error(f"Error checking GPU memory: {str(e)}") 108 | # Don't print to stdout 109 | 110 | def _check_stalled_jobs(self): 111 | """Check for jobs that have been processing for too long""" 112 | try: 113 | processing_jobs = QueueManager.get_processing_jobs() 114 | current_time = time.time() 115 | 116 | for job in processing_jobs: 117 | # Skip if no started_at time 118 | if not job.get("started_at"): 119 | continue 120 | 121 | # Parse the timestamp from the database 122 | start_time = job.get("started_at") 123 | 124 | # Handle different types of start_time values with robust parsing 125 | try: 126 | # Case 1: Already a timestamp (float or int) 127 | if isinstance(start_time, (int, float)): 128 | pass # Already in the correct format 129 | 130 | # Case 2: Already a datetime object 131 | elif isinstance(start_time, datetime): 132 | start_time = start_time.timestamp() 133 | 134 | # Case 3: String format - try multiple parsing approaches 135 | elif isinstance(start_time, str): 136 | # Try parsing as ISO format first 137 | try: 138 | # Handle various ISO formats with/without timezone 139 | if 'T' in start_time: 140 | # Standard ISO format 141 | dt = datetime.fromisoformat(start_time.replace('Z', '+00:00')) 142 | else: 143 | # SQLite date format 144 | dt = datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') 145 | start_time = dt.timestamp() 146 | except ValueError: 147 | # Try other common formats 148 | for fmt in ['%Y-%m-%d %H:%M:%S.%f', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M:%S.%fZ']: 149 | try: 150 | dt = datetime.strptime(start_time, fmt) 151 | start_time = dt.timestamp() 152 | break 153 | except ValueError: 154 | continue 155 | else: 156 | # If we get here, none of the formats worked 157 | raise ValueError(f"Could not parse timestamp: {start_time}") 158 | 159 | # Case 4: Unhandled type 160 | else: 161 | logger.warning(f"Unhandled timestamp type: {type(start_time)}") 162 | continue 163 | 164 | # Calculate how long the job has been processing 165 | processing_time = current_time - start_time 166 | 167 | # If it's been processing too long, attempt recovery 168 | if processing_time > self.max_stalled_time: 169 | logger.warning(f"Stalled job detected: {job['id']} (processing for {processing_time:.1f}s)") 170 | self._recover_stalled_job(job) 171 | 172 | except Exception as e: 173 | logger.warning(f"Error processing job timestamp for job {job['id']}: {str(e)}") 174 | continue 175 | 176 | except Exception as e: 177 | logger.error(f"Error checking stalled jobs: {str(e)}") 178 | 179 | def _recover_stalled_job(self, job): 180 | """Mark a stalled job as failed and resubmit it for retry""" 181 | try: 182 | job_id = job["id"] 183 | logger.info(f"Recovering stalled job: {job_id}") 184 | 185 | # Mark job as failed 186 | QueueManager.update_job_status( 187 | job_id, 188 | "failed", 189 | "Job failed due to timeout - will be retried automatically" 190 | ) 191 | 192 | # Force model cleanup 193 | if self.model_manager: 194 | self._perform_emergency_recovery() 195 | 196 | # Now retry the job 197 | retried_job_id = QueueManager.retry_failed_job(job_id) 198 | if retried_job_id: 199 | logger.info(f"Job {job_id} reset to pending state and will be retried") 200 | else: 201 | logger.warning(f"Could not retry job {job_id} - may have exceeded retry limit") 202 | 203 | except Exception as e: 204 | logger.error(f"Error recovering stalled job: {str(e)}") 205 | 206 | def _perform_emergency_recovery(self): 207 | """Handle critical situations by cleaning up resources""" 208 | if self.recovery_in_progress: 209 | logger.info("Emergency recovery already in progress, skipping") 210 | return 211 | 212 | # Use time.time() directly to avoid any datetime conversion issues 213 | current_time = time.time() 214 | 215 | # Ensure last_recovery_time is a valid timestamp 216 | if not isinstance(self.last_recovery_time, (int, float)): 217 | logger.warning(f"Invalid last_recovery_time ({type(self.last_recovery_time)}), resetting to 0") 218 | self.last_recovery_time = 0 219 | 220 | if current_time - self.last_recovery_time < self.recovery_cooldown: 221 | # Don't log cooldown periods to reduce noise 222 | return 223 | 224 | self.recovery_in_progress = True 225 | self.last_recovery_time = current_time 226 | 227 | try: 228 | logger.warning("EMERGENCY RECOVERY INITIATED") 229 | 230 | # First, mark all currently processing jobs as failed 231 | with self._get_app().app_context(): 232 | self._recover_all_processing_jobs() 233 | 234 | # Force model cleanup 235 | if self.model_manager: 236 | self.model_manager._force_memory_cleanup() 237 | 238 | # Additional cleanup steps 239 | if torch.cuda.is_available(): 240 | torch.cuda.empty_cache() 241 | gc.collect() 242 | torch.cuda.synchronize() 243 | 244 | # Check memory after cleanup 245 | memory_allocated = torch.cuda.memory_allocated() / (1024**3) 246 | total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) 247 | usage_percentage = memory_allocated / total_memory 248 | 249 | logger.info(f"Emergency recovery completed. GPU memory: {memory_allocated:.2f}GB ({usage_percentage:.2%})") 250 | except Exception as e: 251 | logger.error(f"Error during emergency recovery: {str(e)}") 252 | finally: 253 | self.recovery_in_progress = False 254 | 255 | def _recover_all_processing_jobs(self): 256 | """Mark all currently processing jobs as failed during emergency recovery and retry them""" 257 | try: 258 | processing_jobs = QueueManager.get_processing_jobs() 259 | if not processing_jobs: 260 | logger.info("No processing jobs to recover") 261 | return 262 | 263 | logger.info(f"Recovering {len(processing_jobs)} processing jobs") 264 | for job in processing_jobs: 265 | job_id = job["id"] 266 | QueueManager.update_job_status( 267 | job_id, 268 | "failed", 269 | "Job interrupted by emergency recovery" 270 | ) 271 | 272 | # Try to retry the job 273 | retried_job_id = QueueManager.retry_failed_job(job_id) 274 | if not retried_job_id: 275 | logger.warning(f"Could not retry job {job_id} - may have exceeded retry limit") 276 | 277 | except Exception as e: 278 | logger.error(f"Error recovering processing jobs: {str(e)}") -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | cyberimage: 3 | build: . 4 | container_name: cyberimage # Force specific container name 5 | user: 1000:1000 # Run as UID 1000 6 | ports: 7 | - "7860:5050" 8 | volumes: 9 | - ${EXTERNAL_IMAGES_FOLDER}:/app/images 10 | - ${EXTERNAL_MODEL_FOLDER}:/app/models 11 | env_file: 12 | - .env 13 | environment: 14 | - IMAGES_FOLDER=/app/images 15 | - MODEL_FOLDER=/app/models 16 | restart: unless-stopped -------------------------------------------------------------------------------- /examples/ai_assistant_mcp_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of how an AI Assistant can use the CyberImage MCP endpoint for image generation 3 | This demonstrates the Model Context Protocol implementation specifically for AI assistants 4 | """ 5 | import json 6 | import requests 7 | import time 8 | import uuid 9 | import os 10 | import base64 11 | 12 | # API configuration 13 | API_URL = "http://localhost:5050/api/mcp" # Adjust for your deployment 14 | 15 | def generate_image_for_assistant(prompt, model=None, negative_prompt=""): 16 | """ 17 | Complete function for AI assistants to generate an image and get a URL 18 | This function handles the entire process from generation request to downloading 19 | 20 | Args: 21 | prompt (str): The image description 22 | model (str, optional): Model ID to use (defaults to flux-2) 23 | negative_prompt (str, optional): What to exclude from the image 24 | 25 | Returns: 26 | dict: A dictionary with status and either the image URL or error message 27 | """ 28 | # Step 1: Check available models 29 | models_info = list_models() 30 | if not models_info: 31 | return {"status": "error", "message": "Failed to retrieve model information"} 32 | 33 | # Use provided model or default 34 | model_id = model if model and model in models_info["models"] else models_info["default"] 35 | 36 | # Step 2: Generate image 37 | job_id = submit_generation_job(prompt, model_id, negative_prompt) 38 | if not job_id: 39 | return {"status": "error", "message": "Failed to submit generation job"} 40 | 41 | # Step 3: Wait for completion 42 | status = wait_for_job_completion(job_id) 43 | if not status: 44 | return {"status": "error", "message": "Job timed out or failed"} 45 | 46 | if status["status"] == "failed": 47 | return {"status": "error", "message": f"Generation failed: {status.get('message', 'Unknown error')}"} 48 | 49 | # Step 4: Return image information 50 | if status["status"] == "completed" and "images" in status and len(status["images"]) > 0: 51 | image_url = status["images"][0]["url"] 52 | # Convert relative URL to absolute URL 53 | if image_url.startswith("/"): 54 | image_url = f"http://localhost:5050{image_url}" 55 | 56 | return { 57 | "status": "success", 58 | "image_url": image_url, 59 | "prompt": prompt, 60 | "model": model_id 61 | } 62 | 63 | return {"status": "error", "message": "Unknown error occurred"} 64 | 65 | def mcp_request(method, params=None): 66 | """ 67 | Make an MCP request to the CyberImage API 68 | 69 | Args: 70 | method (str): The MCP method to call 71 | params (dict): Parameters for the method 72 | 73 | Returns: 74 | dict: The API response result or None on error 75 | """ 76 | request_id = str(uuid.uuid4()) 77 | payload = { 78 | "jsonrpc": "2.0", 79 | "method": method, 80 | "params": params or {}, 81 | "id": request_id 82 | } 83 | 84 | try: 85 | response = requests.post(API_URL, json=payload, timeout=30) 86 | response.raise_for_status() 87 | 88 | result = response.json() 89 | if "error" in result: 90 | print(f"API error: {result['error']['message']}") 91 | return None 92 | 93 | return result.get("result") 94 | except Exception as e: 95 | print(f"Request error: {str(e)}") 96 | return None 97 | 98 | def list_models(): 99 | """Get available models""" 100 | return mcp_request("context.image_generation.models") 101 | 102 | def submit_generation_job(prompt, model, negative_prompt=""): 103 | """Submit a generation job and return the job ID""" 104 | params = { 105 | "prompt": prompt, 106 | "model": model, 107 | "negative_prompt": negative_prompt, 108 | "settings": { 109 | "num_images": 1, 110 | "num_inference_steps": 30, 111 | "guidance_scale": 7.5, 112 | "height": 1024, 113 | "width": 1024 114 | } 115 | } 116 | 117 | result = mcp_request("context.image_generation.generate", params) 118 | return result["job_id"] if result and "job_id" in result else None 119 | 120 | def check_job_status(job_id): 121 | """Check status of a generation job""" 122 | return mcp_request("context.image_generation.status", {"job_id": job_id}) 123 | 124 | def wait_for_job_completion(job_id, max_wait_seconds=300, poll_interval=5): 125 | """Wait for a job to complete with timeout""" 126 | start_time = time.time() 127 | 128 | while time.time() - start_time < max_wait_seconds: 129 | status = check_job_status(job_id) 130 | 131 | if not status: 132 | return None 133 | 134 | if status["status"] in ["completed", "failed"]: 135 | return status 136 | 137 | time.sleep(poll_interval) 138 | 139 | return None 140 | 141 | # Example usage 142 | if __name__ == "__main__": 143 | # Example of how an AI assistant would call this function 144 | result = generate_image_for_assistant( 145 | prompt="A beautiful mountain landscape with a lake and forest at sunset", 146 | negative_prompt="blurry, distorted, low quality" 147 | ) 148 | 149 | if result["status"] == "success": 150 | print(f"Image generated successfully!") 151 | print(f"Image URL: {result['image_url']}") 152 | print(f"Prompt: {result['prompt']}") 153 | print(f"Model: {result['model']}") 154 | else: 155 | print(f"Error: {result['message']}") -------------------------------------------------------------------------------- /examples/mcp_client_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of how to use the CyberImage MCP endpoint 3 | This demonstrates the Model Context Protocol implementation for AI image generation 4 | """ 5 | import json 6 | import requests 7 | import time 8 | import uuid 9 | import os 10 | from PIL import Image 11 | import io 12 | import base64 13 | 14 | # API configuration 15 | API_URL = "http://localhost:5050/api/mcp" # Adjust for your deployment 16 | 17 | def mcp_request(method, params=None): 18 | """ 19 | Make an MCP request to the CyberImage API 20 | 21 | Args: 22 | method (str): The MCP method to call 23 | params (dict): Parameters for the method 24 | 25 | Returns: 26 | dict: The API response 27 | """ 28 | request_id = str(uuid.uuid4()) 29 | payload = { 30 | "jsonrpc": "2.0", 31 | "method": method, 32 | "params": params or {}, 33 | "id": request_id 34 | } 35 | 36 | response = requests.post(API_URL, json=payload) 37 | 38 | if response.status_code != 200: 39 | print(f"Error: {response.status_code}") 40 | print(response.text) 41 | return None 42 | 43 | return response.json() 44 | 45 | def list_models(): 46 | """Get available models""" 47 | response = mcp_request("context.image_generation.models") 48 | 49 | if not response or "error" in response: 50 | print("Error listing models:", response.get("error", {}).get("message", "Unknown error")) 51 | return 52 | 53 | result = response["result"] 54 | print("\nAvailable Models:") 55 | print("================") 56 | 57 | for name, info in result["models"].items(): 58 | print(f"- {name}: {info['description']}") 59 | 60 | print(f"\nDefault model: {result['default']}") 61 | return result 62 | 63 | def generate_image(prompt, model=None, negative_prompt="", settings=None): 64 | """ 65 | Generate an image using the CyberImage MCP API 66 | 67 | Args: 68 | prompt (str): The image description 69 | model (str, optional): Model to use 70 | negative_prompt (str, optional): What to exclude from the image 71 | settings (dict, optional): Additional generation settings 72 | 73 | Returns: 74 | str: Job ID for the generation request 75 | """ 76 | params = { 77 | "prompt": prompt, 78 | "negative_prompt": negative_prompt 79 | } 80 | 81 | if model: 82 | params["model"] = model 83 | 84 | if settings: 85 | params["settings"] = settings 86 | 87 | print(f"Generating image with prompt: '{prompt}'") 88 | print(f"Using model: {model or 'default'}") 89 | 90 | response = mcp_request("context.image_generation.generate", params) 91 | 92 | if not response or "error" in response: 93 | print("Error generating image:", response.get("error", {}).get("message", "Unknown error")) 94 | return None 95 | 96 | result = response["result"] 97 | job_id = result["job_id"] 98 | print(f"Job submitted with ID: {job_id}") 99 | print(f"Status: {result['status']}") 100 | 101 | return job_id 102 | 103 | def check_status(job_id): 104 | """ 105 | Check the status of a generation job 106 | 107 | Args: 108 | job_id (str): The job ID to check 109 | 110 | Returns: 111 | dict: The job status 112 | """ 113 | response = mcp_request("context.image_generation.status", {"job_id": job_id}) 114 | 115 | if not response or "error" in response: 116 | print("Error checking status:", response.get("error", {}).get("message", "Unknown error")) 117 | return None 118 | 119 | return response["result"] 120 | 121 | def download_image(image_url, save_path): 122 | """ 123 | Download an image from the given URL 124 | 125 | Args: 126 | image_url (str): The image URL 127 | save_path (str): Where to save the image 128 | 129 | Returns: 130 | bool: True if successful, False otherwise 131 | """ 132 | # Get full URL 133 | if image_url.startswith("/"): 134 | image_url = f"http://localhost:5050{image_url}" 135 | 136 | response = requests.get(image_url) 137 | 138 | if response.status_code != 200: 139 | print(f"Error downloading image: {response.status_code}") 140 | return False 141 | 142 | with open(save_path, "wb") as f: 143 | f.write(response.content) 144 | 145 | print(f"Image saved to {save_path}") 146 | return True 147 | 148 | def wait_for_completion(job_id, poll_interval=5, max_attempts=60): 149 | """ 150 | Wait for a job to complete 151 | 152 | Args: 153 | job_id (str): The job ID to check 154 | poll_interval (int): How often to check the status in seconds 155 | max_attempts (int): Maximum number of polling attempts 156 | 157 | Returns: 158 | dict: Final job status or None if timed out 159 | """ 160 | attempts = 0 161 | 162 | while attempts < max_attempts: 163 | status = check_status(job_id) 164 | 165 | if not status: 166 | return None 167 | 168 | print(f"Status: {status['status']}") 169 | 170 | if status["status"] == "processing" and "progress" in status: 171 | progress = status["progress"] 172 | if progress.get("generating") and progress.get("step") is not None: 173 | print(f"Generating: Step {progress['step']}/{progress['total_steps']}") 174 | 175 | if status["status"] in ["completed", "failed"]: 176 | return status 177 | 178 | time.sleep(poll_interval) 179 | attempts += 1 180 | 181 | print(f"Timed out after {max_attempts * poll_interval} seconds") 182 | return None 183 | 184 | def main(): 185 | """Main function demonstrating the MCP client""" 186 | # List available models 187 | models_info = list_models() 188 | 189 | if not models_info: 190 | return 191 | 192 | # Use the default model 193 | default_model = models_info["default"] 194 | 195 | # Generate an image 196 | prompt = "A futuristic cyberpunk city at night with neon lights and flying cars" 197 | negative_prompt = "blurry, low quality, distorted" 198 | 199 | settings = { 200 | "num_images": 1, 201 | "num_inference_steps": 30, 202 | "guidance_scale": 7.5, 203 | "height": 1024, 204 | "width": 1024 205 | } 206 | 207 | job_id = generate_image( 208 | prompt=prompt, 209 | model=default_model, 210 | negative_prompt=negative_prompt, 211 | settings=settings 212 | ) 213 | 214 | if not job_id: 215 | return 216 | 217 | # Wait for the job to complete 218 | final_status = wait_for_completion(job_id) 219 | 220 | if not final_status: 221 | return 222 | 223 | if final_status["status"] == "completed" and "images" in final_status: 224 | # Download each generated image 225 | for i, image_info in enumerate(final_status["images"]): 226 | image_url = image_info["url"] 227 | save_path = f"generated_image_{i+1}.png" 228 | download_image(image_url, save_path) 229 | elif final_status["status"] == "failed": 230 | print(f"Generation failed: {final_status.get('message', 'Unknown error')}") 231 | 232 | if __name__ == "__main__": 233 | main() -------------------------------------------------------------------------------- /instance/config.py: -------------------------------------------------------------------------------- 1 | # Instance specific configuration 2 | 3 | # You can add other instance-specific overrides here, e.g.: 4 | # SECRET_KEY = 'your-secret-key' 5 | # DATABASE = '/path/to/your/production/database.sqlite' -------------------------------------------------------------------------------- /media/basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/basic.png -------------------------------------------------------------------------------- /media/enrich.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/enrich.png -------------------------------------------------------------------------------- /media/gallery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/gallery.png -------------------------------------------------------------------------------- /media/generate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/generate.png -------------------------------------------------------------------------------- /media/queue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/queue.png -------------------------------------------------------------------------------- /media/single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RamboRogers/cyberimage/ac21429a46c03b8292a20fe57b13a10db6942bf2/media/single.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.4.0 2 | annotated-types==0.7.0 3 | bitsandbytes==0.45.2 4 | black==24.1.1 5 | blinker==1.9.0 6 | certifi==2025.1.31 7 | charset-normalizer==3.4.1 8 | click==8.1.8 9 | click-default-group==1.2.4 10 | coverage==7.6.12 11 | diffusers 12 | torchvision 13 | ftfy 14 | gguf 15 | imageio-ffmpeg 16 | imageio 17 | filelock==3.17.0 18 | Flask==3.1.0 19 | Flask-Cors==5.0.0 20 | fsspec==2025.2.0 21 | fal-client 22 | gunicorn==23.0.0 23 | hf_transfer==0.1.9 24 | huggingface-hub==0.29.1 25 | idna==3.10 26 | importlib_metadata==8.6.1 27 | iniconfig==2.0.0 28 | inquirerpy==0.3.4 29 | itsdangerous==2.2.0 30 | Jinja2==3.1.5 31 | MarkupSafe==3.0.2 32 | mpmath==1.3.0 33 | mypy-extensions==1.0.0 34 | networkx==3.4.2 35 | numpy==2.2.3 36 | nvidia-cublas-cu12==12.4.5.8 37 | nvidia-cuda-cupti-cu12==12.4.127 38 | nvidia-cuda-nvrtc-cu12==12.4.127 39 | nvidia-cuda-runtime-cu12==12.4.127 40 | nvidia-cudnn-cu12==9.1.0.70 41 | nvidia-cufft-cu12==11.2.1.3 42 | nvidia-curand-cu12==10.3.5.147 43 | nvidia-cusolver-cu12==11.6.1.9 44 | nvidia-cusparse-cu12==12.3.1.170 45 | nvidia-cusparselt-cu12==0.6.2 46 | nvidia-nccl-cu12==2.21.5 47 | nvidia-nvjitlink-cu12==12.4.127 48 | nvidia-nvtx-cu12==12.4.127 49 | packaging==24.2 50 | pathspec==0.12.1 51 | pfzy==0.3.4 52 | pillow==10.2.0 53 | platformdirs==4.3.6 54 | pluggy==1.5.0 55 | prompt_toolkit==3.0.50 56 | protobuf==4.25.3 57 | psutil==7.0.0 58 | pydantic==2.6.1 59 | pydantic_core==2.16.2 60 | pytest==8.0.0 61 | pytest-cov==4.1.0 62 | python-dateutil==2.9.0.post0 63 | python-dotenv==1.0.0 64 | PyYAML==6.0.2 65 | regex==2024.11.6 66 | requests>=2.31.0 67 | safetensors==0.5.2 68 | sentencepiece==0.2.0 69 | setuptools==75.8.0 70 | six==1.17.0 71 | sqlite-fts4==1.0.3 72 | sqlite-utils==3.35.1 73 | sympy==1.13.1 74 | tabulate==0.9.0 75 | tokenizers==0.21.0 76 | torch==2.6.0 77 | tqdm==4.66.1 78 | transformers==4.49.0 79 | triton==3.2.0 80 | typing_extensions==4.12.2 81 | urllib3==2.3.0 82 | uuid==1.30 83 | wcwidth==0.2.13 84 | Werkzeug==3.1.3 85 | zipp==3.21.0 86 | pytz==2024.1 87 | openai>=1.0.0 88 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """ 2 | CyberImage Application Entry Point 3 | 4 | Sets up proper system configuration for PyTorch memory management 5 | and ensures clean shutdown of GPU resources. 6 | """ 7 | import os 8 | import sys 9 | import signal 10 | import atexit 11 | import gc 12 | import threading 13 | import time 14 | from app import create_app 15 | 16 | # --- PyTorch Memory Management Configuration --- 17 | # Allow memory to be allocated more efficiently (prevents fragmentation) 18 | os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:512' 19 | # Restrict to only use one GPU (device 0) if multiple are available 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | # Add optimization for memory allocation/deallocation 22 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 23 | # Enable tensor debugging when needed 24 | # os.environ['PYTORCH_DEBUG'] = '1' 25 | 26 | # --- Create the Flask application --- 27 | app = create_app() 28 | 29 | # Ensure the generator is properly initialized 30 | with app.app_context(): 31 | from app.models.generator import GenerationPipeline 32 | 33 | # Store reference to the generator for cleanup 34 | if not hasattr(app, '_generator'): 35 | print("\n🔄 Initializing GenerationPipeline during application startup") 36 | app._generator = GenerationPipeline() 37 | print("✅ GenerationPipeline initialized") 38 | sys.stdout.flush() 39 | 40 | # --- Register cleanup handlers --- 41 | def cleanup_handler(signum=None, frame=None): 42 | """ 43 | Comprehensive cleanup handler for application shutdown 44 | 45 | This ensures all GPU resources are properly released on exit 46 | """ 47 | print("\n🧹 Application shutdown - performing comprehensive cleanup...") 48 | sys.stdout.flush() 49 | 50 | try: 51 | # Set a flag to prevent new requests 52 | app.shutting_down = True 53 | 54 | # Access model manager if available in app context 55 | with app.app_context(): 56 | # Try to access the generator through the app 57 | if hasattr(app, '_generator'): 58 | print("✅ Stopping generator pipeline...") 59 | app._generator.stop() 60 | # Give it some time to complete cleanup 61 | time.sleep(1) 62 | sys.stdout.flush() 63 | 64 | # Force CUDA cleanup if available 65 | try: 66 | import torch 67 | if torch.cuda.is_available(): 68 | print("✅ Forcing CUDA memory cleanup...") 69 | sys.stdout.flush() 70 | torch.cuda.empty_cache() 71 | gc.collect() 72 | torch.cuda.synchronize() 73 | 74 | # Second pass cleanup 75 | torch.cuda.empty_cache() 76 | gc.collect() 77 | 78 | # Report final memory state 79 | memory_allocated = torch.cuda.memory_allocated() / (1024**3) 80 | print(f"✅ Final GPU memory: {memory_allocated:.2f}GB") 81 | sys.stdout.flush() 82 | except Exception as cuda_error: 83 | print(f"⚠️ CUDA cleanup warning: {str(cuda_error)}") 84 | sys.stdout.flush() 85 | 86 | except Exception as e: 87 | print(f"❌ Error during shutdown cleanup: {str(e)}") 88 | sys.stdout.flush() 89 | 90 | print("👋 Application shutdown complete!") 91 | sys.stdout.flush() 92 | 93 | # Exit if called as signal handler 94 | if signum is not None: 95 | sys.exit(0) 96 | 97 | # Register the cleanup handler for normal exit 98 | atexit.register(cleanup_handler) 99 | 100 | # Register signal handlers for graceful shutdown 101 | signal.signal(signal.SIGINT, cleanup_handler) 102 | signal.signal(signal.SIGTERM, cleanup_handler) 103 | 104 | # If running on Linux, also handle SIGQUIT 105 | if hasattr(signal, 'SIGQUIT'): 106 | signal.signal(signal.SIGQUIT, cleanup_handler) 107 | 108 | if __name__ == "__main__": 109 | print("\n🚀 Starting CyberImage with optimized GPU memory management") 110 | print("⚠️ PRODUCTION NOTE: Using a single worker is REQUIRED for this application") 111 | print(" Example: gunicorn -w 1 -b 0.0.0.0:5050 run:app") 112 | print("⚠️ Multiple workers will cause memory/generation conflicts!") 113 | sys.stdout.flush() 114 | 115 | # For development, use threaded=True with a single process 116 | # The threaded Flask server is fine since we have a dedicated generation thread 117 | app.run(debug=True, host="0.0.0.0", port=5050, use_reloader=False, threaded=True) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #if .env exists, source it 4 | if [ -f .env ]; then 5 | source .env 6 | fi 7 | #if MODEL_FOLDER or IMAGES_FOLDER is not set, exit 8 | if [ -z "$EXTERNAL_MODEL_FOLDER" ] || [ -z "$EXTERNAL_IMAGES_FOLDER" ]; then 9 | echo "❌ Error: EXTERNAL_MODEL_FOLDER or EXTERNAL_IMAGES_FOLDER is not set!" 10 | exit 1 11 | fi 12 | 13 | # Function to check if a container exists 14 | container_exists() { 15 | docker ps -a --format '{{.Names}}' | grep -q "^cyberimage$" 16 | } 17 | 18 | # Function to check if a container is running 19 | container_running() { 20 | docker ps --format '{{.Names}}' | grep -q "^cyberimage$" 21 | } 22 | 23 | # Function to check if service is responding 24 | check_service() { 25 | local max_attempts=30 26 | local attempt=1 27 | local wait_time=2 28 | 29 | echo "Checking service health..." 30 | while [ $attempt -le $max_attempts ]; do 31 | if curl -s "http://localhost:7860/health" | grep -q "healthy"; then 32 | echo "✅ Service is healthy!" 33 | return 0 34 | fi 35 | echo "⏳ Waiting for service to become healthy (attempt $attempt/$max_attempts)..." 36 | sleep $wait_time 37 | attempt=$((attempt + 1)) 38 | done 39 | echo "❌ Service failed to become healthy after $max_attempts attempts" 40 | return 1 41 | } 42 | 43 | # Check if NVIDIA GPU is available 44 | has_gpu() { 45 | if command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then 46 | return 0 47 | else 48 | return 1 49 | fi 50 | } 51 | 52 | # Ensure .env file exists 53 | if [ ! -f .env ]; then 54 | echo "❌ Error: .env file not found!" 55 | exit 1 56 | fi 57 | 58 | # Create required directories if they don't exist 59 | mkdir -p $EXTERNAL_IMAGES_FOLDER $EXTERNAL_MODEL_FOLDER 60 | # Ensure correct ownership 61 | if [ "$(id -u)" -eq 0 ]; then 62 | chown -R 1000:1000 $EXTERNAL_IMAGES_FOLDER $EXTERNAL_MODEL_FOLDER 63 | else 64 | # If not root, try sudo if the directories aren't owned by the current user 65 | if [ "$(stat -c '%u' $EXTERNAL_IMAGES_FOLDER)" != "$(id -u)" ]; then 66 | sudo chown -R "$(id -u):$(id -g)" $EXTERNAL_IMAGES_FOLDER $EXTERNAL_MODEL_FOLDER 67 | fi 68 | fi 69 | 70 | # Stop and remove existing container if it exists 71 | if container_exists; then 72 | echo "🛑 Stopping existing container..." 73 | docker stop cyberimage 74 | echo "🗑️ Removing existing container..." 75 | docker rm cyberimage 76 | fi 77 | 78 | # Build the image 79 | echo "🏗️ Building Docker image..." 80 | docker build -t cyberimage . 81 | 82 | # Prepare GPU options if available 83 | GPU_OPTIONS="" 84 | if has_gpu; then 85 | echo "🎮 NVIDIA GPU detected, enabling GPU support..." 86 | GPU_OPTIONS="--gpus all" 87 | else 88 | echo "⚠️ No NVIDIA GPU detected, running in CPU mode..." 89 | fi 90 | 91 | # Start the container 92 | echo "🚀 Starting container..." 93 | docker run -d \ 94 | --name cyberimage \ 95 | --user $(id -u):$(id -g) \ 96 | -p 7860:5050 \ 97 | -v "$EXTERNAL_IMAGES_FOLDER:/app/images" \ 98 | -v "$EXTERNAL_MODEL_FOLDER:/app/models" \ 99 | --env-file .env \ 100 | -e FLASK_APP=run.py \ 101 | --restart unless-stopped \ 102 | $GPU_OPTIONS \ 103 | cyberimage 104 | 105 | # Wait a moment for container to start 106 | sleep 2 107 | 108 | # Check if container started successfully 109 | if ! container_running; then 110 | echo "❌ Failed to start container!" 111 | docker logs cyberimage 112 | exit 1 113 | fi 114 | 115 | # Check service health 116 | if ! check_service; then 117 | echo "❌ Service health check failed!" 118 | docker logs cyberimage 119 | exit 1 120 | fi 121 | 122 | echo " 123 | ✨ CyberImage is ready! 124 | 📝 Logs: docker logs -f cyberimage 125 | 🔍 Status: docker ps 126 | 🏥 Health: curl http://localhost:7860/health 127 | " -------------------------------------------------------------------------------- /test_generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for CyberImage generation pipeline 3 | """ 4 | import sys 5 | import time 6 | import requests 7 | from typing import Dict, Optional 8 | from app.models import AVAILABLE_MODELS 9 | 10 | # Configuration 11 | API_BASE = "http://localhost:5050/api" 12 | TEST_PROMPTS = { 13 | "flux-2": "A cyberpunk city at night with neon signs and flying cars, highly detailed", 14 | "flux-1": "A futuristic metropolis with towering skyscrapers and holographic billboards", 15 | "sd-3.5": "A futuristic robot in a zen garden, detailed digital art", 16 | "flux-abliterated": "A mystical forest with glowing mushrooms and floating crystals, fantasy art" 17 | } 18 | 19 | def print_status(message: str, status: str = "info") -> None: 20 | """Print formatted status messages""" 21 | status_icons = { 22 | "info": "ℹ️", 23 | "success": "✅", 24 | "error": "❌", 25 | "warning": "⚠️", 26 | "pending": "⏳" 27 | } 28 | icon = status_icons.get(status, "ℹ️") 29 | print(f"\n{icon} {message}") 30 | sys.stdout.flush() 31 | 32 | def submit_generation(model_id: str, prompt: str) -> Optional[str]: 33 | """Submit a generation request""" 34 | try: 35 | print_status(f"Testing model: {model_id}", "info") 36 | print_status(f"Prompt: {prompt}", "info") 37 | 38 | # Submit generation request 39 | print_status("Submitting generation request...", "pending") 40 | response = requests.post(f"{API_BASE}/generate", json={ 41 | "model_id": model_id, 42 | "prompt": prompt, 43 | "settings": { 44 | "num_inference_steps": 35, 45 | "guidance_scale": 8.0, 46 | "height": 1024, 47 | "width": 1024, 48 | "max_sequence_length": 512 49 | } 50 | }) 51 | response.raise_for_status() 52 | job_id = response.json()["job_id"] 53 | print_status(f"Job submitted successfully. Job ID: {job_id}", "success") 54 | return job_id 55 | except Exception as e: 56 | print_status(f"Failed to submit job: {str(e)}", "error") 57 | return None 58 | 59 | def check_status(job_id: str, timeout: int = 300) -> bool: 60 | """Check job status with timeout""" 61 | start_time = time.time() 62 | try: 63 | while time.time() - start_time < timeout: 64 | print_status("Checking job status...", "pending") 65 | response = requests.get(f"{API_BASE}/status/{job_id}") 66 | response.raise_for_status() 67 | status = response.json() 68 | 69 | if status["status"] == "completed": 70 | print_status("Generation completed successfully!", "success") 71 | return True 72 | elif status["status"] == "failed": 73 | print_status(f"Generation failed: {status.get('error', 'Unknown error')}", "error") 74 | return False 75 | elif status["status"] in ["pending", "processing"]: 76 | print_status(f"Status: {status['status']}", "info") 77 | time.sleep(5) 78 | continue 79 | else: 80 | print_status(f"Unknown status: {status['status']}", "warning") 81 | return False 82 | 83 | print_status(f"Timeout after {timeout} seconds", "error") 84 | return False 85 | except Exception as e: 86 | print_status(f"Error checking status: {str(e)}", "error") 87 | return False 88 | 89 | def test_model(model_id: str, prompt: str) -> Dict: 90 | """Test a single model""" 91 | results = { 92 | "model_id": model_id, 93 | "prompt": prompt, 94 | "success": False, 95 | "error": None 96 | } 97 | 98 | try: 99 | # Submit job 100 | job_id = submit_generation(model_id, prompt) 101 | if not job_id: 102 | results["error"] = "Failed to submit job" 103 | return results 104 | 105 | # Check status 106 | success = check_status(job_id) 107 | results["success"] = success 108 | if not success: 109 | results["error"] = "Generation failed or timed out" 110 | 111 | except Exception as e: 112 | results["error"] = str(e) 113 | 114 | return results 115 | 116 | def main(): 117 | """Main test function""" 118 | print_status("Starting CyberImage Generation Tests", "info") 119 | 120 | results = [] 121 | for model_id in AVAILABLE_MODELS: 122 | prompt = TEST_PROMPTS.get(model_id, "A beautiful landscape at sunset, digital art") 123 | print_status(f"\nTesting {model_id}", "info") 124 | result = test_model(model_id, prompt) 125 | results.append(result) 126 | 127 | # Print summary 128 | print_status("\nTest Summary:", "info") 129 | for result in results: 130 | status = "success" if result["success"] else "error" 131 | message = f"Model: {result['model_id']} - {'Success' if result['success'] else 'Failed'}" 132 | if result["error"]: 133 | message += f" - Error: {result['error']}" 134 | print_status(message, status) 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /test_watchdog.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | GPU Watchdog Test Script for CyberImage 4 | 5 | This script allows testing the GPU watchdog and memory management functionality 6 | by simulating various scenarios including memory leaks, stalled jobs, and recovery. 7 | """ 8 | import os 9 | import sys 10 | import time 11 | import json 12 | import argparse 13 | import requests 14 | import threading 15 | from tqdm import tqdm 16 | import torch 17 | 18 | # Base API URL 19 | API_BASE = "http://localhost:5050/api" 20 | 21 | # Test prompts 22 | TEST_PROMPTS = [ 23 | "A cyberpunk city at night with neon signs and flying cars, highly detailed", 24 | "A futuristic metropolis with towering skyscrapers and holographic billboards", 25 | "A mystical forest with glowing mushrooms and floating crystals, fantasy art", 26 | "A robot samurai standing in a Zen garden, intricate digital art", 27 | "An abandoned space station floating in orbit around an alien planet", 28 | "A post-apocalyptic landscape with overgrown ruins and a vibrant sunset", 29 | ] 30 | 31 | # Output formatting 32 | def print_header(title): 33 | """Print a formatted header""" 34 | print("\n" + "=" * 80) 35 | print(f" {title.upper()} ".center(80, '=')) 36 | print("=" * 80) 37 | 38 | def print_status(message, status="info"): 39 | """Print formatted status messages""" 40 | status_icons = { 41 | "info": "ℹ️", 42 | "success": "✅", 43 | "error": "❌", 44 | "warning": "⚠️", 45 | "pending": "⏳" 46 | } 47 | icon = status_icons.get(status, "ℹ️") 48 | print(f"\n{icon} {message}") 49 | sys.stdout.flush() 50 | 51 | def check_gpu_memory(): 52 | """Check current GPU memory usage""" 53 | if not torch.cuda.is_available(): 54 | print_status("CUDA not available, skipping memory check", "warning") 55 | return None 56 | 57 | try: 58 | memory_allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB 59 | memory_reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB 60 | total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) 61 | 62 | print_status(f"GPU Memory Status:", "info") 63 | print(f" • Total Memory: {total_memory:.2f}GB") 64 | print(f" • Allocated Memory: {memory_allocated:.2f}GB") 65 | print(f" • Reserved Memory: {memory_reserved:.2f}GB") 66 | print(f" • Usage: {memory_allocated/total_memory:.2%}") 67 | 68 | return { 69 | "total": total_memory, 70 | "allocated": memory_allocated, 71 | "reserved": memory_reserved, 72 | "usage_percentage": memory_allocated/total_memory 73 | } 74 | except Exception as e: 75 | print_status(f"Error checking GPU memory: {str(e)}", "error") 76 | return None 77 | 78 | def submit_job(model_id="flux-1", prompt=None): 79 | """Submit a generation job""" 80 | if prompt is None: 81 | import random 82 | prompt = random.choice(TEST_PROMPTS) 83 | 84 | try: 85 | print_status(f"Submitting job for model: {model_id}", "info") 86 | print(f" • Prompt: {prompt}") 87 | 88 | response = requests.post(f"{API_BASE}/generate", json={ 89 | "model_id": model_id, 90 | "prompt": prompt, 91 | "settings": { 92 | "num_inference_steps": 30, 93 | "guidance_scale": 7.5, 94 | "height": 1024, 95 | "width": 1024 96 | } 97 | }) 98 | 99 | if response.status_code == 200: 100 | job_data = response.json() 101 | job_id = job_data["job_id"] 102 | print_status(f"Job submitted successfully: {job_id}", "success") 103 | return job_id 104 | else: 105 | print_status(f"Failed to submit job. Status: {response.status_code}, Response: {response.text}", "error") 106 | return None 107 | except Exception as e: 108 | print_status(f"Error submitting job: {str(e)}", "error") 109 | return None 110 | 111 | def check_job_status(job_id): 112 | """Check the status of a submitted job""" 113 | try: 114 | response = requests.get(f"{API_BASE}/status/{job_id}") 115 | if response.status_code == 200: 116 | status_data = response.json() 117 | return status_data 118 | else: 119 | print_status(f"Failed to check job status. Status: {response.status_code}", "error") 120 | return None 121 | except Exception as e: 122 | print_status(f"Error checking job status: {str(e)}", "error") 123 | return None 124 | 125 | def wait_for_job(job_id, timeout=300, check_interval=2): 126 | """Wait for a job to complete with progress tracking""" 127 | start_time = time.time() 128 | previous_status = None 129 | 130 | with tqdm(total=100, desc="Job Progress", ncols=100) as pbar: 131 | while time.time() - start_time < timeout: 132 | status_data = check_job_status(job_id) 133 | 134 | if not status_data: 135 | time.sleep(check_interval) 136 | continue 137 | 138 | status = status_data.get("status") 139 | if status != previous_status: 140 | print_status(f"Job status: {status}", "info") 141 | previous_status = status 142 | 143 | # Update progress bar based on status 144 | if status == "completed": 145 | pbar.update(100 - pbar.n) # Complete the progress 146 | print_status("Job completed successfully!", "success") 147 | return True, status_data 148 | elif status == "failed": 149 | error_msg = status_data.get("message", "Unknown error") 150 | print_status(f"Job failed: {error_msg}", "error") 151 | return False, status_data 152 | elif status == "processing": 153 | # Estimate progress 154 | if pbar.n < 50: 155 | pbar.update(1) # Slowly increase 156 | elif pbar.n < 80: 157 | pbar.update(0.5) # Even slower 158 | 159 | time.sleep(check_interval) 160 | 161 | print_status(f"Timeout after {timeout} seconds", "error") 162 | return False, {"status": "timeout"} 163 | 164 | def get_queue_status(): 165 | """Get the current queue status""" 166 | try: 167 | response = requests.get(f"{API_BASE}/queue") 168 | if response.status_code == 200: 169 | queue_data = response.json() 170 | 171 | print_status("Current Queue Status:", "info") 172 | print(f" • Pending: {queue_data.get('pending', 0)}") 173 | print(f" • Processing: {queue_data.get('processing', 0)}") 174 | print(f" • Completed: {queue_data.get('completed', 0)}") 175 | print(f" • Failed: {queue_data.get('failed', 0)}") 176 | print(f" • Total: {queue_data.get('total', 0)}") 177 | 178 | return queue_data 179 | else: 180 | print_status(f"Failed to get queue status. Status: {response.status_code}", "error") 181 | return None 182 | except Exception as e: 183 | print_status(f"Error getting queue status: {str(e)}", "error") 184 | return None 185 | 186 | def check_health(): 187 | """Check the health status of the app""" 188 | try: 189 | response = requests.get("http://localhost:5050/health") 190 | if response.status_code == 200: 191 | health_data = response.json() 192 | 193 | print_status("Health Status:", "info") 194 | print(f" • Status: {health_data.get('status', 'unknown')}") 195 | print(f" • Device: {health_data.get('device', 'unknown')}") 196 | 197 | queue = health_data.get('queue', {}) 198 | print(f" • Queue Size: {queue.get('in_memory_size', 0)}") 199 | print(f" • Is Main Process: {health_data.get('is_main_process', False)}") 200 | print(f" • Is Running: {health_data.get('is_running', False)}") 201 | 202 | return health_data 203 | else: 204 | print_status(f"Failed to get health status. Status: {response.status_code}", "error") 205 | return None 206 | except Exception as e: 207 | print_status(f"Error checking health: {str(e)}", "error") 208 | return None 209 | 210 | def simulate_memory_leak(size_mb=500, count=5): 211 | """Simulate a memory leak by creating large tensors""" 212 | if not torch.cuda.is_available(): 213 | print_status("CUDA not available, cannot simulate memory leak", "error") 214 | return 215 | 216 | print_header("SIMULATING MEMORY LEAK") 217 | print_status(f"Creating {count} tensors of {size_mb}MB each", "warning") 218 | 219 | # Store tensors to prevent garbage collection 220 | tensors = [] 221 | 222 | # Check memory before 223 | print_status("Memory before leak:", "info") 224 | check_gpu_memory() 225 | 226 | # Create large tensors 227 | try: 228 | for i in range(count): 229 | # Each tensor will be approximately size_mb 230 | num_elements = size_mb * 1024 * 1024 // 4 # 4 bytes per float32 231 | tensor = torch.ones(num_elements, device="cuda") 232 | tensors.append(tensor) 233 | 234 | print_status(f"Created tensor {i+1}/{count}", "info") 235 | check_gpu_memory() 236 | time.sleep(2) # Give time for the watchdog to detect 237 | except Exception as e: 238 | print_status(f"Error creating tensors: {str(e)}", "error") 239 | 240 | # Check memory after 241 | print_status("Memory after leak:", "info") 242 | check_gpu_memory() 243 | 244 | # Return the tensors so they're not garbage collected 245 | return tensors 246 | 247 | def simulate_stalled_job(): 248 | """Simulate a stalled job by submitting a job and then causing it to stall""" 249 | print_header("SIMULATING STALLED JOB") 250 | 251 | # Submit a job 252 | job_id = submit_job() 253 | if not job_id: 254 | print_status("Failed to submit job for stall test", "error") 255 | return 256 | 257 | print_status("Waiting for job to start processing...", "info") 258 | # Wait for job to enter processing state 259 | processing = False 260 | max_wait = 60 # seconds 261 | start_time = time.time() 262 | 263 | while time.time() - start_time < max_wait: 264 | status_data = check_job_status(job_id) 265 | if status_data and status_data.get("status") == "processing": 266 | processing = True 267 | break 268 | time.sleep(2) 269 | 270 | if not processing: 271 | print_status("Job didn't start processing within expected time", "error") 272 | return 273 | 274 | print_status("Job is now processing, waiting for watchdog to detect stalled job...", "warning") 275 | print_status("Note: This test will take at least 15 minutes to complete", "info") 276 | print_status("The watchdog is configured to detect stalled jobs after 15 minutes", "info") 277 | 278 | # Wait and periodically check job status 279 | print_status("Monitoring job status to see if watchdog recovers it...", "info") 280 | 281 | recovery_timeout = 1200 # 20 minutes 282 | check_interval = 30 # seconds 283 | start_time = time.time() 284 | 285 | while time.time() - start_time < recovery_timeout: 286 | status_data = check_job_status(job_id) 287 | if not status_data: 288 | print_status("Could not check job status", "error") 289 | else: 290 | status = status_data.get("status") 291 | print_status(f"Current job status: {status}", "info") 292 | 293 | elapsed_time = time.time() - start_time 294 | print_status(f"Elapsed time: {elapsed_time:.1f} seconds", "info") 295 | 296 | if status == "failed": 297 | error_msg = status_data.get("message", "Unknown error") 298 | if "timeout" in error_msg.lower() or "recovery" in error_msg.lower(): 299 | print_status("Watchdog successfully recovered the stalled job!", "success") 300 | return True 301 | else: 302 | print_status(f"Job failed for a different reason: {error_msg}", "error") 303 | return False 304 | elif status == "completed": 305 | print_status("Job completed successfully before watchdog timeout", "success") 306 | return False # Not recovered by watchdog 307 | 308 | time.sleep(check_interval) 309 | 310 | print_status("Test timed out waiting for job recovery", "error") 311 | return False 312 | 313 | def load_test(num_jobs=10, delay=2): 314 | """Submit multiple jobs in quick succession to test queue management""" 315 | print_header(f"LOAD TEST - SUBMITTING {num_jobs} JOBS") 316 | 317 | job_ids = [] 318 | for i in range(num_jobs): 319 | print_status(f"Submitting job {i+1}/{num_jobs}", "info") 320 | job_id = submit_job() 321 | if job_id: 322 | job_ids.append(job_id) 323 | time.sleep(delay) # Small delay between submissions 324 | 325 | print_status(f"Submitted {len(job_ids)} jobs successfully", "success") 326 | 327 | # Monitor queue status 328 | print_status("Monitoring queue status...", "info") 329 | start_time = time.time() 330 | timeout = num_jobs * 120 # Approximately 2 minutes per job 331 | 332 | completed = 0 333 | failed = 0 334 | 335 | while time.time() - start_time < timeout: 336 | queue_status = get_queue_status() 337 | check_gpu_memory() 338 | 339 | if queue_status: 340 | completed = queue_status.get("completed", 0) 341 | failed = queue_status.get("failed", 0) 342 | pending = queue_status.get("pending", 0) 343 | processing = queue_status.get("processing", 0) 344 | 345 | # Check if all jobs are processed 346 | if completed + failed >= len(job_ids) and pending == 0 and processing == 0: 347 | print_status("All jobs have been processed!", "success") 348 | break 349 | 350 | time.sleep(10) # Check every 10 seconds 351 | 352 | # Final status 353 | print_status("Load test results:", "info") 354 | print(f" • Total jobs submitted: {len(job_ids)}") 355 | print(f" • Completed successfully: {completed}") 356 | print(f" • Failed: {failed}") 357 | 358 | return completed, failed 359 | 360 | def main(): 361 | """Main function to run the tests""" 362 | parser = argparse.ArgumentParser(description="GPU Watchdog Test Script") 363 | parser.add_argument("--test", choices=["memory", "stall", "load", "all"], default="all", 364 | help="Type of test to run") 365 | parser.add_argument("--jobs", type=int, default=5, help="Number of jobs for load test") 366 | args = parser.parse_args() 367 | 368 | print_header("GPU WATCHDOG TEST SCRIPT") 369 | 370 | # Initial status check 371 | check_health() 372 | check_gpu_memory() 373 | get_queue_status() 374 | 375 | if args.test == "memory" or args.test == "all": 376 | # Test 1: Simulate memory leak 377 | simulate_memory_leak() 378 | time.sleep(10) # Give watchdog time to react 379 | check_gpu_memory() # Check if memory was recovered 380 | 381 | if args.test == "stall" or args.test == "all": 382 | # Test 2: Simulate stalled job 383 | simulate_stalled_job() 384 | 385 | if args.test == "load" or args.test == "all": 386 | # Test 3: Load test with multiple jobs 387 | load_test(args.jobs) 388 | 389 | print_header("TEST COMPLETED") 390 | check_health() 391 | check_gpu_memory() 392 | get_queue_status() 393 | 394 | if __name__ == "__main__": 395 | main() -------------------------------------------------------------------------------- /tests/test_mcp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the Model Context Protocol (MCP) implementation 3 | """ 4 | import unittest 5 | import json 6 | import time 7 | from flask import url_for 8 | from app import create_app 9 | 10 | class MCPTestCase(unittest.TestCase): 11 | """Tests for the MCP implementation""" 12 | 13 | def setUp(self): 14 | """Set up test app""" 15 | self.app = create_app({ 16 | 'TESTING': True, 17 | 'ENABLE_RATE_LIMIT': False, # Disable rate limiting for tests 18 | }) 19 | self.client = self.app.test_client() 20 | self.app_context = self.app.app_context() 21 | self.app_context.push() 22 | 23 | def tearDown(self): 24 | """Clean up after tests""" 25 | self.app_context.pop() 26 | 27 | def test_mcp_invalid_request(self): 28 | """Test that invalid requests are properly rejected""" 29 | # Test missing jsonrpc field 30 | response = self.client.post( 31 | '/api/mcp', 32 | json={'method': 'context.image_generation.models'} 33 | ) 34 | self.assertEqual(response.status_code, 400) 35 | data = json.loads(response.data) 36 | self.assertEqual(data['error']['code'], -32600) 37 | 38 | # Test invalid jsonrpc version 39 | response = self.client.post( 40 | '/api/mcp', 41 | json={'jsonrpc': '1.0', 'method': 'context.image_generation.models'} 42 | ) 43 | self.assertEqual(response.status_code, 400) 44 | data = json.loads(response.data) 45 | self.assertEqual(data['error']['code'], -32600) 46 | 47 | def test_mcp_method_not_found(self): 48 | """Test that an invalid method is properly rejected""" 49 | response = self.client.post( 50 | '/api/mcp', 51 | json={ 52 | 'jsonrpc': '2.0', 53 | 'method': 'invalid.method', 54 | 'id': '123' 55 | } 56 | ) 57 | self.assertEqual(response.status_code, 404) 58 | data = json.loads(response.data) 59 | self.assertEqual(data['error']['code'], -32601) 60 | self.assertEqual(data['id'], '123') 61 | 62 | def test_mcp_models_method(self): 63 | """Test the models listing method""" 64 | response = self.client.post( 65 | '/api/mcp', 66 | json={ 67 | 'jsonrpc': '2.0', 68 | 'method': 'context.image_generation.models', 69 | 'id': '123' 70 | } 71 | ) 72 | self.assertEqual(response.status_code, 200) 73 | data = json.loads(response.data) 74 | self.assertEqual(data['jsonrpc'], '2.0') 75 | self.assertEqual(data['id'], '123') 76 | self.assertIn('result', data) 77 | 78 | # Check that we have our models and the default model 79 | self.assertIn('models', data['result']) 80 | self.assertIn('default', data['result']) 81 | self.assertEqual(data['result']['default'], 'flux-2') 82 | 83 | # Check that we have at least one model 84 | self.assertTrue(len(data['result']['models']) > 0) 85 | 86 | # Check that flux-2 exists 87 | self.assertIn('flux-2', data['result']['models']) 88 | 89 | def test_mcp_generate_method_missing_params(self): 90 | """Test the generate method with missing parameters""" 91 | response = self.client.post( 92 | '/api/mcp', 93 | json={ 94 | 'jsonrpc': '2.0', 95 | 'method': 'context.image_generation.generate', 96 | 'params': {}, 97 | 'id': '123' 98 | } 99 | ) 100 | self.assertEqual(response.status_code, 200) 101 | data = json.loads(response.data) 102 | self.assertIn('error', data) 103 | 104 | def test_mcp_generate_method(self): 105 | """Test the generate method""" 106 | response = self.client.post( 107 | '/api/mcp', 108 | json={ 109 | 'jsonrpc': '2.0', 110 | 'method': 'context.image_generation.generate', 111 | 'params': { 112 | 'prompt': 'Test prompt for MCP', 113 | 'model': 'flux-2', 114 | 'settings': { 115 | 'num_images': 1 116 | } 117 | }, 118 | 'id': '123' 119 | } 120 | ) 121 | self.assertEqual(response.status_code, 200) 122 | data = json.loads(response.data) 123 | 124 | # Check that we have our job ID and status 125 | self.assertIn('result', data) 126 | self.assertIn('job_id', data['result']) 127 | self.assertIn('status', data['result']) 128 | self.assertEqual(data['result']['status'], 'pending') 129 | 130 | # Remember the job ID for the next test 131 | self.job_id = data['result']['job_id'] 132 | 133 | def test_mcp_status_method_missing_params(self): 134 | """Test the status method with missing parameters""" 135 | response = self.client.post( 136 | '/api/mcp', 137 | json={ 138 | 'jsonrpc': '2.0', 139 | 'method': 'context.image_generation.status', 140 | 'params': {}, 141 | 'id': '123' 142 | } 143 | ) 144 | self.assertEqual(response.status_code, 200) 145 | data = json.loads(response.data) 146 | self.assertIn('error', data) 147 | 148 | def test_mcp_status_method_invalid_job(self): 149 | """Test the status method with an invalid job ID""" 150 | response = self.client.post( 151 | '/api/mcp', 152 | json={ 153 | 'jsonrpc': '2.0', 154 | 'method': 'context.image_generation.status', 155 | 'params': { 156 | 'job_id': 'invalid-job-id' 157 | }, 158 | 'id': '123' 159 | } 160 | ) 161 | self.assertEqual(response.status_code, 200) 162 | data = json.loads(response.data) 163 | self.assertIn('error', data) 164 | 165 | def test_mcp_workflow(self): 166 | """Test the complete MCP workflow""" 167 | # Skip in CI environment as we won't have GPU 168 | import os 169 | if os.environ.get('CI'): 170 | self.skipTest("Skipping complete workflow test in CI environment") 171 | 172 | # 1. Submit generation job 173 | response = self.client.post( 174 | '/api/mcp', 175 | json={ 176 | 'jsonrpc': '2.0', 177 | 'method': 'context.image_generation.generate', 178 | 'params': { 179 | 'prompt': 'Test complete workflow', 180 | 'settings': { 181 | 'num_images': 1, 182 | 'num_inference_steps': 2 # Use minimal steps for testing 183 | } 184 | }, 185 | 'id': 'workflow-test' 186 | } 187 | ) 188 | self.assertEqual(response.status_code, 200) 189 | data = json.loads(response.data) 190 | job_id = data['result']['job_id'] 191 | 192 | # 2. Check status - should be pending or processing 193 | response = self.client.post( 194 | '/api/mcp', 195 | json={ 196 | 'jsonrpc': '2.0', 197 | 'method': 'context.image_generation.status', 198 | 'params': { 199 | 'job_id': job_id 200 | }, 201 | 'id': 'workflow-status' 202 | } 203 | ) 204 | self.assertEqual(response.status_code, 200) 205 | data = json.loads(response.data) 206 | self.assertIn('result', data) 207 | self.assertIn('status', data['result']) 208 | self.assertIn(data['result']['status'], ['pending', 'processing']) 209 | 210 | if __name__ == '__main__': 211 | unittest.main() --------------------------------------------------------------------------------