2 |
3 |
Samplers | 78 |Official Bridge | 79 |SD-WebUI Bridge | 80 |||
---|---|---|---|---|
No Karras | 83 |Karras | 84 |No Karras | 85 |Karras | 86 ||
k_lms | 91 |✔️ | 92 |✔️ | 93 |✔️ | 94 |✔️ | 95 |
k_heun | 98 |✔️ | 99 |✔️ | 100 |✔️ | 101 |✔️ | 102 |
k_euler | 105 |✔️ | 106 |✔️ | 107 |✔️ | 108 |✔️ | 109 |
k_euler_a | 112 |✔️ | 113 |✔️ | 114 |✔️ | 115 |✔️ | 116 |
k_dpm_2 | 119 |✔️ | 120 |✔️ | 121 |✔️ | 122 |✔️ | 123 |
k_dpm_2_a | 126 |✔️ | 127 |✔️ | 128 |✔️ | 129 |✔️ | 130 |
k_dpm_fast | 133 |✔️ | 134 |✔️ | 135 |✔️ | 136 |✔️ | 137 |
k_dpm_adaptive | 140 |✔️ | 141 |✔️ | 142 |✔️ | 143 |✔️ | 144 |
k_dpmpp_2s_a | 147 |✔️ | 148 |✔️ | 149 |✔️ | 150 |✔️ | 151 |
k_dpmpp_2m | 154 |✔️ | 155 |✔️ | 156 |✔️ | 157 |✔️ | 158 |
k_dpmpp_sde | 161 |✔️ | 162 |✔️ | 163 |✔️ | 164 |✔️ | 165 |
dpmsolver | 168 |✔️ | 169 |✔️ | 170 |❌ | 171 |❌ | 172 |
ddim | 175 |❌ | 176 |❌ | 177 |✔️ | 178 |❌ | 179 |
plms | 182 |❌ | 183 |❌ | 184 |✔️ | 185 |❌ | 186 |
{x[0]}: {x[1]}
", 248 | horde.state.to_dict().items(), 249 | ) 250 | ) 251 | images = ( 252 | [horde.state.image] if horde.state.image is not None else [] 253 | ) 254 | if image and show_images: 255 | return cid, html, horde.state.status, images 256 | return cid, html, horde.state.status 257 | 258 | with gr.Row(): 259 | log = gr.HTML(elem_id=tab_prefix + "log") 260 | 261 | refresh.click( 262 | fn=lambda: on_refresh(), 263 | outputs=[current_id, log, state], 264 | show_progress=False, 265 | ) 266 | refresh_image.click( 267 | fn=lambda: on_refresh(True), 268 | outputs=[current_id, log, state, preview], 269 | show_progress=False, 270 | ) 271 | apply_settings.click( 272 | fn=apply_stable_horde_settings, 273 | inputs=[ 274 | enable, 275 | name, 276 | apikey, 277 | allow_img2img, 278 | allow_painting, 279 | allow_unsafe_ipaddr, 280 | allow_post_processing, 281 | restore_settings, 282 | nsfw, 283 | interval, 284 | max_pixels, 285 | endpoint, 286 | show_images, 287 | save_images, 288 | save_images_folder, 289 | ], 290 | outputs=[status, running_type], 291 | ) 292 | 293 | return ((demo, "Stable Horde Worker", "stable-horde"),) 294 | 295 | 296 | script_callbacks.on_app_started(on_app_started) 297 | script_callbacks.on_ui_tabs(on_ui_tabs) 298 | -------------------------------------------------------------------------------- /stable_horde/__init__.py: -------------------------------------------------------------------------------- 1 | from .horde import StableHorde 2 | from .config import StableHordeConfig 3 | 4 | __all__ = ["StableHorde", "StableHordeConfig"] 5 | -------------------------------------------------------------------------------- /stable_horde/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as path 3 | from typing import Any 4 | 5 | 6 | class StableHordeConfig(object): 7 | enabled: bool = False 8 | endpoint: str = "https://stablehorde.net/" 9 | apikey: str = "00000000" 10 | name: str = "" 11 | interval: int = 10 12 | max_pixels: int = 1048576 # 1024x1024 13 | nsfw: bool = False 14 | restore_settings: bool = True 15 | allow_img2img: bool = True 16 | allow_painting: bool = True 17 | allow_unsafe_ipaddr: bool = True 18 | allow_post_processing: bool = True 19 | show_image_preview: bool = False 20 | save_images: bool = False 21 | save_images_folder: str = "horde" 22 | current_models: dict = {} 23 | hires_firstphase_resolution: int = 512 24 | hr_upscaler: str = "Latent" 25 | 26 | def __init__(self, basedir: str): 27 | self.basedir = basedir 28 | self.config = self.load() 29 | 30 | def __getattribute__(self, item: str): 31 | if item in ["config", "basedir", "load", "save"]: 32 | return super().__getattribute__(item) 33 | value = self.config.get(item, None) 34 | if value is None: 35 | return super().__getattribute__(item) 36 | return value 37 | 38 | def __setattr__(self, key: str, value: Any): 39 | if key == "config" or key == "basedir": 40 | super().__setattr__(key, value) 41 | else: 42 | self.config[key] = value 43 | self.save() 44 | 45 | def load(self): 46 | if not path.exists(path.join(self.basedir, "config.json")): 47 | self.config = { 48 | "enabled": False, 49 | "allow_img2img": True, 50 | "allow_painting": True, 51 | "allow_unsafe_ipaddr": True, 52 | "allow_post_processing": True, 53 | "restore_settings": True, 54 | "show_image_preview": False, 55 | "save_images": False, 56 | "save_images_folder": "horde", 57 | "endpoint": "https://stablehorde.net/", 58 | "apikey": "00000000", 59 | "name": "", 60 | "interval": 10, 61 | "max_pixels": 1048576, 62 | "nsfw": False, 63 | "hr_upscaler": "Latent", 64 | "hires_firstphase_resolution": 512, 65 | } 66 | self.save() 67 | 68 | with open(path.join(self.basedir, "config.json"), "r") as f: 69 | return json.load(f) 70 | 71 | def save(self): 72 | with open(path.join(self.basedir, "config.json"), "w") as f: 73 | json.dump(self.config, f, indent=2) 74 | -------------------------------------------------------------------------------- /stable_horde/horde.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from os import path 4 | from typing import Any, Dict, Optional 5 | from re import sub 6 | 7 | import aiohttp 8 | from .job import HordeJob 9 | from .config import StableHordeConfig 10 | import numpy as np 11 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 12 | StableDiffusionSafetyChecker, 13 | ) 14 | from PIL import Image 15 | from transformers.models.auto.feature_extraction_auto import ( 16 | AutoFeatureExtractor, 17 | ) 18 | 19 | from modules.images import save_image 20 | from modules import ( 21 | shared, 22 | call_queue, 23 | processing, 24 | sd_models, 25 | sd_samplers, 26 | ) 27 | 28 | # flake8: noqa: E501 29 | stable_horde_supported_models_url = "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json" 30 | 31 | safety_model_id = "CompVis/stable-diffusion-safety-checker" 32 | safety_feature_extractor = None 33 | safety_checker = None 34 | 35 | 36 | class State: 37 | def __init__(self): 38 | self._status = "" 39 | self.id: Optional[str] = None 40 | self.prompt: Optional[str] = None 41 | self.negative_prompt: Optional[str] = None 42 | self.scale: Optional[float] = None 43 | self.steps: Optional[int] = None 44 | self.sampler: Optional[str] = None 45 | self.image: Optional[Image.Image] = None 46 | 47 | @property 48 | def status(self): 49 | return self._status 50 | 51 | @status.setter 52 | def status(self, value): 53 | self._status = value 54 | if shared.cmd_opts.nowebui: 55 | print(value) 56 | 57 | def to_dict(self): 58 | return { 59 | "status": self.status, 60 | "prompt": self.prompt, 61 | "negative_prompt": self.negative_prompt, 62 | "scale": self.scale, 63 | "steps": self.steps, 64 | "sampler": self.sampler, 65 | } 66 | 67 | 68 | class StableHorde: 69 | def __init__(self, basedir: str, config: StableHordeConfig): 70 | self.basedir = basedir 71 | self.config = config 72 | self.session: Optional[aiohttp.ClientSession] = None 73 | 74 | self.sfw_request_censor = Image.open( 75 | path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png") 76 | ) 77 | 78 | self.supported_models = [] 79 | self.current_models = {} 80 | 81 | self.state = State() 82 | 83 | async def get_supported_models(self): 84 | attempts = 10 85 | while attempts > 0: 86 | attempts -= 1 87 | async with aiohttp.ClientSession() as session: 88 | try: 89 | async with session.get(stable_horde_supported_models_url) as resp: 90 | if resp.status != 200: 91 | raise aiohttp.ClientError() 92 | data = await resp.text() 93 | supported_models: Dict[str, Any] = json.loads(data) 94 | 95 | self.supported_models = list(supported_models.values()) 96 | return 97 | except Exception: 98 | print( 99 | f"Failed to get supported models, retrying in 1 second... \ 100 | ({attempts} attempts left" 101 | ) 102 | await asyncio.sleep(1) 103 | raise Exception("Failed to get supported models after 10 attempts") 104 | 105 | def detect_current_model(self): 106 | model_checkpoint = shared.opts.sd_model_checkpoint 107 | checkpoint_info = sd_models.checkpoints_list.get(model_checkpoint, None) 108 | if checkpoint_info is None: 109 | return f"Model checkpoint {model_checkpoint} not found" 110 | 111 | for model in self.supported_models: 112 | try: 113 | remote_hash = model["config"]["files"][0]["sha256sum"] 114 | except KeyError: 115 | continue 116 | 117 | if shared.opts.sd_checkpoint_hash == remote_hash: 118 | self.current_models = {model["name"]: checkpoint_info.name} 119 | 120 | if len(self.current_models) == 0: 121 | return f"Current model {model_checkpoint} not found on StableHorde" 122 | 123 | def set_current_models(self, model_names: list): 124 | """Set the current models in horde and config""" 125 | remote_hashes = {} 126 | self.current_models = { 127 | k: v for k, v in self.current_models.items() if v in model_names 128 | } 129 | # get the sha256 of all supported models 130 | for model in self.supported_models: 131 | try: 132 | remote_hashes[model["config"]["files"][0]["sha256sum"].lower()] = model[ 133 | "name" 134 | ] 135 | except KeyError: 136 | continue 137 | # get the sha256 of all local models and compare it to the remote hashes 138 | # if the sha256 matches, add the model to the current models list 139 | for checkpoint in sd_models.checkpoints_list.values(): 140 | checkpoint: sd_models.CheckpointInfo 141 | if checkpoint.name in model_names: 142 | # skip sha256 calculation if the model already has hash 143 | if checkpoint.sha256 is None: 144 | local_hash = sd_models.hashes.sha256( 145 | checkpoint.filename, f"checkpoint/{checkpoint.name}" 146 | ) 147 | else: 148 | local_hash = checkpoint.sha256 149 | if checkpoint.name in self.config.current_models.values(): 150 | continue 151 | 152 | if local_hash in remote_hashes: 153 | self.current_models[remote_hashes[local_hash]] = checkpoint.name 154 | print( 155 | f"sha256 for {checkpoint.name} is {local_hash} \ 156 | and it's supported by StableHorde" 157 | ) 158 | else: 159 | print( 160 | f"sha256 for {checkpoint.name} is {local_hash} \ 161 | but it's not supported by StableHorde" 162 | ) 163 | 164 | self.config.current_models = self.current_models 165 | self.config.save() 166 | return self.current_models 167 | 168 | async def run(self): 169 | await self.get_supported_models() 170 | self.current_models = self.config.current_models 171 | 172 | while True: 173 | if len(self.current_models) == 0: 174 | result = self.detect_current_model() 175 | if result is not None: 176 | self.state.status = result 177 | # Wait 10 seconds before retrying to detect the current model 178 | # if the current model is not listed in the Stable Horde supported 179 | # models, we don't want to spam the server with requests 180 | await asyncio.sleep(10) 181 | continue 182 | 183 | await asyncio.sleep(self.config.interval) 184 | 185 | if self.config.enabled: 186 | try: 187 | # Require a queue lock to prevent getting jobs when 188 | # there are generation jobs from webui. 189 | with call_queue.queue_lock: 190 | req = await HordeJob.get( 191 | await self.get_session(), 192 | self.config, 193 | list(self.current_models.keys()), 194 | ) 195 | if req is None: 196 | continue 197 | 198 | await self.handle_request(req) 199 | except Exception: 200 | import traceback 201 | 202 | traceback.print_exc() 203 | 204 | def patch_sampler_names(self): 205 | """Add more samplers that the Stable Horde supports, 206 | but are not included in the default sd_samplers module. 207 | """ 208 | from modules import sd_samplers 209 | 210 | try: 211 | # Old versions of webui put every samplers in `modules.sd_samplers` 212 | # But the newer version split them into several files 213 | # Happened in https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/4df63d2d197f26181758b5108f003f225fe84874 # noqa E501 214 | from modules.sd_samplers import KDiffusionSampler, SamplerData 215 | except ImportError: 216 | from modules.sd_samplers_kdiffusion import KDiffusionSampler 217 | from modules.sd_samplers_common import SamplerData 218 | 219 | if sd_samplers.samplers_map.get("euler a karras"): 220 | # already patched 221 | return 222 | 223 | samplers = [ 224 | SamplerData( 225 | name, 226 | lambda model, fn=func: KDiffusionSampler(fn, model), 227 | [alias], 228 | {"scheduler": "karras"}, 229 | ) 230 | for name, func, alias in [ 231 | ("Euler a Karras", "sample_euler_ancestral", "k_euler_a_ka"), 232 | ("Euler Karras", "sample_euler", "k_euler_ka"), 233 | ("LMS Karras", "sample_lms", "k_lms_ka"), 234 | ("Heun Karras", "sample_heun", "k_heun_ka"), 235 | ("DPM2 Karras", "sample_dpm_2", "k_dpm_2_ka"), 236 | ("DPM2 a Karras", "sample_dpm_2_ancestral", "k_dpm_2_a_ka"), 237 | ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", "k_dpmpp_2s_a_ka"), 238 | ("DPM++ 2M Karras", "sample_dpmpp_2m", "k_dpmpp_2m_ka"), 239 | ("DPM++ SDE Karras", "sample_dpmpp_sde", "k_dpmpp_sde_ka"), 240 | ("DPM fast Karras", "sample_dpm_fast", "k_dpm_fast_ka"), 241 | ("DPM adaptive Karras", "sample_dpm_adaptive", "k_dpm_ad_ka"), 242 | ] 243 | ] 244 | sd_samplers.samplers.extend(samplers) 245 | sd_samplers.samplers_for_img2img.extend(samplers) 246 | sd_samplers.all_samplers_map.update({s.name: s for s in samplers}) 247 | for sampler in samplers: 248 | sd_samplers.samplers_map[sampler.name.lower()] = sampler.name 249 | for alias in sampler.aliases: 250 | sd_samplers.samplers_map[alias.lower()] = sampler.name 251 | 252 | async def handle_request(self, job: HordeJob): 253 | self.patch_sampler_names() 254 | 255 | self.state.status = f"Get popped generation request {job.id}, \ 256 | model {job.model}, sampler {job.sampler}" 257 | sampler_name = job.sampler 258 | if sampler_name == "k_dpm_adaptive": 259 | sampler_name = "k_dpm_ad" 260 | if sampler_name not in sd_samplers.samplers_map: 261 | self.state.status = f"ERROR: Unknown sampler {sampler_name}" 262 | return 263 | if job.karras: 264 | sampler_name += "_ka" 265 | 266 | # Map model name to model 267 | local_model = self.current_models.get(job.model, shared.sd_model) 268 | # Short hash for info text 269 | local_model_shorthash = None 270 | for checkpoint in sd_models.checkpoints_list.values(): 271 | checkpoint: sd_models.CheckpointInfo 272 | if checkpoint.name == local_model: 273 | if not checkpoint.shorthash: 274 | checkpoint.calculate_shorthash() 275 | local_model_shorthash = checkpoint.shorthash 276 | break 277 | if local_model_shorthash is None: 278 | raise Exception(f"ERROR: Unknown model {local_model}") 279 | 280 | sampler = sd_samplers.samplers_map.get(sampler_name, None) 281 | if sampler is None: 282 | raise Exception(f"ERROR: Unknown sampler {sampler_name}") 283 | 284 | postprocessors = job.postprocessors 285 | 286 | params = { 287 | "sd_model": local_model, 288 | "prompt": job.prompt, 289 | "negative_prompt": job.negative_prompt, 290 | "sampler_name": sampler, 291 | "cfg_scale": job.cfg_scale, 292 | "seed": job.seed, 293 | "denoising_strength": job.denoising_strength, 294 | "height": job.height, 295 | "width": job.width, 296 | "subseed": job.subseed, 297 | "steps": job.steps, 298 | "tiling": job.tiling, 299 | "n_iter": job.n_iter, 300 | "do_not_save_samples": True, 301 | "do_not_save_grid": True, 302 | "override_settings": { 303 | "sd_model_checkpoint": local_model, 304 | }, 305 | "enable_hr": job.hires_fix, 306 | "hr_upscaler": self.config.hr_upscaler, 307 | "override_settings_restore_afterwards": self.config.restore_settings, 308 | } 309 | 310 | if job.hires_fix: 311 | ar = job.width / job.height 312 | params["firstphase_width"] = min( 313 | self.config.hires_firstphase_resolution, 314 | int(self.config.hires_firstphase_resolution * ar), 315 | ) 316 | params["firstphase_height"] = min( 317 | self.config.hires_firstphase_resolution, 318 | int(self.config.hires_firstphase_resolution / ar), 319 | ) 320 | 321 | if job.source_image is not None: 322 | p = processing.StableDiffusionProcessingImg2Img( 323 | init_images=[job.source_image], 324 | mask=job.source_mask, 325 | **params, 326 | ) 327 | else: 328 | p = processing.StableDiffusionProcessingTxt2Img(**params) 329 | 330 | with call_queue.queue_lock: 331 | shared.state.begin() 332 | # hijack clip skip 333 | hijacked = False 334 | old_clip_skip = shared.opts.CLIP_stop_at_last_layers 335 | if ( 336 | job.clip_skip >= 1 337 | and job.clip_skip != shared.opts.CLIP_stop_at_last_layers 338 | ): 339 | shared.opts.CLIP_stop_at_last_layers = job.clip_skip 340 | hijacked = True 341 | processed = processing.process_images(p) 342 | 343 | if hijacked: 344 | shared.opts.CLIP_stop_at_last_layers = old_clip_skip 345 | shared.state.end() 346 | 347 | has_nsfw = False 348 | 349 | with call_queue.queue_lock: 350 | if job.nsfw_censor: 351 | x_image = np.array(processed.images[0]) 352 | image, has_nsfw = self.check_safety(x_image) 353 | if has_nsfw: 354 | job.censored = True 355 | 356 | else: 357 | image = processed.images[0] 358 | 359 | if not has_nsfw and ( 360 | "GFPGAN" in postprocessors or "CodeFormers" in postprocessors 361 | ): 362 | model = "CodeFormer" if "CodeFormers" in postprocessors else "GFPGAN" 363 | face_restorators = [x for x in shared.face_restorers if x.name() == model] 364 | if len(face_restorators) == 0: 365 | print(f"ERROR: No face restorer for {model}") 366 | 367 | else: 368 | with call_queue.queue_lock: 369 | image = face_restorators[0].restore(np.array(image)) 370 | image = Image.fromarray(image) 371 | 372 | if "RealESRGAN_x4plus" in postprocessors and not has_nsfw: 373 | from modules.postprocessing import run_extras 374 | 375 | with call_queue.queue_lock: 376 | images, _info, _wtf = run_extras( 377 | image=image, 378 | extras_mode=0, 379 | resize_mode=0, 380 | show_extras_results=True, 381 | upscaling_resize=2, 382 | upscaling_resize_h=None, 383 | upscaling_resize_w=None, 384 | upscaling_crop=False, 385 | upscale_first=False, 386 | extras_upscaler_1="R-ESRGAN 4x+", # 8 - RealESRGAN_x4plus 387 | extras_upscaler_2=None, 388 | extras_upscaler_2_visibility=0.0, 389 | gfpgan_visibility=0.0, 390 | codeformer_visibility=0.0, 391 | codeformer_weight=0.0, 392 | image_folder="", 393 | input_dir="", 394 | output_dir="", 395 | save_output=False, 396 | ) 397 | 398 | image = images[0] 399 | 400 | # Saving image locally 401 | infotext = ( 402 | processing.create_infotext( 403 | p, 404 | p.all_prompts, 405 | p.all_seeds, 406 | p.all_subseeds, 407 | "Stable Horde", 408 | 0, 409 | 0, 410 | ) 411 | if shared.opts.enable_pnginfo 412 | else None 413 | ) 414 | # workaround for model name and hash since webui 415 | # uses shard.sd_model instead of local_model 416 | infotext = sub( 417 | "Model:(.*?),", 418 | "Model: " + local_model.split(".")[0] + ",", 419 | infotext, 420 | ) 421 | infotext = sub( 422 | "Model hash:(.*?),", 423 | "Model hash: " + local_model_shorthash + ",", 424 | infotext, 425 | ) 426 | if self.config.save_images: 427 | save_image( 428 | image, 429 | self.config.save_images_folder, 430 | "", 431 | job.seed, 432 | job.prompt, 433 | "png", 434 | info=infotext, 435 | p=p, 436 | ) 437 | 438 | self.state.id = job.id 439 | self.state.prompt = job.prompt 440 | self.state.negative_prompt = job.negative_prompt 441 | self.state.scale = job.cfg_scale 442 | self.state.steps = job.steps 443 | self.state.sampler = sampler_name 444 | self.state.image = image 445 | 446 | res = await job.submit(image) 447 | if res: 448 | self.state.status = f"Submission accepted, reward {res} received." 449 | 450 | # check and replace nsfw content 451 | def check_safety(self, x_image): 452 | global safety_feature_extractor, safety_checker 453 | 454 | if safety_feature_extractor is None: 455 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained( 456 | safety_model_id 457 | ) 458 | safety_checker = StableDiffusionSafetyChecker.from_pretrained( 459 | safety_model_id 460 | ) 461 | 462 | safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt") 463 | image, has_nsfw_concept = safety_checker( 464 | images=x_image, clip_input=safety_checker_input.pixel_values 465 | ) 466 | 467 | if has_nsfw_concept and any(has_nsfw_concept): 468 | return self.sfw_request_censor, has_nsfw_concept 469 | return Image.fromarray(image), has_nsfw_concept 470 | 471 | async def get_session(self) -> aiohttp.ClientSession: 472 | if self.session is None: 473 | headers = { 474 | "apikey": self.config.apikey, 475 | "Content-Type": "application/json", 476 | } 477 | self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers) 478 | # check if apikey has changed 479 | elif self.session.headers["apikey"] != self.config.apikey: 480 | await self.session.close() 481 | self.session = None 482 | self.session = await self.get_session() 483 | return self.session 484 | 485 | def handle_error(self, status: int, res: Dict[str, Any]): 486 | if status == 401: 487 | self.state.status = "ERROR: Invalid API Key" 488 | elif status == 403: 489 | self.state.status = f"ERROR: Access Denied. ({res.get('message', '')})" 490 | elif status == 404: 491 | self.state.status = "ERROR: Request Not Found" 492 | else: 493 | self.state.status = f"ERROR: Unknown Error {status}" 494 | print(f"ERROR: Unknown Error, {res}") 495 | -------------------------------------------------------------------------------- /stable_horde/job.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | from enum import Enum 4 | import io 5 | from random import randint 6 | from typing import List, Optional 7 | from PIL import Image 8 | 9 | import aiohttp 10 | from .config import StableHordeConfig 11 | 12 | 13 | class JobStatus(Enum): 14 | PENDING = "pending" 15 | RUNNING = "running" 16 | GENERATED = "generated" 17 | SUBMITTING = "submitting" 18 | UPLOADED = "uploaded" 19 | SUBMITTED = "submitted" 20 | DONE = "done" 21 | ERROR = "error" 22 | 23 | 24 | class HordeJob: 25 | retry_interval: int = 1 26 | censored = False 27 | 28 | def __init__( 29 | self, 30 | session: aiohttp.ClientSession, 31 | id: str, 32 | model: str, 33 | prompt: str, 34 | negative_prompt: str, 35 | sampler: str, 36 | cfg_scale: float, 37 | seed: int, 38 | denoising_strength: float, 39 | n_iter: int, 40 | height: int, 41 | width: int, 42 | subseed: int, 43 | steps: int, 44 | karras: bool, 45 | tiling: bool, 46 | postprocessors: List[str], 47 | nsfw_censor: bool = False, 48 | clip_skip: int = 0, 49 | source_image: Optional[Image.Image] = None, 50 | source_processing: Optional[str] = "img2img", 51 | source_mask: Optional[Image.Image] = None, 52 | r2_upload: Optional[str] = None, 53 | hires_fix: bool = False, 54 | ): 55 | self.status: JobStatus = JobStatus.PENDING 56 | self.session = session 57 | self.id = id 58 | self.model = model 59 | self.prompt = prompt 60 | self.negative_prompt = negative_prompt 61 | self.sampler = sampler 62 | self.cfg_scale = cfg_scale 63 | self.seed = seed 64 | self.denoising_strength = denoising_strength 65 | self.n_iter = n_iter 66 | self.height = height 67 | self.width = width 68 | self.subseed = subseed 69 | self.steps = steps 70 | self.karras = karras 71 | self.tiling = tiling 72 | self.postprocessors = postprocessors 73 | self.nsfw_censor = nsfw_censor 74 | self.clip_skip = clip_skip 75 | self.source_image = source_image 76 | self.source_processing = ( 77 | source_processing # "img2img", "inpainting", "outpainting" 78 | ) 79 | self.source_mask = source_mask 80 | self.r2_upload = r2_upload 81 | self.hires_fix = hires_fix 82 | 83 | async def submit(self, image: Image.Image): 84 | self.status = JobStatus.SUBMITTING 85 | 86 | bytesio = io.BytesIO() 87 | image.save(bytesio, format="WebP", quality=95) 88 | 89 | if self.r2_upload: 90 | async with aiohttp.ClientSession() as session: 91 | attempts = 10 92 | while attempts > 0: 93 | try: 94 | r = await session.put(self.r2_upload, data=bytesio.getvalue()) 95 | break 96 | except aiohttp.ClientConnectorError: 97 | attempts -= 1 98 | await asyncio.sleep(self.retry_interval) 99 | continue 100 | generation = "R2" 101 | 102 | self.status = JobStatus.UPLOADED 103 | 104 | else: 105 | generation = base64.b64encode(bytesio.getvalue()).decode("utf8") 106 | 107 | post_data = { 108 | "id": self.id, 109 | "generation": generation, 110 | "seed": self.seed, 111 | "state": "censored" if self.censored else "ok", 112 | } 113 | 114 | attempts = 10 115 | while attempts > 0: 116 | try: 117 | r = await self.session.post("/api/v2/generate/submit", json=post_data) 118 | 119 | try: 120 | res = await r.json() 121 | 122 | if r.status == 404: 123 | print(f"job {self.id} has been submitted already") 124 | return 125 | 126 | if r.status == 500: 127 | print( 128 | f"Failed to submit job with status code {r.status}, retry!" 129 | ) 130 | attempts -= 1 131 | await asyncio.sleep(self.retry_interval) 132 | continue 133 | 134 | if r.ok: 135 | self.status = JobStatus.SUBMITTED 136 | reward = res.get("reward", None) 137 | if reward: 138 | self.status = JobStatus.DONE 139 | return reward 140 | else: 141 | print( 142 | "Failed to submit job with status code" 143 | + f"{r.status}: {res.get('message')}" 144 | ) 145 | return None 146 | except Exception: 147 | print("Error when decoding response, the server might be down.") 148 | return None 149 | 150 | except aiohttp.ClientConnectorError: 151 | attempts -= 1 152 | await asyncio.sleep(self.retry_interval) 153 | continue 154 | 155 | self.status = JobStatus.ERROR 156 | 157 | async def error(self): 158 | self.status = JobStatus.ERROR 159 | 160 | post_data = {"id": self.id, "state": "faulted"} 161 | attempts = 10 162 | while attempts > 0: 163 | try: 164 | r = await self.session.post("/api/v2/generate/submit", json=post_data) 165 | if r.ok: 166 | print("Successfully reported error to Stable Horde") 167 | return 168 | else: 169 | res = await r.json() 170 | print( 171 | "Failed to report error with status code" 172 | + f"{r.status}: {res.get('message')}" 173 | ) 174 | return 175 | except aiohttp.ClientConnectorError: 176 | attempts -= 1 177 | await asyncio.sleep(self.retry_interval) 178 | continue 179 | 180 | @classmethod 181 | async def get( 182 | cls, 183 | session: aiohttp.ClientSession, 184 | config: StableHordeConfig, 185 | models: List[str], 186 | ): 187 | # Stable Horde uses a bridge version to differentiate between different 188 | # bridge agents which is used to determine the bridge agent's capabilities. 189 | # We should increment the version number when we add new features to the bridge 190 | # agent. 191 | # 192 | # When we increment the version number, we should also update the AI-Horde side: 193 | # https://github.com/db0/AI-Horde/blob/main/horde/bridge_reference.py 194 | # 195 | # 1 - img2img, inpainting, karras, r2, CodeFormers 196 | # 2 - tiling 197 | # 3 - r2 source 198 | # 4 - hires_fix, clip_skip 199 | version = 4 200 | name = "SD-WebUI Stable Horde Worker Bridge" 201 | repo = "https://github.com/sdwebui-w-horde/sd-webui-stable-horde-worker" 202 | # https://stablehorde.net/api/ 203 | post_data = { 204 | "name": config.name, 205 | "priority_usernames": [], 206 | "nsfw": config.nsfw, 207 | "blacklist": [], 208 | "models": models, 209 | # TODO: add support for bridge version 14 (r2_source) 210 | "bridge_version": 13, 211 | "bridge_agent": f"{name}:{version}:{repo}", 212 | "threads": 1, 213 | "max_pixels": config.max_pixels, 214 | "allow_img2img": config.allow_img2img, 215 | "allow_painting": config.allow_painting, 216 | "allow_unsafe_ipaddr": config.allow_unsafe_ipaddr, 217 | } 218 | 219 | r = await session.post("/api/v2/generate/pop", json=post_data) 220 | 221 | req = await r.json() 222 | 223 | if r.status != 200: 224 | raise Exception(f"Failed to get job: {req.get('message')}") 225 | 226 | if not req.get("id"): 227 | return 228 | 229 | payload = req.get("payload") 230 | prompt = payload.get("prompt") 231 | if "###" in prompt: 232 | prompt, negative = map(lambda x: x.strip(), prompt.rsplit("###", 1)) 233 | else: 234 | negative = "" 235 | 236 | async def to_image(base64str: Optional[str]) -> Optional[Image.Image]: 237 | if not base64str: 238 | return None 239 | # support for r2 source, which is a url rather than a base64 string 240 | if base64str.startswith("http"): 241 | async with aiohttp.ClientSession() as session: 242 | attempts = 10 243 | while attempts > 0: 244 | try: 245 | r = await session.get(base64str) 246 | return Image.open(await r.read()) 247 | except aiohttp.ClientConnectorError: 248 | attempts -= 1 249 | await asyncio.sleep(1) 250 | continue 251 | raise Exception("Failed to download source image") 252 | 253 | return Image.open(base64.b64decode(base64str)) 254 | 255 | return cls( 256 | session=session, 257 | id=req["id"], 258 | prompt=prompt, 259 | negative_prompt=negative, 260 | sampler=payload.get("sampler_name"), 261 | cfg_scale=payload.get("cfg_scale", 5), 262 | seed=int(payload.get("seed", randint(0, 2**32))), 263 | denoising_strength=payload.get("denoising_strength", 0.75), 264 | n_iter=payload.get("n_iter", 1), 265 | height=payload["height"], 266 | width=payload["width"], 267 | subseed=payload.get("seed_variation", 1), 268 | steps=payload.get("ddim_steps", 30), 269 | karras=payload.get("karras", False), 270 | tiling=payload.get("tiling", False), 271 | clip_skip=payload.get("clip_skip", 1), 272 | postprocessors=payload.get("post_processing", []), 273 | nsfw_censor=payload.get("use_nsfw_censor", False), 274 | model=req["model"], 275 | source_image=await to_image(req.get("source_image")), 276 | source_processing=req.get("source_processing"), 277 | source_mask=await to_image(req.get("source_mask")), 278 | r2_upload=req.get("r2_upload"), 279 | hires_fix=payload.get("hires_fix", False), 280 | ) 281 | --------------------------------------------------------------------------------