├── __init__.py ├── utils.py ├── README.md ├── patch.py └── sito.py /__init__.py: -------------------------------------------------------------------------------- 1 | from . import sito, patch 2 | from .patch import apply_patch, remove_patch 3 | 4 | __all__ = ["sito", "patch", "apply_patch", "remove_patch"] 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def isinstance_str(x: object, cls_name: str): 5 | """ 6 | Checks whether x has any class *named* cls_name in its ancestry. 7 | Doesn't require access to the class's implementation. 8 | 9 | Useful for patching! 10 | """ 11 | for _cls in x.__class__.__mro__: 12 | if _cls.__name__ == cls_name: 13 | return True 14 | 15 | return False 16 | 17 | 18 | def init_generator(device: torch.device, fallback: torch.Generator = None): 19 | """ 20 | Forks the current default random generator given device. 21 | """ 22 | if device.type == "cpu": 23 | return torch.Generator(device="cpu").set_state(torch.get_rng_state()) 24 | elif device.type == "cuda": 25 | return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) 26 | else: 27 | if fallback is None: 28 | return init_generator(torch.device("cpu")) 29 | else: 30 | return fallback 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

SiTo: Training-Free and Hardware-Friendly Acceleration for Diffusion Models via Similarity-based Token Pruning (AAAI-2025)

2 | 3 |

4 | Overall Workflow of the CPA-Enhancer Framework 5 |
6 |

7 | 8 | 9 | ## 🔥 News 10 | - `2024/12/10`🤗🤗 Our [paper](https://www.researchgate.net/publication/387204421_Training-Free_and_Hardware-Friendly_Acceleration_for_Diffusion_Models_via_Similarity-based_Token_Pruning) is accepted by AAAI-2025 11 | - `2025/1/18` 💥💥 We release the [code](https://github.com/EvelynZhang-epiclab/SiTo) for our work about accelerating diffusion models for FREE. 🎉 **The zero-shot evaluation shows SiTo leads to 1.90x and 1.75x acceleration on COCO30K and ImageNet with 1.33 and 1.15 FID reduction at the same time. Besides, SiTo has no training requirements and does not require any calibration data, making it plug-and-play in real-world applications.** 12 | - `2025/1/21` 💿💿 We publish the required image for running our code. Click [here](https://www.codewithgpu.com/i/EvelynZhang-epiclab/SiTo/SiTo-SD) to use it directly. 13 | 14 | ## 🌸 Abstract 15 |
16 | 17 | CLICK for full abstract 18 | 19 | > The excellent performance of diffusion models in image generation is always accompanied by overlarge computation costs, which have prevented the application of diffusion models in edge devices and interactive applications. Previous works mainly focus on using fewer sampling steps and compressing the denoising network of diffusion models, while this paper proposes to accelerate diffusion models by introducing **SiTo, a similarity-based token pruning method** that adaptive prunes the redundant tokens in the input data. SiTo is designed to maximize the similarity between model prediction with and without token pruning by using cheap and hardware-friendly operations, leading to significant acceleration ratios without performance drop, and even sometimes improvements in the generation quality. For instance, **the zero-shot evaluation shows SiTo leads to 1.90x and 1.75x acceleration on COCO30K and ImageNet with 1.33 and 1.15 FID reduction** at the same time. Besides, SiTo has **no training requirements and does not require any calibration data**, making it plug-and-play in real-world applications. 20 | 21 |
22 | 23 | 24 | ## 🚀Overview 25 | 26 | SiTo has a three-stage pipeline. 27 | - SiTo carefully selects a set of **base tokens** which are utilized as the base to select and recover the pruned tokens. 28 | - SiTo selects the tokens that have the highest similarity to the base tokens as the **pruned tokens**. 29 | - SiTo feeds the unpruned tokens to the neural layers and **recovers the pruned tokens** by directly copying their most similar base tokens. 30 | 31 |

32 | Overall Workflow of the CPA-Enhancer Framework 33 |
34 | The pipeline of SiTo on the example of self-attention. (a) Base Token Selection: We compute the Cosine Similarity between all the tokens. For each token, we sum its similarity to all the tokens as the SimScore. Then, Gaussian Noise is added to the SimScore introduces randomness, preventing identical base and pruned token choices across timesteps. Finally, the token that has the highest Noise SimScore in an image patch is selected as a base token. (b) Pruned Token Selection: The tokens with the highest similarity to the base tokens are selected as pruned tokens. (c) Pruned Token Recovery: The unpruned tokens are fed to the neural layers. Then, the pruned tokens are recovered by copying from their most similar base tokens. 35 |

36 | 37 | ## 📊Result 38 | ### Qualitative Result 39 |

40 | Overall Workflow of the CPA-Enhancer Framework 41 |
42 | Visual comparisons with the manually crafted challenging prompts. We apply ToMeSD and SiTo on stable diffusion v1.5, achieving similar speed-up ratios of 1.63 and 1.65, respectively. Under these comparable conditions, our method generated more realistic, detailed images that better aligned with the original images and text prompts. 43 |

44 | 45 | ### Quantitative Result 46 |

47 | Overall Workflow of the CPA-Enhancer Framework 48 |
49 | Comparison between the proposed SiTo and ToMeSD with SD v1.5 and SD v2 on ImageNet and COCO30k. 50 |

51 | 52 | ## 🛠 Usage 53 | ### Dependencies 54 | To run SiTo for SD, PyTorch version `1.12.1` or higher is required (due to the use of `scatter_reduce`). You can download it from [here](https://pytorch.org/get-started/locally/). 55 | 56 | ### Installation 57 | ```shell 58 | git clone https://github.com/EvelynZhang-epiclab/SiTo.git 59 | ``` 60 | ### Apply SiTo 61 | Applying SiTo is very simple, you just need the following two steps (and no additional training is required): 62 | 63 | Step1: Add our code package `sito` in the `scripts`. 64 | 65 | Step2:Apply SiTo in SD v1 and SD v2: 66 | Add the following code at the respective lines of [SD v1](https://github.com/runwayml/stable-diffusion/blob/08ab4d326c96854026c4eb3454cd3b02109ee982/scripts/txt2img.py#L241) or [SD v2](https://github.com/Stability-AI/stablediffusion/blob/fc1488421a2761937b9d54784194157882cbc3b1/scripts/txt2img.py#L220): 67 | 68 | ```python 69 | import sito 70 | sito.apply_patch(model,prune_ratio=0.5) 71 | ``` 72 | As follows, we also provide more fine-grained parameter control. 73 | ```python 74 | import sito 75 | sito.apply_patch(model, 76 | prune_ratio = 0.7, # The pruning ratio 77 | max_downsample_ratio = 1, # The number of layers to prune in the Unet. It is recommended to prune only the first layer (see Fig. 7 in the paper for details) 78 | prune_selfattn_flag = True, # Whether to prune the self-attention layers. Recommended 79 | prune_crossattn_flag = False, # Whether to prune the cross-attention layers. Not recommended 80 | prune_mlp_flag: bool = False, # Whether to prune the MLP layers. Strongly not recommended (see Tab. 2 in the paper for details) 81 | sx: int = 2, sy: int = 2, # Patch size 82 | noise_alpha= 0.1, # Controls the noise level 83 | sim_beta:float = 1 84 | ) 85 | ``` 86 | ### Run SiTo 87 | 88 | ~~~python 89 | # After setting up the environment, compile it. 90 | pip install -v -e . 91 | ~~~ 92 | 93 | - Generate an image based on a prompt. 94 | ~~~ 95 | python scripts/txt2img.py --n_iter 1 --n_samples 1 --W 512 --H 512 --ddim_steps 50 --plms --skip_grid --prompt "a photograph of an astronaut riding a horse" 96 | ~~~ 97 | 98 | - Read prompts from a `.txt` file to generate images. Use `--from-file imagenet.txt` for generating ImageNet 2k images, and `--from-file coco30k.txt` for generating COCO 30k images. 99 | 100 | ~~~ 101 | python scripts/txt2img.py --n_iter 2 --n_samples 4 --W 512 --H 512 --ddim_steps 50 --plms --skip_grid --from-file imagenet.txt 102 | ~~~ 103 | 104 | - When measuring speed, set `n_iter` to at least 2 (because at least one iteration is required for warm-up). Enable both `--skip_save` and `--skip_grid` to avoid saving images. 105 | 106 | ~~~ 107 | python scripts/txt2img.py --n_iter 3 --n_samples 8 --W 512 --H 512 --ddim_steps 50 --plms --skip_save --skip_grid --from-file xxx.txt 108 | ~~~ 109 | 110 | 111 | ## 📐 Evaluation 112 | 113 | ### FID 114 | 115 | This implementation references the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository. 116 | Modify [this line](https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py#L146C3-L146C64]) in `pytorch_fid/fid_score.py` as follows: 117 | ~~~python 118 | dataset = ImagePathDataset(files, transforms=my_transform) 119 | my_transform = TF.Compose([ 120 | TF.Resize((512, 512)), 121 | TF.ToTensor(), 122 | TF.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 123 | ]) 124 | ~~~ 125 | 126 | - Download Data 127 | 128 | [https://drive.google.com/drive/folders/1pIZKbvRx1WFAx_6sAaFHW4K6jugc_b3c?usp=drive_link](https://drive.google.com/drive/folders/1pIZKbvRx1WFAx_6sAaFHW4K6jugc_b3c?usp=drive_link) 129 | 130 | ~~~python 131 | python -m pytorch_fid path/to/[datasets].npz path/to/images 132 | ~~~ 133 | 134 | ### Time 135 | 136 | - It is recommended to use torch.cuda.Event to measure time (since GPUs perform parallel computation, using the regular time.time can be inaccurate). Refer to the following code: 137 | 138 | ~~~py 139 | time0= torch.cuda.Event(enable_timing=True) 140 | time1= torch.cuda.Event(enable_timing=True) 141 | time0.record() 142 | # Place the code segment that needs to be measured for time here. 143 | time1.record() 144 | torch.cuda.synchronize() 145 | time_consume=time0.elapsed_time(time1) 146 | ~~~ 147 | 148 | _Note: When measuring speed, it is recommended to perform a warm-up (for example, exclude the time taken for the first few iterations from the statistics)._ 149 | 150 | ## 💐 Acknowledgments 151 | 152 | Special thanks to the creators of [ToMeSD](https://github.com/dbolya/tomesd) upon which this code is built, for their valuable work in advancing diffusion model acceleration. 153 | 154 | ## 🔗 Citation 155 | If you use this codebase, or SiTo inspires your work, we would greatly appreciate it if you could star the repository and cite it using the following BibTeX entry. 156 | ``` 157 | @inproceedings{zhang2025training, 158 | title={Training-free and hardware-friendly acceleration for diffusion models via similarity-based token pruning}, 159 | author={Zhang, Evelyn and Tang, Jiayi and Ning, Xuefei and Zhang, Linfeng}, 160 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 161 | volume={39}, 162 | number={9}, 163 | pages={9878--9886}, 164 | year={2025} 165 | } 166 | ``` 167 | ## :e-mail: Contact 168 | If you have more questions or are seeking collaboration, feel free to contact me via email at [`evelynzhang2002@163.com`](mailto:yuweizhang2002@gmail.com). 169 | -------------------------------------------------------------------------------- /patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from typing import Type, Dict, Any, Tuple, Callable 4 | import copy 5 | from . import sito 6 | from .utils import isinstance_str, init_generator 7 | 8 | import torch.nn.functional as F 9 | import time 10 | 11 | def select_sito_method(x: torch.Tensor, sito_info: Dict[str, Any]) -> Tuple[Callable, ...]: 12 | args = sito_info["args"] 13 | current_timestep = sito_info["timestep"] 14 | original_h, original_w = sito_info["size"] # 64,64 15 | original_tokens = original_h * original_w 16 | downsaple_ratio = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) 17 | 18 | if (downsaple_ratio <= args["max_downsample_ratio"]): 19 | w = int(math.ceil(original_w / downsaple_ratio)) 20 | h = int(math.ceil(original_h / downsaple_ratio)) 21 | num_prune = int(x.shape[1] * args["prune_ratio"]) 22 | p, r = sito.prune_and_recover_tokens(x, num_prune, w=w, h=h, sx=args['sx'], sy=args['sy'],noise_alpha=args['noise_alpha'],sim_beta=args['sim_beta'],current_timestep=current_timestep) 23 | 24 | else: 25 | p, r = (sito.do_nothing, sito.do_nothing) 26 | p_a, r_a = (p, r) if args["prune_selfattn_flag"] else (sito.do_nothing, sito.do_nothing) 27 | p_c, r_c = (p, r) if args["prune_crossattn_flag"] else (sito.do_nothing, sito.do_nothing) 28 | p_m, r_m = (p, r) if args["prune_mlp_flag"] else (sito.do_nothing, sito.do_nothing) 29 | return p_a, p_c, p_m, r_a, r_c, r_m 30 | 31 | 32 | def make_sito_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: 33 | class sitoBlock(block_class): 34 | _parent = block_class 35 | def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: 36 | p_a, p_c, p_m, r_a, r_c, r_m = select_sito_method(x, self._sito_info) # 选择方法 37 | # self-attention 38 | prune_x = p_a(self.norm1(x)) 39 | out1 = self.attn1(prune_x, context=context if self.disable_self_attn else None) 40 | x = r_a(out1) + x 41 | # cross-attention 42 | prop_x = p_c(self.norm2(x)) 43 | out2= self.attn2(prop_x, context=context) 44 | x = r_c(out2) + x 45 | # MLP 46 | x = r_m(self.ff(p_m(self.norm3(x)))) + x 47 | return x 48 | return sitoBlock 49 | 50 | 51 | def make_diffusers_sito_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: 52 | """ 53 | Make a patched class for a diffusers model. 54 | This patch applies ToMe to the forward function of the block. 55 | """ 56 | 57 | class sitoBlock(block_class): 58 | # Save for unpatching later 59 | _parent = block_class 60 | 61 | def forward( 62 | self, 63 | hidden_states, 64 | attention_mask=None, 65 | encoder_hidden_states=None, 66 | encoder_attention_mask=None, 67 | timestep=None, 68 | cross_attention_kwargs=None, 69 | class_labels=None, 70 | ) -> torch.Tensor: 71 | # (1) sito 72 | p_a, p_c, p_m, r_a, r_c, r_m = select_sito_method(hidden_states, self._sito_info) 73 | 74 | if self.use_ada_layer_norm: 75 | norm_hidden_states = self.norm1(hidden_states, timestep) 76 | elif self.use_ada_layer_norm_zero: 77 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 78 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 79 | ) 80 | else: 81 | norm_hidden_states = self.norm1(hidden_states) 82 | 83 | # (2) sito p_a 84 | norm_hidden_states = p_a(norm_hidden_states) 85 | # 1. Self-Attention 86 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 87 | attn_output = self.attn1( 88 | norm_hidden_states, 89 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 90 | attention_mask=attention_mask, 91 | **cross_attention_kwargs, 92 | ) 93 | if self.use_ada_layer_norm_zero: 94 | attn_output = gate_msa.unsqueeze(1) * attn_output 95 | 96 | # (3) sito r_a 97 | hidden_states = r_a(attn_output) + hidden_states 98 | 99 | if self.attn2 is not None: 100 | norm_hidden_states = ( 101 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 102 | ) 103 | # (4) sito p_c 104 | norm_hidden_states = p_c(norm_hidden_states) 105 | 106 | # 2. Cross-Attention 107 | attn_output = self.attn2( 108 | norm_hidden_states, 109 | encoder_hidden_states=encoder_hidden_states, 110 | attention_mask=encoder_attention_mask, 111 | **cross_attention_kwargs, 112 | ) 113 | # (5) sito r_c 114 | hidden_states = r_c(attn_output) + hidden_states 115 | 116 | # 3. Feed-forward 117 | norm_hidden_states = self.norm3(hidden_states) 118 | 119 | if self.use_ada_layer_norm_zero: 120 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 121 | 122 | # (6) sito p_m 123 | norm_hidden_states = p_m(norm_hidden_states) 124 | 125 | ff_output = self.ff(norm_hidden_states) 126 | 127 | if self.use_ada_layer_norm_zero: 128 | ff_output = gate_mlp.unsqueeze(1) * ff_output 129 | 130 | # (7) sito r_m 131 | hidden_states = r_m(ff_output) + hidden_states 132 | 133 | return hidden_states 134 | 135 | return sitoBlock 136 | 137 | 138 | def hook_sito_model(model: torch.nn.Module): 139 | def hook(module, args): 140 | module._sito_info["size"] = (args[0].shape[2], args[0].shape[3]) 141 | module._sito_info["timestep"] = args[1][0].cpu().item() 142 | return None 143 | model._sito_info["hooks"].append(model.register_forward_pre_hook(hook)) 144 | 145 | def apply_patch( 146 | model: torch.nn.Module, 147 | prune_ratio: int = 0.5, 148 | max_downsample_ratio: int = 1, 149 | prune_selfattn_flag: bool = True, 150 | prune_crossattn_flag: bool = False, 151 | prune_mlp_flag: bool = False, 152 | sx: int = 2, sy: int = 2, 153 | noise_alpha:float = 0.1, 154 | sim_beta:float = 1 155 | ): 156 | remove_patch(model) 157 | 158 | is_diffusers = isinstance_str(model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin") 159 | 160 | if not is_diffusers: 161 | if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"): 162 | raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.") 163 | diffusion_model = model.model.diffusion_model 164 | else: 165 | diffusion_model = model.unet if hasattr(model, "unet") else model 166 | 167 | diffusion_model._sito_info = { 168 | "name": None, 169 | "size": None, 170 | "timestep": None, 171 | "hooks": [], 172 | "args": { 173 | "prune_selfattn_flag": prune_selfattn_flag, 174 | "prune_crossattn_flag": prune_crossattn_flag, 175 | "prune_mlp_flag": prune_mlp_flag, 176 | "prune_ratio": prune_ratio, 177 | "max_downsample_ratio": max_downsample_ratio, 178 | "sx": sx, "sy": sy, 179 | "noise_alpha":noise_alpha, 180 | "sim_beta":sim_beta 181 | } 182 | } 183 | hook_sito_model(diffusion_model) # 添加size属性 184 | for x, module in diffusion_model.named_modules(): 185 | if isinstance_str(module, "BasicTransformerBlock"): 186 | make_sito_block_fn = make_diffusers_sito_block if is_diffusers else make_sito_block 187 | module.__class__ = make_sito_block_fn( 188 | module.__class__) 189 | module._sito_info = diffusion_model._sito_info 190 | module._myname = x 191 | if not hasattr(module, "disable_self_attn") and not is_diffusers: 192 | module.disable_self_attn = False 193 | if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers: 194 | module.use_ada_layer_norm = False 195 | module.use_ada_layer_norm_zero = False 196 | return model 197 | 198 | def remove_patch(model: torch.nn.Module): 199 | """ Removes a patch from a sito Diffusion module if it was already patched. """ 200 | model = model.unet if hasattr(model, "unet") else model 201 | for _, module in model.named_modules(): 202 | if hasattr(module, "_sito_info"): 203 | for hook in module._sito_info["hooks"]: 204 | hook.remove() 205 | module._sito_info["hooks"].clear() 206 | 207 | if module.__class__.__name__ == "sitoBlock": 208 | module.__class__ = module._parent 209 | 210 | return model 211 | 212 | ''' 213 | Unet Name 214 | input_blocks.1.1.transformer_blocks.0 215 | input_blocks.2.1.transformer_blocks.0 216 | 217 | input_blocks.4.1.transformer_blocks.0 218 | input_blocks.5.1.transformer_blocks.0 219 | 220 | input_blocks.7.1.transformer_blocks.0 221 | input_blocks.8.1.transformer_blocks.0 222 | 223 | middle_block.1.transformer_blocks.0 224 | 225 | output_blocks.3.1.transformer_blocks.0 226 | output_blocks.4.1.transformer_blocks.0 227 | output_blocks.5.1.transformer_blocks.0 228 | 229 | output_blocks.6.1.transformer_blocks.0 230 | output_blocks.7.1.transformer_blocks.0 231 | output_blocks.8.1.transformer_blocks.0 232 | 233 | output_blocks.9.1.transformer_blocks.0 234 | output_blocks.10.1.transformer_blocks.0 235 | output_blocks.11.1.transformer_blocks.0 236 | ''' -------------------------------------------------------------------------------- /sito.py: -------------------------------------------------------------------------------- 1 | import time 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | from typing import Tuple, Callable, Optional, Dict, Any 6 | import math 7 | from .utils import init_generator 8 | 9 | 10 | def do_nothing(x: torch.Tensor, mode: str = None): 11 | return x 12 | 13 | count_num = torch.zeros((2, 4096), dtype=torch.int32, device="cuda") 14 | 15 | 16 | def do_nothing(x: torch.Tensor, mode: str = None): 17 | return x 18 | 19 | def find_patch_max_indices(tensor, sx, sy): # tensor: (B,N) 20 | b, N = tensor.size() 21 | n = int(math.sqrt(N)) 22 | tensor = tensor.view(b, n, n) 23 | h_patches = n // sy 24 | w_patches = n // sx 25 | tensor = tensor[:, :h_patches * sy, :w_patches * sx] 26 | tensor_reshaped = tensor.view(b, h_patches, sy, w_patches, sx) 27 | tensor_reshaped = tensor_reshaped.permute(0, 1, 3, 2, 4).contiguous() 28 | tensor_reshaped = tensor_reshaped.view(b, h_patches, w_patches, sy * sx) 29 | _, max_indices = tensor_reshaped.max(dim=-1, keepdim=True) 30 | return max_indices 31 | 32 | def duplicate_half_tensor(tensor): 33 | B, _, _,_ = tensor.shape 34 | if B % 2 != 0: 35 | raise ValueError("B must be even for this operation") 36 | first_half = tensor[:B // 2] 37 | tensor[B // 2:] = first_half 38 | return tensor 39 | 40 | def prune_and_recover_tokens(metric: torch.Tensor, 41 | num_prune: int, 42 | w: int, 43 | h: int, 44 | sx: int, 45 | sy: int, 46 | sim_beta: float, 47 | noise_alpha: float, 48 | current_timestep: int 49 | ) -> Tuple[Callable, Callable]: 50 | B, N, C = metric.shape 51 | if num_prune <= 0: # 如果r<0, 什么也不做 52 | return do_nothing, do_nothing 53 | gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather 54 | with torch.no_grad(): 55 | metric = metric.to(torch.float16) 56 | metric = F.normalize(metric, p=2, dim=-1) 57 | consine_graph = torch.matmul(metric, metric.transpose(-1, -2)) 58 | # my_norm = metric.norm(dim=-1, keepdim=True) 59 | # my_norm = my_norm.to(torch.float16) 60 | # metric = metric / my_norm 61 | # consine_graph = metric @ metric.transpose(-1, -2) 62 | hsy, wsx = h // sy, w // sx 63 | # ############################################################## 64 | # Method 1: Select the highest score based on (sim_beta * SimScore + noise_alpha * Noise) within each patch. 65 | dst_score = consine_graph.sum(-1) 66 | noise_score= torch.randn(B, N, dtype=metric.dtype, device=metric.device) 67 | dst_score=sim_beta*dst_score+noise_alpha*noise_score 68 | rand_idx = find_patch_max_indices(dst_score, sx, sy) # [B,hsy,wsx,1] 69 | # Align CFG 70 | # rand_idx = duplicate_half_tensor(rand_idx) 71 | ############################################################## 72 | # Method 2:LocalRandom 73 | # generator = init_generator(metric.device) 74 | # rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) # [hsy, wsx, 1] 75 | # rand_idx = rand_idx.unsqueeze(0).expand(B, -1, -1, -1) 76 | ############################################################# 77 | # Method3: Fixed selection of the top-left corner within each patch. 78 | # rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) 79 | # rand_idx = rand_idx.unsqueeze(0).expand(B,-1,-1,-1) 80 | ############################################################## 81 | idx_buffer_view = torch.zeros(B, hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64) 82 | idx_buffer_view.scatter_(dim=-1, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) 83 | idx_buffer_view = idx_buffer_view.view(B, hsy, wsx, sy, sx).transpose(2, 3).reshape(B, hsy * sy, wsx * sx) 84 | if (hsy * sy) < h or (wsx * sx) < w: 85 | idx_buffer = torch.zeros(B, h, w, device=metric.device, dtype=torch.int64) 86 | idx_buffer[:, :(hsy * sy), :(wsx * sx)] = idx_buffer_view # (B,h,w) 87 | else: 88 | idx_buffer = idx_buffer_view # (B,h,w) 89 | rand_idx = idx_buffer.reshape(B, -1).argsort(dim=1) # (B,N) 90 | del idx_buffer, idx_buffer_view 91 | 92 | ############################################################# 93 | # Method 4: Randomly select within the global scope. 94 | # random_permutation = torch.randperm(N).to(metric.device) 95 | # rand_idx=random_permutation.unsqueeze(0).expand(B,-1) 96 | ############################################################ 97 | # Method 5: Select the maximum SimScore within the global scope. 98 | # dst_score = consine_graph.sum(-1) # (B,N) 99 | # _, rand_idx = torch.sort(dst_score, dim=1, descending=True) 100 | ############################################################# 101 | num_dst = hsy * wsx 102 | a_idx = rand_idx[:, num_dst:] # src (B,N-num_dst) 0~N 103 | b_idx = rand_idx[:, :num_dst] # dst (B,num_dst) 0~N 104 | 105 | score_sim_1 = gather(consine_graph, dim=1, index=a_idx.unsqueeze(-1).expand(B, -1, N)) # (B,N-num_dst,N) 106 | score_sim = gather(score_sim_1, dim=-1, 107 | index=b_idx.unsqueeze(1).expand(B, score_sim_1.shape[1], -1)) # (B,N-num_dst,num_dst) 108 | score_sim_value, score_sim_index = torch.max(score_sim, dim=2) # (B,N-num_dst) 0~num_dst 109 | total_score = - score_sim_value 110 | 111 | edge_idx = total_score.argsort(dim=-1) # The size is (B, N-num_dst), sorted in ascending order, with values as indices (0~N-num_dst). 112 | unm_idx = edge_idx[..., num_prune:] # Size: (B,N-num_dst-num_prune) Value:(0~N-num_dst) [8,1024] 0~3071 113 | src_idx = edge_idx[..., :num_prune] # Size: (B,num_prune) Value:(0~N-num_dst) [8,2048] 0~3071 1896 114 | a_idx_tmp = a_idx.expand(B, N - num_dst) # (8,3072) 0~4095 115 | a_unm_idx = gather(a_idx_tmp, dim=1, index=unm_idx) # Size: (B,N-num_dst-unm_prop) Value:(0~N) 116 | a_src_idx = gather(a_idx_tmp, dim=1, 117 | index=src_idx) # Size:(B,num_prune) Value(0~N) idx 1896 out of bound with size 1024 118 | ################################################################################# 119 | # Count and visualize the number of pruning occurrences. 120 | """ 121 | global count_num 122 | count_num.scatter_add_(1, a_src_idx, torch.ones_like(a_src_idx, dtype=torch.int32)) 123 | 124 | if current_timestep ==1: 125 | count_num.to("cpu") 126 | n = int(np.sqrt(N)) 127 | 128 | max_value = count_num.max().item() 129 | #print('max_value',max_value) 130 | 131 | value_range = torch.arange(1, max_value + 1) 132 | histograms = torch.zeros((B, max_value + 1), dtype=torch.int32) 133 | 134 | for b in range(B): 135 | histograms[b] = torch.bincount(count_num[b], minlength=max_value + 1) 136 | 137 | for b in range(B): 138 | non_zero_mask = histograms[b][1:] > 0 139 | non_zero_values = value_range[non_zero_mask] 140 | non_zero_counts = histograms[b][1:][non_zero_mask] 141 | 142 | plt.figure(figsize=(6, 6)) 143 | plt.bar(non_zero_values.numpy(), non_zero_counts.numpy(),color='#b5e48c') 144 | plt.xlim(0, 250) 145 | 146 | plt.xticks(fontsize=20) 147 | plt.yticks(fontsize=20) 148 | # plt.xlabel('Value') 149 | # plt.ylabel('Count') 150 | # plt.title(f'Distribution of Values for B = {b} (non-zero counts)') 151 | if b==1: 152 | plt.savefig(f'/root/autodl-tmp/stable-diffusion-main/outputs/txt2img-samples/dst_1/test/samples/with_noise_prunenum.svg',format='svg') # 保存图片 153 | plt.close() 154 | 155 | 156 | for b in range(B): 157 | plt.figure(figsize=(6, 6)) 158 | count_matrix = count_num[b].reshape(n, n).cpu().numpy() 159 | 160 | heatmap = plt.imshow(count_matrix, cmap='summer', interpolation='nearest', vmin=0, vmax=250) 161 | cbar = plt.colorbar(heatmap) 162 | 163 | # Adjust the layout to ensure the colorbar and heatmap align properly 164 | plt.subplots_adjust(left=0.1, right=0.8, top=0.6, bottom=0.2) 165 | 166 | if b == 1: 167 | plt.savefig(f'/root/autodl-tmp/stable-diffusion-main/outputs/txt2img-samples/dst_1/test/samples/with_noise_heatmap.svg', format='svg') 168 | plt.close() 169 | """ 170 | ################################################################################################### 171 | combined = torch.cat((a_unm_idx, b_idx), dim=1) # (B,N-num_prune) 172 | weight = gather(consine_graph, dim=1, index=combined.unsqueeze(-1).expand(B, -1, N)) # (B,N-num_prune,N) 173 | weight_prop = gather(weight, dim=2, 174 | index=a_src_idx.unsqueeze(1).expand(-1, weight.shape[1], -1)) # (B,N-prune,num_prune) 175 | _, max_indices = torch.max(weight_prop, dim=1) # (B,num_prune) 0~N-num_prop 176 | 177 | def prune_tokens(x: torch.Tensor) -> torch.Tensor: # x: (B,N,C) 178 | B, N, C = x.shape # 179 | unm = gather(x, dim=-2, 180 | index=a_unm_idx.unsqueeze(2).expand(B, N - num_dst - num_prune, C)) # (B, N-num_dst-num_prune, C) 181 | dst = gather(x, dim=-2, index=b_idx.unsqueeze(2).expand(B, num_dst, C)) # (B,num_dst,C) 182 | result = torch.cat([unm, dst], dim=1) 183 | return result # (B,N-num_prune,c) 184 | 185 | def recover_tokens(x: torch.Tensor) -> torch.Tensor: # (B,N-num_prune,c) 186 | unm_len = a_unm_idx.shape[1] # N-num_dst-num_prune 187 | # unm: (B, N-num_dst-num_prune,C) dst: (B,num_dst,C) 188 | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] 189 | _, _, c = unm.shape 190 | # weight_prop: (B,num_prune,num_dst) 191 | # dst: (B,num_dst,C) 192 | out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) 193 | src = torch.gather(x, 1, max_indices.unsqueeze(-1).expand(-1, -1, c)) # (B,num_prune,c) 194 | out.scatter_(dim=-2, index=b_idx.unsqueeze(2).expand(B, num_dst, c), src=dst) 195 | out.scatter_(dim=-2, index=a_unm_idx.unsqueeze(2).expand(B, unm_len, c), src=unm) 196 | out.scatter_(dim=-2, index=a_src_idx.unsqueeze(2).expand(B, num_prune, c), src=src) 197 | return out # (B,N,C) 198 | 199 | return prune_tokens, recover_tokens 200 | 201 | --------------------------------------------------------------------------------