├── README.md ├── app.py ├── compare3.png ├── merge.py ├── requirements.txt ├── test_notebook.ipynb └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: ToDo 3 | emoji: ⚡️ 4 | app_file: app.py 5 | sdk: gradio 6 | sdk_version: 4.19.2 7 | --- 8 | # ToDo: Token Downsampling for Efficient Generation of High-Resolution Images 9 | --- 10 | 11 | We provide a [HuggingFace Spaces demo](https://huggingface.co/spaces/aningineer/ToDo) for our recently proposed method, ["ToDo: Token Downsampling for Efficient Generation of High-Resolution Images"](https://arxiv.org/abs/2402.13573), and compare it against a popular token merging method called ToMe. 12 | 13 | If you consider our research to be helpful, please consider citing us: 14 | ``` 15 | @misc{smith2024todo, 16 | title={ToDo: Token Downsampling for Efficient Generation of High-Resolution Images}, 17 | author={Ethan Smith and Nayan Saxena and Aninda Saha}, 18 | year={2024}, 19 | eprint={2402.13573}, 20 | archivePrefix={arXiv} 21 | } 22 | ``` 23 | 24 | ![GEuoFn1bMAABQqD](https://github.com/ethansmith2000/ImprovedTokenMerge/assets/98723285/82e03423-81e6-47da-afa4-9c1b2c1c4aeb) 25 | 26 | blog post: https://sweet-hall-e72.notion.site/ToDo-Token-Downsampling-for-Efficient-Generation-of-High-Resolution-Images-b41be1ac8ddc46be8cd687e67dee2d84?pvs=4 27 | 28 | hf demo: https://huggingface.co/spaces/aningineer/ToDo 29 | 30 | heavily inspired by https://github.com/dbolya/tomesd by @dbolya, a big thanks to the original authors. 31 | 32 | This project aims to adress some of the shortcomings of Token Merging for Stable Diffusion. Namely consistenly faster inference without quality loss. 33 | I found with the original that you would have to use a high merging ratio to get really any speedups at all, and by then quality was tarnished. Benchmarks here: https://github.com/dbolya/tomesd/issues/19#issuecomment-1507593483 34 | 35 | 36 | I propose two changes to the original to solve this. 37 | 1. Merging Method 38 | - the original calculates a similarity matrix of the input tokens and merges those with highest similarity 39 | - an issue here is that similarity calculation is O(n2) time, for ViT where token merging was proposed, you only had to do this a few times so it was quite efficient 40 | - here it needs to be done at every step, and the computation ends up being nearly as costly as attention itself 41 | - We can leverage a simple obsevation that nearby tokens tend to be similar to each other. 42 | - therefore we can merge tokens via downsampling which is very cheap and seems to be a good approximation 43 | - this can be analogized to grid-based subsampling of an image when using a nearest-neighbor downsample method, this is similar to what DiNAT (dilated neigborhood attention) does except for the fact we are still making use of global context 44 | 2. Merge Targets 45 | - the original merges the input tokens to attention, and then "unmerges" the resulting tokens to the original size 46 | - this operation seems to be quite lossy 47 | - instead i propose simply downsampling keys/values of the attention operation. both the QK calculation and QK * V can still drastically be reduced from the typical O(n2) scaling of attention, without needing to unmerge anything 48 | - queries are left fully intact, they just attend more sparsely to the image 49 | - attention for images, especially at larger resolutions, seems to be very sparse in general (QK matrix is low rank) so it does not appear that we lose too much from this 50 | 51 | putting this altogether we can get tangible speedups of ~1.5x at typical sizes like 768-1024 and up to 3x and beyond at 1536 to 2048 range, in combination with flash attention 52 | 53 | 54 | # Setup 🛠 55 | ``` 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | # Inference 🚀 60 | See the provided notebook, or gradio demo which you can run with python app.py 61 | 62 | 63 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import time 2 | import spaces 3 | import gradio as gr 4 | import torch 5 | import diffusers 6 | from utils import patch_attention_proc, remove_patch 7 | import math 8 | import numpy as np 9 | from PIL import Image 10 | from threading import Semaphore 11 | 12 | # Globals 13 | css = """ 14 | h1 { 15 | text-align: center; 16 | display: block; 17 | } 18 | """ 19 | 20 | # Pipeline 21 | pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16) 22 | pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config) 23 | pipe.safety_checker = None 24 | 25 | semaphore = Semaphore() # for preventing collisions of two simultaneous button presses 26 | 27 | @spaces.GPU 28 | def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method): 29 | 30 | semaphore.acquire() 31 | 32 | downsample_factor = 2 33 | ratio = 0.38 34 | merge_method = "downsample" if method == "todo" else "similarity" 35 | merge_tokens = "keys/values" if method == "todo" else "all" 36 | 37 | if height_width == 1024: 38 | downsample_factor = 2 39 | ratio = 0.75 40 | downsample_factor_level_2 = 1 41 | ratio_level_2 = 0.0 42 | elif height_width == 1536: 43 | downsample_factor = 3 44 | ratio = 0.89 45 | downsample_factor_level_2 = 1 46 | ratio_level_2 = 0.0 47 | elif height_width == 2048: 48 | downsample_factor = 4 49 | ratio = 0.9375 50 | downsample_factor_level_2 = 1 51 | ratio_level_2 = 0.0 52 | 53 | token_merge_args = {"ratio": ratio, 54 | "merge_tokens": merge_tokens, 55 | "merge_method": merge_method, 56 | "downsample_method": "nearest", 57 | "downsample_factor": downsample_factor, 58 | "timestep_threshold_switch": 0.0, 59 | "timestep_threshold_stop": 0.0, 60 | "downsample_factor_level_2": downsample_factor_level_2, 61 | "ratio_level_2": ratio_level_2 62 | } 63 | 64 | torch.manual_seed(seed) 65 | start_time_base = time.time() 66 | remove_patch(pipe) 67 | base_img = pipe(prompt, 68 | num_inference_steps=steps, height=height_width, width=height_width, 69 | negative_prompt=negative_prompt, 70 | guidance_scale=guidance_scale).images[0] 71 | end_time_base = time.time() 72 | 73 | result = f"Baseline image: {end_time_base-start_time_base:.2f} sec" 74 | 75 | semaphore.release() 76 | 77 | return base_img, result 78 | 79 | 80 | @spaces.GPU 81 | def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method): 82 | 83 | semaphore.acquire() 84 | 85 | downsample_factor = 2 86 | ratio = 0.38 87 | merge_method = "downsample" if method == "todo" else "similarity" 88 | merge_tokens = "keys/values" if method == "todo" else "all" 89 | 90 | if height_width == 1024: 91 | downsample_factor = 2 92 | ratio = 0.75 93 | downsample_factor_level_2 = 1 94 | ratio_level_2 = 0.0 95 | elif height_width == 1536: 96 | downsample_factor = 3 97 | ratio = 0.89 98 | downsample_factor_level_2 = 1 99 | ratio_level_2 = 0.0 100 | elif height_width == 2048: 101 | downsample_factor = 4 102 | ratio = 0.9375 103 | downsample_factor_level_2 = 1 104 | ratio_level_2 = 0.0 105 | 106 | token_merge_args = {"ratio": ratio, 107 | "merge_tokens": merge_tokens, 108 | "merge_method": merge_method, 109 | "downsample_method": "nearest", 110 | "downsample_factor": downsample_factor, 111 | "timestep_threshold_switch": 0.0, 112 | "timestep_threshold_stop": 0.0, 113 | "downsample_factor_level_2": downsample_factor_level_2, 114 | "ratio_level_2": ratio_level_2 115 | } 116 | 117 | patch_attention_proc(pipe.unet, token_merge_args=token_merge_args) 118 | torch.manual_seed(seed) 119 | start_time_merge = time.time() 120 | merged_img = pipe(prompt, 121 | num_inference_steps=steps, height=height_width, width=height_width, 122 | negative_prompt=negative_prompt, 123 | guidance_scale=guidance_scale).images[0] 124 | end_time_merge = time.time() 125 | 126 | result = f"{'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec" 127 | 128 | semaphore.release() 129 | 130 | return merged_img, result 131 | 132 | 133 | 134 | with gr.Blocks(css=css) as demo: 135 | gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images") 136 | prompt = gr.Textbox(interactive=True, label="prompt") 137 | negative_prompt = gr.Textbox(interactive=True, label="negative_prompt") 138 | 139 | with gr.Row(): 140 | method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)") 141 | height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)") 142 | 143 | with gr.Row(): 144 | guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1) 145 | steps = gr.Number(label="steps", value=20, precision=0) 146 | seed = gr.Number(label="seed", value=1, precision=0) 147 | 148 | with gr.Row(): 149 | with gr.Column(): 150 | base_result = gr.Textbox(label="Baseline Runtime") 151 | base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False) 152 | gen = gr.Button("Generate Baseline") 153 | gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt, 154 | guidance_scale, method], outputs=[base_image, base_result]) 155 | with gr.Column(): 156 | output_result = gr.Textbox(label="Runtime") 157 | output_image = gr.Image(label=f"image", type="pil", interactive=False) 158 | gen = gr.Button("Generate") 159 | gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt, 160 | guidance_scale, method], outputs=[output_image, output_result]) 161 | 162 | demo.launch(share=True) -------------------------------------------------------------------------------- /compare3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/ImprovedTokenMerge/81928e4fe47b2be47a983dcfdc0586c3cb138e61/compare3.png -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, Callable 3 | from diffusers.models.attention_processor import XFormersAttnProcessor, Attention 4 | import xformers, xformers.ops 5 | from typing import Optional 6 | import math 7 | import torch.nn.functional as F 8 | from diffusers.utils import USE_PEFT_BACKEND 9 | from diffusers.utils.import_utils import is_xformers_available 10 | 11 | if is_xformers_available(): 12 | import xformers 13 | import xformers.ops 14 | xformers_is_available = True 15 | else: 16 | xformers_is_available = False 17 | 18 | 19 | if hasattr(F, "scaled_dot_product_attention"): 20 | torch2_is_available = True 21 | else: 22 | torch2_is_available = False 23 | 24 | 25 | def init_generator(device: torch.device, fallback: torch.Generator = None): 26 | """ 27 | Forks the current default random generator given device. 28 | """ 29 | if device.type == "cpu": 30 | return torch.Generator(device="cpu").set_state(torch.get_rng_state()) 31 | elif device.type == "cuda": 32 | return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) 33 | else: 34 | if fallback is None: 35 | return init_generator(torch.device("cpu")) 36 | else: 37 | return fallback 38 | 39 | 40 | def do_nothing(x: torch.Tensor, mode: str = None): 41 | return x 42 | 43 | 44 | def mps_gather_workaround(input, dim, index): 45 | if input.shape[-1] == 1: 46 | return torch.gather( 47 | input.unsqueeze(-1), 48 | dim - 1 if dim < 0 else dim, 49 | index.unsqueeze(-1) 50 | ).squeeze(-1) 51 | else: 52 | return torch.gather(input, dim, index) 53 | 54 | 55 | def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method): 56 | batch_size = item.shape[0] 57 | 58 | item = item.reshape(batch_size, cur_h, cur_w, -1) 59 | item = item.permute(0, 3, 1, 2) 60 | df = cur_h // new_h 61 | if method in "max_pool": 62 | item = F.max_pool2d(item, kernel_size=df, stride=df, padding=0) 63 | elif method in "avg_pool": 64 | item = F.avg_pool2d(item, kernel_size=df, stride=df, padding=0) 65 | else: 66 | item = F.interpolate(item, size=(new_h, new_w), mode=method) 67 | item = item.permute(0, 2, 3, 1) 68 | item = item.reshape(batch_size, new_h * new_w, -1) 69 | 70 | return item 71 | 72 | 73 | def compute_merge(x: torch.Tensor, tome_info): 74 | original_h, original_w = tome_info["size"] 75 | original_tokens = original_h * original_w 76 | downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) 77 | dim = x.shape[-1] 78 | if dim == 320: 79 | cur_level = "level_1" 80 | downsample_factor = tome_info['args']['downsample_factor'] 81 | ratio = tome_info['args']['ratio'] 82 | elif dim == 640: 83 | cur_level = "level_2" 84 | downsample_factor = tome_info['args']['downsample_factor_level_2'] 85 | ratio = tome_info['args']['ratio_level_2'] 86 | else: 87 | cur_level = "other" 88 | downsample_factor = 1 89 | ratio = 0.0 90 | 91 | args = tome_info["args"] 92 | 93 | cur_h, cur_w = original_h // downsample, original_w // downsample 94 | new_h, new_w = cur_h // downsample_factor, cur_w // downsample_factor 95 | 96 | if tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_switch']: 97 | merge_method = args["merge_method"] 98 | else: 99 | merge_method = args["secondary_merge_method"] 100 | 101 | if cur_level != "other" and tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_stop']: 102 | if merge_method == "downsample" and downsample_factor > 1: 103 | m = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, args["downsample_method"]) 104 | u = lambda x: up_or_downsample(x, new_w, new_h, cur_w, cur_h, args["downsample_method"]) 105 | elif merge_method == "similarity" and ratio > 0.0: 106 | w = int(math.ceil(original_w / downsample)) 107 | h = int(math.ceil(original_h / downsample)) 108 | r = int(x.shape[1] * ratio) 109 | 110 | # Re-init the generator if it hasn't already been initialized or device has changed. 111 | if args["generator"] is None: 112 | args["generator"] = init_generator(x.device) 113 | elif args["generator"].device != x.device: 114 | args["generator"] = init_generator(x.device, fallback=args["generator"]) 115 | 116 | # If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same 117 | # batch, which causes artifacts with use_rand, so force it to be off. 118 | use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"] 119 | m, u = bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, 120 | no_rand=not use_rand, generator=args["generator"]) 121 | else: 122 | m, u = (do_nothing, do_nothing) 123 | else: 124 | m, u = (do_nothing, do_nothing) 125 | 126 | merge_fn, unmerge_fn = (m, u) 127 | 128 | return merge_fn, unmerge_fn 129 | 130 | 131 | def bipartite_soft_matching_random2d(metric: torch.Tensor, 132 | w: int, 133 | h: int, 134 | sx: int, 135 | sy: int, 136 | r: int, 137 | no_rand: bool = False, 138 | generator: torch.Generator = None) -> Tuple[Callable, Callable]: 139 | """ 140 | Partitions the tokens into src and dst and merges r tokens from src to dst. 141 | Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. 142 | 143 | Args: 144 | - metric [B, N, C]: metric to use for similarity 145 | - w: image width in tokens 146 | - h: image height in tokens 147 | - sx: stride in the x dimension for dst, must divide w 148 | - sy: stride in the y dimension for dst, must divide h 149 | - r: number of tokens to remove (by merging) 150 | - no_rand: if true, disable randomness (use top left corner only) 151 | - rand_seed: if no_rand is false, and if not None, sets random seed. 152 | """ 153 | B, N, _ = metric.shape 154 | 155 | if r <= 0: 156 | return do_nothing, do_nothing 157 | 158 | with torch.no_grad(): 159 | hsy, wsx = h // sy, w // sx 160 | 161 | # For each sy by sx kernel, randomly assign one token to be dst and the rest src 162 | if no_rand: 163 | rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) 164 | else: 165 | rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to( 166 | metric.device) 167 | 168 | # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead 169 | idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64) 170 | idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) 171 | idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) 172 | 173 | # Image is not divisible by sx or sy so we need to move it into a new buffer 174 | if (hsy * sy) < h or (wsx * sx) < w: 175 | idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) 176 | idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view 177 | else: 178 | idx_buffer = idx_buffer_view 179 | 180 | # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices 181 | rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) 182 | 183 | # We're finished with these 184 | del idx_buffer, idx_buffer_view 185 | 186 | # rand_idx is currently dst|src, so split them 187 | num_dst = hsy * wsx 188 | a_idx = rand_idx[:, num_dst:, :] # src 189 | b_idx = rand_idx[:, :num_dst, :] # dst 190 | 191 | def split(x): 192 | C = x.shape[-1] 193 | src = torch.gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) 194 | dst = torch.gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) 195 | return src, dst 196 | 197 | # Cosine similarity between A and B 198 | metric = metric / metric.norm(dim=-1, keepdim=True) 199 | a, b = split(metric) 200 | scores = a @ b.transpose(-1, -2) 201 | 202 | # Can't reduce more than the # tokens in src 203 | r = min(a.shape[1], r) 204 | 205 | # Find the most similar greedily 206 | node_max, node_idx = scores.max(dim=-1) 207 | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] 208 | 209 | unm_idx = edge_idx[..., r:, :] # Unmerged Tokens 210 | src_idx = edge_idx[..., :r, :] # Merged Tokens 211 | dst_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx) 212 | 213 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: 214 | src, dst = split(x) 215 | n, t1, c = src.shape 216 | 217 | unm = torch.gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) 218 | src = torch.gather(src, dim=-2, index=src_idx.expand(n, r, c)) 219 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) 220 | 221 | return torch.cat([unm, dst], dim=1) 222 | 223 | def unmerge(x: torch.Tensor) -> torch.Tensor: 224 | unm_len = unm_idx.shape[1] 225 | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] 226 | _, _, c = unm.shape 227 | 228 | src = torch.gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) 229 | 230 | # Combine back to the original shape 231 | out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) 232 | out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) 233 | out.scatter_(dim=-2, 234 | index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), 235 | src=unm) 236 | out.scatter_(dim=-2, 237 | index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), 238 | src=src) 239 | 240 | return out 241 | 242 | return merge, unmerge 243 | 244 | 245 | class TokenMergeAttentionProcessor: 246 | def __init__(self): 247 | # priortize torch2's flash attention, if not fall back to xformers then regular attention 248 | if torch2_is_available: 249 | self.attn_method = "torch2" 250 | elif xformers_is_available: 251 | self.attn_method = "xformers" 252 | else: 253 | self.attn_method = "regular" 254 | 255 | def torch2_attention(self, attn, query, key, value, attention_mask, batch_size): 256 | inner_dim=key.shape[-1] 257 | head_dim = inner_dim // attn.heads 258 | 259 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 260 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 261 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 262 | 263 | hidden_states = F.scaled_dot_product_attention( 264 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 265 | ) 266 | 267 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 268 | 269 | return hidden_states 270 | 271 | def xformers_attention(self, attn, query, key, value, attention_mask, batch_size): 272 | query = attn.head_to_batch_dim(query).contiguous() 273 | key = attn.head_to_batch_dim(key).contiguous() 274 | value = attn.head_to_batch_dim(value).contiguous() 275 | 276 | if attention_mask is not None: 277 | attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1]) 278 | 279 | hidden_states = xformers.ops.memory_efficient_attention( 280 | query, key, value, attn_bias=attention_mask, scale=attn.scale 281 | ) 282 | 283 | hidden_states = attn.batch_to_head_dim(hidden_states) 284 | 285 | return hidden_states 286 | 287 | 288 | def regular_attention(self, attn, query, key, value, attention_mask, batch_size): 289 | query = attn.head_to_batch_dim(query) 290 | key = attn.head_to_batch_dim(key) 291 | value = attn.head_to_batch_dim(value) 292 | 293 | if attention_mask is not None: 294 | attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1]) 295 | 296 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 297 | hidden_states = torch.bmm(attention_probs, value) 298 | hidden_states = attn.batch_to_head_dim(hidden_states) 299 | 300 | return hidden_states 301 | 302 | 303 | def __call__( 304 | self, 305 | attn: Attention, 306 | hidden_states: torch.FloatTensor, 307 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 308 | attention_mask: Optional[torch.FloatTensor] = None, 309 | temb: Optional[torch.FloatTensor] = None, 310 | scale: float = 1.0, 311 | ) -> torch.FloatTensor: 312 | residual = hidden_states 313 | if attn.spatial_norm is not None: 314 | hidden_states = attn.spatial_norm(hidden_states, temb) 315 | 316 | input_ndim = hidden_states.ndim 317 | 318 | if input_ndim == 4: 319 | batch_size, channel, height, width = hidden_states.shape 320 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 321 | 322 | batch_size, sequence_length, _ = ( 323 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 324 | ) 325 | 326 | if attention_mask is not None: 327 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 328 | # scaled_dot_product_attention expects attention_mask shape to be 329 | # (batch, heads, source_length, target_length) 330 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 331 | 332 | if attn.group_norm is not None: 333 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 334 | 335 | args = () if USE_PEFT_BACKEND else (scale,) 336 | 337 | if self._tome_info['args']['merge_tokens'] == "all": 338 | merge_fn, unmerge_fn = compute_merge(hidden_states, self._tome_info) 339 | hidden_states = merge_fn(hidden_states) 340 | 341 | query = attn.to_q(hidden_states, *args) 342 | 343 | if encoder_hidden_states is None: 344 | encoder_hidden_states = hidden_states 345 | elif attn.norm_cross: 346 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 347 | 348 | if self._tome_info['args']['merge_tokens'] == "keys/values": 349 | merge_fn, _ = compute_merge(encoder_hidden_states, self._tome_info) 350 | encoder_hidden_states = merge_fn(encoder_hidden_states) 351 | 352 | key = attn.to_k(encoder_hidden_states, *args) 353 | value = attn.to_v(encoder_hidden_states, *args) 354 | 355 | if self.attn_method == "torch2": 356 | hidden_states = self.torch2_attention(attn, query, key, value, attention_mask, batch_size) 357 | elif self.attn_method == "xformers": 358 | hidden_states = self.xformers_attention(attn, query, key, value, attention_mask, batch_size) 359 | else: 360 | hidden_states = self.regular_attention(attn, query, key, value, attention_mask, batch_size) 361 | 362 | hidden_states = hidden_states.to(query.dtype) 363 | 364 | # linear proj 365 | hidden_states = attn.to_out[0](hidden_states, *args) 366 | # dropout 367 | hidden_states = attn.to_out[1](hidden_states) 368 | 369 | if self._tome_info['args']['merge_tokens'] == "all": 370 | hidden_states = unmerge_fn(hidden_states) 371 | 372 | if input_ndim == 4: 373 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 374 | 375 | if attn.residual_connection: 376 | hidden_states = hidden_states + residual 377 | 378 | hidden_states = hidden_states / attn.rescale_output_factor 379 | 380 | return hidden_states 381 | 382 | 383 | 384 | 385 | 386 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | transformers 3 | accelerate 4 | xformers -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from merge import TokenMergeAttentionProcessor 3 | from diffusers.utils.import_utils import is_xformers_available 4 | from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, AttnProcessor 5 | import torch.nn.functional as F 6 | 7 | if is_xformers_available(): 8 | xformers_is_available = True 9 | else: 10 | xformers_is_available = False 11 | 12 | if hasattr(F, "scaled_dot_product_attention"): 13 | torch2_is_available = True 14 | else: 15 | torch2_is_available = False 16 | 17 | 18 | def hook_tome_model(model: torch.nn.Module): 19 | """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ 20 | 21 | def hook(module, args): 22 | module._tome_info["size"] = (args[0].shape[2], args[0].shape[3]) 23 | module._tome_info["timestep"] = args[1].item() 24 | return None 25 | 26 | model._tome_info["hooks"].append(model.register_forward_pre_hook(hook)) 27 | 28 | def remove_patch(pipe: torch.nn.Module): 29 | """ Removes a patch from a ToMe Diffusion module if it was already patched. """ 30 | 31 | if hasattr(pipe.unet, "_tome_info"): 32 | del pipe.unet._tome_info 33 | 34 | for n,m in pipe.unet.named_modules(): 35 | if hasattr(m, "processor"): 36 | m.processor = AttnProcessor2_0() 37 | 38 | def patch_attention_proc(unet, token_merge_args={}): 39 | unet._tome_info = { 40 | "size": None, 41 | "timestep": None, 42 | "hooks": [], 43 | "args": { 44 | "ratio": token_merge_args.get("ratio", 0.5), # ratio of tokens to merge 45 | "sx": token_merge_args.get("sx", 2), # stride x for sim calculation 46 | "sy": token_merge_args.get("sy", 2), # stride y for sim calculation 47 | "use_rand": token_merge_args.get("use_rand", True), 48 | "generator": None, 49 | 50 | "merge_tokens": token_merge_args.get("merge_tokens", "keys/values"), # ["all", "keys/values"] 51 | "merge_method": token_merge_args.get("merge_method", "downsample"), # ["none","similarity", "downsample"] 52 | "downsample_method": token_merge_args.get("downsample_method", "nearest-exact"), 53 | # native torch interpolation methods ["nearest", "linear", "bilinear", "bicubic", "nearest-exact"] 54 | "downsample_factor": token_merge_args.get("downsample_factor", 2), # amount to downsample by 55 | "timestep_threshold_switch": token_merge_args.get("timestep_threshold_switch", 0.2), 56 | # timestep to switch to secondary method, 0.2 means 20% steps remaining 57 | "timestep_threshold_stop": token_merge_args.get("timestep_threshold_stop", 0.0), 58 | # timestep to stop merging, 0.0 means stop at 0 steps remaining 59 | "secondary_merge_method": token_merge_args.get("secondary_merge_method", "similarity"), 60 | # ["none", "similarity", "downsample"] 61 | 62 | "downsample_factor_level_2": token_merge_args.get("downsample_factor_level_2", 1), # amount to downsample by at the 2nd down block of unet 63 | "ratio_level_2": token_merge_args.get("ratio_level_2", 0.5), # ratio of tokens to merge at the 2nd down block of unet 64 | } 65 | } 66 | hook_tome_model(unet) 67 | attn_modules = [module for name, module in unet.named_modules() if module.__class__.__name__ == 'BasicTransformerBlock'] 68 | 69 | for i, module in enumerate(attn_modules): 70 | module.attn1.processor = TokenMergeAttentionProcessor() 71 | module.attn1.processor._tome_info = unet._tome_info 72 | 73 | 74 | def remove_patch(pipe: torch.nn.Module): 75 | """ Removes a patch from a ToMe Diffusion module if it was already patched. """ 76 | 77 | # this will remove our custom class 78 | if torch2_is_available: 79 | for n,m in pipe.unet.named_modules(): 80 | if hasattr(m, "processor"): 81 | m.processor = AttnProcessor2_0() 82 | 83 | elif xformers_is_available: 84 | pipe.enable_xformers_memory_efficient_attention() 85 | 86 | else: 87 | for n,m in pipe.unet.named_modules(): 88 | if hasattr(m, "processor"): 89 | m.processor = AttnProcessor() 90 | --------------------------------------------------------------------------------