├── .gitignore ├── requirements.txt ├── examples ├── base.png ├── qr_code.png ├── controlnet.png ├── inpainting.png ├── change_clothes.png ├── reference_only.png └── remove_something.png ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── README.md ├── pipelines ├── __init__.py └── jannchie.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | compel 2 | diffusers 3 | numpy -------------------------------------------------------------------------------- /examples/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/base.png -------------------------------------------------------------------------------- /examples/qr_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/qr_code.png -------------------------------------------------------------------------------- /examples/controlnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/controlnet.png -------------------------------------------------------------------------------- /examples/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/inpainting.png -------------------------------------------------------------------------------- /examples/change_clothes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/change_clothes.png -------------------------------------------------------------------------------- /examples/reference_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/reference_only.png -------------------------------------------------------------------------------- /examples/remove_something.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jannchie/ComfyUI-J/HEAD/examples/remove_something.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-j" 3 | description = "This is a completely different set of nodes than Comfy's own KSampler series. This set of nodes is based on Diffusers, which makes it easier to import models, apply prompts with weights, inpaint, reference only, controlnet, etc." 4 | version = "1.1.0" 5 | license = "LICENSE" 6 | dependencies = ["compel", "diffusers", "numpy"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/Jannchie/ComfyUI-J" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "" 14 | DisplayName = "ComfyUI-J" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'Jannchie' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-J 2 | 3 | ## Introduction 4 | 5 | Jannchie's ComfyUI custom nodes. 6 | 7 | This is a completely different set of nodes than Comfy's own KSampler series. 8 | This set of nodes is based on Diffusers, which makes it easier to import models, apply prompts with weights, inpaint, reference only, controlnet, etc. 9 | 10 | ## Installation 11 | 12 | In the `custom_nodes` directory, run 13 | 14 | ```bash 15 | git clone https://github.com/Jannchie/ComfyUI-J 16 | cd ComfyUI-J 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Examples 21 | 22 | ### Base Usage of Jannchie's Diffusers Pipeline 23 | 24 | You only have to deal with 4 nodes. The default comfy workflow uses 7 nodes to achieve the same result. 25 | 26 | ![Base Usage](./examples/base.png) 27 | 28 | ### Reference Only with Jannchie's Diffusers Pipeline 29 | 30 | ref_only supports two modes: attn and attn + adain, and can adjust the style fidelity parameter to control the style. 31 | 32 | ![Reference only](./examples/reference_only.png) 33 | 34 | ### ControlNet with Jannchie's Diffusers Pipeline 35 | 36 | ContorlNet is also easier to use. A DiffusersControlnetLoader node is provided for loading models. This node automatically detects if the corresponding ControlNet has been downloaded locally, and pulls the model from the huggingface if it has not. 37 | 38 | ![ControlNet](./examples/controlnet.png) 39 | 40 | ## Inpainting with Jannchie's Diffusers Pipeline 41 | 42 | ![Inpainting](./examples/inpainting.png) 43 | 44 | ## Remove something with Jannchie's Diffusers Pipeline 45 | 46 | ![Remove something](./examples/remove_something.png) 47 | 48 | ## Change Clothes with Jannchie's Diffusers Pipeline 49 | 50 | This is a composite application of diffusers pipeline custom node. Includes: 51 | 52 | - Reference only 53 | - ControlNet 54 | - Inpainting 55 | - Textual Inversion 56 | 57 | This is a demonstration of a simple workflow for properly dressing a character. 58 | 59 | A checkpoint for stablediffusion 1.5 is all your need. But for full automation, I use the `Comfyui_segformer_b2_clothes` custom node for generating masks. you can draw your own masks without it. 60 | 61 | ![Change Clothes](./examples/change_clothes.png) 62 | 63 | ## QR Code 64 | 65 | ![QR Code](./examples/qr_code.png) 66 | 67 | ## FAQ 68 | 69 | ### Why Diffusers? 70 | 71 | Unlike Web UI and Comfy, Diffusers is an image generation tool for researchers. It has a large ecosystem, a clearer code structure and a simpler interface. 72 | 73 | ComfyUI's KSampler is nice, but some of the features are incomplete or hard to be access, it's 2042 and I still haven't found a good Reference Only implementation; Inpaint also works differently than I thought it would; I don't understand at all why ControlNet's nodes need to pass in a CLIP; and I don't want to deal with what's going on with Latent, please just return an Image instead of making me decode it with a vae. Diffusers provides a pipeline wrapper that makes generation a lot easier. 74 | 75 | ### Why ComfyUI? 76 | 77 | But combining research results is not an easy task, Comfy is good at combining and sharing combinations with others. While debugging custom nodes as a developer can be a pain, using Comfy makes it faster to verify and share. 78 | 79 | ## TODO 80 | 81 | - [ ] Add LoRA support 82 | - [ ] Stable Diffusion XL support 83 | -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | from diffusers import AutoencoderKL, AutoencoderTiny, DPMSolverMultistepScheduler 4 | from diffusers.schedulers import ( 5 | DEISMultistepScheduler, 6 | DPMSolverMultistepScheduler, 7 | DPMSolverSinglestepScheduler, 8 | EulerAncestralDiscreteScheduler, 9 | EulerDiscreteScheduler, 10 | HeunDiscreteScheduler, 11 | KDPM2AncestralDiscreteScheduler, 12 | KDPM2DiscreteScheduler, 13 | LMSDiscreteScheduler, 14 | UniPCMultistepScheduler, 15 | ) 16 | 17 | import comfy.model_management 18 | import folder_paths 19 | 20 | from .jannchie import * 21 | 22 | schedulers = { 23 | "DPM++ 2M": DPMSolverMultistepScheduler(), 24 | "DPM++ 2M Karras": DPMSolverMultistepScheduler(use_karras_sigmas=True), 25 | "DPM++ 2M SDE": DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++"), 26 | "DPM++ 2M SDE Karras": DPMSolverMultistepScheduler( 27 | use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" 28 | ), 29 | "DPM++ SDE": DPMSolverSinglestepScheduler(), 30 | "DPM++ SDE Karras": DPMSolverSinglestepScheduler(use_karras_sigmas=True), 31 | "DPM2": KDPM2DiscreteScheduler(), 32 | "DPM2 Karras": KDPM2DiscreteScheduler(use_karras_sigmas=True), 33 | "DPM2 a": KDPM2AncestralDiscreteScheduler(), 34 | "DPM2 a Karras": KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True), 35 | "Euler": EulerDiscreteScheduler(), 36 | "Euler a": EulerAncestralDiscreteScheduler(), 37 | "Heun": HeunDiscreteScheduler(), 38 | "LMS": LMSDiscreteScheduler(), 39 | "LMS Karras": LMSDiscreteScheduler(use_karras_sigmas=True), 40 | "DEIS": DEISMultistepScheduler(), 41 | "UniPC": UniPCMultistepScheduler(), 42 | } 43 | 44 | 45 | class PipelineWrapper: 46 | 47 | def __init__( 48 | self, 49 | ckpt_path: str, 50 | vae_path: str = None, 51 | scheduler_name: str = None, 52 | use_tiny_vae: bool = False, 53 | ): 54 | scheduler = schedulers.get(scheduler_name) 55 | device = comfy.model_management.get_torch_device() 56 | vae_dtype = comfy.model_management.vae_dtype() 57 | unet_dtype = comfy.model_management.unet_dtype() 58 | if ckpt_path.endswith(".safetensors"): 59 | self.pipeline = JannchiePipeline.from_single_file( 60 | ckpt_path, 61 | torch_dtype=unet_dtype, 62 | cache_dir=folder_paths.get_folder_paths("diffusers")[0], 63 | use_safetensors=True, 64 | ) 65 | else: 66 | self.pipeline = JannchiePipeline.from_pretrained( 67 | ckpt_path, 68 | torch_dtype=unet_dtype, 69 | cache_dir=folder_paths.get_folder_paths("diffusers")[0], 70 | use_safetensors=ckpt_path.endswith(".safetensors"), 71 | ) 72 | 73 | if use_tiny_vae: 74 | self.pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( 75 | device=self.pipeline.device, dtype=vae_dtype 76 | ) 77 | 78 | elif vae_path: 79 | if vae_path.endswith(".safetensors"): 80 | self.pipeline.vae = AutoencoderKL.from_single_file( 81 | vae_path, 82 | torch_dtype=vae_dtype, 83 | cache_dir=folder_paths.get_folder_paths("diffusers"), 84 | use_safetensors=True, 85 | ) 86 | else: 87 | self.pipeline.vae = AutoencoderKL.from_pretrained( 88 | vae_path, 89 | torch_dtype=vae_dtype, 90 | cache_dir=folder_paths.get_folder_paths("diffusers"), 91 | use_safetensors=vae_path.endswith(".safetensors"), 92 | ) 93 | 94 | if scheduler: 95 | self.pipeline.scheduler = scheduler 96 | self.pipeline.to(device) 97 | self.pipeline.vae.to(vae_dtype) 98 | self.pipeline.safety_checker = None 99 | with contextlib.suppress(Exception): 100 | self.pipeline.enable_xformers_memory_efficient_attention() 101 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import random 4 | from collections import Counter 5 | 6 | import numpy as np 7 | import torch 8 | from compel import Compel, DiffusersTextualInversionManager 9 | from diffusers import StableDiffusionPipeline 10 | from diffusers.models import ControlNetModel 11 | from diffusers.utils.torch_utils import randn_tensor 12 | from PIL import Image 13 | 14 | import comfy.model_management 15 | import folder_paths 16 | from comfy.utils import ProgressBar 17 | 18 | from .pipelines import ControlNetUnit, ControlNetUnits, PipelineWrapper, schedulers 19 | 20 | 21 | def resize_with_padding(image: Image.Image, target_size: tuple[int, int]): 22 | # 打开图像 23 | 24 | # 计算缩放比例 25 | width_ratio = target_size[0] / image.width 26 | height_ratio = target_size[1] / image.height 27 | ratio = min(width_ratio, height_ratio) 28 | 29 | # 计算调整后的尺寸 30 | new_width = int(image.width * ratio) 31 | new_height = int(image.height * ratio) 32 | 33 | # 缩放图像 34 | image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) 35 | 36 | # 创建黑色背景图像 37 | background = Image.new("RGBA", target_size, (0, 0, 0, 0)) 38 | 39 | # 计算粘贴位置 40 | position = ((target_size[0] - new_width) // 2, (target_size[1] - new_height) // 2) 41 | 42 | # 粘贴调整后的图像到黑色背景上 43 | background.paste(image, position) 44 | return background 45 | 46 | 47 | def comfy_image_to_pil(image: torch.Tensor): 48 | image = image.squeeze(0) # (1, H, W, C) => (H, W, C) 49 | image = image * 255 # 0 ~ 1 => 0 ~ 255 50 | image = image.to(dtype=torch.uint8) # float32 => uint8 51 | return Image.fromarray(image.numpy()) # tensor => PIL.Image.Image 52 | 53 | 54 | def get_prompt_embeds(pipe, prompt, negative_prompt): 55 | textual_inversion_manager = DiffusersTextualInversionManager(pipe) 56 | compel = Compel( 57 | tokenizer=pipe.tokenizer, 58 | text_encoder=pipe.text_encoder, 59 | textual_inversion_manager=textual_inversion_manager, 60 | truncate_long_prompts=False, 61 | ) 62 | 63 | prompt_embeds = compel.build_conditioning_tensor(prompt) 64 | negative_prompt_embeds = compel.build_conditioning_tensor(negative_prompt) 65 | [ 66 | prompt_embeds, 67 | negative_prompt_embeds, 68 | ] = compel.pad_conditioning_tensors_to_same_length( 69 | [prompt_embeds, negative_prompt_embeds] 70 | ) 71 | return prompt_embeds, negative_prompt_embeds 72 | 73 | 74 | def latents_to_img_tensor(pipeline, latents): 75 | # 1. 输入的 latents 是一个 -1 ~ 1 之间的 tensor 76 | # 2. 先进行缩放 77 | scaled_latents = latents / pipeline.vae.config.scaling_factor 78 | # 转成 vae 类型 79 | scaled_latents = scaled_latents.to(dtype=comfy.model_management.vae_dtype()) 80 | print(scaled_latents.dtype, pipeline.vae.dtype) 81 | # 3. 解码,返回的是 -1 ~ 1 之间的 tensor 82 | dec_tensor = pipeline.vae.decode(scaled_latents, return_dict=False)[0] 83 | # 4. 缩放到 0 ~ 1 之间 84 | dec_images = pipeline.image_processor.postprocess( 85 | dec_tensor, 86 | output_type="pt", 87 | do_denormalize=[True for _ in range(scaled_latents.shape[0])], 88 | ) 89 | # 5. 转换成 tensor, 90 | res = torch.nan_to_num(dec_images).to(dtype=torch.float32) 91 | # 6. 将 channel 放到最后 92 | # res shape torch.Size([1, 3, 512, 512]) => torch.Size([1, 512, 512, 3]) 93 | res = res.permute(0, 2, 3, 1) 94 | return res 95 | 96 | 97 | def latents_to_mask_tensor(pipeline, latents): 98 | # 1. 输入的 latents 是一个 -1 ~ 1 之间的 tensor 99 | # 2. 先进行缩放 100 | scaled_latents = latents / pipeline.vae.config.scaling_factor 101 | # 3. 解码,返回的是 -1 ~ 1 之间的 tensor 102 | dec_tensor = pipeline.vae.decode(scaled_latents, return_dict=False)[0] 103 | # 4. 缩放到 0 ~ 1 之间 104 | dec_images = pipeline.mask_processor.postprocess( 105 | dec_tensor, 106 | output_type="pt", 107 | ) 108 | # 5. 转换成 tensor, 109 | res = torch.nan_to_num(dec_images).to(dtype=torch.float32) 110 | # 6. 将 channel 放到最后 111 | # res shape torch.Size([1, 3, 512, 512]) => torch.Size([1, 512, 512, 3]) 112 | res = res.permute(0, 2, 3, 1) 113 | return res 114 | 115 | 116 | def prepare_latents( 117 | pipe: StableDiffusionPipeline, 118 | batch_size: int, 119 | height: int, 120 | width: int, 121 | dtype: torch.dtype, 122 | device: torch.device, 123 | generator: torch.Generator, 124 | latents=None, 125 | ): 126 | shape = ( 127 | batch_size, 128 | pipe.unet.config.in_channels, 129 | height // pipe.vae_scale_factor, 130 | width // pipe.vae_scale_factor, 131 | ) 132 | if isinstance(generator, list) and len(generator) != batch_size: 133 | raise ValueError( 134 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 135 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 136 | ) 137 | 138 | if latents is None: 139 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 140 | else: 141 | latents = latents.to(device) 142 | 143 | # scale the initial noise by the standard deviation required by the scheduler 144 | latents = latents * pipe.scheduler.init_noise_sigma 145 | return latents 146 | 147 | 148 | def prepare_image( 149 | pipeline: StableDiffusionPipeline, 150 | seed=47, 151 | batch_size=1, 152 | height=512, 153 | width=512, 154 | ): 155 | generator = torch.Generator() 156 | generator.manual_seed(seed) 157 | latents = prepare_latents( 158 | pipe=pipeline, 159 | batch_size=batch_size, 160 | height=height, 161 | width=width, 162 | generator=generator, 163 | device=comfy.model_management.get_torch_device(), 164 | dtype=comfy.model_management.VAE_DTYPE, 165 | ) 166 | return latents_to_img_tensor(pipeline, latents) 167 | 168 | 169 | class GetFilledColorImage: 170 | RETURN_TYPES = ("IMAGE",) 171 | FUNCTION = "run" 172 | CATEGORY = "Jannchie" 173 | 174 | @classmethod 175 | def INPUT_TYPES(cls): 176 | return { 177 | "required": { 178 | "width": ( 179 | "INT", 180 | { 181 | "default": 512, 182 | "min": 0, 183 | "max": 8192, 184 | "step": 64, 185 | "display": "number", 186 | }, 187 | ), 188 | "height": ( 189 | "INT", 190 | { 191 | "default": 512, 192 | "min": 0, 193 | "max": 8192, 194 | "step": 64, 195 | "display": "number", 196 | }, 197 | ), 198 | "red": ( 199 | "FLOAT", 200 | { 201 | "default": 0.0, 202 | "min": 0.0, 203 | "max": 1.0, 204 | "step": 0.01, 205 | "display": "number", 206 | }, 207 | ), 208 | "green": ( 209 | "FLOAT", 210 | { 211 | "default": 0.0, 212 | "min": 0.0, 213 | "max": 1.0, 214 | "step": 0.01, 215 | "display": "number", 216 | }, 217 | ), 218 | "blue": ( 219 | "FLOAT", 220 | { 221 | "default": 0.0, 222 | "min": 0.0, 223 | "max": 1.0, 224 | "step": 0.01, 225 | "display": "number", 226 | }, 227 | ), 228 | }, 229 | } 230 | 231 | def run(self, width, height, red, green, blue): 232 | image = torch.tensor(np.full((height, width, 3), (red, green, blue))) 233 | # 再转换成 0 - 1 之间的浮点数 234 | image = image 235 | image = image.unsqueeze(0) 236 | return (image,) 237 | 238 | 239 | class DiffusersCompelPromptEmbedding: 240 | CATEGORY = "Jannchie" 241 | FUNCTION = "run" 242 | RETURN_TYPES = ("DIFFUSERS_PROMPT_EMBEDDING", "DIFFUSERS_PROMPT_EMBEDDING") 243 | RETURN_NAMES = ("positive prompt embedding", "negative prompt embedding") 244 | 245 | @classmethod 246 | def INPUT_TYPES(cls): 247 | return { 248 | "required": { 249 | "pipeline": ("DIFFUSERS_PIPELINE",), 250 | "positive_prompt": ( 251 | "STRING", 252 | { 253 | "multiline": True, 254 | "default": "(masterpiece)1.2, (best quality)1.4", 255 | }, 256 | ), 257 | "negative_prompt": ("STRING", {"multiline": True, "default": ""}), 258 | } 259 | } 260 | 261 | def run( 262 | self, 263 | pipeline: StableDiffusionPipeline, 264 | positive_prompt: str, 265 | negative_prompt: str, 266 | ): 267 | return get_prompt_embeds(pipeline, positive_prompt, negative_prompt) 268 | 269 | 270 | class DiffusersTextureInversionLoader: 271 | CATEGORY = "Jannchie" 272 | FUNCTION = "run" 273 | RETURN_TYPES = ("DIFFUSERS_PIPELINE",) 274 | RETURN_NAMES = ("pipeline",) 275 | 276 | @classmethod 277 | def INPUT_TYPES(cls): 278 | return { 279 | "required": { 280 | "pipeline": ("DIFFUSERS_PIPELINE",), 281 | "texture_inversion": (folder_paths.get_filename_list("embeddings"),), 282 | }, 283 | } 284 | 285 | def run(self, pipeline: StableDiffusionPipeline, texture_inversion: str): 286 | with contextlib.suppress(Exception): 287 | path = folder_paths.get_full_path("embeddings", texture_inversion) 288 | token = texture_inversion.split(".")[0] 289 | pipeline.load_textual_inversion(path, token=token) 290 | print(f"Loaded {texture_inversion}") 291 | return (pipeline,) 292 | 293 | 294 | class GetAverageColorFromImage: 295 | CATEGORY = "Jannchie" 296 | FUNCTION = "run" 297 | RETURN_TYPES = ("FLOAT", "FLOAT", "FLOAT") 298 | RETURN_NAMES = ("red", "green", "blue") 299 | 300 | @classmethod 301 | def INPUT_TYPES(cls): 302 | return { 303 | "required": { 304 | "image": ("IMAGE",), 305 | "average": (("mean", "mode"),), 306 | }, 307 | "optional": { 308 | "mask": ("MASK",), 309 | }, 310 | } 311 | 312 | def run(self, image: torch.Tensor, average: str, mask: torch.Tensor = None): 313 | if mask is not None: 314 | assert ( 315 | mask.ndim == image.ndim - 1 316 | ), "Mask dimensions must be one less than image dimensions." 317 | mask = mask.unsqueeze(3) # Unsqueeze to match (B, 1, H, W) 318 | if mask is not None and torch.sum(mask) == 0: 319 | mask = None 320 | if average == "mean": 321 | return self.run_avg(image, mask) 322 | elif average == "mode": 323 | return self.run_mode(image, mask) 324 | else: 325 | raise ValueError("average must be either 'mean' or 'mode'") 326 | 327 | def run_avg(self, image: torch.Tensor, mask: torch.Tensor = None): 328 | masked_image = image * mask if mask is not None else image 329 | 330 | pixel_sum = torch.sum(masked_image, dim=(1, 2)) 331 | if mask is not None: 332 | pixel_count = torch.sum(mask, dim=(1, 2)).unsqueeze(1) 333 | else: 334 | pixel_count = torch.tensor(image.shape[1] * image.shape[2]).unsqueeze(0) 335 | average_rgb = pixel_sum / pixel_count 336 | average_rgb = torch.round(average_rgb * 255) 337 | return tuple(average_rgb.squeeze().int().tolist()) 338 | 339 | def run_mode(self, image: torch.Tensor, mask: torch.Tensor = None): 340 | if mask is not None: 341 | image = image * mask 342 | 343 | # Flatten the image to a 2D matrix where each row is a color 344 | flattened_image = image.view(-1, image.shape[-1]) 345 | 346 | # If mask is provided, remove rows where mask is zero 347 | if mask is not None: 348 | flattened_mask = mask.view(-1, 1) 349 | flattened_image = flattened_image[flattened_mask.squeeze() > 0] 350 | 351 | # Convert the pixel values to a format that can be efficiently counted 352 | unique_colors, counts = torch.unique(flattened_image, return_counts=True, dim=0) 353 | 354 | # Find the most frequent color 355 | max_idx = torch.argmax(counts) 356 | mode_rgb = unique_colors[max_idx] 357 | 358 | mode_rgb = torch.round(mode_rgb * 255) 359 | return tuple(mode_rgb.int().tolist()) 360 | 361 | 362 | class DiffusersXLPipeline: 363 | CATEGORY = "Jannchie" 364 | FUNCTION = "run" 365 | RETURN_TYPES = ("DIFFUSERS_PIPELINE",) 366 | RETURN_NAMES = ("pipeline",) 367 | 368 | @classmethod 369 | def INPUT_TYPES(cls): 370 | return { 371 | "required": { 372 | "ckpt_name": ([],), 373 | }, 374 | "optional": { 375 | "vae_name": ( 376 | folder_paths.get_filename_list("vae") + ["-"], 377 | {"default": "-"}, 378 | ), 379 | "scheduler_name": ( 380 | list(schedulers.keys()) + ["-"], 381 | { 382 | "default": "-", 383 | }, 384 | ), 385 | "use_tiny_vae": ( 386 | ["disable", "enable"], 387 | { 388 | "default": "disable", 389 | }, 390 | ), 391 | }, 392 | } 393 | 394 | def run( 395 | self, 396 | ckpt_name: str, 397 | vae_name: str = None, 398 | scheduler_name: str = None, 399 | use_tiny_vae: str = "disable", 400 | ): 401 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 402 | if ckpt_path is None: 403 | ckpt_path = ckpt_name 404 | if vae_name == "-": 405 | vae_path = None 406 | else: 407 | vae_path = folder_paths.get_full_path("vae", vae_name) 408 | if scheduler_name == "-": 409 | scheduler_name = None 410 | 411 | self.pipeline_wrapper = PipelineWrapper( 412 | ckpt_path, 413 | vae_path, 414 | scheduler_name, 415 | pipeline=StableDiffusionPipeline, 416 | use_tiny_vae=use_tiny_vae == "enable", 417 | ) 418 | return (self.pipeline_wrapper.pipeline,) 419 | 420 | 421 | class DiffusersPipeline: 422 | CATEGORY = "Jannchie" 423 | FUNCTION = "run" 424 | RETURN_TYPES = ("DIFFUSERS_PIPELINE",) 425 | RETURN_NAMES = ("pipeline",) 426 | 427 | @classmethod 428 | def INPUT_TYPES(cls): 429 | return { 430 | "required": { 431 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 432 | }, 433 | "optional": { 434 | "vae_name": ( 435 | folder_paths.get_filename_list("vae") + ["-"], 436 | {"default": "-"}, 437 | ), 438 | "scheduler_name": ( 439 | list(schedulers.keys()) + ["-"], 440 | { 441 | "default": "-", 442 | }, 443 | ), 444 | "use_tiny_vae": ( 445 | ["disable", "enable"], 446 | { 447 | "default": "disable", 448 | }, 449 | ), 450 | }, 451 | } 452 | 453 | def run( 454 | self, 455 | ckpt_name: str, 456 | vae_name: str = None, 457 | scheduler_name: str = None, 458 | use_tiny_vae: str = "disable", 459 | ): 460 | torch.cuda.empty_cache() 461 | gc.collect() 462 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 463 | if ckpt_path is None: 464 | ckpt_path = ckpt_name 465 | if vae_name == "-": 466 | vae_path = None 467 | else: 468 | vae_path = folder_paths.get_full_path("vae", vae_name) 469 | if scheduler_name == "-": 470 | scheduler_name = None 471 | 472 | self.pipeline_wrapper = PipelineWrapper( 473 | ckpt_path, vae_path, scheduler_name, use_tiny_vae=use_tiny_vae == "enable" 474 | ) 475 | return (self.pipeline_wrapper.pipeline,) 476 | 477 | 478 | class DiffusersPrepareLatents: 479 | CATEGORY = "Jannchie" 480 | FUNCTION = "run" 481 | RETURN_TYPES = ("LATENT",) 482 | RETURN_NAMES = ("latents",) 483 | 484 | @classmethod 485 | def INPUT_TYPES(cls): 486 | return { 487 | "required": { 488 | "pipeline": ("DIFFUSERS_PIPELINE",), 489 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 16, "step": 1}), 490 | "height": ("INT", {"default": 512, "min": 0, "max": 8192, "step": 64}), 491 | "width": ("INT", {"default": 512, "min": 0, "max": 8192, "step": 64}), 492 | }, 493 | "optional": { 494 | "latents": ("LATENT", {"default": None}), 495 | "seed": ( 496 | "INT", 497 | {"default": None, "min": 0, "step": 1, "max": 999999999}, 498 | ), 499 | }, 500 | } 501 | 502 | def run( 503 | self, 504 | pipeline: StableDiffusionPipeline, 505 | batch_size: int = 1, 506 | height: int = 512, 507 | width: int = 512, 508 | latents: torch.Tensor | None = None, 509 | seed: int | None = None, 510 | ): 511 | if seed is None: 512 | seed = random.randint(0, 999999999) 513 | device = comfy.model_management.get_torch_device() 514 | generator = torch.Generator(device) 515 | generator.manual_seed(seed) 516 | latents = prepare_latents( 517 | pipe=pipeline, 518 | batch_size=batch_size, 519 | height=height, 520 | width=width, 521 | dtype=comfy.model_management.vae_dtype(), 522 | device=device, 523 | generator=generator, 524 | latents=latents, 525 | ) 526 | return (latents,) 527 | 528 | 529 | class DiffusersDecoder: 530 | CATEGORY = "Jannchie" 531 | FUNCTION = "run" 532 | RETURN_TYPES = ("IMAGE",) 533 | RETURN_NAMES = ("images",) 534 | 535 | @classmethod 536 | def INPUT_TYPES(cls): 537 | return { 538 | "required": { 539 | "pipeline": ("DIFFUSERS_PIPELINE",), 540 | "latents": ("LATENT",), 541 | }, 542 | } 543 | 544 | def run(self, pipeline: StableDiffusionPipeline, latents: torch.Tensor): 545 | res = latents_to_img_tensor(pipeline, latents) 546 | return (res,) 547 | 548 | 549 | # 'https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth' 550 | 551 | controlnet_list = [ 552 | "canny", 553 | "openpose", 554 | "depth", 555 | "tile", 556 | "ip2p", 557 | "shuffle", 558 | "inpaint", 559 | "lineart", 560 | "mlsd", 561 | "normalbae", 562 | "scribble", 563 | "seg", 564 | "softedge", 565 | "lineart_anime", 566 | "other", 567 | ] 568 | 569 | 570 | class DiffusersControlNetLoader: 571 | CATEGORY = "Jannchie" 572 | FUNCTION = "run" 573 | RETURN_TYPES = ("DIFFUSERS_CONTROLNET",) 574 | RETURN_NAMES = ("controlnet",) 575 | 576 | @classmethod 577 | def INPUT_TYPES(cls): 578 | return { 579 | "required": { 580 | "controlnet_model_name": (controlnet_list,), 581 | }, 582 | "optional": { 583 | "controlnet_model_file": (folder_paths.get_filename_list("controlnet"),) 584 | }, 585 | } 586 | 587 | def run(self, controlnet_model_name: str, controlnet_model_file: str = ""): 588 | file_list = folder_paths.get_filename_list("controlnet") 589 | if controlnet_model_name == "other": 590 | controlnet_model_path = folder_paths.get_full_path( 591 | "controlnet", controlnet_model_file 592 | ) 593 | else: 594 | if controlnet_model_name == "depth": 595 | file_name = f"control_v11f1p_sd15_{controlnet_model_name}.pth" 596 | elif controlnet_model_name == "tile": 597 | file_name = f"control_v11f1e_sd15_{controlnet_model_name}.pth" 598 | else: 599 | file_name = f"control_v11p_sd15_{controlnet_model_name}.pth" 600 | controlnet_model_path = next( 601 | ( 602 | folder_paths.get_full_path("controlnet", file) 603 | for file in file_list 604 | if file_name in file 605 | ), 606 | None, 607 | ) 608 | if controlnet_model_path is None: 609 | controlnet_model_path = f"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/{file_name}" 610 | controlnet = ControlNetModel.from_single_file( 611 | controlnet_model_path, 612 | cache_dir=folder_paths.get_folder_paths("controlnet")[0], 613 | ).to( 614 | device=comfy.model_management.get_torch_device(), 615 | dtype=comfy.model_management.unet_dtype(), 616 | ) 617 | return (controlnet,) 618 | 619 | 620 | class DiffusersControlNetUnit: 621 | CATEGORY = "Jannchie" 622 | FUNCTION = "run" 623 | RETURN_TYPES = ("CONTROLNET_UNIT",) 624 | RETURN_NAMES = ("controlnet unit",) 625 | 626 | @classmethod 627 | def INPUT_TYPES(cls): 628 | return { 629 | "required": { 630 | "controlnet": ("DIFFUSERS_CONTROLNET",), 631 | "image": ("IMAGE",), 632 | "scale": ( 633 | "FLOAT", 634 | {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}, 635 | ), 636 | "start": ( 637 | "FLOAT", 638 | {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.1}, 639 | ), 640 | "end": ( 641 | "FLOAT", 642 | {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}, 643 | ), 644 | }, 645 | } 646 | 647 | def run( 648 | self, 649 | controlnet: ControlNetModel, 650 | image: torch.Tensor, 651 | scale: float, 652 | start: float, 653 | end: float, 654 | ): 655 | unit = ControlNetUnit( 656 | controlnet=controlnet, 657 | image=comfy_image_to_pil(image), 658 | scale=scale, 659 | start=start, 660 | end=end, 661 | ) 662 | return ((unit,),) 663 | 664 | 665 | class DiffusersControlNetUnitStack: 666 | CATEGORY = "Jannchie" 667 | FUNCTION = "run" 668 | RETURN_TYPES = ("CONTROLNET_UNIT",) 669 | RETURN_NAMES = ("controlnet unit",) 670 | 671 | @classmethod 672 | def INPUT_TYPES(cls): 673 | return { 674 | "required": { 675 | "controlnet_unit_1": ("CONTROLNET_UNIT",), 676 | }, 677 | "optional": { 678 | "controlnet_unit_2": ( 679 | "CONTROLNET_UNIT", 680 | { 681 | "default": None, 682 | }, 683 | ), 684 | "controlnet_unit_3": ( 685 | "CONTROLNET_UNIT", 686 | { 687 | "default": None, 688 | }, 689 | ), 690 | }, 691 | } 692 | 693 | def run( 694 | self, 695 | controlnet_unit_1: tuple[ControlNetModel], 696 | controlnet_unit_2: tuple[ControlNetModel] | None = None, 697 | controlnet_unit_3: tuple[ControlNetModel] | None = None, 698 | ): 699 | stack = [] 700 | if controlnet_unit_1: 701 | stack += controlnet_unit_1 702 | if controlnet_unit_2: 703 | stack += controlnet_unit_2 704 | if controlnet_unit_3: 705 | stack += controlnet_unit_3 706 | return (stack,) 707 | 708 | 709 | class DiffusersGenerator: 710 | CATEGORY = "Jannchie" 711 | FUNCTION = "run" 712 | RETURN_TYPES = ("IMAGE",) 713 | RETURN_NAMES = ("images",) 714 | 715 | @classmethod 716 | def INPUT_TYPES(cls): 717 | return { 718 | "required": { 719 | "pipeline": ("DIFFUSERS_PIPELINE",), 720 | "positive_prompt_embedding": ("DIFFUSERS_PROMPT_EMBEDDING",), 721 | "negative_prompt_embedding": ("DIFFUSERS_PROMPT_EMBEDDING",), 722 | "strength": ( 723 | "FLOAT", 724 | {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.02}, 725 | ), 726 | "num_inference_steps": ( 727 | "INT", 728 | {"default": 30, "min": 1, "max": 100, "step": 1}, 729 | ), 730 | "guidance_scale": ( 731 | "FLOAT", 732 | {"default": 7.0, "min": 0.0, "max": 30.0, "step": 0.02}, 733 | ), 734 | "seed": ( 735 | "INT", 736 | {"default": 0, "min": 0, "step": 1, "max": 999999999999}, 737 | ), 738 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 16, "step": 1}), 739 | "width": ( 740 | "INT", 741 | { 742 | "default": 512, 743 | "min": 64, 744 | "max": 8192, 745 | "step": 64, 746 | }, 747 | ), 748 | "height": ( 749 | "INT", 750 | { 751 | "default": 512, 752 | "min": 64, 753 | "max": 8192, 754 | "step": 64, 755 | }, 756 | ), 757 | "reference_strength": ( 758 | "FLOAT", 759 | { 760 | "default": 1.0, 761 | "min": 0.0, 762 | "max": 1.0, 763 | "step": 0.01, 764 | }, 765 | ), 766 | "reference_style_fidelity": ( 767 | "FLOAT", 768 | { 769 | "default": 0.5, 770 | "min": 0.0, 771 | "max": 1.0, 772 | "step": 0.01, 773 | }, 774 | ), 775 | }, 776 | "optional": { 777 | "images": ("IMAGE",), 778 | "mask": ("MASK",), 779 | "controlnet_units": ("CONTROLNET_UNIT",), 780 | "reference_image": ( 781 | "IMAGE", 782 | {"default": None}, 783 | ), 784 | "reference_only": ( 785 | ["disable", "enable"], 786 | { 787 | "default": "disable", 788 | }, 789 | ), 790 | "reference_only_adain": ( 791 | ["disable", "enable"], 792 | { 793 | "default": "disable", 794 | }, 795 | ), 796 | }, 797 | } 798 | 799 | def run( 800 | self, 801 | pipeline: StableDiffusionPipeline, 802 | positive_prompt_embedding: torch.Tensor, 803 | negative_prompt_embedding: torch.Tensor, 804 | width: int, 805 | height: int, 806 | batch_size: int, 807 | images: torch.Tensor | None = None, 808 | num_inference_steps: int = 30, 809 | strength: float = 1.0, 810 | guidance_scale: float = 7.0, 811 | controlnet_units: tuple[ControlNetUnit] = None, 812 | seed=None, 813 | mask: torch.Tensor | None = None, 814 | reference_only: str = "disable", 815 | reference_only_adain: str = "disable", 816 | reference_image: torch.Tensor | None = None, 817 | reference_style_fidelity: float = 0.5, 818 | reference_strength: float = 1.0, 819 | ): 820 | reference_only = reference_only == "enable" 821 | reference_only_adain = reference_only_adain == "enable" 822 | latents = None 823 | pbar = ProgressBar(int(num_inference_steps * strength)) 824 | device = comfy.model_management.get_torch_device() 825 | if not seed: 826 | seed = random.randint(0, 999999999999) 827 | generator = torch.Generator(device) 828 | generator.manual_seed(seed) 829 | # (B, H, W, C) to (B, C, H, W) 830 | if images is None: 831 | latents = prepare_latents( 832 | pipe=pipeline, 833 | batch_size=batch_size, 834 | height=height, 835 | width=width, 836 | generator=generator, 837 | device=device, 838 | dtype=comfy.model_management.vae_dtype(), 839 | ) 840 | images = latents_to_img_tensor(pipeline, latents) 841 | else: 842 | images = images 843 | 844 | # positive_prompt_embedding 和 negative_prompt_embedding 需要匹配 batch_size 845 | positive_prompt_embedding = positive_prompt_embedding.repeat(batch_size, 1, 1) 846 | negative_prompt_embedding = negative_prompt_embedding.repeat(batch_size, 1, 1) 847 | width = images.shape[2] 848 | height = images.shape[1] 849 | 850 | def callback(*_): 851 | pbar.update(1) 852 | 853 | if controlnet_units is not None: 854 | for unit in controlnet_units: 855 | target_image_shape = (width, height) 856 | unit_img = resize_with_padding(unit.image, target_image_shape) 857 | unit.image = unit_img 858 | controlnet_units = ControlNetUnits(controlnet_units) 859 | result = pipeline( 860 | image=images, 861 | mask_image=mask, 862 | ref_image=reference_image if reference_image is not None else images, 863 | generator=generator, 864 | width=width, 865 | height=height, 866 | prompt_embeds=positive_prompt_embedding, 867 | negative_prompt_embeds=negative_prompt_embedding, 868 | num_inference_steps=num_inference_steps, 869 | guidance_scale=guidance_scale, 870 | callback_steps=1, 871 | strength=strength, 872 | controlnet_units=controlnet_units, 873 | callback=callback, 874 | reference_strength=reference_strength, 875 | reference_attn=reference_only, 876 | reference_adain=reference_only_adain, 877 | style_fidelity=reference_style_fidelity, 878 | return_dict=True, 879 | ) 880 | # image = result["images"][0] 881 | # images to torch.Tensor 882 | imgs = [np.array(img) for img in result["images"]] 883 | imgs = torch.tensor(imgs) 884 | result["images"][0].save("1.png") 885 | # 0 ~ 255 to 0 ~ 1 886 | imgs = imgs / 255 887 | # (B, C, H, W) to (B, H, W, C) 888 | torch.cuda.empty_cache() 889 | gc.collect() 890 | return (imgs,) 891 | 892 | 893 | NODE_CLASS_MAPPINGS = { 894 | "GetFilledColorImage": GetFilledColorImage, 895 | "GetAverageColorFromImage": GetAverageColorFromImage, 896 | "DiffusersPipeline": DiffusersPipeline, 897 | "DiffusersXLPipeline": DiffusersXLPipeline, 898 | "DiffusersGenerator": DiffusersGenerator, 899 | "DiffusersPrepareLatents": DiffusersPrepareLatents, 900 | "DiffusersDecoder": DiffusersDecoder, 901 | "DiffusersCompelPromptEmbedding": DiffusersCompelPromptEmbedding, 902 | "DiffusersTextureInversionLoader": DiffusersTextureInversionLoader, 903 | "DiffusersControlnetLoader": DiffusersControlNetLoader, 904 | "DiffusersControlnetUnit": DiffusersControlNetUnit, 905 | "DiffusersControlnetUnitStack": DiffusersControlNetUnitStack, 906 | } 907 | NODE_DISPLAY_NAME_MAPPINGS = { 908 | "GetFilledColorImage": "Get Filled Color Image Jannchie", 909 | "GetAverageColorFromImage": "Get Average Color From Image Jannchie", 910 | "DiffusersPipeline": "🤗 Diffusers Pipeline", 911 | "DiffusersXLPipeline": "🤗 Diffusers XL Pipeline", 912 | "DiffusersGenerator": "🤗 Diffusers Generator", 913 | "DiffusersPrepareLatents": "🤗 Diffusers Prepare Latents", 914 | "DiffusersDecoder": "🤗 Diffusers Decoder", 915 | "DiffusersCompelPromptEmbedding": "🤗 Diffusers Compel Prompt Embedding", 916 | "DiffusersTextureInversionLoader": "🤗 Diffusers Texture Inversion Embedding Loader", 917 | "DiffusersControlnetLoader": "🤗 Diffusers Controlnet Loader", 918 | "DiffusersControlnetUnit": "🤗 Diffusers Controlnet Unit", 919 | "DiffusersControlnetUnitStack": "🤗 Diffusers Controlnet Unit Stack", 920 | } 921 | -------------------------------------------------------------------------------- /pipelines/jannchie.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 2 | import inspect 3 | import logging 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import PIL.Image 9 | import torch 10 | from diffusers import StableDiffusionControlNetPipeline 11 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 12 | from diffusers.models import ControlNetModel, UNet2DConditionModel 13 | from diffusers.models.attention import BasicTransformerBlock 14 | from diffusers.models.autoencoders import AutoencoderKL 15 | from diffusers.models.unets.unet_2d_blocks import ( 16 | CrossAttnDownBlock2D, 17 | CrossAttnUpBlock2D, 18 | DownBlock2D, 19 | UNetMidBlock2DCrossAttn, 20 | UpBlock2D, 21 | ) 22 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel 23 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 24 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 25 | StableDiffusionSafetyChecker, 26 | ) 27 | from diffusers.schedulers import KarrasDiffusionSchedulers 28 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor 29 | from transformers import ( 30 | CLIPImageProcessor, 31 | CLIPTextModel, 32 | CLIPTokenizer, 33 | CLIPVisionModelWithProjection, 34 | ) 35 | 36 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 37 | logger = logging.getLogger(__name__) 38 | logger.setLevel(logging.INFO) 39 | ch = logging.StreamHandler() 40 | ch.setFormatter(formatter) 41 | logger.addHandler(ch) 42 | 43 | basic_transformer_idx = 0 44 | 45 | 46 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 47 | def retrieve_latents( 48 | encoder_output: torch.Tensor, 49 | generator: Optional[torch.Generator] = None, 50 | sample_mode: str = "sample", 51 | ): 52 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 53 | return encoder_output.latent_dist.sample(generator) 54 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 55 | return encoder_output.latent_dist.mode() 56 | elif hasattr(encoder_output, "latents"): 57 | return encoder_output.latents 58 | else: 59 | raise AttributeError("Could not access latents of provided encoder_output") 60 | 61 | 62 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 63 | def retrieve_timesteps( 64 | scheduler, 65 | num_inference_steps: Optional[int] = None, 66 | device: Optional[Union[str, torch.device]] = None, 67 | timesteps: Optional[List[int]] = None, 68 | **kwargs, 69 | ): 70 | """ 71 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 72 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 73 | 74 | Args: 75 | scheduler (`SchedulerMixin`): 76 | The scheduler to get timesteps from. 77 | num_inference_steps (`int`): 78 | The number of diffusion steps used when generating samples with a pre-trained model. If used, 79 | `timesteps` must be `None`. 80 | device (`str` or `torch.device`, *optional*): 81 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 82 | timesteps (`List[int]`, *optional*): 83 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 84 | timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` 85 | must be `None`. 86 | 87 | Returns: 88 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 89 | second element is the number of inference steps. 90 | """ 91 | if timesteps is not None: 92 | accepts_timesteps = "timesteps" in set( 93 | inspect.signature(scheduler.set_timesteps).parameters.keys() 94 | ) 95 | if not accepts_timesteps: 96 | raise ValueError( 97 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 98 | f" timestep schedules. Please check whether you are using the correct scheduler." 99 | ) 100 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 101 | timesteps = scheduler.timesteps 102 | num_inference_steps = len(timesteps) 103 | else: 104 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 105 | timesteps = scheduler.timesteps 106 | return timesteps, num_inference_steps 107 | 108 | 109 | def _images_to_tensors( 110 | imgs: List[PIL.Image.Image], 111 | width: int, 112 | height: int, 113 | device: torch.device, 114 | dtype: torch.dtype, 115 | ) -> torch.Tensor: 116 | buf = [] 117 | for image_ in imgs: 118 | assert isinstance(image_, PIL.Image.Image) 119 | image_ = image_.convert("RGB") 120 | image_ = image_.resize((width, height), resample=PIL.Image.Resampling.LANCZOS) 121 | image_ = np.array(image_) 122 | image_ = image_[None, :] 123 | buf.append(image_) 124 | 125 | image = np.concatenate(buf, axis=0) 126 | image = np.array(image).astype(np.float32) / 255.0 127 | image = (image - 0.5) / 0.5 128 | image = image.transpose(0, 3, 1, 2) 129 | image = torch.from_numpy(image) 130 | 131 | assert isinstance(image, torch.Tensor) 132 | 133 | image = image.to(device=device, dtype=dtype) 134 | 135 | return image 136 | 137 | 138 | def mask_images_to_float_tensor( 139 | imgs: List[PIL.Image.Image], 140 | resize_wh: Optional[Tuple[int, int]] = None, 141 | resample: Optional[PIL.Image.Resampling] = None, 142 | ) -> torch.Tensor: 143 | width, height = imgs[0].size 144 | if resize_wh is not None: 145 | width, height = resize_wh 146 | if resample is None: 147 | resample = PIL.Image.Resampling.LANCZOS 148 | mask = [i.resize((width, height), resample=resample) for i in imgs] 149 | else: 150 | mask = imgs 151 | mask = np.stack([np.array(m.convert("L")) for m in mask], axis=0) 152 | assert mask.shape == (len(imgs), height, width) 153 | mask = mask.astype(np.float32) / 255.0 154 | mask = torch.from_numpy(mask) 155 | 156 | assert mask.shape[0] == len(imgs) 157 | if mask.min() < 0 or mask.max() > 1: 158 | raise ValueError("Mask should be in [0, 1] range") 159 | return mask 160 | 161 | 162 | def torch_dfs(model: torch.nn.Module): 163 | result = [model] 164 | for child in model.children(): 165 | result += torch_dfs(child) 166 | return result 167 | 168 | 169 | @dataclass 170 | class ControlNetUnit: 171 | controlnet: ControlNetModel 172 | image: PIL.Image.Image 173 | scale: float 174 | start: float 175 | end: float 176 | 177 | 178 | class ControlNetUnits: 179 | def __init__( 180 | self, 181 | units: tuple[ControlNetUnit], 182 | ): 183 | self.controlnets = [unit.controlnet for unit in units] 184 | self.images = [unit.image for unit in units] 185 | self.scales = [unit.scale for unit in units] 186 | self.starts = [unit.start for unit in units] 187 | self.ends = [unit.end for unit in units] 188 | 189 | 190 | class JannchiePipeline(StableDiffusionControlNetPipeline): 191 | @torch.no_grad() 192 | def __call__( 193 | self, 194 | prompt: Union[str, List[str]] = None, 195 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 196 | ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 197 | ref_image_mask: Union[torch.FloatTensor, PIL.Image.Image] = None, 198 | height: Optional[int] = None, 199 | width: Optional[int] = None, 200 | num_inference_steps: int = 50, 201 | guidance_scale: float = 7.5, 202 | negative_prompt: Optional[Union[str, List[str]]] = None, 203 | num_images_per_prompt: Optional[int] = 1, 204 | eta: float = 0.0, 205 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 206 | latents: Optional[torch.FloatTensor] = None, 207 | prompt_embeds: Optional[torch.FloatTensor] = None, 208 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 209 | output_type: Optional[str] = "pil", 210 | return_dict: bool = True, 211 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 212 | callback_steps: int = 1, 213 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 214 | controlnet_units: ControlNetUnits = None, 215 | guess_mode: bool = False, 216 | reference_attn: bool = False, 217 | reference_adain: bool = False, 218 | attention_auto_machine_weight: float = 100.0, 219 | gn_auto_machine_weight: float = 1.0, 220 | style_fidelity: float = 0.5, 221 | write_mask: Union[torch.FloatTensor, PIL.Image.Image] = None, 222 | bool_mask=False, 223 | desc: Optional[str] = None, 224 | strength=1.0, 225 | timesteps: List[int] = None, 226 | mask_image: PipelineImageInput = None, 227 | masked_image_latents: Optional[torch.FloatTensor] = None, 228 | ip_adapter_image: Optional[PipelineImageInput] = None, 229 | ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, 230 | reference_strength: float = 1.0, 231 | *arg, 232 | **args, 233 | ): 234 | device = self.unet.device 235 | if height == None: 236 | if isinstance(image, torch.Tensor): 237 | if image is not None: 238 | height = image.shape[-2] 239 | elif ref_image is not None: 240 | height = ref_image.shape[-2] 241 | else: 242 | height = 512 243 | elif isinstance(image, PIL.Image.Image): 244 | _, height = image.size 245 | if width == None: 246 | if isinstance(image, torch.Tensor): 247 | if image is not None: 248 | width = image.shape[-1] 249 | elif ref_image is not None: 250 | width = ref_image.shape[-1] 251 | else: 252 | width = 512 253 | elif isinstance(image, PIL.Image.Image): 254 | width, _ = image.size 255 | 256 | if arg or args: 257 | logger.warning(f"Unused arguments: {arg}, {args}") 258 | if desc is None: 259 | desc = "Jannchie's Pipeline" 260 | self.set_progress_bar_config(desc=desc) 261 | controlnet_conditioning_scale = [] 262 | control_guidance_start = [] 263 | control_guidance_end = [] 264 | if controlnet_units: 265 | self.controlnet = MultiControlNetModel( 266 | controlnets=controlnet_units.controlnets 267 | ) 268 | controlnet_images = controlnet_units.images 269 | control_guidance_start = controlnet_units.starts 270 | control_guidance_end = controlnet_units.ends 271 | controlnet_conditioning_scale = controlnet_units.scales 272 | else: 273 | controlnet_images = [] 274 | self.controlnet = MultiControlNetModel(controlnets=[]) 275 | if not reference_attn and not reference_adain: 276 | ref_image = None 277 | if self.controlnet: 278 | controlnet = ( 279 | self.controlnet._orig_mod 280 | if is_compiled_module(self.controlnet) 281 | else self.controlnet 282 | ) 283 | controlnet.to(device) 284 | n_controlnet_unit = ( 285 | len(controlnet.nets) 286 | if isinstance(controlnet, MultiControlNetModel) 287 | else 0 288 | ) 289 | else: 290 | n_controlnet_unit = 0 291 | # 1. Check inputs. Raise error if not correct 292 | self.check_inputs( 293 | prompt, 294 | controlnet_images, 295 | callback_steps, 296 | negative_prompt, 297 | prompt_embeds, 298 | negative_prompt_embeds, 299 | controlnet_conditioning_scale=controlnet_conditioning_scale, 300 | control_guidance_start=control_guidance_start, 301 | control_guidance_end=control_guidance_end, 302 | ) 303 | 304 | # 2. Define call parameters 305 | if prompt is not None and isinstance(prompt, str): 306 | batch_size = 1 307 | elif prompt is not None and isinstance(prompt, list): 308 | batch_size = len(prompt) 309 | else: 310 | # 输入的是 prompt_embeds 311 | batch_size = prompt_embeds.shape[0] 312 | 313 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 314 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 315 | # corresponds to doing no classifier free guidance. 316 | do_classifier_free_guidance = guidance_scale > 1.0 317 | if self.controlnet: 318 | if len(self.controlnet.nets) > 1: 319 | assert isinstance(controlnet_images, list) 320 | if n_controlnet_unit != 0: 321 | global_pool_conditions = ( 322 | controlnet.config.global_pool_conditions 323 | if isinstance(controlnet, ControlNetModel) 324 | else controlnet.nets[0].config.global_pool_conditions 325 | ) 326 | guess_mode = guess_mode or global_pool_conditions 327 | 328 | # 3. Encode input prompt 329 | logger.debug("Encoding prompt") 330 | text_encoder_lora_scale = ( 331 | cross_attention_kwargs.get("scale", None) 332 | if cross_attention_kwargs is not None 333 | else None 334 | ) 335 | prompt_embeds = self.encode_prompt( 336 | prompt, 337 | device, 338 | num_images_per_prompt, 339 | do_classifier_free_guidance, 340 | negative_prompt, 341 | prompt_embeds=prompt_embeds, 342 | negative_prompt_embeds=negative_prompt_embeds, 343 | lora_scale=text_encoder_lora_scale, 344 | ) 345 | prompt_embeds = torch.cat(prompt_embeds[::-1], dim=0) 346 | 347 | # 4. Prepare image 348 | logger.debug("Preparing image") 349 | if n_controlnet_unit != 0: 350 | if isinstance(controlnet, ControlNetModel): 351 | controlnet_images = self.prepare_image( 352 | image=controlnet_images, 353 | width=width, 354 | height=height, 355 | batch_size=batch_size * num_images_per_prompt, 356 | num_images_per_prompt=num_images_per_prompt, 357 | device=device, 358 | dtype=controlnet.dtype, 359 | do_classifier_free_guidance=do_classifier_free_guidance, 360 | guess_mode=guess_mode, 361 | ).to(device=device) 362 | height, width = controlnet_images.shape[-2:] 363 | elif isinstance(controlnet, MultiControlNetModel): 364 | images = [] 365 | 366 | for image_ in controlnet_images: 367 | image_ = self.prepare_image( 368 | image=image_, 369 | width=width, 370 | height=height, 371 | batch_size=batch_size * num_images_per_prompt, 372 | num_images_per_prompt=num_images_per_prompt, 373 | device=device, 374 | dtype=controlnet.dtype, 375 | do_classifier_free_guidance=do_classifier_free_guidance, 376 | guess_mode=guess_mode, 377 | ).to(device=device) 378 | 379 | images.append(image_) 380 | 381 | controlnet_images = images 382 | height, width = controlnet_images[0].shape[-2:] 383 | else: 384 | assert False 385 | 386 | # 5. Preprocess reference image 387 | logger.debug("Preprocessing reference image") 388 | if ref_image is not None: 389 | if isinstance(ref_image, PIL.Image.Image): 390 | ref_image = self.image_processor.preprocess( 391 | ref_image, height=height, width=width 392 | ) 393 | ref_image = self.norm_image_tensor(ref_image) 394 | # 6. Prepare timesteps 395 | logger.debug("Preparing timesteps") 396 | timesteps, num_inference_steps = retrieve_timesteps( 397 | self.scheduler, num_inference_steps, device, timesteps 398 | ) 399 | timesteps, num_inference_steps = self.get_timesteps( 400 | num_inference_steps=num_inference_steps, 401 | strength=strength, 402 | ) 403 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 404 | 405 | # 7. Prepare latent variables 406 | num_channels_latents = self.unet.config.in_channels 407 | if image is not None: 408 | if isinstance(image, PIL.Image.Image): 409 | image = self.image_processor.preprocess( 410 | image, height=height, width=width 411 | ) 412 | if isinstance(image, torch.Tensor): 413 | image = self.norm_image_tensor(image) 414 | input_latents = self.image_to_latents( 415 | image, 416 | batch_size * num_images_per_prompt, 417 | self.unet.dtype, 418 | device, 419 | generator, 420 | False, # it will duplicate the latents after this step 421 | ) 422 | 423 | num_channels_unet = self.unet.config.in_channels 424 | return_image_latents = num_channels_unet == 4 425 | latents_outputs = self.prepare_latents( 426 | batch_size * num_images_per_prompt, 427 | num_channels_latents, 428 | height, 429 | width, 430 | self.unet.dtype, 431 | device, 432 | generator, 433 | latents, 434 | image, 435 | latent_timestep, 436 | is_strength_max=strength == 1.0, 437 | return_noise=True, 438 | return_image_latents=return_image_latents, 439 | ) 440 | if return_image_latents: 441 | input_latents, noise, image_latents = latents_outputs 442 | else: 443 | input_latents, noise = latents_outputs 444 | 445 | # 7. Prepare mask latent variables 446 | if mask_image is not None: 447 | mask_condition = self.mask_processor.preprocess( 448 | mask_image, height=height, width=width 449 | ).to(device=device) 450 | init_image = image 451 | init_image = init_image.to(dtype=torch.float32, device=device) 452 | if masked_image_latents is None: 453 | masked_image = init_image * (mask_condition < 0.5) 454 | else: 455 | masked_image = masked_image_latents 456 | mask, masked_image_latents = self.prepare_mask_latents( 457 | mask_condition, 458 | masked_image, 459 | batch_size * num_images_per_prompt, 460 | height, 461 | width, 462 | self.unet.dtype, 463 | device, 464 | generator, 465 | do_classifier_free_guidance, 466 | ) 467 | 468 | # 8. Prepare reference latent variables 469 | if ref_image is not None: 470 | ref_image_latents = self.image_to_latents( 471 | ref_image, 472 | batch_size * num_images_per_prompt, 473 | self.unet.dtype, 474 | device, 475 | generator, 476 | do_classifier_free_guidance, 477 | ) 478 | 479 | # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 480 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 481 | 482 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 483 | image_embeds = self.prepare_ip_adapter_image_embeds( 484 | ip_adapter_image, 485 | ip_adapter_image_embeds, 486 | device, 487 | batch_size * num_images_per_prompt, 488 | do_classifier_free_guidance, 489 | ) 490 | 491 | # Add image embeds for IP-Adapter 492 | added_cond_kwargs = ( 493 | {"image_embeds": image_embeds} 494 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) 495 | else {} 496 | ) 497 | 498 | # text_embeds for reference, TODO: I forgot why it is needed 499 | added_cond_kwargs["text_embeds"] = prompt_embeds 500 | 501 | ref_mask_dict, out_mask_dict = self.get_ref_mask_dicts( 502 | ref_image_mask, 503 | height, 504 | width, 505 | num_images_per_prompt, 506 | write_mask, 507 | bool_mask, 508 | device, 509 | batch_size, 510 | ) 511 | ref_data = ReferenceData( 512 | ref_image=ref_image, 513 | ref_image_mask=ref_image_mask, 514 | style_fidelity=style_fidelity, 515 | attention_auto_machine_weight=attention_auto_machine_weight, 516 | gn_auto_machine_weight=gn_auto_machine_weight, 517 | ref_mask_dict=ref_mask_dict, 518 | out_mask_dict=out_mask_dict, 519 | strength=reference_strength, 520 | ) 521 | if reference_attn: 522 | self.unet = ReferenceOnlyUNet2DConditionModel.from_unet( 523 | self.unet, 524 | ref_data, 525 | reference_attn, 526 | reference_adain, 527 | ) 528 | else: 529 | self.unet = ReferenceOnlyUNet2DConditionModel.revert_unet( 530 | self.unet 531 | ) # 9. Modify self attention and group norm 532 | if ref_image is not None: 533 | self.unet.ref_data.MODE = "write" 534 | self.unet.ref_data.uc_mask = ( 535 | torch.Tensor( 536 | [1] * batch_size * num_images_per_prompt 537 | + [0] * batch_size * num_images_per_prompt 538 | ) 539 | .type_as(ref_image_latents) 540 | .bool() 541 | ) 542 | 543 | if self.controlnet: 544 | # Create tensor stating which controlnets to keep 545 | controlnet_keep = [] 546 | for i in range(len(timesteps)): 547 | keeps = [ 548 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 549 | for s, e in zip(control_guidance_start, control_guidance_end) 550 | ] 551 | controlnet_keep.append( 552 | keeps[0] if isinstance(controlnet, ControlNetModel) else keeps 553 | ) 554 | 555 | # 11. Denoising loop 556 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 557 | with self.progress_bar(total=num_inference_steps) as progress_bar: 558 | for i, t in enumerate(timesteps): 559 | # ref only part 560 | if reference_attn: 561 | self.unet.ref_data.progress = i / num_inference_steps 562 | 563 | if ref_image is not None: 564 | single_shape = (1,) + ref_image_latents.shape[1:] 565 | single_noise = randn_tensor( 566 | single_shape, 567 | generator=generator, 568 | device=device, 569 | dtype=ref_image_latents.dtype, 570 | ) 571 | noise_for_ref = single_noise.repeat_interleave( 572 | ref_image_latents.shape[0], dim=0 573 | ) 574 | ref_xt = self.scheduler.add_noise( 575 | ref_image_latents, 576 | noise_for_ref, 577 | t.reshape( 578 | 1, 579 | ), 580 | ) 581 | # ref_xt = self.scheduler.scale_model_input(ref_xt, t) 582 | 583 | self.unet.ref_data.MODE = "write" 584 | self.unet( 585 | ref_xt, 586 | t, 587 | encoder_hidden_states=prompt_embeds, 588 | cross_attention_kwargs=cross_attention_kwargs, 589 | return_dict=False, 590 | added_cond_kwargs=added_cond_kwargs, 591 | ) 592 | self.unet.ref_data.MODE = "read" 593 | 594 | # expand the latents if we are doing classifier free guidance 595 | latent_model_input = ( 596 | torch.cat([input_latents] * 2) 597 | if do_classifier_free_guidance 598 | else input_latents 599 | ) 600 | 601 | # controlnet(s) inference 602 | if guess_mode and do_classifier_free_guidance: 603 | # Infer ControlNet only for the conditional batch. 604 | control_model_input = input_latents 605 | control_model_input = self.scheduler.scale_model_input( 606 | control_model_input, t 607 | ) 608 | controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] 609 | else: 610 | control_model_input = latent_model_input 611 | controlnet_prompt_embeds = prompt_embeds 612 | 613 | # calculate final conditioning_scale 614 | if isinstance(controlnet_keep[i], list): 615 | cond_scale = [ 616 | c * s 617 | for c, s in zip( 618 | controlnet_conditioning_scale, controlnet_keep[i] 619 | ) 620 | ] 621 | else: 622 | controlnet_cond_scale = controlnet_conditioning_scale 623 | if isinstance(controlnet_cond_scale, list): 624 | controlnet_cond_scale = controlnet_cond_scale[0] 625 | cond_scale = controlnet_cond_scale * controlnet_keep[i] 626 | 627 | assert isinstance( 628 | self.controlnet, (ControlNetModel, MultiControlNetModel) 629 | ) 630 | if n_controlnet_unit != 0: 631 | down_block_res_samples, mid_block_res_sample = self.controlnet( 632 | control_model_input.to( 633 | device=device, dtype=self.controlnet.dtype 634 | ), 635 | t, 636 | encoder_hidden_states=controlnet_prompt_embeds, 637 | controlnet_cond=controlnet_images, 638 | conditioning_scale=cond_scale, 639 | guess_mode=guess_mode, 640 | return_dict=False, 641 | ) 642 | 643 | if guess_mode and do_classifier_free_guidance: 644 | # Infered ControlNet only for the conditional batch. 645 | # To apply the output of ControlNet to both the unconditional and conditional batches, 646 | # add 0 to the unconditional batch to keep it unchanged. 647 | down_block_res_samples = [ 648 | torch.cat([torch.zeros_like(d), d]) 649 | for d in down_block_res_samples 650 | ] 651 | mid_block_res_sample = torch.cat( 652 | [ 653 | torch.zeros_like(mid_block_res_sample), 654 | mid_block_res_sample, 655 | ] 656 | ) 657 | else: 658 | down_block_res_samples, mid_block_res_sample = None, None 659 | # predict the noise residual 660 | noise_pred = self.unet( 661 | latent_model_input.to(device=device, dtype=self.unet.dtype), 662 | t, 663 | encoder_hidden_states=prompt_embeds.to( 664 | device=device, dtype=self.unet.dtype 665 | ), 666 | cross_attention_kwargs=cross_attention_kwargs, 667 | down_block_additional_residuals=down_block_res_samples, 668 | mid_block_additional_residual=mid_block_res_sample, 669 | added_cond_kwargs=added_cond_kwargs, 670 | )["sample"] 671 | # perform guidance 672 | if do_classifier_free_guidance: 673 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 674 | noise_pred = noise_pred_uncond + guidance_scale * ( 675 | noise_pred_text - noise_pred_uncond 676 | ) 677 | 678 | # compute the previous noisy sample x_t -> x_t-1 679 | input_latents = self.scheduler.step( 680 | noise_pred, t, input_latents, **extra_step_kwargs 681 | )["prev_sample"] 682 | if num_channels_unet == 4 and ( 683 | mask_image is not None or masked_image_latents is not None 684 | ): 685 | init_latents_proper = image_latents 686 | if do_classifier_free_guidance: 687 | init_mask, _ = mask.chunk(2) 688 | else: 689 | init_mask = mask 690 | 691 | if i < len(timesteps) - 1: 692 | noise_timestep = timesteps[i + 1] 693 | init_latents_proper = self.scheduler.add_noise( 694 | init_latents_proper, noise, torch.tensor([noise_timestep]) 695 | ) 696 | init_latents_proper = init_latents_proper.to( 697 | device=device, dtype=self.unet.dtype 698 | ) 699 | 700 | input_latents = ( 701 | 1 - init_mask 702 | ) * init_latents_proper + init_mask * input_latents 703 | # call the callback, if provided 704 | if i == len(timesteps) - 1 or ( 705 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 706 | ): 707 | progress_bar.update() 708 | if callback is not None and i % callback_steps == 0: 709 | step_idx = i // getattr(self.scheduler, "order", 1) 710 | callback(step_idx, t, input_latents) 711 | # If we do sequential model offloading, let's offload unet and controlnet 712 | # manually for max memory savings 713 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 714 | self.unet.to("cpu") 715 | self.controlnet.to("cpu") 716 | torch.cuda.empty_cache() 717 | 718 | if output_type != "latent": 719 | input_latents = input_latents.to(device=device, dtype=self.vae.dtype) 720 | result_imgs = self.vae.decode( 721 | input_latents / self.vae.config.scaling_factor, return_dict=False 722 | )[0] 723 | result_imgs, has_nsfw_concept = self.run_safety_checker( 724 | result_imgs, device, prompt_embeds.dtype 725 | ) 726 | else: 727 | result_imgs = input_latents 728 | has_nsfw_concept = None 729 | 730 | if has_nsfw_concept is None: 731 | do_denormalize = [True] * result_imgs.shape[0] 732 | else: 733 | if isinstance(has_nsfw_concept, list): 734 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 735 | else: 736 | do_denormalize = [not has_nsfw_concept] 737 | # nan to zero 738 | result_imgs = torch.nan_to_num(result_imgs, nan=0.0, posinf=0.0, neginf=0.0) 739 | result_imgs = self.image_processor.postprocess( 740 | result_imgs, output_type=output_type, do_denormalize=do_denormalize 741 | ) 742 | 743 | # Offload last model to CPU 744 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 745 | self.final_offload_hook.offload() 746 | 747 | if not return_dict: 748 | return (result_imgs, has_nsfw_concept) 749 | img_out = self.get_img_from_latents(latents=input_latents) 750 | 751 | modules = torch_dfs(self.unet) 752 | # 卸载 ref only hack 753 | for module in modules: 754 | 755 | if getattr(module, "_original_inner_forward", None) is not None: 756 | # unregister the attn forward hook 757 | module.forward = module._original_inner_forward 758 | 759 | if getattr(module, "original_forward", None) is not None: 760 | # unregister the adain forward hook 761 | module.forward = module.original_forward 762 | 763 | return StableDiffusionPipelineOutput( 764 | images=img_out, nsfw_content_detected=has_nsfw_concept 765 | ) 766 | 767 | def get_ref_mask_dicts( 768 | self, 769 | ref_image_mask, 770 | height, 771 | width, 772 | num_images_per_prompt, 773 | write_mask, 774 | bool_mask, 775 | device, 776 | batch_size, 777 | ): 778 | latent_width = width // self.vae_scale_factor 779 | latent_height = height // self.vae_scale_factor 780 | ref_mask_dict = {} 781 | out_mask_dict = {} 782 | for i in range(4): 783 | w = latent_width >> i 784 | h = latent_height >> i 785 | 786 | resize_wh = (w, h) 787 | if ref_image_mask: 788 | # resize ref_iamge_mask 789 | tmp_mt_key = mask_images_to_float_tensor( 790 | [ref_image_mask], 791 | resize_wh=resize_wh, 792 | ).to(device=device, dtype=self.unet.dtype) 793 | 794 | mt_key = ( 795 | tmp_mt_key.flatten() > 0.5 if bool_mask else tmp_mt_key.flatten() 796 | ) 797 | ref_mask_dict[mt_key.shape[-1]] = mt_key.repeat( 798 | batch_size * num_images_per_prompt, 1 799 | ) 800 | 801 | if write_mask: 802 | tmp_mt_query = mask_images_to_float_tensor( 803 | [write_mask], 804 | resize_wh=resize_wh, 805 | ).to(device=device, dtype=self.unet.dtype) 806 | if bool_mask: 807 | mt_query = tmp_mt_query.flatten(1) > 0.5 808 | else: 809 | mt_query = tmp_mt_query.flatten(1) 810 | out_mask_dict[mt_query.shape[-1]] = mt_query.repeat( 811 | batch_size * num_images_per_prompt, 1 812 | ) 813 | 814 | return ref_mask_dict, out_mask_dict 815 | 816 | def norm_image_tensor(self, ref_image): 817 | # 如果 image 维度为 3,说明没有 batch 维度 818 | if len(ref_image.shape) == 3: 819 | # 增加 batch 维度 820 | ref_image = ref_image.unsqueeze(0) 821 | if ref_image.shape[3] == 3: 822 | # 转换成 channel 在前的形式 823 | ref_image = ref_image.permute(0, 3, 1, 2) 824 | if ref_image.min() >= 0: 825 | # 数值为 0 ~ 1 826 | # 数值规范到 -1 ~ 1 827 | ref_image = (ref_image * 2 - 1).clamp(-1, 1) 828 | return ref_image 829 | 830 | def __init__( 831 | self, 832 | vae: AutoencoderKL, 833 | text_encoder: CLIPTextModel, 834 | tokenizer: CLIPTokenizer, 835 | unet: UNet2DConditionModel, 836 | scheduler: KarrasDiffusionSchedulers, 837 | safety_checker: StableDiffusionSafetyChecker, 838 | feature_extractor: CLIPImageProcessor, 839 | controlnet: Union[ 840 | ControlNetModel, 841 | List[ControlNetModel], 842 | Tuple[ControlNetModel], 843 | MultiControlNetModel, 844 | ] = None, 845 | image_encoder: CLIPVisionModelWithProjection = None, 846 | requires_safety_checker: bool = True, 847 | ): 848 | if controlnet is None: 849 | controlnet = [] 850 | pipe_class_name = self.__class__.__name__ 851 | self.set_progress_bar_config( 852 | desc=f"Running {pipe_class_name}...", 853 | unit_scale=True, 854 | bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]", 855 | ) 856 | self.vae: AutoencoderKL 857 | self.text_encoder: CLIPTextModel 858 | self.tokenizer: CLIPTokenizer 859 | self.unet: UNet2DConditionModel 860 | super().__init__( 861 | vae=vae, 862 | text_encoder=text_encoder, 863 | tokenizer=tokenizer, 864 | unet=unet, 865 | controlnet=controlnet, 866 | scheduler=scheduler, 867 | safety_checker=safety_checker, 868 | feature_extractor=feature_extractor, 869 | image_encoder=image_encoder, 870 | requires_safety_checker=requires_safety_checker, 871 | ) 872 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 873 | self.mask_processor = VaeImageProcessor( 874 | vae_scale_factor=self.vae_scale_factor, 875 | do_normalize=False, 876 | do_resize=False, 877 | do_binarize=True, 878 | do_convert_grayscale=True, 879 | ) 880 | 881 | def image_to_latents( 882 | self, 883 | image, 884 | batch_size, 885 | dtype, 886 | device, 887 | generator, 888 | do_classifier_free_guidance, 889 | ): 890 | image = image.to(device=device, dtype=dtype) 891 | 892 | # encode the mask image into latents space so we can concatenate it to the latents 893 | if isinstance(generator, list): 894 | image_latents = [ 895 | retrieve_latents(self.vae.encode(image[i : i + 1]), generator[i]) 896 | for i in range(batch_size) 897 | ] 898 | image_latents = torch.cat(image_latents, dim=0) 899 | else: 900 | image = image.to(self.vae.dtype) 901 | image_latents = retrieve_latents(self.vae.encode(image), generator) 902 | image_latents = self.vae.config.scaling_factor * image_latents 903 | 904 | # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method 905 | if image_latents.shape[0] < batch_size: 906 | if batch_size % image_latents.shape[0] != 0: 907 | raise ValueError( 908 | "The passed images and the required batch size don't match. Images are supposed to be duplicated" 909 | f" to a total batch size of {batch_size}, but {image_latents.shape[0]} images were passed." 910 | " Make sure the number of images that you pass is divisible by the total requested batch size." 911 | ) 912 | image_latents = image_latents.repeat( 913 | batch_size // image_latents.shape[0], 1, 1, 1 914 | ) 915 | 916 | image_latents = ( 917 | torch.cat([image_latents] * 2) 918 | if do_classifier_free_guidance 919 | else image_latents 920 | ) 921 | 922 | # aligning device to prevent device errors when concating it with the latent model input 923 | image_latents = image_latents.to(device=device, dtype=dtype) 924 | return image_latents 925 | 926 | # def decode_latents(self, latents: torch.Tensor): 927 | # return self.get_img_from_latents(latents=latents) 928 | 929 | def get_img_from_latents(self, latents: torch.Tensor): 930 | # scale and decode the image latents with vae 931 | if len(latents.shape) == 3: 932 | latents = latents[None] 933 | norm_latents = latents 934 | dec_tensor = self.vae.decode( 935 | norm_latents / self.vae.config.scaling_factor, return_dict=False 936 | )[0] 937 | dec_images = self.image_processor.postprocess( 938 | dec_tensor, output_type="np", do_denormalize=[True] * dec_tensor.shape[0] 939 | ) 940 | dec_image_zero = dec_images 941 | dec_image_zero = np.nan_to_num(dec_image_zero) 942 | image_out_np = np.clip( 943 | (dec_image_zero * 255.0).round().astype(int), a_min=0, a_max=255 944 | ).astype(np.uint8) 945 | return [PIL.Image.fromarray(img) for img in image_out_np] 946 | 947 | def encode_images_to_latents( 948 | self, 949 | imgs: List[PIL.Image.Image], 950 | generator: torch.Generator, 951 | device: torch.device, 952 | dtype: torch.dtype, 953 | ) -> torch.Tensor: 954 | width, height = imgs[0].size 955 | image_tensor = _images_to_tensors( 956 | imgs=imgs, width=width, height=height, device=device, dtype=dtype 957 | ) 958 | # encode the mask image into latents space so we can concatenate it to the latents 959 | if isinstance(generator, list): 960 | image_latent = torch.cat( 961 | [ 962 | retrieve_latents( 963 | self.vae.encode(image_tensor[i : i + 1]), generator[i] 964 | ) 965 | for i in range(image_tensor.shape[0]) 966 | ], 967 | dim=0, 968 | ) 969 | else: 970 | image_latent = retrieve_latents(self.vae.encode(image_tensor), generator) 971 | image_latent = self.vae.config.scaling_factor * image_latent 972 | 973 | return image_latent.to(device=device, dtype=dtype) 974 | 975 | def get_timesteps(self, num_inference_steps, strength: float): 976 | # get the original timestep using init_timestep 977 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 978 | t_start = max(num_inference_steps - init_timestep, 0) 979 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 980 | if hasattr(self.scheduler, "set_begin_index"): 981 | self.scheduler.set_begin_index(t_start * self.scheduler.order) 982 | 983 | return timesteps, num_inference_steps - t_start 984 | 985 | def prepare_latents( 986 | self, 987 | batch_size, 988 | num_channels_latents, 989 | height, 990 | width, 991 | dtype, 992 | device, 993 | generator, 994 | latents=None, 995 | image=None, 996 | timestep=None, 997 | is_strength_max=True, 998 | return_noise=False, 999 | return_image_latents=False, 1000 | ): 1001 | shape = ( 1002 | batch_size, 1003 | num_channels_latents, 1004 | height // self.vae_scale_factor, 1005 | width // self.vae_scale_factor, 1006 | ) 1007 | if return_image_latents or (latents is None and not is_strength_max): 1008 | # TODO: check it 1009 | if image is None: 1010 | image = torch.randn(shape, device=device, dtype=dtype) 1011 | image = image.to(device=device, dtype=dtype) 1012 | 1013 | if image.shape[1] == 4: 1014 | image_latents = image 1015 | else: 1016 | image_latents = self._encode_vae_image(image=image, generator=generator) 1017 | image_latents = image_latents.repeat( 1018 | batch_size // image_latents.shape[0], 1, 1, 1 1019 | ) 1020 | image_latents.to(device=device, dtype=dtype) 1021 | 1022 | if latents is None: 1023 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 1024 | if is_strength_max: 1025 | latents = noise * self.scheduler.init_noise_sigma 1026 | else: 1027 | latents = self.scheduler.add_noise(image_latents, noise, timestep) 1028 | else: 1029 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 1030 | if is_strength_max: 1031 | latents = noise * self.scheduler.init_noise_sigma 1032 | else: 1033 | latents = self.scheduler.add_noise(latents, noise, timestep) 1034 | 1035 | outputs = (latents,) 1036 | 1037 | if return_noise: 1038 | outputs += (noise,) 1039 | 1040 | if return_image_latents: 1041 | outputs += (image_latents,) 1042 | return outputs 1043 | 1044 | def prepare_mask_latents( 1045 | self, 1046 | mask: torch.Tensor, 1047 | masked_image: torch.Tensor, 1048 | batch_size: int, 1049 | height: int, 1050 | width: int, 1051 | dtype: torch.dtype, 1052 | device: torch.device, 1053 | generator: torch.Generator, 1054 | do_classifier_free_guidance: bool, 1055 | ): 1056 | # resize the mask to latents shape as we concatenate the mask to the latents 1057 | # we do that before converting to dtype to avoid breaking in case we're using cpu_offload 1058 | # and half precision 1059 | mask = torch.nn.functional.interpolate( 1060 | mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) 1061 | ) 1062 | mask = mask.to(device=device, dtype=dtype) 1063 | 1064 | masked_image = masked_image.to(device=device, dtype=dtype) 1065 | 1066 | if masked_image.shape[1] == 4: 1067 | masked_image_latents = masked_image 1068 | else: 1069 | masked_image_latents = self._encode_vae_image( 1070 | masked_image, generator=generator 1071 | ) 1072 | 1073 | # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method 1074 | if mask.shape[0] < batch_size: 1075 | if batch_size % mask.shape[0] != 0: 1076 | raise ValueError( 1077 | "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" 1078 | f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" 1079 | " of masks that you pass is divisible by the total requested batch size." 1080 | ) 1081 | mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) 1082 | if masked_image_latents.shape[0] < batch_size: 1083 | if batch_size % masked_image_latents.shape[0] != 0: 1084 | raise ValueError( 1085 | "The passed images and the required batch size don't match. Images are supposed to be duplicated" 1086 | f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." 1087 | " Make sure the number of images that you pass is divisible by the total requested batch size." 1088 | ) 1089 | masked_image_latents = masked_image_latents.repeat( 1090 | batch_size // masked_image_latents.shape[0], 1, 1, 1 1091 | ) 1092 | 1093 | mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask 1094 | masked_image_latents = ( 1095 | torch.cat([masked_image_latents] * 2) 1096 | if do_classifier_free_guidance 1097 | else masked_image_latents 1098 | ) 1099 | 1100 | # aligning device to prevent device errors when concating it with the latent model input 1101 | masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) 1102 | return mask, masked_image_latents 1103 | 1104 | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): 1105 | if isinstance(generator, list): 1106 | image_latents = [ 1107 | retrieve_latents( 1108 | self.vae.encode(image[i : i + 1]), generator=generator[i] 1109 | ) 1110 | for i in range(image.shape[0]) 1111 | ] 1112 | image_latents = torch.cat(image_latents, dim=0) 1113 | else: 1114 | image = image.to(self.vae.dtype) 1115 | image_latents = retrieve_latents( 1116 | self.vae.encode(image), generator=generator 1117 | ) 1118 | 1119 | image_latents = self.vae.config.scaling_factor * image_latents 1120 | return image_latents 1121 | 1122 | 1123 | @dataclass 1124 | class ReferenceData: 1125 | ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None 1126 | ref_image_mask: Union[torch.FloatTensor, PIL.Image.Image] = None 1127 | MODE: str = "write" 1128 | progress: float = 0.0 1129 | uc_mask: torch.Tensor = None 1130 | bool_mask: bool = True 1131 | style_fidelity: float = 1.0 1132 | do_classifier_free_guidance: bool = True 1133 | attention_auto_machine_weight: float = 100.0 1134 | gn_auto_machine_weight: float = 1.0 1135 | ref_mask_dict: dict = None 1136 | out_mask_dict: dict = None 1137 | strength: float = 1.0 1138 | 1139 | 1140 | class ReferenceOnlyUNet2DConditionModel(UNet2DConditionModel): 1141 | @classmethod 1142 | def from_unet( 1143 | cls, 1144 | unet: UNet2DConditionModel, 1145 | ref_data: ReferenceData = ReferenceData(), 1146 | reference_attn: bool = False, 1147 | reference_adain: bool = False, 1148 | ) -> "ReferenceOnlyUNet2DConditionModel": 1149 | # 创建一个新的子类实例 1150 | basic_transformer_idx = 0 1151 | basic_transformer_blocks = [] 1152 | for module in torch_dfs(unet): 1153 | if reference_attn: 1154 | if isinstance(module, BasicTransformerBlock): 1155 | basic_transformer_blocks.append(module) 1156 | module.__class__ = BasicTransformerBlockReferenceOnly 1157 | module.ref_data = ref_data 1158 | module.bank = [] 1159 | module.idx = basic_transformer_idx 1160 | basic_transformer_idx += 1 1161 | elif reference_adain: 1162 | if isinstance(module, CrossAttnDownBlock2D): 1163 | module.__class__ = CrossAttnDownBlock2DReferenceOnly 1164 | module.ref_data = ref_data 1165 | module.bank = [] 1166 | if isinstance(module, DownBlock2D): 1167 | module.__class__ = DownBlock2DReferenceOnly 1168 | if isinstance(module, UNetMidBlock2DCrossAttn): 1169 | module.__class__ = UNetMidBlock2DCrossAttnReferenceOnly 1170 | if isinstance(module, UpBlock2D): 1171 | module.__class__ = UpBlock2DReferenceOnly 1172 | if isinstance(module, CrossAttnUpBlock2D): 1173 | module.__class__ = CrossAttnUpBlock2DReferenceOnly 1174 | module.ref_data = ref_data 1175 | unet.mid_block.gn_weight = 0 1176 | down_blocks = unet.down_blocks 1177 | module.mean_bank = [] 1178 | module.var_bank = [] 1179 | for w, module in enumerate(down_blocks): 1180 | module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) 1181 | module.gn_weight *= 2 1182 | 1183 | up_blocks = unet.up_blocks 1184 | for w, module in enumerate(up_blocks): 1185 | module.gn_weight = float(w) / float(len(up_blocks)) 1186 | module.gn_weight *= 2 1187 | 1188 | # 计算 attn_weight 1189 | basic_transformer_blocks = sorted( 1190 | basic_transformer_blocks, key=lambda x: -x.norm1.normalized_shape[0] 1191 | ) 1192 | 1193 | for i, module in enumerate(basic_transformer_blocks): 1194 | module.attn_weight = float(i) / float(len(basic_transformer_blocks)) 1195 | unet.__class__ = cls 1196 | unet.ref_data = ref_data 1197 | return unet 1198 | 1199 | @classmethod 1200 | def revert_unet( 1201 | cls, unet: "ReferenceOnlyUNet2DConditionModel" 1202 | ) -> UNet2DConditionModel: 1203 | unet.__class__ = UNet2DConditionModel 1204 | for module in torch_dfs(unet): 1205 | if isinstance(module, BasicTransformerBlockReferenceOnly): 1206 | module.__class__ = BasicTransformerBlock 1207 | if isinstance(module, CrossAttnDownBlock2DReferenceOnly): 1208 | module.__class__ = CrossAttnDownBlock2D 1209 | if isinstance(module, DownBlock2DReferenceOnly): 1210 | module.__class__ = DownBlock2D 1211 | if isinstance(module, UNetMidBlock2DCrossAttnReferenceOnly): 1212 | module.__class__ = UNetMidBlock2DCrossAttn 1213 | if isinstance(module, UpBlock2DReferenceOnly): 1214 | module.__class__ = UpBlock2D 1215 | if isinstance(module, CrossAttnUpBlock2DReferenceOnly): 1216 | module.__class__ = CrossAttnUpBlock2D 1217 | return unet 1218 | 1219 | 1220 | class BasicTransformerBlockReferenceOnly(BasicTransformerBlock): 1221 | 1222 | @classmethod 1223 | def from_module( 1224 | cls, module: BasicTransformerBlock 1225 | ) -> "BasicTransformerBlockReferenceOnly": 1226 | module.__class__ = cls 1227 | return module 1228 | 1229 | def forward( 1230 | self, 1231 | hidden_states: torch.FloatTensor, 1232 | attention_mask: torch.FloatTensor | None = None, 1233 | encoder_hidden_states: torch.FloatTensor | None = None, 1234 | encoder_attention_mask: torch.FloatTensor | None = None, 1235 | timestep: torch.LongTensor | None = None, 1236 | cross_attention_kwargs: Dict[str, Any] = None, 1237 | class_labels: torch.LongTensor | None = None, 1238 | _: Dict[str, torch.Tensor] | None = None, 1239 | ) -> torch.FloatTensor: 1240 | assert isinstance(self.idx, int) 1241 | ref_data = self.ref_data 1242 | assert isinstance(ref_data, ReferenceData) 1243 | bank = self.bank 1244 | assert isinstance(bank, list) 1245 | 1246 | if self.use_ada_layer_norm: 1247 | norm_hidden_states = self.norm1(hidden_states, timestep) 1248 | elif self.use_ada_layer_norm_zero: 1249 | ( 1250 | norm_hidden_states, 1251 | gate_msa, 1252 | shift_mlp, 1253 | scale_mlp, 1254 | gate_mlp, 1255 | ) = self.norm1( 1256 | hidden_states, 1257 | timestep, 1258 | class_labels, 1259 | hidden_dtype=hidden_states.dtype, 1260 | ) 1261 | else: 1262 | norm_hidden_states = self.norm1(hidden_states) 1263 | 1264 | # 1. Retrieve lora scale. 1265 | lora_scale = ( 1266 | cross_attention_kwargs.get("scale", 1.0) 1267 | if cross_attention_kwargs is not None 1268 | else 1.0 1269 | ) 1270 | 1271 | # 2. Prepare GLIGEN inputs 1272 | cross_attention_kwargs = ( 1273 | cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 1274 | ) 1275 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 1276 | 1277 | # 1. Self-Attention 1278 | cross_attention_kwargs = ( 1279 | cross_attention_kwargs if cross_attention_kwargs is not None else {} 1280 | ) 1281 | 1282 | if self.only_cross_attention: 1283 | attn_output = self.attn1( 1284 | norm_hidden_states, 1285 | encoder_hidden_states=( 1286 | encoder_hidden_states if self.only_cross_attention else None 1287 | ), 1288 | attention_mask=attention_mask, 1289 | **cross_attention_kwargs, 1290 | ) 1291 | elif ref_data.MODE == "read": 1292 | style_fidelity = ref_data.style_fidelity 1293 | attention_auto_machine_weight = ref_data.attention_auto_machine_weight 1294 | if attention_auto_machine_weight > self.attn_weight: 1295 | attn_output_uc = self.attn1( 1296 | norm_hidden_states, 1297 | encoder_hidden_states=torch.cat( 1298 | [norm_hidden_states] + self.bank, dim=1 1299 | ), 1300 | # attention_mask=attention_mask, 1301 | **cross_attention_kwargs, 1302 | ) 1303 | attn_output_c = attn_output_uc.clone() 1304 | do_classifier_free_guidance = ref_data.do_classifier_free_guidance 1305 | if do_classifier_free_guidance and style_fidelity > 0: 1306 | uc_mask = ref_data.uc_mask 1307 | attn_output_c[uc_mask] = self.attn1( 1308 | norm_hidden_states[uc_mask], 1309 | encoder_hidden_states=norm_hidden_states[uc_mask], 1310 | **cross_attention_kwargs, 1311 | ) 1312 | attn_output = ( 1313 | style_fidelity * attn_output_c 1314 | + (1.0 - style_fidelity) * attn_output_uc 1315 | ) 1316 | attn_output *= ref_data.strength 1317 | bank.clear() 1318 | else: 1319 | # without reference only 1320 | attn_output = self.attn1( 1321 | norm_hidden_states, 1322 | encoder_hidden_states=( 1323 | encoder_hidden_states if self.only_cross_attention else None 1324 | ), 1325 | attention_mask=attention_mask, 1326 | **cross_attention_kwargs, 1327 | ) 1328 | 1329 | elif ref_data.MODE == "write": 1330 | bank.append(norm_hidden_states.detach().clone()) 1331 | attn_output = self.attn1( 1332 | norm_hidden_states, 1333 | encoder_hidden_states=( 1334 | encoder_hidden_states if self.only_cross_attention else None 1335 | ), 1336 | attention_mask=attention_mask, 1337 | **cross_attention_kwargs, 1338 | ) 1339 | if self.use_ada_layer_norm_zero: 1340 | attn_output = gate_msa.unsqueeze(1) * attn_output 1341 | 1342 | hidden_states = attn_output + hidden_states 1343 | 1344 | # 2.5 GLIGEN Control 1345 | if gligen_kwargs is not None: 1346 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 1347 | 1348 | # 2. Cross-Attention 1349 | if self.attn2 is not None: 1350 | norm_hidden_states = ( 1351 | self.norm2(hidden_states, timestep) 1352 | if self.use_ada_layer_norm 1353 | else self.norm2(hidden_states) 1354 | ) 1355 | 1356 | attn_output = self.attn2( 1357 | norm_hidden_states, 1358 | encoder_hidden_states=encoder_hidden_states, 1359 | attention_mask=encoder_attention_mask, 1360 | **cross_attention_kwargs, 1361 | ) 1362 | hidden_states = attn_output + hidden_states 1363 | 1364 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 1365 | 1366 | # 3. Feed-forward 1367 | norm_hidden_states = self.norm3(hidden_states) 1368 | 1369 | if self.use_ada_layer_norm_zero: 1370 | norm_hidden_states = ( 1371 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 1372 | ) 1373 | 1374 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 1375 | 1376 | if self.use_ada_layer_norm_zero: 1377 | ff_output = gate_mlp.unsqueeze(1) * ff_output 1378 | 1379 | return ff_output + hidden_states 1380 | 1381 | 1382 | class CrossAttnDownBlock2DReferenceOnly(CrossAttnDownBlock2D): 1383 | @classmethod 1384 | def from_module( 1385 | cls, module: CrossAttnDownBlock2D 1386 | ) -> "CrossAttnDownBlock2DReferenceOnly": 1387 | module.__class__ = cls 1388 | return module 1389 | 1390 | def forward( 1391 | self, 1392 | hidden_states: torch.FloatTensor, 1393 | temb: torch.FloatTensor | None = None, 1394 | encoder_hidden_states: torch.FloatTensor | None = None, 1395 | attention_mask: torch.FloatTensor | None = None, 1396 | cross_attention_kwargs: Dict[str, Any] | None = None, 1397 | encoder_attention_mask: torch.FloatTensor | None = None, 1398 | additional_residuals: torch.FloatTensor | None = None, 1399 | ) -> Tuple[torch.FloatTensor | Tuple[torch.FloatTensor]]: 1400 | MODE = self.ref_data.MODE 1401 | gn_auto_machine_weight = self.ref_data.gn_auto_machine_weight 1402 | do_classifier_free_guidance = self.ref_data.do_classifier_free_guidance 1403 | style_fidelity = self.ref_data.style_fidelity 1404 | uc_mask = self.ref_data.uc_mask 1405 | 1406 | eps = 1e-6 1407 | # TODO(Patrick, William) - attention mask is not used 1408 | output_states = () 1409 | 1410 | blocks = list(zip(self.resnets, self.attentions)) 1411 | for i, (resnet, attn) in enumerate(blocks): 1412 | hidden_states = resnet(hidden_states, temb) 1413 | hidden_states = attn( 1414 | hidden_states, 1415 | encoder_hidden_states=encoder_hidden_states, 1416 | cross_attention_kwargs=cross_attention_kwargs, 1417 | attention_mask=attention_mask, 1418 | encoder_attention_mask=encoder_attention_mask, 1419 | return_dict=False, 1420 | )[0] 1421 | if MODE == "write" and gn_auto_machine_weight >= self.gn_weight: 1422 | var, mean = torch.var_mean( 1423 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1424 | ) 1425 | self.mean_bank.append([mean]) 1426 | self.var_bank.append([var]) 1427 | if MODE == "read" and (len(self.mean_bank) > 0 and len(self.var_bank) > 0): 1428 | var, mean = torch.var_mean( 1429 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1430 | ) 1431 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 1432 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 1433 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 1434 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 1435 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 1436 | hidden_states_c = hidden_states_uc.clone() 1437 | if do_classifier_free_guidance and style_fidelity > 0: 1438 | hidden_states_c[uc_mask] = hidden_states[uc_mask] 1439 | hidden_states = ( 1440 | style_fidelity * hidden_states_c 1441 | + (1.0 - style_fidelity) * hidden_states_uc 1442 | ) 1443 | hidden_states *= self.ref_data.strength 1444 | # apply additional residuals to the output of the last pair of resnet and attention blocks 1445 | if i == len(blocks) - 1 and additional_residuals is not None: 1446 | hidden_states = hidden_states + additional_residuals 1447 | output_states = output_states + (hidden_states,) 1448 | 1449 | if MODE == "read": 1450 | self.mean_bank = [] 1451 | self.var_bank = [] 1452 | 1453 | if self.downsamplers is not None: 1454 | for downsampler in self.downsamplers: 1455 | hidden_states = downsampler(hidden_states) 1456 | 1457 | output_states = output_states + (hidden_states,) 1458 | 1459 | return hidden_states, output_states 1460 | 1461 | 1462 | class DownBlock2DReferenceOnly(DownBlock2D): 1463 | @classmethod 1464 | def from_module(cls, module: DownBlock2D): 1465 | instance = cls() 1466 | instance.__dict__.update(module.__dict__) 1467 | return instance 1468 | 1469 | def forward( 1470 | self, 1471 | hidden_states: torch.FloatTensor, 1472 | temb: torch.FloatTensor | None = None, 1473 | *args, 1474 | **kwargs, 1475 | ) -> Tuple[torch.FloatTensor | Tuple[torch.FloatTensor]]: 1476 | 1477 | MODE = self.ref_data.MODE 1478 | gn_auto_machine_weight = self.ref_data.gn_auto_machine_weight 1479 | do_classifier_free_guidance = self.ref_data.do_classifier_free_guidance 1480 | style_fidelity = self.ref_data.style_fidelity 1481 | uc_mask = self.ref_data.uc_mask 1482 | 1483 | eps = 1e-6 1484 | output_states = () 1485 | for i, resnet in enumerate(self.resnets): 1486 | hidden_states = resnet(hidden_states, temb) 1487 | if MODE == "write" and gn_auto_machine_weight >= self.gn_weight: 1488 | var, mean = torch.var_mean( 1489 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1490 | ) 1491 | self.mean_bank.append([mean]) 1492 | self.var_bank.append([var]) 1493 | if MODE == "read" and (len(self.mean_bank) > 0 and len(self.var_bank) > 0): 1494 | var, mean = torch.var_mean( 1495 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1496 | ) 1497 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 1498 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 1499 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 1500 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 1501 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 1502 | hidden_states_c = hidden_states_uc.clone() 1503 | if do_classifier_free_guidance and style_fidelity > 0: 1504 | hidden_states_c[uc_mask] = hidden_states[uc_mask] 1505 | hidden_states = ( 1506 | style_fidelity * hidden_states_c 1507 | + (1.0 - style_fidelity) * hidden_states_uc 1508 | ) 1509 | hidden_states *= self.ref_data.strength 1510 | 1511 | output_states = output_states + (hidden_states,) 1512 | 1513 | if MODE == "read": 1514 | self.mean_bank = [] 1515 | self.var_bank = [] 1516 | 1517 | if self.downsamplers is not None: 1518 | for downsampler in self.downsamplers: 1519 | hidden_states = downsampler(hidden_states) 1520 | 1521 | output_states = output_states + (hidden_states,) 1522 | 1523 | return hidden_states, output_states 1524 | 1525 | 1526 | class UNetMidBlock2DCrossAttnReferenceOnly(UNetMidBlock2DCrossAttn): 1527 | @classmethod 1528 | def from_module(cls, module: UNetMidBlock2DCrossAttn): 1529 | instance = cls() 1530 | instance.__dict__.update(module.__dict__) 1531 | return instance 1532 | 1533 | def forward( 1534 | self, 1535 | hidden_states: torch.FloatTensor, 1536 | temb: torch.FloatTensor | None = None, 1537 | encoder_hidden_states: torch.FloatTensor | None = None, 1538 | attention_mask: torch.FloatTensor | None = None, 1539 | cross_attention_kwargs: Dict[str, Any] | None = None, 1540 | encoder_attention_mask: torch.FloatTensor | None = None, 1541 | ) -> torch.FloatTensor: 1542 | return super().forward( 1543 | hidden_states, 1544 | temb, 1545 | encoder_hidden_states, 1546 | attention_mask, 1547 | cross_attention_kwargs, 1548 | encoder_attention_mask, 1549 | ) 1550 | 1551 | def forward(self, *args, **kwargs): 1552 | MODE = self.ref_data.MODE 1553 | gn_auto_machine_weight = self.ref_data.gn_auto_machine_weight 1554 | do_classifier_free_guidance = self.ref_data.do_classifier_free_guidance 1555 | style_fidelity = self.ref_data.style_fidelity 1556 | uc_mask = self.ref_data.uc_mask 1557 | x = super().forward(*args, **kwargs) 1558 | if MODE == "write" and gn_auto_machine_weight >= self.gn_weight: 1559 | var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) 1560 | self.mean_bank.append(mean) 1561 | self.var_bank.append(var) 1562 | if MODE == "read": 1563 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 1564 | var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) 1565 | eps = 1e-6 1566 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 1567 | mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) 1568 | var_acc = sum(self.var_bank) / float(len(self.var_bank)) 1569 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 1570 | x_uc = (((x - mean) / std) * std_acc) + mean_acc 1571 | x_c = x_uc.clone() 1572 | if do_classifier_free_guidance and style_fidelity > 0: 1573 | x_c[uc_mask] = x[uc_mask] 1574 | x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc 1575 | x *= self.ref_data.strength 1576 | self.mean_bank = [] 1577 | self.var_bank = [] 1578 | return x 1579 | 1580 | 1581 | class UpBlock2DReferenceOnly(UpBlock2D): 1582 | @classmethod 1583 | def from_module(cls, module: UpBlock2D): 1584 | instance = cls() 1585 | instance.__dict__.update(module.__dict__) 1586 | return instance 1587 | 1588 | def forward( 1589 | self, 1590 | hidden_states: torch.FloatTensor, 1591 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 1592 | temb: Optional[torch.FloatTensor] = None, 1593 | upsample_size: Optional[int] = None, 1594 | *args, 1595 | **kwargs, 1596 | ) -> torch.FloatTensor: 1597 | MODE = self.ref_data.MODE 1598 | gn_auto_machine_weight = self.ref_data.gn_auto_machine_weight 1599 | do_classifier_free_guidance = self.ref_data.do_classifier_free_guidance 1600 | style_fidelity = self.ref_data.style_fidelity 1601 | uc_mask = self.ref_data.uc_mask 1602 | 1603 | eps = 1e-6 1604 | for i, resnet in enumerate(self.resnets): 1605 | # pop res hidden states 1606 | res_hidden_states = res_hidden_states_tuple[-1] 1607 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1608 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1609 | hidden_states = resnet(hidden_states, temb) 1610 | 1611 | if MODE == "write" and gn_auto_machine_weight >= self.gn_weight: 1612 | var, mean = torch.var_mean( 1613 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1614 | ) 1615 | self.mean_bank.append([mean]) 1616 | self.var_bank.append([var]) 1617 | if MODE == "read" and (len(self.mean_bank) > 0 and len(self.var_bank) > 0): 1618 | var, mean = torch.var_mean( 1619 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1620 | ) 1621 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 1622 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 1623 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 1624 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 1625 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 1626 | hidden_states_c = hidden_states_uc.clone() 1627 | if do_classifier_free_guidance and style_fidelity > 0: 1628 | hidden_states_c[uc_mask] = hidden_states[uc_mask] 1629 | hidden_states = ( 1630 | style_fidelity * hidden_states_c 1631 | + (1.0 - style_fidelity) * hidden_states_uc 1632 | ) 1633 | hidden_states *= self.ref_data.strength 1634 | 1635 | if MODE == "read": 1636 | self.mean_bank = [] 1637 | self.var_bank = [] 1638 | 1639 | if self.upsamplers is not None: 1640 | for upsampler in self.upsamplers: 1641 | hidden_states = upsampler(hidden_states, upsample_size) 1642 | 1643 | return hidden_states 1644 | 1645 | 1646 | class CrossAttnUpBlock2DReferenceOnly(CrossAttnUpBlock2D): 1647 | @classmethod 1648 | def from_module(cls, module: CrossAttnUpBlock2D): 1649 | instance = cls() 1650 | instance.__dict__.update(module.__dict__) 1651 | return instance 1652 | 1653 | def forward( 1654 | self, 1655 | hidden_states: torch.FloatTensor, 1656 | res_hidden_states_tuple: Tuple[torch.FloatTensor], 1657 | temb: torch.FloatTensor | None = None, 1658 | encoder_hidden_states: torch.FloatTensor | None = None, 1659 | cross_attention_kwargs: Dict[str, Any] | None = None, 1660 | upsample_size: int | None = None, 1661 | attention_mask: torch.FloatTensor | None = None, 1662 | encoder_attention_mask: torch.FloatTensor | None = None, 1663 | ) -> torch.FloatTensor: 1664 | 1665 | MODE = self.ref_data.MODE 1666 | gn_auto_machine_weight = self.ref_data.gn_auto_machine_weight 1667 | do_classifier_free_guidance = self.ref_data.do_classifier_free_guidance 1668 | style_fidelity = self.ref_data.style_fidelity 1669 | uc_mask = self.ref_data.uc_mask 1670 | 1671 | eps = 1e-6 1672 | # TODO(Patrick, William) - attention mask is not used 1673 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 1674 | # pop res hidden states 1675 | res_hidden_states = res_hidden_states_tuple[-1] 1676 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 1677 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 1678 | hidden_states = resnet(hidden_states, temb) 1679 | hidden_states = attn( 1680 | hidden_states, 1681 | encoder_hidden_states=encoder_hidden_states, 1682 | cross_attention_kwargs=cross_attention_kwargs, 1683 | attention_mask=attention_mask, 1684 | encoder_attention_mask=encoder_attention_mask, 1685 | return_dict=False, 1686 | )[0] 1687 | 1688 | if MODE == "write" and gn_auto_machine_weight >= self.gn_weight: 1689 | var, mean = torch.var_mean( 1690 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1691 | ) 1692 | self.mean_bank.append([mean]) 1693 | self.var_bank.append([var]) 1694 | if MODE == "read" and (len(self.mean_bank) > 0 and len(self.var_bank) > 0): 1695 | var, mean = torch.var_mean( 1696 | hidden_states, dim=(2, 3), keepdim=True, correction=0 1697 | ) 1698 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 1699 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 1700 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 1701 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 1702 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 1703 | hidden_states_c = hidden_states_uc.clone() 1704 | if do_classifier_free_guidance and style_fidelity > 0: 1705 | hidden_states_c[uc_mask] = hidden_states[uc_mask] 1706 | hidden_states = ( 1707 | style_fidelity * hidden_states_c 1708 | + (1.0 - style_fidelity) * hidden_states_uc 1709 | ) 1710 | hidden_states *= self.ref_data.strength 1711 | 1712 | if MODE == "read": 1713 | self.mean_bank = [] 1714 | self.var_bank = [] 1715 | 1716 | if self.upsamplers is not None: 1717 | for upsampler in self.upsamplers: 1718 | hidden_states = upsampler(hidden_states, upsample_size) 1719 | 1720 | return hidden_states 1721 | --------------------------------------------------------------------------------