├── .github └── workflows │ └── publish.yml ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── style_t2i.jpg ├── style_t2i_sdxl.jpg └── style_transfer.jpg ├── comfyui_nodes.py ├── examples ├── 1.png ├── 26.jpg ├── 40.jpg └── lecun.png ├── losses.py ├── pipeline_flux.py ├── pipeline_sd.py ├── pipeline_sdxl.py ├── pyproject.toml ├── requirements.txt ├── train_vae.py ├── utils.py └── workflows ├── style_t2i_generation_flux.json ├── style_t2i_generation_sd15.json ├── style_t2i_generation_sdxl.json └── style_transfer_sd15.json /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'zichongc' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | with: 22 | submodules: true 23 | - name: Publish Custom Node 24 | uses: Comfy-Org/publish-node-action@v1 25 | with: 26 | ## Add your own personal access token to your Github Repository secrets and reference it here. 27 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zichong Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ComfyUI-Attention-Distillation 2 | 3 | Non-native [AttentionDistillation](https://xugao97.github.io/AttentionDistillation/) for ComfyUI. 4 | 5 | Official ComfyUI demo for the paper [AttentionDistillation](https://arxiv.org/abs/2502.20235), implemented as an extension of ComfyUI. Note that this extension incorporates AttentionDistillation using `diffusers`. 6 | 7 | The official code for AttentionDistillation can be found [here](https://github.com/xugao97/AttentionDistillation). 8 | 9 | ### 🔥🔥 News 10 | * **2025/03/10**: Workflows for style-specific T2I generation using **SDXL** and **Flux**(beta) have been released. 11 | * **2025/02/27**: We release the ComfyUI implementation of Attention Distillation and two workflows for style transfer and style-specific text-to-image generation using Stable Diffusion 1.5. 12 | * **2025/02/27**: The official code for AttentionDistillation has been released [here](https://github.com/xugao97/AttentionDistillation). 13 | 14 | ### 🛒 Installation 15 | Download or `git clone` this repository into the `ComfyUI/custom_nodes/` directory, or use the Manager for a streamlined setup. 16 | 17 | 18 | ##### Install manually 19 | 1. `cd custom_nodes` 20 | 2. `git clone ...` 21 | 3. `cd custom_nodes/ComfyUI-AttentionDistillation` 22 | 4. `pip install -r requirements.txt` 23 | 5. restart ComfyUI 24 | 25 | ### 📒 How to Use 26 | ##### Download T2I diffusion models 27 | This implementation utilizes checkpoints for `diffusers`. Download the required models and place them in the `ComfyUI/models/diffusers` directory: 28 | |Model|Model Name and Link| 29 | |:---:|:---:| 30 | | Stable Diffusion (v1.5, v2.1) | [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)
[stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | 31 | | SDXL | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | 32 | | Flux (dev) | [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) | 33 | 34 | 35 | ~~*Note: Currently, only Stable Diffusion v1.5 is required.*~~ 36 | 37 | ##### Load the workflow 38 | Workflows for various tasks are available in `ComfyUI/custom_nodes/Comfy-Attention-Distillation/workflows`. Simply load them to get started. Additionally, we've included usage examples in the [Examples](#examples) section for your reference. 39 | 40 | ### 🔍 Examples 41 | 42 | #### Style-specific text-to-image generation 43 | `style_t2i_generation_sd15.json` 44 | 45 | 46 | 47 | 48 | `style_t2i_generation_sdxl.json` 49 | 50 | 51 | 52 | 53 | `style_t2i_generation_flux.json` (beta) 54 | 55 | 56 | 57 | #### Style Transfer 58 | `style_transfer_sd15.json` 59 | 60 | 61 | 62 | 63 | ### 📃 TODOs 64 | - [x] Workflow for style-specific text-to-image generation using SDXL. 65 | - [x] Workflow for style-specific text-to-image generation using Flux. 66 | - [ ] Workflow for texture synthesis. 67 | 68 | 69 | 81 | 82 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 4 | -------------------------------------------------------------------------------- /assets/style_t2i.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_t2i.jpg -------------------------------------------------------------------------------- /assets/style_t2i_sdxl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_t2i_sdxl.jpg -------------------------------------------------------------------------------- /assets/style_transfer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_transfer.jpg -------------------------------------------------------------------------------- /comfyui_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from diffusers import DDIMScheduler 5 | 6 | from comfy.comfy_types import IO 7 | import comfy.model_management as mm 8 | import node_helpers 9 | import folder_paths 10 | from huggingface_hub import hf_hub_download 11 | from tqdm import tqdm 12 | 13 | from torchvision.transforms.functional import resize, to_tensor 14 | from accelerate.utils import set_seed 15 | from .pipeline_sd import ADPipeline 16 | from .pipeline_sdxl import ADPipeline as ADXLPipeline 17 | from .pipeline_flux import ADPipeline as ADFluxPipeline 18 | from .utils import Controller 19 | from .utils import sd15_file_names, sdxl_file_names, flux_file_names 20 | 21 | 22 | class PureText: 23 | @classmethod 24 | def INPUT_TYPES(s): 25 | return { 26 | "required": { 27 | "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), 28 | } 29 | } 30 | 31 | RETURN_TYPES = (IO.CONDITIONING,) 32 | FUNCTION = "get_prompt" 33 | CATEGORY = "AttentionDistillationWrapper" 34 | 35 | def get_prompt(self, text): 36 | return (text,) 37 | 38 | 39 | class LoadPILImage: 40 | @classmethod 41 | def INPUT_TYPES(s): 42 | input_dir = folder_paths.get_input_directory() 43 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] 44 | return {"required": 45 | {"image": (sorted(files), {"image_upload": True})}, 46 | } 47 | 48 | CATEGORY = "AttentionDistillationWrapper" 49 | 50 | RETURN_TYPES = ("IMAGE",) 51 | RETURN_NAMES = ("image",) 52 | FUNCTION = "load_image" 53 | 54 | def load_image(self, image): 55 | image_path = folder_paths.get_annotated_filepath(image) 56 | img = node_helpers.pillow(Image.open, image_path).convert('RGB') 57 | return (img,) 58 | 59 | 60 | class ResizeImage: 61 | RETURN_TYPES = ("IMAGE",) 62 | RETURN_NAMES = ("image",) 63 | FUNCTION = "resize_image" 64 | 65 | CATEGORY = "AttentionDistillationWrapper" 66 | 67 | @classmethod 68 | def INPUT_TYPES(s): 69 | return { 70 | "required": { 71 | "image": ("IMAGE",), 72 | "resolution": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}), 73 | }, 74 | } 75 | 76 | def resize_image(self, image, resolution): 77 | if isinstance(image, torch.Tensor): 78 | assert image.ndim == 4 79 | if (image.shape[1] != 3 and image.shape[-1] == 3): 80 | image = image.permute(0, 3, 1, 2) 81 | image = resize(image, size=resolution) 82 | return (image,) 83 | 84 | 85 | class LoadDistiller: 86 | RETURN_TYPES = ("DISTILLER",) 87 | RETURN_NAMES = ("distiller",) 88 | FUNCTION = "load_model" 89 | CATEGORY = "AttentionDistillationWrapper" 90 | 91 | @classmethod 92 | def INPUT_TYPES(s): 93 | return { 94 | 'required': { 95 | "model": (['stable-diffusion-v1-5', 'stable-diffusion-xl-base-1.0', 'FLUX.1-dev'], {"default": "stable-diffusion-v1-5"}), 96 | "precision": (['bf16', 'fp32'], {"default": 'bf16'}), 97 | }, 98 | } 99 | 100 | @torch.inference_mode(False) 101 | def load_model(self, model, precision): 102 | weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 103 | if precision == 'fp32': 104 | precision = 'no' 105 | device = mm.get_torch_device() 106 | 107 | model_name = os.path.join(folder_paths.models_dir, 'diffusers', model) 108 | model_class = { 109 | "stable-diffusion-v1-5": ADPipeline, 110 | "stable-diffusion-xl-base-1.0": ADXLPipeline, 111 | "FLUX.1-dev": ADFluxPipeline, 112 | }[model] 113 | 114 | if not os.path.exists(model_name): 115 | print(f"Please download target model to : {model_name}") 116 | 117 | try: 118 | if model == "FLUX.1-dev": 119 | distiller = model_class.from_pretrained( 120 | model_name, safety_checker=None, torch_dtype=weight_dtype 121 | ).to(device) 122 | else: 123 | scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler') 124 | distiller = model_class.from_pretrained( 125 | model_name, scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype 126 | ).to(device) 127 | except: 128 | print('Download models...') 129 | 130 | repo_name = { 131 | "stable-diffusion-v1-5": "stable-diffusion-v1-5/stable-diffusion-v1-5", 132 | "stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0", 133 | "FLUX.1-dev": "black-forest-labs/FLUX.1-dev", 134 | }[model] 135 | 136 | file_names = { 137 | "stable-diffusion-v1-5": sd15_file_names, 138 | "stable-diffusion-xl-base-1.0": sdxl_file_names, 139 | "FLUX.1-dev": flux_file_names, 140 | }[model] 141 | 142 | pbar = tqdm(file_names) 143 | for file_name in pbar: 144 | pbar.set_description(f'Downloading {file_name}') 145 | if not os.path.exists(os.path.join(model_name, file_name)): 146 | hf_hub_download(repo_id=repo_name, filename=file_name, local_dir=model_name) 147 | pbar.update() 148 | 149 | 150 | if model == "FLUX.1-dev": 151 | distiller = model_class.from_pretrained( 152 | model_name, safety_checker=None, torch_dtype=weight_dtype 153 | ).to(device) 154 | else: 155 | scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler') 156 | distiller = model_class.from_pretrained( 157 | model_name, scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype 158 | ).to(device) 159 | 160 | if hasattr(distiller, 'unet'): 161 | distiller.classifier = distiller.unet 162 | elif hasattr(distiller, 'transformer'): 163 | distiller.classifier = distiller.transformer 164 | else: 165 | raise ValueError("Failed to initialize the classifier.") 166 | 167 | return ({"distiller": distiller, "precision": precision, 'weight_dtype': weight_dtype},) 168 | 169 | 170 | class ADOptimizer: 171 | @classmethod 172 | def INPUT_TYPES(s): 173 | return { 174 | "required": { 175 | "distiller": ("DISTILLER",), 176 | "content": ("IMAGE",), 177 | "style": ("IMAGE",), 178 | "steps": ("INT", {"default": 200, "min": 1, "max": 500, "step": 1}), 179 | "content_weight": ("FLOAT", {"default": 0.25, "min": 0., "max": 10., "step": 0.001}), 180 | "lr": ("FLOAT", {"default": 0.05, "min": 0.001, "max": 0.5, "step": 0.001}), 181 | "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}), 182 | "width": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}), 183 | "seed": ("INT", {"default": 2025, "min": 0, "max": 0xffffffffffffffff, "step": 1}), 184 | } 185 | } 186 | RETURN_TYPES = ("IMAGE",) 187 | RETURN_NAMES = ("image",) 188 | FUNCTION = "process" 189 | CATEGORY = "AttentionDistillationWrapper" 190 | 191 | @torch.inference_mode(False) 192 | def process(self, distiller, content, style, steps, content_weight, lr, height, width, seed): 193 | precision = distiller['precision'] 194 | attn_distiller = distiller['distiller'] 195 | 196 | assert isinstance(attn_distiller, ADPipeline), "Only support SD1.5 for style transfer." 197 | assert isinstance(style, Image.Image) and isinstance(content, Image.Image), "Please use the image loader in `AttentionDistillationWrapper->Load PIL Image` for loading image." 198 | 199 | if isinstance(style, torch.Tensor) and style.ndim == 3: 200 | style = resize(style.unsqueeze(0), (512, 512)) 201 | elif isinstance(style, Image.Image): 202 | style = to_tensor(resize(style, (512, 512))).unsqueeze(0) 203 | 204 | if isinstance(content, torch.Tensor) and content.ndim == 3: 205 | content = content.unsqueeze(0) 206 | elif isinstance(content, Image.Image): 207 | content = to_tensor(content).unsqueeze(0) 208 | 209 | assert isinstance(style, torch.Tensor) and style.ndim == 4 210 | assert isinstance(content, torch.Tensor) and content.ndim == 4 211 | 212 | if (style.shape[1] != 3 and style.shape[-1] == 3): 213 | style = style.permute(0, 3, 1, 2) 214 | if (content.shape[1] != 3 and content.shape[-1] == 3): 215 | content = content.permute(0, 3, 1, 2) 216 | 217 | print(content.shape) 218 | controller = Controller(self_layers=(10, 16)) 219 | set_seed(seed) 220 | 221 | print('style', style.min(), style.max()) 222 | print('content', content.min(), content.max()) 223 | 224 | images = attn_distiller.optimize( 225 | lr=lr, 226 | batch_size=1, 227 | iters=1, 228 | width=width, 229 | height=height, 230 | weight=content_weight, 231 | controller=controller, 232 | style_image=style, 233 | content_image=content, 234 | mixed_precision=precision, 235 | num_inference_steps=steps, 236 | enable_gradient_checkpoint=False, 237 | ) 238 | images = images.permute(0, 2, 3, 1).float() 239 | return (images,) 240 | 241 | 242 | class ADSampler: 243 | @classmethod 244 | def INPUT_TYPES(s): 245 | return { 246 | "required": { 247 | "distiller": ("DISTILLER",), 248 | "style": ("IMAGE",), 249 | "positive": (IO.CONDITIONING,), 250 | "negative": (IO.CONDITIONING,), 251 | "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}), 252 | "lr": ("FLOAT", {"default": 0.015, "min": 0.001, "max": 1., "step": 0.001}), 253 | "iters": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}), 254 | "cfg": ("FLOAT", {"default": 7.5, "min": 1., "max": 20., "step": 0.01}), 255 | "num_images_per_prompt": ("INT", {"default": 1, "min": 1, "max": 5, "step": 1}), 256 | "seed": ("INT", {"default": 2025, "min": 0, "max": 0xffffffffffffffff}), 257 | "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}), 258 | "width": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}), 259 | } 260 | } 261 | RETURN_TYPES = ("IMAGE",) 262 | RETURN_NAMES = ("images",) 263 | FUNCTION = "process" 264 | CATEGORY = "AttentionDistillationWrapper" 265 | 266 | DEFAULT_CONFIGS = { 267 | ADPipeline: {'self_layers': (10, 16), 'resolution': (512, 512), 'enable_gradient_checkpoint': False}, 268 | ADXLPipeline: {'self_layers': (64, 70), 'resolution': (1024, 1024), 'enable_gradient_checkpoint': True}, 269 | ADFluxPipeline: {'self_layers': (50, 57), 'resolution': (512, 512), 'enable_gradient_checkpoint': True}, 270 | } 271 | 272 | @torch.inference_mode(False) 273 | def process(self, distiller, style, positive, negative, steps, lr, iters, cfg, num_images_per_prompt, seed, height, width): 274 | precision = distiller['precision'] 275 | attn_distiller = distiller['distiller'] 276 | 277 | assert isinstance(style, Image.Image), "Please use the image loader in `AttentionDistillationWrapper->Load PIL Image` for loading image." 278 | 279 | default_config = self.DEFAULT_CONFIGS[type(attn_distiller)] 280 | print(default_config) 281 | 282 | controller = Controller(self_layers=default_config['self_layers']) 283 | 284 | if isinstance(style, torch.Tensor) and style.ndim == 3: 285 | style = resize(style.unsqueeze(0), default_config['resolution']) 286 | elif isinstance(style, Image.Image): 287 | style = to_tensor(resize(style, default_config['resolution'])).unsqueeze(0) 288 | 289 | assert isinstance(style, torch.Tensor) and style.ndim == 4 290 | 291 | if (style.shape[1] != 3 and style.shape[-1] == 3): 292 | style = style.permute(0, 3, 1, 2) 293 | 294 | print('style', style.min(), style.max(), style.mean()) 295 | set_seed(seed) 296 | images = attn_distiller.sample( 297 | controller=controller, 298 | iters=iters, 299 | lr=lr, 300 | adain=True, 301 | height=height, 302 | width=width, 303 | mixed_precision=precision, 304 | style_image=style, 305 | prompt=positive, 306 | negative_prompt=negative, 307 | guidance_scale=cfg, 308 | num_inference_steps=steps, 309 | num_images_per_prompt=num_images_per_prompt, 310 | enable_gradient_checkpoint=default_config['enable_gradient_checkpoint'] 311 | ) 312 | images = images.permute(0, 2, 3, 1).float() 313 | return (images,) 314 | 315 | 316 | NODE_CLASS_MAPPINGS = { 317 | "LoadDistiller": LoadDistiller, 318 | "ADOptimizer": ADOptimizer, 319 | "ADSampler": ADSampler, 320 | "LoadPILImage": LoadPILImage, 321 | "PureText": PureText, 322 | "ResizeImage": ResizeImage, 323 | } 324 | 325 | NODE_DISPLAY_NAME_MAPPINGS = { 326 | "LoadDistiller": "Load Distiller", 327 | "ADHandler": "Handler for Attention Distillation", 328 | "ADOptimizer": "Optimization-Based Style Transfer", 329 | "ADSampler": "Sampler for Style-Specific Text-to-Image", 330 | "LoadPILImage": "Load PIL Image", 331 | "PureText": "Text Prompt", 332 | "ResizeImage": "Resize Image", 333 | } 334 | -------------------------------------------------------------------------------- /examples/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/1.png -------------------------------------------------------------------------------- /examples/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/26.jpg -------------------------------------------------------------------------------- /examples/40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/40.jpg -------------------------------------------------------------------------------- /examples/lecun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/lecun.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | loss_fn = torch.nn.L1Loss() 8 | 9 | 10 | def ad_loss( 11 | q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None 12 | ): 13 | loss = 0 14 | attn_mask = None 15 | for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list): 16 | if source_mask is not None and target_mask is not None: 17 | w = h = int(np.sqrt(q.shape[2])) 18 | mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w))) 19 | mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w))) 20 | attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1) 21 | attn_mask=attn_mask.to(q.device) 22 | 23 | target_out = F.scaled_dot_product_attention( 24 | q * scale, 25 | torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1), 26 | torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1), 27 | attn_mask=attn_mask 28 | ) 29 | loss += loss_fn(self_out, target_out.detach()) 30 | return loss 31 | 32 | 33 | 34 | def q_loss(q_list, qc_list): 35 | loss = 0 36 | for q, qc in zip(q_list, qc_list): 37 | loss += loss_fn(q, qc.detach()) 38 | return loss 39 | 40 | # weight = 200 41 | def qk_loss(q_list, k_list, qc_list, kc_list): 42 | loss = 0 43 | for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list): 44 | scale_factor = 1 / math.sqrt(q.size(-1)) 45 | self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1) 46 | target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1) 47 | loss += loss_fn(self_map, target_map.detach()) 48 | return loss 49 | 50 | # weight = 1 51 | def qkv_loss(q_list, k_list, vc_list, c_out_list): 52 | loss = 0 53 | for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list): 54 | self_out = F.scaled_dot_product_attention(q, k, vc) 55 | loss += loss_fn(self_out, target_out.detach()) 56 | return loss 57 | -------------------------------------------------------------------------------- /pipeline_flux.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from .utils import DataCache, register_attn_control_flux, adain_flux 8 | from accelerate import Accelerator 9 | from diffusers import FluxPipeline 10 | 11 | 12 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 13 | def retrieve_latents( 14 | encoder_output: torch.Tensor, 15 | generator: Optional[torch.Generator] = None, 16 | sample_mode: str = "sample", 17 | ): 18 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 19 | return encoder_output.latent_dist.sample(generator) 20 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 21 | return encoder_output.latent_dist.mode() 22 | elif hasattr(encoder_output, "latents"): 23 | return encoder_output.latents 24 | else: 25 | raise AttributeError("Could not access latents of provided encoder_output") 26 | 27 | 28 | class ADPipeline(FluxPipeline): 29 | def freeze(self): 30 | self.transformer.requires_grad_(False) 31 | self.text_encoder.requires_grad_(False) 32 | self.text_encoder_2.requires_grad_(False) 33 | self.vae.requires_grad_(False) 34 | 35 | @torch.no_grad() 36 | def image2latent(self, image): 37 | dtype = next(self.vae.parameters()).dtype 38 | device = self._execution_device 39 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0 40 | latent = retrieve_latents(self.vae.encode(image)) 41 | latent = ( 42 | latent - self.vae.config.shift_factor 43 | ) * self.vae.config.scaling_factor 44 | return latent 45 | 46 | @torch.no_grad() 47 | def latent2image(self, latent, height, width): 48 | dtype = next(self.vae.parameters()).dtype 49 | device = self._execution_device 50 | latent = latent.to(device=device, dtype=dtype) 51 | latents = self._unpack_latents(latent, height, width, self.vae_scale_factor) 52 | latents = ( 53 | latents / self.vae.config.scaling_factor 54 | ) + self.vae.config.shift_factor 55 | image = self.vae.decode(latents, return_dict=False)[0] 56 | return (image * 0.5 + 0.5).clamp(0, 1) 57 | 58 | def init(self, enable_gradient_checkpoint): 59 | self.freeze() 60 | self.enable_vae_slicing() 61 | # self.enable_model_cpu_offload() 62 | # self.enable_vae_tiling() 63 | weight_dtype = torch.float32 64 | if self.accelerator.mixed_precision == "fp16": 65 | weight_dtype = torch.float16 66 | elif self.accelerator.mixed_precision == "bf16": 67 | weight_dtype = torch.bfloat16 68 | 69 | # Move unet, vae and text_encoder to device and cast to weight_dtype 70 | self.transformer.to(self.accelerator.device, dtype=weight_dtype) 71 | self.vae.to(self.accelerator.device, dtype=weight_dtype) 72 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype) 73 | self.classifier.to(self.accelerator.device, dtype=weight_dtype) 74 | self.classifier = self.accelerator.prepare(self.classifier) 75 | if enable_gradient_checkpoint: 76 | self.classifier.enable_gradient_checkpointing() 77 | 78 | def sample( 79 | self, 80 | style_image=None, 81 | controller=None, 82 | loss_fn=torch.nn.L1Loss(), 83 | start_time=9999, 84 | lr=0.05, 85 | iters=2, 86 | adain=True, 87 | mixed_precision="no", 88 | enable_gradient_checkpoint=False, 89 | prompt: Union[str, List[str]] = None, 90 | prompt_2: Optional[Union[str, List[str]]] = None, 91 | height: Optional[int] = None, 92 | width: Optional[int] = None, 93 | num_inference_steps: int = 28, 94 | # timesteps: List[int] = None, 95 | guidance_scale: float = 3.5, 96 | num_images_per_prompt: Optional[int] = 1, 97 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 98 | latents: Optional[torch.FloatTensor] = None, 99 | prompt_embeds: Optional[torch.FloatTensor] = None, 100 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 101 | output_type: Optional[str] = "pil", 102 | return_dict: bool = True, 103 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 104 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 105 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 106 | max_sequence_length: int = 512, 107 | **kwargs 108 | ): 109 | height = height or self.default_sample_size * self.vae_scale_factor 110 | width = width or self.default_sample_size * self.vae_scale_factor 111 | device = self._execution_device 112 | self.accelerator = Accelerator( 113 | mixed_precision=mixed_precision, gradient_accumulation_steps=1 114 | ) 115 | 116 | self.init(enable_gradient_checkpoint) 117 | 118 | (null_embeds, null_pooled_embeds, null_text_ids) = self.encode_prompt( 119 | prompt="", 120 | prompt_2=prompt_2, 121 | ) 122 | ( 123 | prompt_embeds, 124 | pooled_prompt_embeds, 125 | text_ids, 126 | ) = self.encode_prompt( 127 | prompt=prompt, 128 | prompt_2=prompt_2, 129 | prompt_embeds=prompt_embeds, 130 | pooled_prompt_embeds=pooled_prompt_embeds, 131 | device=device, 132 | num_images_per_prompt=num_images_per_prompt, 133 | max_sequence_length=max_sequence_length, 134 | ) 135 | # 4. Prepare latent variables 136 | num_channels_latents = self.transformer.config.in_channels // 4 137 | latents, latent_image_ids = self.prepare_latents( 138 | num_images_per_prompt, 139 | num_channels_latents, 140 | height, 141 | width, 142 | null_embeds.dtype, 143 | device, 144 | generator, 145 | latents, 146 | ) 147 | 148 | # print(style_image.shape) 149 | height_, width_ = style_image.shape[2], style_image.shape[3] 150 | style_latent = self.image2latent(style_image) 151 | # print(style_latent.shape) 152 | # print(latents.shape) 153 | style_latent = self._pack_latents(style_latent, 1, num_channels_latents, style_latent.shape[2], style_latent.shape[3]) 154 | 155 | _, null_image_id = self.prepare_latents( 156 | num_images_per_prompt, 157 | num_channels_latents, 158 | height_, 159 | width_, 160 | null_embeds.dtype, 161 | device, 162 | generator, 163 | style_latent, 164 | ) 165 | 166 | # 5. Prepare timesteps 167 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 168 | image_seq_len = latents.shape[1] 169 | mu = calculate_shift( 170 | image_seq_len, 171 | self.scheduler.config.base_image_seq_len, 172 | self.scheduler.config.max_image_seq_len, 173 | self.scheduler.config.base_shift, 174 | self.scheduler.config.max_shift, 175 | ) 176 | timesteps, num_inference_steps = retrieve_timesteps( 177 | self.scheduler, 178 | num_inference_steps, 179 | device, 180 | None, 181 | sigmas, 182 | mu=mu, 183 | ) 184 | 185 | timesteps = self.scheduler.timesteps 186 | # print(f"timesteps: {timesteps}") 187 | self._num_timesteps = len(timesteps) 188 | 189 | cache = DataCache() 190 | 191 | register_attn_control_flux( 192 | self.classifier.transformer_blocks, 193 | controller=controller, 194 | cache=cache, 195 | ) 196 | register_attn_control_flux( 197 | self.classifier.single_transformer_blocks, 198 | controller=controller, 199 | cache=cache, 200 | ) 201 | # handle guidance 202 | if self.transformer.config.guidance_embeds: 203 | guidance = torch.full( 204 | [1], guidance_scale, device=device, dtype=torch.float32 205 | ) 206 | guidance = guidance.expand(latents.shape[0]) 207 | else: 208 | guidance = None 209 | 210 | null_guidance = torch.full( 211 | [1], 1, device=device, dtype=torch.float32 212 | ) 213 | 214 | # print(controller.num_self_layers) 215 | 216 | 217 | pbar = tqdm(timesteps, desc="Sample") 218 | for i, t in enumerate(pbar): 219 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 220 | with torch.no_grad(): 221 | noise_pred = self.transformer( 222 | hidden_states=latents, 223 | timestep=timestep / 1000, 224 | guidance=guidance, 225 | pooled_projections=pooled_prompt_embeds, 226 | encoder_hidden_states=prompt_embeds, 227 | txt_ids=text_ids, 228 | img_ids=latent_image_ids, 229 | joint_attention_kwargs=None, 230 | return_dict=False, 231 | )[0] 232 | 233 | # compute the previous noisy sample x_t -> x_t-1 234 | latents = self.scheduler.step( 235 | noise_pred, t, latents, return_dict=False 236 | )[0] 237 | if t < start_time: 238 | if i < num_inference_steps - 1: 239 | timestep = timesteps[i+1:i+2] 240 | # print(timestep) 241 | noise = torch.randn_like(style_latent) 242 | # print(style_latent.shape) 243 | style_latent_ = self.scheduler.scale_noise(style_latent, timestep, noise) 244 | else: 245 | timestep = torch.tensor([0], device=style_latent.device) 246 | style_latent_ = style_latent 247 | 248 | cache.clear() 249 | controller.step() 250 | 251 | _ = self.transformer( 252 | hidden_states=style_latent_, 253 | timestep=timestep / 1000, 254 | guidance=null_guidance, 255 | pooled_projections=null_pooled_embeds, 256 | encoder_hidden_states=null_embeds, 257 | txt_ids=null_text_ids, 258 | img_ids=null_image_id, 259 | joint_attention_kwargs=None, 260 | return_dict=False, 261 | )[0] 262 | _, ref_k_list, ref_v_list, _ = cache.get() 263 | 264 | if adain: 265 | latents = adain_flux(latents, style_latent_) 266 | 267 | latents = latents.detach() 268 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr) 269 | optimizer = self.accelerator.prepare(optimizer) 270 | 271 | for _ in range(iters): 272 | cache.clear() 273 | controller.step() 274 | optimizer.zero_grad() 275 | _ = self.classifier( 276 | hidden_states=latents, 277 | timestep=timestep / 1000, 278 | guidance=null_guidance, 279 | pooled_projections=null_pooled_embeds, 280 | encoder_hidden_states=null_embeds, 281 | txt_ids=null_text_ids, 282 | img_ids=latent_image_ids, 283 | joint_attention_kwargs=None, 284 | return_dict=False, 285 | )[0] 286 | q_list, _, _, self_out_list = cache.get() 287 | ref_self_out_list = [ 288 | F.scaled_dot_product_attention( 289 | q, 290 | ref_k, 291 | ref_v, 292 | ) 293 | for q, ref_k, ref_v in zip(q_list, ref_k_list, ref_v_list) 294 | ] 295 | style_loss = sum( 296 | [ 297 | loss_fn(self_out, ref_self_out.detach()) 298 | for self_out, ref_self_out in zip( 299 | self_out_list, ref_self_out_list 300 | ) 301 | ] 302 | ) 303 | loss = style_loss 304 | self.accelerator.backward(loss) 305 | # loss.backward() 306 | optimizer.step() 307 | 308 | pbar.set_postfix(loss=loss.item(), time=t.item()) 309 | torch.cuda.empty_cache() 310 | latents = latents.detach() 311 | return self.latent2image(latents, height, width) 312 | 313 | 314 | def calculate_shift( 315 | image_seq_len, 316 | base_seq_len: int = 256, 317 | max_seq_len: int = 4096, 318 | base_shift: float = 0.5, 319 | max_shift: float = 1.16, 320 | ): 321 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 322 | b = base_shift - m * base_seq_len 323 | mu = image_seq_len * m + b 324 | return mu 325 | 326 | 327 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 328 | def retrieve_timesteps( 329 | scheduler, 330 | num_inference_steps: Optional[int] = None, 331 | device: Optional[Union[str, torch.device]] = None, 332 | timesteps: Optional[List[int]] = None, 333 | sigmas: Optional[List[float]] = None, 334 | **kwargs, 335 | ): 336 | r""" 337 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 338 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 339 | 340 | Args: 341 | scheduler (`SchedulerMixin`): 342 | The scheduler to get timesteps from. 343 | num_inference_steps (`int`): 344 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 345 | must be `None`. 346 | device (`str` or `torch.device`, *optional*): 347 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 348 | timesteps (`List[int]`, *optional*): 349 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 350 | `num_inference_steps` and `sigmas` must be `None`. 351 | sigmas (`List[float]`, *optional*): 352 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 353 | `num_inference_steps` and `timesteps` must be `None`. 354 | 355 | Returns: 356 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 357 | second element is the number of inference steps. 358 | """ 359 | if timesteps is not None and sigmas is not None: 360 | raise ValueError( 361 | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" 362 | ) 363 | if timesteps is not None: 364 | accepts_timesteps = "timesteps" in set( 365 | inspect.signature(scheduler.set_timesteps).parameters.keys() 366 | ) 367 | if not accepts_timesteps: 368 | raise ValueError( 369 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 370 | f" timestep schedules. Please check whether you are using the correct scheduler." 371 | ) 372 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 373 | timesteps = scheduler.timesteps 374 | num_inference_steps = len(timesteps) 375 | elif sigmas is not None: 376 | accept_sigmas = "sigmas" in set( 377 | inspect.signature(scheduler.set_timesteps).parameters.keys() 378 | ) 379 | if not accept_sigmas: 380 | raise ValueError( 381 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 382 | f" sigmas schedules. Please check whether you are using the correct scheduler." 383 | ) 384 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 385 | timesteps = scheduler.timesteps 386 | num_inference_steps = len(timesteps) 387 | else: 388 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 389 | timesteps = scheduler.timesteps 390 | return timesteps, num_inference_steps 391 | -------------------------------------------------------------------------------- /pipeline_sd.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from accelerate import Accelerator 8 | from diffusers import StableDiffusionPipeline 9 | from diffusers.image_processor import PipelineImageInput 10 | from .losses import ad_loss, q_loss 11 | from .utils import DataCache, register_attn_control, adain 12 | from tqdm import tqdm 13 | 14 | 15 | class ADPipeline(StableDiffusionPipeline): 16 | def freeze(self): 17 | self.vae.requires_grad_(False) 18 | self.unet.requires_grad_(False) 19 | self.text_encoder.requires_grad_(False) 20 | self.classifier.requires_grad_(False) 21 | 22 | @torch.no_grad() 23 | def image2latent(self, image): 24 | dtype = next(self.vae.parameters()).dtype 25 | device = self._execution_device 26 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0 27 | latent = self.vae.encode(image)["latent_dist"].mean 28 | latent = latent * self.vae.config.scaling_factor 29 | return latent 30 | 31 | @torch.no_grad() 32 | def latent2image(self, latent): 33 | dtype = next(self.vae.parameters()).dtype 34 | device = self._execution_device 35 | latent = latent.to(device=device, dtype=dtype) 36 | latent = latent / self.vae.config.scaling_factor 37 | image = self.vae.decode(latent)[0] 38 | return (image * 0.5 + 0.5).clamp(0, 1) 39 | 40 | def init(self, enable_gradient_checkpoint): 41 | self.freeze() 42 | self.enable_vae_slicing() 43 | # self.enable_model_cpu_offload() 44 | # self.enable_vae_tiling() 45 | weight_dtype = torch.float32 46 | if self.accelerator.mixed_precision == "fp16": 47 | weight_dtype = torch.float16 48 | elif self.accelerator.mixed_precision == "bf16": 49 | weight_dtype = torch.bfloat16 50 | 51 | # Move unet, vae and text_encoder to device and cast to weight_dtype 52 | self.unet.to(self.accelerator.device, dtype=weight_dtype) 53 | self.vae.to(self.accelerator.device, dtype=weight_dtype) 54 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype) 55 | self.classifier.to(self.accelerator.device, dtype=weight_dtype) 56 | self.classifier = self.accelerator.prepare(self.classifier) 57 | if enable_gradient_checkpoint: 58 | self.classifier.enable_gradient_checkpointing() 59 | 60 | def sample( 61 | self, 62 | lr=0.05, 63 | iters=1, 64 | attn_scale=1, 65 | adain=False, 66 | weight=0.25, 67 | controller=None, 68 | style_image=None, 69 | content_image=None, 70 | mixed_precision="no", 71 | start_time=999, 72 | enable_gradient_checkpoint=False, 73 | prompt: Union[str, List[str]] = None, 74 | height: Optional[int] = None, 75 | width: Optional[int] = None, 76 | num_inference_steps: int = 50, 77 | guidance_scale: float = 7.5, 78 | negative_prompt: Optional[Union[str, List[str]]] = None, 79 | num_images_per_prompt: Optional[int] = 1, 80 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 81 | latents: Optional[torch.Tensor] = None, 82 | prompt_embeds: Optional[torch.Tensor] = None, 83 | negative_prompt_embeds: Optional[torch.Tensor] = None, 84 | ip_adapter_image: Optional[PipelineImageInput] = None, 85 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 86 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 87 | guidance_rescale: float = 0.0, 88 | clip_skip: Optional[int] = None, 89 | **kwargs, 90 | ): 91 | # 0. Default height and width to unet 92 | height = height or self.unet.config.sample_size * self.vae_scale_factor 93 | width = width or self.unet.config.sample_size * self.vae_scale_factor 94 | self._guidance_scale = guidance_scale 95 | self._guidance_rescale = guidance_rescale 96 | self._clip_skip = clip_skip 97 | self._cross_attention_kwargs = cross_attention_kwargs 98 | self._interrupt = False 99 | 100 | self.accelerator = Accelerator( 101 | mixed_precision=mixed_precision, gradient_accumulation_steps=1 102 | ) 103 | self.init(enable_gradient_checkpoint) 104 | 105 | # 2. Define call parameters 106 | if prompt is not None and isinstance(prompt, str): 107 | batch_size = 1 108 | elif prompt is not None and isinstance(prompt, list): 109 | batch_size = len(prompt) 110 | else: 111 | batch_size = prompt_embeds.shape[0] 112 | 113 | device = self._execution_device 114 | 115 | # 3. Encode input prompt 116 | lora_scale = ( 117 | self.cross_attention_kwargs.get("scale", None) 118 | if self.cross_attention_kwargs is not None 119 | else None 120 | ) 121 | do_cfg = guidance_scale > 1.0 122 | 123 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 124 | prompt, 125 | device, 126 | num_images_per_prompt, 127 | do_cfg, 128 | negative_prompt, 129 | prompt_embeds=prompt_embeds, 130 | negative_prompt_embeds=negative_prompt_embeds, 131 | lora_scale=lora_scale, 132 | clip_skip=self.clip_skip, 133 | ) 134 | 135 | # For classifier free guidance, we need to do two forward passes. 136 | # Here we concatenate the unconditional and text embeddings into a single batch 137 | # to avoid doing two forward passes 138 | if do_cfg: 139 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 140 | 141 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 142 | image_embeds = self.prepare_ip_adapter_image_embeds( 143 | ip_adapter_image, 144 | ip_adapter_image_embeds, 145 | device, 146 | batch_size * num_images_per_prompt, 147 | do_cfg, 148 | ) 149 | 150 | # 5. Prepare latent variables 151 | num_channels_latents = self.unet.config.in_channels 152 | latents = self.prepare_latents( 153 | batch_size * num_images_per_prompt, 154 | num_channels_latents, 155 | height, 156 | width, 157 | prompt_embeds.dtype, 158 | device, 159 | generator, 160 | latents, 161 | ) 162 | 163 | # 6.1 Add image embeds for IP-Adapter 164 | added_cond_kwargs = ( 165 | {"image_embeds": image_embeds} 166 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) 167 | else None 168 | ) 169 | 170 | # 6.2 Optionally get Guidance Scale Embedding 171 | timestep_cond = None 172 | if self.unet.config.time_cond_proj_dim is not None: 173 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( 174 | batch_size * num_images_per_prompt 175 | ) 176 | timestep_cond = self.get_guidance_scale_embedding( 177 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 178 | ).to(device=device, dtype=latents.dtype) 179 | 180 | self.scheduler.set_timesteps(num_inference_steps) 181 | timesteps = self.scheduler.timesteps 182 | self.style_latent = self.image2latent(style_image) 183 | if content_image is not None: 184 | self.content_latent = self.image2latent(content_image) 185 | else: 186 | self.content_latent = None 187 | null_embeds = self.encode_prompt("", device, 1, False)[0] 188 | self.null_embeds = null_embeds 189 | self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0]) 190 | self.null_embeds_for_style = torch.cat( 191 | [null_embeds] * self.style_latent.shape[0] 192 | ) 193 | 194 | self.adain = adain 195 | self.attn_scale = attn_scale 196 | self.cache = DataCache() 197 | self.controller = controller 198 | register_attn_control( 199 | self.classifier, controller=self.controller, cache=self.cache 200 | ) 201 | print("Total self attention layers of Unet: ", controller.num_self_layers) 202 | print("Self attention layers for AD: ", controller.self_layers) 203 | 204 | pbar = tqdm(timesteps, desc="Sample") 205 | for i, t in enumerate(pbar): 206 | with torch.no_grad(): 207 | # expand the latents if we are doing classifier free guidance 208 | latent_model_input = torch.cat([latents] * 2) if do_cfg else latents 209 | latent_model_input = self.scheduler.scale_model_input( 210 | latent_model_input, t 211 | ) 212 | # predict the noise residual 213 | noise_pred = self.unet( 214 | latent_model_input, 215 | t, 216 | encoder_hidden_states=prompt_embeds, 217 | timestep_cond=timestep_cond, 218 | cross_attention_kwargs=self.cross_attention_kwargs, 219 | added_cond_kwargs=added_cond_kwargs, 220 | return_dict=False, 221 | )[0] 222 | 223 | # perform guidance 224 | if do_cfg: 225 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 226 | noise_pred = noise_pred_uncond + self.guidance_scale * ( 227 | noise_pred_text - noise_pred_uncond 228 | ) 229 | latents = self.scheduler.step( 230 | noise_pred, t, latents, return_dict=False 231 | )[0] 232 | if iters > 0 and t < start_time: 233 | latents = self.AD(latents, t, lr, iters, pbar, weight) 234 | 235 | images = self.latent2image(latents) 236 | # Offload all models 237 | self.maybe_free_model_hooks() 238 | return images 239 | 240 | def optimize( 241 | self, 242 | latents=None, 243 | attn_scale=1.0, 244 | lr=0.05, 245 | iters=1, 246 | weight=0, 247 | width=512, 248 | height=512, 249 | batch_size=1, 250 | controller=None, 251 | style_image=None, 252 | content_image=None, 253 | mixed_precision="no", 254 | num_inference_steps=50, 255 | enable_gradient_checkpoint=False, 256 | source_mask=None, 257 | target_mask=None, 258 | ): 259 | height = height // self.vae_scale_factor 260 | width = width // self.vae_scale_factor 261 | 262 | self.accelerator = Accelerator( 263 | mixed_precision=mixed_precision, gradient_accumulation_steps=1 264 | ) 265 | self.init(enable_gradient_checkpoint) 266 | 267 | style_latent = self.image2latent(style_image) 268 | latents = torch.randn((batch_size, 4, height, width), device=self.device) 269 | null_embeds = self.encode_prompt("", self.device, 1, False)[0] 270 | null_embeds_for_latents = null_embeds.repeat(latents.shape[0], 1, 1) 271 | null_embeds_for_style = null_embeds.repeat(style_latent.shape[0], 1, 1) 272 | 273 | if content_image is not None: 274 | content_latent = self.image2latent(content_image) 275 | latents = torch.cat([content_latent.clone()] * batch_size) 276 | null_embeds_for_content = null_embeds.repeat(content_latent.shape[0], 1, 1) 277 | 278 | self.cache = DataCache() 279 | self.controller = controller 280 | register_attn_control( 281 | self.classifier, controller=self.controller, cache=self.cache 282 | ) 283 | print("Total self attention layers of Unet: ", controller.num_self_layers) 284 | print("Self attention layers for AD: ", controller.self_layers) 285 | 286 | self.scheduler.set_timesteps(num_inference_steps) 287 | timesteps = self.scheduler.timesteps 288 | latents = latents.detach().float() 289 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr) 290 | optimizer = self.accelerator.prepare(optimizer) 291 | pbar = tqdm(timesteps, desc="Optimize") 292 | for i, t in enumerate(pbar): 293 | # t = torch.tensor([1], device=self.device) 294 | with torch.no_grad(): 295 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature( 296 | style_latent, 297 | t, 298 | null_embeds_for_style, 299 | ) 300 | if content_image is not None: 301 | qc_list, kc_list, vc_list, c_out_list = self.extract_feature( 302 | content_latent, 303 | t, 304 | null_embeds_for_content, 305 | ) 306 | for j in range(iters): 307 | style_loss = 0 308 | content_loss = 0 309 | optimizer.zero_grad() 310 | q_list, k_list, v_list, self_out_list = self.extract_feature( 311 | latents, 312 | t, 313 | null_embeds_for_latents, 314 | ) 315 | style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=attn_scale, source_mask=source_mask, target_mask=target_mask) 316 | if content_image is not None: 317 | content_loss = q_loss(q_list, qc_list) 318 | # content_loss = qk_loss(q_list, k_list, qc_list, kc_list) 319 | # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list) 320 | loss = style_loss + content_loss * weight 321 | self.accelerator.backward(loss) 322 | optimizer.step() 323 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j) 324 | images = self.latent2image(latents) 325 | # Offload all models 326 | self.maybe_free_model_hooks() 327 | return images 328 | 329 | def panorama( 330 | self, 331 | lr=0.05, 332 | iters=1, 333 | attn_scale=1, 334 | adain=False, 335 | controller=None, 336 | style_image=None, 337 | mixed_precision="no", 338 | enable_gradient_checkpoint=False, 339 | prompt: Union[str, List[str]] = None, 340 | height: Optional[int] = None, 341 | width: Optional[int] = None, 342 | num_inference_steps: int = 50, 343 | guidance_scale: float = 1, 344 | stride=8, 345 | view_batch_size: int = 16, 346 | negative_prompt: Optional[Union[str, List[str]]] = None, 347 | num_images_per_prompt: Optional[int] = 1, 348 | eta: float = 0.0, 349 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 350 | latents: Optional[torch.Tensor] = None, 351 | prompt_embeds: Optional[torch.Tensor] = None, 352 | negative_prompt_embeds: Optional[torch.Tensor] = None, 353 | ip_adapter_image: Optional[PipelineImageInput] = None, 354 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 355 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 356 | guidance_rescale: float = 0.0, 357 | clip_skip: Optional[int] = None, 358 | **kwargs, 359 | ): 360 | 361 | # 0. Default height and width to unet 362 | height = height or self.unet.config.sample_size * self.vae_scale_factor 363 | width = width or self.unet.config.sample_size * self.vae_scale_factor 364 | 365 | self._guidance_scale = guidance_scale 366 | self._guidance_rescale = guidance_rescale 367 | self._clip_skip = clip_skip 368 | self._cross_attention_kwargs = cross_attention_kwargs 369 | self._interrupt = False 370 | 371 | self.accelerator = Accelerator( 372 | mixed_precision=mixed_precision, gradient_accumulation_steps=1 373 | ) 374 | self.init(enable_gradient_checkpoint) 375 | 376 | # 2. Define call parameters 377 | if prompt is not None and isinstance(prompt, str): 378 | batch_size = 1 379 | elif prompt is not None and isinstance(prompt, list): 380 | batch_size = len(prompt) 381 | else: 382 | batch_size = prompt_embeds.shape[0] 383 | 384 | device = self._execution_device 385 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 386 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 387 | # corresponds to doing no classifier free guidance. 388 | do_cfg = guidance_scale > 1.0 389 | 390 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 391 | image_embeds = self.prepare_ip_adapter_image_embeds( 392 | ip_adapter_image, 393 | ip_adapter_image_embeds, 394 | device, 395 | batch_size * num_images_per_prompt, 396 | self.do_classifier_free_guidance, 397 | ) 398 | 399 | # 3. Encode input prompt 400 | text_encoder_lora_scale = ( 401 | cross_attention_kwargs.get("scale", None) 402 | if cross_attention_kwargs is not None 403 | else None 404 | ) 405 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 406 | prompt, 407 | device, 408 | num_images_per_prompt, 409 | do_cfg, 410 | negative_prompt, 411 | prompt_embeds=prompt_embeds, 412 | negative_prompt_embeds=negative_prompt_embeds, 413 | lora_scale=text_encoder_lora_scale, 414 | clip_skip=clip_skip, 415 | ) 416 | # For classifier free guidance, we need to do two forward passes. 417 | # Here we concatenate the unconditional and text embeddings into a single batch 418 | # to avoid doing two forward passes 419 | if do_cfg: 420 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 421 | 422 | # 5. Prepare latent variables 423 | num_channels_latents = self.unet.config.in_channels 424 | latents = self.prepare_latents( 425 | batch_size * num_images_per_prompt, 426 | num_channels_latents, 427 | height, 428 | width, 429 | prompt_embeds.dtype, 430 | device, 431 | generator, 432 | latents, 433 | ) 434 | 435 | # 6. Define panorama grid and initialize views for synthesis. 436 | # prepare batch grid 437 | views = self.get_views_(height, width, window_size=64, stride=stride) 438 | views_batch = [ 439 | views[i : i + view_batch_size] 440 | for i in range(0, len(views), view_batch_size) 441 | ] 442 | print(len(views), len(views_batch), views_batch) 443 | self.scheduler.set_timesteps(num_inference_steps) 444 | views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len( 445 | views_batch 446 | ) 447 | count = torch.zeros_like(latents) 448 | value = torch.zeros_like(latents) 449 | 450 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 451 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 452 | 453 | # 7.1 Add image embeds for IP-Adapter 454 | added_cond_kwargs = ( 455 | {"image_embeds": image_embeds} 456 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None 457 | else None 458 | ) 459 | 460 | # 7.2 Optionally get Guidance Scale Embedding 461 | timestep_cond = None 462 | if self.unet.config.time_cond_proj_dim is not None: 463 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( 464 | batch_size * num_images_per_prompt 465 | ) 466 | timestep_cond = self.get_guidance_scale_embedding( 467 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 468 | ).to(device=device, dtype=latents.dtype) 469 | 470 | # 8. Denoising loop 471 | # Each denoising step also includes refinement of the latents with respect to the 472 | # views. 473 | 474 | timesteps = self.scheduler.timesteps 475 | self.style_latent = self.image2latent(style_image) 476 | self.content_latent = None 477 | null_embeds = self.encode_prompt("", device, 1, False)[0] 478 | self.null_embeds = null_embeds 479 | self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0]) 480 | self.null_embeds_for_style = torch.cat( 481 | [null_embeds] * self.style_latent.shape[0] 482 | ) 483 | self.adain = adain 484 | self.attn_scale = attn_scale 485 | self.cache = DataCache() 486 | self.controller = controller 487 | register_attn_control( 488 | self.classifier, controller=self.controller, cache=self.cache 489 | ) 490 | print("Total self attention layers of Unet: ", controller.num_self_layers) 491 | print("Self attention layers for AD: ", controller.self_layers) 492 | 493 | pbar = tqdm(timesteps, desc="Sample") 494 | for i, t in enumerate(pbar): 495 | count.zero_() 496 | value.zero_() 497 | # generate views 498 | # Here, we iterate through different spatial crops of the latents and denoise them. These 499 | # denoised (latent) crops are then averaged to produce the final latent 500 | # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the 501 | # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 502 | # Batch views denoise 503 | for j, batch_view in enumerate(views_batch): 504 | vb_size = len(batch_view) 505 | # get the latents corresponding to the current view coordinates 506 | latents_for_view = torch.cat( 507 | [ 508 | latents[:, :, h_start:h_end, w_start:w_end] 509 | for h_start, h_end, w_start, w_end in batch_view 510 | ] 511 | ) 512 | # rematch block's scheduler status 513 | self.scheduler.__dict__.update(views_scheduler_status[j]) 514 | 515 | # expand the latents if we are doing classifier free guidance 516 | latent_model_input = ( 517 | latents_for_view.repeat_interleave(2, dim=0) 518 | if do_cfg 519 | else latents_for_view 520 | ) 521 | 522 | latent_model_input = self.scheduler.scale_model_input( 523 | latent_model_input, t 524 | ) 525 | 526 | # repeat prompt_embeds for batch 527 | prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) 528 | 529 | # predict the noise residual 530 | with torch.no_grad(): 531 | noise_pred = self.unet( 532 | latent_model_input, 533 | t, 534 | encoder_hidden_states=prompt_embeds_input, 535 | timestep_cond=timestep_cond, 536 | cross_attention_kwargs=cross_attention_kwargs, 537 | added_cond_kwargs=added_cond_kwargs, 538 | ).sample 539 | 540 | # perform guidance 541 | if do_cfg: 542 | noise_pred_uncond, noise_pred_text = ( 543 | noise_pred[::2], 544 | noise_pred[1::2], 545 | ) 546 | noise_pred = noise_pred_uncond + guidance_scale * ( 547 | noise_pred_text - noise_pred_uncond 548 | ) 549 | 550 | # compute the previous noisy sample x_t -> x_t-1 551 | latents_denoised_batch = self.scheduler.step( 552 | noise_pred, t, latents_for_view, **extra_step_kwargs 553 | ).prev_sample 554 | if iters > 0: 555 | self.null_embeds_for_latents = torch.cat( 556 | [self.null_embeds] * noise_pred.shape[0] 557 | ) 558 | latents_denoised_batch = self.AD( 559 | latents_denoised_batch, t, lr, iters, pbar 560 | ) 561 | # save views scheduler status after sample 562 | views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) 563 | 564 | # extract value from batch 565 | for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( 566 | latents_denoised_batch.chunk(vb_size), batch_view 567 | ): 568 | 569 | value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised 570 | count[:, :, h_start:h_end, w_start:w_end] += 1 571 | 572 | # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 573 | latents = torch.where(count > 0, value / count, value) 574 | 575 | images = self.latent2image(latents) 576 | # Offload all models 577 | self.maybe_free_model_hooks() 578 | return images 579 | 580 | def AD(self, latents, t, lr, iters, pbar, weight=0): 581 | t = max( 582 | t 583 | - self.scheduler.config.num_train_timesteps 584 | // self.scheduler.num_inference_steps, 585 | torch.tensor([0], device=self.device), 586 | ) 587 | if self.adain: 588 | noise = torch.randn_like(self.style_latent) 589 | style_latent = self.scheduler.add_noise(self.style_latent, noise, t) 590 | latents = adain(latents, style_latent) 591 | 592 | with torch.no_grad(): 593 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature( 594 | self.style_latent, 595 | t, 596 | self.null_embeds_for_style, 597 | add_noise=True, 598 | ) 599 | if self.content_latent is not None: 600 | qc_list, kc_list, vc_list, c_out_list = self.extract_feature( 601 | self.content_latent, 602 | t, 603 | self.null_embeds, 604 | add_noise=True, 605 | ) 606 | 607 | latents = latents.detach() 608 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr) 609 | optimizer = self.accelerator.prepare(optimizer) 610 | 611 | for j in range(iters): 612 | style_loss = 0 613 | content_loss = 0 614 | optimizer.zero_grad() 615 | q_list, k_list, v_list, self_out_list = self.extract_feature( 616 | latents, 617 | t, 618 | self.null_embeds_for_latents, 619 | add_noise=False, 620 | ) 621 | style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale) 622 | if self.content_latent is not None: 623 | content_loss = q_loss(q_list, qc_list) 624 | # content_loss = qk_loss(q_list, k_list, qc_list, kc_list) 625 | # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list) 626 | loss = style_loss + content_loss * weight 627 | self.accelerator.backward(loss) 628 | optimizer.step() 629 | 630 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j) 631 | latents = latents.detach() 632 | return latents 633 | 634 | def extract_feature( 635 | self, 636 | latent, 637 | t, 638 | embeds, 639 | add_noise=False, 640 | ): 641 | self.cache.clear() 642 | self.controller.step() 643 | if add_noise: 644 | noise = torch.randn_like(latent) 645 | latent_ = self.scheduler.add_noise(latent, noise, t) 646 | else: 647 | latent_ = latent 648 | _ = self.classifier(latent_, t, embeds)[0] 649 | return self.cache.get() 650 | 651 | def get_views_( 652 | self, 653 | panorama_height: int, 654 | panorama_width: int, 655 | window_size: int = 64, 656 | stride: int = 8, 657 | ) -> List[Tuple[int, int, int, int]]: 658 | panorama_height //= 8 659 | panorama_width //= 8 660 | 661 | num_blocks_height = ( 662 | math.ceil((panorama_height - window_size) / stride) + 1 663 | if panorama_height > window_size 664 | else 1 665 | ) 666 | num_blocks_width = ( 667 | math.ceil((panorama_width - window_size) / stride) + 1 668 | if panorama_width > window_size 669 | else 1 670 | ) 671 | 672 | views = [] 673 | for i in range(int(num_blocks_height)): 674 | for j in range(int(num_blocks_width)): 675 | h_start = int(min(i * stride, panorama_height - window_size)) 676 | w_start = int(min(j * stride, panorama_width - window_size)) 677 | 678 | h_end = h_start + window_size 679 | w_end = w_start + window_size 680 | 681 | views.append((h_start, h_end, w_start, w_end)) 682 | 683 | return views 684 | -------------------------------------------------------------------------------- /pipeline_sdxl.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from accelerate import Accelerator 7 | from accelerate.utils import ( 8 | DistributedDataParallelKwargs, 9 | ProjectConfiguration, 10 | set_seed, 11 | ) 12 | from diffusers import StableDiffusionXLPipeline 13 | from diffusers.image_processor import PipelineImageInput 14 | from diffusers.utils.torch_utils import is_compiled_module 15 | 16 | from .utils import DataCache, register_attn_control, adain 17 | from .losses import ad_loss 18 | from tqdm import tqdm 19 | 20 | 21 | class ADPipeline(StableDiffusionXLPipeline): 22 | def freeze(self): 23 | self.unet.requires_grad_(False) 24 | self.text_encoder.requires_grad_(False) 25 | self.text_encoder_2.requires_grad_(False) 26 | self.vae.requires_grad_(False) 27 | self.classifier.requires_grad_(False) 28 | 29 | @torch.no_grad() 30 | def image2latent(self, image): 31 | dtype = next(self.vae.parameters()).dtype 32 | device = self._execution_device 33 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0 34 | latent = self.vae.encode(image)["latent_dist"].mean 35 | latent = latent * self.vae.config.scaling_factor 36 | return latent 37 | 38 | @torch.no_grad() 39 | def latent2image(self, latent): 40 | dtype = next(self.vae.parameters()).dtype 41 | device = self._execution_device 42 | latent = latent.to(device=device, dtype=dtype) 43 | latent = latent / self.vae.config.scaling_factor 44 | image = self.vae.decode(latent)[0] 45 | return (image * 0.5 + 0.5).clamp(0, 1) 46 | 47 | def init(self, enable_gradient_checkpoint): 48 | self.freeze() 49 | self.enable_vae_slicing() 50 | # self.enable_model_cpu_offload() 51 | # self.enable_vae_tiling() 52 | weight_dtype = torch.float32 53 | if self.accelerator.mixed_precision == "fp16": 54 | weight_dtype = torch.float16 55 | elif self.accelerator.mixed_precision == "bf16": 56 | weight_dtype = torch.bfloat16 57 | 58 | # Move unet, vae and text_encoder to device and cast to weight_dtype 59 | self.unet.to(self.accelerator.device, dtype=weight_dtype) 60 | self.vae.to(self.accelerator.device, dtype=weight_dtype) 61 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype) 62 | self.text_encoder_2.to(self.accelerator.device, dtype=weight_dtype) 63 | self.classifier.to(self.accelerator.device, dtype=weight_dtype) 64 | self.classifier = self.accelerator.prepare(self.classifier) 65 | if enable_gradient_checkpoint: 66 | self.classifier.enable_gradient_checkpointing() 67 | # self.classifier.train() 68 | 69 | 70 | def sample( 71 | self, 72 | lr=0.05, 73 | iters=1, 74 | adain=True, 75 | controller=None, 76 | style_image=None, 77 | mixed_precision="no", 78 | init_from_style=False, 79 | start_time=999, 80 | prompt: Union[str, List[str]] = None, 81 | prompt_2: Optional[Union[str, List[str]]] = None, 82 | height: Optional[int] = None, 83 | width: Optional[int] = None, 84 | num_inference_steps: int = 50, 85 | denoising_end: Optional[float] = None, 86 | guidance_scale: float = 5.0, 87 | negative_prompt: Optional[Union[str, List[str]]] = None, 88 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 89 | num_images_per_prompt: Optional[int] = 1, 90 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 91 | latents: Optional[torch.Tensor] = None, 92 | prompt_embeds: Optional[torch.Tensor] = None, 93 | negative_prompt_embeds: Optional[torch.Tensor] = None, 94 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 95 | negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, 96 | ip_adapter_image: Optional[PipelineImageInput] = None, 97 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 98 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 99 | guidance_rescale: float = 0.0, 100 | original_size: Optional[Tuple[int, int]] = None, 101 | crops_coords_top_left: Tuple[int, int] = (0, 0), 102 | target_size: Optional[Tuple[int, int]] = None, 103 | negative_original_size: Optional[Tuple[int, int]] = None, 104 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 105 | negative_target_size: Optional[Tuple[int, int]] = None, 106 | clip_skip: Optional[int] = None, 107 | enable_gradient_checkpoint=False, 108 | **kwargs, 109 | ): 110 | # 0. Default height and width to unet 111 | height = height or self.default_sample_size * self.vae_scale_factor 112 | width = width or self.default_sample_size * self.vae_scale_factor 113 | 114 | original_size = original_size or (height, width) 115 | target_size = target_size or (height, width) 116 | self._guidance_scale = guidance_scale 117 | self._guidance_rescale = guidance_rescale 118 | self._clip_skip = clip_skip 119 | self._cross_attention_kwargs = cross_attention_kwargs 120 | self._denoising_end = denoising_end 121 | self._interrupt = False 122 | 123 | self.accelerator = Accelerator( 124 | mixed_precision=mixed_precision, gradient_accumulation_steps=1 125 | ) 126 | self.init(enable_gradient_checkpoint) 127 | 128 | # 2. Define call parameters 129 | if prompt is not None and isinstance(prompt, str): 130 | batch_size = 1 131 | elif prompt is not None and isinstance(prompt, list): 132 | batch_size = len(prompt) 133 | else: 134 | batch_size = prompt_embeds.shape[0] 135 | 136 | device = self._execution_device 137 | 138 | # 3. Encode input prompt 139 | lora_scale = ( 140 | self.cross_attention_kwargs.get("scale", None) 141 | if self.cross_attention_kwargs is not None 142 | else None 143 | ) 144 | 145 | ( 146 | prompt_embeds, 147 | negative_prompt_embeds, 148 | pooled_prompt_embeds, 149 | negative_pooled_prompt_embeds, 150 | ) = self.encode_prompt( 151 | prompt=prompt, 152 | prompt_2=prompt_2, 153 | device=device, 154 | num_images_per_prompt=num_images_per_prompt, 155 | do_classifier_free_guidance=self.do_classifier_free_guidance, 156 | negative_prompt=negative_prompt, 157 | negative_prompt_2=negative_prompt_2, 158 | prompt_embeds=prompt_embeds, 159 | negative_prompt_embeds=negative_prompt_embeds, 160 | pooled_prompt_embeds=pooled_prompt_embeds, 161 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 162 | lora_scale=lora_scale, 163 | clip_skip=self.clip_skip, 164 | ) 165 | 166 | # 5. Prepare latent variables 167 | num_channels_latents = self.unet.config.in_channels 168 | latents = self.prepare_latents( 169 | batch_size * num_images_per_prompt, 170 | num_channels_latents, 171 | height, 172 | width, 173 | prompt_embeds.dtype, 174 | device, 175 | generator, 176 | latents, 177 | ) 178 | 179 | # 7. Prepare added time ids & embeddings 180 | add_text_embeds = pooled_prompt_embeds 181 | if self.text_encoder_2 is None: 182 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 183 | else: 184 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 185 | 186 | add_time_ids = self._get_add_time_ids( 187 | original_size, 188 | crops_coords_top_left, 189 | target_size, 190 | dtype=prompt_embeds.dtype, 191 | text_encoder_projection_dim=text_encoder_projection_dim, 192 | ) 193 | null_add_time_ids = add_time_ids.to(device) 194 | if negative_original_size is not None and negative_target_size is not None: 195 | negative_add_time_ids = self._get_add_time_ids( 196 | negative_original_size, 197 | negative_crops_coords_top_left, 198 | negative_target_size, 199 | dtype=prompt_embeds.dtype, 200 | text_encoder_projection_dim=text_encoder_projection_dim, 201 | ) 202 | else: 203 | negative_add_time_ids = add_time_ids 204 | 205 | if self.do_classifier_free_guidance: 206 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 207 | add_text_embeds = torch.cat( 208 | [negative_pooled_prompt_embeds, add_text_embeds], dim=0 209 | ) 210 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 211 | 212 | prompt_embeds = prompt_embeds.to(device) 213 | add_text_embeds = add_text_embeds.to(device) 214 | add_time_ids = add_time_ids.to(device).repeat( 215 | batch_size * num_images_per_prompt, 1 216 | ) 217 | 218 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 219 | image_embeds = self.prepare_ip_adapter_image_embeds( 220 | ip_adapter_image, 221 | ip_adapter_image_embeds, 222 | device, 223 | batch_size * num_images_per_prompt, 224 | self.do_classifier_free_guidance, 225 | ) 226 | # 8.1 Apply denoising_end 227 | if ( 228 | self.denoising_end is not None 229 | and isinstance(self.denoising_end, float) 230 | and self.denoising_end > 0 231 | and self.denoising_end < 1 232 | ): 233 | discrete_timestep_cutoff = int( 234 | round( 235 | self.scheduler.config.num_train_timesteps 236 | - (self.denoising_end * self.scheduler.config.num_train_timesteps) 237 | ) 238 | ) 239 | num_inference_steps = len( 240 | list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) 241 | ) 242 | timesteps = timesteps[:num_inference_steps] 243 | 244 | # 9. Optionally get Guidance Scale Embedding 245 | timestep_cond = None 246 | if self.unet.config.time_cond_proj_dim is not None: 247 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( 248 | batch_size * num_images_per_prompt 249 | ) 250 | timestep_cond = self.get_guidance_scale_embedding( 251 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 252 | ).to(device=device, dtype=latents.dtype) 253 | self.timestep_cond = timestep_cond 254 | (null_embeds, _, null_pooled_embeds, _) = self.encode_prompt("", device=device) 255 | 256 | added_cond_kwargs = { 257 | "text_embeds": add_text_embeds, 258 | "time_ids": add_time_ids 259 | } 260 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 261 | added_cond_kwargs["image_embeds"] = image_embeds 262 | 263 | self.scheduler.set_timesteps(num_inference_steps) 264 | 265 | timesteps = self.scheduler.timesteps 266 | style_latent = self.image2latent(style_image) 267 | if init_from_style: 268 | latents = torch.cat([style_latent] * latents.shape[0]) 269 | noise = torch.randn_like(latents) 270 | latents = self.scheduler.add_noise( 271 | latents, 272 | noise, 273 | torch.tensor([999]), 274 | ) 275 | 276 | self.style_latent = style_latent 277 | self.null_embeds_for_latents = torch.cat([null_embeds] * (latents.shape[0])) 278 | self.null_embeds_for_style = torch.cat([null_embeds] * style_latent.shape[0]) 279 | self.null_added_cond_kwargs_for_latents = { 280 | "text_embeds": torch.cat([null_pooled_embeds] * (latents.shape[0])), 281 | "time_ids": torch.cat([null_add_time_ids] * (latents.shape[0])), 282 | } 283 | self.null_added_cond_kwargs_for_style = { 284 | "text_embeds": torch.cat([null_pooled_embeds] * style_latent.shape[0]), 285 | "time_ids": torch.cat([null_add_time_ids] * style_latent.shape[0]), 286 | } 287 | self.adain = adain 288 | self.cache = DataCache() 289 | self.controller = controller 290 | register_attn_control( 291 | self.classifier, controller=controller, cache=self.cache 292 | ) 293 | print("Total self attention layers of Unet: ", controller.num_self_layers) 294 | print("Self attention layers for AD: ", controller.self_layers) 295 | 296 | pbar = tqdm(timesteps, desc="Sample") 297 | for i, t in enumerate(pbar): 298 | with torch.no_grad(): 299 | # expand the latents if we are doing classifier free guidance 300 | latent_model_input = ( 301 | torch.cat([latents] * 2) 302 | if self.do_classifier_free_guidance 303 | else latents 304 | ) 305 | 306 | # predict the noise residual 307 | noise_pred = self.unet( 308 | latent_model_input, 309 | t, 310 | encoder_hidden_states=prompt_embeds, 311 | timestep_cond=timestep_cond, 312 | cross_attention_kwargs=self.cross_attention_kwargs, 313 | added_cond_kwargs=added_cond_kwargs, 314 | return_dict=False, 315 | )[0] 316 | 317 | # perform guidance 318 | if self.do_classifier_free_guidance: 319 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 320 | noise_pred = noise_pred_uncond + self.guidance_scale * ( 321 | noise_pred_text - noise_pred_uncond 322 | ) 323 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 324 | 325 | if iters > 0 and t < start_time: 326 | latents = self.AD(latents, t, lr, iters, pbar) 327 | 328 | 329 | # Offload all models 330 | # self.enable_model_cpu_offload() 331 | images = self.latent2image(latents) 332 | self.maybe_free_model_hooks() 333 | return images 334 | 335 | def AD(self, latents, t, lr, iters, pbar): 336 | t = max( 337 | t 338 | - self.scheduler.config.num_train_timesteps 339 | // self.scheduler.num_inference_steps, 340 | torch.tensor([0], device=self.device), 341 | ) 342 | 343 | if self.adain: 344 | noise = torch.randn_like(self.style_latent) 345 | style_latent = self.scheduler.add_noise(self.style_latent, noise, t) 346 | latents = adain(latents, style_latent) 347 | 348 | with torch.no_grad(): 349 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature( 350 | self.style_latent, 351 | t, 352 | self.null_embeds_for_style, 353 | self.timestep_cond, 354 | self.null_added_cond_kwargs_for_style, 355 | add_noise=True, 356 | ) 357 | # latents = latents.to(dtype=torch.float32) 358 | latents = latents.detach() 359 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr) 360 | optimizer, latents = self.accelerator.prepare(optimizer, latents) 361 | 362 | for j in range(iters): 363 | optimizer.zero_grad() 364 | q_list, k_list, v_list, self_out_list = self.extract_feature( 365 | latents, 366 | t, 367 | self.null_embeds_for_latents, 368 | self.timestep_cond, 369 | self.null_added_cond_kwargs_for_latents, 370 | add_noise=False, 371 | ) 372 | 373 | loss = ad_loss(q_list, ks_list, vs_list, self_out_list) 374 | self.accelerator.backward(loss) 375 | optimizer.step() 376 | 377 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j) 378 | latents = latents.detach() 379 | return latents 380 | 381 | def extract_feature( 382 | self, 383 | latent, 384 | t, 385 | encoder_hidden_states, 386 | timestep_cond, 387 | added_cond_kwargs, 388 | add_noise=False, 389 | ): 390 | self.cache.clear() 391 | self.controller.step() 392 | if add_noise: 393 | noise = torch.randn_like(latent) 394 | latent_ = self.scheduler.add_noise(latent, noise, t) 395 | else: 396 | latent_ = latent 397 | self.classifier( 398 | latent_, 399 | t, 400 | encoder_hidden_states=encoder_hidden_states, 401 | timestep_cond=timestep_cond, 402 | added_cond_kwargs=added_cond_kwargs, 403 | return_dict=False, 404 | )[0] 405 | return self.cache.get() 406 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "attention-distillation" 3 | description = "Non-native [a/AttentionDistillation](https://github.com/xugao97/AttentionDistillation) for ComfyUI.\nOfficial ComfyUI demo for the paper AttentionDistillation, implemented as an extension of ComfyUI. Note that this extension incorporates AttentionDistillation using diffusers." 4 | version = "1.1.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["diffusers", "accelerate", "Pillow", "torch>=2.1.0", "tqdm", "huggingface_hub", "sentencepiece", "protobuf"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/zichongc/ComfyUI-Attention-Distillation" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "zichongc" 14 | DisplayName = "ComfyUI-Attention-Distillation" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | accelerate 3 | Pillow 4 | torch>=2.1.0 5 | protobuf 6 | sentencepiece 7 | tqdm -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from diffusers import AutoencoderKL 6 | from torch import nn 7 | from torch.optim import Adam 8 | from .utils import load_image, save_image 9 | 10 | 11 | def main(args): 12 | os.makedirs(args.out_dir, exist_ok=True) 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | vae = AutoencoderKL.from_pretrained(args.vae_model_path).to( 16 | device, dtype=torch.float32 17 | ) 18 | vae.requires_grad_(False) 19 | 20 | image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32) 21 | image = image * 2 - 1 22 | save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png") 23 | 24 | latents = vae.encode(image)["latent_dist"].mean 25 | save_image(latents, f"{args.out_dir}/latents.png") 26 | 27 | rec_image = vae.decode(latents, return_dict=False)[0] 28 | save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png") 29 | 30 | for param in vae.decoder.parameters(): 31 | param.requires_grad = True 32 | 33 | loss_fn = nn.L1Loss() 34 | optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate) 35 | 36 | # Training loop 37 | for epoch in range(args.num_epochs): 38 | reconstructed = vae.decode(latents, return_dict=False)[0] 39 | loss = loss_fn(reconstructed, image) 40 | 41 | optimizer.zero_grad() 42 | loss.backward() 43 | optimizer.step() 44 | 45 | print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}") 46 | 47 | rec_image = vae.decode(latents, return_dict=False)[0] 48 | save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png") 49 | vae.save_pretrained( 50 | f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}" 51 | ) 52 | 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser( 57 | description="Train a VAE with given image and settings." 58 | ) 59 | 60 | # Add arguments 61 | parser.add_argument( 62 | "--out_dir", 63 | type=str, 64 | default="./trained_vae/", 65 | help="Output directory to save results", 66 | ) 67 | parser.add_argument( 68 | "--vae_model_path", 69 | type=str, 70 | required=True, 71 | help="Path to the pretrained VAE model", 72 | ) 73 | parser.add_argument( 74 | "--image_path", type=str, required=True, help="Path to the input image" 75 | ) 76 | parser.add_argument( 77 | "--learning_rate", 78 | type=float, 79 | default=1e-4, 80 | help="Learning rate for the optimizer", 81 | ) 82 | parser.add_argument( 83 | "--num_epochs", type=int, default=75, help="Number of training epochs" 84 | ) 85 | 86 | args = parser.parse_args() 87 | main(args) 88 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from PIL import Image 5 | from torchvision.transforms import ToTensor 6 | from torchvision.utils import save_image 7 | # import matplotlib.pyplot as plt 8 | import math 9 | 10 | 11 | sd15_file_names = [ 12 | 'feature_extractor/preprocessor_config.json', 13 | 'scheduler/scheduler_config.json', 14 | 'text_encoder/config.json', 15 | 'text_encoder/model.safetensors', 16 | 'tokenizer/merges.txt', 17 | 'tokenizer/special_tokens_map.json', 18 | 'tokenizer/tokenizer_config.json', 19 | 'tokenizer/vocab.json', 20 | 'unet/config.json', 21 | 'unet/diffusion_pytorch_model.safetensors', 22 | 'vae/config.json', 23 | 'vae/diffusion_pytorch_model.safetensors', 24 | 'model_index.json' 25 | ] 26 | 27 | sdxl_file_names = [ 28 | 'model_index.json', 29 | 'vae/config.json', 30 | 'vae/diffusion_pytorch_model.safetensors', 31 | 'unet/config.json', 32 | 'unet/diffusion_pytorch_model.safetensors', 33 | 'tokenizer/merges.txt', 34 | 'tokenizer/special_tokens_map.json', 35 | 'tokenizer/tokenizer_config.json', 36 | 'tokenizer/vocab.json', 37 | 'tokenizer_2/merges.txt', 38 | 'tokenizer_2/special_tokens_map.json', 39 | 'tokenizer_2/tokenizer_config.json', 40 | 'tokenizer_2/vocab.json', 41 | 'text_encoder/config.json', 42 | 'text_encoder/model.safetensors', 43 | 'text_encoder_2/config.json', 44 | 'text_encoder_2/model.safetensors', 45 | 'scheduler/scheduler_config.json', 46 | ] 47 | 48 | flux_file_names = [ 49 | 'model_index.json', 50 | 'vae/config.json', 51 | 'vae/diffusion_pytorch_model.safetensors', 52 | 'transformer/config.json', 53 | 'transformer/diffusion_pytorch_model-00001-of-00003.safetensors', 54 | 'transformer/diffusion_pytorch_model-00002-of-00003.safetensors', 55 | 'transformer/diffusion_pytorch_model-00003-of-00003.safetensors', 56 | 'transformer/diffusion_pytorch_model.safetensors.index.json', 57 | 'tokenizer/merges.txt', 58 | 'tokenizer/special_tokens_map.json', 59 | 'tokenizer/tokenizer_config.json', 60 | 'tokenizer/vocab.json', 61 | 'tokenizer_2/spiece.model', 62 | 'tokenizer_2/special_tokens_map.json', 63 | 'tokenizer_2/tokenizer_config.json', 64 | 'tokenizer_2/tokenizer.json', 65 | 'text_encoder/config.json', 66 | 'text_encoder/model.safetensors', 67 | 'text_encoder_2/config.json', 68 | 'text_encoder_2/model-00001-of-00002.safetensors', 69 | 'text_encoder_2/model-00002-of-00002.safetensors', 70 | 'text_encoder_2/model.safetensors.index.json', 71 | 'scheduler/scheduler_config.json', 72 | ] 73 | 74 | 75 | def register_attn_control(unet, controller, cache=None): 76 | def attn_forward(self): 77 | def forward( 78 | hidden_states, 79 | encoder_hidden_states=None, 80 | attention_mask=None, 81 | temb=None, 82 | *args, 83 | **kwargs, 84 | ): 85 | residual = hidden_states 86 | if self.spatial_norm is not None: 87 | hidden_states = self.spatial_norm(hidden_states, temb) 88 | 89 | input_ndim = hidden_states.ndim 90 | 91 | if input_ndim == 4: 92 | batch_size, channel, height, width = hidden_states.shape 93 | hidden_states = hidden_states.view( 94 | batch_size, channel, height * width 95 | ).transpose(1, 2) 96 | 97 | batch_size, sequence_length, _ = ( 98 | hidden_states.shape 99 | if encoder_hidden_states is None 100 | else encoder_hidden_states.shape 101 | ) 102 | 103 | if attention_mask is not None: 104 | attention_mask = self.prepare_attention_mask( 105 | attention_mask, sequence_length, batch_size 106 | ) 107 | # scaled_dot_product_attention expects attention_mask shape to be 108 | # (batch, heads, source_length, target_length) 109 | attention_mask = attention_mask.view( 110 | batch_size, self.heads, -1, attention_mask.shape[-1] 111 | ) 112 | 113 | if self.group_norm is not None: 114 | hidden_states = self.group_norm( 115 | hidden_states.transpose(1, 2) 116 | ).transpose(1, 2) 117 | 118 | q = self.to_q(hidden_states) 119 | is_self = encoder_hidden_states is None 120 | 121 | if encoder_hidden_states is None: 122 | encoder_hidden_states = hidden_states 123 | elif self.norm_cross: 124 | encoder_hidden_states = self.norm_encoder_hidden_states( 125 | encoder_hidden_states 126 | ) 127 | 128 | k = self.to_k(encoder_hidden_states) 129 | v = self.to_v(encoder_hidden_states) 130 | 131 | inner_dim = k.shape[-1] 132 | head_dim = inner_dim // self.heads 133 | 134 | q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 135 | k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 136 | v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 137 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 138 | # TODO: add support for attn.scale when we move to Torch 2.1 139 | hidden_states = F.scaled_dot_product_attention( 140 | q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 141 | ) 142 | if is_self and controller.cur_self_layer in controller.self_layers: 143 | cache.add(q, k, v, hidden_states) 144 | 145 | hidden_states = hidden_states.transpose(1, 2).reshape( 146 | batch_size, -1, self.heads * head_dim 147 | ) 148 | hidden_states = hidden_states.to(q.dtype) 149 | 150 | # linear proj 151 | hidden_states = self.to_out[0](hidden_states) 152 | # dropout 153 | hidden_states = self.to_out[1](hidden_states) 154 | 155 | if input_ndim == 4: 156 | hidden_states = hidden_states.transpose(-1, -2).reshape( 157 | batch_size, channel, height, width 158 | ) 159 | if self.residual_connection: 160 | hidden_states = hidden_states + residual 161 | 162 | hidden_states = hidden_states / self.rescale_output_factor 163 | 164 | if is_self: 165 | controller.cur_self_layer += 1 166 | 167 | return hidden_states 168 | 169 | return forward 170 | 171 | def modify_forward(net, count): 172 | for name, subnet in net.named_children(): 173 | if net.__class__.__name__ == "Attention": # spatial Transformer layer 174 | net.forward = attn_forward(net) 175 | return count + 1 176 | elif hasattr(net, "children"): 177 | count = modify_forward(subnet, count) 178 | return count 179 | 180 | cross_att_count = 0 181 | for net_name, net in unet.named_children(): 182 | cross_att_count += modify_forward(net, 0) 183 | controller.num_self_layers = cross_att_count // 2 184 | 185 | 186 | def register_attn_control_flux(unet, controller, cache=None): 187 | def attn_forward(self): 188 | 189 | def forward( 190 | hidden_states, 191 | encoder_hidden_states=None, 192 | attention_mask=None, 193 | image_rotary_emb=None, 194 | *args, 195 | **kwargs, 196 | ): 197 | batch_size, _, _ = ( 198 | hidden_states.shape 199 | if encoder_hidden_states is None 200 | else encoder_hidden_states.shape 201 | ) 202 | 203 | # `sample` projections. 204 | query = self.to_q(hidden_states) 205 | key = self.to_k(hidden_states) 206 | value = self.to_v(hidden_states) 207 | 208 | inner_dim = key.shape[-1] 209 | head_dim = inner_dim // self.heads 210 | 211 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 212 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 213 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) 214 | 215 | if self.norm_q is not None: 216 | query = self.norm_q(query) 217 | if self.norm_k is not None: 218 | key = self.norm_k(key) 219 | 220 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 221 | if encoder_hidden_states is not None: 222 | # `context` projections. 223 | encoder_hidden_states_query_proj = self.add_q_proj( 224 | encoder_hidden_states 225 | ) 226 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) 227 | encoder_hidden_states_value_proj = self.add_v_proj( 228 | encoder_hidden_states 229 | ) 230 | 231 | encoder_hidden_states_query_proj = ( 232 | encoder_hidden_states_query_proj.view( 233 | batch_size, -1, self.heads, head_dim 234 | ).transpose(1, 2) 235 | ) 236 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 237 | batch_size, -1, self.heads, head_dim 238 | ).transpose(1, 2) 239 | encoder_hidden_states_value_proj = ( 240 | encoder_hidden_states_value_proj.view( 241 | batch_size, -1, self.heads, head_dim 242 | ).transpose(1, 2) 243 | ) 244 | 245 | if self.norm_added_q is not None: 246 | encoder_hidden_states_query_proj = self.norm_added_q( 247 | encoder_hidden_states_query_proj 248 | ) 249 | if self.norm_added_k is not None: 250 | encoder_hidden_states_key_proj = self.norm_added_k( 251 | encoder_hidden_states_key_proj 252 | ) 253 | 254 | # attention 255 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 256 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 257 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 258 | 259 | if image_rotary_emb is not None: 260 | from diffusers.models.embeddings import apply_rotary_emb 261 | 262 | query = apply_rotary_emb(query, image_rotary_emb) 263 | key = apply_rotary_emb(key, image_rotary_emb) 264 | 265 | hidden_states = F.scaled_dot_product_attention( 266 | query, key, value, dropout_p=0.0, is_causal=False 267 | ) 268 | if controller.cur_self_layer in controller.self_layers: 269 | # print("cache added") 270 | cache.add(query, key, value, hidden_states) 271 | # if encoder_hidden_states is None: 272 | controller.cur_self_layer += 1 273 | 274 | hidden_states = hidden_states.transpose(1, 2).reshape( 275 | batch_size, -1, self.heads * head_dim 276 | ) 277 | 278 | hidden_states = hidden_states.to(query.dtype) 279 | 280 | if encoder_hidden_states is not None: 281 | encoder_hidden_states, hidden_states = ( 282 | hidden_states[:, : encoder_hidden_states.shape[1]], 283 | hidden_states[:, encoder_hidden_states.shape[1] :], 284 | ) 285 | 286 | # linear proj 287 | hidden_states = self.to_out[0](hidden_states) 288 | # dropout 289 | hidden_states = self.to_out[1](hidden_states) 290 | encoder_hidden_states = self.to_add_out(encoder_hidden_states) 291 | 292 | return hidden_states, encoder_hidden_states 293 | else: 294 | return hidden_states 295 | 296 | return forward 297 | 298 | def modify_forward(net, count): 299 | # print(net.named_children()) 300 | for name, subnet in net.named_children(): 301 | if net.__class__.__name__ == "Attention": # spatial Transformer layer 302 | net.forward = attn_forward(net) 303 | return count + 1 304 | elif hasattr(net, "children"): 305 | count = modify_forward(subnet, count) 306 | return count 307 | 308 | cross_att_count = 0 309 | cross_att_count += modify_forward(unet, 0) 310 | controller.num_self_layers += cross_att_count 311 | 312 | 313 | def load_image(image_path, size=None, mode="RGB"): 314 | img = Image.open(image_path).convert(mode) 315 | if size is None: 316 | width, height = img.size 317 | new_width = (width // 64) * 64 318 | new_height = (height // 64) * 64 319 | size = (new_width, new_height) 320 | img = img.resize(size, Image.BICUBIC) 321 | return ToTensor()(img).unsqueeze(0) 322 | 323 | 324 | def adain(source, target, eps=1e-6): 325 | source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std( 326 | source, dim=(2, 3), keepdim=True 327 | ) 328 | target_mean, target_std = torch.mean( 329 | target, dim=(0, 2, 3), keepdim=True 330 | ), torch.std(target, dim=(0, 2, 3), keepdim=True) 331 | normalized_source = (source - source_mean) / (source_std + eps) 332 | transferred_source = normalized_source * target_std + target_mean 333 | 334 | return transferred_source 335 | 336 | 337 | def adain_flux(source, target, eps=1e-6): 338 | source_mean, source_std = torch.mean(source, dim=1, keepdim=True), torch.std( 339 | source, dim=1, keepdim=True 340 | ) 341 | target_mean, target_std = torch.mean( 342 | target, dim=(0, 1), keepdim=True 343 | ), torch.std(target, dim=(0, 1), keepdim=True) 344 | normalized_source = (source - source_mean) / (source_std + eps) 345 | transferred_source = normalized_source * target_std + target_mean 346 | 347 | return transferred_source 348 | 349 | 350 | class Controller: 351 | def step(self): 352 | self.cur_self_layer = 0 353 | 354 | def __init__(self, self_layers=(0, 16)): 355 | self.num_self_layers = -1 356 | self.cur_self_layer = 0 357 | self.self_layers = list(range(*self_layers)) 358 | 359 | 360 | class DataCache: 361 | def __init__(self): 362 | self.q = [] 363 | self.k = [] 364 | self.v = [] 365 | self.out = [] 366 | 367 | def clear(self): 368 | self.q.clear() 369 | self.k.clear() 370 | self.v.clear() 371 | self.out.clear() 372 | 373 | def add(self, q, k, v, out): 374 | self.q.append(q) 375 | self.k.append(k) 376 | self.v.append(v) 377 | self.out.append(out) 378 | 379 | def get(self): 380 | return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy() 381 | 382 | 383 | 384 | # def show_image(path, title, display_height=3, title_fontsize=12): 385 | # img = Image.open(path) 386 | # img_width, img_height = img.size 387 | 388 | # aspect_ratio = img_width / img_height 389 | # display_width = display_height * aspect_ratio 390 | 391 | # plt.figure(figsize=(display_width, display_height)) 392 | # plt.imshow(img) 393 | # plt.title(title, 394 | # fontsize=title_fontsize, 395 | # fontweight='bold', 396 | # pad=20) 397 | # plt.axis('off') 398 | # plt.tight_layout() 399 | # plt.show() 400 | -------------------------------------------------------------------------------- /workflows/style_t2i_generation_flux.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"ResizeImage","pos":[254.7171630859375,483.82080078125],"size":[315,58],"flags":{},"order":3,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":1,"localized_name":"image"}],"outputs":[{"name":"image","type":"IMAGE","links":[2],"slot_index":0,"localized_name":"image"}],"properties":{"Node name for S&R":"ResizeImage"},"widgets_values":[512]},{"id":5,"type":"PreviewImage","pos":[1068.2235107421875,130.05441284179688],"size":[540.3896484375,543.4026489257812],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5,"localized_name":"images"}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":3,"type":"LoadPILImage","pos":[326.92474365234375,104.5999755859375],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[3],"slot_index":0,"localized_name":"image"}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":2,"type":"LoadPILImage","pos":[680.6911010742188,103.04146575927734],"size":[315,294],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0,"localized_name":"image"}],"title":"Content Image","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["lecun.png","image"]},{"id":1,"type":"LoadDistiller","pos":[256.7950439453125,609.0155029296875],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[4],"slot_index":0,"localized_name":"distiller"}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADOptimizer","pos":[620.4314575195312,465.638916015625],"size":[415.8000183105469,242],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":4,"localized_name":"distiller"},{"name":"content","type":"IMAGE","link":2,"localized_name":"content"},{"name":"style","type":"IMAGE","link":3,"localized_name":"style"}],"outputs":[{"name":"image","type":"IMAGE","links":[5],"slot_index":0,"localized_name":"image"}],"properties":{"Node name for S&R":"ADOptimizer"},"widgets_values":[300,0.23,0.05,512,512,2025,"fixed"]}],"links":[[1,2,0,6,0,"IMAGE"],[2,6,0,4,1,"IMAGE"],[3,3,0,4,2,"IMAGE"],[4,1,0,4,0,"DISTILLER"],[5,4,0,5,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style Transfer","bounding":[150.79759216308594,-23.581954956054688,1556.8001708984375,755.0857543945312],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.9090909090909092,"offset":[116.54605349790273,20.61595758295544]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4} -------------------------------------------------------------------------------- /workflows/style_t2i_generation_sd15.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"PreviewImage","pos":[1050.10546875,266.7724609375],"size":[529.428466796875,492.2856750488281],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":5,"type":"LoadPILImage","pos":[525.5338134765625,133.05821228027344],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":3,"type":"PureText","pos":[25.32625961303711,601.0584106445312],"size":[400,200],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[4],"slot_index":0}],"title":"Negative prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["blur, low quality"]},{"id":2,"type":"PureText","pos":[27.923683166503906,327.5516357421875],"size":[400,200],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[3],"slot_index":0}],"title":"Positive prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["A portrait of Mr. Donald Trump"]},{"id":1,"type":"LoadDistiller","pos":[75.71586608886719,163.65560913085938],"size":[315,82],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[2],"slot_index":0}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADSampler","pos":[492.3912658691406,484.4866943359375],"size":[504,334],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":2},{"name":"style","type":"IMAGE","link":1},{"name":"positive","type":"CONDITIONING","link":3},{"name":"negative","type":"CONDITIONING","link":4}],"outputs":[{"name":"images","type":"IMAGE","links":[5],"slot_index":0}],"properties":{"Node name for S&R":"ADSampler"},"widgets_values":[50,0.015,2,7.5,0,1,2025,"increment",512,512]}],"links":[[1,5,0,4,1,"IMAGE"],[2,1,0,4,0,"DISTILLER"],[3,2,0,4,2,"CONDITIONING"],[4,3,0,4,3,"CONDITIONING"],[5,4,0,6,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style-Specific Text-to-Image Generation","bounding":[-50.282806396484375,25.119415283203125,1678.1817626953125,820.259765625],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":1,"offset":[147.782927316091,-0.36286816241710085]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4} -------------------------------------------------------------------------------- /workflows/style_t2i_generation_sdxl.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":5,"type":"LoadPILImage","pos":[642.10498046875,143.9153289794922],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0,"localized_name":"image"}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["1.png","image"]},{"id":1,"type":"LoadDistiller","pos":[192.28724670410156,174.5127410888672],"size":[315,82],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[2],"slot_index":0,"localized_name":"distiller"}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-xl-base-1.0","bf16"]},{"id":3,"type":"PureText","pos":[141.8976287841797,611.9155883789062],"size":[400,200],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[4],"slot_index":0,"localized_name":"CONDITIONING"}],"title":"Negative prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":[""]},{"id":6,"type":"PreviewImage","pos":[1166.67724609375,277.6296081542969],"size":[529.428466796875,492.2856750488281],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5,"localized_name":"images"}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":2,"type":"PureText","pos":[144.49505615234375,338.4087829589844],"size":[400,200],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[3],"slot_index":0,"localized_name":"CONDITIONING"}],"title":"Positive prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["A photo of Big Ben, London."]},{"id":4,"type":"ADSampler","pos":[608.9625244140625,495.3438415527344],"size":[504,334],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":2,"localized_name":"distiller"},{"name":"style","type":"IMAGE","link":1,"localized_name":"style"},{"name":"positive","type":"CONDITIONING","link":3,"localized_name":"positive"},{"name":"negative","type":"CONDITIONING","link":4,"localized_name":"negative"}],"outputs":[{"name":"images","type":"IMAGE","links":[5],"slot_index":0,"localized_name":"images"}],"properties":{"Node name for S&R":"ADSampler"},"widgets_values":[50,0.015,2,7.5,1,2025,"fixed",1024,1024]}],"links":[[1,5,0,4,1,"IMAGE"],[2,1,0,4,0,"DISTILLER"],[3,2,0,4,2,"CONDITIONING"],[4,3,0,4,3,"CONDITIONING"],[5,4,0,6,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style-Specific Text-to-Image Generation (SDXL)","bounding":[66.2885971069336,35.97654342651367,1678.1817626953125,820.259765625],"color":"#88A","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.9090909090909092,"offset":[116.54605349790273,20.61595758295544]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4} -------------------------------------------------------------------------------- /workflows/style_transfer_sd15.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"ResizeImage","pos":[254.7171630859375,483.82080078125],"size":[315,58],"flags":{},"order":3,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":1}],"outputs":[{"name":"image","type":"IMAGE","links":[2],"slot_index":0}],"properties":{"Node name for S&R":"ResizeImage"},"widgets_values":[512]},{"id":5,"type":"PreviewImage","pos":[1068.2235107421875,130.05441284179688],"size":[540.3896484375,543.4026489257812],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"}},{"id":3,"type":"LoadPILImage","pos":[326.92474365234375,104.5999755859375],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[3],"slot_index":0}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":2,"type":"LoadPILImage","pos":[680.6911010742188,103.04146575927734],"size":[315,294],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0}],"title":"Content Image","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["lecun.png","image"]},{"id":1,"type":"LoadDistiller","pos":[256.7950439453125,609.0155029296875],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[4],"slot_index":0}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADOptimizer","pos":[620.4314575195312,465.638916015625],"size":[415.8000183105469,242],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":4},{"name":"content","type":"IMAGE","link":2},{"name":"style","type":"IMAGE","link":3}],"outputs":[{"name":"image","type":"IMAGE","links":[5],"slot_index":0}],"properties":{"Node name for S&R":"ADOptimizer"},"widgets_values":[300,0.23,0.05,512,512,2025,"fixed"]}],"links":[[1,2,0,6,0,"IMAGE"],[2,6,0,4,1,"IMAGE"],[3,3,0,4,2,"IMAGE"],[4,1,0,4,0,"DISTILLER"],[5,4,0,5,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style Transfer","bounding":[150.79759216308594,-23.581954956054688,1556.8001708984375,755.0857543945312],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.8264462809917354,"offset":[279.35430419292214,247.5915815675024]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4} --------------------------------------------------------------------------------