├── .gitignore ├── .prettierrc ├── README.md ├── agent_scheduler ├── api.py ├── db │ ├── __init__.py │ ├── app_state.py │ ├── base.py │ └── task.py ├── helpers.py ├── models.py ├── task_helpers.py └── task_runner.py ├── docs ├── .DS_Store ├── CHANGELOG.md └── images │ ├── install.png │ ├── settings.png │ └── walkthrough.png ├── install.py ├── javascript └── agent-scheduler.iife.js ├── preload.py ├── scripts └── task_scheduler.py ├── style.css └── ui ├── .eslintignore ├── .eslintrc.cjs ├── .gitignore ├── package.json ├── postcss.config.js ├── src ├── assets │ └── icons │ │ ├── bookmark-filled.svg │ │ ├── bookmark.svg │ │ ├── cancel.svg │ │ ├── delete.svg │ │ ├── play.svg │ │ ├── rotate.svg │ │ ├── save.svg │ │ └── search.svg ├── extension │ ├── index.scss │ ├── index.ts │ ├── stores │ │ ├── history.store.ts │ │ ├── pending.store.ts │ │ └── shared.store.ts │ ├── tailwind.css │ └── types.ts ├── utils │ ├── ag-grid.ts │ ├── debounce.ts │ └── extract-args.ts └── vite-env.d.ts ├── tailwind.config.js ├── tsconfig.json ├── tsconfig.node.json ├── vite.config.ts ├── vite.extension.ts └── yarn.lock /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.safetensors 4 | *.pth 5 | /ESRGAN/* 6 | /SwinIR/* 7 | /repositories 8 | /venv 9 | /tmp 10 | /model.ckpt 11 | /models/**/* 12 | /GFPGANv1.3.pth 13 | /gfpgan/weights/*.pth 14 | /ui-config.json 15 | /outputs 16 | /config.json 17 | /log 18 | /webui.settings.bat 19 | /embeddings 20 | /styles.csv 21 | /params.txt 22 | /styles.csv.bak 23 | /webui-user.bat 24 | /webui-user.sh 25 | /interrogate 26 | /user.css 27 | /.idea 28 | notification.mp3 29 | /SwinIR 30 | /textual_inversion 31 | .vscode 32 | /extensions 33 | /test/stdout.txt 34 | /test/stderr.txt 35 | /cache.json 36 | *.sql 37 | *.db 38 | *.sqlite 39 | *.sqlite3 -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": true, 3 | "jsxSingleQuote": false, 4 | "arrowParens": "avoid", 5 | "trailingComma": "es5", 6 | "semi": true, 7 | "tabWidth": 2, 8 | "printWidth": 100 9 | } 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Agent Scheduler 2 | 3 | Introducing AgentScheduler, an Automatic1111/Vladmandic Stable Diffusion Web UI extension to power up your image generation workflow! 4 | 5 | ## Table of Content 6 | 7 | - [Compatibility](#compatibility) 8 | - [Installation](#installation) 9 | - [Using Vlad Fork](#using-vlads-webui-fork) 10 | - [Using the built-in extension list](#using-the-built-in-extension-list) 11 | - [Manual clone](#manual-clone) 12 | - [Functionality](#functionality-as-of-current-version) 13 | - [Settings](#extension-settings) 14 | - [API Access](#api-access) 15 | - [Troubleshooting](#troubleshooting) 16 | - [Road Map](#road-map) 17 | - [Contributing](#contributing) 18 | - [License](#license) 19 | - [Disclaimer](#disclaimer) 20 | 21 | --- 22 | 23 | ## Compatibility 24 | 25 | This version of AgentScheduler is compatible with latest versions of: 26 | 27 | - A1111: [commit baf6946](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/baf6946e06249c5af9851c60171692c44ef633e0) 28 | - Vladmandic: [commit 9726b4d](https://github.com/vladmandic/automatic/commit/9726b4d23cb63779964e1d4edff49dd2c9c11e51) 29 | 30 | > Older versions may not working properly. 31 | 32 | ## Installation 33 | 34 | ### Using Vlad's WebUI Fork 35 | 36 | The extension is already included in [Vlad fork](https://github.com/vladmandic/automatic)'s builtin extensions. 37 | 38 | ### Using the built-in extension list 39 | 40 | 1. Open the Extensions tab 41 | 2. Open the "Install From URL" sub-tab 42 | 3. Paste the repo url: https://github.com/ArtVentureX/sd-webui-agent-scheduler.git 43 | 4. Click "Install" 44 | 45 | ![Install](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/f0fa740b-392a-4dd6-abe1-49c770ea60da) 46 | 47 | ### Manual clone 48 | 49 | ```bash 50 | git clone "https://github.com/ArtVentureX/sd-webui-agent-scheduler.git" extensions/agent-scheduler 51 | ``` 52 | 53 | (The second argument specifies the name of the folder, you can choose whatever you like). 54 | 55 | ## Basic Features 56 | 57 | ![Extension Walkthrough 1](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/a5a039a7-d98b-4186-9131-6775f0812c39) 58 | 59 | 1️⃣ Input your usual Prompts & Settings. **Enqueue** to send your current prompts, settings, controlnets to **AgentScheduler**. 60 | 61 | ![Extension Walkthrough 2](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/734176b4-7ee3-40e5-bb92-35608fabfc4b) 62 | 63 | 2️⃣ **AgentScheduler** Extension Tab. 64 | 65 | 3️⃣ See all queued tasks, current image being generated and tasks' associated information. **Drag and drop** the handle in the begining of each row to reaggrange the generation order. 66 | 67 | 4️⃣ **Pause** to stop queue auto generation. **Resume** to start. 68 | 69 | 5️⃣ Press ▶️ to prioritize selected task, or to start a single task when queue is paused. **Delete** tasks that you no longer want. 70 | 71 | ![ Extension Walkthrough 3](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/23109761-2633-4b24-bbb3-091628367047) 72 | 73 | 6️⃣ Show queue history. 74 | 75 | 7️⃣ **Filter** task status or search by text. 76 | 77 | 8️⃣ **Bookmark** task to easier filtering. 78 | 79 | 9️⃣ Double click the task id to **rename** and quickly update basic parameters. Click ↩️ to **Requeue** old task. 80 | 81 | 🔟 Click on each task to **view** the generation results. 82 | 83 | https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/50c74922-b85f-493c-9be8-b8e78f0cd061 84 | 85 | ## Hidden Features: 86 | 87 | #### Queue all checkpoints at the same time 88 | 89 | Right click the `Enqueue` button and select `Queue with all checkpoints` to quickly queue the current setting with all available checkpoints. 90 | 91 | ![image](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/c75276e8-6d0c-4f72-91db-817f38a3fea6) 92 | 93 | #### Queue with a subset of checkpoints 94 | 95 | ![image](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/b776d09d-c789-47f1-8884-975848bb766d) 96 | 97 | ![image](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/bdb2b41a-5ae8-41c1-bece-7dbff67e38b7) 98 | 99 | With the custom checkpoint select enabled (see [Extension Settings](#extension-settings) section below), you can select a folder (or subfolder) to queue task with all checkpoints inside. Eg: Select `anime` will queue `anime\AOM3A1B_oragemixs`, `anime\counterfeit\Counterfeit-V2.5_fp16` and `anime\counterfeit\Counterfeit-V2.5_pruned`. 100 | 101 | #### Edit queued task 102 | 103 | Double click a queued task to edit. You can name a task by changing `task_id` or update some basic parameters: `prompt`, `negative prompt`, `sampler`, `checkpoint`, `steps`, `cfg scale`. 104 | 105 | ![image](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/57535174-2f66-4ee7-8f3c-9f1dd3882eff) 106 | 107 | ## Extension Settings 108 | 109 | Go to `Settings > Agent Scheduler` to access extension settings. 110 | 111 | ![Settings](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/b0377ccd-f9bf-486e-8393-c06fe26aa117) 112 | 113 | **Disable Queue Auto-Processing**: Check this option to disable queue auto-processing on start-up. You can also temporarily pause or resume the queue from the Extension tab. 114 | 115 | **Queue Button Placement**: Change the placement of the queue button on the UI. 116 | 117 | **Hide the Checkpoint Dropdown**: The Extension provides a custom checkpoint dropdown. 118 | 119 | ![Custom Checkpoint](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/d110d314-a208-4eec-bb54-9f8c73cb450b) 120 | 121 | By default, queued tasks use the currently loaded checkpoint. However, changing the system checkpoint requires some time to load the checkpoint into memory, and you also cannot change the checkpoint during image generation. You can use this dropdown to quickly queue a task with a custom checkpoint. 122 | 123 | **Auto Delete Queue History**: Select a timeframe to keep your queue history. Tasks that are older than the configured value will be automatically deleted. Please note that bookmarked tasks will not be deleted. 124 | 125 | ## API Access 126 | 127 | All the functionality of this extension can be accessed through HTTP APIs. You can access the API documentation via `http://127.0.0.1:7860/docs`. Remember to include `--api` in your startup arguments. 128 | 129 | ![API docs](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/012ab2cc-b41f-4c68-8fa5-7ab4e49aa91d) 130 | 131 | #### Queue Task 132 | 133 | The two apis `/agent-scheduler/v1/queue/txt2img` and `/agent-scheduler/v1/queue/img2img` support all the parameters of the original webui apis. These apis response the task id, which can be used to perform updates later. 134 | 135 | ```json 136 | { 137 | "task_id": "string" 138 | } 139 | ``` 140 | 141 | #### Download Results 142 | 143 | Use api `/agent-scheduler/v1/results/{id}` to get the generated images. The api supports two response format: 144 | 145 | - json with base64 encoded 146 | 147 | ```json 148 | { 149 | "success": true, 150 | "data": [ 151 | { 152 | "image": "data:image/png;base64,iVBORw0KGgoAAAAN...", 153 | "infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..." 154 | }, 155 | { 156 | "image": "data:image/png;base64,iVBORw0KGgoAAAAN...", 157 | "infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..." 158 | } 159 | ] 160 | } 161 | ``` 162 | 163 | - zip file with querystring `zip=true` 164 | 165 | #### API Callback 166 | 167 | Queue task with param `callback_url` to register an API callback. Eg: 168 | 169 | ```json 170 | { 171 | "prompt": "1girl", 172 | "negative_prompt": "easynegative", 173 | "callback_url": "http://somehost:port/task_completed" 174 | } 175 | ``` 176 | 177 | The callback endpoint must support `POST` method with body in `multipart/form-data` encoding. Body format: 178 | 179 | ```json 180 | { 181 | "task_id": "abc123", 182 | "status": "done", 183 | "files": [list of image files], 184 | } 185 | ``` 186 | 187 | Example code of the endpoint handle with `FastApi`: 188 | 189 | ```python 190 | from fastapi import FastAPI, UploadFile, File, Form 191 | 192 | @app.post("/task_completed") 193 | async def handle_task_completed( 194 | task_id: Annotated[str, Form()], 195 | status: Annotated[str, Form()], 196 | files: Optional[List[UploadFile]] = File(None), 197 | ): 198 | print(f"Received {len(files)} files for task {task_id} with status {status}") 199 | for file in files: 200 | print(f"* {file.filename} {file.content_type} {file.size}") 201 | # ... do something with the file contents ... 202 | 203 | # Received 1 files for task 3cf8b150-f260-4489-b6e8-d86ed8a564ca with status done 204 | # * 00008-3322209480.png image/png 416400 205 | ``` 206 | 207 | ## Troubleshooting 208 | 209 | Make sure that you are running the latest version of the extension and an updated version of the WebUI. 210 | 211 | - To update the extension, go to `Extension` tab and click `Check for Updates`, then click `Apply and restart UI`. 212 | - To update the WebUI it self, you run the command `git pull origin master` in the same folder as webui.bat (or webui.sh). 213 | 214 | Steps to try to find the cause of issues: 215 | 216 | - Check the for errors in the WebUI output console. 217 | - Press F12 in the browser then go to the console tab and reload the page, find any error message here. 218 | 219 | Common errors: 220 | 221 | **AttributeError: module 'modules.script_callbacks' has no attribute 'on_before_reload'** 222 | 223 | If you see this error message in the output console, try update the WebUI to the latest version. 224 | 225 | **Update**: The extension is updated to print this warning message instead: **YOUR SD WEBUI IS OUTDATED AND AGENT SCHEDULER WILL NOT WORKING PROPERLY.** You can still able to use the extension but it will not working correctly after a reload. 226 | 227 | ~~**ReferenceError: submit_enqueue is not defined**~~ 228 | 229 | ~~If you click the `Enqueue` button and nothing happen, and you find above error message in the browser F12 console, follow the steps in [this comment](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/4#issuecomment-1575986274).~~ 230 | 231 | Update: This issue is now fixed. 232 | 233 | **TypeError: issubclass() arg 1 must be a class** 234 | Please update the extension, there's a chance it's already fixed. 235 | 236 | **TypeError: Object of type X is not JSON serializable** 237 | Please update the extension, it should be fixed already. If not, please fire an issue report with the list of installed extensions. 238 | 239 | For other errors, feel free to fire a new [Github issue](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/new/choose). 240 | 241 | 242 | ## Contributing 243 | 244 | We welcome contributions to the Agent Scheduler Extension project! Please feel free to submit issues, bug reports, and feature requests through the GitHub repository. 245 | 246 | Please give us a ⭐ if you find this extension helpful! 247 | 248 | ## License 249 | 250 | This project is licensed under the Apache License 2.0. 251 | 252 | ## Disclaimer 253 | 254 | The author(s) of this project are not responsible for any damages or legal issues arising from the use of this software. Users are solely responsible for ensuring that they comply with any applicable laws and regulations when using this software and assume all risks associated with its use. The author(s) are not responsible for any copyright violations or legal issues arising from the use of input or output content. 255 | 256 | --- 257 | 258 | ## CRAFTED BY THE PEOPLE BUILDING [**SIPHER//AGI**](https://sipheragi.com), [**PROTOGAIA**](https://protogaia.com), [**ATHERLABS**](https://atherlabs.com/) & [**SIPHER ODYSSEY**](http://playsipher.com/) 259 | 260 | ### About ProtoGAIA 261 | 262 | ProtoGAIA offers powerful collaboration features for Generative AI Image workflows. It is designed to help designers and creative professionals of all levels collaborate more efficiently, unleash their creativity, and have full transparency and tracking over the creation process. 263 | 264 | ### Current protoGAIA Features 265 | 266 | Like any open project that seeks to bring the powerful of Generative AI to the masses, ProtoGAIA offers the following key features: 267 | 268 | ✅ Seamless Access: available on desktop and mobile 269 | 270 | ✅ Powerful Macro Abilities that allowing the chaining of tasks, which is then packaged as Macro Command ready for AI Agent Automation 271 | 272 | ✅ Multiplayer & Collaborative UX. Strong collaboration features, such as real-time commenting and feedback, version control, and image/file/project sharing 273 | 274 | ✅ Rooms Chat for lively discussion between users and running Generative AI workflows right in the chat 275 | 276 | ✅ Custom Models Management including Lora, Diffusion Models, Controlnet Models and more 277 | 278 | ✅ Powerful semantic search capabilities 279 | 280 | ✅ Powerful AI driven chat box that can trigger quick Generative AI tasks and workflows 281 | 282 | ✅ Building on shoulders of Giants, leveraging A1111/Vladnmandic and other pioneers, provide collaboration process from Idea to Final Results in 1 platform 283 | 284 | ✅ Automation tooling for certain repeated tasks 285 | 286 | ✅ Secure and transparent, leveraging hasing and metadata to track the origin and history of models, loras, images to allow for tracability and ease of collaboration 287 | 288 | ✅ Personalize UIUX for both beginner and experienced users to quickly remix existing SD images by editing prompts and negative prompts, selecting new training models and output quality as desired 289 | 290 | ✅ Provenance Tracking for all models, loras, images to allow for tracability and ease of collaboration 291 | 292 | ✅ Custom UIUX for both beginner and experienced users to quickly remix existing SD images by editing prompts and negative prompts, selecting new training models and output quality as desired 293 | 294 | ✅ Articles and Tutorials for learning Generative AI 295 | 296 | ✅ Voting System for best generative AI images, models, recipes, macros etc 297 | 298 | ✅ Open sharing of generative AI images, models, recipes, macros etc via the Global Explore tab 299 | 300 | ### Target Audience 301 | 302 | ProtoGAIA is designed for the following target audiences: 303 | 304 | - Creators 305 | - Small Design Teams or Freelancers 306 | - Design Agencies & Game Studios 307 | - AI Agents 308 | 309 | ## 🎉 Stay Tuned for Updates 310 | 311 | We hope you find this extension to be useful. We will be adding new features and improvements over time as we enhance this extension to support our creative workflows. 312 | 313 | To stay up-to-date with the latest news and updates, be sure to follow us on GitHub and Twitter. We welcome your feedback and suggestions, and are excited to hear how AgentScheduler can help you streamline your workflow and unleash your creativity! 314 | -------------------------------------------------------------------------------- /agent_scheduler/api.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import json 4 | import base64 5 | import requests 6 | import threading 7 | from uuid import uuid4 8 | from zipfile import ZipFile 9 | from pathlib import Path 10 | from secrets import compare_digest 11 | from typing import Optional, Dict, List 12 | from datetime import datetime, timezone 13 | from collections import defaultdict 14 | from gradio.routes import App 15 | from PIL import Image 16 | from fastapi import Depends 17 | from fastapi.responses import StreamingResponse 18 | from fastapi.security import HTTPBasic, HTTPBasicCredentials 19 | from fastapi.exceptions import HTTPException 20 | from pydantic import BaseModel 21 | 22 | from modules import shared, progress, sd_models, sd_samplers 23 | 24 | from .db import Task, TaskStatus, task_manager 25 | from .models import ( 26 | Txt2ImgApiTaskArgs, 27 | Img2ImgApiTaskArgs, 28 | QueueTaskResponse, 29 | QueueStatusResponse, 30 | HistoryResponse, 31 | TaskModel, 32 | UpdateTaskArgs, 33 | ) 34 | from .task_runner import TaskRunner 35 | from .helpers import log, request_with_retry 36 | from .task_helpers import encode_image_to_base64, img2img_image_args_by_mode 37 | 38 | 39 | def api_callback(callback_url: str, task_id: str, status: TaskStatus, images: list): 40 | files = [] 41 | for img in images: 42 | img_path = Path(img) 43 | ext = img_path.suffix.lower() 44 | content_type = f"image/{ext[1:]}" 45 | files.append( 46 | ( 47 | "files", 48 | (img_path.name, open(os.path.abspath(img), "rb"), content_type), 49 | ) 50 | ) 51 | 52 | return requests.post( 53 | callback_url, 54 | timeout=5, 55 | data={"task_id": task_id, "status": status.value}, 56 | files=files, 57 | ) 58 | 59 | 60 | def on_task_finished( 61 | task_id: str, 62 | task: Task, 63 | status: TaskStatus = None, 64 | result: dict = None, 65 | **_, 66 | ): 67 | # handle api task callback 68 | if not task.api_task_callback: 69 | return 70 | 71 | upload = lambda: api_callback( 72 | task.api_task_callback, 73 | task_id=task_id, 74 | status=status, 75 | images=result["images"], 76 | ) 77 | 78 | request_with_retry(upload) 79 | 80 | 81 | def regsiter_apis(app: App, task_runner: TaskRunner): 82 | api_credentials = {} 83 | deps = None 84 | 85 | def auth(credentials: HTTPBasicCredentials = Depends(HTTPBasic())): 86 | if credentials.username in api_credentials: 87 | if compare_digest(credentials.password, api_credentials[credentials.username]): 88 | return True 89 | 90 | raise HTTPException( 91 | status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"} 92 | ) 93 | 94 | if shared.cmd_opts.api_auth: 95 | api_credentials = {} 96 | 97 | for cred in shared.cmd_opts.api_auth.split(","): 98 | user, password = cred.split(":") 99 | api_credentials[user] = password 100 | 101 | deps = [Depends(auth)] 102 | 103 | log.info("[AgentScheduler] Registering APIs") 104 | 105 | @app.get("/agent-scheduler/v1/samplers", response_model=List[str]) 106 | def get_samplers(): 107 | return [sampler[0] for sampler in sd_samplers.all_samplers] 108 | 109 | @app.get("/agent-scheduler/v1/sd-models", response_model=List[str]) 110 | def get_sd_models(): 111 | return [x.title for x in sd_models.checkpoints_list.values()] 112 | 113 | @app.post("/agent-scheduler/v1/queue/txt2img", response_model=QueueTaskResponse, dependencies=deps) 114 | def queue_txt2img(body: Txt2ImgApiTaskArgs): 115 | task_id = str(uuid4()) 116 | args = body.dict() 117 | checkpoint = args.pop("checkpoint", None) 118 | vae = args.pop("vae", None) 119 | callback_url = args.pop("callback_url", None) 120 | task = task_runner.register_api_task( 121 | task_id, 122 | api_task_id=None, 123 | is_img2img=False, 124 | args=args, 125 | checkpoint=checkpoint, 126 | vae=vae, 127 | ) 128 | if callback_url: 129 | task.api_task_callback = callback_url 130 | task_manager.update_task(task) 131 | 132 | task_runner.execute_pending_tasks_threading() 133 | 134 | return QueueTaskResponse(task_id=task_id) 135 | 136 | @app.post("/agent-scheduler/v1/queue/img2img", response_model=QueueTaskResponse, dependencies=deps) 137 | def queue_img2img(body: Img2ImgApiTaskArgs): 138 | task_id = str(uuid4()) 139 | args = body.dict() 140 | checkpoint = args.pop("checkpoint", None) 141 | vae = args.pop("vae", None) 142 | callback_url = args.pop("callback_url", None) 143 | task = task_runner.register_api_task( 144 | task_id, 145 | api_task_id=None, 146 | is_img2img=True, 147 | args=args, 148 | checkpoint=checkpoint, 149 | vae=vae, 150 | ) 151 | if callback_url: 152 | task.api_task_callback = callback_url 153 | task_manager.update_task(task) 154 | 155 | task_runner.execute_pending_tasks_threading() 156 | 157 | return QueueTaskResponse(task_id=task_id) 158 | 159 | def format_task_args(task): 160 | task_args = TaskRunner.instance.parse_task_args(task, deserialization=False) 161 | named_args = task_args.named_args 162 | named_args["checkpoint"] = task_args.checkpoint 163 | # remove unused args to reduce payload size 164 | named_args.pop("alwayson_scripts", None) 165 | named_args.pop("script_args", None) 166 | named_args.pop("init_images", None) 167 | for image_args in img2img_image_args_by_mode.values(): 168 | for keys in image_args: 169 | named_args.pop(keys[0], None) 170 | return named_args 171 | 172 | @app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse, dependencies=deps) 173 | def queue_status_api(limit: int = 20, offset: int = 0): 174 | current_task_id = progress.current_task 175 | total_pending_tasks = task_manager.count_tasks(status="pending") 176 | pending_tasks = task_manager.get_tasks(status=TaskStatus.PENDING, limit=limit, offset=offset) 177 | position = offset 178 | parsed_tasks = [] 179 | for task in pending_tasks: 180 | params = format_task_args(task) 181 | task_data = task.dict() 182 | task_data["params"] = params 183 | if task.id == current_task_id: 184 | task_data["status"] = "running" 185 | 186 | task_data["position"] = position 187 | parsed_tasks.append(TaskModel(**task_data)) 188 | position += 1 189 | 190 | return QueueStatusResponse( 191 | current_task_id=current_task_id, 192 | pending_tasks=parsed_tasks, 193 | total_pending_tasks=total_pending_tasks, 194 | paused=TaskRunner.instance.paused, 195 | ) 196 | 197 | @app.get("/agent-scheduler/v1/export") 198 | def export_queue(limit: int = 1000, offset: int = 0): 199 | pending_tasks = task_manager.get_tasks(status=TaskStatus.PENDING, limit=limit, offset=offset) 200 | pending_tasks = [Task.from_table(t).to_json() for t in pending_tasks] 201 | return pending_tasks 202 | 203 | class StringRequestBody(BaseModel): 204 | content: str 205 | 206 | @app.post("/agent-scheduler/v1/import") 207 | def import_queue(queue: StringRequestBody): 208 | try: 209 | objList = json.loads(queue.content) 210 | taskList: List[Task] = [] 211 | for obj in objList: 212 | if "id" not in obj or not obj["id"] or obj["id"] == "": 213 | obj["id"] = str(uuid4()) 214 | obj["result"] = None 215 | obj["status"] = TaskStatus.PENDING 216 | task = Task.from_json(obj) 217 | taskList.append(task) 218 | 219 | for task in taskList: 220 | exists = task_manager.get_task(task.id) 221 | if exists: 222 | task_manager.update_task(task) 223 | else: 224 | task_manager.add_task(task) 225 | return {"success": True, "message": "Queue imported"} 226 | except Exception as e: 227 | print(e) 228 | return {"success": False, "message": "Import Failed"} 229 | 230 | @app.get("/agent-scheduler/v1/history", response_model=HistoryResponse, dependencies=deps) 231 | def history_api(status: str = None, limit: int = 20, offset: int = 0): 232 | bookmarked = True if status == "bookmarked" else None 233 | if not status or status == "all" or bookmarked: 234 | status = [ 235 | TaskStatus.DONE, 236 | TaskStatus.FAILED, 237 | TaskStatus.INTERRUPTED, 238 | ] 239 | 240 | total = task_manager.count_tasks(status=status) 241 | tasks = task_manager.get_tasks( 242 | status=status, 243 | bookmarked=bookmarked, 244 | limit=limit, 245 | offset=offset, 246 | order="desc", 247 | ) 248 | parsed_tasks = [] 249 | for task in tasks: 250 | params = format_task_args(task) 251 | task_data = task.dict() 252 | task_data["params"] = params 253 | parsed_tasks.append(TaskModel(**task_data)) 254 | 255 | return HistoryResponse( 256 | total=total, 257 | tasks=parsed_tasks, 258 | ) 259 | 260 | @app.get("/agent-scheduler/v1/task/{id}", dependencies=deps) 261 | def get_task(id: str): 262 | task = task_manager.get_task(id) 263 | if task is None: 264 | return {"success": False, "message": "Task not found"} 265 | 266 | params = format_task_args(task) 267 | task_data = task.dict() 268 | task_data["params"] = params 269 | if task.id == progress.current_task: 270 | task_data["status"] = "running" 271 | if task_data["status"] == TaskStatus.PENDING: 272 | task_data["position"] = task_manager.get_task_position(id) 273 | 274 | return {"success": True, "data": TaskModel(**task_data)} 275 | 276 | @app.get("/agent-scheduler/v1/task/{id}/position", dependencies=deps) 277 | def get_task_position(id: str): 278 | task = task_manager.get_task(id) 279 | if task is None: 280 | return {"success": False, "message": "Task not found"} 281 | 282 | position = None if task.status != TaskStatus.PENDING else task_manager.get_task_position(id) 283 | return {"success": True, "data": {"status": task.status, "position": position}} 284 | 285 | @app.put("/agent-scheduler/v1/task/{id}", dependencies=deps) 286 | def update_task(id: str, body: UpdateTaskArgs): 287 | task = task_manager.get_task(id) 288 | if task is None: 289 | return {"success": False, "message": "Task not found"} 290 | 291 | should_save = False 292 | if body.name is not None: 293 | task.name = body.name 294 | should_save = True 295 | 296 | if body.checkpoint or body.params: 297 | params: Dict = json.loads(task.params) 298 | if body.checkpoint is not None: 299 | params["checkpoint"] = body.checkpoint 300 | if body.checkpoint is not None: 301 | params["args"].update(body.params) 302 | 303 | task.params = json.dumps(params) 304 | should_save = True 305 | 306 | if should_save: 307 | task_manager.update_task(task) 308 | 309 | return {"success": True, "message": "Task updated."} 310 | 311 | @app.post("/agent-scheduler/v1/run/{id}", dependencies=deps, deprecated=True) 312 | @app.post("/agent-scheduler/v1/task/{id}/run", dependencies=deps) 313 | def run_task(id: str): 314 | if progress.current_task is not None: 315 | if progress.current_task == id: 316 | return {"success": False, "message": "Task is running"} 317 | else: 318 | # move task up in queue 319 | task_manager.prioritize_task(id, 0) 320 | return { 321 | "success": True, 322 | "message": "Task is scheduled to run next", 323 | } 324 | else: 325 | # run task 326 | task = task_manager.get_task(id) 327 | current_thread = threading.Thread( 328 | target=TaskRunner.instance.execute_task, 329 | args=( 330 | task, 331 | lambda: None, 332 | ), 333 | ) 334 | current_thread.daemon = True 335 | current_thread.start() 336 | 337 | return {"success": True, "message": "Task is executing"} 338 | 339 | @app.post("/agent-scheduler/v1/requeue/{id}", dependencies=deps, deprecated=True) 340 | @app.post("/agent-scheduler/v1/task/{id}/requeue", dependencies=deps) 341 | def requeue_task(id: str): 342 | task = task_manager.get_task(id) 343 | if task is None: 344 | return {"success": False, "message": "Task not found"} 345 | 346 | task.id = str(uuid4()) 347 | task.result = None 348 | task.status = TaskStatus.PENDING 349 | task.bookmarked = False 350 | task.name = f"Copy of {task.name}" if task.name else None 351 | task_manager.add_task(task) 352 | task_runner.execute_pending_tasks_threading() 353 | 354 | return {"success": True, "message": "Task requeued"} 355 | 356 | @app.post("/agent-scheduler/v1/task/requeue-failed", dependencies=deps) 357 | def requeue_failed_tasks(): 358 | failed_tasks = task_manager.get_tasks(status=TaskStatus.FAILED) 359 | if (len(failed_tasks)) == 0: 360 | return {"success": False, "message": "No failed tasks"} 361 | 362 | for task in failed_tasks: 363 | task.status = TaskStatus.PENDING 364 | task.result = None 365 | task.priority = int(datetime.now(timezone.utc).timestamp() * 1000) 366 | task_manager.update_task(task) 367 | 368 | return {"success": True, "message": f"Requeued {len(failed_tasks)} failed tasks"} 369 | 370 | @app.post("/agent-scheduler/v1/delete/{id}", dependencies=deps, deprecated=True) 371 | @app.delete("/agent-scheduler/v1/task/{id}", dependencies=deps) 372 | def delete_task(id: str): 373 | if progress.current_task == id: 374 | shared.state.interrupt() 375 | task_runner.interrupted = id 376 | return {"success": True, "message": "Task interrupted"} 377 | 378 | task_manager.delete_task(id) 379 | return {"success": True, "message": "Task deleted"} 380 | 381 | @app.post("/agent-scheduler/v1/move/{id}/{over_id}", dependencies=deps, deprecated=True) 382 | @app.post("/agent-scheduler/v1/task/{id}/move/{over_id}", dependencies=deps) 383 | def move_task(id: str, over_id: str): 384 | task = task_manager.get_task(id) 385 | if task is None: 386 | return {"success": False, "message": "Task not found"} 387 | 388 | if over_id == "top": 389 | task_manager.prioritize_task(id, 0) 390 | return {"success": True, "message": "Task moved to top"} 391 | elif over_id == "bottom": 392 | task_manager.prioritize_task(id, -1) 393 | return {"success": True, "message": "Task moved to bottom"} 394 | else: 395 | over_task = task_manager.get_task(over_id) 396 | if over_task is None: 397 | return {"success": False, "message": "Task not found"} 398 | 399 | task_manager.prioritize_task(id, over_task.priority) 400 | return {"success": True, "message": "Task moved"} 401 | 402 | @app.post("/agent-scheduler/v1/bookmark/{id}", dependencies=deps, deprecated=True) 403 | @app.post("/agent-scheduler/v1/task/{id}/bookmark", dependencies=deps) 404 | def pin_task(id: str): 405 | task = task_manager.get_task(id) 406 | if task is None: 407 | return {"success": False, "message": "Task not found"} 408 | 409 | task.bookmarked = True 410 | task_manager.update_task(task) 411 | return {"success": True, "message": "Task bookmarked"} 412 | 413 | @app.post("/agent-scheduler/v1/unbookmark/{id}", dependencies=deps, deprecated=True) 414 | @app.post("/agent-scheduler/v1/task/{id}/unbookmark") 415 | def unpin_task(id: str): 416 | task = task_manager.get_task(id) 417 | if task is None: 418 | return {"success": False, "message": "Task not found"} 419 | 420 | task.bookmarked = False 421 | task_manager.update_task(task) 422 | return {"success": True, "message": "Task unbookmarked"} 423 | 424 | @app.post("/agent-scheduler/v1/rename/{id}", dependencies=deps, deprecated=True) 425 | @app.post("/agent-scheduler/v1/task/{id}/rename", dependencies=deps) 426 | def rename_task(id: str, name: str): 427 | task = task_manager.get_task(id) 428 | if task is None: 429 | return {"success": False, "message": "Task not found"} 430 | 431 | task.name = name 432 | task_manager.update_task(task) 433 | return {"success": True, "message": "Task renamed."} 434 | 435 | @app.get("/agent-scheduler/v1/results/{id}", dependencies=deps, deprecated=True) 436 | @app.get("/agent-scheduler/v1/task/{id}/results", dependencies=deps) 437 | def get_task_results(id: str, zip: Optional[bool] = False): 438 | task = task_manager.get_task(id) 439 | if task is None: 440 | return {"success": False, "message": "Task not found"} 441 | 442 | if task.status != TaskStatus.DONE: 443 | return {"success": False, "message": f"Task is {task.status}"} 444 | 445 | if task.result is None: 446 | return {"success": False, "message": "Task result is not available"} 447 | 448 | result: dict = json.loads(task.result) 449 | infotexts = result.get("infotexts", None) 450 | if infotexts is None: 451 | geninfo = result.get("geninfo", {}) 452 | infotexts = geninfo.get("infotexts", defaultdict(lambda: "")) 453 | 454 | if zip: 455 | zip_buffer = io.BytesIO() 456 | 457 | # Create a new zip file in the in-memory buffer 458 | with ZipFile(zip_buffer, "w") as zip_file: 459 | # Loop through the files in the directory and add them to the zip file 460 | for image in result["images"]: 461 | if Path(image).is_file(): 462 | zip_file.write(Path(image), Path(image).name) 463 | 464 | # Reset the buffer position to the beginning to avoid truncation issues 465 | zip_buffer.seek(0) 466 | 467 | # Return the in-memory buffer as a streaming response with the appropriate headers 468 | return StreamingResponse( 469 | zip_buffer, 470 | media_type="application/zip", 471 | headers={"Content-Disposition": f"attachment; filename=results-{id}.zip"}, 472 | ) 473 | else: 474 | data = [ 475 | { 476 | "image": encode_image_to_base64(Image.open(image)), 477 | "infotext": infotexts[i], 478 | } 479 | for i, image in enumerate(result["images"]) 480 | if Path(image).is_file() 481 | ] 482 | 483 | return {"success": True, "data": data} 484 | 485 | @app.post("/agent-scheduler/v1/pause", dependencies=deps, deprecated=True) 486 | @app.post("/agent-scheduler/v1/queue/pause", dependencies=deps) 487 | def pause_queue(): 488 | shared.opts.queue_paused = True 489 | return {"success": True, "message": "Queue paused."} 490 | 491 | @app.post("/agent-scheduler/v1/resume", dependencies=deps, deprecated=True) 492 | @app.post("/agent-scheduler/v1/queue/resume", dependencies=deps) 493 | def resume_queue(): 494 | shared.opts.queue_paused = False 495 | TaskRunner.instance.execute_pending_tasks_threading() 496 | return {"success": True, "message": "Queue resumed."} 497 | 498 | @app.post("/agent-scheduler/v1/queue/clear", dependencies=deps) 499 | def clear_queue(): 500 | task_manager.delete_tasks(status=TaskStatus.PENDING) 501 | return {"success": True, "message": "Queue cleared."} 502 | 503 | @app.post("/agent-scheduler/v1/history/clear", dependencies=deps) 504 | def clear_history(): 505 | task_manager.delete_tasks( 506 | status=[ 507 | TaskStatus.DONE, 508 | TaskStatus.FAILED, 509 | TaskStatus.INTERRUPTED, 510 | ] 511 | ) 512 | return {"success": True, "message": "History cleared."} 513 | 514 | task_runner.on_task_finished(on_task_finished) 515 | -------------------------------------------------------------------------------- /agent_scheduler/db/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from sqlalchemy import create_engine, inspect, text, String, Text 3 | 4 | from .base import Base, metadata, db_file 5 | from .app_state import AppStateKey, AppState, AppStateManager 6 | from .task import TaskStatus, Task, TaskManager 7 | 8 | version = "2" 9 | 10 | state_manager = AppStateManager() 11 | task_manager = TaskManager() 12 | 13 | 14 | def init(): 15 | engine = create_engine(f"sqlite:///{db_file}") 16 | 17 | metadata.create_all(engine) 18 | 19 | state_manager.set_value(AppStateKey.Version, version) 20 | # check if app state exists 21 | if state_manager.get_value(AppStateKey.QueueState) is None: 22 | # create app state 23 | state_manager.set_value(AppStateKey.QueueState, "running") 24 | 25 | inspector = inspect(engine) 26 | with engine.connect() as conn: 27 | task_columns = inspector.get_columns("task") 28 | # add result column 29 | if not any(col["name"] == "result" for col in task_columns): 30 | conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT")) 31 | 32 | # add api_task_id column 33 | if not any(col["name"] == "api_task_id" for col in task_columns): 34 | conn.execute(text("ALTER TABLE task ADD COLUMN api_task_id VARCHAR(64)")) 35 | 36 | # add api_task_callback column 37 | if not any(col["name"] == "api_task_callback" for col in task_columns): 38 | conn.execute(text("ALTER TABLE task ADD COLUMN api_task_callback VARCHAR(255)")) 39 | 40 | # add name column 41 | if not any(col["name"] == "name" for col in task_columns): 42 | conn.execute(text("ALTER TABLE task ADD COLUMN name VARCHAR(255)")) 43 | 44 | # add bookmarked column 45 | if not any(col["name"] == "bookmarked" for col in task_columns): 46 | conn.execute(text("ALTER TABLE task ADD COLUMN bookmarked BOOLEAN DEFAULT FALSE")) 47 | 48 | params_column = next(col for col in task_columns if col["name"] == "params") 49 | if version > "1" and not isinstance(params_column["type"], Text): 50 | transaction = conn.begin() 51 | conn.execute( 52 | text( 53 | """ 54 | CREATE TABLE task_temp ( 55 | id VARCHAR(64) NOT NULL, 56 | type VARCHAR(20) NOT NULL, 57 | params TEXT NOT NULL, 58 | script_params BLOB NOT NULL, 59 | priority INTEGER NOT NULL, 60 | status VARCHAR(20) NOT NULL, 61 | created_at DATETIME DEFAULT (datetime('now')) NOT NULL, 62 | updated_at DATETIME DEFAULT (datetime('now')) NOT NULL, 63 | result TEXT, 64 | PRIMARY KEY (id) 65 | )""" 66 | ) 67 | ) 68 | conn.execute(text("INSERT INTO task_temp SELECT * FROM task")) 69 | conn.execute(text("DROP TABLE task")) 70 | conn.execute(text("ALTER TABLE task_temp RENAME TO task")) 71 | transaction.commit() 72 | 73 | conn.close() 74 | 75 | 76 | __all__ = [ 77 | "init", 78 | "Base", 79 | "metadata", 80 | "db_file", 81 | "AppStateKey", 82 | "AppState", 83 | "TaskStatus", 84 | "Task", 85 | "task_manager", 86 | "state_manager", 87 | ] 88 | -------------------------------------------------------------------------------- /agent_scheduler/db/app_state.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | from sqlalchemy import Column, String 5 | from sqlalchemy.orm import Session 6 | 7 | from .base import BaseTableManager, Base 8 | 9 | 10 | class AppStateKey(str, Enum): 11 | Version = "version" 12 | QueueState = "queue_state" # paused or running 13 | 14 | 15 | class AppState: 16 | def __init__(self, key: str, value: str): 17 | self.key: str = key 18 | self.value: str = value 19 | 20 | @staticmethod 21 | def from_table(table: "AppStateTable"): 22 | return AppState(table.key, table.value) 23 | 24 | def to_table(self): 25 | return AppStateTable(key=self.key, value=self.value) 26 | 27 | 28 | class AppStateTable(Base): 29 | __tablename__ = "app_state" 30 | 31 | key = Column(String(64), primary_key=True) 32 | value = Column(String(255), nullable=True) 33 | 34 | def __repr__(self): 35 | return f"AppState(key={self.key!r}, value={self.value!r})" 36 | 37 | 38 | class AppStateManager(BaseTableManager): 39 | def get_value(self, key: str) -> Union[str, None]: 40 | session = Session(self.engine) 41 | try: 42 | result = session.get(AppStateTable, key) 43 | if result: 44 | return result.value 45 | else: 46 | return None 47 | except Exception as e: 48 | print(f"Exception getting value from database: {e}") 49 | raise e 50 | finally: 51 | session.close() 52 | 53 | def set_value(self, key: str, value: str): 54 | session = Session(self.engine) 55 | try: 56 | result = session.get(AppStateTable, key) 57 | if result: 58 | result.value = value 59 | else: 60 | result = AppStateTable(key=key, value=value) 61 | session.add(result) 62 | session.commit() 63 | except Exception as e: 64 | print(f"Exception setting value in database: {e}") 65 | raise e 66 | finally: 67 | session.close() 68 | 69 | def delete_value(self, key: str): 70 | session = Session(self.engine) 71 | try: 72 | result = session.get(AppStateTable, key) 73 | if result: 74 | session.delete(result) 75 | session.commit() 76 | except Exception as e: 77 | print(f"Exception deleting value from database: {e}") 78 | raise e 79 | finally: 80 | session.close() 81 | -------------------------------------------------------------------------------- /agent_scheduler/db/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sqlalchemy import create_engine 4 | from sqlalchemy.schema import MetaData 5 | from sqlalchemy.orm import declarative_base 6 | 7 | from modules import scripts 8 | from modules import shared 9 | 10 | if hasattr(shared.cmd_opts, "agent_scheduler_sqlite_file"): 11 | # if relative path, join with basedir 12 | if not os.path.isabs(shared.cmd_opts.agent_scheduler_sqlite_file): 13 | db_file = os.path.join(scripts.basedir(), shared.cmd_opts.agent_scheduler_sqlite_file) 14 | else: 15 | db_file = os.path.abspath(shared.cmd_opts.agent_scheduler_sqlite_file) 16 | else: 17 | db_file = os.path.join(scripts.basedir(), "task_scheduler.sqlite3") 18 | 19 | print(f"Using sqlite file: {db_file}") 20 | 21 | 22 | Base = declarative_base() 23 | metadata: MetaData = Base.metadata 24 | 25 | class BaseTableManager: 26 | def __init__(self, engine = None): 27 | # Get the db connection object, making the file and tables if needed. 28 | try: 29 | self.engine = engine if engine else create_engine(f"sqlite:///{db_file}") 30 | except Exception as e: 31 | print(f"Exception connecting to database: {e}") 32 | raise e 33 | 34 | def get_engine(self): 35 | return self.engine 36 | 37 | # Commit and close the database connection. 38 | def quit(self): 39 | self.engine.dispose() 40 | -------------------------------------------------------------------------------- /agent_scheduler/db/task.py: -------------------------------------------------------------------------------- 1 | import json 2 | import base64 3 | from enum import Enum 4 | from datetime import datetime, timezone 5 | from typing import Optional, Union, List, Dict 6 | 7 | from sqlalchemy import ( 8 | TypeDecorator, 9 | Column, 10 | String, 11 | Text, 12 | Integer, 13 | DateTime as DateTimeImpl, 14 | LargeBinary, 15 | Boolean, 16 | text, 17 | func, 18 | ) 19 | from sqlalchemy.orm import Session 20 | 21 | from .base import BaseTableManager, Base 22 | from ..models import TaskModel 23 | 24 | 25 | class DateTime(TypeDecorator): 26 | impl = DateTimeImpl 27 | cache_ok = True 28 | 29 | def process_bind_param(self, value: Optional[datetime], _): 30 | if value is None: 31 | return None 32 | return value.astimezone(timezone.utc) 33 | 34 | def process_result_value(self, value: Optional[datetime], _): 35 | if value is None: 36 | return None 37 | if value.tzinfo is None: 38 | return value.replace(tzinfo=timezone.utc) 39 | return value.astimezone(timezone.utc) 40 | 41 | 42 | class TaskStatus(str, Enum): 43 | PENDING = "pending" 44 | RUNNING = "running" 45 | DONE = "done" 46 | FAILED = "failed" 47 | INTERRUPTED = "interrupted" 48 | 49 | 50 | class Task(TaskModel): 51 | script_params: bytes = None 52 | params: str 53 | 54 | def __init__(self, **kwargs): 55 | priority = kwargs.pop("priority", int(datetime.now(timezone.utc).timestamp() * 1000)) 56 | super().__init__(priority=priority, **kwargs) 57 | 58 | class Config(TaskModel.__config__): 59 | exclude = ["script_params"] 60 | 61 | @staticmethod 62 | def from_table(table: "TaskTable"): 63 | return Task( 64 | id=table.id, 65 | api_task_id=table.api_task_id, 66 | api_task_callback=table.api_task_callback, 67 | name=table.name, 68 | type=table.type, 69 | params=table.params, 70 | script_params=table.script_params, 71 | priority=table.priority, 72 | status=table.status, 73 | result=table.result, 74 | bookmarked=table.bookmarked, 75 | created_at=table.created_at, 76 | updated_at=table.updated_at, 77 | ) 78 | 79 | def to_table(self): 80 | return TaskTable( 81 | id=self.id, 82 | api_task_id=self.api_task_id, 83 | api_task_callback=self.api_task_callback, 84 | name=self.name, 85 | type=self.type, 86 | params=self.params, 87 | script_params=self.script_params, 88 | priority=self.priority, 89 | status=self.status, 90 | result=self.result, 91 | bookmarked=self.bookmarked, 92 | ) 93 | 94 | def from_json(json_obj: Dict): 95 | return Task( 96 | id=json_obj.get("id"), 97 | api_task_id=json_obj.get("api_task_id", None), 98 | api_task_callback=json_obj.get("api_task_callback", None), 99 | name=json_obj.get("name", None), 100 | type=json_obj.get("type"), 101 | status=json_obj.get("status", TaskStatus.PENDING), 102 | params=json.dumps(json_obj.get("params")), 103 | script_params=base64.b64decode(json_obj.get("script_params")), 104 | priority=json_obj.get("priority", int(datetime.now(timezone.utc).timestamp() * 1000)), 105 | result=json_obj.get("result", None), 106 | bookmarked=json_obj.get("bookmarked", False), 107 | created_at=datetime.fromtimestamp(json_obj.get("created_at", datetime.now(timezone.utc).timestamp())), 108 | updated_at=datetime.fromtimestamp(json_obj.get("updated_at", datetime.now(timezone.utc).timestamp())), 109 | ) 110 | 111 | def to_json(self): 112 | return { 113 | "id": self.id, 114 | "api_task_id": self.api_task_id, 115 | "api_task_callback": self.api_task_callback, 116 | "name": self.name, 117 | "type": self.type, 118 | "status": self.status, 119 | "params": json.loads(self.params), 120 | "script_params": base64.b64encode(self.script_params).decode("utf-8"), 121 | "priority": self.priority, 122 | "result": self.result, 123 | "bookmarked": self.bookmarked, 124 | "created_at": int(self.created_at.timestamp()), 125 | "updated_at": int(self.updated_at.timestamp()), 126 | } 127 | 128 | 129 | class TaskTable(Base): 130 | __tablename__ = "task" 131 | 132 | id = Column(String(64), primary_key=True) 133 | api_task_id = Column(String(64), nullable=True) 134 | api_task_callback = Column(String(255), nullable=True) 135 | name = Column(String(255), nullable=True) 136 | type = Column(String(20), nullable=False) # txt2img or img2txt 137 | params = Column(Text, nullable=False) # task args 138 | script_params = Column(LargeBinary, nullable=False) # script args 139 | priority = Column(Integer, nullable=False) 140 | status = Column(String(20), nullable=False, default="pending") # pending, running, done, failed 141 | result = Column(Text) # task result 142 | bookmarked = Column(Boolean, nullable=True, default=False) 143 | created_at = Column( 144 | DateTime, 145 | nullable=False, 146 | server_default=text("(datetime('now'))"), 147 | ) 148 | updated_at = Column( 149 | DateTime, 150 | nullable=False, 151 | server_default=text("(datetime('now'))"), 152 | onupdate=text("(datetime('now'))"), 153 | ) 154 | 155 | def __repr__(self): 156 | return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})" 157 | 158 | 159 | class TaskManager(BaseTableManager): 160 | def get_task(self, id: str) -> Union[TaskTable, None]: 161 | session = Session(self.engine) 162 | try: 163 | task = session.get(TaskTable, id) 164 | 165 | return Task.from_table(task) if task else None 166 | except Exception as e: 167 | print(f"Exception getting task from database: {e}") 168 | raise e 169 | finally: 170 | session.close() 171 | 172 | def get_task_position(self, id: str) -> int: 173 | session = Session(self.engine) 174 | try: 175 | task = session.get(TaskTable, id) 176 | if task: 177 | return ( 178 | session.query(func.count(TaskTable.id)) 179 | .filter(TaskTable.status == TaskStatus.PENDING) 180 | .filter(TaskTable.priority < task.priority) 181 | .scalar() 182 | ) 183 | else: 184 | raise Exception(f"Task with id {id} not found") 185 | except Exception as e: 186 | print(f"Exception getting task position from database: {e}") 187 | raise e 188 | finally: 189 | session.close() 190 | 191 | def get_tasks( 192 | self, 193 | type: str = None, 194 | status: Union[str, List[str]] = None, 195 | bookmarked: bool = None, 196 | api_task_id: str = None, 197 | limit: int = None, 198 | offset: int = None, 199 | order: str = "asc", 200 | ) -> List[TaskTable]: 201 | session = Session(self.engine) 202 | try: 203 | query = session.query(TaskTable) 204 | if type: 205 | query = query.filter(TaskTable.type == type) 206 | 207 | if status is not None: 208 | if isinstance(status, list): 209 | query = query.filter(TaskTable.status.in_(status)) 210 | else: 211 | query = query.filter(TaskTable.status == status) 212 | 213 | if api_task_id: 214 | query = query.filter(TaskTable.api_task_id == api_task_id) 215 | 216 | if bookmarked == True: 217 | query = query.filter(TaskTable.bookmarked == bookmarked) 218 | else: 219 | query = query.order_by(TaskTable.bookmarked.asc()) 220 | 221 | query = query.order_by(TaskTable.priority.asc() if order == "asc" else TaskTable.priority.desc()) 222 | 223 | if limit: 224 | query = query.limit(limit) 225 | 226 | if offset: 227 | query = query.offset(offset) 228 | 229 | all = query.all() 230 | return [Task.from_table(t) for t in all] 231 | except Exception as e: 232 | print(f"Exception getting tasks from database: {e}") 233 | raise e 234 | finally: 235 | session.close() 236 | 237 | def count_tasks( 238 | self, 239 | type: str = None, 240 | status: Union[str, List[str]] = None, 241 | api_task_id: str = None, 242 | ) -> int: 243 | session = Session(self.engine) 244 | try: 245 | query = session.query(TaskTable) 246 | if type: 247 | query = query.filter(TaskTable.type == type) 248 | 249 | if status is not None: 250 | if isinstance(status, list): 251 | query = query.filter(TaskTable.status.in_(status)) 252 | else: 253 | query = query.filter(TaskTable.status == status) 254 | 255 | if api_task_id: 256 | query = query.filter(TaskTable.api_task_id == api_task_id) 257 | 258 | return query.count() 259 | except Exception as e: 260 | print(f"Exception counting tasks from database: {e}") 261 | raise e 262 | finally: 263 | session.close() 264 | 265 | def add_task(self, task: Task) -> TaskTable: 266 | session = Session(self.engine) 267 | try: 268 | item = task.to_table() 269 | session.add(item) 270 | session.commit() 271 | return task 272 | except Exception as e: 273 | print(f"Exception adding task to database: {e}") 274 | raise e 275 | finally: 276 | session.close() 277 | 278 | def update_task(self, task: Task) -> TaskTable: 279 | session = Session(self.engine) 280 | try: 281 | current = session.get(TaskTable, task.id) 282 | if current is None: 283 | raise Exception(f"Task with id {id} not found") 284 | 285 | session.merge(task.to_table()) 286 | session.commit() 287 | return task 288 | 289 | except Exception as e: 290 | print(f"Exception updating task in database: {e}") 291 | raise e 292 | finally: 293 | session.close() 294 | 295 | def prioritize_task(self, id: str, priority: int) -> TaskTable: 296 | """0 means move to top, -1 means move to bottom, otherwise set the exact priority""" 297 | 298 | session = Session(self.engine) 299 | try: 300 | result = session.get(TaskTable, id) 301 | if result: 302 | if priority == 0: 303 | result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1 304 | elif priority == -1: 305 | result.priority = int(datetime.now(timezone.utc).timestamp() * 1000) 306 | else: 307 | self.__move_tasks_down(priority) 308 | session.execute(text("SELECT 1")) 309 | result.priority = priority 310 | 311 | session.commit() 312 | return result 313 | else: 314 | raise Exception(f"Task with id {id} not found") 315 | except Exception as e: 316 | print(f"Exception updating task in database: {e}") 317 | raise e 318 | finally: 319 | session.close() 320 | 321 | def delete_task(self, id: str): 322 | session = Session(self.engine) 323 | try: 324 | result = session.get(TaskTable, id) 325 | if result: 326 | session.delete(result) 327 | session.commit() 328 | else: 329 | raise Exception(f"Task with id {id} not found") 330 | except Exception as e: 331 | print(f"Exception deleting task from database: {e}") 332 | raise e 333 | finally: 334 | session.close() 335 | 336 | def delete_tasks( 337 | self, 338 | before: datetime = None, 339 | status: Union[str, List[str]] = [ 340 | TaskStatus.DONE, 341 | TaskStatus.FAILED, 342 | TaskStatus.INTERRUPTED, 343 | ], 344 | ): 345 | session = Session(self.engine) 346 | try: 347 | query = session.query(TaskTable).filter(TaskTable.bookmarked == False) 348 | 349 | if before: 350 | query = query.filter(TaskTable.created_at < before) 351 | 352 | if status is not None: 353 | if isinstance(status, list): 354 | query = query.filter(TaskTable.status.in_(status)) 355 | else: 356 | query = query.filter(TaskTable.status == status) 357 | 358 | deleted_rows = query.delete() 359 | session.commit() 360 | 361 | return deleted_rows 362 | except Exception as e: 363 | print(f"Exception deleting tasks from database: {e}") 364 | raise e 365 | finally: 366 | session.close() 367 | 368 | def __get_min_priority(self, status: str = None) -> int: 369 | session = Session(self.engine) 370 | try: 371 | query = session.query(func.min(TaskTable.priority)) 372 | if status is not None: 373 | query = query.filter(TaskTable.status == status) 374 | 375 | min_priority = query.scalar() 376 | return min_priority if min_priority else 0 377 | except Exception as e: 378 | print(f"Exception getting min priority from database: {e}") 379 | raise e 380 | finally: 381 | session.close() 382 | 383 | def __move_tasks_down(self, priority: int): 384 | session = Session(self.engine) 385 | try: 386 | session.query(TaskTable).filter(TaskTable.priority >= priority).update( 387 | {TaskTable.priority: TaskTable.priority + 1} 388 | ) 389 | session.commit() 390 | except Exception as e: 391 | print(f"Exception moving tasks down in database: {e}") 392 | raise e 393 | finally: 394 | session.close() 395 | -------------------------------------------------------------------------------- /agent_scheduler/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import abc 4 | import atexit 5 | import time 6 | import logging 7 | import platform 8 | import requests 9 | import traceback 10 | from typing import Callable, List, NoReturn 11 | 12 | import gradio as gr 13 | from gradio.blocks import Block, BlockContext 14 | 15 | is_windows = platform.system() == "Windows" 16 | is_macos = platform.system() == "Darwin" 17 | 18 | if logging.getLogger().hasHandlers(): 19 | log = logging.getLogger("sd") 20 | else: 21 | import copy 22 | class ColoredFormatter(logging.Formatter): 23 | COLORS = { 24 | "DEBUG": "\033[0;36m", # CYAN 25 | "INFO": "\033[0;32m", # GREEN 26 | "WARNING": "\033[0;33m", # YELLOW 27 | "ERROR": "\033[0;31m", # RED 28 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 29 | "RESET": "\033[0m", # RESET COLOR 30 | } 31 | 32 | def format(self, record): 33 | colored_record = copy.copy(record) 34 | levelname = colored_record.levelname 35 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 36 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 37 | return super().format(colored_record) 38 | 39 | # Create a new logger 40 | logger = logging.getLogger("AgentScheduler") 41 | logger.propagate = False 42 | 43 | # Add handler if we don't have one. 44 | if not logger.handlers: 45 | handler = logging.StreamHandler(sys.stdout) 46 | handler.setFormatter(ColoredFormatter("%(levelname)s - %(message)s")) 47 | logger.addHandler(handler) 48 | 49 | # Configure logger 50 | loglevel = logging.INFO 51 | logger.setLevel(loglevel) 52 | 53 | log = logger 54 | 55 | 56 | class Singleton(abc.ABCMeta, type): 57 | """ 58 | Singleton metaclass for ensuring only one instance of a class. 59 | """ 60 | 61 | _instances = {} 62 | 63 | def __call__(cls, *args, **kwargs): 64 | """Call method for the singleton metaclass.""" 65 | if cls not in cls._instances: 66 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 67 | return cls._instances[cls] 68 | 69 | 70 | def compare_components_with_ids(components: List[Block], ids: List[int]): 71 | return len(components) == len(ids) and all( 72 | component._id == _id for component, _id in zip(components, ids) 73 | ) 74 | 75 | 76 | def get_component_by_elem_id(root: Block, elem_id: str): 77 | if root.elem_id == elem_id: 78 | return root 79 | 80 | elem = None 81 | if isinstance(root, BlockContext): 82 | for block in root.children: 83 | elem = get_component_by_elem_id(block, elem_id) 84 | if elem is not None: 85 | break 86 | 87 | return elem 88 | 89 | 90 | def get_components_by_ids(root: Block, ids: List[int]): 91 | components: List[Block] = [] 92 | 93 | if root._id in ids: 94 | components.append(root) 95 | ids = [_id for _id in ids if _id != root._id] 96 | 97 | if isinstance(root, BlockContext): 98 | for block in root.children: 99 | components.extend(get_components_by_ids(block, ids)) 100 | 101 | return components 102 | 103 | 104 | def detect_control_net(root: gr.Blocks, submit: gr.Button): 105 | UiControlNetUnit = None 106 | 107 | dependencies: List[dict] = [ 108 | x 109 | for x in root.dependencies 110 | if x["trigger"] == "click" and submit._id in x["targets"] 111 | ] 112 | for d in dependencies: 113 | if len(d["outputs"]) == 1: 114 | outputs = get_components_by_ids(root, d["outputs"]) 115 | output = outputs[0] 116 | if ( 117 | isinstance(output, gr.State) 118 | and type(output.value).__name__ == "UiControlNetUnit" 119 | ): 120 | UiControlNetUnit = type(output.value) 121 | 122 | return UiControlNetUnit 123 | 124 | 125 | def get_dict_attribute(dict_inst: dict, name_string: str, default=None): 126 | nested_keys = name_string.split(".") 127 | value = dict_inst 128 | 129 | for key in nested_keys: 130 | value = value.get(key, None) 131 | 132 | if value is None: 133 | return default 134 | 135 | return value 136 | 137 | 138 | def set_dict_attribute(dict_inst: dict, name_string: str, value): 139 | """ 140 | Set an attribute to a dictionary using dot notation. 141 | If the attribute does not already exist, it will create a nested dictionary. 142 | 143 | Parameters: 144 | - dict_inst: the dictionary instance to set the attribute 145 | - name_string: the attribute name in dot notation (ex: 'attribute.name') 146 | - value: the value to set for the attribute 147 | 148 | Returns: 149 | None 150 | """ 151 | # Split the attribute names by dot 152 | name_list = name_string.split(".") 153 | 154 | # Traverse the dictionary and create a nested dictionary if necessary 155 | current_dict = dict_inst 156 | for name in name_list[:-1]: 157 | if name not in current_dict: 158 | current_dict[name] = {} 159 | current_dict = current_dict[name] 160 | 161 | # Set the final attribute to its value 162 | current_dict[name_list[-1]] = value 163 | 164 | 165 | def request_with_retry( 166 | make_request: Callable[[], requests.Response], 167 | max_try: int = 3, 168 | retries: int = 0, 169 | ): 170 | try: 171 | res = make_request() 172 | if res.status_code > 400: 173 | raise Exception(res.text) 174 | 175 | return True 176 | except requests.exceptions.ConnectionError: 177 | log.error("[ArtVenture] Connection error while uploading result") 178 | if retries >= max_try - 1: 179 | return False 180 | 181 | time.sleep(1) 182 | log.info(f"[ArtVenture] Retrying {retries + 1}...") 183 | return request_with_retry( 184 | make_request, 185 | max_try=max_try, 186 | retries=retries + 1, 187 | ) 188 | except Exception as e: 189 | log.error("[ArtVenture] Error while uploading result") 190 | log.error(e) 191 | log.debug(traceback.format_exc()) 192 | return False 193 | 194 | 195 | def _exit(status: int) -> NoReturn: 196 | try: 197 | atexit._run_exitfuncs() 198 | except: 199 | pass 200 | sys.stdout.flush() 201 | sys.stderr.flush() 202 | os._exit(status) 203 | -------------------------------------------------------------------------------- /agent_scheduler/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | from typing import Optional, List, Any, Dict 3 | from pydantic import BaseModel, Field 4 | 5 | from modules import sd_samplers 6 | from modules.api.models import ( 7 | StableDiffusionTxt2ImgProcessingAPI, 8 | StableDiffusionImg2ImgProcessingAPI, 9 | ) 10 | 11 | 12 | def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str: 13 | return dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" if dt else None 14 | 15 | 16 | def transform_to_utc_datetime(dt: datetime) -> datetime: 17 | return dt.astimezone(tz=timezone.utc) 18 | 19 | 20 | class QueueStatusAPI(BaseModel): 21 | limit: Optional[int] = Field(title="Limit", description="The maximum number of tasks to return", default=20) 22 | offset: Optional[int] = Field(title="Offset", description="The offset of the tasks to return", default=0) 23 | 24 | 25 | class TaskModel(BaseModel): 26 | id: str = Field(title="Task Id") 27 | api_task_id: Optional[str] = Field(title="API Task Id", default=None) 28 | api_task_callback: Optional[str] = Field(title="API Task Callback", default=None) 29 | name: Optional[str] = Field(title="Task Name") 30 | type: str = Field(title="Task Type", description="Either txt2img or img2img") 31 | status: str = Field( 32 | "pending", 33 | title="Task Status", 34 | description="Either pending, running, done or failed", 35 | ) 36 | params: Dict[str, Any] = Field(title="Task Parameters", description="The parameters of the task in JSON format") 37 | priority: Optional[int] = Field(title="Task Priority") 38 | position: Optional[int] = Field(title="Task Position") 39 | result: Optional[str] = Field(title="Task Result", description="The result of the task in JSON format") 40 | bookmarked: Optional[bool] = Field(title="Is task bookmarked") 41 | created_at: Optional[datetime] = Field( 42 | title="Task Created At", 43 | description="The time when the task was created", 44 | default=None, 45 | ) 46 | updated_at: Optional[datetime] = Field( 47 | title="Task Updated At", 48 | description="The time when the task was updated", 49 | default=None, 50 | ) 51 | 52 | 53 | class Txt2ImgApiTaskArgs(StableDiffusionTxt2ImgProcessingAPI): 54 | checkpoint: Optional[str] = Field( 55 | None, 56 | title="Custom checkpoint.", 57 | description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.", 58 | ) 59 | vae: Optional[str] = Field( 60 | None, 61 | title="Custom VAE.", 62 | description="Custom VAE. If not specified, the current VAE will be used.", 63 | ) 64 | sampler_index: Optional[str] = Field(sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name") 65 | callback_url: Optional[str] = Field( 66 | None, 67 | title="Callback URL", 68 | description="The callback URL to send the result to.", 69 | ) 70 | 71 | class Config(StableDiffusionTxt2ImgProcessingAPI.__config__): 72 | @staticmethod 73 | def schema_extra(schema: Dict[str, Any], model) -> None: 74 | props = schema.get("properties", {}) 75 | props.pop("send_images", None) 76 | props.pop("save_images", None) 77 | 78 | 79 | class Img2ImgApiTaskArgs(StableDiffusionImg2ImgProcessingAPI): 80 | checkpoint: Optional[str] = Field( 81 | None, 82 | title="Custom checkpoint.", 83 | description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.", 84 | ) 85 | vae: Optional[str] = Field( 86 | None, 87 | title="Custom VAE.", 88 | description="Custom VAE. If not specified, the current VAE will be used.", 89 | ) 90 | sampler_index: Optional[str] = Field(sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name") 91 | callback_url: Optional[str] = Field( 92 | None, 93 | title="Callback URL", 94 | description="The callback URL to send the result to.", 95 | ) 96 | 97 | class Config(StableDiffusionImg2ImgProcessingAPI.__config__): 98 | @staticmethod 99 | def schema_extra(schema: Dict[str, Any], model) -> None: 100 | props = schema.get("properties", {}) 101 | props.pop("send_images", None) 102 | props.pop("save_images", None) 103 | 104 | 105 | class QueueTaskResponse(BaseModel): 106 | task_id: str = Field(title="Task Id") 107 | 108 | 109 | class QueueStatusResponse(BaseModel): 110 | current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id") 111 | pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue") 112 | total_pending_tasks: int = Field(title="Queue length", description="The total pending tasks in the queue") 113 | paused: bool = Field(title="Paused", description="Whether the queue is paused") 114 | 115 | class Config: 116 | json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)} 117 | 118 | 119 | class HistoryResponse(BaseModel): 120 | tasks: List[TaskModel] = Field(title="Tasks") 121 | total: int = Field(title="Task count") 122 | 123 | class Config: 124 | json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)} 125 | 126 | 127 | class UpdateTaskArgs(BaseModel): 128 | name: Optional[str] = Field(title="Task Name") 129 | checkpoint: Optional[str] 130 | params: Optional[Dict[str, Any]] = Field( 131 | title="Task Parameters", description="The parameters of the task in JSON format" 132 | ) 133 | -------------------------------------------------------------------------------- /agent_scheduler/task_helpers.py: -------------------------------------------------------------------------------- 1 | import io 2 | import zlib 3 | import base64 4 | import pickle 5 | import inspect 6 | import requests 7 | import numpy as np 8 | import torch 9 | from typing import Union, List, Dict 10 | from enum import Enum 11 | from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter, PngImagePlugin 12 | from numpy import ndarray 13 | from torch import Tensor 14 | 15 | from modules import sd_samplers, scripts, shared, sd_vae, images, txt2img, img2img 16 | from modules.generation_parameters_copypaste import create_override_settings_dict 17 | from modules.sd_models import CheckpointInfo, get_closet_checkpoint_match 18 | from modules.api.models import ( 19 | StableDiffusionTxt2ImgProcessingAPI, 20 | StableDiffusionImg2ImgProcessingAPI, 21 | ) 22 | 23 | from .helpers import log, get_dict_attribute 24 | 25 | img2img_image_args_by_mode: Dict[int, List[List[str]]] = { 26 | 0: [["init_img"]], 27 | 1: [["sketch"]], 28 | 2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]], 29 | 3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]], 30 | 4: [["init_img_inpaint"], ["init_mask_inpaint"]], 31 | } 32 | 33 | 34 | def get_script_by_name(script_name: str, is_img2img: bool = False, is_always_on: bool = False) -> scripts.Script: 35 | script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img 36 | available_scripts = script_runner.alwayson_scripts if is_always_on else script_runner.selectable_scripts 37 | 38 | return next( 39 | (s for s in available_scripts if s.title().lower() == script_name.lower()), 40 | None, 41 | ) 42 | 43 | 44 | def load_image_from_url(url: str): 45 | try: 46 | response = requests.get(url) 47 | buffer = io.BytesIO(response.content) 48 | return Image.open(buffer) 49 | except Exception as e: 50 | log.error(f"[AgentScheduler] Error downloading image from url: {e}") 51 | return None 52 | 53 | 54 | def encode_image_to_base64(image): 55 | if isinstance(image, np.ndarray): 56 | image = Image.fromarray(image.astype("uint8")) 57 | elif isinstance(image, str): 58 | if image.startswith("http://") or image.startswith("https://"): 59 | image = load_image_from_url(image) 60 | 61 | if not isinstance(image, Image.Image): 62 | return image 63 | 64 | geninfo, _ = images.read_info_from_image(image) 65 | pnginfo = PngImagePlugin.PngInfo() 66 | if geninfo: 67 | pnginfo.add_text("parameters", geninfo) 68 | 69 | with io.BytesIO() as output_bytes: 70 | if geninfo: 71 | image.save(output_bytes, format="PNG", pnginfo=pnginfo) 72 | else: 73 | image.save(output_bytes, format="PNG") # remove pnginfo to save space 74 | bytes_data = output_bytes.getvalue() 75 | return "data:image/png;base64," + base64.b64encode(bytes_data).decode("utf-8") 76 | 77 | 78 | def serialize_image(image): 79 | if isinstance(image, np.ndarray): 80 | shape = image.shape 81 | dtype = image.dtype 82 | data = base64.b64encode(zlib.compress(image.tobytes())).decode() 83 | return {"shape": shape, "data": data, "cls": "ndarray", "dtype": str(dtype)} 84 | elif isinstance(image, torch.Tensor): 85 | shape = image.shape 86 | dtype = image.dtype 87 | data = base64.b64encode(zlib.compress(image.detach().numpy().tobytes())).decode() 88 | return { 89 | "shape": shape, 90 | "data": data, 91 | "cls": "Tensor", 92 | "device": image.device.type, 93 | "dtype": str(dtype), 94 | } 95 | elif isinstance(image, Image.Image): 96 | size = image.size 97 | mode = image.mode 98 | data = base64.b64encode(zlib.compress(image.tobytes())).decode() 99 | return { 100 | "size": size, 101 | "mode": mode, 102 | "data": data, 103 | "cls": "Image", 104 | } 105 | else: 106 | return image 107 | 108 | 109 | def deserialize_image(image_str): 110 | if isinstance(image_str, dict) and image_str.get("cls", None): 111 | cls = image_str["cls"] 112 | data = zlib.decompress(base64.b64decode(image_str["data"])) 113 | 114 | if cls == "ndarray": 115 | # warn if required fields are missing 116 | if image_str.get("dtype", None) is None: 117 | log.warning(f"Missing dtype for ndarray") 118 | shape = tuple(image_str["shape"]) 119 | dtype = np.dtype(image_str.get("dtype", "uint8")) 120 | image = np.frombuffer(data, dtype=dtype) 121 | return image.reshape(shape) 122 | elif cls == "Tensor": 123 | if image_str.get("device", None) is None: 124 | log.warning(f"Missing device for Tensor") 125 | shape = tuple(image_str["shape"]) 126 | dtype = np.dtype(image_str.get("dtype", "uint8")) 127 | image_np = np.frombuffer(data, dtype=dtype) 128 | return torch.from_numpy(image_np.reshape(shape)).to(device=image_str.get("device", "cpu")) 129 | else: 130 | size = tuple(image_str["size"]) 131 | mode = image_str["mode"] 132 | return Image.frombytes(mode, size, data) 133 | else: 134 | return image_str 135 | 136 | 137 | def serialize_img2img_image_args(args: Dict): 138 | for mode, image_args in img2img_image_args_by_mode.items(): 139 | for keys in image_args: 140 | if mode != args["mode"]: 141 | # set None to unused image args to save space 142 | args[keys[0]] = None 143 | elif len(keys) == 1: 144 | image = args.get(keys[0], None) 145 | args[keys[0]] = serialize_image(image) 146 | else: 147 | value = args.get(keys[0], {}) 148 | image = value.get(keys[1], None) 149 | value[keys[1]] = serialize_image(image) 150 | args[keys[0]] = value 151 | 152 | 153 | def deserialize_img2img_image_args(args: Dict): 154 | for mode, image_args in img2img_image_args_by_mode.items(): 155 | if mode != args["mode"]: 156 | continue 157 | 158 | for keys in image_args: 159 | if len(keys) == 1: 160 | image = args.get(keys[0], None) 161 | args[keys[0]] = deserialize_image(image) 162 | else: 163 | value = args.get(keys[0], {}) 164 | image = value.get(keys[1], None) 165 | value[keys[1]] = deserialize_image(image) 166 | args[keys[0]] = value 167 | 168 | 169 | def serialize_controlnet_args(cnet_unit): 170 | args: Dict = cnet_unit.__dict__ 171 | serialized_args = {"is_cnet": True} 172 | for k, v in args.items(): 173 | if isinstance(v, Enum): 174 | serialized_args[k] = v.value 175 | else: 176 | serialized_args[k] = v 177 | 178 | return serialized_args 179 | 180 | 181 | def deserialize_controlnet_args(args: Dict): 182 | new_args = args.copy() 183 | new_args.pop("is_cnet", None) 184 | new_args.pop("is_ui", None) 185 | 186 | return new_args 187 | 188 | 189 | def serialize_script_args(script_args: List): 190 | # convert UiControlNetUnit to dict to make it serializable 191 | for i, a in enumerate(script_args): 192 | if type(a).__name__ == "UiControlNetUnit": 193 | script_args[i] = serialize_controlnet_args(a) 194 | 195 | return zlib.compress(pickle.dumps(script_args)) 196 | 197 | 198 | def deserialize_script_args(script_args: Union[bytes, List], UiControlNetUnit = None): 199 | if type(script_args) is bytes: 200 | script_args = pickle.loads(zlib.decompress(script_args)) 201 | 202 | for i, a in enumerate(script_args): 203 | if isinstance(a, dict) and a.get("is_cnet", False): 204 | unit = deserialize_controlnet_args(a) 205 | skip_controlnet = False 206 | if UiControlNetUnit is not None: 207 | u = UiControlNetUnit() 208 | for k, v in unit.items(): 209 | if isinstance(getattr(u, k, None), Enum): 210 | # check if v is a valid enum value 211 | enum_obj: Enum= getattr(u, k) 212 | if v not in [e.value for e in enum_obj.__class__]: 213 | log.error(f"Invalid enum value {v} for {k} encountered, valid value is {enum_obj.__class__}") 214 | skip_controlnet = True 215 | break 216 | unit[k] = type(getattr(u, k))(v) 217 | if not skip_controlnet: # valid 218 | unit = UiControlNetUnit(**unit) 219 | if not skip_controlnet: # valid 220 | script_args[i] = unit 221 | 222 | return script_args 223 | 224 | 225 | def map_controlnet_args_to_api_task_args(args: Dict): 226 | if type(args).__name__ == "UiControlNetUnit": 227 | args = args.__dict__ 228 | 229 | for k, v in args.items(): 230 | if k == "image" and v is not None: 231 | args[k] = { 232 | "image": encode_image_to_base64(v["image"]), 233 | "mask": encode_image_to_base64(v["mask"]) if v.get("mask", None) is not None else None, 234 | } 235 | if isinstance(v, Enum): 236 | args[k] = v.value 237 | 238 | return args 239 | 240 | 241 | def map_ui_task_args_list_to_named_args(args: List, is_img2img: bool): 242 | fn = ( 243 | getattr(img2img, "img2img_create_processing", img2img.img2img) 244 | if is_img2img 245 | else getattr(txt2img, "txt2img_create_processing", txt2img.txt2img) 246 | ) 247 | arg_names = inspect.getfullargspec(fn).args 248 | 249 | # SD WebUI 1.5.0 has new request arg 250 | if "request" in arg_names: 251 | args.insert(arg_names.index("request"), None) 252 | 253 | named_args = dict(zip(arg_names, args[0 : len(arg_names)])) 254 | script_args = args[len(arg_names) :] 255 | 256 | override_settings_texts: List[str] = named_args.get("override_settings_texts", []) 257 | # add clip_skip if not exist in args (vlad fork has this arg) 258 | if named_args.get("clip_skip", None) is None: 259 | clip_skip = next((s for s in override_settings_texts if s.startswith("Clip skip:")), None) 260 | if clip_skip is None and hasattr(shared.opts, "CLIP_stop_at_last_layers"): 261 | override_settings_texts.append(f"Clip skip: {shared.opts.CLIP_stop_at_last_layers}") 262 | 263 | named_args["override_settings_texts"] = override_settings_texts 264 | 265 | sampler_index = named_args.get("sampler_index", None) 266 | if sampler_index is not None: 267 | available_samplers = sd_samplers.samplers_for_img2img if is_img2img else sd_samplers.samplers 268 | sampler_name = available_samplers[named_args["sampler_index"]].name 269 | named_args["sampler_name"] = sampler_name 270 | log.debug(f"serialize sampler index: {str(sampler_index)} as {sampler_name}") 271 | 272 | return ( 273 | named_args, 274 | script_args, 275 | ) 276 | 277 | 278 | def map_named_args_to_ui_task_args_list(named_args: Dict, script_args: List, is_img2img: bool): 279 | fn = ( 280 | getattr(img2img, "img2img_create_processing", img2img.img2img) 281 | if is_img2img 282 | else getattr(txt2img, "txt2img_create_processing", txt2img.txt2img) 283 | ) 284 | arg_names = inspect.getfullargspec(fn).args 285 | 286 | sampler_name = named_args.get("sampler_name", None) 287 | if sampler_name is not None: 288 | available_samplers = sd_samplers.samplers_for_img2img if is_img2img else sd_samplers.samplers 289 | sampler_index = next((i for i, x in enumerate(available_samplers) if x.name == sampler_name), 0) 290 | named_args["sampler_index"] = sampler_index 291 | 292 | args = [named_args.get(name, None) for name in arg_names] 293 | args.extend(script_args) 294 | 295 | return args 296 | 297 | 298 | def map_script_args_list_to_named(script: scripts.Script, args: List): 299 | script_name = script.title().lower() 300 | 301 | if script_name == "controlnet": 302 | for i, cnet_args in enumerate(args): 303 | args[i] = map_controlnet_args_to_api_task_args(cnet_args) 304 | 305 | return args 306 | 307 | fn = script.process if script.alwayson else script.run 308 | inspection = inspect.getfullargspec(fn) 309 | arg_names = inspection.args[2:] 310 | named_script_args = dict(zip(arg_names, args[: len(arg_names)])) 311 | if inspection.varargs is not None: 312 | named_script_args[inspection.varargs] = args[len(arg_names) :] 313 | 314 | return named_script_args 315 | 316 | 317 | def map_named_script_args_to_list(script: scripts.Script, named_args: Union[dict, list]): 318 | script_name = script.title().lower() 319 | 320 | if isinstance(named_args, dict): 321 | fn = script.process if script.alwayson else script.run 322 | inspection = inspect.getfullargspec(fn) 323 | arg_names = inspection.args[2:] 324 | args = [named_args.get(name, None) for name in arg_names] 325 | if inspection.varargs is not None: 326 | args.extend(named_args.get(inspection.varargs, [])) 327 | 328 | return args 329 | 330 | if isinstance(named_args, list): 331 | if script_name == "controlnet": 332 | for i, cnet_args in enumerate(named_args): 333 | named_args[i] = map_controlnet_args_to_api_task_args(cnet_args) 334 | 335 | return named_args 336 | 337 | 338 | def map_ui_task_args_to_api_task_args(named_args: Dict, script_args: List, is_img2img: bool): 339 | api_task_args: Dict = named_args.copy() 340 | 341 | prompt_styles = api_task_args.pop("prompt_styles", []) 342 | api_task_args["styles"] = prompt_styles 343 | 344 | sampler_index = api_task_args.pop("sampler_index", 0) 345 | api_task_args["sampler_name"] = sd_samplers.samplers[sampler_index].name 346 | 347 | override_settings_texts = api_task_args.pop("override_settings_texts", []) 348 | api_task_args["override_settings"] = create_override_settings_dict(override_settings_texts) 349 | 350 | if is_img2img: 351 | mode = api_task_args.pop("mode", 0) 352 | for arg_mode, image_args in img2img_image_args_by_mode.items(): 353 | if mode != arg_mode: 354 | for keys in image_args: 355 | api_task_args.pop(keys[0], None) 356 | 357 | # the logic below is copied from modules/img2img.py 358 | if mode == 0: 359 | image = api_task_args.pop("init_img") 360 | image = image.convert("RGB") if image else None 361 | mask = None 362 | elif mode == 1: 363 | image = api_task_args.pop("sketch") 364 | image = image.convert("RGB") if image else None 365 | mask = None 366 | elif mode == 2: 367 | init_img_with_mask: Dict = api_task_args.pop("init_img_with_mask") or {} 368 | image = init_img_with_mask.get("image", None) 369 | image = image.convert("RGB") if image else None 370 | mask = init_img_with_mask.get("mask", None) 371 | if mask: 372 | alpha_mask = ( 373 | ImageOps.invert(image.split()[-1]).convert("L").point(lambda x: 255 if x > 0 else 0, mode="1") 374 | ) 375 | mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L") 376 | elif mode == 3: 377 | image = api_task_args.pop("inpaint_color_sketch") 378 | orig = api_task_args.pop("inpaint_color_sketch_orig") or image 379 | if image is not None: 380 | mask_alpha = api_task_args.pop("mask_alpha", 0) 381 | mask_blur = api_task_args.get("mask_blur", 4) 382 | pred = np.any(np.array(image) != np.array(orig), axis=-1) 383 | mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") 384 | mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) 385 | blur = ImageFilter.GaussianBlur(mask_blur) 386 | image = Image.composite(image.filter(blur), orig, mask.filter(blur)) 387 | image = image.convert("RGB") 388 | elif mode == 4: 389 | image = api_task_args.pop("init_img_inpaint") 390 | mask = api_task_args.pop("init_mask_inpaint") 391 | else: 392 | raise Exception(f"Batch mode is not supported yet") 393 | 394 | image = ImageOps.exif_transpose(image) if image else None 395 | api_task_args["init_images"] = [encode_image_to_base64(image)] if image else [] 396 | api_task_args["mask"] = encode_image_to_base64(mask) if mask else None 397 | 398 | selected_scale_tab = api_task_args.pop("selected_scale_tab", 0) 399 | scale_by = api_task_args.get("scale_by", 1) 400 | if selected_scale_tab == 1 and image: 401 | api_task_args["width"] = int(image.width * scale_by) 402 | api_task_args["height"] = int(image.height * scale_by) 403 | else: 404 | hr_sampler_index = api_task_args.pop("hr_sampler_index", 0) 405 | api_task_args["hr_sampler_name"] = ( 406 | sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None 407 | ) 408 | 409 | # script 410 | script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img 411 | script_id = script_args[0] 412 | if script_id == 0: 413 | api_task_args["script_name"] = None 414 | api_task_args["script_args"] = [] 415 | else: 416 | script: scripts.Script = script_runner.selectable_scripts[script_id - 1] 417 | api_task_args["script_name"] = script.title().lower() 418 | current_script_args = script_args[script.args_from : script.args_to] 419 | api_task_args["script_args"] = map_script_args_list_to_named(script, current_script_args) 420 | 421 | # alwayson scripts 422 | alwayson_scripts = api_task_args.get("alwayson_scripts", None) 423 | if not alwayson_scripts: 424 | api_task_args["alwayson_scripts"] = {} 425 | alwayson_scripts = api_task_args["alwayson_scripts"] 426 | 427 | for script in script_runner.alwayson_scripts: 428 | alwayson_script_args = script_args[script.args_from : script.args_to] 429 | script_name = script.title().lower() 430 | if script_name != "agent scheduler": 431 | named_script_args = map_script_args_list_to_named(script, alwayson_script_args) 432 | alwayson_scripts[script_name] = {"args": named_script_args} 433 | 434 | return api_task_args 435 | 436 | 437 | def serialize_api_task_args( 438 | params: Dict, 439 | is_img2img: bool, 440 | checkpoint: str = None, 441 | vae: str = None, 442 | ) -> Dict: 443 | # handle named script args 444 | script_name = params.get("script_name", None) 445 | if script_name is not None and script_name != "": 446 | script = get_script_by_name(script_name, is_img2img) 447 | if script is None: 448 | raise Exception(f"Not found script {script_name}") 449 | 450 | script_args = params.get("script_args", {}) 451 | params["script_args"] = map_named_script_args_to_list(script, script_args) 452 | 453 | # handle named alwayson script args 454 | alwayson_scripts = get_dict_attribute(params, "alwayson_scripts", {}) 455 | assert type(alwayson_scripts) is dict 456 | 457 | script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img 458 | allowed_alwayson_scripts = {s.title().lower(): s for s in script_runner.alwayson_scripts} 459 | 460 | valid_alwayson_scripts = {} 461 | for script_name, script_args in alwayson_scripts.items(): 462 | if script_name.lower() == "agent scheduler": 463 | continue 464 | 465 | if script_name.lower() not in allowed_alwayson_scripts: 466 | log.warning(f"Script {script_name} is not in script_runner.alwayson_scripts") 467 | continue 468 | 469 | script = allowed_alwayson_scripts[script_name.lower()] 470 | script_args = get_dict_attribute(script_args, "args", []) 471 | arg_list = map_named_script_args_to_list(script, script_args) 472 | valid_alwayson_scripts[script_name] = {"args": arg_list} 473 | 474 | params["alwayson_scripts"] = valid_alwayson_scripts 475 | 476 | args = ( 477 | StableDiffusionImg2ImgProcessingAPI(**params) if is_img2img else StableDiffusionTxt2ImgProcessingAPI(**params) 478 | ) 479 | 480 | if args.override_settings is None: 481 | args.override_settings = {} 482 | 483 | if checkpoint is not None: 484 | checkpoint_info: CheckpointInfo = get_closet_checkpoint_match(checkpoint) 485 | if not checkpoint_info: 486 | log.warning(f"Checkpoint {checkpoint} not found, use current system model") 487 | else: 488 | args.override_settings["sd_model_checkpoint"] = checkpoint_info.title 489 | 490 | if vae is not None: 491 | if vae not in sd_vae.vae_dict: 492 | log.warning(f"VAE {vae} not found, use current system vae") 493 | else: 494 | args.override_settings["sd_vae"] = vae 495 | 496 | # load images from url or file if needed 497 | if is_img2img: 498 | init_images = args.init_images 499 | if len(init_images) == 0: 500 | raise Exception("At least one init image is required") 501 | 502 | for i, image in enumerate(init_images): 503 | init_images[i] = encode_image_to_base64(image) 504 | 505 | args.mask = encode_image_to_base64(args.mask) 506 | if len(init_images) > 1: 507 | args.batch_size = len(init_images) 508 | 509 | return args.dict() 510 | -------------------------------------------------------------------------------- /agent_scheduler/task_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ctypes 3 | import json 4 | import subprocess 5 | import time 6 | import traceback 7 | import threading 8 | import gradio as gr 9 | 10 | from datetime import datetime, timezone 11 | from pydantic import BaseModel 12 | from typing import Any, Callable, Union, Optional, List, Dict 13 | from fastapi import FastAPI 14 | from PIL import Image 15 | 16 | from modules import progress, shared, script_callbacks 17 | from modules.call_queue import queue_lock, wrap_gradio_call 18 | from modules.txt2img import txt2img 19 | from modules.img2img import img2img 20 | from modules.api.api import Api 21 | from modules.api.models import ( 22 | StableDiffusionTxt2ImgProcessingAPI, 23 | StableDiffusionImg2ImgProcessingAPI, 24 | ) 25 | 26 | from .db import TaskStatus, Task, task_manager 27 | from .helpers import ( 28 | log, 29 | detect_control_net, 30 | get_component_by_elem_id, 31 | get_dict_attribute, 32 | is_windows, 33 | is_macos, 34 | _exit, 35 | ) 36 | from .task_helpers import ( 37 | encode_image_to_base64, 38 | serialize_img2img_image_args, 39 | deserialize_img2img_image_args, 40 | serialize_script_args, 41 | deserialize_script_args, 42 | serialize_api_task_args, 43 | map_ui_task_args_list_to_named_args, 44 | map_named_args_to_ui_task_args_list, 45 | ) 46 | 47 | 48 | class OutOfMemoryError(Exception): 49 | def __init__(self, message="CUDA out of memory") -> None: 50 | self.message = message 51 | super().__init__(message) 52 | 53 | 54 | class FakeRequest: 55 | def __init__(self, username: str = None): 56 | self.username = username 57 | 58 | 59 | class ParsedTaskArgs(BaseModel): 60 | is_ui: bool 61 | named_args: Dict[str, Any] 62 | script_args: List[Any] 63 | checkpoint: Optional[str] = None 64 | vae: Optional[str] = None 65 | 66 | 67 | class TaskRunner: 68 | instance = None 69 | 70 | def __init__(self, UiControlNetUnit=None): 71 | self.UiControlNetUnit = UiControlNetUnit 72 | 73 | self.__total_pending_tasks: int = 0 74 | self.__current_thread: threading.Thread = None 75 | self.__api = Api(FastAPI(), queue_lock) 76 | 77 | self.__saved_images_path: List[str] = [] 78 | script_callbacks.on_image_saved(self.__on_image_saved) 79 | 80 | self.script_callbacks = { 81 | "task_registered": [], 82 | "task_started": [], 83 | "task_finished": [], 84 | "task_cleared": [], 85 | } 86 | 87 | # Mark this to True when reload UI 88 | self.dispose = False 89 | self.interrupted = None 90 | 91 | if TaskRunner.instance is not None: 92 | raise Exception("TaskRunner instance already exists") 93 | TaskRunner.instance = self 94 | 95 | @property 96 | def current_task_id(self) -> Union[str, None]: 97 | return progress.current_task 98 | 99 | @property 100 | def is_executing_task(self) -> bool: 101 | return self.__current_thread and self.__current_thread.is_alive() 102 | 103 | @property 104 | def paused(self) -> bool: 105 | return getattr(shared.opts, "queue_paused", False) 106 | 107 | def __serialize_ui_task_args( 108 | self, 109 | is_img2img: bool, 110 | *args, 111 | checkpoint: str = None, 112 | vae: str = None, 113 | request: gr.Request = None, 114 | ): 115 | named_args, script_args = map_ui_task_args_list_to_named_args(list(args), is_img2img) 116 | 117 | # loop through named_args and serialize images 118 | if is_img2img: 119 | serialize_img2img_image_args(named_args) 120 | 121 | if "request" in named_args: 122 | named_args["request"] = {"username": request.username} 123 | 124 | params = json.dumps( 125 | { 126 | "args": named_args, 127 | "checkpoint": checkpoint, 128 | "vae": vae, 129 | "is_ui": True, 130 | "is_img2img": is_img2img, 131 | } 132 | ) 133 | script_params = serialize_script_args(script_args) 134 | 135 | return (params, script_params) 136 | 137 | def __serialize_api_task_args( 138 | self, 139 | is_img2img: bool, 140 | checkpoint: str = None, 141 | vae: str = None, 142 | **api_args, 143 | ): 144 | named_args = serialize_api_task_args(api_args, is_img2img, checkpoint=checkpoint, vae=vae) 145 | checkpoint = get_dict_attribute(named_args, "override_settings.sd_model_checkpoint", None) 146 | script_args = named_args.pop("script_args", []) 147 | 148 | params = json.dumps( 149 | { 150 | "args": named_args, 151 | "checkpoint": checkpoint, 152 | "is_ui": False, 153 | "is_img2img": is_img2img, 154 | } 155 | ) 156 | script_params = serialize_script_args(script_args) 157 | return (params, script_params) 158 | 159 | def __deserialize_ui_task_args( 160 | self, 161 | is_img2img: bool, 162 | named_args: Dict, 163 | script_args: List, 164 | checkpoint: str = None, 165 | vae: str = None, 166 | ): 167 | """ 168 | Deserialize UI task arguments 169 | In-place update named_args and script_args 170 | """ 171 | 172 | # Apply checkpoint override 173 | if checkpoint is not None: 174 | override: List[str] = named_args.get("override_settings_texts", []) 175 | override = [x for x in override if not x.startswith("Model hash: ")] 176 | if checkpoint != "System": 177 | override.append("Model hash: " + checkpoint) 178 | named_args["override_settings_texts"] = override 179 | 180 | # Apply VAE override 181 | if vae is not None: 182 | override: List[str] = named_args.get("override_settings_texts", []) 183 | override = [x for x in override if not x.startswith("VAE: ")] 184 | override.append("VAE: " + vae) 185 | named_args["override_settings_texts"] = override 186 | 187 | # A1111 1.5.0-RC has new request field 188 | if "request" in named_args: 189 | named_args["request"] = FakeRequest(**named_args["request"]) 190 | 191 | # loop through image_args and deserialize images 192 | if is_img2img: 193 | deserialize_img2img_image_args(named_args) 194 | 195 | # loop through script_args and deserialize images 196 | script_args = deserialize_script_args(script_args, self.UiControlNetUnit) 197 | 198 | return (named_args, script_args) 199 | 200 | def __deserialize_api_task_args( 201 | self, 202 | is_img2img: bool, 203 | named_args: Dict, 204 | script_args: List, 205 | checkpoint: str = None, 206 | vae: str = None, 207 | ): 208 | # Apply checkpoint override 209 | if checkpoint is not None: 210 | override: Dict = named_args.get("override_settings", {}) 211 | if checkpoint != "System": 212 | override["sd_model_checkpoint"] = checkpoint 213 | else: 214 | override.pop("sd_model_checkpoint", None) 215 | named_args["override_settings"] = override 216 | 217 | # Apply VAE override 218 | if vae is not None: 219 | override: Dict = named_args.get("override_settings", {}) 220 | override["sd_vae"] = vae 221 | named_args["override_settings"] = override 222 | 223 | # load images from disk 224 | if is_img2img: 225 | init_images = named_args.get("init_images") 226 | for i, img in enumerate(init_images): 227 | if isinstance(img, str) and os.path.isfile(img): 228 | image = Image.open(img) 229 | init_images[i] = encode_image_to_base64(image) 230 | 231 | # force image saving 232 | named_args.update({"save_images": True, "send_images": False}) 233 | 234 | script_args = deserialize_script_args(script_args) 235 | return (named_args, script_args) 236 | 237 | def parse_task_args(self, task: Task, deserialization: bool = True): 238 | parsed: Dict[str, Any] = json.loads(task.params) 239 | 240 | is_ui = parsed.get("is_ui", True) 241 | is_img2img = parsed.get("is_img2img", None) 242 | checkpoint = parsed.get("checkpoint", None) 243 | vae = parsed.get("vae", None) 244 | named_args: Dict[str, Any] = parsed["args"] 245 | script_args: List[Any] = parsed.get("script_args", task.script_params) 246 | 247 | if is_ui and deserialization: 248 | named_args, script_args = self.__deserialize_ui_task_args( 249 | is_img2img, named_args, script_args, checkpoint=checkpoint, vae=vae 250 | ) 251 | elif deserialization: 252 | named_args, script_args = self.__deserialize_api_task_args( 253 | is_img2img, named_args, script_args, checkpoint=checkpoint, vae=vae 254 | ) 255 | else: 256 | # ignore script_args if not deserialization 257 | script_args = [] 258 | 259 | return ParsedTaskArgs( 260 | is_ui=is_ui, 261 | named_args=named_args, 262 | script_args=script_args, 263 | checkpoint=checkpoint, 264 | vae=vae, 265 | ) 266 | 267 | def register_ui_task( 268 | self, 269 | task_id: str, 270 | is_img2img: bool, 271 | *args, 272 | checkpoint: str = None, 273 | task_name: str = None, 274 | request: gr.Request = None, 275 | ): 276 | progress.add_task_to_queue(task_id) 277 | 278 | vae = getattr(shared.opts, "sd_vae", "Automatic") 279 | 280 | (params, script_args) = self.__serialize_ui_task_args( 281 | is_img2img, *args, checkpoint=checkpoint, vae=vae, request=request 282 | ) 283 | 284 | task_type = "img2img" if is_img2img else "txt2img" 285 | task = Task( 286 | id=task_id, 287 | name=task_name, 288 | type=task_type, 289 | params=params, 290 | script_params=script_args, 291 | ) 292 | task_manager.add_task(task) 293 | 294 | self.__run_callbacks("task_registered", task_id, is_img2img=is_img2img, is_ui=True, args=params) 295 | self.__total_pending_tasks += 1 296 | 297 | return task 298 | 299 | def register_api_task( 300 | self, 301 | task_id: str, 302 | api_task_id: str, 303 | is_img2img: bool, 304 | args: Dict, 305 | checkpoint: str = None, 306 | vae: str = None, 307 | ): 308 | progress.add_task_to_queue(task_id) 309 | 310 | (params, script_params) = self.__serialize_api_task_args(is_img2img, checkpoint=checkpoint, vae=vae, **args) 311 | 312 | task_type = "img2img" if is_img2img else "txt2img" 313 | task = Task( 314 | id=task_id, 315 | api_task_id=api_task_id, 316 | type=task_type, 317 | params=params, 318 | script_params=script_params, 319 | ) 320 | task_manager.add_task(task) 321 | 322 | self.__run_callbacks("task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params) 323 | self.__total_pending_tasks += 1 324 | 325 | return task 326 | 327 | def execute_task(self, task: Task, get_next_task: Callable[[], Task]): 328 | while True: 329 | if self.dispose: 330 | break 331 | 332 | if progress.current_task is None: 333 | task_id = task.id 334 | is_img2img = task.type == "img2img" 335 | log.info(f"[AgentScheduler] Executing task {task_id}") 336 | 337 | task_args = self.parse_task_args(task) 338 | task_meta = { 339 | "is_img2img": is_img2img, 340 | "is_ui": task_args.is_ui, 341 | "task": task, 342 | } 343 | 344 | self.interrupted = None 345 | self.__saved_images_path = [] 346 | self.__run_callbacks("task_started", task_id, **task_meta) 347 | 348 | # enable image saving 349 | samples_save = shared.opts.samples_save 350 | shared.opts.samples_save = True 351 | 352 | res = self.__execute_task(task_id, is_img2img, task_args) 353 | 354 | # disable image saving 355 | shared.opts.samples_save = samples_save 356 | 357 | if not res or isinstance(res, Exception): 358 | if isinstance(res, OutOfMemoryError): 359 | log.error(f"[AgentScheduler] Task {task_id} failed: CUDA OOM. Queue will be paused.") 360 | shared.opts.queue_paused = True 361 | else: 362 | log.error(f"[AgentScheduler] Task {task_id} failed: {res}") 363 | log.debug(traceback.format_exc()) 364 | 365 | if getattr(shared.opts, "queue_automatic_requeue_failed_task", False): 366 | log.info(f"[AgentScheduler] Requeue task {task_id}") 367 | task.status = TaskStatus.PENDING 368 | task.priority = int(datetime.now(timezone.utc).timestamp() * 1000) 369 | task_manager.update_task(task) 370 | else: 371 | task.status = TaskStatus.FAILED 372 | task.result = str(res) if res else None 373 | task_manager.update_task(task) 374 | self.__run_callbacks("task_finished", task_id, status=TaskStatus.FAILED, **task_meta) 375 | else: 376 | is_interrupted = self.interrupted == task_id 377 | if is_interrupted: 378 | log.info(f"\n[AgentScheduler] Task {task.id} interrupted") 379 | task.status = TaskStatus.INTERRUPTED 380 | task_manager.update_task(task) 381 | self.__run_callbacks( 382 | "task_finished", 383 | task_id, 384 | status=TaskStatus.INTERRUPTED, 385 | **task_meta, 386 | ) 387 | else: 388 | geninfo = json.loads(res) 389 | result = { 390 | "images": self.__saved_images_path.copy(), 391 | "geninfo": geninfo, 392 | } 393 | 394 | task.status = TaskStatus.DONE 395 | task.result = json.dumps(result) 396 | task_manager.update_task(task) 397 | self.__run_callbacks( 398 | "task_finished", 399 | task_id, 400 | status=TaskStatus.DONE, 401 | result=result, 402 | **task_meta, 403 | ) 404 | 405 | self.__saved_images_path = [] 406 | else: 407 | time.sleep(2) 408 | continue 409 | 410 | task = get_next_task() 411 | if not task: 412 | if not self.paused: 413 | time.sleep(1) 414 | self.__on_completed() 415 | break 416 | 417 | def execute_pending_tasks_threading(self): 418 | if self.paused: 419 | log.info("[AgentScheduler] Runner is paused") 420 | return 421 | 422 | if self.is_executing_task: 423 | log.info("[AgentScheduler] Runner already started") 424 | return 425 | 426 | pending_task = self.__get_pending_task() 427 | if pending_task: 428 | # Start the infinite loop in a separate thread 429 | self.__current_thread = threading.Thread( 430 | target=self.execute_task, 431 | args=( 432 | pending_task, 433 | self.__get_pending_task, 434 | ), 435 | ) 436 | self.__current_thread.daemon = True 437 | self.__current_thread.start() 438 | 439 | def __execute_task(self, task_id: str, is_img2img: bool, task_args: ParsedTaskArgs): 440 | if task_args.is_ui: 441 | ui_args = map_named_args_to_ui_task_args_list(task_args.named_args, task_args.script_args, is_img2img) 442 | 443 | return self.__execute_ui_task(task_id, is_img2img, *ui_args) 444 | else: 445 | return self.__execute_api_task( 446 | task_id, 447 | is_img2img, 448 | script_args=task_args.script_args, 449 | **task_args.named_args, 450 | ) 451 | 452 | def __execute_ui_task(self, task_id: str, is_img2img: bool, *args): 453 | func = wrap_gradio_call(img2img if is_img2img else txt2img, add_stats=True) 454 | 455 | with queue_lock: 456 | shared.state.begin() 457 | progress.start_task(task_id) 458 | 459 | res = None 460 | try: 461 | result = func(*args) 462 | if result[0] is None and hasattr(shared.state, "oom") and shared.state.oom: 463 | res = OutOfMemoryError() 464 | elif "CUDA out of memory" in result[2]: 465 | res = OutOfMemoryError() 466 | else: 467 | res = result[1] 468 | except Exception as e: 469 | res = e 470 | finally: 471 | progress.finish_task(task_id) 472 | 473 | shared.state.end() 474 | 475 | return res 476 | 477 | def __execute_api_task(self, task_id: str, is_img2img: bool, **kwargs): 478 | progress.start_task(task_id) 479 | 480 | res = None 481 | try: 482 | result = ( 483 | self.__api.img2imgapi(StableDiffusionImg2ImgProcessingAPI(**kwargs)) 484 | if is_img2img 485 | else self.__api.text2imgapi(StableDiffusionTxt2ImgProcessingAPI(**kwargs)) 486 | ) 487 | res = result.info 488 | except Exception as e: 489 | if "CUDA out of memory" in str(e): 490 | res = OutOfMemoryError() 491 | else: 492 | res = e 493 | finally: 494 | progress.finish_task(task_id) 495 | 496 | return res 497 | 498 | def __get_pending_task(self): 499 | if self.dispose: 500 | return None 501 | 502 | if self.paused: 503 | log.info("[AgentScheduler] Runner is paused") 504 | return None 505 | 506 | # # delete task that are too old 507 | # retention_days = 30 508 | # if ( 509 | # getattr(shared.opts, "queue_history_retention_days", None) 510 | # and shared.opts.queue_history_retention_days in task_history_retenion_map 511 | # ): 512 | # retention_days = task_history_retenion_map[shared.opts.queue_history_retention_days] 513 | 514 | # if retention_days > 0: 515 | # deleted_rows = task_manager.delete_tasks(before=datetime.now() - timedelta(days=retention_days)) 516 | # if deleted_rows > 0: 517 | # log.debug(f"[AgentScheduler] Deleted {deleted_rows} tasks older than {retention_days} days") 518 | 519 | self.__total_pending_tasks = task_manager.count_tasks(status="pending") 520 | 521 | # get more task if needed 522 | if self.__total_pending_tasks > 0: 523 | log.info(f"[AgentScheduler] Total pending tasks: {self.__total_pending_tasks}") 524 | pending_tasks = task_manager.get_tasks(status="pending", limit=1) 525 | if len(pending_tasks) > 0: 526 | return pending_tasks[0] 527 | else: 528 | log.info("[AgentScheduler] Task queue is empty") 529 | self.__run_callbacks("task_cleared") 530 | 531 | def __on_image_saved(self, data: script_callbacks.ImageSaveParams): 532 | if self.current_task_id is None: 533 | return 534 | 535 | outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids 536 | if data.filename.startswith(outpath_grids): 537 | self.__saved_images_path.insert(0, data.filename) 538 | else: 539 | self.__saved_images_path.append(data.filename) 540 | 541 | def __on_completed(self): 542 | action = getattr(shared.opts, "queue_completion_action", "Do nothing") 543 | 544 | if action == "Do nothing": 545 | return 546 | 547 | command = None 548 | if action == "Shut down": 549 | log.info("[AgentScheduler] Shutting down...") 550 | if is_windows: 551 | command = ["shutdown", "/s", "/hybrid", "/t", "0"] 552 | elif is_macos: 553 | command = ["osascript", "-e", 'tell application "Finder" to shut down'] 554 | else: 555 | command = ["systemctl", "poweroff"] 556 | elif action == "Restart": 557 | log.info("[AgentScheduler] Restarting...") 558 | if is_windows: 559 | command = ["shutdown", "/r", "/t", "0"] 560 | elif is_macos: 561 | command = ["osascript", "-e", 'tell application "Finder" to restart'] 562 | else: 563 | command = ["systemctl", "reboot"] 564 | elif action == "Sleep": 565 | log.info("[AgentScheduler] Sleeping...") 566 | if is_windows: 567 | if not ctypes.windll.PowrProf.SetSuspendState(False, False, False): 568 | print(f"Couldn't sleep: {ctypes.GetLastError()}") 569 | elif is_macos: 570 | command = ["osascript", "-e", 'tell application "Finder" to sleep'] 571 | else: 572 | command = ["sh", "-c", 'systemctl hybrid-sleep || (echo "Couldn\'t hybrid sleep, will try to suspend instead: $?"; systemctl suspend)'] 573 | elif action == "Hibernate": 574 | log.info("[AgentScheduler] Hibernating...") 575 | if is_windows: 576 | command = ["shutdown", "/h"] 577 | elif is_macos: 578 | command = ["osascript", "-e", 'tell application "Finder" to sleep'] 579 | else: 580 | command = ["systemctl", "hibernate"] 581 | elif action == "Stop webui": 582 | log.info("[AgentScheduler] Stopping webui...") 583 | _exit(0) 584 | 585 | if command: 586 | subprocess.Popen(command) 587 | 588 | if action in {"Shut down", "Restart"}: 589 | _exit(0) 590 | 591 | def on_task_registered(self, callback: Callable): 592 | """Callback when a task is registered 593 | 594 | Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, args: Dict) 595 | """ 596 | 597 | self.script_callbacks["task_registered"].append(callback) 598 | 599 | def on_task_started(self, callback: Callable): 600 | """Callback when a task is started 601 | 602 | Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool) 603 | """ 604 | 605 | self.script_callbacks["task_started"].append(callback) 606 | 607 | def on_task_finished(self, callback: Callable): 608 | """Callback when a task is finished 609 | 610 | Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, status: TaskStatus, result: Dict) 611 | """ 612 | 613 | self.script_callbacks["task_finished"].append(callback) 614 | 615 | def on_task_cleared(self, callback: Callable): 616 | self.script_callbacks["task_cleared"].append(callback) 617 | 618 | def __run_callbacks(self, name: str, *args, **kwargs): 619 | for callback in self.script_callbacks[name]: 620 | callback(*args, **kwargs) 621 | 622 | 623 | def get_instance(block) -> TaskRunner: 624 | if TaskRunner.instance is None: 625 | if block is not None: 626 | txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate") 627 | UiControlNetUnit = detect_control_net(block, txt2img_submit_button) 628 | TaskRunner(UiControlNetUnit) 629 | else: 630 | TaskRunner() 631 | 632 | if not hasattr(script_callbacks, "on_before_reload"): 633 | log.warning( 634 | "*****************************************************************************************\n" 635 | + "[AgentScheduler] YOUR SD WEBUI IS OUTDATED AND AGENT SCHEDULER WILL NOT WORKING PROPERLY." 636 | + "*****************************************************************************************\n", 637 | ) 638 | else: 639 | 640 | def on_before_reload(): 641 | # Tell old instance to stop 642 | TaskRunner.instance.dispose = True 643 | # force recreate the instance 644 | TaskRunner.instance = None 645 | 646 | script_callbacks.on_before_reload(on_before_reload) 647 | 648 | return TaskRunner.instance 649 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SipherAGI/sd-webui-agent-scheduler/a33753321b914c6122df96d1dc0b5117d38af680/docs/.DS_Store -------------------------------------------------------------------------------- /docs/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Logs 2 | 3 | ## 2023/08/10 4 | 5 | New features: 6 | - New API `/task/{id}/position` to get task position in queue [#105](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/105) 7 | - Display task in local timezone [#95](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/95) 8 | 9 | Bugs fixing: 10 | - `alwayson_scripts` should allow script name in all cases [#102](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/102) 11 | - Fix `script_args` not working when queue task via API [#103](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/103) 12 | 13 | ## 2023/08/02 14 | 15 | Bugs fixing: 16 | - Fix task_id is duplicated [#97](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/97) 17 | - Fix [#100](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/100) 18 | 19 | ## 2023/07/25 20 | 21 | New features: 22 | - Clear queue and clear history 23 | - Queue task with specific name & queue with all checkpoints [#88](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/88) 24 | 25 | ## 2023/07/24 26 | 27 | New features: 28 | - New API `/task/{id}` to get single task (https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/86) 29 | - Update queued task 30 | - Minor change to support changes in SD webui 1.5.0-RC 31 | 32 | Bugs fixing: 33 | - Fixed https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/87 34 | 35 | ## 2023/07/16 36 | 37 | - Use pickle to serialize script args 38 | - Fix task re-ordering not working [#79](https://github.com/ArtVentureX/sd-webui-agent-scheduler/issues/79) 39 | 40 | ## 2023/07/12 41 | 42 | - Fix: batch_size is ignored when queue img2img task via api 43 | 44 | ## 2023/07/11 45 | 46 | - Add clip_skip to queue params 47 | - Add support for api task callback 48 | 49 | ## 2023/06/29 50 | 51 | - Switch js format to iife 52 | - Bugs fixing 53 | 54 | ## 2023/06/23 55 | 56 | - Add setting to disable keyboard shortcut 57 | - Bugs fixing 58 | 59 | ## 2023/06/21 60 | 61 | - Add enqueue keyboard shortcut 62 | - Bugs fixing 63 | 64 | ## 2023/06/20 65 | 66 | - Add api to download task's generated images 67 | - Add setting to render extension UI below the main UI 68 | - Display task datetime in local timezone 69 | - Persist the grid state (columns order, sorting) for next session 70 | - Bugs fixing 71 | 72 | ## 2023/06/07 73 | 74 | - Re-organize folder structure for better loading time 75 | - Prevent duplicate ui initialization 76 | - Prevent unnecessary data refresh 77 | 78 | ## 2023/06/06 79 | 80 | - Force image saving when run task 81 | - Auto pause queue when OOM error detected 82 | 83 | ## 2023/06/05 84 | 85 | - Able to view queue history 86 | - Bookmark task 87 | - Rename task 88 | - Requeue a task 89 | - View generated images of a task 90 | - Send generation params directly to txt2img, img2img 91 | - Add apis to queue task 92 | - Bugs fixing 93 | 94 | ## 2023/06/02 95 | 96 | - Remove the queue placement option `Above Generate Button` 97 | - Make the grid height scale with window resize 98 | - Keep the previous generation result when click enqueue 99 | - Fix: unable to run a specific task when queue is paused 100 | 101 | ## 2023/06/01 102 | 103 | - Add a flag to enable/disable queue auto processing 104 | - Add queue button placement setting 105 | - Add a flag to hide the custom checkpoint select 106 | - Rewrite frontend code in typescript 107 | - Bugs fixing 108 | 109 | ## 2023/05/29 110 | 111 | - First release 112 | -------------------------------------------------------------------------------- /docs/images/install.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SipherAGI/sd-webui-agent-scheduler/a33753321b914c6122df96d1dc0b5117d38af680/docs/images/install.png -------------------------------------------------------------------------------- /docs/images/settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SipherAGI/sd-webui-agent-scheduler/a33753321b914c6122df96d1dc0b5117d38af680/docs/images/settings.png -------------------------------------------------------------------------------- /docs/images/walkthrough.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SipherAGI/sd-webui-agent-scheduler/a33753321b914c6122df96d1dc0b5117d38af680/docs/images/walkthrough.png -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | if not launch.is_installed("sqlalchemy"): 4 | launch.run_pip("install sqlalchemy", "requirement for task-scheduler") 5 | -------------------------------------------------------------------------------- /preload.py: -------------------------------------------------------------------------------- 1 | # preload.py is used for cmd line arguments 2 | def preload(parser): 3 | parser.add_argument( 4 | "--agent-scheduler-sqlite-file", 5 | help="sqlite file to use for the database connection. It can be abs or relative path(from base path) default: task_scheduler.sqlite3", 6 | default="task_scheduler.sqlite3", 7 | ) -------------------------------------------------------------------------------- /scripts/task_scheduler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gradio as gr 4 | from PIL import Image 5 | from uuid import uuid4 6 | from typing import List 7 | from collections import defaultdict 8 | from datetime import datetime, timedelta 9 | 10 | from modules import call_queue, shared, script_callbacks, scripts, ui_components 11 | from modules.shared import list_checkpoint_tiles, refresh_checkpoints 12 | from modules.cmd_args import parser 13 | from modules.ui import create_refresh_button 14 | from modules.ui_common import save_files 15 | from modules.sd_models import model_path 16 | from modules.generation_parameters_copypaste import ( 17 | registered_param_bindings, 18 | register_paste_params_button, 19 | connect_paste_params_buttons, 20 | parse_generation_parameters, 21 | ParamBinding, 22 | ) 23 | 24 | from agent_scheduler.task_runner import TaskRunner, get_instance 25 | from agent_scheduler.helpers import log, compare_components_with_ids, get_components_by_ids, is_macos 26 | from agent_scheduler.db import init as init_db, task_manager, TaskStatus 27 | from agent_scheduler.api import regsiter_apis 28 | 29 | is_sdnext = parser.description == "SD.Next" 30 | ToolButton = gr.Button if is_sdnext else ui_components.ToolButton 31 | 32 | task_runner: TaskRunner = None 33 | 34 | checkpoint_current = "Current Checkpoint" 35 | checkpoint_runtime = "Runtime Checkpoint" 36 | queue_with_every_checkpoints = "$$_queue_with_all_checkpoints_$$" 37 | 38 | ui_placement_as_tab = "As a tab" 39 | ui_placement_append_to_main = "Append to main UI" 40 | 41 | placement_under_generate = "Under Generate button" 42 | placement_between_prompt_and_generate = "Between Prompt and Generate button" 43 | 44 | completion_action_choices = ["Do nothing", "Shut down", "Restart", "Sleep", "Hibernate", "Stop webui"] 45 | 46 | task_filter_choices = ["All", "Bookmarked", "Done", "Failed", "Interrupted"] 47 | 48 | enqueue_key_modifiers = [ 49 | "Command" if is_macos else "Ctrl", 50 | "Control" if is_macos else "Alt", 51 | "Shift", 52 | ] 53 | enqueue_default_hotkey = enqueue_key_modifiers[0] + "+KeyE" 54 | enqueue_key_codes = {} 55 | enqueue_key_codes.update({chr(i): "Key" + chr(i) for i in range(ord("A"), ord("Z") + 1)}) 56 | enqueue_key_codes.update({chr(i): "Digit" + chr(i) for i in range(ord("0"), ord("9") + 1)}) 57 | enqueue_key_codes.update({"`": "Backquote", "Enter": "Enter"}) 58 | 59 | task_history_retenion_map = { 60 | "1 day": 1, 61 | "3 days": 3, 62 | "7 days": 7, 63 | "14 days": 14, 64 | "30 days": 30, 65 | "90 days": 90, 66 | "Keep forever": 0, 67 | } 68 | 69 | init_db() 70 | 71 | 72 | class Script(scripts.Script): 73 | def __init__(self): 74 | super().__init__() 75 | script_callbacks.on_app_started(lambda block, _: self.on_app_started(block)) 76 | self.checkpoint_override = checkpoint_current 77 | self.generate_button = None 78 | self.enqueue_row = None 79 | self.checkpoint_dropdown = None 80 | self.submit_button = None 81 | 82 | def title(self): 83 | return "Agent Scheduler" 84 | 85 | def show(self, is_img2img): 86 | return scripts.AlwaysVisible 87 | 88 | def on_checkpoint_changed(self, checkpoint): 89 | self.checkpoint_override = checkpoint 90 | 91 | def after_component(self, component, **_kwargs): 92 | generate_id = "txt2img_generate" if self.is_txt2img else "img2img_generate" 93 | generate_box = "txt2img_generate_box" if self.is_txt2img else "img2img_generate_box" 94 | actions_column_id = "txt2img_actions_column" if self.is_txt2img else "img2img_actions_column" 95 | neg_id = "txt2img_neg_prompt" if self.is_txt2img else "img2img_neg_prompt" 96 | toprow_id = "txt2img_toprow" if self.is_txt2img else "img2img_toprow" 97 | 98 | def add_enqueue_row(elem_id): 99 | parent = component.parent 100 | while parent is not None: 101 | if parent.elem_id == elem_id: 102 | self.add_enqueue_button() 103 | component.parent.children.pop() 104 | parent.add(self.enqueue_row) 105 | break 106 | parent = parent.parent 107 | 108 | if component.elem_id == generate_id: 109 | self.generate_button = component 110 | if getattr(shared.opts, "compact_prompt_box", False): 111 | add_enqueue_row(generate_box) 112 | else: 113 | if getattr(shared.opts, "queue_button_placement", placement_under_generate) == placement_under_generate: 114 | add_enqueue_row(actions_column_id) 115 | elif component.elem_id == neg_id: 116 | if not getattr(shared.opts, "compact_prompt_box", False): 117 | if getattr(shared.opts, "queue_button_placement", placement_under_generate) == placement_between_prompt_and_generate: 118 | add_enqueue_row(toprow_id) 119 | 120 | def on_app_started(self, block): 121 | if self.generate_button is not None: 122 | self.bind_enqueue_button(block) 123 | 124 | def add_enqueue_button(self): 125 | id_part = "img2img" if self.is_img2img else "txt2img" 126 | with gr.Row(elem_id=f"{id_part}_enqueue_wrapper") as row: 127 | self.enqueue_row = row 128 | hide_checkpoint = getattr(shared.opts, "queue_button_hide_checkpoint", True) 129 | self.checkpoint_dropdown = gr.Dropdown( 130 | choices=get_checkpoint_choices(), 131 | value=checkpoint_current, 132 | show_label=False, 133 | interactive=True, 134 | visible=not hide_checkpoint, 135 | ) 136 | if not hide_checkpoint: 137 | create_refresh_button( 138 | self.checkpoint_dropdown, 139 | refresh_checkpoints, 140 | lambda: {"choices": get_checkpoint_choices()}, 141 | f"refresh_{id_part}_checkpoint", 142 | ) 143 | self.submit_button = gr.Button("Enqueue", elem_id=f"{id_part}_enqueue", variant="primary") 144 | 145 | def bind_enqueue_button(self, root: gr.Blocks): 146 | generate = self.generate_button 147 | is_img2img = self.is_img2img 148 | dependencies: List[dict] = [ 149 | x for x in root.dependencies if x["trigger"] == "click" and generate._id in x["targets"] 150 | ] 151 | 152 | dependency: dict = None 153 | cnet_dependency: dict = None 154 | UiControlNetUnit = None 155 | for d in dependencies: 156 | if len(d["outputs"]) == 1: 157 | outputs = get_components_by_ids(root, d["outputs"]) 158 | output = outputs[0] 159 | if isinstance(output, gr.State) and type(output.value).__name__ == "UiControlNetUnit": 160 | cnet_dependency = d 161 | UiControlNetUnit = type(output.value) 162 | 163 | elif len(d["outputs"]) == 4: 164 | dependency = d 165 | 166 | with root: 167 | if self.checkpoint_dropdown is not None: 168 | self.checkpoint_dropdown.change(fn=self.on_checkpoint_changed, inputs=[self.checkpoint_dropdown]) 169 | 170 | fn_block = next(fn for fn in root.fns if compare_components_with_ids(fn.inputs, dependency["inputs"])) 171 | fn = self.wrap_register_ui_task() 172 | inputs = fn_block.inputs.copy() 173 | inputs.insert(0, self.checkpoint_dropdown) 174 | args = dict( 175 | fn=fn, 176 | _js="submit_enqueue_img2img" if is_img2img else "submit_enqueue", 177 | inputs=inputs, 178 | outputs=None, 179 | show_progress=False, 180 | ) 181 | 182 | self.submit_button.click(**args) 183 | 184 | if cnet_dependency is not None: 185 | cnet_fn_block = next( 186 | fn for fn in root.fns if compare_components_with_ids(fn.inputs, cnet_dependency["inputs"]) 187 | ) 188 | self.submit_button.click( 189 | fn=UiControlNetUnit, 190 | inputs=cnet_fn_block.inputs, 191 | outputs=cnet_fn_block.outputs, 192 | queue=False, 193 | ) 194 | 195 | def wrap_register_ui_task(self): 196 | def f(request: gr.Request, *args): 197 | if len(args) == 0: 198 | raise Exception("Invalid call") 199 | 200 | checkpoint: str = args[0] 201 | task_id = args[1] 202 | args = args[1:] 203 | task_name = None 204 | 205 | if task_id == queue_with_every_checkpoints: 206 | task_id = str(uuid4()) 207 | checkpoint = list_checkpoint_tiles() 208 | else: 209 | if not task_id.startswith("task("): 210 | task_name = task_id 211 | task_id = str(uuid4()) 212 | 213 | if checkpoint is None or checkpoint == "" or checkpoint == checkpoint_current: 214 | checkpoint = [shared.sd_model.sd_checkpoint_info.title] 215 | elif checkpoint == checkpoint_runtime: 216 | checkpoint = [None] 217 | elif checkpoint.endswith(" checkpoints)"): 218 | checkpoint_dir = " ".join(checkpoint.split(" ")[0:-2]) 219 | checkpoint = list(filter(lambda c: c.startswith(checkpoint_dir), list_checkpoint_tiles())) 220 | else: 221 | checkpoint = [checkpoint] 222 | 223 | for i, c in enumerate(checkpoint): 224 | t_id = task_id if i == 0 else f"{task_id}.{i}" 225 | task_runner.register_ui_task( 226 | t_id, 227 | self.is_img2img, 228 | *args, 229 | checkpoint=c, 230 | task_name=task_name, 231 | request=request, 232 | ) 233 | 234 | task_runner.execute_pending_tasks_threading() 235 | 236 | return f 237 | 238 | 239 | def get_checkpoint_choices(): 240 | checkpoints: List[str] = list_checkpoint_tiles() 241 | 242 | checkpoint_dirs = defaultdict(lambda: 0) 243 | for checkpoint in checkpoints: 244 | checkpoint_dir = os.path.dirname(checkpoint) 245 | while checkpoint_dir != "" and checkpoint_dir != "/": 246 | checkpoint_dirs[checkpoint_dir] += 1 247 | checkpoint_dir = os.path.dirname(checkpoint_dir) 248 | 249 | choices = checkpoints 250 | choices.extend([f"{d} ({checkpoint_dirs[d]} checkpoints)" for d in checkpoint_dirs.keys()]) 251 | choices = sorted(choices) 252 | 253 | choices.insert(0, checkpoint_runtime) 254 | choices.insert(0, checkpoint_current) 255 | 256 | return choices 257 | 258 | 259 | def create_send_to_buttons(): 260 | return { 261 | "txt2img": ToolButton( 262 | "➠ text" if is_sdnext else "📝", 263 | elem_id="agent_scheduler_send_to_txt2img", 264 | tooltip="Send generation parameters to txt2img tab.", 265 | ), 266 | "img2img": ToolButton( 267 | "➠ image" if is_sdnext else "🖼️", 268 | elem_id="agent_scheduler_send_to_img2img", 269 | tooltip="Send image and generation parameters to img2img tab.", 270 | ), 271 | "inpaint": ToolButton( 272 | "➠ inpaint" if is_sdnext else "🎨️", 273 | elem_id="agent_scheduler_send_to_inpaint", 274 | tooltip="Send image and generation parameters to img2img inpaint tab.", 275 | ), 276 | "extras": ToolButton( 277 | "➠ process" if is_sdnext else "📐", 278 | elem_id="agent_scheduler_send_to_extras", 279 | tooltip="Send image and generation parameters to extras tab.", 280 | ), 281 | } 282 | 283 | 284 | def infotexts_to_geninfo(infotexts: List[str]): 285 | all_promts = [] 286 | all_seeds = [] 287 | 288 | geninfo = {"infotexts": infotexts, "all_prompts": all_promts, "all_seeds": all_seeds, "index_of_first_image": 0} 289 | 290 | for infotext in infotexts: 291 | # Dynamic prompt breaks layout of infotext 292 | if "Template: " in infotext: 293 | lines = infotext.split("\n") 294 | lines = [l for l in lines if not (l.startswith("Template: ") or l.startswith("Negative Template: "))] 295 | infotext = "\n".join(lines) 296 | 297 | params = parse_generation_parameters(infotext) 298 | 299 | if "prompt" not in geninfo: 300 | geninfo["prompt"] = params.get("Prompt", "") 301 | geninfo["negative_prompt"] = params.get("Negative prompt", "") 302 | geninfo["seed"] = params.get("Seed", "-1") 303 | geninfo["sampler_name"] = params.get("Sampler", "") 304 | geninfo["cfg_scale"] = params.get("CFG scale", "") 305 | geninfo["steps"] = params.get("Steps", "0") 306 | geninfo["width"] = params.get("Size-1", "512") 307 | geninfo["height"] = params.get("Size-2", "512") 308 | 309 | all_promts.append(params.get("Prompt", "")) 310 | all_seeds.append(params.get("Seed", "-1")) 311 | 312 | return geninfo 313 | 314 | 315 | def get_task_results(task_id: str, image_idx: int = None): 316 | task = task_manager.get_task(task_id) 317 | 318 | galerry = None 319 | geninfo = None 320 | infotext = None 321 | if task is None: 322 | pass 323 | elif task.status != TaskStatus.DONE: 324 | infotext = f"Status: {task.status}" 325 | if task.status == TaskStatus.FAILED and task.result: 326 | infotext += f"\nError: {task.result}" 327 | elif task.status == TaskStatus.DONE: 328 | try: 329 | result: dict = json.loads(task.result) 330 | images = result.get("images", []) 331 | geninfo = result.get("geninfo", None) 332 | if isinstance(geninfo, dict): 333 | infotexts = geninfo.get("infotexts", []) 334 | else: 335 | infotexts = result.get("infotexts", []) 336 | geninfo = infotexts_to_geninfo(infotexts) 337 | 338 | galerry = [Image.open(i) for i in images if os.path.exists(i)] if image_idx is None else gr.update() 339 | idx = image_idx if image_idx is not None else 0 340 | if idx < len(infotexts): 341 | infotext = infotexts[idx] 342 | except Exception as e: 343 | log.error(f"[AgentScheduler] Failed to load task result") 344 | log.error(e) 345 | infotext = f"Failed to load task result: {str(e)}" 346 | 347 | res = ( 348 | gr.Textbox.update(infotext, visible=infotext is not None), 349 | gr.Row.update(visible=galerry is not None), 350 | ) 351 | 352 | if image_idx is None: 353 | geninfo = json.dumps(geninfo) if geninfo else None 354 | res += ( 355 | galerry, 356 | gr.Textbox.update(geninfo), 357 | gr.File.update(None, visible=False), 358 | gr.HTML.update(None), 359 | ) 360 | 361 | return res 362 | 363 | 364 | def remove_old_tasks(): 365 | # delete task that are too old 366 | 367 | retention_days = 30 368 | if ( 369 | getattr(shared.opts, "queue_history_retention_days", None) 370 | and shared.opts.queue_history_retention_days in task_history_retenion_map 371 | ): 372 | retention_days = task_history_retenion_map[shared.opts.queue_history_retention_days] 373 | 374 | if retention_days > 0: 375 | deleted_rows = task_manager.delete_tasks(before=datetime.now() - timedelta(days=retention_days)) 376 | if deleted_rows > 0: 377 | log.debug(f"[AgentScheduler] Deleted {deleted_rows} tasks older than {retention_days} days") 378 | 379 | 380 | def on_ui_tab(**_kwargs): 381 | grid_page_size = getattr(shared.opts, "queue_grid_page_size", 0) 382 | 383 | with gr.Blocks(analytics_enabled=False) as scheduler_tab: 384 | with gr.Tabs(elem_id="agent_scheduler_tabs"): 385 | with gr.Tab("Task Queue", id=0, elem_id="agent_scheduler_pending_tasks_tab"): 386 | with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"): 387 | with gr.Column(scale=1): 388 | with gr.Row(elem_id="agent_scheduler_pending_tasks_actions", elem_classes="flex-row"): 389 | paused = getattr(shared.opts, "queue_paused", False) 390 | 391 | gr.Button( 392 | "Pause", 393 | elem_id="agent_scheduler_action_pause", 394 | variant="stop", 395 | visible=not paused, 396 | ) 397 | gr.Button( 398 | "Resume", 399 | elem_id="agent_scheduler_action_resume", 400 | variant="primary", 401 | visible=paused, 402 | ) 403 | gr.Button( 404 | "Refresh", 405 | elem_id="agent_scheduler_action_reload", 406 | variant="secondary", 407 | ) 408 | gr.Button( 409 | "Clear", 410 | elem_id="agent_scheduler_action_clear_queue", 411 | variant="stop", 412 | ) 413 | gr.Button( 414 | "Export", 415 | elem_id="agent_scheduler_action_export", 416 | variant="secondary", 417 | ) 418 | gr.Button( 419 | "Import", 420 | elem_id="agent_scheduler_action_import", 421 | variant="secondary", 422 | ) 423 | gr.HTML(f'') 424 | 425 | with gr.Row(elem_classes=["agent_scheduler_filter_container", "flex-row", "ml-auto"]): 426 | gr.Textbox( 427 | max_lines=1, 428 | placeholder="Search", 429 | label="Search", 430 | show_label=False, 431 | min_width=0, 432 | elem_id="agent_scheduler_action_search", 433 | ) 434 | gr.HTML( 435 | f'
' 436 | ) 437 | with gr.Column(scale=1): 438 | gr.Gallery( 439 | elem_id="agent_scheduler_current_task_images", 440 | label="Output", 441 | show_label=False, 442 | columns=2, 443 | object_fit="contain", 444 | ) 445 | with gr.Tab("Task History", id=1, elem_id="agent_scheduler_history_tab"): 446 | with gr.Row(elem_id="agent_scheduler_history_wrapper"): 447 | with gr.Column(scale=1): 448 | with gr.Row(elem_id="agent_scheduler_history_actions", elem_classes="flex-row"): 449 | gr.Button( 450 | "Requeue Failed", 451 | elem_id="agent_scheduler_action_requeue", 452 | variant="primary", 453 | ) 454 | gr.Button( 455 | "Refresh", 456 | elem_id="agent_scheduler_action_refresh_history", 457 | elem_classes="agent_scheduler_action_refresh", 458 | variant="secondary", 459 | ) 460 | gr.Button( 461 | "Clear", 462 | elem_id="agent_scheduler_action_clear_history", 463 | variant="stop", 464 | ) 465 | 466 | with gr.Row(elem_classes=["agent_scheduler_filter_container", "flex-row", "ml-auto"]): 467 | status = gr.Dropdown( 468 | elem_id="agent_scheduler_status_filter", 469 | choices=task_filter_choices, 470 | value="All", 471 | show_label=False, 472 | min_width=0, 473 | ) 474 | gr.Textbox( 475 | max_lines=1, 476 | placeholder="Search", 477 | label="Search", 478 | show_label=False, 479 | min_width=0, 480 | elem_id="agent_scheduler_action_search_history", 481 | ) 482 | gr.HTML( 483 | f'
' 484 | ) 485 | with gr.Column(scale=1, elem_id="agent_scheduler_history_results"): 486 | galerry = gr.Gallery( 487 | elem_id="agent_scheduler_history_gallery", 488 | label="Output", 489 | show_label=False, 490 | columns=2, 491 | preview=True, 492 | object_fit="contain", 493 | ) 494 | with gr.Row( 495 | elem_id="agent_scheduler_history_result_actions", 496 | visible=False, 497 | ) as result_actions: 498 | if is_sdnext: 499 | with gr.Group(): 500 | save = ToolButton( 501 | "💾", 502 | elem_id="agent_scheduler_save", 503 | tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).", 504 | ) 505 | save_zip = None 506 | else: 507 | save = ToolButton( 508 | "💾", 509 | elem_id="agent_scheduler_save", 510 | tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).", 511 | ) 512 | save_zip = ToolButton( 513 | "🗃️", 514 | elem_id="agent_scheduler_save_zip", 515 | tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})", 516 | ) 517 | send_to_buttons = create_send_to_buttons() 518 | with gr.Group(): 519 | generation_info = gr.Textbox(visible=False, elem_id=f"agent_scheduler_generation_info") 520 | infotext = gr.TextArea( 521 | label="Generation Info", 522 | elem_id=f"agent_scheduler_history_infotext", 523 | interactive=False, 524 | visible=True, 525 | lines=3, 526 | ) 527 | download_files = gr.File( 528 | None, 529 | file_count="multiple", 530 | interactive=False, 531 | show_label=False, 532 | visible=False, 533 | elem_id=f"agent_scheduler_download_files", 534 | ) 535 | html_log = gr.HTML(elem_id=f"agent_scheduler_html_log", elem_classes="html-log") 536 | selected_task = gr.Textbox( 537 | elem_id="agent_scheduler_history_selected_task", 538 | visible=False, 539 | show_label=False, 540 | ) 541 | selected_image_id = gr.Textbox( 542 | elem_id="agent_scheduler_history_selected_image", 543 | visible=False, 544 | show_label=False, 545 | ) 546 | 547 | # register event handlers 548 | status.change( 549 | fn=lambda x: None, 550 | _js="agent_scheduler_status_filter_changed", 551 | inputs=[status], 552 | ) 553 | save.click( 554 | fn=lambda x, y, z: call_queue.wrap_gradio_call(save_files)(x, y, False, int(z)), 555 | _js="(x, y, z) => [x, y, selected_gallery_index()]", 556 | inputs=[generation_info, galerry, infotext], 557 | outputs=[download_files, html_log], 558 | show_progress=False, 559 | ) 560 | if save_zip: 561 | save_zip.click( 562 | fn=lambda x, y, z: call_queue.wrap_gradio_call(save_files)(x, y, True, int(z)), 563 | _js="(x, y, z) => [x, y, selected_gallery_index()]", 564 | inputs=[generation_info, galerry, infotext], 565 | outputs=[download_files, html_log], 566 | ) 567 | selected_task.change( 568 | fn=lambda x: get_task_results(x, None), 569 | inputs=[selected_task], 570 | outputs=[infotext, result_actions, galerry, generation_info, download_files, html_log], 571 | ) 572 | selected_image_id.change( 573 | fn=lambda x, y: get_task_results(x, image_idx=int(y)), 574 | inputs=[selected_task, selected_image_id], 575 | outputs=[infotext, result_actions], 576 | ) 577 | try: 578 | for paste_tabname, paste_button in send_to_buttons.items(): 579 | register_paste_params_button( 580 | ParamBinding( 581 | paste_button=paste_button, 582 | tabname=paste_tabname, 583 | source_text_component=infotext, 584 | source_image_component=galerry, 585 | ) 586 | ) 587 | except: 588 | pass 589 | 590 | return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")] 591 | 592 | 593 | def on_ui_settings(): 594 | section = ("agent_scheduler", "Agent Scheduler") 595 | shared.opts.add_option( 596 | "queue_paused", 597 | shared.OptionInfo( 598 | False, 599 | "Disable queue auto-processing", 600 | gr.Checkbox, 601 | {"interactive": True}, 602 | section=section, 603 | ), 604 | ) 605 | shared.opts.add_option( 606 | "queue_button_hide_checkpoint", 607 | shared.OptionInfo( 608 | True, 609 | "Hide the custom checkpoint dropdown", 610 | gr.Checkbox, 611 | {}, 612 | section=section, 613 | ), 614 | ) 615 | shared.opts.add_option( 616 | "queue_button_placement", 617 | shared.OptionInfo( 618 | placement_under_generate, 619 | "Queue button placement", 620 | gr.Radio, 621 | lambda: { 622 | "choices": [ 623 | placement_under_generate, 624 | placement_between_prompt_and_generate, 625 | ] 626 | }, 627 | section=section, 628 | ), 629 | ) 630 | shared.opts.add_option( 631 | "queue_ui_placement", 632 | shared.OptionInfo( 633 | ui_placement_as_tab, 634 | "Task queue UI placement", 635 | gr.Radio, 636 | lambda: { 637 | "choices": [ 638 | ui_placement_as_tab, 639 | ui_placement_append_to_main, 640 | ] 641 | }, 642 | section=section, 643 | ), 644 | ) 645 | shared.opts.add_option( 646 | "queue_history_retention_days", 647 | shared.OptionInfo( 648 | "30 days", 649 | "Auto delete queue history (bookmarked tasks excluded)", 650 | gr.Radio, 651 | lambda: { 652 | "choices": list(task_history_retenion_map.keys()), 653 | }, 654 | section=section, 655 | ), 656 | ) 657 | shared.opts.add_option( 658 | "queue_automatic_requeue_failed_task", 659 | shared.OptionInfo( 660 | False, 661 | "Auto requeue failed tasks", 662 | gr.Checkbox, 663 | {}, 664 | section=section, 665 | ), 666 | ) 667 | shared.opts.add_option( 668 | "queue_grid_page_size", 669 | shared.OptionInfo( 670 | 0, 671 | "Task list page size (0 for auto)", 672 | gr.Slider, 673 | {"minimum": 0, "maximum": 200, "step": 1}, 674 | section=section, 675 | ), 676 | ) 677 | 678 | def enqueue_keyboard_shortcut(disabled: bool, modifiers, key_code: str): 679 | if disabled: 680 | modifiers.insert(0, "Disabled") 681 | 682 | shortcut = "+".join(sorted(modifiers) + [enqueue_key_codes[key_code]]) 683 | 684 | return ( 685 | shortcut, 686 | gr.CheckboxGroup.update(interactive=not disabled), 687 | gr.Dropdown.update(interactive=not disabled), 688 | ) 689 | 690 | def enqueue_keyboard_shortcut_ui(**_kwargs): 691 | value = _kwargs.get("value", enqueue_default_hotkey) 692 | parts = value.split("+") 693 | key = parts.pop() 694 | key_code_value = [k for k, v in enqueue_key_codes.items() if v == key] 695 | modifiers = [m for m in parts if m in enqueue_key_modifiers] 696 | disabled = "Disabled" in value 697 | 698 | with gr.Group(elem_id="enqueue_keyboard_shortcut_wrapper"): 699 | modifiers = gr.CheckboxGroup( 700 | enqueue_key_modifiers, 701 | value=modifiers, 702 | label="Enqueue keyboard shortcut", 703 | elem_id="enqueue_keyboard_shortcut_modifiers", 704 | interactive=not disabled, 705 | ) 706 | key_code = gr.Dropdown( 707 | choices=list(enqueue_key_codes.keys()), 708 | value="E" if len(key_code_value) == 0 else key_code_value[0], 709 | elem_id="enqueue_keyboard_shortcut_key", 710 | label="Key", 711 | interactive=not disabled, 712 | ) 713 | shortcut = gr.Textbox(**_kwargs) 714 | disable = gr.Checkbox( 715 | value=disabled, 716 | elem_id="enqueue_keyboard_shortcut_disable", 717 | label="Disable keyboard shortcut", 718 | ) 719 | 720 | modifiers.change( 721 | fn=enqueue_keyboard_shortcut, 722 | inputs=[disable, modifiers, key_code], 723 | outputs=[shortcut, modifiers, key_code], 724 | ) 725 | key_code.change( 726 | fn=enqueue_keyboard_shortcut, 727 | inputs=[disable, modifiers, key_code], 728 | outputs=[shortcut, modifiers, key_code], 729 | ) 730 | disable.change( 731 | fn=enqueue_keyboard_shortcut, 732 | inputs=[disable, modifiers, key_code], 733 | outputs=[shortcut, modifiers, key_code], 734 | ) 735 | 736 | return shortcut 737 | 738 | shared.opts.add_option( 739 | "queue_keyboard_shortcut", 740 | shared.OptionInfo( 741 | enqueue_default_hotkey, 742 | "Enqueue keyboard shortcut", 743 | enqueue_keyboard_shortcut_ui, 744 | { 745 | "interactive": False, 746 | }, 747 | section=section, 748 | ), 749 | ) 750 | shared.opts.add_option( 751 | "queue_completion_action", 752 | shared.OptionInfo( 753 | "Do nothing", 754 | "Action after queue completion", 755 | gr.Radio, 756 | lambda: { 757 | "choices": completion_action_choices, 758 | }, 759 | section=section, 760 | ), 761 | ) 762 | 763 | 764 | def on_app_started(block: gr.Blocks, app): 765 | global task_runner 766 | task_runner = get_instance(block) 767 | task_runner.execute_pending_tasks_threading() 768 | regsiter_apis(app, task_runner) 769 | task_runner.on_task_cleared(lambda: remove_old_tasks()) 770 | 771 | if getattr(shared.opts, "queue_ui_placement", "") == ui_placement_append_to_main and block: 772 | with block: 773 | with block.children[1]: 774 | bindings = registered_param_bindings.copy() 775 | registered_param_bindings.clear() 776 | on_ui_tab() 777 | connect_paste_params_buttons() 778 | registered_param_bindings.extend(bindings) 779 | 780 | 781 | if getattr(shared.opts, "queue_ui_placement", "") != ui_placement_append_to_main: 782 | script_callbacks.on_ui_tabs(on_ui_tab) 783 | 784 | script_callbacks.on_ui_settings(on_ui_settings) 785 | script_callbacks.on_app_started(on_app_started) 786 | -------------------------------------------------------------------------------- /ui/.eslintignore: -------------------------------------------------------------------------------- 1 | dist 2 | .eslintrc.cjs -------------------------------------------------------------------------------- /ui/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | env: { browser: true, es2020: true }, 4 | extends: [ 5 | 'eslint:recommended', 6 | 'plugin:@typescript-eslint/recommended', 7 | 'plugin:react-hooks/recommended', 8 | 'plugin:deprecation/recommended', 9 | ], 10 | parser: '@typescript-eslint/parser', 11 | parserOptions: { 12 | ecmaVersion: 'latest', 13 | sourceType: 'module', 14 | project: 'tsconfig.json', 15 | }, 16 | plugins: ['react-refresh', 'simple-import-sort'], 17 | rules: { 18 | 'comma-dangle': [ 19 | 'error', 20 | { 21 | 'arrays': 'always-multiline', 22 | 'objects': 'always-multiline', 23 | 'imports': 'always-multiline', 24 | 'exports': 'always-multiline', 25 | 'functions': 'never', 26 | }, 27 | ], 28 | 'semi': ['error', 'always'], 29 | 'semi-spacing': ['error', { 'after': true, 'before': false }], 30 | 'semi-style': ['error', 'last'], 31 | 'no-extra-semi': 'error', 32 | 'no-unexpected-multiline': 'error', 33 | 'no-unreachable': 'error', 34 | 'no-irregular-whitespace': ['error', { 'skipTemplates': true }], 35 | 'react-refresh/only-export-components': 'warn', 36 | 'simple-import-sort/imports': [ 37 | 'error', 38 | { 39 | groups: [ 40 | // Side effect imports. 41 | ['^\\u0000'], 42 | // Node.js builtins. 43 | [`^(${require('module').builtinModules.join('|')})(/|$)`], 44 | // Packages. `react` related packages come first. 45 | ['^react', '^\\w', '^@\\w'], 46 | // Type 47 | [`^(@@types)(/.*|$)`], 48 | // Internal packages. 49 | [ 50 | `^(~)(/.*|$)`, 51 | ], 52 | // Parent imports. Put `..` last. 53 | ['^\\.\\.(?!/?$)', '^\\.\\./?$'], 54 | // Other relative imports. Put same-folder imports and `.` last. 55 | ['^\\./(?=.*/)(?!/?$)', '^\\.(?!/?$)', '^\\./?$'], 56 | // Style imports. 57 | ['^.+\\.s?css$'], 58 | ], 59 | }, 60 | ], 61 | '@typescript-eslint/no-explicit-any': 'off', 62 | '@typescript-eslint/no-non-null-assertion': 'off', 63 | '@typescript-eslint/strict-boolean-expressions': [ 64 | 'error', 65 | { 66 | allowString: false, 67 | allowNumber: false, 68 | allowNullableObject: false, 69 | }, 70 | ], 71 | }, 72 | } 73 | -------------------------------------------------------------------------------- /ui/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | -------------------------------------------------------------------------------- /ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ui", 3 | "private": true, 4 | "version": "0.0.0", 5 | "scripts": { 6 | "dev": "vite", 7 | "build": "tsc && yarn build:extension", 8 | "build:extension": "vite build --config vite.extension.ts", 9 | "lint": "eslint src --ext ts,tsx --fix --report-unused-disable-directives --max-warnings 0", 10 | "preview": "vite preview" 11 | }, 12 | "dependencies": { 13 | "ag-grid-community": "^31.1.1", 14 | "notyf": "^3.10.0", 15 | "react": "^18.2.0", 16 | "react-dom": "^18.2.0", 17 | "zustand": "^4.5.1" 18 | }, 19 | "devDependencies": { 20 | "@types/node": "^20.11.19", 21 | "@types/react": "^18.2.57", 22 | "@types/react-dom": "^18.2.19", 23 | "@typescript-eslint/eslint-plugin": "^7.0.2", 24 | "@typescript-eslint/parser": "^7.0.2", 25 | "@vitejs/plugin-react": "^4.2.1", 26 | "autoprefixer": "^10.4.17", 27 | "eslint": "^8.56.0", 28 | "eslint-plugin-deprecation": "^2.0.0", 29 | "eslint-plugin-react-hooks": "^4.6.0", 30 | "eslint-plugin-react-refresh": "^0.4.5", 31 | "eslint-plugin-simple-import-sort": "^12.0.0", 32 | "postcss": "^8.4.35", 33 | "sass": "^1.71.1", 34 | "tailwindcss": "^3.4.1", 35 | "typescript": "^5.3.3", 36 | "vite": "^5.1.4" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /ui/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } 7 | -------------------------------------------------------------------------------- /ui/src/assets/icons/bookmark-filled.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ui/src/assets/icons/bookmark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ui/src/assets/icons/cancel.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /ui/src/assets/icons/delete.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ui/src/assets/icons/play.svg: -------------------------------------------------------------------------------- 1 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /ui/src/assets/icons/rotate.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /ui/src/assets/icons/save.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ui/src/assets/icons/search.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /ui/src/extension/index.scss: -------------------------------------------------------------------------------- 1 | @use 'ag-grid-community/styles/ag-theme-alpine.css'; 2 | @import './tailwind.css'; 3 | 4 | /* ========================================================================= */ 5 | 6 | .ag-theme-gradio { 7 | @extend .ag-theme-alpine; 8 | .dark & { 9 | @extend .ag-theme-alpine-dark; 10 | } 11 | } 12 | .ag-theme-gradio, 13 | .dark .ag-theme-gradio { 14 | > * { 15 | --body-text-color: 'inherit'; 16 | } 17 | 18 | --ag-alpine-active-color: var(--color-accent); 19 | --ag-selected-row-background-color: var(--table-row-focus); 20 | --ag-modal-overlay-background-color: transparent; 21 | --ag-row-hover-color: transparent; 22 | --ag-column-hover-color: transparent; 23 | --ag-input-focus-border-color: var(--input-border-color-focus); 24 | --ag-background-color: var(--table-even-background-fill); 25 | --ag-foreground-color: var(--body-text-color); 26 | --ag-border-color: var(--border-color-primary); 27 | --ag-secondary-border-color: var(--border-color-primary); 28 | --ag-header-background-color: var(--table-even-background-fill); 29 | --ag-tooltip-background-color: var(--table-even-background-fill); 30 | --ag-odd-row-background-color: var(--table-even-background-fill); 31 | --ag-control-panel-background-color: var(--table-even-background-fill); 32 | --ag-invalid-color: var(--error-text-color); 33 | --ag-input-border-color: var(--input-border-color); 34 | --ag-disabled-foreground-color: var(--body-text-color-subdued); 35 | --ag-row-border-color: var(--border-color-primary); 36 | 37 | --ag-row-height: 45px; 38 | --ag-header-height: 45px; 39 | --ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2); 40 | 41 | .ag-root-wrapper { 42 | border-radius: var(--table-radius); 43 | } 44 | 45 | .ag-row-even { 46 | background: var(--table-odd-background-fill); 47 | } 48 | 49 | .ag-row-highlight-above, 50 | .ag-row-highlight-below { 51 | &::after { 52 | width: 100%; 53 | height: 2px; 54 | left: 0; 55 | z-index: 3; 56 | } 57 | } 58 | .ag-row-highlight-above { 59 | &::after { 60 | top: -1.5px; 61 | } 62 | &.ag-row-first::after { 63 | top: 0; 64 | } 65 | } 66 | .ag-row-highlight-below { 67 | &::after { 68 | bottom: -1.5px; 69 | } 70 | &.ag-row-last::after { 71 | bottom: 0; 72 | } 73 | } 74 | 75 | .cell-span { 76 | border-bottom-color: var(--ag-border-color); 77 | } 78 | .cell-not-span { 79 | opacity: 0; 80 | } 81 | .ag-input-field-input { 82 | background-color: var(--input-background-fill); 83 | } 84 | .ag-select .ag-picker-field-wrapper { 85 | background-color: var(--input-background-fill); 86 | } 87 | } 88 | .ag-center-cols-viewport { 89 | scrollbar-width: none; 90 | -ms-overflow-style: none; 91 | &::-webkit-scrollbar { 92 | display: none; 93 | } 94 | } 95 | .ag-horizontal-left-spacer, 96 | .ag-horizontal-right-spacer { 97 | overflow-x: hidden; 98 | } 99 | .ag-overlay { 100 | z-index: 5; 101 | } 102 | 103 | .notyf { 104 | font-family: var(--font); 105 | 106 | .notyf__toast { 107 | padding: 0 16px; 108 | border-radius: 6px; 109 | 110 | &.notyf__toast--success .notyf__ripple { 111 | background-color: #22c55e !important; 112 | } 113 | 114 | &.notyf__toast--error .notyf__ripple { 115 | background-color: #ef4444 !important; 116 | } 117 | } 118 | 119 | .notyf__wrapper { 120 | padding: 12px 0; 121 | } 122 | } 123 | 124 | /* ========================================================================= */ 125 | 126 | #tabs > #agent_scheduler_tabs { 127 | margin-top: var(--layout-gap); 128 | } 129 | 130 | .ag-cell.pending-actions { 131 | .ag-row.ag-row-editing & { 132 | .control-actions { 133 | display: none; 134 | } 135 | } 136 | 137 | .ag-row:not(.ag-row-editing) & { 138 | .edit-actions { 139 | display: none; 140 | } 141 | } 142 | } 143 | 144 | .ag-cell.wrap-cell { 145 | line-height: var(--line-lg); 146 | padding-top: calc(var(--ag-cell-horizontal-padding) - 1px); 147 | padding-bottom: calc(var(--ag-cell-horizontal-padding) - 1px); 148 | } 149 | 150 | button.ts-btn-action { 151 | display: inline-flex; 152 | justify-content: center; 153 | align-items: center; 154 | transition: var(--button-transition); 155 | box-shadow: var(--button-shadow); 156 | padding: var(--size-1) var(--size-2) !important; 157 | text-align: center; 158 | 159 | &:hover, 160 | &[disabled] { 161 | box-shadow: var(--button-shadow-hover); 162 | } 163 | 164 | &[disabled] { 165 | opacity: 0.5; 166 | filter: grayscale(30%); 167 | cursor: not-allowed; 168 | } 169 | 170 | &:active { 171 | box-shadow: var(--button-shadow-active); 172 | } 173 | 174 | &.primary { 175 | border: var(--button-border-width) solid var(--button-primary-border-color); 176 | background: var(--button-primary-background-fill); 177 | color: var(--button-primary-text-color); 178 | 179 | &:hover, 180 | &[disabled] { 181 | border-color: var(--button-primary-border-color-hover); 182 | background: var(--button-primary-background-fill-hover); 183 | color: var(--button-primary-text-color-hover); 184 | } 185 | } 186 | 187 | &.secondary { 188 | border: var(--button-border-width) solid var(--button-secondary-border-color); 189 | background: var(--button-secondary-background-fill); 190 | color: var(--button-secondary-text-color); 191 | 192 | &:hover, 193 | &[disabled] { 194 | border-color: var(--button-secondary-border-color-hover); 195 | background: var(--button-secondary-background-fill-hover); 196 | color: var(--button-secondary-text-color-hover); 197 | } 198 | } 199 | 200 | &.stop { 201 | border: var(--button-border-width) solid var(--button-cancel-border-color); 202 | background: var(--button-cancel-background-fill); 203 | color: var(--button-cancel-text-color); 204 | 205 | &:hover, 206 | &[disabled] { 207 | border-color: var(--button-cancel-border-color-hover); 208 | background: var(--button-cancel-background-fill-hover); 209 | color: var(--button-cancel-text-color-hover); 210 | } 211 | } 212 | } 213 | 214 | .ts-bookmark { 215 | color: var(--body-text-color-subdued) !important; 216 | } 217 | .ts-bookmarked { 218 | color: var(--color-accent) !important; 219 | } 220 | 221 | #agent_scheduler_pending_tasks_grid, 222 | #agent_scheduler_history_tasks_grid { 223 | height: calc(100vh - 300px); 224 | min-height: 400px; 225 | } 226 | 227 | #agent_scheduler_pending_tasks_wrapper, 228 | #agent_scheduler_history_wrapper { 229 | border: none; 230 | border-width: 0; 231 | box-shadow: none; 232 | justify-content: flex-end; 233 | gap: var(--layout-gap); 234 | padding: 0; 235 | 236 | @media (max-width: 1024px) { 237 | flex-wrap: wrap; 238 | } 239 | 240 | > div:last-child { 241 | width: 100%; 242 | 243 | @media (min-width: 1280px) { 244 | min-width: 400px !important; 245 | max-width: min(25%, 700px); 246 | } 247 | } 248 | 249 | > button { 250 | flex: 0 0 auto; 251 | } 252 | } 253 | 254 | #agent_scheduler_history_actions, 255 | #agent_scheduler_pending_tasks_actions { 256 | gap: calc(var(--layout-gap) / 2); 257 | min-height: 36px; 258 | } 259 | 260 | #agent_scheduler_history_result_actions { 261 | display: flex; 262 | justify-content: center; 263 | 264 | > div.form { 265 | flex: 0 0 auto !important; 266 | } 267 | 268 | > div.gr-group { 269 | flex: 1 1 100% !important; 270 | } 271 | } 272 | 273 | #agent_scheduler_pending_tasks_wrapper { 274 | .livePreview { 275 | margin: 0; 276 | padding-top: 100%; 277 | 278 | img { 279 | top: 0; 280 | border-radius: 5px; 281 | } 282 | } 283 | 284 | .progressDiv { 285 | height: 42px; 286 | line-height: 42px; 287 | max-width: 100%; 288 | text-align: center; 289 | position: static; 290 | font-size: var(--button-large-text-size); 291 | font-weight: var(--button-large-text-weight); 292 | 293 | .progress { 294 | height: 42px; 295 | line-height: 42px; 296 | } 297 | 298 | + .livePreview { 299 | margin-top: calc(40px + var(--layout-gap)); 300 | } 301 | } 302 | } 303 | 304 | #agent_scheduler_current_task_images, 305 | #agent_scheduler_history_gallery { 306 | width: 100%; 307 | padding-top: calc(100%); 308 | position: relative; 309 | box-sizing: content-box; 310 | 311 | > div { 312 | position: absolute; 313 | top: 0; 314 | left: 0; 315 | width: 100%; 316 | height: 100%; 317 | } 318 | } 319 | 320 | #agent_scheduler_history_gallery { 321 | @media screen and (min-width: 1280px) { 322 | .fixed-height { 323 | min-height: 400px; 324 | } 325 | } 326 | } 327 | 328 | .ml-auto { 329 | margin-left: auto; 330 | } 331 | 332 | .gradio-row.flex-row { 333 | > *, 334 | > .form, 335 | > .form > * { 336 | flex: initial; 337 | width: initial; 338 | min-width: initial; 339 | } 340 | } 341 | 342 | .agent_scheduler_filter_container { 343 | > div.form { 344 | margin: 0; 345 | } 346 | } 347 | 348 | #agent_scheduler_status_filter { 349 | width: var(--size-36); 350 | padding: 0 !important; 351 | 352 | label > div { 353 | height: 100%; 354 | } 355 | } 356 | 357 | #agent_scheduler_action_search, 358 | #agent_scheduler_action_search_history { 359 | width: var(--size-64); 360 | padding: 0 !important; 361 | 362 | > label { 363 | position: relative; 364 | height: 100%; 365 | } 366 | 367 | input.ts-search-input { 368 | padding: var(--block-padding); 369 | height: 100%; 370 | } 371 | } 372 | 373 | #txt2img_enqueue_wrapper, 374 | #img2img_enqueue_wrapper { 375 | min-width: 210px; 376 | display: flex; 377 | flex-direction: column; 378 | gap: calc(var(--layout-gap) / 2); 379 | 380 | > div:first-child { 381 | flex-direction: row; 382 | flex-wrap: nowrap; 383 | align-items: stretch; 384 | flex: 0 0 auto; 385 | flex-grow: unset !important; 386 | } 387 | } 388 | 389 | :not(#txt2img_generate_box) > #txt2img_enqueue_wrapper, 390 | :not(#img2img_generate_box) > #img2img_enqueue_wrapper { 391 | align-self: flex-start; 392 | } 393 | 394 | #img2img_toprow .interrogate-col.has-queue-button { 395 | min-width: unset !important; 396 | flex-direction: row !important; 397 | gap: calc(var(--layout-gap) / 2) !important; 398 | 399 | button { 400 | margin: 0; 401 | } 402 | } 403 | 404 | #enqueue_keyboard_shortcut_wrapper { 405 | flex-wrap: wrap; 406 | 407 | .form { 408 | display: flex; 409 | flex-direction: row; 410 | align-items: flex-end; 411 | flex-wrap: nowrap; 412 | 413 | > div, 414 | fieldset { 415 | flex: 0 0 auto; 416 | width: auto; 417 | } 418 | 419 | #enqueue_keyboard_shortcut_modifiers { 420 | width: 300px; 421 | } 422 | #enqueue_keyboard_shortcut_key { 423 | width: 100px; 424 | } 425 | #setting_queue_keyboard_shortcut { 426 | display: none; 427 | } 428 | #enqueue_keyboard_shortcut_disable { 429 | width: 100%; 430 | } 431 | } 432 | } 433 | 434 | .modification-indicator + #enqueue_keyboard_shortcut_wrapper { 435 | #enqueue_keyboard_shortcut_disable { 436 | padding-left: 12px !important; 437 | } 438 | } -------------------------------------------------------------------------------- /ui/src/extension/stores/history.store.ts: -------------------------------------------------------------------------------- 1 | import { createStore } from 'zustand/vanilla'; 2 | 3 | import { ResponseStatus, Task, TaskHistoryResponse, TaskStatus } from '../types'; 4 | 5 | type HistoryTasksState = { 6 | total: number; 7 | tasks: Task[]; 8 | status?: TaskStatus; 9 | }; 10 | 11 | type HistoryTasksActions = { 12 | refresh: (options?: { limit?: number; offset?: number }) => Promise; 13 | onFilterStatus: (status?: TaskStatus) => void; 14 | bookmarkTask: (id: string, bookmarked: boolean) => Promise; 15 | renameTask: (id: string, name: string) => Promise; 16 | requeueTask: (id: string) => Promise; 17 | requeueFailedTasks: () => Promise; 18 | clearHistory: () => Promise; 19 | }; 20 | 21 | export type HistoryTasksStore = ReturnType; 22 | 23 | export const createHistoryTasksStore = (initialState: HistoryTasksState) => { 24 | const store = createStore()(() => initialState); 25 | const { getState, setState, subscribe } = store; 26 | 27 | const actions: HistoryTasksActions = { 28 | refresh: async options => { 29 | const { limit = 1000, offset = 0 } = options ?? {}; 30 | const status = getState().status ?? ''; 31 | 32 | return fetch(`/agent-scheduler/v1/history?status=${status}&limit=${limit}&offset=${offset}`) 33 | .then(response => response.json()) 34 | .then((data: TaskHistoryResponse) => { 35 | setState({ ...data }); 36 | return data; 37 | }); 38 | }, 39 | onFilterStatus: status => { 40 | setState({ status }); 41 | actions.refresh(); 42 | }, 43 | bookmarkTask: async (id: string, bookmarked: boolean) => { 44 | return fetch(`/agent-scheduler/v1/task/${id}/${bookmarked ? 'bookmark' : 'unbookmark'}`, { 45 | method: 'POST', 46 | }).then(response => response.json()); 47 | }, 48 | renameTask: async (id: string, name: string) => { 49 | return fetch(`/agent-scheduler/v1/task/${id}/rename?name=${encodeURIComponent(name)}`, { 50 | method: 'POST', 51 | headers: { 'Content-Type': 'application/json' }, 52 | }).then(response => response.json()); 53 | }, 54 | requeueTask: async (id: string) => { 55 | return fetch(`/agent-scheduler/v1/task/${id}/requeue`, { method: 'POST' }).then(response => 56 | response.json() 57 | ); 58 | }, 59 | requeueFailedTasks: async () => { 60 | return fetch('/agent-scheduler/v1/task/requeue-failed', { method: 'POST' }).then(response => { 61 | actions.refresh(); 62 | return response.json(); 63 | }); 64 | }, 65 | clearHistory: async () => { 66 | return fetch('/agent-scheduler/v1/history/clear', { method: 'POST' }).then(response => { 67 | actions.refresh(); 68 | return response.json(); 69 | }); 70 | }, 71 | }; 72 | 73 | return { getState, setState, subscribe, ...actions }; 74 | }; 75 | -------------------------------------------------------------------------------- /ui/src/extension/stores/pending.store.ts: -------------------------------------------------------------------------------- 1 | import { createStore } from 'zustand/vanilla'; 2 | 3 | import { ResponseStatus, Task } from '../types'; 4 | 5 | type PendingTasksState = { 6 | current_task_id: string | null; 7 | total_pending_tasks: number; 8 | pending_tasks: Task[]; 9 | paused: boolean; 10 | }; 11 | 12 | type PendingTasksActions = { 13 | refresh: () => Promise; 14 | exportQueue: () => Promise; 15 | importQueue: (str: string) => Promise; 16 | pauseQueue: () => Promise; 17 | resumeQueue: () => Promise; 18 | clearQueue: () => Promise; 19 | runTask: (id: string) => Promise; 20 | moveTask: (id: string, overId: string) => Promise; 21 | updateTask: (id: string, task: Task) => Promise; 22 | deleteTask: (id: string) => Promise; 23 | }; 24 | 25 | export type PendingTasksStore = ReturnType; 26 | 27 | export const createPendingTasksStore = (initialState: PendingTasksState) => { 28 | const store = createStore()(() => initialState); 29 | const { getState, setState, subscribe } = store; 30 | 31 | const actions: PendingTasksActions = { 32 | refresh: async () => { 33 | return fetch('/agent-scheduler/v1/queue?limit=1000') 34 | .then(response => response.json()) 35 | .then(setState); 36 | }, 37 | exportQueue: async () => { 38 | return fetch('/agent-scheduler/v1/export').then(response => response.json()); 39 | }, 40 | importQueue: async (str: string) => { 41 | const bodyObj = { 42 | content: str, 43 | }; 44 | return fetch(`/agent-scheduler/v1/import`, { 45 | method: 'POST', 46 | headers: { 'Content-Type': 'application/json' }, 47 | body: JSON.stringify(bodyObj), 48 | }) 49 | .then(response => response.json()) 50 | .then(data => { 51 | setTimeout(() => { 52 | actions.refresh(); 53 | }, 3000); 54 | return data; 55 | }); 56 | }, 57 | pauseQueue: async () => { 58 | return fetch('/agent-scheduler/v1/queue/pause', { method: 'POST' }) 59 | .then(response => response.json()) 60 | .then(data => { 61 | setTimeout(() => { 62 | actions.refresh(); 63 | }, 500); 64 | return data; 65 | }); 66 | }, 67 | resumeQueue: async () => { 68 | return fetch('/agent-scheduler/v1/queue/resume', { method: 'POST' }) 69 | .then(response => response.json()) 70 | .then(data => { 71 | setTimeout(() => { 72 | actions.refresh(); 73 | }, 500); 74 | return data; 75 | }); 76 | }, 77 | clearQueue: async () => { 78 | return fetch('/agent-scheduler/v1/queue/clear', { method: 'POST' }) 79 | .then(response => response.json()) 80 | .then(data => { 81 | actions.refresh(); 82 | return data; 83 | }); 84 | }, 85 | runTask: async (id: string) => { 86 | return fetch(`/agent-scheduler/v1/task/${id}/run`, { method: 'POST' }) 87 | .then(response => response.json()) 88 | .then(data => { 89 | setTimeout(() => { 90 | actions.refresh(); 91 | }, 500); 92 | return data; 93 | }); 94 | }, 95 | moveTask: async (id: string, overId: string) => { 96 | return fetch(`/agent-scheduler/v1/task/${id}/move/${overId}`, { method: 'POST' }) 97 | .then(response => response.json()) 98 | .then(data => { 99 | actions.refresh(); 100 | return data; 101 | }); 102 | }, 103 | updateTask: async (id: string, task: Task) => { 104 | const newValue = { 105 | name: task.name, 106 | checkpoint: task.params.checkpoint, 107 | params: { 108 | prompt: task.params.prompt, 109 | negative_prompt: task.params.negative_prompt, 110 | sampler_name: task.params.sampler_name, 111 | steps: task.params.steps, 112 | cfg_scale: task.params.cfg_scale, 113 | }, 114 | }; 115 | return fetch(`/agent-scheduler/v1/task/${id}`, { 116 | method: 'PUT', 117 | body: JSON.stringify(newValue), 118 | headers: { 'Content-Type': 'application/json' }, 119 | }).then(response => response.json()); 120 | }, 121 | deleteTask: async (id: string) => { 122 | return fetch(`/agent-scheduler/v1/task/${id}`, { method: 'DELETE' }).then(response => 123 | response.json() 124 | ); 125 | }, 126 | }; 127 | 128 | return { getState, setState, subscribe, ...actions }; 129 | }; 130 | -------------------------------------------------------------------------------- /ui/src/extension/stores/shared.store.ts: -------------------------------------------------------------------------------- 1 | import { createStore } from 'zustand/vanilla'; 2 | 3 | type SelectedTab = 'history' | 'pending'; 4 | 5 | type SharedState = { 6 | uiAsTab: boolean; 7 | selectedTab: SelectedTab; 8 | }; 9 | 10 | type SharedActions = { 11 | setSelectedTab: (tab: SelectedTab) => void; 12 | getSamplers: () => Promise; 13 | getCheckpoints: () => Promise; 14 | }; 15 | 16 | export const createSharedStore = (initialState: SharedState) => { 17 | const store = createStore(() => initialState); 18 | const { getState, setState, subscribe } = store; 19 | 20 | const actions: SharedActions = { 21 | setSelectedTab: (tab: SelectedTab) => { 22 | setState({ selectedTab: tab }); 23 | }, 24 | getSamplers: async () => { 25 | return fetch('/agent-scheduler/v1/samplers').then(response => response.json()); 26 | }, 27 | getCheckpoints: async () => { 28 | return fetch('/agent-scheduler/v1/sd-models').then(response => response.json()); 29 | }, 30 | }; 31 | 32 | return { getState, setState, subscribe, ...actions }; 33 | }; 34 | -------------------------------------------------------------------------------- /ui/src/extension/tailwind.css: -------------------------------------------------------------------------------- 1 | @tailwind components; 2 | @tailwind utilities; 3 | 4 | @layer components { 5 | .ts-search-input { 6 | padding-left: calc(var(--input-padding) + 24px + var(--spacing-sm)) !important; 7 | } 8 | 9 | .ts-search-icon { 10 | @apply absolute text-[--body-text-color] left-[--input-padding] inset-y-[--input-padding] flex 11 | items-center pointer-events-none; 12 | } 13 | 14 | .ts-btn-action { 15 | @apply !m-0 first:rounded-l-[--button-small-radius] last:rounded-r-[--button-small-radius]; 16 | } 17 | 18 | @keyframes blink { 19 | 0%, 20 | 100% { 21 | color: var(--color-accent); 22 | } 23 | 50% { 24 | color: var(--color-accent-soft); 25 | } 26 | } 27 | 28 | .ag-cell.task-running { 29 | color: var(--color-accent); 30 | animation: 1s blink ease infinite; 31 | } 32 | 33 | .ag-cell.task-failed { 34 | @apply text-[--error-text-color]; 35 | } 36 | 37 | .ag-cell.task-interrupted { 38 | @apply text-[--body-text-color-subdued]; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /ui/src/extension/types.ts: -------------------------------------------------------------------------------- 1 | export type TaskStatus = 'pending' | 'running' | 'done' | 'failed' | 'interrupted' | 'saved'; 2 | 3 | export type Task = { 4 | id: string; 5 | api_task_id?: string; 6 | name?: string; 7 | type: string; 8 | status: TaskStatus; 9 | params: Record; 10 | priority: number; 11 | result: string; 12 | bookmarked?: boolean; 13 | editing?: boolean; 14 | created_at: number; 15 | updated_at: number; 16 | }; 17 | 18 | export type ResponseStatus = { 19 | success: boolean; 20 | message: string; 21 | }; 22 | 23 | export type TaskHistoryResponse = { 24 | tasks: Task[]; 25 | total: number; 26 | }; 27 | 28 | export type ProgressResponse = { 29 | active: boolean; 30 | completed: boolean; 31 | eta: number; 32 | id_live_preview: number; 33 | live_preview: string | null; 34 | paused: boolean; 35 | progress: number; 36 | queued: false; 37 | }; 38 | -------------------------------------------------------------------------------- /ui/src/utils/ag-grid.ts: -------------------------------------------------------------------------------- 1 | import { FocusService, GridApi, IRowNode, RowHighlightPosition } from 'ag-grid-community'; 2 | 3 | // patches to suppress scrolling when mouse down on a cell by default 4 | const setFocusedCell = FocusService.prototype.setFocusedCell; 5 | FocusService.prototype.setFocusedCell = function (params) { 6 | if (params.preventScrollOnBrowserFocus == null) { 7 | params.preventScrollOnBrowserFocus = true; 8 | } 9 | return setFocusedCell.call(this, params); 10 | }; 11 | 12 | export const getRowNodeAtPixel = (api: GridApi, pixel: number) => { 13 | if (api.getDisplayedRowCount() === 0) return; 14 | 15 | const firstRowIndexOfPage = api.paginationGetPageSize() * api.paginationGetCurrentPage(); 16 | const firstRowNodeOfPage = api.getDisplayedRowAtIndex(firstRowIndexOfPage)!; 17 | const rowTopOfPage = firstRowNodeOfPage.rowTop!; 18 | 19 | const lastRowIndexOfPage = Math.min( 20 | api.paginationGetPageSize() * (api.paginationGetCurrentPage() + 1) - 1, 21 | api.getDisplayedRowCount() - 1 22 | ); 23 | const lastRowNodeOfPage = api.getDisplayedRowAtIndex(lastRowIndexOfPage)!; 24 | const rowBottomOfPage = lastRowNodeOfPage.rowTop! + lastRowNodeOfPage.rowHeight!; 25 | 26 | let rowNode: IRowNode | undefined; 27 | api.forEachNodeAfterFilterAndSort(node => { 28 | const rowTop = node.rowTop!, rowHeight = node.rowHeight!; 29 | if (rowTop < rowBottomOfPage) { 30 | const pixelOnRow = pixel - (rowTop - rowTopOfPage); 31 | if (pixelOnRow > 0 && pixelOnRow < rowHeight) { 32 | rowNode = node; 33 | } 34 | } 35 | }); 36 | return rowNode; 37 | }; 38 | 39 | export const getPixelOnRow = (api: GridApi, rowNode: IRowNode, pixel: number) => { 40 | const firstRowIndexOfPage = api.paginationGetPageSize() * api.paginationGetCurrentPage(); 41 | const firstRowNodeOfPage = api.getDisplayedRowAtIndex(firstRowIndexOfPage)!; 42 | const pageFirstRowTop = firstRowNodeOfPage.rowTop!; 43 | return pixel - (rowNode.rowTop! - pageFirstRowTop); 44 | }; 45 | 46 | export const getHighlightPosition = (api: GridApi, rowNode: IRowNode, pixel: number) => { 47 | const pixelOnRow = getPixelOnRow(api, rowNode, pixel); 48 | return pixelOnRow < rowNode.rowHeight! / 2 ? RowHighlightPosition.Above : RowHighlightPosition.Below; 49 | }; 50 | -------------------------------------------------------------------------------- /ui/src/utils/debounce.ts: -------------------------------------------------------------------------------- 1 | export const debounce = (fn: (this: T, ...args: P) => R, ms = 300) => { 2 | let timeout: ReturnType | undefined; 3 | return function (this, ...args) { 4 | clearTimeout(timeout); 5 | timeout = setTimeout(() => fn.apply(this, args), ms); 6 | }; 7 | }; 8 | -------------------------------------------------------------------------------- /ui/src/utils/extract-args.ts: -------------------------------------------------------------------------------- 1 | export const extractArgs = (func: (...args: any[]) => any) => { 2 | return (func + '') 3 | .replace(/[/][/].*$/gm, '') // strip single-line comments 4 | .replace(/\s+/g, '') // strip white space 5 | .replace(/[/][*][^/*]*[*][/]/g, '') // strip multi-line comments 6 | .split('){', 1)[0] 7 | .replace(/^[^(]*[(]/, '') // extract the parameters 8 | .replace(/=[^,]+/g, '') // strip any ES6 defaults 9 | .split(',') 10 | .filter(Boolean); // split & filter [""] 11 | }; 12 | -------------------------------------------------------------------------------- /ui/src/vite-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /ui/tailwind.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | content: ['./src/**/*.{js,ts,tsx}'], 3 | darkMode: 'class', 4 | theme: { 5 | extend: {}, 6 | }, 7 | corePlugins: { 8 | container: false, 9 | }, 10 | plugins: [], 11 | safelist: [ 12 | { pattern: /task-(pending|running|done|failed|interrupted|saved)/ }, 13 | ], 14 | } -------------------------------------------------------------------------------- /ui/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "useDefineForClassFields": true, 5 | "lib": ["ES2020", "DOM", "DOM.Iterable"], 6 | "module": "ESNext", 7 | "skipLibCheck": true, 8 | 9 | /* Bundler mode */ 10 | "moduleResolution": "bundler", 11 | "allowImportingTsExtensions": true, 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "noEmit": true, 15 | "jsx": "react-jsx", 16 | 17 | /* Linting */ 18 | "strict": true, 19 | "noFallthroughCasesInSwitch": true, 20 | "noImplicitThis": true, 21 | "noImplicitReturns": true, 22 | "noUnusedLocals": true, 23 | "noUnusedParameters": true 24 | }, 25 | "include": ["src"], 26 | "references": [{ "path": "./tsconfig.node.json" }] 27 | } 28 | -------------------------------------------------------------------------------- /ui/tsconfig.node.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "composite": true, 4 | "skipLibCheck": true, 5 | "module": "ESNext", 6 | "moduleResolution": "bundler", 7 | "allowSyntheticDefaultImports": true 8 | }, 9 | "include": ["vite.config.ts"] 10 | } 11 | -------------------------------------------------------------------------------- /ui/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vite'; 2 | import react from '@vitejs/plugin-react'; 3 | 4 | // https://vitejs.dev/config/ 5 | export default defineConfig({ 6 | plugins: [react()], 7 | build: { 8 | outDir: '../', 9 | rollupOptions: { 10 | input: { 11 | main: 'index.html', 12 | }, 13 | }, 14 | lib: { 15 | name: 'agent-scheduler', 16 | entry: 'src/extension/index.ts', 17 | fileName: 'javascript/extension', 18 | formats: ['es'] 19 | }, 20 | }, 21 | }); 22 | -------------------------------------------------------------------------------- /ui/vite.extension.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vite'; 2 | 3 | // https://vitejs.dev/config/ 4 | export default defineConfig({ 5 | build: { 6 | outDir: '../', 7 | copyPublicDir: false, 8 | lib: { 9 | name: 'agentScheduler', 10 | entry: 'src/extension/index.ts', 11 | fileName: 'javascript/agent-scheduler', 12 | formats: ['iife'] 13 | }, 14 | }, 15 | }); 16 | --------------------------------------------------------------------------------