├── tests └── __init__.py ├── stablepy ├── face_restoration │ ├── __init__.py │ ├── gfpgan.py │ ├── restoreformer.py │ ├── codeformer.py │ ├── main_face_restoration.py │ └── face_restoration_utils.py ├── diffusers_vanilla │ ├── sd_embed │ │ ├── __init__.py │ │ └── LICENSE │ ├── extra_pipe │ │ ├── __init__.py │ │ └── flux │ │ │ └── __init__.py │ ├── preprocessor │ │ ├── __init__.py │ │ ├── transformers_lib │ │ │ ├── __init__.py │ │ │ └── pipelines.py │ │ ├── controlnet_aux_beta │ │ │ ├── __init__.py │ │ │ ├── teed │ │ │ │ ├── Fsmish.py │ │ │ │ ├── LICENSE.txt │ │ │ │ ├── Xsmish.py │ │ │ │ ├── __init__.py │ │ │ │ └── ted.py │ │ │ ├── lineart_standard │ │ │ │ └── __init__.py │ │ │ ├── preprocessor_utils.py │ │ │ └── anyline │ │ │ │ └── __init__.py │ │ ├── constans_preprocessor.py │ │ ├── image_utils.py │ │ └── main_preprocessor.py │ ├── extra_scheduler │ │ ├── __init__.py │ │ └── scheduling_euler_discrete_variants.py │ ├── inpainting_canvas.py │ ├── style_prompt_config.py │ ├── t5_embedder.py │ ├── extra_model_loaders.py │ ├── high_resolution.py │ ├── prompt_weights.py │ ├── sampler_scheduler_config.py │ ├── lora_loader.py │ ├── adetailer.py │ └── main_prompt_embeds.py ├── __version__.py ├── logging │ └── logging_setup.py ├── __init__.py └── upscalers │ ├── pipelines │ ├── swinir.py │ ├── scunet.py │ ├── base.py │ └── common.py │ ├── main_upscaler.py │ └── utils_upscaler.py ├── requirements_dev.txt ├── pyproject.toml └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/face_restoration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/sd_embed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.5" 2 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/extra_pipe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/extra_pipe/flux/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/extra_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/transformers_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | omegaconf==2.3.0 2 | torch 3 | git+https://github.com/huggingface/diffusers.git 4 | git+https://github.com/damian0815/compel.git 5 | invisible_watermark 6 | transformers 7 | accelerate 8 | scipy 9 | safetensors==0.3.3 10 | xformers 11 | mediapy 12 | ipywidgets==7.7.1 13 | controlnet_aux==0.0.6 14 | mediapipe==0.10.1 15 | pytorch-lightning 16 | git+https://github.com/R3gm/asdff.git 17 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/teed/Fsmish.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script based on: 3 | Wang, Xueliang, Honge Ren, and Achuan Wang. 4 | "Smish: A Novel Activation Function for Deep Learning Methods. 5 | " Electronics 11.4 (2022): 540. 6 | """ 7 | 8 | # import pytorch 9 | import torch 10 | 11 | 12 | @torch.jit.script 13 | def smish(input): 14 | """ 15 | Applies the mish function element-wise: 16 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) 17 | See additional documentation for mish class. 18 | """ 19 | return input * torch.tanh(torch.log(1 + torch.sigmoid(input))) 20 | -------------------------------------------------------------------------------- /stablepy/logging/logging_setup.py: -------------------------------------------------------------------------------- 1 | import logging, sys 2 | 3 | def setup_logger(name_log): 4 | logger = logging.getLogger(name_log) 5 | logger.setLevel(logging.INFO) 6 | 7 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 8 | _default_handler.flush = sys.stderr.flush 9 | logger.addHandler(_default_handler) 10 | 11 | logger.propagate = False 12 | 13 | handlers = logger.handlers 14 | 15 | for handler in handlers: 16 | formatter = logging.Formatter("[%(levelname)s] >> %(message)s") 17 | handler.setFormatter(formatter) 18 | 19 | #logger.handlers 20 | 21 | return logger 22 | 23 | logger = setup_logger("stablepy") 24 | logger.setLevel(logging.INFO) 25 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/teed/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xavier Soria Poma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/teed/Xsmish.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script based on: 3 | Wang, Xueliang, Honge Ren, and Achuan Wang. 4 | "Smish: A Novel Activation Function for Deep Learning Methods. 5 | " Electronics 11.4 (2022): 540. 6 | smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) 7 | """ 8 | 9 | # import pytorch 10 | # import activation functions 11 | from torch import nn 12 | 13 | from .Fsmish import smish 14 | 15 | 16 | class Smish(nn.Module): 17 | """ 18 | Applies the mish function element-wise: 19 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 20 | Shape: 21 | - Input: (N, *) where * means, any number of additional 22 | dimensions 23 | - Output: (N, *), same shape as the input 24 | Examples: 25 | >>> m = Mish() 26 | >>> input = torch.randn(2) 27 | >>> output = m(input) 28 | Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html 29 | """ 30 | 31 | def __init__(self): 32 | """ 33 | Init method. 34 | """ 35 | super().__init__() 36 | 37 | def forward(self, input): 38 | """ 39 | Forward pass of the function. 40 | """ 41 | return smish(input) 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "stablepy" 3 | version = "0.6.5" 4 | description = "A tool for easy use of stable diffusion" 5 | authors = ["Roger Condori(R3gm) "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | torch = {version = "*", source = "pytorch-gpu-src"} 11 | torchvision = {version = "*", source = "pytorch-gpu-src"} 12 | torchaudio = {version = "*", source = "pytorch-gpu-src"} 13 | omegaconf = ">=2.3.0" 14 | diffusers = "0.31.0" 15 | compel = "2.0.2" 16 | invisible-watermark = "0.2.0" 17 | transformers = ">=4.47.1,!=4.57.0" 18 | accelerate = "1.2.1" 19 | safetensors = ">=0.4.3" 20 | mediapy = ">=1.1.9" 21 | ipywidgets = "7.7.1" 22 | controlnet-aux = "0.0.6" 23 | mediapipe = ">=0.10.5" 24 | pytorch-lightning = ">=2.0.9.post0" 25 | ultralytics = "<=8.3.48" 26 | huggingface_hub = ">=0.23.1" 27 | peft = ">=0.11.1" 28 | torchsde = ">=0.2.6" 29 | onnxruntime = ">=1.18.0" 30 | insightface = ">=0.7.3" 31 | opencv-contrib-python = ">=4.8.0.76" 32 | sentencepiece = "*" 33 | numpy = "<=1.26.4" 34 | lark = "*" 35 | spandrel = "0.4.1" 36 | spandrel-extra-arches = "0.2.0" 37 | facexlib = "0.3.0" 38 | 39 | [[tool.poetry.source]] 40 | name = "pytorch-gpu-src" 41 | url = "https://download.pytorch.org/whl/cu121" 42 | priority = "explicit" 43 | 44 | [build-system] 45 | requires = ["poetry-core"] 46 | build-backend = "poetry.core.masonry.api" 47 | -------------------------------------------------------------------------------- /stablepy/face_restoration/gfpgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from ..upscalers.utils_upscaler import load_spandrel_model, load_file_from_url 6 | from .face_restoration_utils import CommonFaceRestoration 7 | 8 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 9 | model_download_name = "GFPGANv1.4.pth" 10 | 11 | 12 | class FaceRestorerGFPGAN(CommonFaceRestoration): 13 | def name(self): 14 | return "GFPGAN" 15 | 16 | def get_device(self): 17 | return self.device 18 | 19 | def load_net(self, path=None) -> torch.Module: 20 | if not path: 21 | path = model_url 22 | 23 | if path.startswith("http"): 24 | path = load_file_from_url( 25 | path, 26 | model_dir=self.model_path, 27 | ) 28 | 29 | return load_spandrel_model( 30 | path, 31 | device=self.get_device(), 32 | expected_architecture='GFPGAN', 33 | prefer_half=False, 34 | ).model 35 | 36 | def restore(self, np_image, **kwargs): 37 | def restore_face(cropped_face_t): 38 | assert self.net is not None 39 | return self.net(cropped_face_t, return_rgb=False)[0] 40 | 41 | return self.restore_with_helper(np_image, restore_face) 42 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/lineart_standard/__init__.py: -------------------------------------------------------------------------------- 1 | # Code based based from the repository comfyui_controlnet_aux: 2 | # https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/controlnet_aux/lineart_standard/__init__.py 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from ..preprocessor_utils import HWC3, fast_resize_image 7 | 8 | 9 | class LineartStandardDetector: 10 | def __call__( 11 | self, 12 | image=None, 13 | guassian_sigma=6.0, 14 | intensity_threshold=8, 15 | **kwargs 16 | ): 17 | 18 | detect_resolution = kwargs.pop("detect_resolution", 512) 19 | image_resolution = kwargs.pop("image_resolution", 512) 20 | 21 | image = HWC3(image) 22 | original_height, original_width, _ = image.shape 23 | image = fast_resize_image(image, detect_resolution) 24 | 25 | x = image.astype(np.float32) 26 | g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) 27 | intensity = np.min(g - x, axis=2).clip(0, 255) 28 | intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) 29 | intensity *= 127 30 | detected_map = intensity.clip(0, 255).astype(np.uint8) 31 | 32 | detected_map = HWC3(detected_map) 33 | resize_result = fast_resize_image( 34 | detected_map, image_resolution 35 | ) 36 | 37 | return Image.fromarray(resize_result) 38 | -------------------------------------------------------------------------------- /stablepy/face_restoration/restoreformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | 4 | import torch 5 | 6 | from ..upscalers.utils_upscaler import load_spandrel_model, load_file_from_url 7 | from .face_restoration_utils import CommonFaceRestoration 8 | 9 | model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' 10 | 11 | 12 | class FaceRestorerRestoreFormer(CommonFaceRestoration): 13 | def name(self): 14 | return "RestoreFormer" 15 | 16 | def load_net(self, path=None) -> torch.Module: 17 | if not path: 18 | path = model_url 19 | 20 | if path.startswith("http"): 21 | path = load_file_from_url( 22 | path, 23 | model_dir=self.model_path, 24 | ) 25 | 26 | if os.path.exists(path): 27 | return load_spandrel_model( 28 | path, 29 | device=self.device, 30 | expected_architecture='RestoreFormer', 31 | prefer_half=False, 32 | ).model 33 | raise ValueError("No RestoreFormer model found") 34 | 35 | def get_device(self): 36 | return self.device 37 | 38 | def restore(self, np_image, **kwargs): 39 | 40 | def restore_face(cropped_face_t): 41 | assert self.net is not None 42 | return self.net(cropped_face_t)[0] 43 | 44 | return self.restore_with_helper(np_image, restore_face) 45 | -------------------------------------------------------------------------------- /stablepy/face_restoration/codeformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | 4 | import torch 5 | 6 | from ..upscalers.utils_upscaler import load_spandrel_model, load_file_from_url 7 | from .face_restoration_utils import CommonFaceRestoration 8 | 9 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 10 | 11 | 12 | class FaceRestorerCodeFormer(CommonFaceRestoration): 13 | def name(self): 14 | return "CodeFormer" 15 | 16 | def load_net(self, path=None) -> torch.Module: 17 | if not path: 18 | path = model_url 19 | 20 | if path.startswith("http"): 21 | path = load_file_from_url( 22 | path, 23 | model_dir=self.model_path, 24 | ) 25 | 26 | if os.path.exists(path): 27 | return load_spandrel_model( 28 | path, 29 | device=self.device, 30 | expected_architecture='CodeFormer', 31 | prefer_half=False, 32 | ).model 33 | raise ValueError("No codeformer model found") 34 | 35 | def get_device(self): 36 | return self.device 37 | 38 | def restore(self, np_image, w: float | None = None, **kwargs): 39 | if w is None: 40 | w = 0.5 41 | 42 | def restore_face(cropped_face_t): 43 | assert self.net is not None 44 | return self.net(cropped_face_t, weight=w, adain=True)[0] 45 | 46 | return self.restore_with_helper(np_image, restore_face) 47 | -------------------------------------------------------------------------------- /stablepy/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ 2 | from .diffusers_vanilla.model import Model_Diffusers 3 | from .diffusers_vanilla.adetailer import ad_model_process 4 | from .diffusers_vanilla import utils 5 | from .upscalers.esrgan import UpscalerESRGAN, UpscalerLanczos, UpscalerNearest 6 | from .logging.logging_setup import logger 7 | from .diffusers_vanilla.high_resolution import LATENT_UPSCALERS, ALL_BUILTIN_UPSCALERS 8 | from .diffusers_vanilla.constants import ( 9 | CONTROLNET_MODEL_IDS, 10 | VALID_TASKS, 11 | FLASH_LORA, 12 | SCHEDULER_CONFIG_MAP, 13 | scheduler_names, 14 | IP_ADAPTER_MODELS, 15 | IP_ADAPTERS_SD, 16 | IP_ADAPTERS_SDXL, 17 | REPO_IMAGE_ENCODER, 18 | ALL_PROMPT_WEIGHT_OPTIONS, 19 | PROMPT_WEIGHT_OPTIONS_PRIORITY, 20 | SD15_TASKS, 21 | SDXL_TASKS, 22 | SCHEDULE_TYPE_OPTIONS, 23 | SCHEDULE_PREDICTION_TYPE_OPTIONS, 24 | FLUX_SCHEDULE_TYPES, 25 | FLUX_SCHEDULE_TYPE_OPTIONS, 26 | VALID_FILENAME_PATTERNS, 27 | ) 28 | from .diffusers_vanilla.sampler_scheduler_config import ( 29 | check_scheduler_compatibility 30 | ) 31 | from .diffusers_vanilla.preprocessor.constans_preprocessor import ( 32 | TASK_AND_PREPROCESSORS, 33 | T2I_PREPROCESSOR_NAME, 34 | ALL_PREPROCESSOR_TASKS, 35 | ) 36 | from .diffusers_vanilla.preprocessor.main_preprocessor import Preprocessor 37 | from .upscalers.main_upscaler import BUILTIN_UPSCALERS, load_upscaler_model 38 | from .face_restoration.main_face_restoration import ( 39 | FACE_RESTORATION_MODELS, 40 | batch_process_face_restoration, 41 | load_face_restoration_model, 42 | process_face_restoration, 43 | ) 44 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/teed/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from einops import rearrange 6 | from huggingface_hub import hf_hub_download 7 | from PIL import Image 8 | from ..preprocessor_utils import HWC3, fast_resize_image, safe_step 9 | from .ted import TED 10 | 11 | 12 | class TEEDdetector: 13 | def __init__(self, pretrained_model_or_path=None, filename=None, subfolder=None): 14 | if not pretrained_model_or_path: 15 | pretrained_model_or_path = "fal-ai/teed" 16 | filename = "5_model.pth" 17 | subfolder = None 18 | if os.path.isdir(pretrained_model_or_path): 19 | model_path = os.path.join(pretrained_model_or_path, filename) 20 | else: 21 | model_path = hf_hub_download( 22 | pretrained_model_or_path, filename, subfolder=subfolder 23 | ) 24 | 25 | model = TED() 26 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 27 | self.model = model 28 | 29 | def to(self, device): 30 | self.model.to(device) 31 | return self 32 | 33 | def __call__( 34 | self, 35 | image, 36 | safe_steps=2, 37 | **kwargs 38 | ): 39 | detect_resolution = kwargs.pop("detect_resolution", 512) 40 | image_resolution = kwargs.pop("image_resolution", 512) 41 | 42 | device = next(iter(self.model.parameters())).device 43 | 44 | image = HWC3(image) 45 | original_height, original_width, _ = image.shape 46 | image = fast_resize_image(image, detect_resolution) 47 | 48 | assert image.ndim == 3 49 | height, width, _ = image.shape 50 | with torch.no_grad(): 51 | image_teed = torch.from_numpy(image.copy()).float().to(device) 52 | image_teed = rearrange(image_teed, "h w c -> 1 c h w") 53 | edges = self.model(image_teed) 54 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 55 | edges = [ 56 | cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) 57 | for e in edges 58 | ] 59 | edges = np.stack(edges, axis=2) 60 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 61 | if safe_steps != 0: 62 | edge = safe_step(edge, safe_steps) 63 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 64 | 65 | detected_map = edge 66 | 67 | detected_map = HWC3(detected_map) 68 | resize_result = fast_resize_image( 69 | detected_map, image_resolution 70 | ) 71 | 72 | return Image.fromarray(resize_result) 73 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/constans_preprocessor.py: -------------------------------------------------------------------------------- 1 | AUX_TASKS = [ 2 | "HED", 3 | "Midas", 4 | "MLSD", 5 | "Openpose", 6 | "PidiNet", 7 | "NormalBae", 8 | "Lineart", 9 | "LineartAnime", 10 | "Canny", 11 | "ContentShuffle", 12 | ] 13 | 14 | TRANSFORMERS_LIB_TASKS = ["DPT", "UPerNet", "ZoeDepth", "SegFormer", "DepthAnything"] 15 | 16 | AUX_BETA_TASKS = ["TEED", "Anyline", "Lineart standard"] 17 | 18 | EXTRA_AUX_TASKS = ["Recolor", "Blur"] 19 | 20 | ALL_PREPROCESSOR_TASKS = AUX_TASKS + TRANSFORMERS_LIB_TASKS + AUX_BETA_TASKS + EXTRA_AUX_TASKS 21 | 22 | T2I_PREPROCESSOR_NAME = { 23 | "sdxl_canny_t2i": "Canny", 24 | "sdxl_openpose_t2i": "Openpose core", 25 | "sdxl_sketch_t2i": "PidiNet", 26 | "sdxl_depth-midas_t2i": "Midas", 27 | "sdxl_lineart_t2i": "Lineart", 28 | } 29 | 30 | TASK_AND_PREPROCESSORS = { 31 | "openpose": [ 32 | "Openpose", 33 | "Openpose core", 34 | "None", 35 | ], 36 | "scribble": [ 37 | "HED", 38 | "PidiNet", 39 | "TEED", 40 | "None", 41 | ], 42 | "softedge": [ 43 | "PidiNet", 44 | "HED", 45 | "HED safe", 46 | "PidiNet safe", 47 | "TEED", 48 | "None", 49 | ], 50 | "segmentation": [ 51 | "UPerNet", 52 | "SegFormer", 53 | "None", 54 | ], 55 | "depth": [ 56 | "DPT", 57 | "Midas", 58 | "ZoeDepth", 59 | "DepthAnything", 60 | "None", 61 | ], 62 | "normalbae": [ 63 | "NormalBae", 64 | "None", 65 | ], 66 | "lineart": [ 67 | "Lineart", 68 | "Lineart coarse", 69 | "Lineart (anime)", 70 | "Lineart standard", 71 | "Anyline", 72 | "None", 73 | "None (anime)", 74 | ], 75 | "lineart_anime": [ 76 | "Lineart", 77 | "Lineart coarse", 78 | "Lineart (anime)", 79 | "Lineart standard", 80 | "Anyline", 81 | "None", 82 | "None (anime)", 83 | ], 84 | "shuffle": [ 85 | "ContentShuffle", 86 | "None", 87 | ], 88 | "canny": [ 89 | "Canny", 90 | "None", 91 | ], 92 | "mlsd": [ 93 | "MLSD", 94 | "None", 95 | ], 96 | "ip2p": [ 97 | "None" 98 | ], 99 | "recolor": [ 100 | "Recolor luminance", 101 | "Recolor intensity", 102 | "None", 103 | ], 104 | "pattern": [ 105 | "None", 106 | ], 107 | "tile": [ 108 | "Blur", 109 | "None", 110 | ], 111 | "repaint": [ 112 | "None", 113 | ], 114 | "inpaint": [ 115 | "None", 116 | ], 117 | "img2img": [ 118 | "None", 119 | ], 120 | } 121 | -------------------------------------------------------------------------------- /stablepy/upscalers/pipelines/swinir.py: -------------------------------------------------------------------------------- 1 | from ..utils_upscaler import upscale_2, load_spandrel_model, release_resources_upscaler, load_file_from_url 2 | from .base import Upscaler, UpscalerData 3 | from PIL import Image 4 | 5 | SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" 6 | 7 | 8 | class UpscalerSwinIR(Upscaler): 9 | def __init__(self, model="SwinIR 4x", tile=192, tile_overlap=8, device="cuda", half=False, **kwargs): 10 | self._cached_model = None 11 | self._cached_model_config = None 12 | self.name = "SwinIR" 13 | self.model_url = SWINIR_MODEL_URL 14 | self.model_name = "SwinIR 4x" 15 | super().__init__() 16 | self.scalers = [UpscalerData(self.model_name, self.model_url, self)] 17 | 18 | self.device = device 19 | self.half = half 20 | self.tile = tile 21 | self.tile_overlap = tile_overlap 22 | 23 | release_resources_upscaler() 24 | 25 | try: 26 | self.model_descriptor = self.load_model(model) 27 | except Exception as e: 28 | print(f"Unable to load SwinIR model {model}: {e}") 29 | self.model_descriptor = None 30 | 31 | def do_upscale(self, img: Image.Image) -> Image.Image: 32 | release_resources_upscaler() 33 | 34 | if self.model_descriptor is None: 35 | return img 36 | 37 | img = upscale_2( 38 | img, 39 | self.model_descriptor, 40 | tile_size=self.tile, 41 | tile_overlap=self.tile_overlap, 42 | scale=self.model_descriptor.scale, 43 | desc="SwinIR", 44 | disable_progress_bar=self.disable_progress_bar, 45 | ) 46 | 47 | release_resources_upscaler() 48 | 49 | return img 50 | 51 | def load_model(self, path, scale=4): 52 | if self.scalers[0].name == path: 53 | path = self.scalers[0].data_path 54 | 55 | if path.startswith("http"): 56 | filename = load_file_from_url( 57 | url=path, 58 | model_dir=self.model_download_path, 59 | file_name=f"{self.model_name.replace(' ', '_')}.pth", 60 | ) 61 | else: 62 | filename = path 63 | 64 | model_descriptor = load_spandrel_model( 65 | filename, 66 | device=self.device, 67 | # prefer_half=self.half, 68 | expected_architecture="SwinIR", 69 | ) 70 | 71 | # try: 72 | # model_descriptor.model.compile() 73 | # except Exception: 74 | # logger.warning("Failed to compile SwinIR model, fallback to JIT") 75 | 76 | return model_descriptor 77 | 78 | 79 | if __name__ == "__main__": 80 | from PIL import Image 81 | 82 | up = UpscalerSwinIR(model="SwinIR 4x", tile=192, tile_overlap=8, device="cuda", half=False) 83 | scale_up = 1.1 84 | img = Image.open("img.png") 85 | print(img.size) 86 | img_up = up.upscale(img, scale_up) 87 | print(img_up.size) 88 | -------------------------------------------------------------------------------- /stablepy/upscalers/main_upscaler.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | ANY_UPSCALER = (".pipelines.common", "UpscalerCommon") 4 | 5 | UPSCALER_MAP = { 6 | # None: (".pipelines.base", "UpscalerNone"), 7 | # "None": (".pipelines.base", "UpscalerNone"), 8 | "Lanczos": (".pipelines.base", "UpscalerLanczos"), 9 | "Nearest": (".pipelines.base", "UpscalerNearest"), 10 | "ESRGAN_4x": (".pipelines.common", "UpscalerCommon"), 11 | "DAT x2": (".pipelines.common", "UpscalerCommon"), 12 | "DAT x3": (".pipelines.common", "UpscalerCommon"), 13 | "DAT x4": (".pipelines.common", "UpscalerCommon"), 14 | "HAT x4": (".pipelines.common", "UpscalerCommon"), 15 | "R-ESRGAN General 4xV3": (".pipelines.common", "UpscalerCommon"), 16 | "R-ESRGAN General WDN 4xV3": (".pipelines.common", "UpscalerCommon"), 17 | "R-ESRGAN AnimeVideo": (".pipelines.common", "UpscalerCommon"), 18 | "R-ESRGAN 4x+": (".pipelines.common", "UpscalerCommon"), 19 | "R-ESRGAN 4x+ Anime6B": (".pipelines.common", "UpscalerCommon"), 20 | "R-ESRGAN 2x+": (".pipelines.common", "UpscalerCommon"), 21 | "ScuNET GAN": (".pipelines.scunet", "UpscalerScuNET"), 22 | "ScuNET PSNR": (".pipelines.scunet", "UpscalerScuNET"), 23 | "SwinIR 4x": (".pipelines.swinir", "UpscalerSwinIR"), 24 | } 25 | 26 | BUILTIN_UPSCALERS = list(UPSCALER_MAP.keys()) 27 | 28 | 29 | def load_upscaler_model(**kwargs): 30 | """ 31 | Loads and returns an upscaler model class instance based on the provided keyword arguments. 32 | 33 | Keyword Args: 34 | model (str): The name or path of the model to load. It can be any of the BUILTIN_UPSCALERS. 35 | tile (int, optional): The size of the tiles to use for upscaling. Default is 192. 36 | tile_overlap (int, optional): The overlap between tiles. Default is 8. 37 | device (str, optional): The device to use for computation, e.g., "cuda" or "cpu". Default is "cuda". 38 | half (bool, optional): Whether to use half-precision floats. Default is False. 39 | **kwargs: Additional keyword arguments to pass to the model class constructor. 40 | 41 | Returns: 42 | object: An instance of the upscaler model class. 43 | 44 | Example: 45 | from PIL import Image 46 | 47 | # Load the upscaler model 48 | upscaler = load_upscaler_model(model="your_model_name_or_path", tile=192, tile_overlap=8, device="cuda", half=False) 49 | 50 | # Open an image using PIL 51 | img_pre_up = Image.open("path_to_your_image.jpg") 52 | 53 | # Define the upscaling parameters 54 | upscaler_increases_size = 1.4 55 | disable_progress_bar = False 56 | 57 | # Use the upscaler to upscale the image 58 | image_pos_up = upscaler.upscale(img_pre_up, upscaler_increases_size, disable_progress_bar) 59 | """ 60 | 61 | model = kwargs.get("model", None) 62 | 63 | # Get the module and class model based on `model` 64 | module_path, class_name = UPSCALER_MAP.get(model, ANY_UPSCALER) 65 | 66 | # Import the module and get the class 67 | module = importlib.import_module(module_path, package=__package__) 68 | cls = getattr(module, class_name) 69 | 70 | return cls(**kwargs) 71 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/preprocessor_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def HWC3(x): 9 | assert x.dtype == np.uint8 10 | if x.ndim == 2: 11 | x = x[:, :, None] 12 | assert x.ndim == 3 13 | H, W, C = x.shape 14 | assert C == 1 or C == 3 or C == 4 15 | if C == 3: 16 | return x 17 | if C == 1: 18 | return np.concatenate([x, x, x], axis=2) 19 | if C == 4: 20 | color = x[:, :, 0:3].astype(np.float32) 21 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 22 | y = color * alpha + 255.0 * (1.0 - alpha) 23 | y = y.clip(0, 255).astype(np.uint8) 24 | return y 25 | 26 | 27 | def make_noise_disk(H, W, C, F): 28 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 29 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 30 | noise = noise[F: F + H, F: F + W] 31 | noise -= np.min(noise) 32 | noise /= np.max(noise) 33 | if C == 1: 34 | noise = noise[:, :, None] 35 | return noise 36 | 37 | 38 | def nms(x, t, s): 39 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 40 | 41 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 42 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 43 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 44 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 45 | 46 | y = np.zeros_like(x) 47 | 48 | for f in [f1, f2, f3, f4]: 49 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 50 | 51 | z = np.zeros_like(y, dtype=np.uint8) 52 | z[y > t] = 255 53 | return z 54 | 55 | 56 | def min_max_norm(x): 57 | x -= np.min(x) 58 | x /= np.maximum(np.max(x), 1e-5) 59 | return x 60 | 61 | 62 | def safe_step(x, step=2): 63 | y = x.astype(np.float32) * float(step + 1) 64 | y = y.astype(np.int32).astype(np.float32) / float(step) 65 | return y 66 | 67 | 68 | def img2mask(img, H, W, low=10, high=90): 69 | assert img.ndim == 3 or img.ndim == 2 70 | assert img.dtype == np.uint8 71 | 72 | if img.ndim == 3: 73 | y = img[:, :, random.randrange(0, img.shape[2])] 74 | else: 75 | y = img 76 | 77 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 78 | 79 | if random.uniform(0, 1) < 0.5: 80 | y = 255 - y 81 | 82 | return y < np.percentile(y, random.randrange(low, high)) 83 | 84 | 85 | def fast_resize_image(input_image, resolution): 86 | H, W, C = input_image.shape 87 | H = float(H) 88 | W = float(W) 89 | k = float(resolution) / min(H, W) 90 | H *= k 91 | W *= k 92 | H = int(np.round(H / 64.0)) * 64 93 | W = int(np.round(W / 64.0)) * 64 94 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 95 | return img 96 | 97 | 98 | def torch_gc(): 99 | if torch.cuda.is_available(): 100 | torch.cuda.empty_cache() 101 | torch.cuda.ipc_collect() 102 | -------------------------------------------------------------------------------- /stablepy/upscalers/pipelines/scunet.py: -------------------------------------------------------------------------------- 1 | from ..utils_upscaler import upscale_2, load_spandrel_model, release_resources_upscaler, load_file_from_url 2 | from .base import Upscaler, UpscalerData 3 | import PIL.Image 4 | import os 5 | 6 | 7 | class UpscalerScuNET(Upscaler): 8 | def __init__(self, model="ScuNET GAN", tile=192, tile_overlap=8, device="cuda", half=False, **kwargs): 9 | self.name = "ScuNET" 10 | self.model_name = "ScuNET GAN" 11 | self.model_name2 = "ScuNET PSNR" 12 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" 13 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" 14 | super().__init__() 15 | model_a = UpscalerData(self.model_name, self.model_url, self, 4) 16 | model_b = UpscalerData(self.model_name2, self.model_url2, self) 17 | self.scalers = [model_a, model_b] 18 | 19 | self.device = device 20 | self.half = half 21 | self.tile = tile 22 | self.tile_overlap = tile_overlap 23 | 24 | release_resources_upscaler() 25 | 26 | try: 27 | self.model_descriptor = self.load_model(model) 28 | except Exception as e: 29 | print(f"Unable to load ScuNET model {model}: {e}") 30 | self.model_descriptor = None 31 | 32 | def do_upscale(self, img: PIL.Image.Image): 33 | release_resources_upscaler() 34 | 35 | if self.model_descriptor is None: 36 | return img 37 | 38 | img = upscale_2( 39 | img, 40 | self.model_descriptor, 41 | tile_size=self.tile, 42 | tile_overlap=self.tile_overlap, 43 | scale=1, # ScuNET is a denoising model, not an upscaler 44 | desc='ScuNET', 45 | disable_progress_bar=self.disable_progress_bar, 46 | ) 47 | 48 | release_resources_upscaler() 49 | 50 | return img 51 | 52 | def load_model(self, path: str): 53 | for scaler in self.scalers: 54 | if scaler.name == path: 55 | if scaler.local_data_path.startswith("http"): 56 | scaler.local_data_path = load_file_from_url( 57 | scaler.data_path, 58 | model_dir=self.model_download_path, 59 | ) 60 | if not os.path.exists(scaler.local_data_path): 61 | raise FileNotFoundError(f"ScuNET data missing: {scaler.local_data_path}") 62 | return load_spandrel_model( 63 | scaler.local_data_path, 64 | device=self.device, 65 | # prefer_half=self.half, 66 | expected_architecture="SCUNet", 67 | ) 68 | raise ValueError(f"Unable to find model info: {path}") 69 | 70 | 71 | if __name__ == "__main__": 72 | from PIL import Image 73 | 74 | up = UpscalerScuNET(model="ScuNET PSNR", tile=192, tile_overlap=8, device="cuda", half=False) 75 | scale_up = 1.1 76 | img = Image.open("img.png") 77 | print(img.size) 78 | img_up = up.upscale(img, scale_up) # ScuNET PSNR ScuNET GAN 79 | print(img_up.size) 80 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/inpainting_canvas.py: -------------------------------------------------------------------------------- 1 | # ===================================== 2 | # Inpainting canvas 3 | # ===================================== 4 | canvas_html = """ 5 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 72 | """ 73 | 74 | import base64, os 75 | from base64 import b64decode 76 | import numpy as np 77 | import matplotlib.pyplot as plt 78 | from IPython.display import display, HTML 79 | from ..logging.logging_setup import logger 80 | import torch 81 | 82 | class NotValid(Exception): 83 | pass 84 | 85 | 86 | def draw(imgm, filename="drawing.png", w=400, h=200, line_width=1): 87 | try: 88 | from google.colab.output import eval_js 89 | 90 | display( 91 | HTML(canvas_html % (w, h, w, h, filename.split(".")[-1], imgm, line_width)) 92 | ) 93 | data = eval_js("data") 94 | binary = b64decode(data.split(",")[1]) 95 | with open(filename, "wb") as f: 96 | f.write(binary) 97 | logger.info(f"Created draw and saved: {filename}") 98 | except Exception: 99 | raise NotValid( 100 | "The 'image_mask' parameter is required for this task. If you're trying " 101 | "to use the option to draw mask from a Notebook, this option is only" 102 | " available for colab" 103 | ) 104 | 105 | 106 | # the control image of init_image and mask_image 107 | def make_inpaint_condition(image, image_mask): 108 | image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 109 | image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 110 | 111 | assert ( 112 | image.shape[0:1] == image_mask.shape[0:1] 113 | ), "image and image_mask must have the same image size" 114 | image[image_mask > 0.5] = -1.0 # set as masked pixel 115 | image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) 116 | image = torch.from_numpy(image) 117 | return image 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | diffusers/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /stablepy/upscalers/pipelines/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | import PIL 4 | from PIL import Image 5 | 6 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 7 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) 8 | 9 | 10 | class Upscaler: 11 | name = None 12 | model_path = None 13 | model_name = None 14 | model_url = None 15 | model = None 16 | scalers: list 17 | tile = True 18 | 19 | def __init__(self): 20 | self.tile_size = 192 21 | self.tile_pad = 8 22 | self.device = "cpu" 23 | self.scale = 1 24 | self.half = False 25 | self.model_download_path = os.path.join(os.path.expanduser("~"), ".cache", "upscalers") 26 | self.can_tile = True 27 | self.disable_progress_bar = False 28 | 29 | @abstractmethod 30 | def do_upscale(self, img: PIL.Image): 31 | return img 32 | 33 | def upscale(self, img: PIL.Image, scale, disable_progress_bar=False): 34 | self.disable_progress_bar = disable_progress_bar 35 | self.scale = scale 36 | 37 | dest_w = int((img.width * scale) // 8 * 8) 38 | dest_h = int((img.height * scale) // 8 * 8) 39 | 40 | for i in range(3): 41 | if img.width >= dest_w and img.height >= dest_h and (i > 0 or scale != 1): 42 | break 43 | 44 | shape = (img.width, img.height) 45 | 46 | img = self.do_upscale(img) 47 | 48 | if shape == (img.width, img.height): 49 | break 50 | 51 | if img.width != dest_w or img.height != dest_h: 52 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 53 | 54 | return img 55 | 56 | @abstractmethod 57 | def load_model(self, path: str): 58 | pass 59 | 60 | 61 | class UpscalerData: 62 | name = None 63 | data_path = None 64 | scale: int = 4 65 | scaler: Upscaler = None 66 | model: None 67 | 68 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, supports_half=True, model=None): 69 | self.name = name 70 | self.data_path = path 71 | self.local_data_path = path 72 | self.scaler = upscaler 73 | self.scale = scale 74 | self.supports_half = supports_half 75 | self.model = model 76 | 77 | def __repr__(self): 78 | return f"" 79 | 80 | 81 | class UpscalerNone(Upscaler): 82 | name = "None" 83 | scalers = [] 84 | 85 | def load_model(self, path): 86 | pass 87 | 88 | def do_upscale(self, img): 89 | return img 90 | 91 | def __init__(self, **kwargs): 92 | super().__init__() 93 | self.scalers = [UpscalerData("None", None, self)] 94 | 95 | 96 | class UpscalerLanczos(Upscaler): 97 | scalers = [] 98 | 99 | def do_upscale(self, img): 100 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 101 | 102 | def load_model(self, _): 103 | pass 104 | 105 | def __init__(self, **kwargs): 106 | super().__init__() 107 | self.name = "Lanczos" 108 | self.scalers = [UpscalerData("Lanczos", None, self)] 109 | 110 | 111 | class UpscalerNearest(Upscaler): 112 | scalers = [] 113 | 114 | def do_upscale(self, img): 115 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) 116 | 117 | def load_model(self, _): 118 | pass 119 | 120 | def __init__(self, **kwargs): 121 | super().__init__() 122 | self.name = "Nearest" 123 | self.scalers = [UpscalerData("Nearest", None, self)] 124 | 125 | 126 | if __name__ == "__main__": 127 | up = UpscalerNearest() 128 | scale_up = 1.1 129 | img = Image.open("img.png") 130 | print(img.size) 131 | img_up = up.upscale(img, scale_up) 132 | print(img_up.size) 133 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/style_prompt_config.py: -------------------------------------------------------------------------------- 1 | from ..logging.logging_setup import logger 2 | from typing import Tuple 3 | import json 4 | 5 | 6 | def get_json_content(file_path): 7 | try: 8 | with open(file_path, 'rt', encoding="utf-8") as file: 9 | json_data = json.load(file) 10 | return json_data 11 | except Exception as e: 12 | logger.error(f"A Problem occurred: {str(e)}") 13 | 14 | 15 | BASE_STYLE_LIST = [ 16 | { 17 | "name": "(No style)", 18 | "prompt": "{prompt}", 19 | "negative_prompt": "", 20 | }, 21 | { 22 | "name": "Cinematic", 23 | "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 24 | "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", 25 | }, 26 | { 27 | "name": "Photographic", 28 | "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", 29 | "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", 30 | }, 31 | { 32 | "name": "Anime", 33 | "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", 34 | "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", 35 | }, 36 | { 37 | "name": "Manga", 38 | "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", 39 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", 40 | }, 41 | { 42 | "name": "Digital Art", 43 | "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", 44 | "negative_prompt": "photo, photorealistic, realism, ugly", 45 | }, 46 | { 47 | "name": "Pixel art", 48 | "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", 49 | "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", 50 | }, 51 | { 52 | "name": "Fantasy art", 53 | "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", 54 | "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", 55 | }, 56 | { 57 | "name": "Neonpunk", 58 | "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", 59 | "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", 60 | }, 61 | { 62 | "name": "3D Model", 63 | "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", 64 | "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", 65 | }, 66 | ] 67 | 68 | styles_data = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in BASE_STYLE_LIST} 69 | STYLE_NAMES = list(styles_data.keys()) 70 | DEFAULT_STYLE_NAME = "(No style)" 71 | 72 | 73 | def apply_style(style_name_list: list, positive: str, negative: str = "", styles_data=None, STYLE_NAMES=None) -> Tuple[str, str]: 74 | 75 | for style_name in style_name_list: 76 | if style_name in ["(No style)", "base", "", None] or "none by" in style_name: 77 | return positive, negative 78 | if style_name in STYLE_NAMES: 79 | p, n = styles_data.get(style_name) 80 | if p.strip() == "{prompt}": 81 | return positive, negative 82 | positive, negative = p.replace("{prompt}", positive), n + ", " + negative 83 | else: 84 | logger.warning(f"{style_name} style not found") 85 | 86 | return positive, negative 87 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/anyline/__init__.py: -------------------------------------------------------------------------------- 1 | # code based in https://github.com/TheMistoAI/ComfyUI-Anyline/blob/main/anyline.py 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | from ..teed.ted import TED 10 | from ..preprocessor_utils import HWC3, fast_resize_image, safe_step 11 | 12 | 13 | class AnylineDetector: 14 | def __init__(self, pretrained_model_or_path=None, filename=None, subfolder=None): 15 | if not pretrained_model_or_path: 16 | pretrained_model_or_path = "TheMistoAI/MistoLine" 17 | filename = "MTEED.pth" 18 | subfolder = "Anyline" 19 | if os.path.isdir(pretrained_model_or_path): 20 | model_path = os.path.join(pretrained_model_or_path, filename) 21 | else: 22 | model_path = hf_hub_download( 23 | pretrained_model_or_path, filename, subfolder=subfolder 24 | ) 25 | 26 | model = TED() 27 | model.load_state_dict(torch.load(model_path, map_location="cpu")) 28 | self.model = model 29 | 30 | def to(self, device): 31 | self.model.to(device) 32 | return self 33 | 34 | def __call__( 35 | self, 36 | image, 37 | guassian_sigma=2.0, 38 | intensity_threshold=3, 39 | **kwargs 40 | ): 41 | from skimage import morphology 42 | 43 | detect_resolution = kwargs.pop("detect_resolution", 512) 44 | image_resolution = kwargs.pop("image_resolution", 512) 45 | 46 | device = next(iter(self.model.parameters())).device 47 | 48 | image = HWC3(image) 49 | original_height, original_width, _ = image.shape 50 | image = fast_resize_image(image, detect_resolution) 51 | 52 | assert image.ndim == 3 53 | height, width, _ = image.shape 54 | with torch.no_grad(): 55 | image_teed = torch.from_numpy(image.copy()).float().to(device) 56 | image_teed = rearrange(image_teed, "h w c -> 1 c h w") 57 | edges = self.model(image_teed) 58 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 59 | edges = [ 60 | cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) 61 | for e in edges 62 | ] 63 | edges = np.stack(edges, axis=2) 64 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 65 | edge = safe_step(edge, 2) 66 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 67 | 68 | mteed_result = edge 69 | mteed_result = HWC3(mteed_result) 70 | 71 | x = image.astype(np.float32) 72 | g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) 73 | intensity = np.min(g - x, axis=2).clip(0, 255) 74 | intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) 75 | intensity *= 127 76 | lineart_result = intensity.clip(0, 255).astype(np.uint8) 77 | 78 | lineart_result = HWC3(lineart_result) 79 | 80 | lineart_result = self.get_intensity_mask( 81 | lineart_result, lower_bound=0, upper_bound=255 82 | ) 83 | 84 | cleaned = morphology.remove_small_objects( 85 | lineart_result.astype(bool), min_size=36, connectivity=1 86 | ) 87 | lineart_result = lineart_result * cleaned 88 | final_result = self.combine_layers(mteed_result, lineart_result) 89 | 90 | final_result = HWC3(final_result.astype(np.uint8)) 91 | resize_result = fast_resize_image( 92 | final_result, image_resolution 93 | ) 94 | 95 | return Image.fromarray(resize_result) 96 | 97 | def get_intensity_mask(self, image_array, lower_bound, upper_bound): 98 | mask = image_array[:, :, 0] 99 | mask = np.where((mask >= lower_bound) & (mask <= upper_bound), mask, 0) 100 | mask = np.expand_dims(mask, 2).repeat(3, axis=2) 101 | return mask 102 | 103 | def combine_layers(self, base_layer, top_layer): 104 | mask = top_layer.astype(bool) 105 | temp = 1 - (1 - top_layer) * (1 - base_layer) 106 | result = base_layer * (~mask) + temp * mask 107 | return result 108 | -------------------------------------------------------------------------------- /stablepy/face_restoration/main_face_restoration.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from ..logging.logging_setup import logger 5 | from ..diffusers_vanilla.utils import release_resources 6 | 7 | FACE_RESTORATION_MODELS = ["CodeFormer", "GFPGAN", "RestoreFormer"] 8 | 9 | 10 | def load_face_restoration_model(face_restoration_model, device): 11 | """ 12 | Load the specified face restoration model. 13 | Parameters: 14 | face_restoration_model (str): The name of the face restoration model to load. 15 | Must be one of the models listed in FACE_RESTORATION_MODELS. 16 | device (str): The device to load the model on (e.g., 'cpu' or 'cuda'). 17 | Returns: 18 | model: An instance of the specified face restoration model, or None if an invalid model name is provided. 19 | """ 20 | 21 | if face_restoration_model == FACE_RESTORATION_MODELS[0]: 22 | from .codeformer import FaceRestorerCodeFormer 23 | model = FaceRestorerCodeFormer(device) 24 | elif face_restoration_model == FACE_RESTORATION_MODELS[1]: 25 | from .gfpgan import FaceRestorerGFPGAN 26 | model = FaceRestorerGFPGAN(device) 27 | elif face_restoration_model == FACE_RESTORATION_MODELS[2]: 28 | from .restoreformer import FaceRestorerRestoreFormer 29 | model = FaceRestorerRestoreFormer(device) 30 | else: 31 | valid_models = ", ".join(FACE_RESTORATION_MODELS) 32 | logger.error(f"Invalid face restoration model: {face_restoration_model}. Valid models are: {valid_models}") 33 | return None 34 | logger.info(f"Face restoration: {face_restoration_model}") 35 | model.load_net() 36 | return model 37 | 38 | 39 | def process_face_restoration( 40 | source_img, 41 | model, 42 | face_restoration_visibility, 43 | face_restoration_weight 44 | ): 45 | """ 46 | Process a single image for face restoration. 47 | 48 | Parameters: 49 | source_img (PIL.Image): The source image to be processed. 50 | model: The face restoration model to use. 51 | face_restoration_visibility (float): The visibility of the restored face in the final image. 52 | Should be between 0 and 1. 53 | face_restoration_weight (float): The weight parameter for CodeFormer model. 54 | 55 | Returns: 56 | PIL.Image: The processed image with face restoration applied. 57 | """ 58 | if face_restoration_visibility == 0 or model is None: 59 | return source_img 60 | 61 | source_img = source_img.convert("RGB") 62 | 63 | restored_img = model.restore( 64 | np.array(source_img, dtype=np.uint8), w=face_restoration_weight 65 | ) 66 | res = Image.fromarray(restored_img) 67 | 68 | if face_restoration_visibility < 1.0: 69 | res = Image.blend(source_img, res, face_restoration_visibility) 70 | 71 | return res 72 | 73 | 74 | def batch_process_face_restoration( 75 | images, 76 | face_restoration_model, 77 | face_restoration_visibility, 78 | face_restoration_weight, 79 | device="cuda", 80 | ): 81 | """ 82 | Processes a batch of images for face restoration using the specified model and parameters. 83 | Args: 84 | images (list): List of images to be processed. 85 | face_restoration_model (str): The name or path of the face restoration model to be used. 86 | face_restoration_visibility (float): The visibility parameter for face restoration. 87 | face_restoration_weight (float): The weight parameter for CodeFormer model. 88 | device (str, optional): The device to run the model on, either 'cuda' or 'cpu'. Defaults to 'cuda'. 89 | Returns: 90 | list: A list of processed images. If an error occurs during processing, the original image is returned in the list. 91 | """ 92 | 93 | model = load_face_restoration_model(face_restoration_model, device) 94 | 95 | result_list = [] 96 | for source_img in images: 97 | try: 98 | res = process_face_restoration( 99 | source_img, 100 | model, 101 | face_restoration_visibility, 102 | face_restoration_weight 103 | ) 104 | result_list.append(res) 105 | except Exception as e: 106 | logger.error(f"Failed face restoration: {str(e)}", exc_info=True) 107 | result_list.append(source_img) 108 | 109 | del model 110 | release_resources() 111 | 112 | return result_list 113 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/image_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def resize_image(input_image, resolution, interpolation=None): 6 | H, W, C = input_image.shape 7 | H = float(H) 8 | W = float(W) 9 | k = float(resolution) / max(H, W) 10 | H *= k 11 | W *= k 12 | H = int(np.round(H / 64.0)) * 64 13 | W = int(np.round(W / 64.0)) * 64 14 | if interpolation is None: 15 | interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA 16 | img = cv2.resize(input_image, (W, H), interpolation=interpolation) 17 | return img 18 | 19 | 20 | def apply_gaussian_blur(image_np, ksize=5): 21 | sigmaX = ksize / 2 22 | ksize = int(ksize) 23 | if ksize % 2 == 0: 24 | ksize += 1 25 | blurred_image_np = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX) 26 | return blurred_image_np 27 | 28 | 29 | def recolor_luminance(img, thr_a=1.0, **kwargs): 30 | result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2LAB) 31 | result = result[:, :, 0].astype(np.float32) / 255.0 32 | result = result ** thr_a 33 | result = (result * 255.0).clip(0, 255).astype(np.uint8) 34 | result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB) 35 | return result 36 | 37 | 38 | def recolor_intensity(img, thr_a=1.0, **kwargs): 39 | result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2HSV) 40 | result = result[:, :, 2].astype(np.float32) / 255.0 41 | result = result ** thr_a 42 | result = (result * 255.0).clip(0, 255).astype(np.uint8) 43 | result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB) 44 | return result 45 | 46 | 47 | def HWC3(x): 48 | """ copy """ 49 | assert x.dtype == np.uint8 50 | if x.ndim == 2: 51 | x = x[:, :, None] 52 | assert x.ndim == 3 53 | H, W, C = x.shape 54 | assert C == 1 or C == 3 or C == 4 55 | if C == 3: 56 | return x 57 | if C == 1: 58 | return np.concatenate([x, x, x], axis=2) 59 | if C == 4: 60 | color = x[:, :, 0:3].astype(np.float32) 61 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 62 | y = color * alpha + 255.0 * (1.0 - alpha) 63 | y = y.clip(0, 255).astype(np.uint8) 64 | return y 65 | 66 | 67 | def ade_palette(): 68 | """ADE20K palette that maps each class to RGB values.""" 69 | return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 70 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 71 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 72 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 73 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 74 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 75 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 76 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 77 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 78 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 79 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 80 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 81 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 82 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 83 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 84 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 85 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 86 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 87 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 88 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 89 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 90 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 91 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 92 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 93 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 94 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 95 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 96 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 97 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 98 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 99 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 100 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 101 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 102 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 103 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 104 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 105 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 106 | [102, 255, 0], [92, 0, 255]] 107 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/t5_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | from .multi_emphasis_prompt import ( 4 | get_current_option, 5 | parse_prompt_attention, 6 | ) 7 | 8 | PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) 9 | 10 | 11 | class PromptChunk: 12 | def __init__(self): 13 | self.tokens = [] 14 | self.multipliers = [] 15 | 16 | 17 | class T5TextProcessingEngine: 18 | def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256): 19 | super().__init__() 20 | 21 | self.text_encoder = text_encoder.encoder 22 | self.tokenizer = tokenizer 23 | 24 | self.device = text_encoder.device.type 25 | self.emphasis = get_current_option(emphasis_name)() 26 | self.min_length = min_length 27 | self.id_end = 1 28 | self.id_pad = 0 29 | 30 | vocab = self.tokenizer.get_vocab() 31 | 32 | self.comma_token = vocab.get(',', None) 33 | 34 | self.token_mults = {} 35 | 36 | tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] 37 | for text, ident in tokens_with_parens: 38 | mult = 1.0 39 | for c in text: 40 | if c == '[': 41 | mult /= 1.1 42 | if c == ']': 43 | mult *= 1.1 44 | if c == '(': 45 | mult *= 1.1 46 | if c == ')': 47 | mult /= 1.1 48 | 49 | if mult != 1.0: 50 | self.token_mults[ident] = mult 51 | 52 | def tokenize(self, texts): 53 | tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] 54 | return tokenized 55 | 56 | def encode_with_transformers(self, tokens): 57 | tokens = tokens.to(self.device) 58 | 59 | z = self.text_encoder( 60 | input_ids=tokens, 61 | )[0] 62 | 63 | return z 64 | 65 | def tokenize_line(self, line): 66 | if self.emphasis.name != "None": 67 | parsed = parse_prompt_attention(line) 68 | else: 69 | parsed = [[line, 1.0]] 70 | 71 | tokenized = self.tokenize([text for text, _ in parsed]) 72 | 73 | chunks = [] 74 | chunk = PromptChunk() 75 | token_count = 0 76 | 77 | def next_chunk(): 78 | nonlocal token_count 79 | nonlocal chunk 80 | 81 | chunk.tokens = chunk.tokens + [self.id_end] 82 | chunk.multipliers = chunk.multipliers + [1.0] 83 | current_chunk_length = len(chunk.tokens) 84 | 85 | token_count += current_chunk_length 86 | remaining_count = self.min_length - current_chunk_length 87 | 88 | if remaining_count > 0: 89 | chunk.tokens += [self.id_pad] * remaining_count 90 | chunk.multipliers += [1.0] * remaining_count 91 | 92 | chunks.append(chunk) 93 | chunk = PromptChunk() 94 | 95 | for tokens, (text, weight) in zip(tokenized, parsed): 96 | if text == 'BREAK' and weight == -1: 97 | next_chunk() 98 | continue 99 | 100 | position = 0 101 | while position < len(tokens): 102 | token = tokens[position] 103 | chunk.tokens.append(token) 104 | chunk.multipliers.append(weight) 105 | position += 1 106 | 107 | if chunk.tokens or not chunks: 108 | next_chunk() 109 | 110 | return chunks, token_count 111 | 112 | def __call__(self, texts): 113 | if not isinstance(texts, list): 114 | texts = [texts] 115 | 116 | zs = [] 117 | cache = {} 118 | 119 | for line in texts: 120 | if line in cache: 121 | line_z_values = cache[line] 122 | else: 123 | chunks, token_count = self.tokenize_line(line) 124 | line_z_values = [] 125 | 126 | # pad all chunks to length of longest chunk 127 | max_tokens = 0 128 | for chunk in chunks: 129 | max_tokens = max(len(chunk.tokens), max_tokens) 130 | 131 | for chunk in chunks: 132 | tokens = chunk.tokens 133 | multipliers = chunk.multipliers 134 | 135 | remaining_count = max_tokens - len(tokens) 136 | if remaining_count > 0: 137 | tokens += [self.id_pad] * remaining_count 138 | multipliers += [1.0] * remaining_count 139 | 140 | z = self.process_tokens([tokens], [multipliers])[0] 141 | line_z_values.append(z) 142 | cache[line] = line_z_values 143 | 144 | zs.extend(line_z_values) 145 | 146 | return torch.stack(zs) 147 | 148 | def process_tokens(self, batch_tokens, batch_multipliers): 149 | tokens = torch.asarray(batch_tokens) 150 | 151 | z = self.encode_with_transformers(tokens) 152 | 153 | self.emphasis.tokens = batch_tokens 154 | self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z) 155 | self.emphasis.z = z 156 | self.emphasis.after_transformers() 157 | z = self.emphasis.z 158 | 159 | return z 160 | -------------------------------------------------------------------------------- /stablepy/upscalers/pipelines/common.py: -------------------------------------------------------------------------------- 1 | from ..utils_upscaler import load_spandrel_model, upscale_with_model, release_resources_upscaler, load_file_from_url 2 | from .base import Upscaler, UpscalerData 3 | import os 4 | 5 | 6 | class UpscalerCommon(Upscaler): 7 | def __init__(self, model="R-ESRGAN 4x+", tile=192, tile_overlap=8, device="cuda", half=False, **kwargs): 8 | self.name = "RealESRGAN" 9 | super().__init__() 10 | self.scalers = get_models(self) 11 | 12 | self.device = device 13 | self.half = half 14 | self.tile = tile 15 | self.tile_overlap = tile_overlap 16 | 17 | release_resources_upscaler() 18 | 19 | try: 20 | self.model_descriptor = self.load_model(model) 21 | except Exception as e: 22 | print(f"Unable to load upscaler model {model}: {e}") 23 | self.model_descriptor = None 24 | 25 | def do_upscale(self, img): 26 | release_resources_upscaler() 27 | 28 | if self.model_descriptor is None: 29 | return img 30 | 31 | return upscale_with_model( 32 | self.model_descriptor, 33 | img, 34 | tile_size=self.tile, 35 | tile_overlap=self.tile_overlap, 36 | # TODO: `outscale`? 37 | disable_progress_bar=self.disable_progress_bar, 38 | ) 39 | 40 | def load_model(self, path): 41 | for scaler in self.scalers: 42 | if scaler.name == path: 43 | if scaler.local_data_path.startswith("http"): 44 | scaler.local_data_path = load_file_from_url( 45 | scaler.data_path, 46 | model_dir=self.model_download_path, 47 | ) 48 | if not os.path.exists(scaler.local_data_path): 49 | raise FileNotFoundError(f"Upscaler model data missing: {scaler.local_data_path}") 50 | return load_spandrel_model( 51 | scaler.local_data_path, 52 | device=self.device, 53 | prefer_half=self.half if scaler.supports_half else False, 54 | ) 55 | 56 | # Load custom model 57 | if path.startswith("http"): 58 | filename = load_file_from_url( 59 | url=path, 60 | model_dir=self.model_download_path, 61 | ) 62 | else: 63 | filename = path 64 | 65 | if not os.path.isfile(filename): 66 | raise FileNotFoundError(f"Model file {filename} not found") 67 | 68 | return load_spandrel_model( 69 | filename, 70 | device=self.device, 71 | prefer_half=self.half, 72 | ) 73 | 74 | 75 | def get_models(scaler: UpscalerCommon): 76 | return [ 77 | # ESRGAN 78 | UpscalerData( 79 | name="ESRGAN_4x", 80 | path="https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth", 81 | scale=4, 82 | upscaler=scaler, 83 | ), 84 | # R-ESRGAN 85 | UpscalerData( 86 | name="R-ESRGAN General 4xV3", 87 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 88 | scale=4, 89 | upscaler=scaler, 90 | ), 91 | UpscalerData( 92 | name="R-ESRGAN General WDN 4xV3", 93 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", 94 | scale=4, 95 | upscaler=scaler, 96 | ), 97 | UpscalerData( 98 | name="R-ESRGAN AnimeVideo", 99 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", 100 | scale=4, 101 | upscaler=scaler, 102 | ), 103 | UpscalerData( 104 | name="R-ESRGAN 4x+", 105 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 106 | scale=4, 107 | upscaler=scaler, 108 | ), 109 | UpscalerData( 110 | name="R-ESRGAN 4x+ Anime6B", 111 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 112 | scale=4, 113 | upscaler=scaler, 114 | ), 115 | UpscalerData( 116 | name="R-ESRGAN 2x+", 117 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 118 | scale=2, 119 | upscaler=scaler, 120 | ), 121 | # DAT 122 | UpscalerData( 123 | name="DAT x2", 124 | path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth", 125 | scale=2, 126 | upscaler=scaler, 127 | supports_half=False, 128 | ), 129 | UpscalerData( 130 | name="DAT x3", 131 | path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth", 132 | scale=3, 133 | upscaler=scaler, 134 | supports_half=False, 135 | ), 136 | UpscalerData( 137 | name="DAT x4", 138 | path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth", 139 | scale=4, 140 | upscaler=scaler, 141 | supports_half=False, 142 | ), 143 | # HAT 144 | UpscalerData( 145 | name="HAT x4", 146 | path="https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors", 147 | scale=4, 148 | upscaler=scaler, 149 | supports_half=False, 150 | ), 151 | ] 152 | 153 | 154 | if __name__ == "__main__": 155 | from PIL import Image 156 | 157 | up = UpscalerCommon(model="R-ESRGAN 2x+", tile=192, tile_overlap=8, device="cuda", half=False) 158 | scale_up = 1.1 159 | img = Image.open("img.png") 160 | print(img.size) 161 | img_up = up.upscale(img, scale_up) 162 | print(img_up.size) 163 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/extra_model_loaders.py: -------------------------------------------------------------------------------- 1 | from diffusers import ( 2 | MotionAdapter, 3 | AnimateDiffPipeline, 4 | AutoPipelineForImage2Image, 5 | StableDiffusionXLPipeline, 6 | StableDiffusionControlNetInpaintPipeline, 7 | StableDiffusionXLInpaintPipeline, 8 | ControlNetModel, 9 | StableDiffusionPipeline 10 | ) 11 | import torch 12 | from ..logging.logging_setup import logger 13 | 14 | 15 | def custom_task_model_loader( 16 | pipe, 17 | model_category="detailfix", 18 | task_name="txt2img", 19 | torch_dtype=torch.float16, 20 | load_text_encoder=False, 21 | ): 22 | # Pipe detailfix_pipe 23 | if model_category == "detailfix": 24 | 25 | if hasattr(pipe, "transformer"): 26 | from .extra_pipe.flux.pipeline_flux_inpaint import FluxInpaintPipeline 27 | return FluxInpaintPipeline( 28 | vae=pipe.vae, 29 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 30 | tokenizer=pipe.tokenizer, 31 | scheduler=pipe.scheduler, 32 | text_encoder_2=(pipe.text_encoder_2 if load_text_encoder else None), 33 | tokenizer_2=pipe.tokenizer_2, 34 | transformer=pipe.transformer, 35 | ) 36 | 37 | if not hasattr(pipe, "text_encoder_2"): 38 | # sd df 39 | if torch_dtype == torch.float16: 40 | type_params = {"torch_dtype": torch.float16, "variant": "fp16"} 41 | else: 42 | type_params = {"torch_dtype": torch.float32} 43 | logger.debug(f"Params detailfix sd controlnet {type_params}") 44 | controlnet_detailfix = ControlNetModel.from_pretrained( 45 | "lllyasviel/control_v11p_sd15_inpaint", **type_params, 46 | ) 47 | detailfix_pipe = StableDiffusionControlNetInpaintPipeline( 48 | vae=pipe.vae, 49 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 50 | tokenizer=pipe.tokenizer, 51 | unet=pipe.unet, 52 | controlnet=controlnet_detailfix, 53 | scheduler=pipe.scheduler, 54 | safety_checker=pipe.safety_checker, 55 | feature_extractor=pipe.feature_extractor, 56 | image_encoder=pipe.image_encoder, 57 | requires_safety_checker=pipe.config.requires_safety_checker, 58 | ) 59 | else: 60 | # sdxl df 61 | detailfix_pipe = StableDiffusionXLInpaintPipeline( 62 | vae=pipe.vae, 63 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 64 | text_encoder_2=(pipe.text_encoder_2 if load_text_encoder else None), 65 | tokenizer=pipe.tokenizer, 66 | tokenizer_2=pipe.tokenizer_2, 67 | unet=pipe.unet, 68 | # controlnet=controlnet, 69 | scheduler=pipe.scheduler, 70 | feature_extractor=pipe.feature_extractor, 71 | image_encoder=pipe.image_encoder, 72 | ) 73 | detailfix_pipe.enable_vae_slicing() 74 | detailfix_pipe.enable_vae_tiling() 75 | detailfix_pipe.watermark = None 76 | 77 | return detailfix_pipe 78 | 79 | elif model_category in ["hires", "detailfix_img2img"]: 80 | 81 | if hasattr(pipe, "transformer"): 82 | from .extra_pipe.flux.pipeline_flux_img2img import FluxImg2ImgPipeline 83 | return FluxImg2ImgPipeline( 84 | vae=pipe.vae, 85 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 86 | tokenizer=pipe.tokenizer, 87 | scheduler=pipe.scheduler, 88 | text_encoder_2=(pipe.text_encoder_2 if load_text_encoder else None), 89 | tokenizer_2=pipe.tokenizer_2, 90 | transformer=pipe.transformer, 91 | ) 92 | 93 | # Pipe hires detailfix_pipe img2img 94 | if task_name != "txt2img" or hasattr(pipe, "set_pag_applied_layers"): 95 | if not hasattr(pipe, "text_encoder_2"): 96 | hires_pipe = StableDiffusionPipeline( 97 | vae=pipe.vae, 98 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 99 | tokenizer=pipe.tokenizer, 100 | unet=pipe.unet, 101 | scheduler=pipe.scheduler, 102 | safety_checker=pipe.safety_checker, 103 | feature_extractor=pipe.feature_extractor, 104 | image_encoder=pipe.image_encoder, 105 | requires_safety_checker=pipe.config.requires_safety_checker, 106 | ) 107 | 108 | else: 109 | hires_pipe = StableDiffusionXLPipeline( 110 | vae=pipe.vae, 111 | text_encoder=(pipe.text_encoder if load_text_encoder else None), 112 | text_encoder_2=(pipe.text_encoder_2 if load_text_encoder else None), 113 | tokenizer=pipe.tokenizer, 114 | tokenizer_2=pipe.tokenizer_2, 115 | unet=pipe.unet, 116 | scheduler=pipe.scheduler, 117 | feature_extractor=pipe.feature_extractor, 118 | image_encoder=pipe.image_encoder, 119 | ) 120 | 121 | hires_pipe = AutoPipelineForImage2Image.from_pipe(hires_pipe, enable_pag=False) 122 | else: 123 | hires_pipe = AutoPipelineForImage2Image.from_pipe(pipe, enable_pag=False) 124 | 125 | if hasattr(hires_pipe, "text_encoder_2"): 126 | hires_pipe.enable_vae_slicing() 127 | hires_pipe.enable_vae_tiling() 128 | hires_pipe.watermark = None 129 | 130 | return hires_pipe 131 | 132 | elif model_category == "animatediff": 133 | # Pipe animatediff 134 | if not hasattr(pipe, "text_encoder_2"): 135 | adapter = MotionAdapter.from_pretrained( 136 | "guoyww/animatediff-motion-adapter-v1-5-2" 137 | ) 138 | adapter.to("cuda" if torch.cuda.is_available() else "cpu") 139 | 140 | animatediff_pipe = AnimateDiffPipeline( 141 | vae=pipe.vae, 142 | text_encoder=pipe.text_encoder, 143 | tokenizer=pipe.tokenizer, 144 | unet=pipe.unet, 145 | motion_adapter=adapter, 146 | scheduler=pipe.scheduler, 147 | feature_extractor=pipe.feature_extractor, 148 | image_encoder=pipe.image_encoder, 149 | ) 150 | else: 151 | raise ValueError("Animatediff not implemented for SDXL") 152 | 153 | return animatediff_pipe 154 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/transformers_lib/pipelines.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline, AutoImageProcessor, SegformerForSemanticSegmentation, UperNetForSemanticSegmentation 2 | import torch 3 | import PIL 4 | import numpy as np 5 | from ..image_utils import HWC3, resize_image, ade_palette 6 | import cv2 7 | 8 | 9 | class ZoeDepth: 10 | def __init__(self): 11 | self.model = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti", device=-1) 12 | 13 | @torch.inference_mode() 14 | def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: 15 | detect_resolution = kwargs.pop("detect_resolution", 512) 16 | image_resolution = kwargs.pop("image_resolution", 512) 17 | image = HWC3(image) 18 | image = resize_image(image, resolution=detect_resolution) 19 | image = PIL.Image.fromarray(image) 20 | 21 | result = self.model(image) 22 | depth = result["depth"] 23 | 24 | depth_array = np.array(depth) 25 | depth_inverted = np.max(depth_array) - depth_array 26 | depth_inverted = HWC3(depth_inverted.astype(np.uint8)) 27 | 28 | resize_result = resize_image( 29 | depth_inverted, resolution=image_resolution, interpolation=cv2.INTER_NEAREST 30 | ) 31 | 32 | return PIL.Image.fromarray(resize_result) 33 | 34 | def to(self, device): 35 | self.model.device = torch.device(device) 36 | self.model.model.to(device) 37 | 38 | 39 | class DPTDepthEstimator: 40 | def __init__(self): 41 | self.model = pipeline("depth-estimation", device=-1) 42 | 43 | def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: 44 | detect_resolution = kwargs.pop("detect_resolution", 512) 45 | image_resolution = kwargs.pop("image_resolution", 512) 46 | image = np.array(image) 47 | image = HWC3(image) 48 | image = resize_image(image, resolution=detect_resolution) 49 | image = PIL.Image.fromarray(image) 50 | image = self.model(image) 51 | image = image["depth"] 52 | image = np.array(image) 53 | image = HWC3(image) 54 | image = resize_image(image, resolution=image_resolution) 55 | return PIL.Image.fromarray(image) 56 | 57 | def to(self, device): 58 | self.model.model.to(device) 59 | self.model.device = torch.device(device) 60 | 61 | 62 | class UP_ImageSegmentor: 63 | def __init__(self): 64 | self.image_processor = AutoImageProcessor.from_pretrained( 65 | "openmmlab/upernet-convnext-small" 66 | ) 67 | self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained( 68 | "openmmlab/upernet-convnext-small" 69 | ) 70 | 71 | @torch.inference_mode() 72 | def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: 73 | detect_resolution = kwargs.pop("detect_resolution", 512) 74 | image_resolution = kwargs.pop("image_resolution", 512) 75 | image = HWC3(image) 76 | image = resize_image(image, resolution=detect_resolution) 77 | image = PIL.Image.fromarray(image) 78 | 79 | pixel_values = self.image_processor(image, return_tensors="pt").pixel_values 80 | outputs = self.image_segmentor(pixel_values) 81 | seg = self.image_processor.post_process_semantic_segmentation( 82 | outputs, target_sizes=[image.size[::-1]] 83 | )[0] 84 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 85 | for label, color in enumerate(ade_palette()): 86 | color_seg[seg == label, :] = color 87 | color_seg = color_seg.astype(np.uint8) 88 | 89 | color_seg = resize_image( 90 | color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST 91 | ) 92 | return PIL.Image.fromarray(color_seg) 93 | 94 | 95 | class SegFormer: 96 | def __init__(self): 97 | self.image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") 98 | self.image_segmentor = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") 99 | 100 | @torch.inference_mode() 101 | def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: 102 | detect_resolution = kwargs.pop("detect_resolution", 512) 103 | image_resolution = kwargs.pop("image_resolution", 512) 104 | image = HWC3(image) 105 | image = resize_image(image, resolution=detect_resolution) 106 | image = PIL.Image.fromarray(image) 107 | 108 | pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.to(self.image_segmentor.device.type) 109 | outputs = self.image_segmentor(pixel_values) 110 | seg = self.image_processor.post_process_semantic_segmentation( 111 | outputs, target_sizes=[image.size[::-1]] 112 | )[0].cpu() 113 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 114 | for label, color in enumerate(ade_palette()): 115 | color_seg[seg == label, :] = color 116 | color_seg = color_seg.astype(np.uint8) 117 | 118 | color_seg = resize_image( 119 | color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST 120 | ) 121 | return PIL.Image.fromarray(color_seg) 122 | 123 | def to(self, device): 124 | self.image_segmentor.to(device) 125 | 126 | 127 | class DepthAnything: 128 | def __init__(self): 129 | from transformers import AutoImageProcessor, AutoModelForDepthEstimation 130 | 131 | self.image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf") 132 | self.model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf") 133 | 134 | @torch.inference_mode() 135 | def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image: 136 | detect_resolution = kwargs.pop("detect_resolution", 512) 137 | image_resolution = kwargs.pop("image_resolution", 512) 138 | image = HWC3(image) 139 | image = resize_image(image, resolution=detect_resolution) 140 | image = PIL.Image.fromarray(image) 141 | 142 | inputs = self.image_processor(images=image, return_tensors="pt").to(self.model.device.type) 143 | 144 | with torch.no_grad(): 145 | outputs = self.model(**inputs) 146 | 147 | post_processed_output = self.image_processor.post_process_depth_estimation( 148 | outputs, 149 | target_sizes=[(image.height, image.width)], 150 | ) 151 | 152 | predicted_depth = post_processed_output[0]["predicted_depth"] 153 | depth = predicted_depth * 255 / predicted_depth.max() 154 | depth = depth.detach().cpu().numpy() 155 | 156 | depth = HWC3(depth.astype(np.uint8)) 157 | resize_result = resize_image( 158 | depth, resolution=image_resolution, interpolation=cv2.INTER_NEAREST 159 | ) 160 | 161 | return PIL.Image.fromarray(resize_result) 162 | 163 | def to(self, device): 164 | self.model.to(device) 165 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/high_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | from diffusers import DDIMScheduler 4 | from diffusers.image_processor import VaeImageProcessor 5 | 6 | # from ..upscalers.esrgan import UpscalerESRGAN, UpscalerLanczos, UpscalerNearest 7 | from ..upscalers.main_upscaler import load_upscaler_model, BUILTIN_UPSCALERS 8 | from ..logging.logging_setup import logger 9 | 10 | latent_upscale_modes = { 11 | "Latent": {"mode": "bilinear", "antialias": False}, 12 | "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, 13 | "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, 14 | "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, 15 | "Latent (nearest)": {"mode": "nearest", "antialias": False}, 16 | "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, 17 | } 18 | 19 | LATENT_UPSCALERS = list(latent_upscale_modes.keys()) 20 | ALL_BUILTIN_UPSCALERS = BUILTIN_UPSCALERS[:2] + LATENT_UPSCALERS + BUILTIN_UPSCALERS[2:] 21 | 22 | 23 | def process_images_high_resolution( 24 | images, 25 | upscaler_model_path, upscaler_increases_size, 26 | upscaler_tile_size=None, upscaler_tile_overlap=None, 27 | hires_steps=1, hires_params_config=None, 28 | task_name=None, 29 | generator=None, 30 | hires_pipe=None, 31 | disable_progress_bar=False, 32 | # hires_apply_cn_tile=None, 33 | ): 34 | 35 | def upscale_images(images, upscaler_model_path, upscaler_tile_size, upscaler_tile_overlap): 36 | device_upscaler = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | if upscaler_model_path is not None: 39 | # if upscaler_model_path == "Lanczos": 40 | # scaler = UpscalerLanczos() 41 | # elif upscaler_model_path == "Nearest": 42 | # scaler = UpscalerNearest() 43 | # else: 44 | # scaler = UpscalerESRGAN(upscaler_tile_size, upscaler_tile_overlap) 45 | 46 | scaler = load_upscaler_model( 47 | model=upscaler_model_path, 48 | tile=upscaler_tile_size, 49 | tile_overlap=upscaler_tile_overlap, 50 | device=device_upscaler, 51 | half=True if device_upscaler == "cuda" else False, 52 | ) 53 | 54 | # result_scaler = [] 55 | # for img_pre_up in images: 56 | # image_pos_up = scaler.upscale( 57 | # img_pre_up, upscaler_increases_size, upscaler_model_path 58 | # ) 59 | # torch.cuda.empty_cache() 60 | # gc.collect() 61 | # result_scaler.append(image_pos_up) 62 | 63 | result_scaler = [] 64 | for img_pre_up in images: 65 | try: 66 | image_pos_up = scaler.upscale( 67 | img_pre_up, upscaler_increases_size, disable_progress_bar 68 | ) 69 | except Exception: 70 | logger.error("Upscaler switching to basic 'Nearest'", exc_info=True) 71 | scaler = load_upscaler_model( 72 | model="Nearest", 73 | tile=upscaler_tile_size, 74 | tile_overlap=upscaler_tile_overlap, 75 | ) 76 | image_pos_up = scaler.upscale( 77 | img_pre_up, upscaler_increases_size 78 | ) 79 | result_scaler.append(image_pos_up) 80 | 81 | images = result_scaler 82 | logger.info(f"Upscale resolution: {images[0].size[0]}x{images[0].size[1]}") 83 | 84 | return images 85 | 86 | def hires_fix(images): 87 | if str(hires_pipe.__class__.__name__) in ["FluxImg2ImgPipeline", "FluxInpaintPipeline"]: 88 | hires_params_config["height"] = images[0].size[1] 89 | hires_params_config["width"] = images[0].size[0] 90 | 91 | result_hires = [] 92 | for img_pre_hires in images: 93 | try: 94 | img_pos_hires = hires_pipe( 95 | generator=generator, 96 | image=img_pre_hires, 97 | **hires_params_config, 98 | ).images[0] 99 | except Exception as e: 100 | e = str(e) 101 | if "Tensor with 2 elements cannot be converted to Scalar" in e: 102 | logger.debug(e) 103 | logger.error("Error in sampler; trying with DDIM sampler") 104 | hires_pipe.scheduler = DDIMScheduler.from_config(hires_pipe.scheduler.config) 105 | img_pos_hires = hires_pipe( 106 | generator=generator, 107 | image=img_pre_hires, 108 | **hires_params_config, 109 | ).images[0] 110 | elif "The size of tensor a (0) must match the size of tensor b (3) at non-singleton" in e or "cannot reshape tensor of 0 elements into shape [0, -1, 1, 512] because the unspecified dimensi" in e: 111 | logger.error("Strength or steps too low for the model to produce a satisfactory response, returning image only with upscaling.") 112 | img_pos_hires = img_pre_hires 113 | else: 114 | logger.error(e) 115 | logger.error("The hiresfix couldn't be applied, returning image only with upscaling.") 116 | img_pos_hires = img_pre_hires 117 | torch.cuda.empty_cache() 118 | gc.collect() 119 | result_hires.append(img_pos_hires) 120 | images = result_hires 121 | 122 | return images 123 | 124 | if upscaler_model_path in LATENT_UPSCALERS: 125 | 126 | image_processor = VaeImageProcessor() 127 | images_conversion = [] 128 | for img_base in images: 129 | if not isinstance(img_base, torch.Tensor): 130 | prep_image = image_processor.preprocess(img_base) 131 | prep_image = prep_image.to(device=hires_pipe.vae.device.type, dtype=hires_pipe.vae.dtype) 132 | 133 | with torch.no_grad(): 134 | img_base = hires_pipe.vae.encode(prep_image).latent_dist.sample() 135 | 136 | img_base = hires_pipe.vae.config.scaling_factor * img_base 137 | 138 | images_conversion.append(img_base) 139 | 140 | config_latent = latent_upscale_modes[upscaler_model_path] 141 | 142 | logger.debug(str(images_conversion[0].shape)) 143 | 144 | images = [ 145 | torch.nn.functional.interpolate( 146 | im_l, 147 | size=( 148 | int(images_conversion[0].shape[2] * upscaler_increases_size), # maybe round instead of int 149 | int(images_conversion[0].shape[3] * upscaler_increases_size), 150 | ), 151 | mode=config_latent["mode"], 152 | antialias=config_latent["antialias"], 153 | ) for im_l in images_conversion 154 | ] 155 | 156 | logger.debug(str(images[0].shape)) 157 | logger.info( 158 | "Latent resolution: " 159 | f"{images[0].shape[3] * 8}x{images[0].shape[2] * 8}" 160 | ) 161 | 162 | torch.cuda.empty_cache() 163 | 164 | else: 165 | images = upscale_images( 166 | images, upscaler_model_path, upscaler_tile_size, upscaler_tile_overlap 167 | ) 168 | 169 | if hires_steps > 1: 170 | images = hires_fix(images) 171 | 172 | return images 173 | -------------------------------------------------------------------------------- /stablepy/face_restoration/face_restoration_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | from functools import cached_property 4 | from typing import TYPE_CHECKING, Callable 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | 10 | from ..diffusers_vanilla.utils import release_resources 11 | from ..logging.logging_setup import logger 12 | 13 | if TYPE_CHECKING: 14 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 15 | 16 | 17 | def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: 18 | """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" 19 | assert img.shape[2] == 3, "image must be RGB" 20 | if img.dtype == "float64": 21 | img = img.astype("float32") 22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 23 | return torch.from_numpy(img.transpose(2, 0, 1)).float() 24 | 25 | 26 | def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: 27 | """ 28 | Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. 29 | """ 30 | tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 31 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) 32 | assert tensor.dim() == 3, "tensor must be RGB" 33 | img_np = tensor.numpy().transpose(1, 2, 0) 34 | if img_np.shape[2] == 1: # gray image, no RGB/BGR required 35 | return np.squeeze(img_np, axis=2) 36 | return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) 37 | 38 | 39 | def create_face_helper(device) -> FaceRestoreHelper: 40 | from facexlib.detection import retinaface 41 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 42 | if hasattr(retinaface, 'device'): 43 | retinaface.device = device 44 | return FaceRestoreHelper( 45 | upscale_factor=1, 46 | face_size=512, 47 | crop_ratio=(1, 1), 48 | det_model='retinaface_resnet50', 49 | save_ext='png', 50 | use_parse=True, 51 | device=device, 52 | ) 53 | 54 | 55 | def restore_with_face_helper( 56 | np_image: np.ndarray, 57 | face_helper: FaceRestoreHelper, 58 | restore_face: Callable[[torch.Tensor], torch.Tensor], 59 | device, 60 | ) -> np.ndarray: 61 | """ 62 | Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. 63 | 64 | `restore_face` should take a cropped face image and return a restored face image. 65 | """ 66 | from torchvision.transforms.functional import normalize 67 | np_image = np_image[:, :, ::-1] 68 | original_resolution = np_image.shape[0:2] 69 | 70 | try: 71 | logger.debug("Detecting faces...") 72 | face_helper.clean_all() 73 | face_helper.read_image(np_image) 74 | face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) 75 | face_helper.align_warp_face() 76 | logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) 77 | for cropped_face in face_helper.cropped_faces: 78 | cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) 79 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 80 | cropped_face_t = cropped_face_t.unsqueeze(0).to(device) 81 | 82 | try: 83 | with torch.no_grad(): 84 | cropped_face_t = restore_face(cropped_face_t) 85 | release_resources() 86 | except Exception: 87 | logger.error('Failed face-restoration inference', exc_info=True) 88 | 89 | restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) 90 | restored_face = (restored_face * 255.0).astype('uint8') 91 | face_helper.add_restored_face(restored_face) 92 | 93 | logger.debug("Merging restored faces into image") 94 | face_helper.get_inverse_affine(None) 95 | img = face_helper.paste_faces_to_input_image() 96 | img = img[:, :, ::-1] 97 | if original_resolution != img.shape[0:2]: 98 | img = cv2.resize( 99 | img, 100 | (0, 0), 101 | fx=original_resolution[1] / img.shape[1], 102 | fy=original_resolution[0] / img.shape[0], 103 | interpolation=cv2.INTER_LINEAR, 104 | ) 105 | logger.debug("Face restoration complete") 106 | finally: 107 | face_helper.clean_all() 108 | return img 109 | 110 | 111 | class FaceRestoration: 112 | def name(self): 113 | return "None" 114 | 115 | def restore(self, np_image): 116 | return np_image 117 | 118 | 119 | class CommonFaceRestoration(FaceRestoration): 120 | net: torch.Module | None 121 | model_url: str 122 | model_download_name: str 123 | 124 | def __init__(self, device): 125 | super().__init__() 126 | self.device = device 127 | self.net = None 128 | self.model_path = os.path.join(os.path.expanduser("~"), ".cache", "face_restoration_models") 129 | os.makedirs(self.model_path, exist_ok=True) 130 | 131 | @cached_property 132 | def face_helper(self) -> FaceRestoreHelper: 133 | return create_face_helper(self.get_device()) 134 | 135 | def send_model_to(self, device): 136 | if self.net: 137 | logger.debug("Sending %s to %s", self.net.__class__.__name__, device) 138 | self.net.to(device) 139 | if self.face_helper: 140 | logger.debug("Sending face helper to %s", device) 141 | self.face_helper.face_det.to(device) 142 | self.face_helper.face_parse.to(device) 143 | 144 | def get_device(self): 145 | raise NotImplementedError("get_device must be implemented by subclasses") 146 | 147 | def load_net(self) -> torch.Module: 148 | raise NotImplementedError("load_net must be implemented by subclasses") 149 | 150 | def restore_with_helper( 151 | self, 152 | np_image: np.ndarray, 153 | restore_face: Callable[[torch.Tensor], torch.Tensor], 154 | ) -> np.ndarray: 155 | try: 156 | if self.net is None: 157 | self.net = self.load_net() 158 | except Exception: 159 | logger.warning("Unable to load face-restoration model", exc_info=True) 160 | return np_image 161 | 162 | try: 163 | release_resources() 164 | self.send_model_to(self.get_device()) 165 | return restore_with_face_helper(np_image, self.face_helper, restore_face, self.get_device()) 166 | finally: 167 | # self.send_model_to("cpu") 168 | pass 169 | 170 | 171 | def patch_facexlib(dirname: str) -> None: 172 | import facexlib.detection 173 | import facexlib.parsing 174 | 175 | det_facex_load_file_from_url = facexlib.detection.load_file_from_url 176 | par_facex_load_file_from_url = facexlib.parsing.load_file_from_url 177 | 178 | def update_kwargs(kwargs): 179 | return dict(kwargs, save_dir=dirname, model_dir=None) 180 | 181 | def facex_load_file_from_url(**kwargs): 182 | return det_facex_load_file_from_url(**update_kwargs(kwargs)) 183 | 184 | def facex_load_file_from_url2(**kwargs): 185 | return par_facex_load_file_from_url(**update_kwargs(kwargs)) 186 | 187 | facexlib.detection.load_file_from_url = facex_load_file_from_url 188 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 189 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/prompt_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | 4 | ESCAPED_SYNTACTIC_SYMBOLS = [ 5 | '"', 6 | '(', 7 | ')', 8 | '=', 9 | # '-', 10 | # '+', 11 | # '.', 12 | # ',', 13 | ] 14 | 15 | TRANSLATION_DICT = { 16 | ord(symbol): "\\" + symbol for symbol in ESCAPED_SYNTACTIC_SYMBOLS 17 | } 18 | 19 | 20 | def parse_prompt_attention(text): 21 | re_attention = re.compile(r""" 22 | \\\(| 23 | \\\)| 24 | \\\[| 25 | \\]| 26 | \\\\| 27 | \\| 28 | \(| 29 | \[| 30 | :([+-]?[.\d]+)\)| 31 | \)| 32 | ]| 33 | [^\\()\[\]:]+| 34 | : 35 | """, re.X) 36 | 37 | res = [] 38 | round_brackets = [] 39 | square_brackets = [] 40 | 41 | round_bracket_multiplier = 1.1 42 | square_bracket_multiplier = 1 / 1.1 43 | 44 | def multiply_range(start_position, multiplier): 45 | for p in range(start_position, len(res)): 46 | res[p][1] *= multiplier 47 | 48 | for m in re_attention.finditer(text): 49 | text = m.group(0) 50 | weight = m.group(1) 51 | 52 | if text.startswith('\\'): 53 | res.append([text[1:], 1.0]) 54 | elif text == '(': 55 | round_brackets.append(len(res)) 56 | elif text == '[': 57 | square_brackets.append(len(res)) 58 | elif weight is not None and len(round_brackets) > 0: 59 | multiply_range(round_brackets.pop(), float(weight)) 60 | elif text == ')' and len(round_brackets) > 0: 61 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 62 | elif text == ']' and len(square_brackets) > 0: 63 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 64 | else: 65 | parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text) 66 | for i, part in enumerate(parts): 67 | if i > 0: 68 | res.append(["BREAK", -1]) 69 | res.append([part, 1.0]) 70 | 71 | for pos in round_brackets: 72 | multiply_range(pos, round_bracket_multiplier) 73 | 74 | for pos in square_brackets: 75 | multiply_range(pos, square_bracket_multiplier) 76 | 77 | if len(res) == 0: 78 | res = [["", 1.0]] 79 | 80 | # merge runs of identical weights 81 | i = 0 82 | while i + 1 < len(res): 83 | if res[i][1] == res[i + 1][1]: 84 | res[i][0] += res[i + 1][0] 85 | res.pop(i + 1) 86 | else: 87 | i += 1 88 | 89 | return res 90 | 91 | 92 | def prompt_attention_to_invoke_prompt(attention): 93 | tokens = [] 94 | for text, weight in attention: 95 | text = text.translate(TRANSLATION_DICT) 96 | 97 | # Round weight to 2 decimal places 98 | weight = round(weight, 2) 99 | if weight == 1.0: 100 | tokens.append(text) 101 | elif weight < 1.0: 102 | if weight < 0.8: 103 | tokens.append(f"({text}){weight}") 104 | else: 105 | tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10)) 106 | else: 107 | if weight < 1.3: 108 | tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10)) 109 | else: 110 | tokens.append(f"({text}){weight}") 111 | return "".join(tokens) 112 | 113 | 114 | def concat_tensor(t): 115 | t_list = torch.split(t, 1, dim=0) 116 | t = torch.cat(t_list, dim=1) 117 | return t 118 | 119 | 120 | def merge_embeds(prompt_chanks, compel): 121 | num_chanks = len(prompt_chanks) 122 | if num_chanks != 0: 123 | power_prompt = 1/(num_chanks*(num_chanks+1)//2) 124 | prompt_embs = compel(prompt_chanks) 125 | t_list = list(torch.split(prompt_embs, 1, dim=0)) 126 | for i in range(num_chanks): 127 | t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt) 128 | prompt_emb = torch.stack(t_list, dim=0).sum(dim=0) 129 | else: 130 | prompt_emb = compel('') 131 | return prompt_emb 132 | 133 | 134 | def detokenize(chunk, actual_prompt): 135 | chunk[-1] = chunk[-1].replace('', '') 136 | chanked_prompt = ''.join(chunk).strip() 137 | while '' in chanked_prompt: 138 | if actual_prompt[chanked_prompt.find('')] == ' ': 139 | chanked_prompt = chanked_prompt.replace('', ' ', 1) 140 | else: 141 | chanked_prompt = chanked_prompt.replace('', '', 1) 142 | actual_prompt = actual_prompt.replace(chanked_prompt, '') 143 | return chanked_prompt.strip(), actual_prompt.strip() 144 | 145 | 146 | def tokenize_line(line, tokenizer): # split into chunks 147 | actual_prompt = line.lower().strip() 148 | actual_tokens = tokenizer.tokenize(actual_prompt) 149 | max_tokens = tokenizer.model_max_length - 2 150 | comma_token = tokenizer.tokenize(',')[0] 151 | 152 | chunks = [] 153 | chunk = [] 154 | for item in actual_tokens: 155 | chunk.append(item) 156 | if len(chunk) == max_tokens: 157 | if chunk[-1] != comma_token: 158 | for i in range(max_tokens-1, -1, -1): 159 | if chunk[i] == comma_token: 160 | actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt) 161 | chunks.append(actual_chunk) 162 | chunk = chunk[i+1:] 163 | break 164 | else: 165 | actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) 166 | chunks.append(actual_chunk) 167 | chunk = [] 168 | else: 169 | actual_chunk, actual_prompt = detokenize(chunk, actual_prompt) 170 | chunks.append(actual_chunk) 171 | chunk = [] 172 | if chunk: 173 | actual_chunk, _ = detokenize(chunk, actual_prompt) 174 | chunks.append(actual_chunk) 175 | 176 | return chunks 177 | 178 | 179 | def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False): 180 | 181 | if compel_process_sd: 182 | return merge_embeds(tokenize_line(prompt, pipeline.tokenizer), compel) 183 | else: 184 | # fix bug weights conversion excessive emphasis 185 | prompt = prompt.replace("((", "(").replace("))", ")") 186 | 187 | # Convert to Compel 188 | attention = parse_prompt_attention(prompt) 189 | global_attention_chanks = [] 190 | 191 | for att in attention: 192 | for chank in att[0].split(','): 193 | temp_prompt_chanks = tokenize_line(chank, pipeline.tokenizer) 194 | for small_chank in temp_prompt_chanks: 195 | temp_dict = { 196 | "weight": round(att[1], 2), 197 | "lenght": len(pipeline.tokenizer.tokenize(f'{small_chank},')), 198 | "prompt": f'{small_chank},' 199 | } 200 | global_attention_chanks.append(temp_dict) 201 | 202 | max_tokens = pipeline.tokenizer.model_max_length - 2 203 | global_prompt_chanks = [] 204 | current_list = [] 205 | current_length = 0 206 | for item in global_attention_chanks: 207 | if current_length + item['lenght'] > max_tokens: 208 | global_prompt_chanks.append(current_list) 209 | current_list = [[item['prompt'], item['weight']]] 210 | current_length = item['lenght'] 211 | else: 212 | if not current_list: 213 | current_list.append([item['prompt'], item['weight']]) 214 | else: 215 | if item['weight'] != current_list[-1][1]: 216 | current_list.append([item['prompt'], item['weight']]) 217 | else: 218 | current_list[-1][0] += f" {item['prompt']}" 219 | current_length += item['lenght'] 220 | if current_list: 221 | global_prompt_chanks.append(current_list) 222 | 223 | if only_convert_string: 224 | return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks]) 225 | 226 | return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel) 227 | 228 | 229 | def add_comma_after_pattern_ti(text): 230 | pattern = re.compile(r'\b\w+_\d+\b') 231 | modified_text = pattern.sub(lambda x: x.group() + ',', text) 232 | return modified_text 233 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/sampler_scheduler_config.py: -------------------------------------------------------------------------------- 1 | from diffusers import EulerDiscreteScheduler 2 | import json 3 | from huggingface_hub import hf_hub_download 4 | from ..logging.logging_setup import logger 5 | from .constants import ( 6 | SD15, 7 | SDXL, 8 | FLUX, 9 | INCOMPATIBILITY_SAMPLER_SCHEDULE, 10 | SCHEDULE_TYPES, 11 | SCHEDULE_TYPE_OPTIONS, 12 | SCHEDULE_PREDICTION_TYPE, 13 | AYS_SCHEDULES, 14 | FLUX_SCHEDULE_TYPES, 15 | FLUX_SCHEDULE_TYPE_OPTIONS, 16 | ) 17 | import numpy as np 18 | 19 | 20 | def configure_scheduler(pipe, schedule_type, schedule_prediction_type): 21 | 22 | if "Flux" in str(pipe.__class__.__name__): 23 | 24 | flux_selected_schedule = FLUX_SCHEDULE_TYPES.get(schedule_type, None) 25 | if flux_selected_schedule: 26 | pipe.scheduler.register_to_config(**flux_selected_schedule) 27 | 28 | return None 29 | 30 | # Get the configuration for the selected schedule 31 | selected_schedule = SCHEDULE_TYPES.get(schedule_type) 32 | 33 | if selected_schedule: 34 | # Set all schedule types to False first 35 | default_config = { 36 | "use_karras_sigmas": False, 37 | "use_exponential_sigmas": False, 38 | "use_beta_sigmas": False, 39 | } 40 | pipe.scheduler.register_to_config(**default_config) 41 | 42 | # Apply the specific configuration for the selected schedule 43 | pipe.scheduler.register_to_config(**selected_schedule) 44 | 45 | # Get the configuration for the selected prediction type 46 | selected_prediction = SCHEDULE_PREDICTION_TYPE.get( 47 | schedule_prediction_type 48 | ) 49 | 50 | if selected_prediction: 51 | # Update the prediction type in the scheduler's config 52 | if isinstance(selected_prediction, dict): 53 | pipe.scheduler.register_to_config(**selected_prediction) 54 | else: 55 | pipe.scheduler.register_to_config( 56 | prediction_type=selected_prediction 57 | ) 58 | 59 | if ( 60 | hasattr(pipe.scheduler.config, "prediction_type") 61 | and pipe.scheduler.config.prediction_type == "v_prediction" 62 | ): 63 | pipe.scheduler.register_to_config( 64 | rescale_betas_zero_snr=True, 65 | ) 66 | 67 | 68 | def verify_schedule_integrity(model_scheduler, base_model_id): 69 | # noobai v-pred repo id 70 | if base_model_id.lower().startswith("laxhar/noobai-xl-vpred-"): 71 | model_scheduler.register_to_config( 72 | prediction_type="v_prediction", 73 | ) 74 | 75 | if ( 76 | hasattr(model_scheduler.config, "prediction_type") 77 | and model_scheduler.config.prediction_type == "v_prediction" 78 | ): 79 | model_scheduler.register_to_config( 80 | rescale_betas_zero_snr=True, 81 | ) 82 | 83 | if not hasattr(model_scheduler.config, "algorithm_type"): 84 | return model_scheduler 85 | 86 | logger.debug("Resetting scheduler settings") 87 | 88 | scheduler_xl = hf_hub_download( 89 | repo_id="stabilityai/stable-diffusion-xl-base-1.0", 90 | filename="scheduler/scheduler_config.json" 91 | ) 92 | with open(scheduler_xl, 'r', encoding="utf-8") as file: 93 | params_ = json.load(file) 94 | 95 | original_scheduler = EulerDiscreteScheduler.from_config(params_) 96 | 97 | model_params_ = dict(model_scheduler.config.items()) 98 | original_params_ = dict(original_scheduler.config.items()) 99 | 100 | new_value_params = {} 101 | for k, v in model_params_.items(): 102 | if not k.startswith("_") and k in original_params_: 103 | new_value_params[k] = v 104 | 105 | logger.debug( 106 | "The next configurations are loaded" 107 | f" from the repo model scheduler: {(new_value_params)}" 108 | ) 109 | 110 | original_scheduler.register_to_config( 111 | **new_value_params 112 | ) 113 | 114 | if ( 115 | hasattr(original_scheduler.config, "prediction_type") 116 | and original_scheduler.config.prediction_type == "v_prediction" 117 | ): 118 | original_scheduler.register_to_config( 119 | rescale_betas_zero_snr=True, 120 | ) 121 | 122 | return original_scheduler 123 | 124 | 125 | def check_scheduler_compatibility(cls, sampler, schedule_type): 126 | msg = "" 127 | auto_schedule = SCHEDULE_TYPE_OPTIONS[0] 128 | 129 | for old_sampler_config, def_schedule in [("Karras", "Karras"), ("trailing", "SGM Uniform"), (" Lu", "Lambdas")]: 130 | if old_sampler_config in sampler: 131 | sampler = sampler.replace(old_sampler_config, "").strip() 132 | if schedule_type == auto_schedule: 133 | schedule_type = def_schedule 134 | 135 | if cls == FLUX: 136 | if "Flow" not in sampler: 137 | sampler = "FlowMatch DPM++ 2M" 138 | msg += ( 139 | "The selected sampler does not work with FLUX models;" 140 | f" so it has been switched to {sampler}. " 141 | ) 142 | 143 | valid_schedule = FLUX_SCHEDULE_TYPES.get(schedule_type, None) 144 | 145 | if schedule_type != auto_schedule: 146 | if sampler == "FlowMatch Euler": 147 | msg += ( 148 | "FlowMatch Euler only support" 149 | f" '{auto_schedule}' schedule type." 150 | ) 151 | schedule_type = auto_schedule 152 | elif not valid_schedule: 153 | msg += ( 154 | f"The sampler: {sampler} only support schedule types" 155 | f": {', '.join(FLUX_SCHEDULE_TYPE_OPTIONS)}" 156 | f". Changed to '{auto_schedule}'." 157 | ) 158 | schedule_type = auto_schedule 159 | 160 | return sampler, schedule_type, msg 161 | 162 | if "Flow" in sampler: 163 | sampler = sampler.replace("FlowMatch ", "") 164 | msg += ( 165 | "The selected sampler works only with FLUX models;" 166 | f" so it has been switched to {sampler}. " 167 | ) 168 | 169 | incompatible_schedule = INCOMPATIBILITY_SAMPLER_SCHEDULE.get(sampler, []) 170 | if schedule_type in incompatible_schedule: 171 | COMPATIBLE_SCHEDULES = [ 172 | item for item in SCHEDULE_TYPE_OPTIONS 173 | if item not in incompatible_schedule 174 | ] 175 | 176 | msg += ( 177 | f"The sampler: {sampler} only support schedule types" 178 | f": {', '.join(COMPATIBLE_SCHEDULES)}" 179 | f". Changed to '{auto_schedule}'." 180 | ) 181 | schedule_type = auto_schedule 182 | 183 | return sampler, schedule_type, msg 184 | 185 | 186 | def loglinear_interp(t_steps, num_steps): 187 | """ 188 | Performs log-linear interpolation of a given array of decreasing numbers. 189 | """ 190 | xs = np.linspace(0, 1, len(t_steps)) 191 | ys = np.log(t_steps[::-1]) 192 | 193 | new_xs = np.linspace(0, 1, num_steps) 194 | new_ys = np.interp(new_xs, xs, ys) 195 | 196 | interped_ys = np.exp(new_ys)[::-1].copy() 197 | return interped_ys 198 | 199 | 200 | def ays_timesteps(cls, schedule, num_steps): 201 | if schedule not in AYS_SCHEDULES: 202 | return {} 203 | 204 | list_steps = AYS_SCHEDULES[schedule] 205 | 206 | if cls == SD15: 207 | steps = list_steps[0] 208 | elif cls == SDXL: 209 | steps = list_steps[1] 210 | else: 211 | raise ValueError( 212 | f"The pipeline {cls} does not support AYS scheduling." 213 | ) 214 | 215 | key_param = "sigmas" if "sigmas" in schedule else "timesteps" 216 | 217 | if schedule == "AYS timesteps": 218 | t_steps = loglinear_interp(steps, num_steps) 219 | # steps = t_steps[t_steps != 0] 220 | # t_steps = np.append(t_steps, 0) 221 | steps = np.round(t_steps).astype(int).tolist() 222 | elif schedule == "AYS sigmas": 223 | t_steps = loglinear_interp(steps, num_steps) 224 | t_steps[-1] = .0 225 | steps = t_steps.tolist() 226 | 227 | return {key_param: steps} 228 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/lora_loader.py: -------------------------------------------------------------------------------- 1 | # ===================================== 2 | # LoRA Loaders 3 | # ===================================== 4 | import torch 5 | from safetensors.torch import load_file 6 | from collections import defaultdict 7 | from ..logging.logging_setup import logger 8 | import safetensors 9 | import os 10 | import string 11 | import traceback 12 | import logging 13 | 14 | VALID_LORA_LAYERS_SDXL = [ 15 | "input_blocks", 16 | "middle_block", 17 | "output_blocks", 18 | "text_model", 19 | ".down_blocks", 20 | ".mid_block", 21 | ".up_blocks", 22 | # "text_projection", # text encoder 2 layer 23 | # "conv_in", # unet extra layers 24 | # "time_proj", 25 | # "time_embedding", 26 | # "time_embedding.linear_1", 27 | # "time_embedding.act", 28 | # "time_embedding.linear_2", 29 | # "add_time_proj", 30 | # "add_embedding", 31 | # "add_embedding.linear_1", 32 | # "add_embedding.linear_2", 33 | # "conv_norm_out", 34 | # "conv_out" 35 | ] 36 | 37 | 38 | def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype): 39 | LORA_PREFIX_UNET = "lora_unet" 40 | LORA_PREFIX_TEXT_ENCODER = "lora_te" 41 | # load LoRA weight from .safetensors 42 | if isinstance(checkpoint_path, str): 43 | checkpoint_path = [checkpoint_path] 44 | for ckptpath in checkpoint_path: 45 | state_dict = load_file(ckptpath, device=device) 46 | 47 | updates = defaultdict(dict) 48 | for key, value in state_dict.items(): 49 | # it is suggested to print out the key, it usually will be something like below 50 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 51 | 52 | layer, elem = key.split(".", 1) 53 | updates[layer][elem] = value 54 | 55 | # directly update weight in diffusers model 56 | for layer, elems in updates.items(): 57 | if "text" in layer: 58 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split( 59 | "_" 60 | ) 61 | curr_layer = pipeline.text_encoder 62 | else: 63 | layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") 64 | curr_layer = pipeline.unet 65 | 66 | # find the target layer 67 | temp_name = layer_infos.pop(0) 68 | while len(layer_infos) > -1: 69 | try: 70 | curr_layer = curr_layer.__getattr__(temp_name) 71 | if len(layer_infos) > 0: 72 | temp_name = layer_infos.pop(0) 73 | elif len(layer_infos) == 0: 74 | break 75 | except Exception: 76 | if len(temp_name) > 0: 77 | temp_name += "_" + layer_infos.pop(0) 78 | else: 79 | temp_name = layer_infos.pop(0) 80 | 81 | # get elements for this layer 82 | weight_up = elems["lora_up.weight"].to(dtype) 83 | weight_down = elems["lora_down.weight"].to(dtype) 84 | alpha = elems["alpha"] 85 | if alpha: 86 | alpha = alpha.item() / weight_up.shape[1] 87 | else: 88 | alpha = 1.0 89 | 90 | # update weight 91 | if len(weight_up.shape) == 4: 92 | curr_layer.weight.data += ( 93 | multiplier 94 | * alpha 95 | * torch.mm( 96 | weight_up.squeeze(3).squeeze(2), 97 | weight_down.squeeze(3).squeeze(2), 98 | ) 99 | .unsqueeze(2) 100 | .unsqueeze(3) 101 | ) 102 | else: 103 | curr_layer.weight.data += ( 104 | multiplier * alpha * torch.mm(weight_up, weight_down) 105 | ) 106 | 107 | logger.debug(f"Config LoRA: multiplier {multiplier} | alpha {alpha}") 108 | 109 | return pipeline 110 | 111 | 112 | def validate_lora_layers(lora_path): 113 | state_dict = safetensors.torch.load_file(lora_path, device="cpu") 114 | state_dict = { 115 | k: w for k, w in state_dict.items() 116 | if any(ly in k for ly in VALID_LORA_LAYERS_SDXL) 117 | } 118 | 119 | return state_dict 120 | 121 | 122 | def lora_mix_load(pipe, lora_path, alpha_scale=1.0, device="cuda", dtype=torch.float16): 123 | if hasattr(pipe, "text_encoder_2"): 124 | # sdxl lora 125 | try: 126 | pipe.load_lora_weights(lora_path) 127 | pipe.fuse_lora(lora_scale=alpha_scale) 128 | pipe.unload_lora_weights() 129 | except Exception as e: 130 | pipe.unload_lora_weights() 131 | if "size mismatch for" in str(e) or not os.path.exists(lora_path): 132 | raise e 133 | 134 | logger.debug(str(e)) 135 | 136 | state_dict = validate_lora_layers(lora_path) 137 | 138 | if not state_dict: 139 | raise ValueError("No valid lora layers were found.") 140 | 141 | try: 142 | pipe.load_lora_weights(state_dict) 143 | pipe.fuse_lora(lora_scale=alpha_scale) 144 | pipe.unload_lora_weights() 145 | except Exception as e: 146 | pipe.unload_lora_weights() 147 | raise e 148 | else: 149 | # sd lora 150 | try: 151 | pipe = load_lora_weights( 152 | pipe, [lora_path], alpha_scale, device=device, dtype=dtype 153 | ) 154 | except Exception as e: 155 | logger.debug(f"{str(e)} \nDiffusers loader>>") 156 | try: 157 | pipe.load_lora_weights(lora_path) 158 | pipe.fuse_lora(lora_scale=alpha_scale) 159 | pipe.unload_lora_weights() 160 | except Exception as e: 161 | pipe.unload_lora_weights() 162 | raise e 163 | 164 | return pipe 165 | 166 | 167 | def load_no_fused_lora(pipe, num_loras, current_lora_list, current_lora_scale_list): 168 | 169 | lora_status = [None] * num_loras 170 | 171 | logger.debug("Unloading and reloading LoRA weights on the fly") 172 | pipe.unload_lora_weights() 173 | 174 | active_adapters = [] 175 | active_adapters_scales = [] 176 | number_to_value = {i: letter for i, letter in enumerate(string.ascii_uppercase[:num_loras])} 177 | 178 | for i, (lora, scale) in enumerate(zip(current_lora_list, current_lora_scale_list)): 179 | if lora: 180 | try: 181 | adapter_name = number_to_value[i] 182 | pipe.load_lora_weights(lora, adapter_name=adapter_name) 183 | active_adapters.append(adapter_name) 184 | active_adapters_scales.append(scale) 185 | lora_status[i] = True 186 | logger.info(f"Loaded LoRA on the fly: {lora}") 187 | except Exception as e: 188 | lora_status[i] = False 189 | if lora in pipe.get_active_adapters(): 190 | pipe.delete_adapters(lora) 191 | 192 | if "size mismatch for" in str(e) or not os.path.exists(lora): 193 | if logger.isEnabledFor(logging.DEBUG): 194 | traceback.print_exc() 195 | raise RuntimeError(f"ERROR > LoRA not compatible: {lora}") 196 | 197 | state_dict = validate_lora_layers(lora) 198 | 199 | if not state_dict: 200 | logger.debug("No valid LoRA layers were found.") 201 | if logger.isEnabledFor(logging.DEBUG): 202 | traceback.print_exc() 203 | raise RuntimeError(f"ERROR > LoRA not compatible: {lora}") 204 | 205 | try: 206 | pipe.load_lora_weights(state_dict, adapter_name=adapter_name) 207 | active_adapters.append(adapter_name) 208 | active_adapters_scales.append(scale) 209 | logger.info(f"Loaded LoRA on the fly: {lora}") 210 | lora_status[i] = True 211 | except Exception: 212 | if lora in pipe.get_active_adapters(): 213 | pipe.delete_adapters(lora) 214 | if logger.isEnabledFor(logging.DEBUG): 215 | traceback.print_exc() 216 | raise RuntimeError(f"ERROR > LoRA not compatible: {lora}") 217 | 218 | if active_adapters: 219 | pipe.set_adapters(active_adapters, adapter_weights=active_adapters_scales) 220 | 221 | return lora_status # A wrongly loaded LoRA can cause issues. 222 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/extra_scheduler/scheduling_euler_discrete_variants.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | from diffusers.utils import logging 4 | from diffusers.utils.torch_utils import randn_tensor 5 | import torch 6 | import math 7 | from diffusers.schedulers.scheduling_euler_discrete import ( 8 | EulerDiscreteScheduler, 9 | EulerDiscreteSchedulerOutput, 10 | ) 11 | 12 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 13 | 14 | 15 | class EulerDiscreteSchedulerNegative(EulerDiscreteScheduler): 16 | 17 | def step( 18 | self, 19 | model_output: torch.Tensor, 20 | timestep: Union[float, torch.Tensor], 21 | sample: torch.Tensor, 22 | s_churn: float = 0.0, 23 | s_tmin: float = 0.0, 24 | s_tmax: float = float("inf"), 25 | s_noise: float = 1.0, 26 | generator: Optional[torch.Generator] = None, 27 | return_dict: bool = True, 28 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 29 | """ 30 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 31 | process from the learned model outputs (most often the predicted noise). 32 | 33 | Args: 34 | model_output (`torch.Tensor`): 35 | The direct output from learned diffusion model. 36 | timestep (`float`): 37 | The current discrete timestep in the diffusion chain. 38 | sample (`torch.Tensor`): 39 | A current instance of a sample created by the diffusion process. 40 | s_churn (`float`): 41 | s_tmin (`float`): 42 | s_tmax (`float`): 43 | s_noise (`float`, defaults to 1.0): 44 | Scaling factor for noise added to the sample. 45 | generator (`torch.Generator`, *optional*): 46 | A random number generator. 47 | return_dict (`bool`): 48 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 49 | tuple. 50 | 51 | Returns: 52 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 53 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 54 | returned, otherwise a tuple is returned where the first element is the sample tensor. 55 | """ 56 | 57 | if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): 58 | raise ValueError( 59 | ( 60 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 61 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 62 | " one of the `scheduler.timesteps` as a timestep." 63 | ), 64 | ) 65 | 66 | if not self.is_scale_input_called: 67 | logger.warning( 68 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 69 | "See `StableDiffusionPipeline` for a usage example." 70 | ) 71 | 72 | if self.step_index is None: 73 | self._init_step_index(timestep) 74 | 75 | # Upcast to avoid precision issues when computing prev_sample 76 | sample = sample.to(torch.float32) 77 | 78 | sigma = self.sigmas[self.step_index] 79 | 80 | gamma = max(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 81 | 82 | sigma_hat = sigma * (gamma + 1) 83 | 84 | if gamma > 0: 85 | noise = randn_tensor( 86 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator 87 | ) 88 | eps = noise * s_noise 89 | sample = sample - eps * (sigma_hat**2 - sigma**2) ** 0.5 90 | 91 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 92 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 93 | # backwards compatibility 94 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": 95 | pred_original_sample = model_output 96 | elif self.config.prediction_type == "epsilon": 97 | pred_original_sample = sample - sigma_hat * model_output 98 | elif self.config.prediction_type == "v_prediction": 99 | # denoised = model_output * c_out + input * c_skip 100 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 101 | else: 102 | raise ValueError( 103 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 104 | ) 105 | 106 | # 2. Convert to an ODE derivative 107 | derivative = (sample - pred_original_sample) / sigma_hat 108 | 109 | dt = self.sigmas[self.step_index + 1] - sigma_hat 110 | 111 | if self.sigmas[self.step_index + 1] > 0 and self.step_index // 2 == 1: 112 | prev_sample = -sample - derivative * dt 113 | else: 114 | prev_sample = sample + derivative * dt 115 | 116 | # Cast sample back to model compatible dtype 117 | prev_sample = prev_sample.to(model_output.dtype) 118 | 119 | # upon completion increase step index by one 120 | self._step_index += 1 121 | 122 | if not return_dict: 123 | return ( 124 | prev_sample, 125 | pred_original_sample, 126 | ) 127 | 128 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 129 | 130 | 131 | class EulerDiscreteSchedulerMax(EulerDiscreteScheduler): 132 | 133 | def step( 134 | self, 135 | model_output: torch.Tensor, 136 | timestep: Union[float, torch.Tensor], 137 | sample: torch.Tensor, 138 | s_churn: float = 0.0, 139 | s_tmin: float = 0.0, 140 | s_tmax: float = float("inf"), 141 | s_noise: float = 1.0, 142 | generator: Optional[torch.Generator] = None, 143 | return_dict: bool = True, 144 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 145 | """ 146 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 147 | process from the learned model outputs (most often the predicted noise). 148 | 149 | Args: 150 | model_output (`torch.Tensor`): 151 | The direct output from learned diffusion model. 152 | timestep (`float`): 153 | The current discrete timestep in the diffusion chain. 154 | sample (`torch.Tensor`): 155 | A current instance of a sample created by the diffusion process. 156 | s_churn (`float`): 157 | s_tmin (`float`): 158 | s_tmax (`float`): 159 | s_noise (`float`, defaults to 1.0): 160 | Scaling factor for noise added to the sample. 161 | generator (`torch.Generator`, *optional*): 162 | A random number generator. 163 | return_dict (`bool`): 164 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 165 | tuple. 166 | 167 | Returns: 168 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 169 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 170 | returned, otherwise a tuple is returned where the first element is the sample tensor. 171 | """ 172 | 173 | if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): 174 | raise ValueError( 175 | ( 176 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 177 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 178 | " one of the `scheduler.timesteps` as a timestep." 179 | ), 180 | ) 181 | 182 | if not self.is_scale_input_called: 183 | logger.warning( 184 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 185 | "See `StableDiffusionPipeline` for a usage example." 186 | ) 187 | 188 | if self.step_index is None: 189 | self._init_step_index(timestep) 190 | 191 | # Upcast to avoid precision issues when computing prev_sample 192 | sample = sample.to(torch.float32) 193 | 194 | sigma = self.sigmas[self.step_index] 195 | 196 | gamma = max(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 197 | 198 | sigma_hat = sigma * (gamma + 1) 199 | 200 | if gamma > 0: 201 | noise = randn_tensor( 202 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator 203 | ) 204 | eps = noise * s_noise 205 | sample = sample - eps * (sigma_hat**2 - sigma**2) ** 0.5 206 | 207 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 208 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 209 | # backwards compatibility 210 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": 211 | pred_original_sample = model_output 212 | elif self.config.prediction_type == "epsilon": 213 | pred_original_sample = sample - sigma_hat * model_output 214 | elif self.config.prediction_type == "v_prediction": 215 | # denoised = model_output * c_out + input * c_skip 216 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 217 | else: 218 | raise ValueError( 219 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 220 | ) 221 | 222 | # 2. Convert to an ODE derivative 223 | derivative = (sample - pred_original_sample) / sigma_hat 224 | 225 | dt = self.sigmas[self.step_index + 1] - sigma_hat 226 | 227 | prev_sample = sample + (math.cos(self.step_index + 1) / (self.step_index + 1) + 1) * derivative * dt 228 | 229 | # Cast sample back to model compatible dtype 230 | prev_sample = prev_sample.to(model_output.dtype) 231 | 232 | # upon completion increase step index by one 233 | self._step_index += 1 234 | 235 | if not return_dict: 236 | return ( 237 | prev_sample, 238 | pred_original_sample, 239 | ) 240 | 241 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 242 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/adetailer.py: -------------------------------------------------------------------------------- 1 | # ===================================== 2 | # Adetailer 3 | # ===================================== 4 | from functools import partial 5 | import os 6 | from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, DDIMScheduler 7 | from huggingface_hub import hf_hub_download 8 | from typing import Any, Callable, Iterable, List, Mapping, Optional 9 | import numpy as np 10 | from PIL import Image 11 | import torch, copy, gc 12 | from ..logging.logging_setup import logger 13 | 14 | FIXED_SIZE_CLASS = [ 15 | "StableDiffusionControlNetInpaintPipeline", 16 | "StableDiffusionXLInpaintPipeline", 17 | "FluxImg2ImgPipeline", 18 | "FluxInpaintPipeline", 19 | ] 20 | 21 | 22 | def ad_model_process( 23 | detailfix_pipe, 24 | pipe_params_df, 25 | face_detector_ad, 26 | person_detector_ad, 27 | hand_detector_ad, 28 | image_list_task, # pil 29 | mask_dilation=4, 30 | mask_blur=4, 31 | mask_padding=32, 32 | custom_model_path="", 33 | ): 34 | # input: params pipe, detailfix_pipe, paras yolo 35 | # output: list of PIL images 36 | 37 | scheduler_assigned = copy.deepcopy(detailfix_pipe.scheduler) 38 | logger.debug(f"Base sampler detailfix_pipe: {scheduler_assigned}") 39 | 40 | detailfix_pipe.safety_checker = None 41 | detailfix_pipe.to("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | # detailfi resolution param 44 | if str(detailfix_pipe.__class__.__name__) in FIXED_SIZE_CLASS: 45 | pipe_params_df["height"] = image_list_task[0].size[1] 46 | pipe_params_df["width"] = image_list_task[0].size[0] 47 | logger.debug(f"detailfix inpaint only") 48 | else: 49 | pipe_params_df.pop("height", None) 50 | pipe_params_df.pop("width", None) 51 | logger.debug(f"detailfix img2img") 52 | 53 | image_list_ad = [] 54 | 55 | detectors = [] 56 | if person_detector_ad: 57 | person_model_path = hf_hub_download("Bingsu/adetailer", "person_yolov8s-seg.pt") 58 | person_detector = partial(yolo_detector, model_path=person_model_path) 59 | detectors.append(person_detector) 60 | if face_detector_ad: 61 | face_model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt") 62 | face_detector = partial(yolo_detector, model_path=face_model_path) 63 | detectors.append(face_detector) 64 | if hand_detector_ad: 65 | hand_model_path = hf_hub_download("Bingsu/adetailer", "hand_yolov8n.pt") 66 | hand_detector = partial(yolo_detector, model_path=hand_model_path) 67 | detectors.append(hand_detector) 68 | if custom_model_path and os.path.exists(custom_model_path): 69 | try: 70 | custom_detector = partial(yolo_detector, model_path=custom_model_path) 71 | detectors.append(custom_detector) 72 | except Exception as e: 73 | logger.error(f"Error loading custom model from {custom_model_path}: {str(e)}") 74 | 75 | image_list_ad = [] 76 | 77 | for i, init_image_base in enumerate(image_list_task): 78 | init_image = init_image_base.convert("RGB") 79 | final_image = None 80 | 81 | for j, detector in enumerate(detectors): 82 | masks = detector(init_image) 83 | 84 | if masks is None: 85 | logger.info( 86 | f"No object detected on {(i + 1)} image with {str(detector).split('/')[-1][:-2]} detector." 87 | ) 88 | continue 89 | 90 | for k, mask in enumerate(masks): 91 | mask = mask.convert("L") 92 | mask = mask_dilate(mask, mask_dilation) 93 | bbox = mask.getbbox() 94 | if bbox is None: 95 | logger.info(f"No object in {(k + 1)} mask.") 96 | continue 97 | mask = mask_gaussian_blur(mask, mask_blur) 98 | bbox_padded = bbox_padding(bbox, init_image.size, mask_padding) 99 | 100 | crop_image = init_image.crop(bbox_padded) 101 | crop_mask = mask.crop(bbox_padded) 102 | 103 | pipe_params_df["image"] = crop_image 104 | pipe_params_df["mask_image"] = crop_mask 105 | 106 | if str(detailfix_pipe.__class__.__name__) == "StableDiffusionControlNetInpaintPipeline": 107 | logger.debug("SD 1.5 detailfix") 108 | pipe_params_df["control_image"] = make_inpaint_condition(crop_image, crop_mask) 109 | 110 | try: 111 | inpaint_output = detailfix_pipe(**pipe_params_df) 112 | except Exception as e: 113 | e = str(e) 114 | if "Tensor with 2 elements cannot be converted to Scalar" in e: 115 | try: 116 | logger.error("Sampler not compatible with DetailFix; trying with DDIM sampler") 117 | logger.debug(e) 118 | detailfix_pipe.scheduler = detailfix_pipe.default_scheduler 119 | detailfix_pipe.scheduler = DDIMScheduler.from_config(detailfix_pipe.scheduler.config) 120 | 121 | inpaint_output = detailfix_pipe(**pipe_params_df) 122 | except Exception as ex: 123 | logger.error("trying with base sampler") 124 | logger.debug(str(ex)) 125 | detailfix_pipe.scheduler = detailfix_pipe.default_scheduler 126 | 127 | inpaint_output = detailfix_pipe(**pipe_params_df) 128 | elif "The size of tensor a (0) must match the size of tensor b (3) at non-singleton" in e or "cannot reshape tensor of 0 elements into shape [0, -1, 1, 512] because the unspecified dimensi" in e: 129 | logger.error(f"strength or steps too low for the model to produce a satisfactory response.") 130 | inpaint_output = [[crop_image]] 131 | else: 132 | raise ValueError(e) 133 | 134 | inpaint_image: Image.Image = inpaint_output[0][0] 135 | final_image = composite( 136 | init=init_image, 137 | mask=mask, 138 | gen=inpaint_image, 139 | bbox_padded=bbox_padded, 140 | ) 141 | init_image = final_image 142 | 143 | if final_image is not None: 144 | image_list_ad.append(final_image) 145 | else: 146 | logger.info( 147 | f"DetailFix: No detections found in image. Returning original image" 148 | ) 149 | image_list_ad.append(init_image_base) 150 | 151 | torch.cuda.empty_cache() 152 | gc.collect() 153 | 154 | detailfix_pipe.scheduler = scheduler_assigned 155 | 156 | torch.cuda.empty_cache() 157 | gc.collect() 158 | 159 | return image_list_ad 160 | 161 | 162 | # ===================================== 163 | # Yolo 164 | # ===================================== 165 | from pathlib import Path 166 | import numpy as np 167 | import torch 168 | from huggingface_hub import hf_hub_download 169 | from PIL import Image, ImageDraw 170 | from torchvision.transforms.functional import to_pil_image 171 | from ultralytics import YOLO 172 | 173 | 174 | def create_mask_from_bbox( 175 | bboxes: np.ndarray, shape: tuple[int, int] 176 | ) -> list[Image.Image]: 177 | """ 178 | Parameters 179 | ---------- 180 | bboxes: list[list[float]] 181 | list of [x1, y1, x2, y2] 182 | bounding boxes 183 | shape: tuple[int, int] 184 | shape of the image (width, height) 185 | 186 | Returns 187 | ------- 188 | masks: list[Image.Image] 189 | A list of masks 190 | 191 | """ 192 | masks = [] 193 | for bbox in bboxes: 194 | mask = Image.new("L", shape, "black") 195 | mask_draw = ImageDraw.Draw(mask) 196 | mask_draw.rectangle(bbox, fill="white") 197 | masks.append(mask) 198 | return masks 199 | 200 | 201 | def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]: 202 | """ 203 | Parameters 204 | ---------- 205 | masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). 206 | The device can be CUDA, but `to_pil_image` takes care of that. 207 | 208 | shape: tuple[int, int] 209 | (width, height) of the original image 210 | 211 | Returns 212 | ------- 213 | images: list[Image.Image] 214 | """ 215 | n = masks.shape[0] 216 | return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] 217 | 218 | 219 | def yolo_detector( 220 | image: Image.Image, model_path: str | Path | None = None, confidence: float = 0.3 221 | ) -> list[Image.Image] | None: 222 | if not model_path: 223 | model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt") 224 | model = YOLO(model_path) 225 | pred = model(image, conf=confidence) 226 | 227 | bboxes = pred[0].boxes.xyxy.cpu().numpy() 228 | if bboxes.size == 0: 229 | return None 230 | 231 | if pred[0].masks is None: 232 | masks = create_mask_from_bbox(bboxes, image.size) 233 | else: 234 | masks = mask_to_pil(pred[0].masks.data, image.size) 235 | 236 | return masks 237 | 238 | 239 | # ===================================== 240 | # Utils 241 | # ===================================== 242 | 243 | import cv2 244 | import numpy as np 245 | from PIL import Image, ImageFilter, ImageOps 246 | import torch 247 | 248 | 249 | def mask_dilate(image: Image.Image, value: int = 4) -> Image.Image: 250 | if value <= 0: 251 | return image 252 | 253 | arr = np.array(image) 254 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) 255 | dilated = cv2.dilate(arr, kernel, iterations=1) 256 | return Image.fromarray(dilated) 257 | 258 | 259 | def mask_gaussian_blur(image: Image.Image, value: int = 4) -> Image.Image: 260 | if value <= 0: 261 | return image 262 | 263 | blur = ImageFilter.GaussianBlur(value) 264 | return image.filter(blur) 265 | 266 | 267 | def bbox_padding( 268 | bbox: tuple[int, int, int, int], image_size: tuple[int, int], value: int = 32 269 | ) -> tuple[int, int, int, int]: 270 | if value <= 0: 271 | return bbox 272 | 273 | arr = np.array(bbox).reshape(2, 2) 274 | arr[0] -= value 275 | arr[1] += value 276 | arr = np.clip(arr, (0, 0), image_size) 277 | return tuple(arr.flatten()) 278 | 279 | 280 | def composite( 281 | init: Image.Image, 282 | mask: Image.Image, 283 | gen: Image.Image, 284 | bbox_padded: tuple[int, int, int, int], 285 | ) -> Image.Image: 286 | img_masked = Image.new("RGBa", init.size) 287 | img_masked.paste( 288 | init.convert("RGBA").convert("RGBa"), 289 | mask=ImageOps.invert(mask), 290 | ) 291 | img_masked = img_masked.convert("RGBA") 292 | 293 | size = ( 294 | bbox_padded[2] - bbox_padded[0], 295 | bbox_padded[3] - bbox_padded[1], 296 | ) 297 | resized = gen.resize(size) 298 | 299 | output = Image.new("RGBA", init.size) 300 | output.paste(resized, bbox_padded) 301 | output.alpha_composite(img_masked) 302 | return output.convert("RGB") 303 | 304 | 305 | 306 | def make_inpaint_condition(init_image, mask_image): 307 | init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0 308 | mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 309 | 310 | assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size" 311 | init_image[mask_image > 0.5] = -1.0 # set as masked pixel 312 | init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2) 313 | init_image = torch.from_numpy(init_image) 314 | return init_image 315 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/controlnet_aux_beta/teed/ted.py: -------------------------------------------------------------------------------- 1 | # Original from: https://github.com/xavysp/TEED 2 | # TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3 3 | # with a Slightly modification 4 | # LDC parameters: 5 | # 155665 6 | # TED > 58K 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .Fsmish import smish as Fsmish 13 | from .Xsmish import Smish 14 | 15 | 16 | def weight_init(m): 17 | if isinstance(m, (nn.Conv2d,)): 18 | torch.nn.init.xavier_normal_(m.weight, gain=1.0) 19 | 20 | if m.bias is not None: 21 | torch.nn.init.zeros_(m.bias) 22 | 23 | # for fusion layer 24 | if isinstance(m, (nn.ConvTranspose2d,)): 25 | torch.nn.init.xavier_normal_(m.weight, gain=1.0) 26 | if m.bias is not None: 27 | torch.nn.init.zeros_(m.bias) 28 | 29 | 30 | class CoFusion(nn.Module): 31 | # from LDC 32 | 33 | def __init__(self, in_ch, out_ch): 34 | super(CoFusion, self).__init__() 35 | self.conv1 = nn.Conv2d( 36 | in_ch, 32, kernel_size=3, stride=1, padding=1 37 | ) # before 64 38 | self.conv3 = nn.Conv2d( 39 | 32, out_ch, kernel_size=3, stride=1, padding=1 40 | ) # before 64 instead of 32 41 | self.relu = nn.ReLU() 42 | self.norm_layer1 = nn.GroupNorm(4, 32) # before 64 43 | 44 | def forward(self, x): 45 | # fusecat = torch.cat(x, dim=1) 46 | attn = self.relu(self.norm_layer1(self.conv1(x))) 47 | attn = F.softmax(self.conv3(attn), dim=1) 48 | return ((x * attn).sum(1)).unsqueeze(1) 49 | 50 | 51 | class CoFusion2(nn.Module): 52 | # TEDv14-3 53 | def __init__(self, in_ch, out_ch): 54 | super(CoFusion2, self).__init__() 55 | self.conv1 = nn.Conv2d( 56 | in_ch, 32, kernel_size=3, stride=1, padding=1 57 | ) # before 64 58 | # self.conv2 = nn.Conv2d(32, 32, kernel_size=3, 59 | # stride=1, padding=1)# before 64 60 | self.conv3 = nn.Conv2d( 61 | 32, out_ch, kernel_size=3, stride=1, padding=1 62 | ) # before 64 instead of 32 63 | self.smish = Smish() # nn.ReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | # fusecat = torch.cat(x, dim=1) 67 | attn = self.conv1(self.smish(x)) 68 | attn = self.conv3(self.smish(attn)) # before , )dim=1) 69 | 70 | # return ((fusecat * attn).sum(1)).unsqueeze(1) 71 | return ((x * attn).sum(1)).unsqueeze(1) 72 | 73 | 74 | class DoubleFusion(nn.Module): 75 | # TED fusion before the final edge map prediction 76 | def __init__(self, in_ch, out_ch): 77 | super(DoubleFusion, self).__init__() 78 | self.DWconv1 = nn.Conv2d( 79 | in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch 80 | ) # before 64 81 | self.PSconv1 = nn.PixelShuffle(1) 82 | 83 | self.DWconv2 = nn.Conv2d( 84 | 24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24 85 | ) # before 64 instead of 32 86 | 87 | self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()# 88 | 89 | def forward(self, x): 90 | # fusecat = torch.cat(x, dim=1) 91 | attn = self.PSconv1( 92 | self.DWconv1(self.AF(x)) 93 | ) # #TEED best res TEDv14 [8, 32, 352, 352] 94 | 95 | attn2 = self.PSconv1( 96 | self.DWconv2(self.AF(attn)) 97 | ) # #TEED best res TEDv14[8, 3, 352, 352] 98 | 99 | return Fsmish(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res 100 | 101 | 102 | class _DenseLayer(nn.Sequential): 103 | def __init__(self, input_features, out_features): 104 | super(_DenseLayer, self).__init__() 105 | 106 | ( 107 | self.add_module( 108 | "conv1", 109 | nn.Conv2d( 110 | input_features, 111 | out_features, 112 | kernel_size=3, 113 | stride=1, 114 | padding=2, 115 | bias=True, 116 | ), 117 | ), 118 | ) 119 | (self.add_module("smish1", Smish()),) 120 | self.add_module( 121 | "conv2", 122 | nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True), 123 | ) 124 | 125 | def forward(self, x): 126 | x1, x2 = x 127 | 128 | new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu() 129 | 130 | return 0.5 * (new_features + x2), x2 131 | 132 | 133 | class _DenseBlock(nn.Sequential): 134 | def __init__(self, num_layers, input_features, out_features): 135 | super(_DenseBlock, self).__init__() 136 | for i in range(num_layers): 137 | layer = _DenseLayer(input_features, out_features) 138 | self.add_module("denselayer%d" % (i + 1), layer) 139 | input_features = out_features 140 | 141 | 142 | class UpConvBlock(nn.Module): 143 | def __init__(self, in_features, up_scale): 144 | super(UpConvBlock, self).__init__() 145 | self.up_factor = 2 146 | self.constant_features = 16 147 | 148 | layers = self.make_deconv_layers(in_features, up_scale) 149 | assert layers is not None, layers 150 | self.features = nn.Sequential(*layers) 151 | 152 | def make_deconv_layers(self, in_features, up_scale): 153 | layers = [] 154 | all_pads = [0, 0, 1, 3, 7] 155 | for i in range(up_scale): 156 | kernel_size = 2**up_scale 157 | pad = all_pads[up_scale] # kernel_size-1 158 | out_features = self.compute_out_features(i, up_scale) 159 | layers.append(nn.Conv2d(in_features, out_features, 1)) 160 | layers.append(Smish()) 161 | layers.append( 162 | nn.ConvTranspose2d( 163 | out_features, out_features, kernel_size, stride=2, padding=pad 164 | ) 165 | ) 166 | in_features = out_features 167 | return layers 168 | 169 | def compute_out_features(self, idx, up_scale): 170 | return 1 if idx == up_scale - 1 else self.constant_features 171 | 172 | def forward(self, x): 173 | return self.features(x) 174 | 175 | 176 | class SingleConvBlock(nn.Module): 177 | def __init__(self, in_features, out_features, stride, use_ac=False): 178 | super(SingleConvBlock, self).__init__() 179 | # self.use_bn = use_bs 180 | self.use_ac = use_ac 181 | self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True) 182 | if self.use_ac: 183 | self.smish = Smish() 184 | 185 | def forward(self, x): 186 | x = self.conv(x) 187 | if self.use_ac: 188 | return self.smish(x) 189 | else: 190 | return x 191 | 192 | 193 | class DoubleConvBlock(nn.Module): 194 | def __init__( 195 | self, in_features, mid_features, out_features=None, stride=1, use_act=True 196 | ): 197 | super(DoubleConvBlock, self).__init__() 198 | 199 | self.use_act = use_act 200 | if out_features is None: 201 | out_features = mid_features 202 | self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride) 203 | self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) 204 | self.smish = Smish() # nn.ReLU(inplace=True) 205 | 206 | def forward(self, x): 207 | x = self.conv1(x) 208 | x = self.smish(x) 209 | x = self.conv2(x) 210 | if self.use_act: 211 | x = self.smish(x) 212 | return x 213 | 214 | 215 | class TED(nn.Module): 216 | """Definition of Tiny and Efficient Edge Detector 217 | model 218 | """ 219 | 220 | def __init__(self): 221 | super(TED, self).__init__() 222 | self.block_1 = DoubleConvBlock( 223 | 3, 224 | 16, 225 | 16, 226 | stride=2, 227 | ) 228 | self.block_2 = DoubleConvBlock(16, 32, use_act=False) 229 | self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64) 230 | 231 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 232 | 233 | # skip1 connection, see fig. 2 234 | self.side_1 = SingleConvBlock(16, 32, 2) 235 | 236 | # skip2 connection, see fig. 2 237 | self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1) 238 | 239 | # USNet 240 | self.up_block_1 = UpConvBlock(16, 1) 241 | self.up_block_2 = UpConvBlock(32, 1) 242 | self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1) 243 | 244 | self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion 245 | 246 | self.apply(weight_init) 247 | 248 | def slice(self, tensor, slice_shape): 249 | t_shape = tensor.shape 250 | img_h, img_w = slice_shape 251 | if img_w != t_shape[-1] or img_h != t_shape[2]: 252 | new_tensor = F.interpolate( 253 | tensor, size=(img_h, img_w), mode="bicubic", align_corners=False 254 | ) 255 | 256 | else: 257 | new_tensor = tensor 258 | # tensor[..., :height, :width] 259 | return new_tensor 260 | 261 | def resize_input(self, tensor): 262 | t_shape = tensor.shape 263 | if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0: 264 | img_w = ((t_shape[3] // 8) + 1) * 8 265 | img_h = ((t_shape[2] // 8) + 1) * 8 266 | new_tensor = F.interpolate( 267 | tensor, size=(img_h, img_w), mode="bicubic", align_corners=False 268 | ) 269 | else: 270 | new_tensor = tensor 271 | return new_tensor 272 | 273 | def crop_bdcn(data1, h, w, crop_h, crop_w): 274 | # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN 275 | _, _, h1, w1 = data1.size() 276 | assert h <= h1 and w <= w1 277 | data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w] 278 | return data 279 | 280 | def forward(self, x, single_test=False): 281 | assert x.ndim == 4, x.shape 282 | # supose the image size is 352x352 283 | 284 | # Block 1 285 | block_1 = self.block_1(x) # [8,16,176,176] 286 | block_1_side = self.side_1(block_1) # 16 [8,32,88,88] 287 | 288 | # Block 2 289 | block_2 = self.block_2(block_1) # 32 # [8,32,176,176] 290 | block_2_down = self.maxpool(block_2) # [8,32,88,88] 291 | block_2_add = block_2_down + block_1_side # [8,32,88,88] 292 | 293 | # Block 3 294 | block_3_pre_dense = self.pre_dense_3( 295 | block_2_down 296 | ) # [8,64,88,88] block 3 L connection 297 | block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88] 298 | 299 | # upsampling blocks 300 | out_1 = self.up_block_1(block_1) 301 | out_2 = self.up_block_2(block_2) 302 | out_3 = self.up_block_3(block_3) 303 | 304 | results = [out_1, out_2, out_3] 305 | 306 | # concatenate multiscale outputs 307 | block_cat = torch.cat(results, dim=1) # Bx6xHxW 308 | block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion 309 | 310 | results.append(block_cat) 311 | return results 312 | 313 | 314 | if __name__ == "__main__": 315 | batch_size = 8 316 | img_height = 352 317 | img_width = 352 318 | 319 | # device = "cuda" if torch.cuda.is_available() else "cpu" 320 | device = "cpu" 321 | input = torch.rand(batch_size, 3, img_height, img_width).to(device) 322 | # target = torch.rand(batch_size, 1, img_height, img_width).to(device) 323 | print(f"input shape: {input.shape}") 324 | model = TED().to(device) 325 | output = model(input) 326 | print(f"output shapes: {[t.shape for t in output]}") 327 | 328 | # for i in range(20000): 329 | # print(i) 330 | # output = model(input) 331 | # loss = nn.MSELoss()(output[-1], target) 332 | # loss.backward() 333 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/sd_embed/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/main_prompt_embeds.py: -------------------------------------------------------------------------------- 1 | from .multi_emphasis_prompt import ( 2 | ClassicTextProcessingEngine, 3 | pad_equal_len, 4 | ) 5 | import torch 6 | import gc 7 | from compel import Compel, ReturnedEmbeddingsType 8 | from ..logging.logging_setup import logger 9 | from .constants import ( 10 | PROMPT_WEIGHT_OPTIONS, 11 | OLD_PROMPT_WEIGHT_OPTIONS, 12 | SD_EMBED, 13 | CLASSIC_VARIANT, 14 | ALL_PROMPT_WEIGHT_OPTIONS, 15 | ) 16 | from .prompt_weights import get_embed_new 17 | from .sd_embed.embedding_funcs import ( 18 | get_weighted_text_embeddings_sd15, 19 | get_weighted_text_embeddings_sdxl, 20 | get_weighted_text_embeddings_flux1, 21 | ) 22 | 23 | 24 | class Prompt_Embedder_Base: 25 | def __init__(self): 26 | self.last_clip_skip = None 27 | 28 | def apply_ti(self, class_name, textual_inversion, pipe, device, gui_active): 29 | 30 | if "FluxPipeline" == class_name: 31 | logger.warning("Textual Inverstion not available") 32 | return None 33 | 34 | # Textual Inversion 35 | for name, directory_name in textual_inversion: 36 | try: 37 | if class_name == "StableDiffusionPipeline": 38 | if directory_name.endswith(".pt"): 39 | model = torch.load(directory_name, map_location=device) 40 | model_tensors = model.get("string_to_param").get("*") 41 | s_model = {"emb_params": model_tensors} 42 | # save_file(s_model, directory_name[:-3] + '.safetensors') 43 | pipe.load_textual_inversion(s_model, token=name) 44 | else: 45 | # pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer),pad_to_multiple_of=128) 46 | # pipe.load_textual_inversion("./bad_prompt.pt", token="baddd") 47 | pipe.load_textual_inversion(directory_name, token=name) 48 | elif class_name == "StableDiffusionXLPipeline": 49 | from safetensors.torch import load_file 50 | state_dict = load_file(directory_name) 51 | pipe.load_textual_inversion(state_dict["clip_g"], token=name, text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) 52 | pipe.load_textual_inversion(state_dict["clip_l"], token=name, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) 53 | else: 54 | logger.error("Textual Inversion not combatible") 55 | 56 | logger.info(f"Applied: {name}") 57 | except Exception as e: 58 | exception = str(e) 59 | if name in exception: 60 | logger.debug(f"Previous loaded embed {name}") 61 | else: 62 | logger.error(exception) 63 | logger.error(f"Can't apply embed {name}") 64 | 65 | @torch.no_grad() 66 | def __call__(self, prompt, negative_prompt, syntax_weights, pipe, clip_skip, compel): 67 | if syntax_weights in CLASSIC_VARIANT: 68 | emphasis = CLASSIC_VARIANT[syntax_weights] 69 | return self.classic_variant( 70 | prompt, negative_prompt, pipe, clip_skip, emphasis 71 | ) 72 | elif syntax_weights in SD_EMBED: 73 | return self.sd_embed_variant( 74 | prompt, negative_prompt, pipe, clip_skip 75 | ) 76 | else: 77 | return self.compel_processor( 78 | prompt, negative_prompt, pipe, clip_skip, syntax_weights, compel 79 | ) 80 | 81 | 82 | class Promt_Embedder_SD1(Prompt_Embedder_Base): 83 | 84 | def classic_variant(self, prompt, negative_prompt, pipe, clip_skip, emphasis): 85 | 86 | clip_l_engine = ClassicTextProcessingEngine( 87 | text_encoder=pipe.text_encoder, 88 | tokenizer=pipe.tokenizer, 89 | chunk_length=75, 90 | emphasis_name=emphasis, 91 | text_projection=False, 92 | minimal_clip_skip=1, 93 | clip_skip=2 if clip_skip else 1, 94 | return_pooled=False, 95 | final_layer_norm=True, 96 | ) 97 | 98 | cond = clip_l_engine(prompt) 99 | uncond = clip_l_engine(negative_prompt) 100 | 101 | cond, uncond = pad_equal_len(clip_l_engine, cond, uncond) 102 | 103 | return cond, uncond, None 104 | 105 | def sd_embed_variant(self, prompt, negative_prompt, pipe, clip_skip): 106 | 107 | ( 108 | cond_embeddings, 109 | uncond_embeddings 110 | ) = get_weighted_text_embeddings_sd15( 111 | pipe, 112 | prompt=prompt, 113 | neg_prompt=negative_prompt, 114 | clip_skip=(1 if clip_skip else 0) 115 | ) 116 | 117 | return cond_embeddings, uncond_embeddings, None 118 | 119 | def compel_processor(self, prompt, negative_prompt, pipe, clip_skip, syntax_weights, compel): 120 | 121 | if compel is None or clip_skip != self.last_clip_skip: 122 | compel = Compel( 123 | tokenizer=pipe.tokenizer, 124 | text_encoder=pipe.text_encoder, 125 | truncate_long_prompts=False, 126 | returned_embeddings_type=( 127 | ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED 128 | if clip_skip 129 | else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED), 130 | ) 131 | self.last_clip_skip = clip_skip 132 | 133 | # Syntax weights 134 | compel_weights = False if syntax_weights == "Classic" else True 135 | # pipe.to(device) 136 | prompt_emb = get_embed_new( 137 | prompt, pipe, compel, compel_process_sd=compel_weights 138 | ) 139 | negative_prompt_emb = get_embed_new( 140 | negative_prompt, pipe, compel, compel_process_sd=compel_weights 141 | ) 142 | 143 | # Fix error shape 144 | if prompt_emb.shape != negative_prompt_emb.shape: 145 | ( 146 | prompt_emb, 147 | negative_prompt_emb, 148 | ) = compel.pad_conditioning_tensors_to_same_length( 149 | [prompt_emb, negative_prompt_emb] 150 | ) 151 | 152 | return prompt_emb, negative_prompt_emb, compel 153 | 154 | 155 | class Promt_Embedder_SDXL(Prompt_Embedder_Base): 156 | def classic_variant(self, prompt, negative_prompt, pipe, clip_skip, emphasis): 157 | 158 | clip_l_engine = ClassicTextProcessingEngine( 159 | text_encoder=pipe.text_encoder, 160 | tokenizer=pipe.tokenizer, 161 | emphasis_name=emphasis, 162 | text_projection=False, 163 | minimal_clip_skip=2, 164 | clip_skip=2 if clip_skip else 1, 165 | return_pooled=False, 166 | final_layer_norm=False, 167 | ) 168 | 169 | clip_g_engine = ClassicTextProcessingEngine( 170 | text_encoder=pipe.text_encoder_2, 171 | tokenizer=pipe.tokenizer_2, 172 | emphasis_name=emphasis, 173 | text_projection=True, 174 | minimal_clip_skip=2, 175 | clip_skip=2 if clip_skip else 1, 176 | return_pooled=True, 177 | final_layer_norm=False, 178 | ) 179 | 180 | cond = clip_l_engine(prompt) 181 | uncond = clip_l_engine(negative_prompt) 182 | cond, uncond = pad_equal_len(clip_l_engine, cond, uncond) 183 | 184 | cond_2, cond_pooled = clip_g_engine(prompt) 185 | uncond_2, uncond_pooled = clip_g_engine(negative_prompt) 186 | clip_g_engine.return_pooled = False 187 | cond_2, uncond_2 = pad_equal_len(clip_g_engine, cond_2, uncond_2) 188 | 189 | cond_embed = torch.cat((cond, cond_2), dim=2) 190 | neg_uncond_embed = torch.cat((uncond, uncond_2), dim=2) 191 | 192 | all_cond = torch.cat([cond_embed, neg_uncond_embed]) 193 | all_pooled = torch.cat([cond_pooled, uncond_pooled]) 194 | 195 | return all_cond, all_pooled, None 196 | 197 | def sd_embed_variant(self, prompt, negative_prompt, pipe, clip_skip): 198 | 199 | ( 200 | cond_embed, 201 | neg_uncond_embed, 202 | cond_pooled, 203 | uncond_pooled 204 | ) = get_weighted_text_embeddings_sdxl( 205 | pipe, 206 | prompt=prompt, 207 | neg_prompt=negative_prompt, 208 | ) 209 | 210 | all_cond = torch.cat([cond_embed, neg_uncond_embed]) 211 | 212 | all_pooled = torch.cat([cond_pooled, uncond_pooled]) 213 | 214 | assert torch.equal(all_cond[0:1], cond_embed), "Tensors are not equal" 215 | 216 | return all_cond, all_pooled, None 217 | 218 | def compel_processor(self, prompt, negative_prompt, pipe, clip_skip, syntax_weights, compel): 219 | 220 | if compel is None: 221 | compel = Compel( 222 | tokenizer=[pipe.tokenizer, pipe.tokenizer_2], 223 | text_encoder=[pipe.text_encoder, pipe.text_encoder_2], 224 | requires_pooled=[False, True], 225 | truncate_long_prompts=False, 226 | returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, 227 | ) 228 | 229 | # Syntax weights 230 | # pipe.to(device) 231 | if syntax_weights == "Classic": 232 | prompt = get_embed_new( 233 | prompt, pipe, compel, only_convert_string=True 234 | ) 235 | negative_prompt = get_embed_new( 236 | negative_prompt, pipe, compel, only_convert_string=True 237 | ) 238 | 239 | conditioning, pooled = compel([prompt, negative_prompt]) 240 | 241 | return conditioning, pooled, compel 242 | 243 | 244 | class Promt_Embedder_FLUX(Prompt_Embedder_Base): 245 | def classic_variant(self, prompt, negative_prompt, pipe, clip_skip, emphasis): 246 | 247 | torch.cuda.empty_cache() 248 | gc.collect() 249 | 250 | clip_l_engine = ClassicTextProcessingEngine( 251 | text_encoder=pipe.text_encoder, 252 | tokenizer=pipe.tokenizer, 253 | emphasis_name=emphasis, 254 | text_projection=False, 255 | minimal_clip_skip=1, 256 | clip_skip=1, 257 | return_pooled=True, 258 | final_layer_norm=True, 259 | ) 260 | 261 | from .t5_embedder import T5TextProcessingEngine 262 | t5_engine = T5TextProcessingEngine( 263 | pipe.text_encoder_2, 264 | pipe.tokenizer_2, 265 | emphasis_name=emphasis, 266 | min_length=( 267 | 256 # 512 if pipe.transformer.config.guidance_embeds else 256 268 | ), 269 | ) 270 | 271 | _, cond_pooled = clip_l_engine(prompt) 272 | cond = t5_engine(prompt) 273 | 274 | cond_pooled = cond_pooled.to(dtype=pipe.text_encoder.dtype) 275 | 276 | if cond.shape[0] > 1: 277 | tensor_slices = [cond[i:i + 1, :, :] for i in range(cond.shape[0])] 278 | cond = torch.cat(tensor_slices, dim=1) 279 | cond = cond.to(dtype=pipe.text_encoder_2.dtype) 280 | 281 | torch.cuda.empty_cache() 282 | gc.collect() 283 | 284 | return cond, cond_pooled, None 285 | 286 | def sd_embed_variant(self, prompt, negative_prompt, pipe, clip_skip): 287 | 288 | torch.cuda.empty_cache() 289 | gc.collect() 290 | 291 | ( 292 | positive_embeddings, 293 | pooled_embeddings 294 | ) = get_weighted_text_embeddings_flux1( 295 | pipe=pipe, 296 | prompt=prompt, 297 | ) 298 | 299 | positive_embeddings = positive_embeddings.to(dtype=pipe.text_encoder_2.dtype) 300 | pooled_embeddings = pooled_embeddings.to(dtype=pipe.text_encoder.dtype) 301 | 302 | torch.cuda.empty_cache() 303 | gc.collect() 304 | 305 | return positive_embeddings, pooled_embeddings, None 306 | 307 | def compel_processor(self, prompt, negative_prompt, pipe, clip_skip, syntax_weights, compel): 308 | 309 | # pipe.text_encoder_2.to("cuda") 310 | # pipe.transformer.to("cpu") 311 | # torch.cuda.empty_cache() 312 | # gc.collect() 313 | 314 | prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt( 315 | prompt=prompt, 316 | prompt_2=None, 317 | device=pipe.text_encoder.device, 318 | num_images_per_prompt=1, 319 | prompt_embeds=None, 320 | pooled_prompt_embeds=None, 321 | max_sequence_length=512, 322 | lora_scale=None, 323 | ) 324 | 325 | # pipe.text_encoder_2.to("cpu") 326 | # pipe.transformer.to("cuda") 327 | torch.cuda.empty_cache() 328 | gc.collect() 329 | 330 | return prompt_embeds, pooled_prompt_embeds, None 331 | -------------------------------------------------------------------------------- /stablepy/upscalers/utils_upscaler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | from typing import Callable 4 | from collections import namedtuple 5 | import gc 6 | 7 | import tqdm 8 | import numpy as np 9 | from PIL import Image 10 | import spandrel 11 | import torch 12 | import torch.nn 13 | import math 14 | 15 | from stablepy.logging.logging_setup import logger 16 | 17 | _spandrel_extra_init_state = None 18 | 19 | 20 | def load_file_from_url( 21 | url: str, 22 | *, 23 | model_dir: str, 24 | progress: bool = True, 25 | file_name: str | None = None, 26 | re_download: bool = False, 27 | ) -> str: 28 | """Download a file from `url` into `model_dir`, using the file present if possible. 29 | Returns the path to the downloaded file. 30 | 31 | file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url. 32 | file is downloaded to {file_name}.tmp then moved to the final location after download is complete. 33 | re_download: forcibly re-download the file even if it already exists. 34 | """ 35 | from urllib.parse import urlparse 36 | import requests 37 | 38 | if not file_name: 39 | parts = urlparse(url) 40 | file_name = os.path.basename(parts.path) 41 | 42 | cached_file = os.path.abspath(os.path.join(model_dir, file_name)) 43 | 44 | if re_download or not os.path.exists(cached_file): 45 | os.makedirs(model_dir, exist_ok=True) 46 | temp_file = os.path.join(model_dir, f"{file_name}.tmp") 47 | logger.info(f'\nDownloading: "{url}" to {cached_file}') 48 | response = requests.get(url, stream=True) 49 | response.raise_for_status() 50 | total_size = int(response.headers.get('content-length', 0)) 51 | with tqdm.auto.tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar: 52 | with open(temp_file, 'wb') as file: 53 | for chunk in response.iter_content(chunk_size=1024): 54 | if chunk: 55 | file.write(chunk) 56 | progress_bar.update(len(chunk)) 57 | 58 | os.rename(temp_file, cached_file) 59 | return cached_file 60 | 61 | 62 | def _init_spandrel_extra_archs() -> None: 63 | """ 64 | Try to initialize `spandrel_extra_archs` (exactly once). 65 | """ 66 | global _spandrel_extra_init_state 67 | if _spandrel_extra_init_state is not None: 68 | return 69 | 70 | try: 71 | import spandrel 72 | import spandrel_extra_arches 73 | spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY) 74 | _spandrel_extra_init_state = True 75 | except Exception: 76 | logger.warning("Failed to load spandrel_extra_arches", exc_info=True) 77 | _spandrel_extra_init_state = False 78 | 79 | 80 | def load_spandrel_model( 81 | path: str | os.PathLike, 82 | *, 83 | device: str | torch.device | None, 84 | prefer_half: bool = False, 85 | dtype: str | torch.dtype | None = None, 86 | expected_architecture: str | None = None, 87 | ) -> spandrel.ModelDescriptor: 88 | global _spandrel_extra_init_state 89 | 90 | import spandrel 91 | _init_spandrel_extra_archs() 92 | 93 | model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) 94 | arch = model_descriptor.architecture 95 | if expected_architecture and arch.name != expected_architecture: 96 | logger.warning( 97 | f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})", 98 | ) 99 | half = False 100 | if prefer_half: 101 | if model_descriptor.supports_half: 102 | model_descriptor.model.half() 103 | half = True 104 | else: 105 | logger.info("Model %s does not support half precision, ignoring half", path) 106 | if dtype: 107 | model_descriptor.model.to(dtype=dtype) 108 | model_descriptor.model.eval() 109 | logger.debug( 110 | "Loaded %s from %s (device=%s, half=%s, dtype=%s)", 111 | arch, path, device, half, dtype, 112 | ) 113 | return model_descriptor 114 | 115 | 116 | def get_param(model) -> torch.nn.Parameter: 117 | """ 118 | Find the first parameter in a model or module. 119 | """ 120 | if hasattr(model, "model") and hasattr(model.model, "parameters"): 121 | # Unpeel a model descriptor to get at the actual Torch module. 122 | model = model.model 123 | 124 | for param in model.parameters(): 125 | return param 126 | 127 | raise ValueError(f"No parameters found in model {model!r}") 128 | 129 | 130 | class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])): 131 | @property 132 | def tile_count(self) -> int: 133 | """ 134 | The total number of tiles in the grid. 135 | """ 136 | return sum(len(row[2]) for row in self.tiles) 137 | 138 | 139 | def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: 140 | w, h = image.size 141 | 142 | non_overlap_width = tile_w - overlap 143 | non_overlap_height = tile_h - overlap 144 | 145 | cols = math.ceil((w - overlap) / non_overlap_width) 146 | rows = math.ceil((h - overlap) / non_overlap_height) 147 | 148 | dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 149 | dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 150 | 151 | grid = Grid([], tile_w, tile_h, w, h, overlap) 152 | for row in range(rows): 153 | row_images = [] 154 | 155 | y = int(row * dy) 156 | 157 | if y + tile_h >= h: 158 | y = h - tile_h 159 | 160 | for col in range(cols): 161 | x = int(col * dx) 162 | 163 | if x + tile_w >= w: 164 | x = w - tile_w 165 | 166 | tile = image.crop((x, y, x + tile_w, y + tile_h)) 167 | 168 | row_images.append([x, tile_w, tile]) 169 | 170 | grid.tiles.append([y, tile_h, row_images]) 171 | 172 | return grid 173 | 174 | 175 | def combine_grid(grid): 176 | def make_mask_image(r): 177 | r = r * 255 / grid.overlap 178 | r = r.astype(np.uint8) 179 | return Image.fromarray(r, 'L') 180 | 181 | mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) 182 | mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) 183 | 184 | combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) 185 | for y, h, row in grid.tiles: 186 | combined_row = Image.new("RGB", (grid.image_w, h)) 187 | for x, w, tile in row: 188 | if x == 0: 189 | combined_row.paste(tile, (0, 0)) 190 | continue 191 | 192 | combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) 193 | combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) 194 | 195 | if y == 0: 196 | combined_image.paste(combined_row, (0, 0)) 197 | continue 198 | 199 | combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h) 200 | combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap)) 201 | 202 | return combined_image 203 | 204 | 205 | def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: 206 | img = np.array(img.convert("RGB")) 207 | img = img[:, :, ::-1] # flip RGB to BGR 208 | img = np.transpose(img, (2, 0, 1)) # HWC to CHW 209 | img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] 210 | return torch.from_numpy(img) 211 | 212 | 213 | def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: 214 | if tensor.ndim == 4: 215 | # If we're given a tensor with a batch dimension, squeeze it out 216 | # (but only if it's a batch of size 1). 217 | if tensor.shape[0] != 1: 218 | raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") 219 | tensor = tensor.squeeze(0) 220 | assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" 221 | # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? 222 | arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp 223 | arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale 224 | arr = arr.round().astype(np.uint8) 225 | arr = arr[:, :, ::-1] # flip BGR to RGB 226 | return Image.fromarray(arr, "RGB") 227 | 228 | 229 | def upscale_pil_patch(model, img: Image.Image) -> Image.Image: 230 | """ 231 | Upscale a given PIL image using the given model. 232 | """ 233 | param = get_param(model) 234 | 235 | with torch.inference_mode(): 236 | tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension 237 | tensor = tensor.to(device=param.device, dtype=param.dtype) 238 | return torch_bgr_to_pil_image(model(tensor)) 239 | 240 | 241 | def upscale_with_model( 242 | model: Callable[[torch.Tensor], torch.Tensor], 243 | img: Image.Image, 244 | *, 245 | tile_size: int, 246 | tile_overlap: int = 0, 247 | desc="tiled upscale", 248 | disable_progress_bar=False, 249 | ) -> Image.Image: 250 | if tile_size <= 0: 251 | logger.debug("Upscaling %s without tiling", img) 252 | output = upscale_pil_patch(model, img) 253 | logger.debug("=> %s", output) 254 | return output 255 | 256 | grid = split_grid(img, tile_size, tile_size, tile_overlap) 257 | newtiles = [] 258 | 259 | with tqdm.auto.tqdm(total=grid.tile_count, desc=desc, disable=disable_progress_bar) as p: 260 | for y, h, row in grid.tiles: 261 | newrow = [] 262 | for x, w, tile in row: 263 | output = upscale_pil_patch(model, tile) 264 | scale_factor = output.width // tile.width 265 | newrow.append([x * scale_factor, w * scale_factor, output]) 266 | p.update(1) 267 | newtiles.append([y * scale_factor, h * scale_factor, newrow]) 268 | 269 | newgrid = Grid( 270 | newtiles, 271 | tile_w=grid.tile_w * scale_factor, 272 | tile_h=grid.tile_h * scale_factor, 273 | image_w=grid.image_w * scale_factor, 274 | image_h=grid.image_h * scale_factor, 275 | overlap=grid.overlap * scale_factor, 276 | ) 277 | return combine_grid(newgrid) 278 | 279 | 280 | def tiled_upscale_2( 281 | img: torch.Tensor, 282 | model, 283 | *, 284 | tile_size: int, 285 | tile_overlap: int, 286 | scale: int, 287 | device: torch.device, 288 | desc="Tiled upscale", 289 | disable_progress_bar=False, 290 | ): 291 | # Alternative implementation of `upscale_with_model` originally used by 292 | # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and 293 | # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in 294 | # Pillow space without weighting. 295 | 296 | b, c, h, w = img.size() 297 | tile_size = min(tile_size, h, w) 298 | 299 | if tile_size <= 0: 300 | logger.debug("Upscaling %s without tiling", img.shape) 301 | return model(img) 302 | 303 | stride = tile_size - tile_overlap 304 | h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] 305 | w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] 306 | result = torch.zeros( 307 | b, 308 | c, 309 | h * scale, 310 | w * scale, 311 | device=device, 312 | dtype=img.dtype, 313 | ) 314 | weights = torch.zeros_like(result) 315 | logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) 316 | with tqdm.auto.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=disable_progress_bar) as pbar: 317 | for h_idx in h_idx_list: 318 | for w_idx in w_idx_list: 319 | 320 | # Only move this patch to the device if it's not already there. 321 | in_patch = img[ 322 | ..., 323 | h_idx: h_idx + tile_size, 324 | w_idx: w_idx + tile_size, 325 | ].to(device=device) 326 | 327 | out_patch = model(in_patch) 328 | 329 | result[ 330 | ..., 331 | h_idx * scale: (h_idx + tile_size) * scale, 332 | w_idx * scale: (w_idx + tile_size) * scale, 333 | ].add_(out_patch) 334 | 335 | out_patch_mask = torch.ones_like(out_patch) 336 | 337 | weights[ 338 | ..., 339 | h_idx * scale: (h_idx + tile_size) * scale, 340 | w_idx * scale: (w_idx + tile_size) * scale, 341 | ].add_(out_patch_mask) 342 | 343 | pbar.update(1) 344 | 345 | output = result.div_(weights) 346 | 347 | return output 348 | 349 | 350 | def upscale_2( 351 | img: Image.Image, 352 | model, 353 | *, 354 | tile_size: int, 355 | tile_overlap: int, 356 | scale: int, 357 | desc: str, 358 | disable_progress_bar: bool, 359 | ): 360 | """ 361 | Convenience wrapper around `tiled_upscale_2` that handles PIL images. 362 | """ 363 | param = get_param(model) 364 | tensor = pil_image_to_torch_bgr(img).to(device=model.device, dtype=param.dtype).unsqueeze(0) # add batch dimension 365 | 366 | with torch.no_grad(): 367 | output = tiled_upscale_2( 368 | tensor, 369 | model, 370 | tile_size=tile_size, 371 | tile_overlap=tile_overlap, 372 | scale=scale, 373 | desc=desc, 374 | device=param.device, 375 | disable_progress_bar=disable_progress_bar, 376 | ) 377 | return torch_bgr_to_pil_image(output) 378 | 379 | 380 | def release_resources_upscaler(): 381 | torch.cuda.empty_cache() 382 | gc.collect() 383 | -------------------------------------------------------------------------------- /stablepy/diffusers_vanilla/preprocessor/main_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL 3 | import torch 4 | 5 | from ..utils import release_resources 6 | from .image_utils import ( 7 | HWC3, 8 | resize_image, 9 | apply_gaussian_blur, 10 | recolor_luminance, 11 | recolor_intensity, 12 | ) 13 | from .constans_preprocessor import ( 14 | AUX_TASKS, 15 | TRANSFORMERS_LIB_TASKS, 16 | AUX_BETA_TASKS, 17 | EXTRA_AUX_TASKS, 18 | TASK_AND_PREPROCESSORS, 19 | ) 20 | from ..utils import convert_image_to_numpy_array 21 | import cv2 22 | 23 | 24 | def calculate_new_resolution(input_image_shape, resolution): 25 | H, W = input_image_shape[:2] 26 | H = float(H) 27 | W = float(W) 28 | k = float(resolution) / max(H, W) 29 | H *= k 30 | W *= k 31 | H = int(np.round(H / 64.0)) * 64 32 | W = int(np.round(W / 64.0)) * 64 33 | return (H, W) 34 | 35 | 36 | def standardize_image(np_image, res=1024, spected_res=None) -> np.ndarray: 37 | """ 38 | Ensure the image is in HWC3 format and resize it so that the 39 | largest dimension matches the given resolution, rounding dimensions 40 | to the nearest multiple of 64. 41 | """ 42 | if np_image is None: 43 | raise ValueError("image must be defined.") 44 | 45 | if not isinstance(np_image, np.ndarray): 46 | np_image = convert_image_to_numpy_array(np_image) 47 | 48 | np_image = HWC3(np_image) 49 | np_image = resize_image(np_image, resolution=res) 50 | 51 | if spected_res is not None: 52 | if (np_image.shape[0], np_image.shape[1]) != spected_res: 53 | np_image = cv2.resize( 54 | np_image, 55 | (spected_res[1], spected_res[0]), 56 | interpolation=cv2.INTER_LANCZOS4 57 | ) 58 | 59 | return np_image 60 | 61 | 62 | def process_basic_task(image: np.ndarray, resolution: int) -> PIL.Image.Image: 63 | """Process basic tasks that require only resizing.""" 64 | image = standardize_image(image, resolution) 65 | return PIL.Image.fromarray(image) 66 | 67 | 68 | class RecolorDetector: 69 | def __call__(self, image=None, gamma_correction=1.0, image_resolution=512, mode="luminance", **kwargs): 70 | """Process the 'recolor' task.""" 71 | if mode == "luminance": 72 | func_c = recolor_luminance 73 | elif mode == "intensity": 74 | func_c = recolor_intensity 75 | else: 76 | raise ValueError("Invalid recolor mode") 77 | 78 | return func_c( 79 | standardize_image(image, image_resolution), thr_a=gamma_correction 80 | ) 81 | 82 | 83 | class BlurDetector(RecolorDetector): 84 | def __call__(self, image=None, image_resolution=512, blur_sigma=5, **kwargs): 85 | """Process the 'tile' task with Gaussian blur.""" 86 | return apply_gaussian_blur( 87 | standardize_image(image, image_resolution), ksize=blur_sigma 88 | ) 89 | 90 | 91 | class Preprocessor: 92 | MODEL_ID = "lllyasviel/Annotators" 93 | 94 | def __init__(self): 95 | self.model = None 96 | self.name = "" 97 | 98 | def _load_aux_model(self, name: str): 99 | """Lazy load models from the `controlnet_aux` library.""" 100 | import controlnet_aux as cnx 101 | 102 | model_map = { 103 | "HED": lambda: cnx.HEDdetector.from_pretrained(self.MODEL_ID), 104 | "Midas": lambda: cnx.MidasDetector.from_pretrained(self.MODEL_ID), 105 | "MLSD": lambda: cnx.MLSDdetector.from_pretrained(self.MODEL_ID), 106 | "Openpose": lambda: cnx.OpenposeDetector.from_pretrained(self.MODEL_ID), 107 | "PidiNet": lambda: cnx.PidiNetDetector.from_pretrained(self.MODEL_ID), 108 | "NormalBae": lambda: cnx.NormalBaeDetector.from_pretrained(self.MODEL_ID), 109 | "Lineart": lambda: cnx.LineartDetector.from_pretrained(self.MODEL_ID), 110 | "LineartAnime": lambda: cnx.LineartAnimeDetector.from_pretrained(self.MODEL_ID), 111 | "Canny": lambda: cnx.CannyDetector(), 112 | "ContentShuffle": lambda: cnx.ContentShuffleDetector(), 113 | } 114 | 115 | if name in model_map: 116 | return model_map[name]() 117 | 118 | raise ValueError(f"Unsupported task name: {name}") 119 | 120 | def _load_transformers_model(self, name: str): 121 | """Lazy load models from the `.transformers_lib.pipelines`.""" 122 | from .transformers_lib.pipelines import ( 123 | DPTDepthEstimator, 124 | UP_ImageSegmentor, 125 | ZoeDepth, 126 | SegFormer, 127 | DepthAnything, 128 | ) 129 | model_map = { 130 | TRANSFORMERS_LIB_TASKS[0]: DPTDepthEstimator, 131 | TRANSFORMERS_LIB_TASKS[1]: UP_ImageSegmentor, 132 | TRANSFORMERS_LIB_TASKS[2]: ZoeDepth, 133 | TRANSFORMERS_LIB_TASKS[3]: SegFormer, 134 | TRANSFORMERS_LIB_TASKS[4]: DepthAnything, 135 | } 136 | if name in model_map: 137 | return model_map[name]() 138 | raise ValueError(f"Unsupported task name: {name}") 139 | 140 | def _load_custom_model(self, name: str): 141 | """Lazy load custom models from specialized modules.""" 142 | if name == AUX_BETA_TASKS[0]: 143 | from .controlnet_aux_beta.teed import TEEDdetector 144 | return TEEDdetector() 145 | elif name == AUX_BETA_TASKS[1]: 146 | from .controlnet_aux_beta.anyline import AnylineDetector 147 | return AnylineDetector() 148 | elif name == AUX_BETA_TASKS[2]: 149 | from .controlnet_aux_beta.lineart_standard import LineartStandardDetector 150 | return LineartStandardDetector() 151 | raise ValueError(f"Unsupported task name: {name}") 152 | 153 | def _load_extra_model(self, name: str): 154 | """Lazy load custom models from specialized modules.""" 155 | if name == EXTRA_AUX_TASKS[0]: 156 | return RecolorDetector() 157 | elif name == EXTRA_AUX_TASKS[1]: 158 | return BlurDetector() 159 | raise ValueError(f"Unsupported task name: {name}") 160 | 161 | def to(self, device): 162 | if hasattr(self.model, "to"): 163 | self.model.to(device) 164 | 165 | def load(self, name: str, use_cuda: bool = False) -> None: 166 | """ 167 | Load the specified preprocessor model. 168 | Parameters: 169 | name (str): The name of the preprocessor model to load. 170 | use_cuda (bool, optional): If True, the model will be moved to GPU. Defaults to False. 171 | Raises: 172 | ValueError: If the specified preprocessor name is not recognized. 173 | Notes: 174 | - If the specified model is already loaded, the function will return early. 175 | - The function will release any previously held resources before loading a new model. 176 | - The model can be loaded from different sources based on the name provided. 177 | """ 178 | 179 | if name == self.name: 180 | if use_cuda: 181 | self.to("cuda") 182 | return # Skip if already loaded 183 | 184 | if name in AUX_TASKS: 185 | self.model = self._load_aux_model(name) 186 | elif name in TRANSFORMERS_LIB_TASKS: 187 | self.model = self._load_transformers_model(name) 188 | elif name in AUX_BETA_TASKS: 189 | self.model = self._load_custom_model(name) 190 | elif name in EXTRA_AUX_TASKS: 191 | self.model = self._load_extra_model(name) 192 | else: 193 | raise ValueError(f"Unknown preprocessor name: {name}") 194 | 195 | release_resources() 196 | 197 | self.name = name 198 | 199 | if use_cuda: 200 | self.to("cuda") 201 | 202 | def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: 203 | """ 204 | Process an image using the preprocessor function or model. 205 | 206 | Args: 207 | image (PIL.Image.Image): The input image. 208 | **kwargs: Additional parameters for preprocessing, which may include: 209 | - image_resolution (int): The proportional maximum resolution of the input image while preserving the aspect ratio (applicable for all tasks). 210 | - detect_resolution (int): The resolution to preprocess to (applicable for all tasks). 211 | - low_threshold (int): Low threshold for edge detection (applicable for Canny). 212 | - high_threshold (int): High threshold for edge detection (applicable for Canny). 213 | - thr_v (float): Threshold for MLSD value detection (applicable for MLSD). 214 | - thr_d (float): Threshold for MLSD distance detection (applicable for MLSD). 215 | - mode (str): Mode for Recolor (e.g., 'intensity' or 'luminance') (applicable for Recolor). 216 | - gamma_correction (float): Gamma correction value for Recolor (applicable for Recolor). 217 | - blur_sigma (int): Sigma value for Blur (applicable for Blur). 218 | - hand_and_face (bool): Whether to include hand and face detection (applicable for Openpose). 219 | - scribble (bool): Whether to use scribble mode (applicable for HED). 220 | - safe (bool): Whether to use safe mode (applicable for PidiNet). 221 | - coarse (bool): Whether to use coarse mode (applicable for Lineart). 222 | 223 | Returns: 224 | PIL.Image.Image: The processed image. 225 | """ 226 | 227 | if not self.model: 228 | raise RuntimeError("No model is loaded. Please call `load()` first.") 229 | 230 | if not isinstance(image, np.ndarray): 231 | image = convert_image_to_numpy_array(image) 232 | 233 | image_resolution = kwargs.get("image_resolution", 1024) 234 | spected_resolution = calculate_new_resolution(image.shape, image_resolution) 235 | 236 | if self.name == "Canny": 237 | image = self._process_canny(image, **kwargs) 238 | elif self.name == "Midas": 239 | image = self._process_midas(image, **kwargs) 240 | else: 241 | image = self.model(image, **kwargs) 242 | 243 | image = standardize_image(image, image_resolution, spected_resolution) 244 | 245 | return PIL.Image.fromarray(image) 246 | 247 | def _process_canny(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: 248 | """Process an image using the Canny preprocessor.""" 249 | detect_resolution = kwargs.pop("detect_resolution", None) 250 | image = np.array(image) 251 | image = HWC3(image) 252 | if detect_resolution: 253 | image = resize_image(image, resolution=detect_resolution) 254 | return self.model(image, **kwargs) 255 | 256 | def _process_midas(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: 257 | """Process an image using the Midas preprocessor.""" 258 | detect_resolution = kwargs.pop("detect_resolution", 512) 259 | image = np.array(image) 260 | image = HWC3(image) 261 | image = resize_image(image, resolution=detect_resolution) 262 | return self.model(image) # , **kwargs) 263 | 264 | 265 | def get_preprocessor_params( 266 | image: np.ndarray, 267 | task_name: str, 268 | preprocessor_name: str, 269 | image_resolution: int, 270 | preprocess_resolution: int, 271 | low_threshold: int, 272 | high_threshold: int, 273 | value_threshold: float, 274 | distance_threshold: float, 275 | gamma_correction: float, 276 | blur_sigma: int, 277 | ) -> tuple[dict, str]: 278 | """ 279 | Determine the parameters and model name for preprocessing. 280 | 281 | Args: 282 | image (np.ndarray): The input image. 283 | task_name (str): The name of the task. 284 | preprocessor_name (str): The name of the preprocessor. 285 | image_resolution (int): The resolution of the input image. 286 | preprocess_resolution (int): The resolution to preprocess to. 287 | low_threshold (int): Low threshold for edge detection. 288 | high_threshold (int): High threshold for edge detection. 289 | value_threshold (float): Threshold for MLSD value detection. 290 | distance_threshold (float): Threshold for MLSD distance detection. 291 | gamma_correction (float): Threshold for Recolor thr_a. 292 | blur_sigma (int): Threshold for Blur sigma. 293 | 294 | Returns: 295 | tuple[dict, str]: A dictionary of parameters for preprocessing and the model name. 296 | """ 297 | params_preprocessor = { 298 | "image": image, 299 | "image_resolution": image_resolution, 300 | "detect_resolution": preprocess_resolution, 301 | } 302 | model_name = None 303 | 304 | if task_name in ["canny", "sdxl_canny_t2i"]: 305 | params_preprocessor.update({ 306 | "low_threshold": low_threshold, 307 | "high_threshold": high_threshold, 308 | }) 309 | model_name = "Canny" 310 | elif task_name in ["openpose", "sdxl_openpose_t2i"]: 311 | params_preprocessor["hand_and_face"] = not ("core" in preprocessor_name) 312 | model_name = "Openpose" 313 | elif task_name in ["depth", "sdxl_depth-midas_t2i"]: 314 | model_name = preprocessor_name 315 | elif task_name == "mlsd": 316 | params_preprocessor.update({ 317 | "thr_v": value_threshold, 318 | "thr_d": distance_threshold, 319 | }) 320 | model_name = "MLSD" 321 | elif task_name in ["scribble", "sdxl_sketch_t2i"]: 322 | if "HED" in preprocessor_name: 323 | params_preprocessor["scribble"] = False 324 | model_name = "HED" 325 | elif "TEED" in preprocessor_name: 326 | model_name = "TEED" 327 | else: 328 | params_preprocessor["safe"] = False 329 | model_name = "PidiNet" 330 | elif task_name == "softedge": 331 | if "HED" in preprocessor_name: 332 | params_preprocessor["scribble"] = "safe" in preprocessor_name 333 | model_name = "HED" 334 | elif "TEED" in preprocessor_name: 335 | model_name = "TEED" 336 | else: 337 | params_preprocessor["safe"] = "safe" in preprocessor_name 338 | model_name = "PidiNet" 339 | elif task_name == "segmentation": 340 | model_name = preprocessor_name 341 | elif task_name == "normalbae": 342 | model_name = "NormalBae" 343 | elif task_name in ["lineart", "lineart_anime", "sdxl_lineart_t2i"]: 344 | if preprocessor_name in ["Lineart standard", "Anyline"]: 345 | model_name = preprocessor_name 346 | else: 347 | model_name = "LineartAnime" if "anime" in preprocessor_name.lower() else "Lineart" 348 | if "coarse" in preprocessor_name: 349 | params_preprocessor["coarse"] = "coarse" in preprocessor_name 350 | elif task_name == "shuffle": 351 | params_preprocessor.pop("detect_resolution", None) 352 | model_name = preprocessor_name 353 | elif task_name == "recolor": 354 | if "intensity" in preprocessor_name: 355 | params_preprocessor["mode"] = "intensity" 356 | else: 357 | params_preprocessor["mode"] = "luminance" 358 | params_preprocessor["gamma_correction"] = gamma_correction 359 | model_name = "Recolor" 360 | elif task_name == "tile": 361 | params_preprocessor["blur_sigma"] = blur_sigma 362 | model_name = "Blur" 363 | 364 | return params_preprocessor, model_name 365 | --------------------------------------------------------------------------------