├── .github
└── workflows
│ └── publish.yml
├── LICENSE
├── README.md
├── __init__.py
├── assets
├── style_t2i.jpg
├── style_t2i_sdxl.jpg
└── style_transfer.jpg
├── comfyui_nodes.py
├── examples
├── 1.png
├── 26.jpg
├── 40.jpg
└── lecun.png
├── losses.py
├── pipeline_flux.py
├── pipeline_sd.py
├── pipeline_sdxl.py
├── pyproject.toml
├── requirements.txt
├── train_vae.py
├── utils.py
└── workflows
├── style_t2i_generation_flux.json
├── style_t2i_generation_sd15.json
├── style_t2i_generation_sdxl.json
└── style_transfer_sd15.json
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - "pyproject.toml"
9 |
10 | permissions:
11 | issues: write
12 |
13 | jobs:
14 | publish-node:
15 | name: Publish Custom Node to registry
16 | runs-on: ubuntu-latest
17 | if: ${{ github.repository_owner == 'zichongc' }}
18 | steps:
19 | - name: Check out code
20 | uses: actions/checkout@v4
21 | with:
22 | submodules: true
23 | - name: Publish Custom Node
24 | uses: Comfy-Org/publish-node-action@v1
25 | with:
26 | ## Add your own personal access token to your Github Repository secrets and reference it here.
27 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Zichong Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ComfyUI-Attention-Distillation
2 |
3 | Non-native [AttentionDistillation](https://xugao97.github.io/AttentionDistillation/) for ComfyUI.
4 |
5 | Official ComfyUI demo for the paper [AttentionDistillation](https://arxiv.org/abs/2502.20235), implemented as an extension of ComfyUI. Note that this extension incorporates AttentionDistillation using `diffusers`.
6 |
7 | The official code for AttentionDistillation can be found [here](https://github.com/xugao97/AttentionDistillation).
8 |
9 | ### 🔥🔥 News
10 | * **2025/03/10**: Workflows for style-specific T2I generation using **SDXL** and **Flux**(beta) have been released.
11 | * **2025/02/27**: We release the ComfyUI implementation of Attention Distillation and two workflows for style transfer and style-specific text-to-image generation using Stable Diffusion 1.5.
12 | * **2025/02/27**: The official code for AttentionDistillation has been released [here](https://github.com/xugao97/AttentionDistillation).
13 |
14 | ### 🛒 Installation
15 | Download or `git clone` this repository into the `ComfyUI/custom_nodes/` directory, or use the Manager for a streamlined setup.
16 |
17 |
18 | ##### Install manually
19 | 1. `cd custom_nodes`
20 | 2. `git clone ...`
21 | 3. `cd custom_nodes/ComfyUI-AttentionDistillation`
22 | 4. `pip install -r requirements.txt`
23 | 5. restart ComfyUI
24 |
25 | ### 📒 How to Use
26 | ##### Download T2I diffusion models
27 | This implementation utilizes checkpoints for `diffusers`. Download the required models and place them in the `ComfyUI/models/diffusers` directory:
28 | |Model|Model Name and Link|
29 | |:---:|:---:|
30 | | Stable Diffusion (v1.5, v2.1) | [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)
[stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) |
31 | | SDXL | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
32 | | Flux (dev) | [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
33 |
34 |
35 | ~~*Note: Currently, only Stable Diffusion v1.5 is required.*~~
36 |
37 | ##### Load the workflow
38 | Workflows for various tasks are available in `ComfyUI/custom_nodes/Comfy-Attention-Distillation/workflows`. Simply load them to get started. Additionally, we've included usage examples in the [Examples](#examples) section for your reference.
39 |
40 | ### 🔍 Examples
41 |
42 | #### Style-specific text-to-image generation
43 | `style_t2i_generation_sd15.json`
44 |
45 |
46 |
47 |
48 | `style_t2i_generation_sdxl.json`
49 |
50 |
51 |
52 |
53 | `style_t2i_generation_flux.json` (beta)
54 |
55 |
56 |
57 | #### Style Transfer
58 | `style_transfer_sd15.json`
59 |
60 |
61 |
62 |
63 | ### 📃 TODOs
64 | - [x] Workflow for style-specific text-to-image generation using SDXL.
65 | - [x] Workflow for style-specific text-to-image generation using Flux.
66 | - [ ] Workflow for texture synthesis.
67 |
68 |
69 |
81 |
82 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 |
3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
4 |
--------------------------------------------------------------------------------
/assets/style_t2i.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_t2i.jpg
--------------------------------------------------------------------------------
/assets/style_t2i_sdxl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_t2i_sdxl.jpg
--------------------------------------------------------------------------------
/assets/style_transfer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/assets/style_transfer.jpg
--------------------------------------------------------------------------------
/comfyui_nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from PIL import Image
4 | from diffusers import DDIMScheduler
5 |
6 | from comfy.comfy_types import IO
7 | import comfy.model_management as mm
8 | import node_helpers
9 | import folder_paths
10 | from huggingface_hub import hf_hub_download
11 | from tqdm import tqdm
12 |
13 | from torchvision.transforms.functional import resize, to_tensor
14 | from accelerate.utils import set_seed
15 | from .pipeline_sd import ADPipeline
16 | from .pipeline_sdxl import ADPipeline as ADXLPipeline
17 | from .pipeline_flux import ADPipeline as ADFluxPipeline
18 | from .utils import Controller
19 | from .utils import sd15_file_names, sdxl_file_names, flux_file_names
20 |
21 |
22 | class PureText:
23 | @classmethod
24 | def INPUT_TYPES(s):
25 | return {
26 | "required": {
27 | "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
28 | }
29 | }
30 |
31 | RETURN_TYPES = (IO.CONDITIONING,)
32 | FUNCTION = "get_prompt"
33 | CATEGORY = "AttentionDistillationWrapper"
34 |
35 | def get_prompt(self, text):
36 | return (text,)
37 |
38 |
39 | class LoadPILImage:
40 | @classmethod
41 | def INPUT_TYPES(s):
42 | input_dir = folder_paths.get_input_directory()
43 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
44 | return {"required":
45 | {"image": (sorted(files), {"image_upload": True})},
46 | }
47 |
48 | CATEGORY = "AttentionDistillationWrapper"
49 |
50 | RETURN_TYPES = ("IMAGE",)
51 | RETURN_NAMES = ("image",)
52 | FUNCTION = "load_image"
53 |
54 | def load_image(self, image):
55 | image_path = folder_paths.get_annotated_filepath(image)
56 | img = node_helpers.pillow(Image.open, image_path).convert('RGB')
57 | return (img,)
58 |
59 |
60 | class ResizeImage:
61 | RETURN_TYPES = ("IMAGE",)
62 | RETURN_NAMES = ("image",)
63 | FUNCTION = "resize_image"
64 |
65 | CATEGORY = "AttentionDistillationWrapper"
66 |
67 | @classmethod
68 | def INPUT_TYPES(s):
69 | return {
70 | "required": {
71 | "image": ("IMAGE",),
72 | "resolution": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}),
73 | },
74 | }
75 |
76 | def resize_image(self, image, resolution):
77 | if isinstance(image, torch.Tensor):
78 | assert image.ndim == 4
79 | if (image.shape[1] != 3 and image.shape[-1] == 3):
80 | image = image.permute(0, 3, 1, 2)
81 | image = resize(image, size=resolution)
82 | return (image,)
83 |
84 |
85 | class LoadDistiller:
86 | RETURN_TYPES = ("DISTILLER",)
87 | RETURN_NAMES = ("distiller",)
88 | FUNCTION = "load_model"
89 | CATEGORY = "AttentionDistillationWrapper"
90 |
91 | @classmethod
92 | def INPUT_TYPES(s):
93 | return {
94 | 'required': {
95 | "model": (['stable-diffusion-v1-5', 'stable-diffusion-xl-base-1.0', 'FLUX.1-dev'], {"default": "stable-diffusion-v1-5"}),
96 | "precision": (['bf16', 'fp32'], {"default": 'bf16'}),
97 | },
98 | }
99 |
100 | @torch.inference_mode(False)
101 | def load_model(self, model, precision):
102 | weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
103 | if precision == 'fp32':
104 | precision = 'no'
105 | device = mm.get_torch_device()
106 |
107 | model_name = os.path.join(folder_paths.models_dir, 'diffusers', model)
108 | model_class = {
109 | "stable-diffusion-v1-5": ADPipeline,
110 | "stable-diffusion-xl-base-1.0": ADXLPipeline,
111 | "FLUX.1-dev": ADFluxPipeline,
112 | }[model]
113 |
114 | if not os.path.exists(model_name):
115 | print(f"Please download target model to : {model_name}")
116 |
117 | try:
118 | if model == "FLUX.1-dev":
119 | distiller = model_class.from_pretrained(
120 | model_name, safety_checker=None, torch_dtype=weight_dtype
121 | ).to(device)
122 | else:
123 | scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler')
124 | distiller = model_class.from_pretrained(
125 | model_name, scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype
126 | ).to(device)
127 | except:
128 | print('Download models...')
129 |
130 | repo_name = {
131 | "stable-diffusion-v1-5": "stable-diffusion-v1-5/stable-diffusion-v1-5",
132 | "stable-diffusion-xl-base-1.0": "stabilityai/stable-diffusion-xl-base-1.0",
133 | "FLUX.1-dev": "black-forest-labs/FLUX.1-dev",
134 | }[model]
135 |
136 | file_names = {
137 | "stable-diffusion-v1-5": sd15_file_names,
138 | "stable-diffusion-xl-base-1.0": sdxl_file_names,
139 | "FLUX.1-dev": flux_file_names,
140 | }[model]
141 |
142 | pbar = tqdm(file_names)
143 | for file_name in pbar:
144 | pbar.set_description(f'Downloading {file_name}')
145 | if not os.path.exists(os.path.join(model_name, file_name)):
146 | hf_hub_download(repo_id=repo_name, filename=file_name, local_dir=model_name)
147 | pbar.update()
148 |
149 |
150 | if model == "FLUX.1-dev":
151 | distiller = model_class.from_pretrained(
152 | model_name, safety_checker=None, torch_dtype=weight_dtype
153 | ).to(device)
154 | else:
155 | scheduler = DDIMScheduler.from_pretrained(model_name, subfolder='scheduler')
156 | distiller = model_class.from_pretrained(
157 | model_name, scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype
158 | ).to(device)
159 |
160 | if hasattr(distiller, 'unet'):
161 | distiller.classifier = distiller.unet
162 | elif hasattr(distiller, 'transformer'):
163 | distiller.classifier = distiller.transformer
164 | else:
165 | raise ValueError("Failed to initialize the classifier.")
166 |
167 | return ({"distiller": distiller, "precision": precision, 'weight_dtype': weight_dtype},)
168 |
169 |
170 | class ADOptimizer:
171 | @classmethod
172 | def INPUT_TYPES(s):
173 | return {
174 | "required": {
175 | "distiller": ("DISTILLER",),
176 | "content": ("IMAGE",),
177 | "style": ("IMAGE",),
178 | "steps": ("INT", {"default": 200, "min": 1, "max": 500, "step": 1}),
179 | "content_weight": ("FLOAT", {"default": 0.25, "min": 0., "max": 10., "step": 0.001}),
180 | "lr": ("FLOAT", {"default": 0.05, "min": 0.001, "max": 0.5, "step": 0.001}),
181 | "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}),
182 | "width": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}),
183 | "seed": ("INT", {"default": 2025, "min": 0, "max": 0xffffffffffffffff, "step": 1}),
184 | }
185 | }
186 | RETURN_TYPES = ("IMAGE",)
187 | RETURN_NAMES = ("image",)
188 | FUNCTION = "process"
189 | CATEGORY = "AttentionDistillationWrapper"
190 |
191 | @torch.inference_mode(False)
192 | def process(self, distiller, content, style, steps, content_weight, lr, height, width, seed):
193 | precision = distiller['precision']
194 | attn_distiller = distiller['distiller']
195 |
196 | assert isinstance(attn_distiller, ADPipeline), "Only support SD1.5 for style transfer."
197 | assert isinstance(style, Image.Image) and isinstance(content, Image.Image), "Please use the image loader in `AttentionDistillationWrapper->Load PIL Image` for loading image."
198 |
199 | if isinstance(style, torch.Tensor) and style.ndim == 3:
200 | style = resize(style.unsqueeze(0), (512, 512))
201 | elif isinstance(style, Image.Image):
202 | style = to_tensor(resize(style, (512, 512))).unsqueeze(0)
203 |
204 | if isinstance(content, torch.Tensor) and content.ndim == 3:
205 | content = content.unsqueeze(0)
206 | elif isinstance(content, Image.Image):
207 | content = to_tensor(content).unsqueeze(0)
208 |
209 | assert isinstance(style, torch.Tensor) and style.ndim == 4
210 | assert isinstance(content, torch.Tensor) and content.ndim == 4
211 |
212 | if (style.shape[1] != 3 and style.shape[-1] == 3):
213 | style = style.permute(0, 3, 1, 2)
214 | if (content.shape[1] != 3 and content.shape[-1] == 3):
215 | content = content.permute(0, 3, 1, 2)
216 |
217 | print(content.shape)
218 | controller = Controller(self_layers=(10, 16))
219 | set_seed(seed)
220 |
221 | print('style', style.min(), style.max())
222 | print('content', content.min(), content.max())
223 |
224 | images = attn_distiller.optimize(
225 | lr=lr,
226 | batch_size=1,
227 | iters=1,
228 | width=width,
229 | height=height,
230 | weight=content_weight,
231 | controller=controller,
232 | style_image=style,
233 | content_image=content,
234 | mixed_precision=precision,
235 | num_inference_steps=steps,
236 | enable_gradient_checkpoint=False,
237 | )
238 | images = images.permute(0, 2, 3, 1).float()
239 | return (images,)
240 |
241 |
242 | class ADSampler:
243 | @classmethod
244 | def INPUT_TYPES(s):
245 | return {
246 | "required": {
247 | "distiller": ("DISTILLER",),
248 | "style": ("IMAGE",),
249 | "positive": (IO.CONDITIONING,),
250 | "negative": (IO.CONDITIONING,),
251 | "steps": ("INT", {"default": 50, "min": 1, "max": 200, "step": 1}),
252 | "lr": ("FLOAT", {"default": 0.015, "min": 0.001, "max": 1., "step": 0.001}),
253 | "iters": ("INT", {"default": 2, "min": 0, "max": 5, "step": 1}),
254 | "cfg": ("FLOAT", {"default": 7.5, "min": 1., "max": 20., "step": 0.01}),
255 | "num_images_per_prompt": ("INT", {"default": 1, "min": 1, "max": 5, "step": 1}),
256 | "seed": ("INT", {"default": 2025, "min": 0, "max": 0xffffffffffffffff}),
257 | "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}),
258 | "width": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 8}),
259 | }
260 | }
261 | RETURN_TYPES = ("IMAGE",)
262 | RETURN_NAMES = ("images",)
263 | FUNCTION = "process"
264 | CATEGORY = "AttentionDistillationWrapper"
265 |
266 | DEFAULT_CONFIGS = {
267 | ADPipeline: {'self_layers': (10, 16), 'resolution': (512, 512), 'enable_gradient_checkpoint': False},
268 | ADXLPipeline: {'self_layers': (64, 70), 'resolution': (1024, 1024), 'enable_gradient_checkpoint': True},
269 | ADFluxPipeline: {'self_layers': (50, 57), 'resolution': (512, 512), 'enable_gradient_checkpoint': True},
270 | }
271 |
272 | @torch.inference_mode(False)
273 | def process(self, distiller, style, positive, negative, steps, lr, iters, cfg, num_images_per_prompt, seed, height, width):
274 | precision = distiller['precision']
275 | attn_distiller = distiller['distiller']
276 |
277 | assert isinstance(style, Image.Image), "Please use the image loader in `AttentionDistillationWrapper->Load PIL Image` for loading image."
278 |
279 | default_config = self.DEFAULT_CONFIGS[type(attn_distiller)]
280 | print(default_config)
281 |
282 | controller = Controller(self_layers=default_config['self_layers'])
283 |
284 | if isinstance(style, torch.Tensor) and style.ndim == 3:
285 | style = resize(style.unsqueeze(0), default_config['resolution'])
286 | elif isinstance(style, Image.Image):
287 | style = to_tensor(resize(style, default_config['resolution'])).unsqueeze(0)
288 |
289 | assert isinstance(style, torch.Tensor) and style.ndim == 4
290 |
291 | if (style.shape[1] != 3 and style.shape[-1] == 3):
292 | style = style.permute(0, 3, 1, 2)
293 |
294 | print('style', style.min(), style.max(), style.mean())
295 | set_seed(seed)
296 | images = attn_distiller.sample(
297 | controller=controller,
298 | iters=iters,
299 | lr=lr,
300 | adain=True,
301 | height=height,
302 | width=width,
303 | mixed_precision=precision,
304 | style_image=style,
305 | prompt=positive,
306 | negative_prompt=negative,
307 | guidance_scale=cfg,
308 | num_inference_steps=steps,
309 | num_images_per_prompt=num_images_per_prompt,
310 | enable_gradient_checkpoint=default_config['enable_gradient_checkpoint']
311 | )
312 | images = images.permute(0, 2, 3, 1).float()
313 | return (images,)
314 |
315 |
316 | NODE_CLASS_MAPPINGS = {
317 | "LoadDistiller": LoadDistiller,
318 | "ADOptimizer": ADOptimizer,
319 | "ADSampler": ADSampler,
320 | "LoadPILImage": LoadPILImage,
321 | "PureText": PureText,
322 | "ResizeImage": ResizeImage,
323 | }
324 |
325 | NODE_DISPLAY_NAME_MAPPINGS = {
326 | "LoadDistiller": "Load Distiller",
327 | "ADHandler": "Handler for Attention Distillation",
328 | "ADOptimizer": "Optimization-Based Style Transfer",
329 | "ADSampler": "Sampler for Style-Specific Text-to-Image",
330 | "LoadPILImage": "Load PIL Image",
331 | "PureText": "Text Prompt",
332 | "ResizeImage": "Resize Image",
333 | }
334 |
--------------------------------------------------------------------------------
/examples/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/1.png
--------------------------------------------------------------------------------
/examples/26.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/26.jpg
--------------------------------------------------------------------------------
/examples/40.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/40.jpg
--------------------------------------------------------------------------------
/examples/lecun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zichongc/ComfyUI-Attention-Distillation/4c00ca9b2604c9e27dc1427ec49f9b108a1e71d0/examples/lecun.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | loss_fn = torch.nn.L1Loss()
8 |
9 |
10 | def ad_loss(
11 | q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
12 | ):
13 | loss = 0
14 | attn_mask = None
15 | for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):
16 | if source_mask is not None and target_mask is not None:
17 | w = h = int(np.sqrt(q.shape[2]))
18 | mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))
19 | mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))
20 | attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)
21 | attn_mask=attn_mask.to(q.device)
22 |
23 | target_out = F.scaled_dot_product_attention(
24 | q * scale,
25 | torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
26 | torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
27 | attn_mask=attn_mask
28 | )
29 | loss += loss_fn(self_out, target_out.detach())
30 | return loss
31 |
32 |
33 |
34 | def q_loss(q_list, qc_list):
35 | loss = 0
36 | for q, qc in zip(q_list, qc_list):
37 | loss += loss_fn(q, qc.detach())
38 | return loss
39 |
40 | # weight = 200
41 | def qk_loss(q_list, k_list, qc_list, kc_list):
42 | loss = 0
43 | for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):
44 | scale_factor = 1 / math.sqrt(q.size(-1))
45 | self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)
46 | target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)
47 | loss += loss_fn(self_map, target_map.detach())
48 | return loss
49 |
50 | # weight = 1
51 | def qkv_loss(q_list, k_list, vc_list, c_out_list):
52 | loss = 0
53 | for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):
54 | self_out = F.scaled_dot_product_attention(q, k, vc)
55 | loss += loss_fn(self_out, target_out.detach())
56 | return loss
57 |
--------------------------------------------------------------------------------
/pipeline_flux.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from typing import Any, Callable, Dict, List, Optional, Union
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from .utils import DataCache, register_attn_control_flux, adain_flux
8 | from accelerate import Accelerator
9 | from diffusers import FluxPipeline
10 |
11 |
12 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
13 | def retrieve_latents(
14 | encoder_output: torch.Tensor,
15 | generator: Optional[torch.Generator] = None,
16 | sample_mode: str = "sample",
17 | ):
18 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
19 | return encoder_output.latent_dist.sample(generator)
20 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
21 | return encoder_output.latent_dist.mode()
22 | elif hasattr(encoder_output, "latents"):
23 | return encoder_output.latents
24 | else:
25 | raise AttributeError("Could not access latents of provided encoder_output")
26 |
27 |
28 | class ADPipeline(FluxPipeline):
29 | def freeze(self):
30 | self.transformer.requires_grad_(False)
31 | self.text_encoder.requires_grad_(False)
32 | self.text_encoder_2.requires_grad_(False)
33 | self.vae.requires_grad_(False)
34 |
35 | @torch.no_grad()
36 | def image2latent(self, image):
37 | dtype = next(self.vae.parameters()).dtype
38 | device = self._execution_device
39 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
40 | latent = retrieve_latents(self.vae.encode(image))
41 | latent = (
42 | latent - self.vae.config.shift_factor
43 | ) * self.vae.config.scaling_factor
44 | return latent
45 |
46 | @torch.no_grad()
47 | def latent2image(self, latent, height, width):
48 | dtype = next(self.vae.parameters()).dtype
49 | device = self._execution_device
50 | latent = latent.to(device=device, dtype=dtype)
51 | latents = self._unpack_latents(latent, height, width, self.vae_scale_factor)
52 | latents = (
53 | latents / self.vae.config.scaling_factor
54 | ) + self.vae.config.shift_factor
55 | image = self.vae.decode(latents, return_dict=False)[0]
56 | return (image * 0.5 + 0.5).clamp(0, 1)
57 |
58 | def init(self, enable_gradient_checkpoint):
59 | self.freeze()
60 | self.enable_vae_slicing()
61 | # self.enable_model_cpu_offload()
62 | # self.enable_vae_tiling()
63 | weight_dtype = torch.float32
64 | if self.accelerator.mixed_precision == "fp16":
65 | weight_dtype = torch.float16
66 | elif self.accelerator.mixed_precision == "bf16":
67 | weight_dtype = torch.bfloat16
68 |
69 | # Move unet, vae and text_encoder to device and cast to weight_dtype
70 | self.transformer.to(self.accelerator.device, dtype=weight_dtype)
71 | self.vae.to(self.accelerator.device, dtype=weight_dtype)
72 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
73 | self.classifier.to(self.accelerator.device, dtype=weight_dtype)
74 | self.classifier = self.accelerator.prepare(self.classifier)
75 | if enable_gradient_checkpoint:
76 | self.classifier.enable_gradient_checkpointing()
77 |
78 | def sample(
79 | self,
80 | style_image=None,
81 | controller=None,
82 | loss_fn=torch.nn.L1Loss(),
83 | start_time=9999,
84 | lr=0.05,
85 | iters=2,
86 | adain=True,
87 | mixed_precision="no",
88 | enable_gradient_checkpoint=False,
89 | prompt: Union[str, List[str]] = None,
90 | prompt_2: Optional[Union[str, List[str]]] = None,
91 | height: Optional[int] = None,
92 | width: Optional[int] = None,
93 | num_inference_steps: int = 28,
94 | # timesteps: List[int] = None,
95 | guidance_scale: float = 3.5,
96 | num_images_per_prompt: Optional[int] = 1,
97 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
98 | latents: Optional[torch.FloatTensor] = None,
99 | prompt_embeds: Optional[torch.FloatTensor] = None,
100 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
101 | output_type: Optional[str] = "pil",
102 | return_dict: bool = True,
103 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
104 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
105 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
106 | max_sequence_length: int = 512,
107 | **kwargs
108 | ):
109 | height = height or self.default_sample_size * self.vae_scale_factor
110 | width = width or self.default_sample_size * self.vae_scale_factor
111 | device = self._execution_device
112 | self.accelerator = Accelerator(
113 | mixed_precision=mixed_precision, gradient_accumulation_steps=1
114 | )
115 |
116 | self.init(enable_gradient_checkpoint)
117 |
118 | (null_embeds, null_pooled_embeds, null_text_ids) = self.encode_prompt(
119 | prompt="",
120 | prompt_2=prompt_2,
121 | )
122 | (
123 | prompt_embeds,
124 | pooled_prompt_embeds,
125 | text_ids,
126 | ) = self.encode_prompt(
127 | prompt=prompt,
128 | prompt_2=prompt_2,
129 | prompt_embeds=prompt_embeds,
130 | pooled_prompt_embeds=pooled_prompt_embeds,
131 | device=device,
132 | num_images_per_prompt=num_images_per_prompt,
133 | max_sequence_length=max_sequence_length,
134 | )
135 | # 4. Prepare latent variables
136 | num_channels_latents = self.transformer.config.in_channels // 4
137 | latents, latent_image_ids = self.prepare_latents(
138 | num_images_per_prompt,
139 | num_channels_latents,
140 | height,
141 | width,
142 | null_embeds.dtype,
143 | device,
144 | generator,
145 | latents,
146 | )
147 |
148 | # print(style_image.shape)
149 | height_, width_ = style_image.shape[2], style_image.shape[3]
150 | style_latent = self.image2latent(style_image)
151 | # print(style_latent.shape)
152 | # print(latents.shape)
153 | style_latent = self._pack_latents(style_latent, 1, num_channels_latents, style_latent.shape[2], style_latent.shape[3])
154 |
155 | _, null_image_id = self.prepare_latents(
156 | num_images_per_prompt,
157 | num_channels_latents,
158 | height_,
159 | width_,
160 | null_embeds.dtype,
161 | device,
162 | generator,
163 | style_latent,
164 | )
165 |
166 | # 5. Prepare timesteps
167 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
168 | image_seq_len = latents.shape[1]
169 | mu = calculate_shift(
170 | image_seq_len,
171 | self.scheduler.config.base_image_seq_len,
172 | self.scheduler.config.max_image_seq_len,
173 | self.scheduler.config.base_shift,
174 | self.scheduler.config.max_shift,
175 | )
176 | timesteps, num_inference_steps = retrieve_timesteps(
177 | self.scheduler,
178 | num_inference_steps,
179 | device,
180 | None,
181 | sigmas,
182 | mu=mu,
183 | )
184 |
185 | timesteps = self.scheduler.timesteps
186 | # print(f"timesteps: {timesteps}")
187 | self._num_timesteps = len(timesteps)
188 |
189 | cache = DataCache()
190 |
191 | register_attn_control_flux(
192 | self.classifier.transformer_blocks,
193 | controller=controller,
194 | cache=cache,
195 | )
196 | register_attn_control_flux(
197 | self.classifier.single_transformer_blocks,
198 | controller=controller,
199 | cache=cache,
200 | )
201 | # handle guidance
202 | if self.transformer.config.guidance_embeds:
203 | guidance = torch.full(
204 | [1], guidance_scale, device=device, dtype=torch.float32
205 | )
206 | guidance = guidance.expand(latents.shape[0])
207 | else:
208 | guidance = None
209 |
210 | null_guidance = torch.full(
211 | [1], 1, device=device, dtype=torch.float32
212 | )
213 |
214 | # print(controller.num_self_layers)
215 |
216 |
217 | pbar = tqdm(timesteps, desc="Sample")
218 | for i, t in enumerate(pbar):
219 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
220 | with torch.no_grad():
221 | noise_pred = self.transformer(
222 | hidden_states=latents,
223 | timestep=timestep / 1000,
224 | guidance=guidance,
225 | pooled_projections=pooled_prompt_embeds,
226 | encoder_hidden_states=prompt_embeds,
227 | txt_ids=text_ids,
228 | img_ids=latent_image_ids,
229 | joint_attention_kwargs=None,
230 | return_dict=False,
231 | )[0]
232 |
233 | # compute the previous noisy sample x_t -> x_t-1
234 | latents = self.scheduler.step(
235 | noise_pred, t, latents, return_dict=False
236 | )[0]
237 | if t < start_time:
238 | if i < num_inference_steps - 1:
239 | timestep = timesteps[i+1:i+2]
240 | # print(timestep)
241 | noise = torch.randn_like(style_latent)
242 | # print(style_latent.shape)
243 | style_latent_ = self.scheduler.scale_noise(style_latent, timestep, noise)
244 | else:
245 | timestep = torch.tensor([0], device=style_latent.device)
246 | style_latent_ = style_latent
247 |
248 | cache.clear()
249 | controller.step()
250 |
251 | _ = self.transformer(
252 | hidden_states=style_latent_,
253 | timestep=timestep / 1000,
254 | guidance=null_guidance,
255 | pooled_projections=null_pooled_embeds,
256 | encoder_hidden_states=null_embeds,
257 | txt_ids=null_text_ids,
258 | img_ids=null_image_id,
259 | joint_attention_kwargs=None,
260 | return_dict=False,
261 | )[0]
262 | _, ref_k_list, ref_v_list, _ = cache.get()
263 |
264 | if adain:
265 | latents = adain_flux(latents, style_latent_)
266 |
267 | latents = latents.detach()
268 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
269 | optimizer = self.accelerator.prepare(optimizer)
270 |
271 | for _ in range(iters):
272 | cache.clear()
273 | controller.step()
274 | optimizer.zero_grad()
275 | _ = self.classifier(
276 | hidden_states=latents,
277 | timestep=timestep / 1000,
278 | guidance=null_guidance,
279 | pooled_projections=null_pooled_embeds,
280 | encoder_hidden_states=null_embeds,
281 | txt_ids=null_text_ids,
282 | img_ids=latent_image_ids,
283 | joint_attention_kwargs=None,
284 | return_dict=False,
285 | )[0]
286 | q_list, _, _, self_out_list = cache.get()
287 | ref_self_out_list = [
288 | F.scaled_dot_product_attention(
289 | q,
290 | ref_k,
291 | ref_v,
292 | )
293 | for q, ref_k, ref_v in zip(q_list, ref_k_list, ref_v_list)
294 | ]
295 | style_loss = sum(
296 | [
297 | loss_fn(self_out, ref_self_out.detach())
298 | for self_out, ref_self_out in zip(
299 | self_out_list, ref_self_out_list
300 | )
301 | ]
302 | )
303 | loss = style_loss
304 | self.accelerator.backward(loss)
305 | # loss.backward()
306 | optimizer.step()
307 |
308 | pbar.set_postfix(loss=loss.item(), time=t.item())
309 | torch.cuda.empty_cache()
310 | latents = latents.detach()
311 | return self.latent2image(latents, height, width)
312 |
313 |
314 | def calculate_shift(
315 | image_seq_len,
316 | base_seq_len: int = 256,
317 | max_seq_len: int = 4096,
318 | base_shift: float = 0.5,
319 | max_shift: float = 1.16,
320 | ):
321 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
322 | b = base_shift - m * base_seq_len
323 | mu = image_seq_len * m + b
324 | return mu
325 |
326 |
327 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
328 | def retrieve_timesteps(
329 | scheduler,
330 | num_inference_steps: Optional[int] = None,
331 | device: Optional[Union[str, torch.device]] = None,
332 | timesteps: Optional[List[int]] = None,
333 | sigmas: Optional[List[float]] = None,
334 | **kwargs,
335 | ):
336 | r"""
337 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
338 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
339 |
340 | Args:
341 | scheduler (`SchedulerMixin`):
342 | The scheduler to get timesteps from.
343 | num_inference_steps (`int`):
344 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
345 | must be `None`.
346 | device (`str` or `torch.device`, *optional*):
347 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
348 | timesteps (`List[int]`, *optional*):
349 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
350 | `num_inference_steps` and `sigmas` must be `None`.
351 | sigmas (`List[float]`, *optional*):
352 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
353 | `num_inference_steps` and `timesteps` must be `None`.
354 |
355 | Returns:
356 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
357 | second element is the number of inference steps.
358 | """
359 | if timesteps is not None and sigmas is not None:
360 | raise ValueError(
361 | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
362 | )
363 | if timesteps is not None:
364 | accepts_timesteps = "timesteps" in set(
365 | inspect.signature(scheduler.set_timesteps).parameters.keys()
366 | )
367 | if not accepts_timesteps:
368 | raise ValueError(
369 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
370 | f" timestep schedules. Please check whether you are using the correct scheduler."
371 | )
372 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
373 | timesteps = scheduler.timesteps
374 | num_inference_steps = len(timesteps)
375 | elif sigmas is not None:
376 | accept_sigmas = "sigmas" in set(
377 | inspect.signature(scheduler.set_timesteps).parameters.keys()
378 | )
379 | if not accept_sigmas:
380 | raise ValueError(
381 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
382 | f" sigmas schedules. Please check whether you are using the correct scheduler."
383 | )
384 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
385 | timesteps = scheduler.timesteps
386 | num_inference_steps = len(timesteps)
387 | else:
388 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
389 | timesteps = scheduler.timesteps
390 | return timesteps, num_inference_steps
391 |
--------------------------------------------------------------------------------
/pipeline_sd.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | from typing import Any, Dict, List, Optional, Tuple, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from accelerate import Accelerator
8 | from diffusers import StableDiffusionPipeline
9 | from diffusers.image_processor import PipelineImageInput
10 | from .losses import ad_loss, q_loss
11 | from .utils import DataCache, register_attn_control, adain
12 | from tqdm import tqdm
13 |
14 |
15 | class ADPipeline(StableDiffusionPipeline):
16 | def freeze(self):
17 | self.vae.requires_grad_(False)
18 | self.unet.requires_grad_(False)
19 | self.text_encoder.requires_grad_(False)
20 | self.classifier.requires_grad_(False)
21 |
22 | @torch.no_grad()
23 | def image2latent(self, image):
24 | dtype = next(self.vae.parameters()).dtype
25 | device = self._execution_device
26 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
27 | latent = self.vae.encode(image)["latent_dist"].mean
28 | latent = latent * self.vae.config.scaling_factor
29 | return latent
30 |
31 | @torch.no_grad()
32 | def latent2image(self, latent):
33 | dtype = next(self.vae.parameters()).dtype
34 | device = self._execution_device
35 | latent = latent.to(device=device, dtype=dtype)
36 | latent = latent / self.vae.config.scaling_factor
37 | image = self.vae.decode(latent)[0]
38 | return (image * 0.5 + 0.5).clamp(0, 1)
39 |
40 | def init(self, enable_gradient_checkpoint):
41 | self.freeze()
42 | self.enable_vae_slicing()
43 | # self.enable_model_cpu_offload()
44 | # self.enable_vae_tiling()
45 | weight_dtype = torch.float32
46 | if self.accelerator.mixed_precision == "fp16":
47 | weight_dtype = torch.float16
48 | elif self.accelerator.mixed_precision == "bf16":
49 | weight_dtype = torch.bfloat16
50 |
51 | # Move unet, vae and text_encoder to device and cast to weight_dtype
52 | self.unet.to(self.accelerator.device, dtype=weight_dtype)
53 | self.vae.to(self.accelerator.device, dtype=weight_dtype)
54 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
55 | self.classifier.to(self.accelerator.device, dtype=weight_dtype)
56 | self.classifier = self.accelerator.prepare(self.classifier)
57 | if enable_gradient_checkpoint:
58 | self.classifier.enable_gradient_checkpointing()
59 |
60 | def sample(
61 | self,
62 | lr=0.05,
63 | iters=1,
64 | attn_scale=1,
65 | adain=False,
66 | weight=0.25,
67 | controller=None,
68 | style_image=None,
69 | content_image=None,
70 | mixed_precision="no",
71 | start_time=999,
72 | enable_gradient_checkpoint=False,
73 | prompt: Union[str, List[str]] = None,
74 | height: Optional[int] = None,
75 | width: Optional[int] = None,
76 | num_inference_steps: int = 50,
77 | guidance_scale: float = 7.5,
78 | negative_prompt: Optional[Union[str, List[str]]] = None,
79 | num_images_per_prompt: Optional[int] = 1,
80 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
81 | latents: Optional[torch.Tensor] = None,
82 | prompt_embeds: Optional[torch.Tensor] = None,
83 | negative_prompt_embeds: Optional[torch.Tensor] = None,
84 | ip_adapter_image: Optional[PipelineImageInput] = None,
85 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
86 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
87 | guidance_rescale: float = 0.0,
88 | clip_skip: Optional[int] = None,
89 | **kwargs,
90 | ):
91 | # 0. Default height and width to unet
92 | height = height or self.unet.config.sample_size * self.vae_scale_factor
93 | width = width or self.unet.config.sample_size * self.vae_scale_factor
94 | self._guidance_scale = guidance_scale
95 | self._guidance_rescale = guidance_rescale
96 | self._clip_skip = clip_skip
97 | self._cross_attention_kwargs = cross_attention_kwargs
98 | self._interrupt = False
99 |
100 | self.accelerator = Accelerator(
101 | mixed_precision=mixed_precision, gradient_accumulation_steps=1
102 | )
103 | self.init(enable_gradient_checkpoint)
104 |
105 | # 2. Define call parameters
106 | if prompt is not None and isinstance(prompt, str):
107 | batch_size = 1
108 | elif prompt is not None and isinstance(prompt, list):
109 | batch_size = len(prompt)
110 | else:
111 | batch_size = prompt_embeds.shape[0]
112 |
113 | device = self._execution_device
114 |
115 | # 3. Encode input prompt
116 | lora_scale = (
117 | self.cross_attention_kwargs.get("scale", None)
118 | if self.cross_attention_kwargs is not None
119 | else None
120 | )
121 | do_cfg = guidance_scale > 1.0
122 |
123 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
124 | prompt,
125 | device,
126 | num_images_per_prompt,
127 | do_cfg,
128 | negative_prompt,
129 | prompt_embeds=prompt_embeds,
130 | negative_prompt_embeds=negative_prompt_embeds,
131 | lora_scale=lora_scale,
132 | clip_skip=self.clip_skip,
133 | )
134 |
135 | # For classifier free guidance, we need to do two forward passes.
136 | # Here we concatenate the unconditional and text embeddings into a single batch
137 | # to avoid doing two forward passes
138 | if do_cfg:
139 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
140 |
141 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
142 | image_embeds = self.prepare_ip_adapter_image_embeds(
143 | ip_adapter_image,
144 | ip_adapter_image_embeds,
145 | device,
146 | batch_size * num_images_per_prompt,
147 | do_cfg,
148 | )
149 |
150 | # 5. Prepare latent variables
151 | num_channels_latents = self.unet.config.in_channels
152 | latents = self.prepare_latents(
153 | batch_size * num_images_per_prompt,
154 | num_channels_latents,
155 | height,
156 | width,
157 | prompt_embeds.dtype,
158 | device,
159 | generator,
160 | latents,
161 | )
162 |
163 | # 6.1 Add image embeds for IP-Adapter
164 | added_cond_kwargs = (
165 | {"image_embeds": image_embeds}
166 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
167 | else None
168 | )
169 |
170 | # 6.2 Optionally get Guidance Scale Embedding
171 | timestep_cond = None
172 | if self.unet.config.time_cond_proj_dim is not None:
173 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
174 | batch_size * num_images_per_prompt
175 | )
176 | timestep_cond = self.get_guidance_scale_embedding(
177 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
178 | ).to(device=device, dtype=latents.dtype)
179 |
180 | self.scheduler.set_timesteps(num_inference_steps)
181 | timesteps = self.scheduler.timesteps
182 | self.style_latent = self.image2latent(style_image)
183 | if content_image is not None:
184 | self.content_latent = self.image2latent(content_image)
185 | else:
186 | self.content_latent = None
187 | null_embeds = self.encode_prompt("", device, 1, False)[0]
188 | self.null_embeds = null_embeds
189 | self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
190 | self.null_embeds_for_style = torch.cat(
191 | [null_embeds] * self.style_latent.shape[0]
192 | )
193 |
194 | self.adain = adain
195 | self.attn_scale = attn_scale
196 | self.cache = DataCache()
197 | self.controller = controller
198 | register_attn_control(
199 | self.classifier, controller=self.controller, cache=self.cache
200 | )
201 | print("Total self attention layers of Unet: ", controller.num_self_layers)
202 | print("Self attention layers for AD: ", controller.self_layers)
203 |
204 | pbar = tqdm(timesteps, desc="Sample")
205 | for i, t in enumerate(pbar):
206 | with torch.no_grad():
207 | # expand the latents if we are doing classifier free guidance
208 | latent_model_input = torch.cat([latents] * 2) if do_cfg else latents
209 | latent_model_input = self.scheduler.scale_model_input(
210 | latent_model_input, t
211 | )
212 | # predict the noise residual
213 | noise_pred = self.unet(
214 | latent_model_input,
215 | t,
216 | encoder_hidden_states=prompt_embeds,
217 | timestep_cond=timestep_cond,
218 | cross_attention_kwargs=self.cross_attention_kwargs,
219 | added_cond_kwargs=added_cond_kwargs,
220 | return_dict=False,
221 | )[0]
222 |
223 | # perform guidance
224 | if do_cfg:
225 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
226 | noise_pred = noise_pred_uncond + self.guidance_scale * (
227 | noise_pred_text - noise_pred_uncond
228 | )
229 | latents = self.scheduler.step(
230 | noise_pred, t, latents, return_dict=False
231 | )[0]
232 | if iters > 0 and t < start_time:
233 | latents = self.AD(latents, t, lr, iters, pbar, weight)
234 |
235 | images = self.latent2image(latents)
236 | # Offload all models
237 | self.maybe_free_model_hooks()
238 | return images
239 |
240 | def optimize(
241 | self,
242 | latents=None,
243 | attn_scale=1.0,
244 | lr=0.05,
245 | iters=1,
246 | weight=0,
247 | width=512,
248 | height=512,
249 | batch_size=1,
250 | controller=None,
251 | style_image=None,
252 | content_image=None,
253 | mixed_precision="no",
254 | num_inference_steps=50,
255 | enable_gradient_checkpoint=False,
256 | source_mask=None,
257 | target_mask=None,
258 | ):
259 | height = height // self.vae_scale_factor
260 | width = width // self.vae_scale_factor
261 |
262 | self.accelerator = Accelerator(
263 | mixed_precision=mixed_precision, gradient_accumulation_steps=1
264 | )
265 | self.init(enable_gradient_checkpoint)
266 |
267 | style_latent = self.image2latent(style_image)
268 | latents = torch.randn((batch_size, 4, height, width), device=self.device)
269 | null_embeds = self.encode_prompt("", self.device, 1, False)[0]
270 | null_embeds_for_latents = null_embeds.repeat(latents.shape[0], 1, 1)
271 | null_embeds_for_style = null_embeds.repeat(style_latent.shape[0], 1, 1)
272 |
273 | if content_image is not None:
274 | content_latent = self.image2latent(content_image)
275 | latents = torch.cat([content_latent.clone()] * batch_size)
276 | null_embeds_for_content = null_embeds.repeat(content_latent.shape[0], 1, 1)
277 |
278 | self.cache = DataCache()
279 | self.controller = controller
280 | register_attn_control(
281 | self.classifier, controller=self.controller, cache=self.cache
282 | )
283 | print("Total self attention layers of Unet: ", controller.num_self_layers)
284 | print("Self attention layers for AD: ", controller.self_layers)
285 |
286 | self.scheduler.set_timesteps(num_inference_steps)
287 | timesteps = self.scheduler.timesteps
288 | latents = latents.detach().float()
289 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
290 | optimizer = self.accelerator.prepare(optimizer)
291 | pbar = tqdm(timesteps, desc="Optimize")
292 | for i, t in enumerate(pbar):
293 | # t = torch.tensor([1], device=self.device)
294 | with torch.no_grad():
295 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
296 | style_latent,
297 | t,
298 | null_embeds_for_style,
299 | )
300 | if content_image is not None:
301 | qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
302 | content_latent,
303 | t,
304 | null_embeds_for_content,
305 | )
306 | for j in range(iters):
307 | style_loss = 0
308 | content_loss = 0
309 | optimizer.zero_grad()
310 | q_list, k_list, v_list, self_out_list = self.extract_feature(
311 | latents,
312 | t,
313 | null_embeds_for_latents,
314 | )
315 | style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=attn_scale, source_mask=source_mask, target_mask=target_mask)
316 | if content_image is not None:
317 | content_loss = q_loss(q_list, qc_list)
318 | # content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
319 | # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
320 | loss = style_loss + content_loss * weight
321 | self.accelerator.backward(loss)
322 | optimizer.step()
323 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
324 | images = self.latent2image(latents)
325 | # Offload all models
326 | self.maybe_free_model_hooks()
327 | return images
328 |
329 | def panorama(
330 | self,
331 | lr=0.05,
332 | iters=1,
333 | attn_scale=1,
334 | adain=False,
335 | controller=None,
336 | style_image=None,
337 | mixed_precision="no",
338 | enable_gradient_checkpoint=False,
339 | prompt: Union[str, List[str]] = None,
340 | height: Optional[int] = None,
341 | width: Optional[int] = None,
342 | num_inference_steps: int = 50,
343 | guidance_scale: float = 1,
344 | stride=8,
345 | view_batch_size: int = 16,
346 | negative_prompt: Optional[Union[str, List[str]]] = None,
347 | num_images_per_prompt: Optional[int] = 1,
348 | eta: float = 0.0,
349 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
350 | latents: Optional[torch.Tensor] = None,
351 | prompt_embeds: Optional[torch.Tensor] = None,
352 | negative_prompt_embeds: Optional[torch.Tensor] = None,
353 | ip_adapter_image: Optional[PipelineImageInput] = None,
354 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
355 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
356 | guidance_rescale: float = 0.0,
357 | clip_skip: Optional[int] = None,
358 | **kwargs,
359 | ):
360 |
361 | # 0. Default height and width to unet
362 | height = height or self.unet.config.sample_size * self.vae_scale_factor
363 | width = width or self.unet.config.sample_size * self.vae_scale_factor
364 |
365 | self._guidance_scale = guidance_scale
366 | self._guidance_rescale = guidance_rescale
367 | self._clip_skip = clip_skip
368 | self._cross_attention_kwargs = cross_attention_kwargs
369 | self._interrupt = False
370 |
371 | self.accelerator = Accelerator(
372 | mixed_precision=mixed_precision, gradient_accumulation_steps=1
373 | )
374 | self.init(enable_gradient_checkpoint)
375 |
376 | # 2. Define call parameters
377 | if prompt is not None and isinstance(prompt, str):
378 | batch_size = 1
379 | elif prompt is not None and isinstance(prompt, list):
380 | batch_size = len(prompt)
381 | else:
382 | batch_size = prompt_embeds.shape[0]
383 |
384 | device = self._execution_device
385 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
386 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
387 | # corresponds to doing no classifier free guidance.
388 | do_cfg = guidance_scale > 1.0
389 |
390 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
391 | image_embeds = self.prepare_ip_adapter_image_embeds(
392 | ip_adapter_image,
393 | ip_adapter_image_embeds,
394 | device,
395 | batch_size * num_images_per_prompt,
396 | self.do_classifier_free_guidance,
397 | )
398 |
399 | # 3. Encode input prompt
400 | text_encoder_lora_scale = (
401 | cross_attention_kwargs.get("scale", None)
402 | if cross_attention_kwargs is not None
403 | else None
404 | )
405 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
406 | prompt,
407 | device,
408 | num_images_per_prompt,
409 | do_cfg,
410 | negative_prompt,
411 | prompt_embeds=prompt_embeds,
412 | negative_prompt_embeds=negative_prompt_embeds,
413 | lora_scale=text_encoder_lora_scale,
414 | clip_skip=clip_skip,
415 | )
416 | # For classifier free guidance, we need to do two forward passes.
417 | # Here we concatenate the unconditional and text embeddings into a single batch
418 | # to avoid doing two forward passes
419 | if do_cfg:
420 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
421 |
422 | # 5. Prepare latent variables
423 | num_channels_latents = self.unet.config.in_channels
424 | latents = self.prepare_latents(
425 | batch_size * num_images_per_prompt,
426 | num_channels_latents,
427 | height,
428 | width,
429 | prompt_embeds.dtype,
430 | device,
431 | generator,
432 | latents,
433 | )
434 |
435 | # 6. Define panorama grid and initialize views for synthesis.
436 | # prepare batch grid
437 | views = self.get_views_(height, width, window_size=64, stride=stride)
438 | views_batch = [
439 | views[i : i + view_batch_size]
440 | for i in range(0, len(views), view_batch_size)
441 | ]
442 | print(len(views), len(views_batch), views_batch)
443 | self.scheduler.set_timesteps(num_inference_steps)
444 | views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(
445 | views_batch
446 | )
447 | count = torch.zeros_like(latents)
448 | value = torch.zeros_like(latents)
449 |
450 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
451 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
452 |
453 | # 7.1 Add image embeds for IP-Adapter
454 | added_cond_kwargs = (
455 | {"image_embeds": image_embeds}
456 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None
457 | else None
458 | )
459 |
460 | # 7.2 Optionally get Guidance Scale Embedding
461 | timestep_cond = None
462 | if self.unet.config.time_cond_proj_dim is not None:
463 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
464 | batch_size * num_images_per_prompt
465 | )
466 | timestep_cond = self.get_guidance_scale_embedding(
467 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
468 | ).to(device=device, dtype=latents.dtype)
469 |
470 | # 8. Denoising loop
471 | # Each denoising step also includes refinement of the latents with respect to the
472 | # views.
473 |
474 | timesteps = self.scheduler.timesteps
475 | self.style_latent = self.image2latent(style_image)
476 | self.content_latent = None
477 | null_embeds = self.encode_prompt("", device, 1, False)[0]
478 | self.null_embeds = null_embeds
479 | self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
480 | self.null_embeds_for_style = torch.cat(
481 | [null_embeds] * self.style_latent.shape[0]
482 | )
483 | self.adain = adain
484 | self.attn_scale = attn_scale
485 | self.cache = DataCache()
486 | self.controller = controller
487 | register_attn_control(
488 | self.classifier, controller=self.controller, cache=self.cache
489 | )
490 | print("Total self attention layers of Unet: ", controller.num_self_layers)
491 | print("Self attention layers for AD: ", controller.self_layers)
492 |
493 | pbar = tqdm(timesteps, desc="Sample")
494 | for i, t in enumerate(pbar):
495 | count.zero_()
496 | value.zero_()
497 | # generate views
498 | # Here, we iterate through different spatial crops of the latents and denoise them. These
499 | # denoised (latent) crops are then averaged to produce the final latent
500 | # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
501 | # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
502 | # Batch views denoise
503 | for j, batch_view in enumerate(views_batch):
504 | vb_size = len(batch_view)
505 | # get the latents corresponding to the current view coordinates
506 | latents_for_view = torch.cat(
507 | [
508 | latents[:, :, h_start:h_end, w_start:w_end]
509 | for h_start, h_end, w_start, w_end in batch_view
510 | ]
511 | )
512 | # rematch block's scheduler status
513 | self.scheduler.__dict__.update(views_scheduler_status[j])
514 |
515 | # expand the latents if we are doing classifier free guidance
516 | latent_model_input = (
517 | latents_for_view.repeat_interleave(2, dim=0)
518 | if do_cfg
519 | else latents_for_view
520 | )
521 |
522 | latent_model_input = self.scheduler.scale_model_input(
523 | latent_model_input, t
524 | )
525 |
526 | # repeat prompt_embeds for batch
527 | prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
528 |
529 | # predict the noise residual
530 | with torch.no_grad():
531 | noise_pred = self.unet(
532 | latent_model_input,
533 | t,
534 | encoder_hidden_states=prompt_embeds_input,
535 | timestep_cond=timestep_cond,
536 | cross_attention_kwargs=cross_attention_kwargs,
537 | added_cond_kwargs=added_cond_kwargs,
538 | ).sample
539 |
540 | # perform guidance
541 | if do_cfg:
542 | noise_pred_uncond, noise_pred_text = (
543 | noise_pred[::2],
544 | noise_pred[1::2],
545 | )
546 | noise_pred = noise_pred_uncond + guidance_scale * (
547 | noise_pred_text - noise_pred_uncond
548 | )
549 |
550 | # compute the previous noisy sample x_t -> x_t-1
551 | latents_denoised_batch = self.scheduler.step(
552 | noise_pred, t, latents_for_view, **extra_step_kwargs
553 | ).prev_sample
554 | if iters > 0:
555 | self.null_embeds_for_latents = torch.cat(
556 | [self.null_embeds] * noise_pred.shape[0]
557 | )
558 | latents_denoised_batch = self.AD(
559 | latents_denoised_batch, t, lr, iters, pbar
560 | )
561 | # save views scheduler status after sample
562 | views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
563 |
564 | # extract value from batch
565 | for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
566 | latents_denoised_batch.chunk(vb_size), batch_view
567 | ):
568 |
569 | value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
570 | count[:, :, h_start:h_end, w_start:w_end] += 1
571 |
572 | # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
573 | latents = torch.where(count > 0, value / count, value)
574 |
575 | images = self.latent2image(latents)
576 | # Offload all models
577 | self.maybe_free_model_hooks()
578 | return images
579 |
580 | def AD(self, latents, t, lr, iters, pbar, weight=0):
581 | t = max(
582 | t
583 | - self.scheduler.config.num_train_timesteps
584 | // self.scheduler.num_inference_steps,
585 | torch.tensor([0], device=self.device),
586 | )
587 | if self.adain:
588 | noise = torch.randn_like(self.style_latent)
589 | style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
590 | latents = adain(latents, style_latent)
591 |
592 | with torch.no_grad():
593 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
594 | self.style_latent,
595 | t,
596 | self.null_embeds_for_style,
597 | add_noise=True,
598 | )
599 | if self.content_latent is not None:
600 | qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
601 | self.content_latent,
602 | t,
603 | self.null_embeds,
604 | add_noise=True,
605 | )
606 |
607 | latents = latents.detach()
608 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
609 | optimizer = self.accelerator.prepare(optimizer)
610 |
611 | for j in range(iters):
612 | style_loss = 0
613 | content_loss = 0
614 | optimizer.zero_grad()
615 | q_list, k_list, v_list, self_out_list = self.extract_feature(
616 | latents,
617 | t,
618 | self.null_embeds_for_latents,
619 | add_noise=False,
620 | )
621 | style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
622 | if self.content_latent is not None:
623 | content_loss = q_loss(q_list, qc_list)
624 | # content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
625 | # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
626 | loss = style_loss + content_loss * weight
627 | self.accelerator.backward(loss)
628 | optimizer.step()
629 |
630 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
631 | latents = latents.detach()
632 | return latents
633 |
634 | def extract_feature(
635 | self,
636 | latent,
637 | t,
638 | embeds,
639 | add_noise=False,
640 | ):
641 | self.cache.clear()
642 | self.controller.step()
643 | if add_noise:
644 | noise = torch.randn_like(latent)
645 | latent_ = self.scheduler.add_noise(latent, noise, t)
646 | else:
647 | latent_ = latent
648 | _ = self.classifier(latent_, t, embeds)[0]
649 | return self.cache.get()
650 |
651 | def get_views_(
652 | self,
653 | panorama_height: int,
654 | panorama_width: int,
655 | window_size: int = 64,
656 | stride: int = 8,
657 | ) -> List[Tuple[int, int, int, int]]:
658 | panorama_height //= 8
659 | panorama_width //= 8
660 |
661 | num_blocks_height = (
662 | math.ceil((panorama_height - window_size) / stride) + 1
663 | if panorama_height > window_size
664 | else 1
665 | )
666 | num_blocks_width = (
667 | math.ceil((panorama_width - window_size) / stride) + 1
668 | if panorama_width > window_size
669 | else 1
670 | )
671 |
672 | views = []
673 | for i in range(int(num_blocks_height)):
674 | for j in range(int(num_blocks_width)):
675 | h_start = int(min(i * stride, panorama_height - window_size))
676 | w_start = int(min(j * stride, panorama_width - window_size))
677 |
678 | h_end = h_start + window_size
679 | w_end = w_start + window_size
680 |
681 | views.append((h_start, h_end, w_start, w_end))
682 |
683 | return views
684 |
--------------------------------------------------------------------------------
/pipeline_sdxl.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Any, Dict, List, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from accelerate import Accelerator
7 | from accelerate.utils import (
8 | DistributedDataParallelKwargs,
9 | ProjectConfiguration,
10 | set_seed,
11 | )
12 | from diffusers import StableDiffusionXLPipeline
13 | from diffusers.image_processor import PipelineImageInput
14 | from diffusers.utils.torch_utils import is_compiled_module
15 |
16 | from .utils import DataCache, register_attn_control, adain
17 | from .losses import ad_loss
18 | from tqdm import tqdm
19 |
20 |
21 | class ADPipeline(StableDiffusionXLPipeline):
22 | def freeze(self):
23 | self.unet.requires_grad_(False)
24 | self.text_encoder.requires_grad_(False)
25 | self.text_encoder_2.requires_grad_(False)
26 | self.vae.requires_grad_(False)
27 | self.classifier.requires_grad_(False)
28 |
29 | @torch.no_grad()
30 | def image2latent(self, image):
31 | dtype = next(self.vae.parameters()).dtype
32 | device = self._execution_device
33 | image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
34 | latent = self.vae.encode(image)["latent_dist"].mean
35 | latent = latent * self.vae.config.scaling_factor
36 | return latent
37 |
38 | @torch.no_grad()
39 | def latent2image(self, latent):
40 | dtype = next(self.vae.parameters()).dtype
41 | device = self._execution_device
42 | latent = latent.to(device=device, dtype=dtype)
43 | latent = latent / self.vae.config.scaling_factor
44 | image = self.vae.decode(latent)[0]
45 | return (image * 0.5 + 0.5).clamp(0, 1)
46 |
47 | def init(self, enable_gradient_checkpoint):
48 | self.freeze()
49 | self.enable_vae_slicing()
50 | # self.enable_model_cpu_offload()
51 | # self.enable_vae_tiling()
52 | weight_dtype = torch.float32
53 | if self.accelerator.mixed_precision == "fp16":
54 | weight_dtype = torch.float16
55 | elif self.accelerator.mixed_precision == "bf16":
56 | weight_dtype = torch.bfloat16
57 |
58 | # Move unet, vae and text_encoder to device and cast to weight_dtype
59 | self.unet.to(self.accelerator.device, dtype=weight_dtype)
60 | self.vae.to(self.accelerator.device, dtype=weight_dtype)
61 | self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
62 | self.text_encoder_2.to(self.accelerator.device, dtype=weight_dtype)
63 | self.classifier.to(self.accelerator.device, dtype=weight_dtype)
64 | self.classifier = self.accelerator.prepare(self.classifier)
65 | if enable_gradient_checkpoint:
66 | self.classifier.enable_gradient_checkpointing()
67 | # self.classifier.train()
68 |
69 |
70 | def sample(
71 | self,
72 | lr=0.05,
73 | iters=1,
74 | adain=True,
75 | controller=None,
76 | style_image=None,
77 | mixed_precision="no",
78 | init_from_style=False,
79 | start_time=999,
80 | prompt: Union[str, List[str]] = None,
81 | prompt_2: Optional[Union[str, List[str]]] = None,
82 | height: Optional[int] = None,
83 | width: Optional[int] = None,
84 | num_inference_steps: int = 50,
85 | denoising_end: Optional[float] = None,
86 | guidance_scale: float = 5.0,
87 | negative_prompt: Optional[Union[str, List[str]]] = None,
88 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
89 | num_images_per_prompt: Optional[int] = 1,
90 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
91 | latents: Optional[torch.Tensor] = None,
92 | prompt_embeds: Optional[torch.Tensor] = None,
93 | negative_prompt_embeds: Optional[torch.Tensor] = None,
94 | pooled_prompt_embeds: Optional[torch.Tensor] = None,
95 | negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
96 | ip_adapter_image: Optional[PipelineImageInput] = None,
97 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
98 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
99 | guidance_rescale: float = 0.0,
100 | original_size: Optional[Tuple[int, int]] = None,
101 | crops_coords_top_left: Tuple[int, int] = (0, 0),
102 | target_size: Optional[Tuple[int, int]] = None,
103 | negative_original_size: Optional[Tuple[int, int]] = None,
104 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
105 | negative_target_size: Optional[Tuple[int, int]] = None,
106 | clip_skip: Optional[int] = None,
107 | enable_gradient_checkpoint=False,
108 | **kwargs,
109 | ):
110 | # 0. Default height and width to unet
111 | height = height or self.default_sample_size * self.vae_scale_factor
112 | width = width or self.default_sample_size * self.vae_scale_factor
113 |
114 | original_size = original_size or (height, width)
115 | target_size = target_size or (height, width)
116 | self._guidance_scale = guidance_scale
117 | self._guidance_rescale = guidance_rescale
118 | self._clip_skip = clip_skip
119 | self._cross_attention_kwargs = cross_attention_kwargs
120 | self._denoising_end = denoising_end
121 | self._interrupt = False
122 |
123 | self.accelerator = Accelerator(
124 | mixed_precision=mixed_precision, gradient_accumulation_steps=1
125 | )
126 | self.init(enable_gradient_checkpoint)
127 |
128 | # 2. Define call parameters
129 | if prompt is not None and isinstance(prompt, str):
130 | batch_size = 1
131 | elif prompt is not None and isinstance(prompt, list):
132 | batch_size = len(prompt)
133 | else:
134 | batch_size = prompt_embeds.shape[0]
135 |
136 | device = self._execution_device
137 |
138 | # 3. Encode input prompt
139 | lora_scale = (
140 | self.cross_attention_kwargs.get("scale", None)
141 | if self.cross_attention_kwargs is not None
142 | else None
143 | )
144 |
145 | (
146 | prompt_embeds,
147 | negative_prompt_embeds,
148 | pooled_prompt_embeds,
149 | negative_pooled_prompt_embeds,
150 | ) = self.encode_prompt(
151 | prompt=prompt,
152 | prompt_2=prompt_2,
153 | device=device,
154 | num_images_per_prompt=num_images_per_prompt,
155 | do_classifier_free_guidance=self.do_classifier_free_guidance,
156 | negative_prompt=negative_prompt,
157 | negative_prompt_2=negative_prompt_2,
158 | prompt_embeds=prompt_embeds,
159 | negative_prompt_embeds=negative_prompt_embeds,
160 | pooled_prompt_embeds=pooled_prompt_embeds,
161 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
162 | lora_scale=lora_scale,
163 | clip_skip=self.clip_skip,
164 | )
165 |
166 | # 5. Prepare latent variables
167 | num_channels_latents = self.unet.config.in_channels
168 | latents = self.prepare_latents(
169 | batch_size * num_images_per_prompt,
170 | num_channels_latents,
171 | height,
172 | width,
173 | prompt_embeds.dtype,
174 | device,
175 | generator,
176 | latents,
177 | )
178 |
179 | # 7. Prepare added time ids & embeddings
180 | add_text_embeds = pooled_prompt_embeds
181 | if self.text_encoder_2 is None:
182 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
183 | else:
184 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
185 |
186 | add_time_ids = self._get_add_time_ids(
187 | original_size,
188 | crops_coords_top_left,
189 | target_size,
190 | dtype=prompt_embeds.dtype,
191 | text_encoder_projection_dim=text_encoder_projection_dim,
192 | )
193 | null_add_time_ids = add_time_ids.to(device)
194 | if negative_original_size is not None and negative_target_size is not None:
195 | negative_add_time_ids = self._get_add_time_ids(
196 | negative_original_size,
197 | negative_crops_coords_top_left,
198 | negative_target_size,
199 | dtype=prompt_embeds.dtype,
200 | text_encoder_projection_dim=text_encoder_projection_dim,
201 | )
202 | else:
203 | negative_add_time_ids = add_time_ids
204 |
205 | if self.do_classifier_free_guidance:
206 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
207 | add_text_embeds = torch.cat(
208 | [negative_pooled_prompt_embeds, add_text_embeds], dim=0
209 | )
210 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
211 |
212 | prompt_embeds = prompt_embeds.to(device)
213 | add_text_embeds = add_text_embeds.to(device)
214 | add_time_ids = add_time_ids.to(device).repeat(
215 | batch_size * num_images_per_prompt, 1
216 | )
217 |
218 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
219 | image_embeds = self.prepare_ip_adapter_image_embeds(
220 | ip_adapter_image,
221 | ip_adapter_image_embeds,
222 | device,
223 | batch_size * num_images_per_prompt,
224 | self.do_classifier_free_guidance,
225 | )
226 | # 8.1 Apply denoising_end
227 | if (
228 | self.denoising_end is not None
229 | and isinstance(self.denoising_end, float)
230 | and self.denoising_end > 0
231 | and self.denoising_end < 1
232 | ):
233 | discrete_timestep_cutoff = int(
234 | round(
235 | self.scheduler.config.num_train_timesteps
236 | - (self.denoising_end * self.scheduler.config.num_train_timesteps)
237 | )
238 | )
239 | num_inference_steps = len(
240 | list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
241 | )
242 | timesteps = timesteps[:num_inference_steps]
243 |
244 | # 9. Optionally get Guidance Scale Embedding
245 | timestep_cond = None
246 | if self.unet.config.time_cond_proj_dim is not None:
247 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
248 | batch_size * num_images_per_prompt
249 | )
250 | timestep_cond = self.get_guidance_scale_embedding(
251 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
252 | ).to(device=device, dtype=latents.dtype)
253 | self.timestep_cond = timestep_cond
254 | (null_embeds, _, null_pooled_embeds, _) = self.encode_prompt("", device=device)
255 |
256 | added_cond_kwargs = {
257 | "text_embeds": add_text_embeds,
258 | "time_ids": add_time_ids
259 | }
260 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
261 | added_cond_kwargs["image_embeds"] = image_embeds
262 |
263 | self.scheduler.set_timesteps(num_inference_steps)
264 |
265 | timesteps = self.scheduler.timesteps
266 | style_latent = self.image2latent(style_image)
267 | if init_from_style:
268 | latents = torch.cat([style_latent] * latents.shape[0])
269 | noise = torch.randn_like(latents)
270 | latents = self.scheduler.add_noise(
271 | latents,
272 | noise,
273 | torch.tensor([999]),
274 | )
275 |
276 | self.style_latent = style_latent
277 | self.null_embeds_for_latents = torch.cat([null_embeds] * (latents.shape[0]))
278 | self.null_embeds_for_style = torch.cat([null_embeds] * style_latent.shape[0])
279 | self.null_added_cond_kwargs_for_latents = {
280 | "text_embeds": torch.cat([null_pooled_embeds] * (latents.shape[0])),
281 | "time_ids": torch.cat([null_add_time_ids] * (latents.shape[0])),
282 | }
283 | self.null_added_cond_kwargs_for_style = {
284 | "text_embeds": torch.cat([null_pooled_embeds] * style_latent.shape[0]),
285 | "time_ids": torch.cat([null_add_time_ids] * style_latent.shape[0]),
286 | }
287 | self.adain = adain
288 | self.cache = DataCache()
289 | self.controller = controller
290 | register_attn_control(
291 | self.classifier, controller=controller, cache=self.cache
292 | )
293 | print("Total self attention layers of Unet: ", controller.num_self_layers)
294 | print("Self attention layers for AD: ", controller.self_layers)
295 |
296 | pbar = tqdm(timesteps, desc="Sample")
297 | for i, t in enumerate(pbar):
298 | with torch.no_grad():
299 | # expand the latents if we are doing classifier free guidance
300 | latent_model_input = (
301 | torch.cat([latents] * 2)
302 | if self.do_classifier_free_guidance
303 | else latents
304 | )
305 |
306 | # predict the noise residual
307 | noise_pred = self.unet(
308 | latent_model_input,
309 | t,
310 | encoder_hidden_states=prompt_embeds,
311 | timestep_cond=timestep_cond,
312 | cross_attention_kwargs=self.cross_attention_kwargs,
313 | added_cond_kwargs=added_cond_kwargs,
314 | return_dict=False,
315 | )[0]
316 |
317 | # perform guidance
318 | if self.do_classifier_free_guidance:
319 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
320 | noise_pred = noise_pred_uncond + self.guidance_scale * (
321 | noise_pred_text - noise_pred_uncond
322 | )
323 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
324 |
325 | if iters > 0 and t < start_time:
326 | latents = self.AD(latents, t, lr, iters, pbar)
327 |
328 |
329 | # Offload all models
330 | # self.enable_model_cpu_offload()
331 | images = self.latent2image(latents)
332 | self.maybe_free_model_hooks()
333 | return images
334 |
335 | def AD(self, latents, t, lr, iters, pbar):
336 | t = max(
337 | t
338 | - self.scheduler.config.num_train_timesteps
339 | // self.scheduler.num_inference_steps,
340 | torch.tensor([0], device=self.device),
341 | )
342 |
343 | if self.adain:
344 | noise = torch.randn_like(self.style_latent)
345 | style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
346 | latents = adain(latents, style_latent)
347 |
348 | with torch.no_grad():
349 | qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
350 | self.style_latent,
351 | t,
352 | self.null_embeds_for_style,
353 | self.timestep_cond,
354 | self.null_added_cond_kwargs_for_style,
355 | add_noise=True,
356 | )
357 | # latents = latents.to(dtype=torch.float32)
358 | latents = latents.detach()
359 | optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
360 | optimizer, latents = self.accelerator.prepare(optimizer, latents)
361 |
362 | for j in range(iters):
363 | optimizer.zero_grad()
364 | q_list, k_list, v_list, self_out_list = self.extract_feature(
365 | latents,
366 | t,
367 | self.null_embeds_for_latents,
368 | self.timestep_cond,
369 | self.null_added_cond_kwargs_for_latents,
370 | add_noise=False,
371 | )
372 |
373 | loss = ad_loss(q_list, ks_list, vs_list, self_out_list)
374 | self.accelerator.backward(loss)
375 | optimizer.step()
376 |
377 | pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
378 | latents = latents.detach()
379 | return latents
380 |
381 | def extract_feature(
382 | self,
383 | latent,
384 | t,
385 | encoder_hidden_states,
386 | timestep_cond,
387 | added_cond_kwargs,
388 | add_noise=False,
389 | ):
390 | self.cache.clear()
391 | self.controller.step()
392 | if add_noise:
393 | noise = torch.randn_like(latent)
394 | latent_ = self.scheduler.add_noise(latent, noise, t)
395 | else:
396 | latent_ = latent
397 | self.classifier(
398 | latent_,
399 | t,
400 | encoder_hidden_states=encoder_hidden_states,
401 | timestep_cond=timestep_cond,
402 | added_cond_kwargs=added_cond_kwargs,
403 | return_dict=False,
404 | )[0]
405 | return self.cache.get()
406 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "attention-distillation"
3 | description = "Non-native [a/AttentionDistillation](https://github.com/xugao97/AttentionDistillation) for ComfyUI.\nOfficial ComfyUI demo for the paper AttentionDistillation, implemented as an extension of ComfyUI. Note that this extension incorporates AttentionDistillation using diffusers."
4 | version = "1.1.0"
5 | license = {file = "LICENSE"}
6 | dependencies = ["diffusers", "accelerate", "Pillow", "torch>=2.1.0", "tqdm", "huggingface_hub", "sentencepiece", "protobuf"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/zichongc/ComfyUI-Attention-Distillation"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "zichongc"
14 | DisplayName = "ComfyUI-Attention-Distillation"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | accelerate
3 | Pillow
4 | torch>=2.1.0
5 | protobuf
6 | sentencepiece
7 | tqdm
--------------------------------------------------------------------------------
/train_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from diffusers import AutoencoderKL
6 | from torch import nn
7 | from torch.optim import Adam
8 | from .utils import load_image, save_image
9 |
10 |
11 | def main(args):
12 | os.makedirs(args.out_dir, exist_ok=True)
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 | vae = AutoencoderKL.from_pretrained(args.vae_model_path).to(
16 | device, dtype=torch.float32
17 | )
18 | vae.requires_grad_(False)
19 |
20 | image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32)
21 | image = image * 2 - 1
22 | save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png")
23 |
24 | latents = vae.encode(image)["latent_dist"].mean
25 | save_image(latents, f"{args.out_dir}/latents.png")
26 |
27 | rec_image = vae.decode(latents, return_dict=False)[0]
28 | save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png")
29 |
30 | for param in vae.decoder.parameters():
31 | param.requires_grad = True
32 |
33 | loss_fn = nn.L1Loss()
34 | optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate)
35 |
36 | # Training loop
37 | for epoch in range(args.num_epochs):
38 | reconstructed = vae.decode(latents, return_dict=False)[0]
39 | loss = loss_fn(reconstructed, image)
40 |
41 | optimizer.zero_grad()
42 | loss.backward()
43 | optimizer.step()
44 |
45 | print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}")
46 |
47 | rec_image = vae.decode(latents, return_dict=False)[0]
48 | save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png")
49 | vae.save_pretrained(
50 | f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}"
51 | )
52 |
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser(
57 | description="Train a VAE with given image and settings."
58 | )
59 |
60 | # Add arguments
61 | parser.add_argument(
62 | "--out_dir",
63 | type=str,
64 | default="./trained_vae/",
65 | help="Output directory to save results",
66 | )
67 | parser.add_argument(
68 | "--vae_model_path",
69 | type=str,
70 | required=True,
71 | help="Path to the pretrained VAE model",
72 | )
73 | parser.add_argument(
74 | "--image_path", type=str, required=True, help="Path to the input image"
75 | )
76 | parser.add_argument(
77 | "--learning_rate",
78 | type=float,
79 | default=1e-4,
80 | help="Learning rate for the optimizer",
81 | )
82 | parser.add_argument(
83 | "--num_epochs", type=int, default=75, help="Number of training epochs"
84 | )
85 |
86 | args = parser.parse_args()
87 | main(args)
88 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from PIL import Image
5 | from torchvision.transforms import ToTensor
6 | from torchvision.utils import save_image
7 | # import matplotlib.pyplot as plt
8 | import math
9 |
10 |
11 | sd15_file_names = [
12 | 'feature_extractor/preprocessor_config.json',
13 | 'scheduler/scheduler_config.json',
14 | 'text_encoder/config.json',
15 | 'text_encoder/model.safetensors',
16 | 'tokenizer/merges.txt',
17 | 'tokenizer/special_tokens_map.json',
18 | 'tokenizer/tokenizer_config.json',
19 | 'tokenizer/vocab.json',
20 | 'unet/config.json',
21 | 'unet/diffusion_pytorch_model.safetensors',
22 | 'vae/config.json',
23 | 'vae/diffusion_pytorch_model.safetensors',
24 | 'model_index.json'
25 | ]
26 |
27 | sdxl_file_names = [
28 | 'model_index.json',
29 | 'vae/config.json',
30 | 'vae/diffusion_pytorch_model.safetensors',
31 | 'unet/config.json',
32 | 'unet/diffusion_pytorch_model.safetensors',
33 | 'tokenizer/merges.txt',
34 | 'tokenizer/special_tokens_map.json',
35 | 'tokenizer/tokenizer_config.json',
36 | 'tokenizer/vocab.json',
37 | 'tokenizer_2/merges.txt',
38 | 'tokenizer_2/special_tokens_map.json',
39 | 'tokenizer_2/tokenizer_config.json',
40 | 'tokenizer_2/vocab.json',
41 | 'text_encoder/config.json',
42 | 'text_encoder/model.safetensors',
43 | 'text_encoder_2/config.json',
44 | 'text_encoder_2/model.safetensors',
45 | 'scheduler/scheduler_config.json',
46 | ]
47 |
48 | flux_file_names = [
49 | 'model_index.json',
50 | 'vae/config.json',
51 | 'vae/diffusion_pytorch_model.safetensors',
52 | 'transformer/config.json',
53 | 'transformer/diffusion_pytorch_model-00001-of-00003.safetensors',
54 | 'transformer/diffusion_pytorch_model-00002-of-00003.safetensors',
55 | 'transformer/diffusion_pytorch_model-00003-of-00003.safetensors',
56 | 'transformer/diffusion_pytorch_model.safetensors.index.json',
57 | 'tokenizer/merges.txt',
58 | 'tokenizer/special_tokens_map.json',
59 | 'tokenizer/tokenizer_config.json',
60 | 'tokenizer/vocab.json',
61 | 'tokenizer_2/spiece.model',
62 | 'tokenizer_2/special_tokens_map.json',
63 | 'tokenizer_2/tokenizer_config.json',
64 | 'tokenizer_2/tokenizer.json',
65 | 'text_encoder/config.json',
66 | 'text_encoder/model.safetensors',
67 | 'text_encoder_2/config.json',
68 | 'text_encoder_2/model-00001-of-00002.safetensors',
69 | 'text_encoder_2/model-00002-of-00002.safetensors',
70 | 'text_encoder_2/model.safetensors.index.json',
71 | 'scheduler/scheduler_config.json',
72 | ]
73 |
74 |
75 | def register_attn_control(unet, controller, cache=None):
76 | def attn_forward(self):
77 | def forward(
78 | hidden_states,
79 | encoder_hidden_states=None,
80 | attention_mask=None,
81 | temb=None,
82 | *args,
83 | **kwargs,
84 | ):
85 | residual = hidden_states
86 | if self.spatial_norm is not None:
87 | hidden_states = self.spatial_norm(hidden_states, temb)
88 |
89 | input_ndim = hidden_states.ndim
90 |
91 | if input_ndim == 4:
92 | batch_size, channel, height, width = hidden_states.shape
93 | hidden_states = hidden_states.view(
94 | batch_size, channel, height * width
95 | ).transpose(1, 2)
96 |
97 | batch_size, sequence_length, _ = (
98 | hidden_states.shape
99 | if encoder_hidden_states is None
100 | else encoder_hidden_states.shape
101 | )
102 |
103 | if attention_mask is not None:
104 | attention_mask = self.prepare_attention_mask(
105 | attention_mask, sequence_length, batch_size
106 | )
107 | # scaled_dot_product_attention expects attention_mask shape to be
108 | # (batch, heads, source_length, target_length)
109 | attention_mask = attention_mask.view(
110 | batch_size, self.heads, -1, attention_mask.shape[-1]
111 | )
112 |
113 | if self.group_norm is not None:
114 | hidden_states = self.group_norm(
115 | hidden_states.transpose(1, 2)
116 | ).transpose(1, 2)
117 |
118 | q = self.to_q(hidden_states)
119 | is_self = encoder_hidden_states is None
120 |
121 | if encoder_hidden_states is None:
122 | encoder_hidden_states = hidden_states
123 | elif self.norm_cross:
124 | encoder_hidden_states = self.norm_encoder_hidden_states(
125 | encoder_hidden_states
126 | )
127 |
128 | k = self.to_k(encoder_hidden_states)
129 | v = self.to_v(encoder_hidden_states)
130 |
131 | inner_dim = k.shape[-1]
132 | head_dim = inner_dim // self.heads
133 |
134 | q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
135 | k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
136 | v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
137 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
138 | # TODO: add support for attn.scale when we move to Torch 2.1
139 | hidden_states = F.scaled_dot_product_attention(
140 | q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
141 | )
142 | if is_self and controller.cur_self_layer in controller.self_layers:
143 | cache.add(q, k, v, hidden_states)
144 |
145 | hidden_states = hidden_states.transpose(1, 2).reshape(
146 | batch_size, -1, self.heads * head_dim
147 | )
148 | hidden_states = hidden_states.to(q.dtype)
149 |
150 | # linear proj
151 | hidden_states = self.to_out[0](hidden_states)
152 | # dropout
153 | hidden_states = self.to_out[1](hidden_states)
154 |
155 | if input_ndim == 4:
156 | hidden_states = hidden_states.transpose(-1, -2).reshape(
157 | batch_size, channel, height, width
158 | )
159 | if self.residual_connection:
160 | hidden_states = hidden_states + residual
161 |
162 | hidden_states = hidden_states / self.rescale_output_factor
163 |
164 | if is_self:
165 | controller.cur_self_layer += 1
166 |
167 | return hidden_states
168 |
169 | return forward
170 |
171 | def modify_forward(net, count):
172 | for name, subnet in net.named_children():
173 | if net.__class__.__name__ == "Attention": # spatial Transformer layer
174 | net.forward = attn_forward(net)
175 | return count + 1
176 | elif hasattr(net, "children"):
177 | count = modify_forward(subnet, count)
178 | return count
179 |
180 | cross_att_count = 0
181 | for net_name, net in unet.named_children():
182 | cross_att_count += modify_forward(net, 0)
183 | controller.num_self_layers = cross_att_count // 2
184 |
185 |
186 | def register_attn_control_flux(unet, controller, cache=None):
187 | def attn_forward(self):
188 |
189 | def forward(
190 | hidden_states,
191 | encoder_hidden_states=None,
192 | attention_mask=None,
193 | image_rotary_emb=None,
194 | *args,
195 | **kwargs,
196 | ):
197 | batch_size, _, _ = (
198 | hidden_states.shape
199 | if encoder_hidden_states is None
200 | else encoder_hidden_states.shape
201 | )
202 |
203 | # `sample` projections.
204 | query = self.to_q(hidden_states)
205 | key = self.to_k(hidden_states)
206 | value = self.to_v(hidden_states)
207 |
208 | inner_dim = key.shape[-1]
209 | head_dim = inner_dim // self.heads
210 |
211 | query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
212 | key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
213 | value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
214 |
215 | if self.norm_q is not None:
216 | query = self.norm_q(query)
217 | if self.norm_k is not None:
218 | key = self.norm_k(key)
219 |
220 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
221 | if encoder_hidden_states is not None:
222 | # `context` projections.
223 | encoder_hidden_states_query_proj = self.add_q_proj(
224 | encoder_hidden_states
225 | )
226 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
227 | encoder_hidden_states_value_proj = self.add_v_proj(
228 | encoder_hidden_states
229 | )
230 |
231 | encoder_hidden_states_query_proj = (
232 | encoder_hidden_states_query_proj.view(
233 | batch_size, -1, self.heads, head_dim
234 | ).transpose(1, 2)
235 | )
236 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
237 | batch_size, -1, self.heads, head_dim
238 | ).transpose(1, 2)
239 | encoder_hidden_states_value_proj = (
240 | encoder_hidden_states_value_proj.view(
241 | batch_size, -1, self.heads, head_dim
242 | ).transpose(1, 2)
243 | )
244 |
245 | if self.norm_added_q is not None:
246 | encoder_hidden_states_query_proj = self.norm_added_q(
247 | encoder_hidden_states_query_proj
248 | )
249 | if self.norm_added_k is not None:
250 | encoder_hidden_states_key_proj = self.norm_added_k(
251 | encoder_hidden_states_key_proj
252 | )
253 |
254 | # attention
255 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
256 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
257 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
258 |
259 | if image_rotary_emb is not None:
260 | from diffusers.models.embeddings import apply_rotary_emb
261 |
262 | query = apply_rotary_emb(query, image_rotary_emb)
263 | key = apply_rotary_emb(key, image_rotary_emb)
264 |
265 | hidden_states = F.scaled_dot_product_attention(
266 | query, key, value, dropout_p=0.0, is_causal=False
267 | )
268 | if controller.cur_self_layer in controller.self_layers:
269 | # print("cache added")
270 | cache.add(query, key, value, hidden_states)
271 | # if encoder_hidden_states is None:
272 | controller.cur_self_layer += 1
273 |
274 | hidden_states = hidden_states.transpose(1, 2).reshape(
275 | batch_size, -1, self.heads * head_dim
276 | )
277 |
278 | hidden_states = hidden_states.to(query.dtype)
279 |
280 | if encoder_hidden_states is not None:
281 | encoder_hidden_states, hidden_states = (
282 | hidden_states[:, : encoder_hidden_states.shape[1]],
283 | hidden_states[:, encoder_hidden_states.shape[1] :],
284 | )
285 |
286 | # linear proj
287 | hidden_states = self.to_out[0](hidden_states)
288 | # dropout
289 | hidden_states = self.to_out[1](hidden_states)
290 | encoder_hidden_states = self.to_add_out(encoder_hidden_states)
291 |
292 | return hidden_states, encoder_hidden_states
293 | else:
294 | return hidden_states
295 |
296 | return forward
297 |
298 | def modify_forward(net, count):
299 | # print(net.named_children())
300 | for name, subnet in net.named_children():
301 | if net.__class__.__name__ == "Attention": # spatial Transformer layer
302 | net.forward = attn_forward(net)
303 | return count + 1
304 | elif hasattr(net, "children"):
305 | count = modify_forward(subnet, count)
306 | return count
307 |
308 | cross_att_count = 0
309 | cross_att_count += modify_forward(unet, 0)
310 | controller.num_self_layers += cross_att_count
311 |
312 |
313 | def load_image(image_path, size=None, mode="RGB"):
314 | img = Image.open(image_path).convert(mode)
315 | if size is None:
316 | width, height = img.size
317 | new_width = (width // 64) * 64
318 | new_height = (height // 64) * 64
319 | size = (new_width, new_height)
320 | img = img.resize(size, Image.BICUBIC)
321 | return ToTensor()(img).unsqueeze(0)
322 |
323 |
324 | def adain(source, target, eps=1e-6):
325 | source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std(
326 | source, dim=(2, 3), keepdim=True
327 | )
328 | target_mean, target_std = torch.mean(
329 | target, dim=(0, 2, 3), keepdim=True
330 | ), torch.std(target, dim=(0, 2, 3), keepdim=True)
331 | normalized_source = (source - source_mean) / (source_std + eps)
332 | transferred_source = normalized_source * target_std + target_mean
333 |
334 | return transferred_source
335 |
336 |
337 | def adain_flux(source, target, eps=1e-6):
338 | source_mean, source_std = torch.mean(source, dim=1, keepdim=True), torch.std(
339 | source, dim=1, keepdim=True
340 | )
341 | target_mean, target_std = torch.mean(
342 | target, dim=(0, 1), keepdim=True
343 | ), torch.std(target, dim=(0, 1), keepdim=True)
344 | normalized_source = (source - source_mean) / (source_std + eps)
345 | transferred_source = normalized_source * target_std + target_mean
346 |
347 | return transferred_source
348 |
349 |
350 | class Controller:
351 | def step(self):
352 | self.cur_self_layer = 0
353 |
354 | def __init__(self, self_layers=(0, 16)):
355 | self.num_self_layers = -1
356 | self.cur_self_layer = 0
357 | self.self_layers = list(range(*self_layers))
358 |
359 |
360 | class DataCache:
361 | def __init__(self):
362 | self.q = []
363 | self.k = []
364 | self.v = []
365 | self.out = []
366 |
367 | def clear(self):
368 | self.q.clear()
369 | self.k.clear()
370 | self.v.clear()
371 | self.out.clear()
372 |
373 | def add(self, q, k, v, out):
374 | self.q.append(q)
375 | self.k.append(k)
376 | self.v.append(v)
377 | self.out.append(out)
378 |
379 | def get(self):
380 | return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
381 |
382 |
383 |
384 | # def show_image(path, title, display_height=3, title_fontsize=12):
385 | # img = Image.open(path)
386 | # img_width, img_height = img.size
387 |
388 | # aspect_ratio = img_width / img_height
389 | # display_width = display_height * aspect_ratio
390 |
391 | # plt.figure(figsize=(display_width, display_height))
392 | # plt.imshow(img)
393 | # plt.title(title,
394 | # fontsize=title_fontsize,
395 | # fontweight='bold',
396 | # pad=20)
397 | # plt.axis('off')
398 | # plt.tight_layout()
399 | # plt.show()
400 |
--------------------------------------------------------------------------------
/workflows/style_t2i_generation_flux.json:
--------------------------------------------------------------------------------
1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"ResizeImage","pos":[254.7171630859375,483.82080078125],"size":[315,58],"flags":{},"order":3,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":1,"localized_name":"image"}],"outputs":[{"name":"image","type":"IMAGE","links":[2],"slot_index":0,"localized_name":"image"}],"properties":{"Node name for S&R":"ResizeImage"},"widgets_values":[512]},{"id":5,"type":"PreviewImage","pos":[1068.2235107421875,130.05441284179688],"size":[540.3896484375,543.4026489257812],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5,"localized_name":"images"}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":3,"type":"LoadPILImage","pos":[326.92474365234375,104.5999755859375],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[3],"slot_index":0,"localized_name":"image"}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":2,"type":"LoadPILImage","pos":[680.6911010742188,103.04146575927734],"size":[315,294],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0,"localized_name":"image"}],"title":"Content Image","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["lecun.png","image"]},{"id":1,"type":"LoadDistiller","pos":[256.7950439453125,609.0155029296875],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[4],"slot_index":0,"localized_name":"distiller"}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADOptimizer","pos":[620.4314575195312,465.638916015625],"size":[415.8000183105469,242],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":4,"localized_name":"distiller"},{"name":"content","type":"IMAGE","link":2,"localized_name":"content"},{"name":"style","type":"IMAGE","link":3,"localized_name":"style"}],"outputs":[{"name":"image","type":"IMAGE","links":[5],"slot_index":0,"localized_name":"image"}],"properties":{"Node name for S&R":"ADOptimizer"},"widgets_values":[300,0.23,0.05,512,512,2025,"fixed"]}],"links":[[1,2,0,6,0,"IMAGE"],[2,6,0,4,1,"IMAGE"],[3,3,0,4,2,"IMAGE"],[4,1,0,4,0,"DISTILLER"],[5,4,0,5,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style Transfer","bounding":[150.79759216308594,-23.581954956054688,1556.8001708984375,755.0857543945312],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.9090909090909092,"offset":[116.54605349790273,20.61595758295544]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4}
--------------------------------------------------------------------------------
/workflows/style_t2i_generation_sd15.json:
--------------------------------------------------------------------------------
1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"PreviewImage","pos":[1050.10546875,266.7724609375],"size":[529.428466796875,492.2856750488281],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":5,"type":"LoadPILImage","pos":[525.5338134765625,133.05821228027344],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":3,"type":"PureText","pos":[25.32625961303711,601.0584106445312],"size":[400,200],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[4],"slot_index":0}],"title":"Negative prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["blur, low quality"]},{"id":2,"type":"PureText","pos":[27.923683166503906,327.5516357421875],"size":[400,200],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[3],"slot_index":0}],"title":"Positive prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["A portrait of Mr. Donald Trump"]},{"id":1,"type":"LoadDistiller","pos":[75.71586608886719,163.65560913085938],"size":[315,82],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[2],"slot_index":0}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADSampler","pos":[492.3912658691406,484.4866943359375],"size":[504,334],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":2},{"name":"style","type":"IMAGE","link":1},{"name":"positive","type":"CONDITIONING","link":3},{"name":"negative","type":"CONDITIONING","link":4}],"outputs":[{"name":"images","type":"IMAGE","links":[5],"slot_index":0}],"properties":{"Node name for S&R":"ADSampler"},"widgets_values":[50,0.015,2,7.5,0,1,2025,"increment",512,512]}],"links":[[1,5,0,4,1,"IMAGE"],[2,1,0,4,0,"DISTILLER"],[3,2,0,4,2,"CONDITIONING"],[4,3,0,4,3,"CONDITIONING"],[5,4,0,6,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style-Specific Text-to-Image Generation","bounding":[-50.282806396484375,25.119415283203125,1678.1817626953125,820.259765625],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":1,"offset":[147.782927316091,-0.36286816241710085]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4}
--------------------------------------------------------------------------------
/workflows/style_t2i_generation_sdxl.json:
--------------------------------------------------------------------------------
1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":5,"type":"LoadPILImage","pos":[642.10498046875,143.9153289794922],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0,"localized_name":"image"}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["1.png","image"]},{"id":1,"type":"LoadDistiller","pos":[192.28724670410156,174.5127410888672],"size":[315,82],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[2],"slot_index":0,"localized_name":"distiller"}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-xl-base-1.0","bf16"]},{"id":3,"type":"PureText","pos":[141.8976287841797,611.9155883789062],"size":[400,200],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[4],"slot_index":0,"localized_name":"CONDITIONING"}],"title":"Negative prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":[""]},{"id":6,"type":"PreviewImage","pos":[1166.67724609375,277.6296081542969],"size":[529.428466796875,492.2856750488281],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5,"localized_name":"images"}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"},"widgets_values":[]},{"id":2,"type":"PureText","pos":[144.49505615234375,338.4087829589844],"size":[400,200],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[3],"slot_index":0,"localized_name":"CONDITIONING"}],"title":"Positive prompt","properties":{"Node name for S&R":"PureText"},"widgets_values":["A photo of Big Ben, London."]},{"id":4,"type":"ADSampler","pos":[608.9625244140625,495.3438415527344],"size":[504,334],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":2,"localized_name":"distiller"},{"name":"style","type":"IMAGE","link":1,"localized_name":"style"},{"name":"positive","type":"CONDITIONING","link":3,"localized_name":"positive"},{"name":"negative","type":"CONDITIONING","link":4,"localized_name":"negative"}],"outputs":[{"name":"images","type":"IMAGE","links":[5],"slot_index":0,"localized_name":"images"}],"properties":{"Node name for S&R":"ADSampler"},"widgets_values":[50,0.015,2,7.5,1,2025,"fixed",1024,1024]}],"links":[[1,5,0,4,1,"IMAGE"],[2,1,0,4,0,"DISTILLER"],[3,2,0,4,2,"CONDITIONING"],[4,3,0,4,3,"CONDITIONING"],[5,4,0,6,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style-Specific Text-to-Image Generation (SDXL)","bounding":[66.2885971069336,35.97654342651367,1678.1817626953125,820.259765625],"color":"#88A","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.9090909090909092,"offset":[116.54605349790273,20.61595758295544]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4}
--------------------------------------------------------------------------------
/workflows/style_transfer_sd15.json:
--------------------------------------------------------------------------------
1 | {"last_node_id":6,"last_link_id":5,"nodes":[{"id":6,"type":"ResizeImage","pos":[254.7171630859375,483.82080078125],"size":[315,58],"flags":{},"order":3,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":1}],"outputs":[{"name":"image","type":"IMAGE","links":[2],"slot_index":0}],"properties":{"Node name for S&R":"ResizeImage"},"widgets_values":[512]},{"id":5,"type":"PreviewImage","pos":[1068.2235107421875,130.05441284179688],"size":[540.3896484375,543.4026489257812],"flags":{},"order":5,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":5}],"outputs":[],"properties":{"Node name for S&R":"PreviewImage"}},{"id":3,"type":"LoadPILImage","pos":[326.92474365234375,104.5999755859375],"size":[315,294],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[3],"slot_index":0}],"title":"Style Reference","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["40.jpg","image"]},{"id":2,"type":"LoadPILImage","pos":[680.6911010742188,103.04146575927734],"size":[315,294],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"image","type":"IMAGE","links":[1],"slot_index":0}],"title":"Content Image","properties":{"Node name for S&R":"LoadPILImage"},"widgets_values":["lecun.png","image"]},{"id":1,"type":"LoadDistiller","pos":[256.7950439453125,609.0155029296875],"size":[315,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"distiller","type":"DISTILLER","links":[4],"slot_index":0}],"properties":{"Node name for S&R":"LoadDistiller"},"widgets_values":["stable-diffusion-v1-5","bf16"]},{"id":4,"type":"ADOptimizer","pos":[620.4314575195312,465.638916015625],"size":[415.8000183105469,242],"flags":{},"order":4,"mode":0,"inputs":[{"name":"distiller","type":"DISTILLER","link":4},{"name":"content","type":"IMAGE","link":2},{"name":"style","type":"IMAGE","link":3}],"outputs":[{"name":"image","type":"IMAGE","links":[5],"slot_index":0}],"properties":{"Node name for S&R":"ADOptimizer"},"widgets_values":[300,0.23,0.05,512,512,2025,"fixed"]}],"links":[[1,2,0,6,0,"IMAGE"],[2,6,0,4,1,"IMAGE"],[3,3,0,4,2,"IMAGE"],[4,1,0,4,0,"DISTILLER"],[5,4,0,5,0,"IMAGE"]],"groups":[{"id":1,"title":"Attention Distillation for Style Transfer","bounding":[150.79759216308594,-23.581954956054688,1556.8001708984375,755.0857543945312],"color":"#3f789e","font_size":24,"flags":{}}],"config":{},"extra":{"ds":{"scale":0.8264462809917354,"offset":[279.35430419292214,247.5915815675024]},"node_versions":{"comfy-core":"0.3.12"}},"version":0.4}
--------------------------------------------------------------------------------