├── LICENSE
├── README.md
├── demo
├── anime_lora_sun.jpg
├── controlnet_lora.py
└── util.py
├── pics
├── boy.png
├── cat.gif
├── cat2.gif
├── comparison.jpg
├── dpm20.gif
├── face_sun_small.gif
├── guilin_sun_small.gif
└── sun4.gif
└── speed_up_net
├── __init__.py
├── attention_processor.py
└── sun_pipe.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ant Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # speedup-plugin-for-stable-diffusions
2 |
3 | This repo is the official implementation of SpeedUpNet(SUN) in PyTorch.
4 |
5 | Paper: [SpeedUpNet: A Plug-and-Play Hyper-Network for Accelerating Text-to-Image Diffusion Models](https://arxiv.org/pdf/2312.08887.pdf)
6 |
7 | Project Page: [SpeedUpNet](https://williechai.github.io/speedup-plugin-for-stable-diffusions.github.io/)
8 |
9 |
10 |
11 | ## 10x speed up on stable diffusions
12 | Introducing SUN as a plug-in, a pre-trained SD can generate high-quality images in only 4 steps. We can test on MacBook Pro(M1 Pro):
13 |
14 | DPM-Solver++ 20 steps, 16 seconds (baseline)
15 |
16 |
17 | +SUN, 4 steps, 2 seconds
18 |
19 |
20 | See more on our [webpage](https://williechai.github.io/speedup-plugin-for-stable-diffusions.github.io/)
21 |
22 | ## Realtime Controllable Generation
23 |
24 | SUN is compatible with controllable tools.
25 | Real-time rendering can be achieved on high-end consumer-grade graphics cards.
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | ## Usage
35 |
36 | ```
37 | cd demo
38 |
39 | # prepare models
40 |
41 | python controlnet_lora.py
42 | ```
43 |
44 | ## Download SUN adapter
45 |
46 | https://huggingface.co/Williechai/SpeedUpNet/tree/main
47 |
48 | ## Update
49 |
50 | **`2023.12.15`**: Readme.
51 |
52 |
53 | ## Citation
54 | If you find this work is helpful in your research, please cite our work:
55 | ```
56 | @misc{chai2023speedupnet,
57 | title={SpeedUpNet: A Plug-and-Play Hyper-Network for Accelerating Text-to-Image Diffusion Models},
58 | author={Weilong Chai and DanDan Zheng and Jiajiong Cao and Zhiquan Chen and Changbao Wang and Chenguang Ma},
59 | year={2023},
60 | eprint={2312.08887},
61 | archivePrefix={arXiv},
62 | primaryClass={cs.CV}
63 | }
64 | ```
65 |
66 | ## Contact
67 | If you have any questions, feel free to open an issue or directly contact me via: `weilong.cwl@antgroup.com`.
68 |
--------------------------------------------------------------------------------
/demo/anime_lora_sun.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/demo/anime_lora_sun.jpg
--------------------------------------------------------------------------------
/demo/controlnet_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import DPMSolverMultistepScheduler
3 | from tqdm.auto import tqdm
4 | from transformers import CLIPTokenizer, CLIPTextModel
5 | from PIL import Image
6 | import numpy as np
7 | from IPython import embed
8 | from diffusers.image_processor import VaeImageProcessor
9 | from diffusers import ControlNetModel
10 | import time
11 | from IPython import embed
12 | from controlnet_aux import CannyDetector, OpenposeDetector
13 | import cv2
14 |
15 | import sys
16 | sys.path.insert(0, '..')
17 | from speed_up_net.sun_pipe import SUNPipe
18 |
19 | from util import get_stable_diffusion_controlnet_pipe, pil_up_crop_square
20 |
21 |
22 | weight_dtype = torch.float16
23 | torch_device = "mps"
24 |
25 | # model for base model
26 | model_path = "../models/sd_v15"
27 | # model for style unet
28 | # https://huggingface.co/SG161222/Realistic_Vision_V5.1_noVAE
29 | style_unet_model_path = "../models/Realistic_Vision_V5.1_noVAE"
30 | # model for controlnet
31 | controlnet_model_path = "../models/sd-controlnet-canny"
32 |
33 | pipe = get_stable_diffusion_controlnet_pipe(
34 | base_model_path=model_path,
35 | controlnet_model_path=controlnet_model_path,
36 | unet_model_path=None, #style_unet_model_path
37 | scheluler_cls=DPMSolverMultistepScheduler,
38 | ).to(torch_device)
39 |
40 | # Refer to README.md to Download
41 | adapter_path = "../test_cache/sun_adapter_sdv15_4step_addpos.safetensors"
42 | pipe = SUNPipe(pipe, adapter_path, add_pos=True)
43 |
44 | # https://civitai.com/models/24833/minimalist-anime-style
45 | lora_path = "../models/anime_minimalist_v1-000020.safetensors"
46 |
47 | if lora_path:
48 | pipe.load_lora_weights(lora_path)
49 |
50 | canny_detecter = CannyDetector()
51 |
52 | img_pil = pil_up_crop_square(Image.open("../pics/boy.png")).resize((512, 512))
53 |
54 | default_negative_prompt = "worst quality, low quality, blurry, bad hand, watermark, multiple limbs, deformed fingers, bad fingers, ugly, monochrome, horror, geometry, bad anatomy, bad limbs, Blurry pupil, bad shading, error, bad composition, Extra fingers, strange fingers, Extra ears, extra leg, bad leg, disability, Blurry eyes, bad eyes, Twisted body, confusion, bad legs"
55 | prompt = "anime minimalist, 1 20 y.o man, solo, closeup face photo in sweater, cleavage, pale skin"
56 | negative_prompt = default_negative_prompt
57 |
58 | #anime minimalist, anime minimalist,
59 |
60 |
61 | with torch.no_grad():
62 | ctn_img = canny_detecter(img_pil)
63 | image = pipe(
64 | prompt=prompt,
65 | negative_prompt=negative_prompt,
66 | num_inference_steps=4,
67 | generator=torch.manual_seed(0),
68 | eta=0.0,
69 | image=ctn_img,
70 | controlnet_conditioning_scale=0.75,
71 | ).images[0]
72 |
73 |
74 | image.save("anime_lora_sun.jpg")
75 |
--------------------------------------------------------------------------------
/demo/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from speed_up_net.sun_pipe import SUNPipe
3 |
4 | from diffusers import DDIMScheduler, UNet2DConditionModel, ControlNetModel
5 | from diffusers import StableDiffusionPipeline
6 | from diffusers import StableDiffusionControlNetPipeline
7 | from diffusers import StableDiffusionImg2ImgPipeline
8 | from diffusers import StableDiffusionInpaintPipeline
9 | from PIL import Image
10 | import cv2
11 | import numpy as np
12 |
13 | import torch
14 |
15 |
16 | def get_stable_diffusion_pipe(
17 | base_model_path,
18 | unet_model_path=None,
19 | scheluler_cls=DDIMScheduler
20 | ):
21 |
22 | pipe = StableDiffusionPipeline.from_pretrained(
23 | base_model_path, safety_checker=None)
24 | pipe.scheduler = scheluler_cls.from_config(pipe.scheduler.config)
25 | pipe.scheduler.register_to_config(timestep_spacing="trailing")
26 |
27 | torch_dtype = pipe.unet.dtype
28 | torch_device = pipe.unet.device
29 | if unet_model_path:
30 | unet = UNet2DConditionModel.from_pretrained(
31 | unet_model_path, subfolder="unet",
32 | use_safetensors=(False if os.path.exists(os.path.join(unet_model_path, "unet", "diffusion_pytorch_model.bin")) else True)
33 | ).to(torch_dtype).to(torch_device)
34 | pipe.unet = unet
35 |
36 | return pipe
37 |
38 | def get_stable_diffusion_img2img(
39 | base_model_path,
40 | unet_model_path=None,
41 | scheluler_cls=DDIMScheduler
42 | ):
43 |
44 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
45 | base_model_path, safety_checker=None, torch_dtype=torch.float16)
46 | pipe.scheduler = scheluler_cls.from_config(pipe.scheduler.config)
47 | pipe.scheduler.register_to_config(timestep_spacing="trailing")
48 | pipe.vae.to(torch.float16)
49 |
50 | torch_dtype = pipe.unet.dtype
51 | torch_device = pipe.unet.device
52 | if unet_model_path:
53 | unet = UNet2DConditionModel.from_pretrained(
54 | unet_model_path, subfolder="unet",
55 | use_safetensors=(False if os.path.exists(os.path.join(unet_model_path, "unet", "diffusion_pytorch_model.bin")) else True)
56 | ).to(torch_dtype).to(torch_device)
57 | pipe.unet = unet
58 |
59 | return pipe
60 |
61 | def get_stable_diffusion_controlnet_pipe(
62 | base_model_path,
63 | controlnet_model_path,
64 | unet_model_path=None,
65 | scheluler_cls=DDIMScheduler
66 | ):
67 | #controlnet_model_path = os.path.join(model_root, "models--lllyasviel--control_v11p_sd15_openpose/")
68 | controlnet = ControlNetModel.from_pretrained(
69 | controlnet_model_path
70 | )
71 |
72 | pipe = StableDiffusionControlNetPipeline.from_pretrained(
73 | base_model_path, controlnet=controlnet, safety_checker=None)
74 |
75 | pipe.to(torch.float16)
76 |
77 | pipe.scheduler = scheluler_cls.from_config(pipe.scheduler.config)
78 | pipe.scheduler.register_to_config(timestep_spacing="trailing")
79 |
80 | torch_dtype = pipe.unet.dtype
81 | torch_device = pipe.unet.device
82 | if unet_model_path:
83 | unet = UNet2DConditionModel.from_pretrained(
84 | unet_model_path, subfolder="unet",
85 | use_safetensors=(False if os.path.exists(os.path.join(unet_model_path, "unet", "diffusion_pytorch_model.bin")) else True)
86 | ).to(torch_dtype).to(torch_device)
87 | pipe.unet = unet
88 |
89 | return pipe
90 |
91 | def get_stable_diffusion_inpainting_pipe(
92 | base_model_path,
93 | scheluler_cls=DDIMScheduler
94 | ):
95 |
96 | pipe = StableDiffusionInpaintPipeline.from_pretrained(
97 | base_model_path, safety_checker=None)
98 | pipe.scheduler = scheluler_cls.from_config(pipe.scheduler.config)
99 | pipe.scheduler.register_to_config(timestep_spacing="trailing")
100 |
101 | #torch_dtype = pipe.unet.dtype
102 | #torch_device = pipe.unet.device
103 | return pipe
104 |
105 | def pil_up_crop_square(img):
106 | size = img.size
107 | short_size = size[1] if size[0] > size[1] else size[0]
108 | crop_left = (size[0] - short_size) // 2
109 | crop_right = crop_left + short_size
110 | crop_up = 0
111 | crop_down = crop_up + short_size
112 | return img.crop((crop_left, crop_up, crop_right, crop_down))
113 |
114 | def pil_center_crop_square(img, ratio=1.0):
115 | size = img.size
116 | short_size = size[1] if size[0] > size[1] else size[0]
117 | center_size = int(short_size * ratio)
118 | crop_left = (size[0] - center_size) // 2
119 | crop_right = crop_left + center_size
120 | crop_up = (size[1] - center_size) // 2
121 | crop_down = crop_up + center_size
122 | return img.crop((crop_left, crop_up, crop_right, crop_down))
123 |
124 | def cv2_center_crop_square(img, ratio=1.0):
125 | size = (img.shape[1], img.shape[0])
126 | short_size = size[1] if size[0] > size[1] else size[0]
127 | center_size = int(short_size * ratio)
128 | crop_left = (size[0] - center_size) // 2
129 | crop_right = crop_left + center_size
130 | crop_up = (size[1] - center_size) // 2
131 | crop_down = crop_up + center_size
132 | #return img.crop((crop_left, crop_up, crop_right, crop_down))
133 | return img[crop_up:crop_down, crop_left:crop_right]
134 |
135 | def tensor_to_pil(tensor):
136 | image = (tensor / 2 + 0.5).clamp(0, 1).squeeze()
137 | image = (image.permute(1, 2, 0) * 255).round().to(torch.uint8).cpu().numpy()
138 | image = Image.fromarray(image)
139 | return image
140 |
141 | def tensor_to_cv2(tensor):
142 | image = ((tensor / 2 + 0.5).clamp(0, 1) * 255).squeeze().to(torch.uint8)
143 | image = image.round().cpu().permute(1, 2, 0).numpy()
144 | image = image[:,:,::-1]
145 | return image
146 |
147 | def cv2pil(mat):
148 | return Image.fromarray(cv2.cvtColor(mat, cv2.COLOR_BGR2RGB))
149 |
150 | def pil2cv2(img):
151 | return cv2.cvtColor(np.asarray(img), cv2.COLOR_BGR2RGB)
152 |
153 | def pil_to_tensor(img):
154 | return torch.from_numpy((np.asarray(img).astype(np.float32) / 255)).permute(2, 0, 1)
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/pics/boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/boy.png
--------------------------------------------------------------------------------
/pics/cat.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/cat.gif
--------------------------------------------------------------------------------
/pics/cat2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/cat2.gif
--------------------------------------------------------------------------------
/pics/comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/comparison.jpg
--------------------------------------------------------------------------------
/pics/dpm20.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/dpm20.gif
--------------------------------------------------------------------------------
/pics/face_sun_small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/face_sun_small.gif
--------------------------------------------------------------------------------
/pics/guilin_sun_small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/guilin_sun_small.gif
--------------------------------------------------------------------------------
/pics/sun4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williechai/speedup-plugin-for-stable-diffusions/be6f39dfdbc3a6454c92d0c12ee2209416175c8e/pics/sun4.gif
--------------------------------------------------------------------------------
/speed_up_net/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/speed_up_net/attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def is_torch2_available():
7 | return hasattr(F, "scaled_dot_product_attention")
8 |
9 |
10 |
11 | class AttnProcessor(nn.Module):
12 | r"""
13 | Default processor for performing attention-related computations.
14 | """
15 | def __init__(
16 | self,
17 | hidden_size=None,
18 | cross_attention_dim=None,
19 | ):
20 | super().__init__()
21 |
22 | def __call__(
23 | self,
24 | attn,
25 | hidden_states,
26 | encoder_hidden_states=None,
27 | attention_mask=None,
28 | temb=None,
29 | ):
30 | residual = hidden_states
31 |
32 | if attn.spatial_norm is not None:
33 | hidden_states = attn.spatial_norm(hidden_states, temb)
34 |
35 | input_ndim = hidden_states.ndim
36 |
37 | if input_ndim == 4:
38 | batch_size, channel, height, width = hidden_states.shape
39 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
40 |
41 | batch_size, sequence_length, _ = (
42 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
43 | )
44 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
45 |
46 | if attn.group_norm is not None:
47 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
48 |
49 | query = attn.to_q(hidden_states)
50 |
51 | if encoder_hidden_states is None:
52 | encoder_hidden_states = hidden_states
53 | elif attn.norm_cross:
54 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
55 |
56 | key = attn.to_k(encoder_hidden_states)
57 | value = attn.to_v(encoder_hidden_states)
58 |
59 | query = attn.head_to_batch_dim(query)
60 | key = attn.head_to_batch_dim(key)
61 | value = attn.head_to_batch_dim(value)
62 |
63 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
64 | hidden_states = torch.bmm(attention_probs, value)
65 | hidden_states = attn.batch_to_head_dim(hidden_states)
66 |
67 | # linear proj
68 | hidden_states = attn.to_out[0](hidden_states)
69 | # dropout
70 | hidden_states = attn.to_out[1](hidden_states)
71 |
72 | if input_ndim == 4:
73 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
74 |
75 | if attn.residual_connection:
76 | hidden_states = hidden_states + residual
77 |
78 | hidden_states = hidden_states / attn.rescale_output_factor
79 |
80 | return hidden_states
81 |
82 |
83 | def compute_attention_v1(attn, query, key, value, attention_mask):
84 |
85 | query = attn.head_to_batch_dim(query)
86 | key = attn.head_to_batch_dim(key)
87 | value = attn.head_to_batch_dim(value)
88 |
89 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
90 | hidden_states = torch.bmm(attention_probs, value)
91 | hidden_states = attn.batch_to_head_dim(hidden_states)
92 | return hidden_states
93 |
94 | def compute_attention_v2(attn, query, key, value, attention_mask):
95 | batch_size = query.shape[0]
96 | inner_dim = key.shape[-1]
97 | head_dim = inner_dim // attn.heads
98 |
99 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
100 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
101 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
102 |
103 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
104 | # TODO: add support for attn.scale when we move to Torch 2.1
105 | hidden_states = F.scaled_dot_product_attention(
106 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
107 | )
108 |
109 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
110 | hidden_states = hidden_states.to(query.dtype)
111 |
112 | return hidden_states
113 |
114 |
115 | from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
116 |
117 | class SUNAttnProcessor(nn.Module):
118 | r"""
119 | Attention processor for SUN.
120 | Args:
121 | hidden_size (`int`):
122 | The hidden size of the attention layer.
123 | cross_attention_dim (`int`):
124 | The number of channels in the `encoder_hidden_states`.
125 | num_tokens (`int`, defaults to 77):
126 | The context length of the original text-encoder features.
127 | add_pos (`bool`, defaults to False):
128 | Whether to use extra parameters for positive prompt.
129 | """
130 |
131 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=77, param_scale=False, add_pos=False):
132 | super().__init__()
133 |
134 | self.hidden_size = hidden_size
135 | self.cross_attention_dim = cross_attention_dim
136 | self.num_tokens = num_tokens
137 | self.add_pos = add_pos
138 |
139 | self.to_k_neg = LoRACompatibleLinear(cross_attention_dim or hidden_size, hidden_size, bias=False)
140 | self.to_v_neg = LoRACompatibleLinear(cross_attention_dim or hidden_size, hidden_size, bias=False)
141 | self.to_k_neg_alpha = nn.Parameter(torch.ones(hidden_size))
142 | self.to_k_neg_beta = nn.Parameter(torch.zeros(hidden_size))
143 |
144 | if self.add_pos:
145 | self.to_k_extra_pos = LoRACompatibleLinear(cross_attention_dim or hidden_size, hidden_size, bias=False)
146 | self.to_v_extra_pos = LoRACompatibleLinear(cross_attention_dim or hidden_size, hidden_size, bias=False)
147 | self.to_k_extra_pos_alpha = nn.Parameter(torch.ones(hidden_size))
148 | self.to_k_extra_pos_beta = nn.Parameter(torch.zeros(hidden_size))
149 |
150 | if param_scale:
151 | self.scale = nn.Parameter(torch.Tensor([scale]).float())
152 | else:
153 | self.scale = scale
154 |
155 | def set_cfg_scale(self, scale):
156 | g = (scale - 1) / scale
157 | self.scale = g
158 |
159 | def __call__(
160 | self,
161 | attn,
162 | hidden_states,
163 | encoder_hidden_states=None,
164 | attention_mask=None,
165 | temb=None,
166 | ):
167 | residual = hidden_states
168 |
169 | if attn.spatial_norm is not None:
170 | hidden_states = attn.spatial_norm(hidden_states, temb)
171 |
172 | input_ndim = hidden_states.ndim
173 |
174 | if input_ndim == 4:
175 | batch_size, channel, height, width = hidden_states.shape
176 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
177 |
178 | batch_size, sequence_length, _ = (
179 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
180 | )
181 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
182 |
183 | if attn.group_norm is not None:
184 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
185 |
186 | query = attn.to_q(hidden_states)
187 |
188 | if encoder_hidden_states is None:
189 | encoder_hidden_states = hidden_states
190 | else:
191 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
192 | encoder_hidden_states, neg_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
193 | if attn.norm_cross:
194 | assert False
195 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
196 |
197 | key = attn.to_k(encoder_hidden_states)
198 | value = attn.to_v(encoder_hidden_states)
199 |
200 | neg_key = self.to_k_neg(neg_hidden_states)
201 | neg_value = self.to_v_neg(neg_hidden_states)
202 |
203 | if self.add_pos:
204 | pos_key = self.to_k_extra_pos(encoder_hidden_states)
205 | pos_value = self.to_v_extra_pos(encoder_hidden_states)
206 |
207 | if is_torch2_available():
208 | # print('v2, attn_mask:{}'.format(attention_mask))
209 | hidden_states = compute_attention_v2(attn, query, key, value, attention_mask)
210 | neg_hidden_states = compute_attention_v2(attn, query, neg_key, neg_value, None)
211 | if self.add_pos:
212 | pos_hidden_states = compute_attention_v2(attn, query, pos_key, pos_value, attention_mask)
213 | else:
214 | # print('v1, attn_mask:{}'.format(attention_mask))
215 | hidden_states = compute_attention_v1(attn, query, key, value, attention_mask)
216 | neg_hidden_states = compute_attention_v1(attn, query, neg_key, neg_value, None)
217 | if self.add_pos:
218 | pos_hidden_states = compute_attention_v1(attn, query, pos_key, pos_value, attention_mask)
219 | if torch.is_tensor(self.scale):
220 | shape = [-1] + [1] * (neg_hidden_states.ndim - 1)
221 | scale = self.scale.view(*shape)
222 | else:
223 | scale = self.scale
224 |
225 |
226 | origin_norm = torch.norm(hidden_states, p=2)
227 | neg_norm = torch.norm(neg_hidden_states, p=2)
228 | neg_hidden_states = origin_norm * (neg_hidden_states / neg_norm)
229 | neg_hidden_states = self.to_k_neg_alpha * neg_hidden_states + self.to_k_neg_beta
230 |
231 | if self.add_pos:
232 | pos_norm = torch.norm(pos_hidden_states, p=2)
233 | pos_hidden_states = origin_norm * (pos_hidden_states / pos_norm)
234 | pos_hidden_states = self.to_k_extra_pos_alpha * pos_hidden_states + self.to_k_extra_pos_beta
235 |
236 | hidden_states = hidden_states - neg_hidden_states + pos_hidden_states
237 |
238 | # linear proj
239 | hidden_states = attn.to_out[0](hidden_states)
240 | # dropout
241 | hidden_states = attn.to_out[1](hidden_states)
242 |
243 | if input_ndim == 4:
244 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
245 |
246 | if attn.residual_connection:
247 | hidden_states = hidden_states + residual
248 |
249 | hidden_states = hidden_states / attn.rescale_output_factor
250 |
251 | return hidden_states
252 |
253 | class AttnProcessor2_0(torch.nn.Module):
254 | r"""
255 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
256 | """
257 | def __init__(
258 | self,
259 | hidden_size=None,
260 | cross_attention_dim=None,
261 | ):
262 | super().__init__()
263 | if not hasattr(F, "scaled_dot_product_attention"):
264 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
265 |
266 | def __call__(
267 | self,
268 | attn,
269 | hidden_states,
270 | encoder_hidden_states=None,
271 | attention_mask=None,
272 | temb=None,
273 | ):
274 | residual = hidden_states
275 |
276 | if attn.spatial_norm is not None:
277 | hidden_states = attn.spatial_norm(hidden_states, temb)
278 |
279 | input_ndim = hidden_states.ndim
280 |
281 | if input_ndim == 4:
282 | batch_size, channel, height, width = hidden_states.shape
283 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
284 |
285 | batch_size, sequence_length, _ = (
286 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
287 | )
288 |
289 | if attention_mask is not None:
290 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
291 | # scaled_dot_product_attention expects attention_mask shape to be
292 | # (batch, heads, source_length, target_length)
293 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
294 |
295 | if attn.group_norm is not None:
296 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
297 |
298 | query = attn.to_q(hidden_states)
299 |
300 | if encoder_hidden_states is None:
301 | encoder_hidden_states = hidden_states
302 | elif attn.norm_cross:
303 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
304 |
305 | key = attn.to_k(encoder_hidden_states)
306 | value = attn.to_v(encoder_hidden_states)
307 |
308 | inner_dim = key.shape[-1]
309 | head_dim = inner_dim // attn.heads
310 |
311 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
312 |
313 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
314 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
315 |
316 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
317 | # TODO: add support for attn.scale when we move to Torch 2.1
318 | hidden_states = F.scaled_dot_product_attention(
319 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
320 | )
321 |
322 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
323 | hidden_states = hidden_states.to(query.dtype)
324 |
325 | # linear proj
326 | hidden_states = attn.to_out[0](hidden_states)
327 | # dropout
328 | hidden_states = attn.to_out[1](hidden_states)
329 |
330 | if input_ndim == 4:
331 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
332 |
333 | if attn.residual_connection:
334 | hidden_states = hidden_states + residual
335 |
336 | hidden_states = hidden_states / attn.rescale_output_factor
337 |
338 | return hidden_states
339 |
340 |
341 | SUNAttnProcessor2_0 = SUNAttnProcessor
342 |
--------------------------------------------------------------------------------
/speed_up_net/sun_pipe.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Optional
2 | import torch
3 | import PIL
4 | import numpy as np
5 | import os
6 | import logging
7 | import json
8 | #from .attention_processor import is_torch2_available
9 | from peft import LoraModel, LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
10 | from safetensors.torch import load_file as safe_load
11 |
12 | import torch.nn.functional as F
13 | def is_torch2_available():
14 | return hasattr(F, "scaled_dot_product_attention")
15 |
16 | if is_torch2_available():
17 | from .attention_processor import AttnProcessor2_0 as AttnProcessor
18 | else:
19 | from .attention_processor import AttnProcessor
20 |
21 | from .attention_processor import SUNAttnProcessor
22 | from IPython import embed
23 |
24 | from diffusers import (
25 | StableDiffusionControlNetInpaintPipeline, StableDiffusionInpaintPipeline,
26 | StableDiffusionControlNetPipeline, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
27 | )
28 |
29 | def _build_adapter(unet, ip_param_scale=False, add_pos=False):
30 | attn_procs = {}
31 |
32 | unet_sd = unet.state_dict()
33 | for name in unet.attn_processors.keys():
34 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
35 | if name.startswith("mid_block"):
36 | hidden_size = unet.config.block_out_channels[-1]
37 | elif name.startswith("up_blocks"):
38 | block_id = int(name[len("up_blocks.")])
39 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
40 | elif name.startswith("down_blocks"):
41 | block_id = int(name[len("down_blocks.")])
42 | hidden_size = unet.config.block_out_channels[block_id]
43 | if cross_attention_dim is None:
44 | attn_procs[name] = AttnProcessor()
45 | else:
46 | attn_procs[name] = SUNAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=77, param_scale=ip_param_scale, add_pos=add_pos)
47 |
48 | unet.set_attn_processor(attn_procs)
49 | return unet
50 |
51 | def _load_adapter(adapt_unet, adapter_path):
52 | if adapter_path.endswith("safetensors"):
53 | state_dict = safe_load(adapter_path, "cpu")
54 | else:
55 | state_dict = torch.load(adapter_path, map_location='cpu')
56 | adapter_modules = torch.nn.ModuleList(adapt_unet.attn_processors.values())
57 | adapter_modules.load_state_dict(state_dict)
58 | return adapt_unet
59 |
60 |
61 | def _load_lora_dir_format(pipe, lora_dir):
62 | torch_device = pipe.unet.device
63 | torch_dtype = pipe.unet.dtype
64 |
65 | loraModelPath = os.path.join(lora_dir, "lora.pt")
66 | loraJsonPath = os.path.join(lora_dir, "lora_config.json")
67 |
68 | if not os.path.exists(loraJsonPath) or not os.path.exists(loraModelPath):
69 | logging.error(f"{lora_dir}: lora file not exists")
70 | exit(27)
71 |
72 | with open(loraJsonPath, "r") as f:
73 | lora_config = json.load(f)
74 |
75 | lora_checkpoint = torch.load(loraModelPath, map_location="cpu")
76 |
77 | unet_lora = {k: v for k, v in lora_checkpoint.items() if "text_encoder_" not in k}
78 | text_encoder_lora = {k.replace("text_encoder_", ""): v for k, v in lora_checkpoint.items() if "text_encoder_" in k}
79 |
80 | unet_config = LoraConfig(**lora_config["peft_config"])
81 | pipe.unet = LoraModel(unet_config, pipe.unet).to(torch_device).to(torch_dtype)
82 | set_peft_model_state_dict(pipe.unet, unet_lora)
83 |
84 | if "text_encoder_peft_config" in lora_config:
85 | text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
86 | pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder).to(torch_device).to(torch_dtype)
87 | set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora)
88 |
89 |
90 | class SUNPipe:
91 | def __init__(self, sd_pipe, sun_adapter_path, add_pos=False):
92 | self.pipe = sd_pipe
93 | unet = self.pipe.unet
94 | torch_device = unet.device
95 | torch_dtype = unet.dtype
96 | unet = _build_adapter(unet, ip_param_scale=False, add_pos=add_pos).to(torch_device, torch_dtype)
97 | unet = _load_adapter(unet, sun_adapter_path).to(torch_device, torch_dtype)
98 | self.pipe.unet = unet
99 | self.torch_device = torch_device
100 | self.pipe.scheduler.register_to_config(timestep_spacing="trailing")
101 |
102 | def load_lora_weights(self, lora_path):
103 | if lora_path.endswith(".safetensors"):
104 | self.pipe.load_lora_weights(lora_path)
105 | elif os.path.isdir(lora_path):
106 | pt_path = os.path.join(lora_path, "lora.pt")
107 | json_path = os.path.join(lora_path, "lora_config.json")
108 | assert os.path.exists(pt_path)
109 | assert os.path.exists(json_path)
110 | #assert False, "not supported yet"
111 | _load_lora_dir_format(self.pipe, lora_path)
112 | else:
113 | raise ValueError("lora path unrecognized")
114 |
115 | @torch.no_grad()
116 | def __call__(
117 | self,
118 | prompt: Union[str, List[str]] = None,
119 | image: Union[torch.Tensor, PIL.Image.Image] = None,
120 | mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
121 | control_image: Union[
122 | torch.FloatTensor,
123 | PIL.Image.Image,
124 | np.ndarray,
125 | List[torch.FloatTensor],
126 | List[PIL.Image.Image],
127 | List[np.ndarray],
128 | ] = None,
129 | height: Optional[int] = None,
130 | width: Optional[int] = None,
131 | strength: float = 1.0,
132 | num_inference_steps: int = 50,
133 | negative_prompt: Optional[Union[str, List[str]]] = None,
134 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
135 | controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
136 | eta: float = 0.0,
137 | prompt_embeds: Optional[torch.FloatTensor] = None,
138 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
139 | ):
140 | if prompt is not None and isinstance(prompt, str):
141 | bsz = 1
142 | elif prompt is not None and isinstance(prompt, list):
143 | bsz = len(prompt)
144 | else:
145 | bsz = prompt_embeds.shape[0]
146 |
147 | with torch.inference_mode():
148 | pos_neg_embed_cat_in_dim0 = self.pipe._encode_prompt(
149 | prompt,
150 | device=self.torch_device,
151 | num_images_per_prompt=1,
152 | do_classifier_free_guidance=True,
153 | negative_prompt=negative_prompt,
154 | prompt_embeds=prompt_embeds,
155 | negative_prompt_embeds=negative_prompt_embeds
156 | )
157 | assert pos_neg_embed_cat_in_dim0.size(0) == bsz * 2
158 | neg_embed, pos_embed = pos_neg_embed_cat_in_dim0.chunk(2, dim=0)
159 | pos_neg_embed_cat_in_dim1 = torch.cat([pos_embed, neg_embed], dim=1)
160 | assert pos_neg_embed_cat_in_dim1.size(0) == bsz
161 | assert pos_neg_embed_cat_in_dim1.size(1) == 77 * 2
162 |
163 | if isinstance(self.pipe, StableDiffusionControlNetInpaintPipeline):
164 | #"""
165 | images = self.pipe(
166 | image=image,
167 | mask_image=mask_image,
168 | control_image=control_image,
169 | height=height,
170 | width=width,
171 | strength=strength,
172 | num_inference_steps=num_inference_steps,
173 | prompt_embeds=pos_neg_embed_cat_in_dim1,
174 | guidance_scale=0.0,
175 | generator=generator,
176 | controlnet_conditioning_scale=controlnet_conditioning_scale,
177 | eta=eta
178 | )
179 |
180 | elif isinstance(self.pipe, StableDiffusionInpaintPipeline):
181 | assert control_image is None
182 | images = self.pipe(
183 | image=image,
184 | mask_image=mask_image,
185 | height=height,
186 | width=width,
187 | strength=strength,
188 | num_inference_steps=num_inference_steps,
189 | prompt_embeds=pos_neg_embed_cat_in_dim1,
190 | guidance_scale=0.0,
191 | generator=generator,
192 | eta=eta
193 | )
194 | elif isinstance(self.pipe, StableDiffusionControlNetPipeline):
195 | assert control_image is None
196 | assert mask_image is None
197 | images = self.pipe(
198 | image=image,
199 | height=height,
200 | width=width,
201 | num_inference_steps=num_inference_steps,
202 | prompt_embeds=pos_neg_embed_cat_in_dim1,
203 | guidance_scale=0.0,
204 | generator=generator,
205 | controlnet_conditioning_scale=controlnet_conditioning_scale,
206 | eta=eta
207 | )
208 | elif isinstance(self.pipe, StableDiffusionPipeline):
209 | assert control_image is None
210 | assert mask_image is None
211 | assert image is None
212 | images = self.pipe(
213 | height=height,
214 | width=width,
215 | num_inference_steps=num_inference_steps,
216 | prompt_embeds=pos_neg_embed_cat_in_dim1,
217 | guidance_scale=0.0,
218 | generator=generator,
219 | eta=eta
220 | )
221 | elif isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
222 | assert control_image is None
223 | assert mask_image is None
224 | assert image is not None
225 | images = self.pipe(
226 | image=image,
227 | strength=strength,
228 | num_inference_steps=num_inference_steps,
229 | prompt_embeds=pos_neg_embed_cat_in_dim1,
230 | guidance_scale=0.0,
231 | generator=generator,
232 | eta=eta
233 | )
234 |
235 | else:
236 | raise ValueError("not supported {}".format(type(self.pipe)))
237 |
238 | return images
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------