├── __init__.py
├── utils.py
├── README.md
├── patch.py
└── sito.py
/__init__.py:
--------------------------------------------------------------------------------
1 | from . import sito, patch
2 | from .patch import apply_patch, remove_patch
3 |
4 | __all__ = ["sito", "patch", "apply_patch", "remove_patch"]
5 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def isinstance_str(x: object, cls_name: str):
5 | """
6 | Checks whether x has any class *named* cls_name in its ancestry.
7 | Doesn't require access to the class's implementation.
8 |
9 | Useful for patching!
10 | """
11 | for _cls in x.__class__.__mro__:
12 | if _cls.__name__ == cls_name:
13 | return True
14 |
15 | return False
16 |
17 |
18 | def init_generator(device: torch.device, fallback: torch.Generator = None):
19 | """
20 | Forks the current default random generator given device.
21 | """
22 | if device.type == "cpu":
23 | return torch.Generator(device="cpu").set_state(torch.get_rng_state())
24 | elif device.type == "cuda":
25 | return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
26 | else:
27 | if fallback is None:
28 | return init_generator(torch.device("cpu"))
29 | else:
30 | return fallback
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
SiTo: Training-Free and Hardware-Friendly Acceleration for Diffusion Models via Similarity-based Token Pruning (AAAI-2025)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## 🔥 News
10 | - `2024/12/10`🤗🤗 Our [paper](https://www.researchgate.net/publication/387204421_Training-Free_and_Hardware-Friendly_Acceleration_for_Diffusion_Models_via_Similarity-based_Token_Pruning) is accepted by AAAI-2025
11 | - `2025/1/18` 💥💥 We release the [code](https://github.com/EvelynZhang-epiclab/SiTo) for our work about accelerating diffusion models for FREE. 🎉 **The zero-shot evaluation shows SiTo leads to 1.90x and 1.75x acceleration on COCO30K and ImageNet with 1.33 and 1.15 FID reduction at the same time. Besides, SiTo has no training requirements and does not require any calibration data, making it plug-and-play in real-world applications.**
12 | - `2025/1/21` 💿💿 We publish the required image for running our code. Click [here](https://www.codewithgpu.com/i/EvelynZhang-epiclab/SiTo/SiTo-SD) to use it directly.
13 |
14 | ## 🌸 Abstract
15 |
16 |
17 | CLICK for full abstract
18 |
19 | > The excellent performance of diffusion models in image generation is always accompanied by overlarge computation costs, which have prevented the application of diffusion models in edge devices and interactive applications. Previous works mainly focus on using fewer sampling steps and compressing the denoising network of diffusion models, while this paper proposes to accelerate diffusion models by introducing **SiTo, a similarity-based token pruning method** that adaptive prunes the redundant tokens in the input data. SiTo is designed to maximize the similarity between model prediction with and without token pruning by using cheap and hardware-friendly operations, leading to significant acceleration ratios without performance drop, and even sometimes improvements in the generation quality. For instance, **the zero-shot evaluation shows SiTo leads to 1.90x and 1.75x acceleration on COCO30K and ImageNet with 1.33 and 1.15 FID reduction** at the same time. Besides, SiTo has **no training requirements and does not require any calibration data**, making it plug-and-play in real-world applications.
20 |
21 |
22 |
23 |
24 | ## 🚀Overview
25 |
26 | SiTo has a three-stage pipeline.
27 | - SiTo carefully selects a set of **base tokens** which are utilized as the base to select and recover the pruned tokens.
28 | - SiTo selects the tokens that have the highest similarity to the base tokens as the **pruned tokens**.
29 | - SiTo feeds the unpruned tokens to the neural layers and **recovers the pruned tokens** by directly copying their most similar base tokens.
30 |
31 |
32 |
33 |
34 | The pipeline of SiTo on the example of self-attention. (a) Base Token Selection: We compute the Cosine Similarity between all the tokens. For each token, we sum its similarity to all the tokens as the SimScore. Then, Gaussian Noise is added to the SimScore introduces randomness, preventing identical base and pruned token choices across timesteps. Finally, the token that has the highest Noise SimScore in an image patch is selected as a base token. (b) Pruned Token Selection: The tokens with the highest similarity to the base tokens are selected as pruned tokens. (c) Pruned Token Recovery: The unpruned tokens are fed to the neural layers. Then, the pruned tokens are recovered by copying from their most similar base tokens.
35 |
36 |
37 | ## 📊Result
38 | ### Qualitative Result
39 |
40 |
41 |
42 | Visual comparisons with the manually crafted challenging prompts. We apply ToMeSD and SiTo on stable diffusion v1.5, achieving similar speed-up ratios of 1.63 and 1.65, respectively. Under these comparable conditions, our method generated more realistic, detailed images that better aligned with the original images and text prompts.
43 |
44 |
45 | ### Quantitative Result
46 |
47 |
48 |
49 | Comparison between the proposed SiTo and ToMeSD with SD v1.5 and SD v2 on ImageNet and COCO30k.
50 |
51 |
52 | ## 🛠 Usage
53 | ### Dependencies
54 | To run SiTo for SD, PyTorch version `1.12.1` or higher is required (due to the use of `scatter_reduce`). You can download it from [here](https://pytorch.org/get-started/locally/).
55 |
56 | ### Installation
57 | ```shell
58 | git clone https://github.com/EvelynZhang-epiclab/SiTo.git
59 | ```
60 | ### Apply SiTo
61 | Applying SiTo is very simple, you just need the following two steps (and no additional training is required):
62 |
63 | Step1: Add our code package `sito` in the `scripts`.
64 |
65 | Step2:Apply SiTo in SD v1 and SD v2:
66 | Add the following code at the respective lines of [SD v1](https://github.com/runwayml/stable-diffusion/blob/08ab4d326c96854026c4eb3454cd3b02109ee982/scripts/txt2img.py#L241) or [SD v2](https://github.com/Stability-AI/stablediffusion/blob/fc1488421a2761937b9d54784194157882cbc3b1/scripts/txt2img.py#L220):
67 |
68 | ```python
69 | import sito
70 | sito.apply_patch(model,prune_ratio=0.5)
71 | ```
72 | As follows, we also provide more fine-grained parameter control.
73 | ```python
74 | import sito
75 | sito.apply_patch(model,
76 | prune_ratio = 0.7, # The pruning ratio
77 | max_downsample_ratio = 1, # The number of layers to prune in the Unet. It is recommended to prune only the first layer (see Fig. 7 in the paper for details)
78 | prune_selfattn_flag = True, # Whether to prune the self-attention layers. Recommended
79 | prune_crossattn_flag = False, # Whether to prune the cross-attention layers. Not recommended
80 | prune_mlp_flag: bool = False, # Whether to prune the MLP layers. Strongly not recommended (see Tab. 2 in the paper for details)
81 | sx: int = 2, sy: int = 2, # Patch size
82 | noise_alpha= 0.1, # Controls the noise level
83 | sim_beta:float = 1
84 | )
85 | ```
86 | ### Run SiTo
87 |
88 | ~~~python
89 | # After setting up the environment, compile it.
90 | pip install -v -e .
91 | ~~~
92 |
93 | - Generate an image based on a prompt.
94 | ~~~
95 | python scripts/txt2img.py --n_iter 1 --n_samples 1 --W 512 --H 512 --ddim_steps 50 --plms --skip_grid --prompt "a photograph of an astronaut riding a horse"
96 | ~~~
97 |
98 | - Read prompts from a `.txt` file to generate images. Use `--from-file imagenet.txt` for generating ImageNet 2k images, and `--from-file coco30k.txt` for generating COCO 30k images.
99 |
100 | ~~~
101 | python scripts/txt2img.py --n_iter 2 --n_samples 4 --W 512 --H 512 --ddim_steps 50 --plms --skip_grid --from-file imagenet.txt
102 | ~~~
103 |
104 | - When measuring speed, set `n_iter` to at least 2 (because at least one iteration is required for warm-up). Enable both `--skip_save` and `--skip_grid` to avoid saving images.
105 |
106 | ~~~
107 | python scripts/txt2img.py --n_iter 3 --n_samples 8 --W 512 --H 512 --ddim_steps 50 --plms --skip_save --skip_grid --from-file xxx.txt
108 | ~~~
109 |
110 |
111 | ## 📐 Evaluation
112 |
113 | ### FID
114 |
115 | This implementation references the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository.
116 | Modify [this line](https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py#L146C3-L146C64]) in `pytorch_fid/fid_score.py` as follows:
117 | ~~~python
118 | dataset = ImagePathDataset(files, transforms=my_transform)
119 | my_transform = TF.Compose([
120 | TF.Resize((512, 512)),
121 | TF.ToTensor(),
122 | TF.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
123 | ])
124 | ~~~
125 |
126 | - Download Data
127 |
128 | [https://drive.google.com/drive/folders/1pIZKbvRx1WFAx_6sAaFHW4K6jugc_b3c?usp=drive_link](https://drive.google.com/drive/folders/1pIZKbvRx1WFAx_6sAaFHW4K6jugc_b3c?usp=drive_link)
129 |
130 | ~~~python
131 | python -m pytorch_fid path/to/[datasets].npz path/to/images
132 | ~~~
133 |
134 | ### Time
135 |
136 | - It is recommended to use torch.cuda.Event to measure time (since GPUs perform parallel computation, using the regular time.time can be inaccurate). Refer to the following code:
137 |
138 | ~~~py
139 | time0= torch.cuda.Event(enable_timing=True)
140 | time1= torch.cuda.Event(enable_timing=True)
141 | time0.record()
142 | # Place the code segment that needs to be measured for time here.
143 | time1.record()
144 | torch.cuda.synchronize()
145 | time_consume=time0.elapsed_time(time1)
146 | ~~~
147 |
148 | _Note: When measuring speed, it is recommended to perform a warm-up (for example, exclude the time taken for the first few iterations from the statistics)._
149 |
150 | ## 💐 Acknowledgments
151 |
152 | Special thanks to the creators of [ToMeSD](https://github.com/dbolya/tomesd) upon which this code is built, for their valuable work in advancing diffusion model acceleration.
153 |
154 | ## 🔗 Citation
155 | If you use this codebase, or SiTo inspires your work, we would greatly appreciate it if you could star the repository and cite it using the following BibTeX entry.
156 | ```
157 | @inproceedings{zhang2025training,
158 | title={Training-free and hardware-friendly acceleration for diffusion models via similarity-based token pruning},
159 | author={Zhang, Evelyn and Tang, Jiayi and Ning, Xuefei and Zhang, Linfeng},
160 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
161 | volume={39},
162 | number={9},
163 | pages={9878--9886},
164 | year={2025}
165 | }
166 | ```
167 | ## :e-mail: Contact
168 | If you have more questions or are seeking collaboration, feel free to contact me via email at [`evelynzhang2002@163.com`](mailto:yuweizhang2002@gmail.com).
169 |
--------------------------------------------------------------------------------
/patch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from typing import Type, Dict, Any, Tuple, Callable
4 | import copy
5 | from . import sito
6 | from .utils import isinstance_str, init_generator
7 |
8 | import torch.nn.functional as F
9 | import time
10 |
11 | def select_sito_method(x: torch.Tensor, sito_info: Dict[str, Any]) -> Tuple[Callable, ...]:
12 | args = sito_info["args"]
13 | current_timestep = sito_info["timestep"]
14 | original_h, original_w = sito_info["size"] # 64,64
15 | original_tokens = original_h * original_w
16 | downsaple_ratio = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
17 |
18 | if (downsaple_ratio <= args["max_downsample_ratio"]):
19 | w = int(math.ceil(original_w / downsaple_ratio))
20 | h = int(math.ceil(original_h / downsaple_ratio))
21 | num_prune = int(x.shape[1] * args["prune_ratio"])
22 | p, r = sito.prune_and_recover_tokens(x, num_prune, w=w, h=h, sx=args['sx'], sy=args['sy'],noise_alpha=args['noise_alpha'],sim_beta=args['sim_beta'],current_timestep=current_timestep)
23 |
24 | else:
25 | p, r = (sito.do_nothing, sito.do_nothing)
26 | p_a, r_a = (p, r) if args["prune_selfattn_flag"] else (sito.do_nothing, sito.do_nothing)
27 | p_c, r_c = (p, r) if args["prune_crossattn_flag"] else (sito.do_nothing, sito.do_nothing)
28 | p_m, r_m = (p, r) if args["prune_mlp_flag"] else (sito.do_nothing, sito.do_nothing)
29 | return p_a, p_c, p_m, r_a, r_c, r_m
30 |
31 |
32 | def make_sito_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
33 | class sitoBlock(block_class):
34 | _parent = block_class
35 | def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
36 | p_a, p_c, p_m, r_a, r_c, r_m = select_sito_method(x, self._sito_info) # 选择方法
37 | # self-attention
38 | prune_x = p_a(self.norm1(x))
39 | out1 = self.attn1(prune_x, context=context if self.disable_self_attn else None)
40 | x = r_a(out1) + x
41 | # cross-attention
42 | prop_x = p_c(self.norm2(x))
43 | out2= self.attn2(prop_x, context=context)
44 | x = r_c(out2) + x
45 | # MLP
46 | x = r_m(self.ff(p_m(self.norm3(x)))) + x
47 | return x
48 | return sitoBlock
49 |
50 |
51 | def make_diffusers_sito_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
52 | """
53 | Make a patched class for a diffusers model.
54 | This patch applies ToMe to the forward function of the block.
55 | """
56 |
57 | class sitoBlock(block_class):
58 | # Save for unpatching later
59 | _parent = block_class
60 |
61 | def forward(
62 | self,
63 | hidden_states,
64 | attention_mask=None,
65 | encoder_hidden_states=None,
66 | encoder_attention_mask=None,
67 | timestep=None,
68 | cross_attention_kwargs=None,
69 | class_labels=None,
70 | ) -> torch.Tensor:
71 | # (1) sito
72 | p_a, p_c, p_m, r_a, r_c, r_m = select_sito_method(hidden_states, self._sito_info)
73 |
74 | if self.use_ada_layer_norm:
75 | norm_hidden_states = self.norm1(hidden_states, timestep)
76 | elif self.use_ada_layer_norm_zero:
77 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
78 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
79 | )
80 | else:
81 | norm_hidden_states = self.norm1(hidden_states)
82 |
83 | # (2) sito p_a
84 | norm_hidden_states = p_a(norm_hidden_states)
85 | # 1. Self-Attention
86 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
87 | attn_output = self.attn1(
88 | norm_hidden_states,
89 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
90 | attention_mask=attention_mask,
91 | **cross_attention_kwargs,
92 | )
93 | if self.use_ada_layer_norm_zero:
94 | attn_output = gate_msa.unsqueeze(1) * attn_output
95 |
96 | # (3) sito r_a
97 | hidden_states = r_a(attn_output) + hidden_states
98 |
99 | if self.attn2 is not None:
100 | norm_hidden_states = (
101 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
102 | )
103 | # (4) sito p_c
104 | norm_hidden_states = p_c(norm_hidden_states)
105 |
106 | # 2. Cross-Attention
107 | attn_output = self.attn2(
108 | norm_hidden_states,
109 | encoder_hidden_states=encoder_hidden_states,
110 | attention_mask=encoder_attention_mask,
111 | **cross_attention_kwargs,
112 | )
113 | # (5) sito r_c
114 | hidden_states = r_c(attn_output) + hidden_states
115 |
116 | # 3. Feed-forward
117 | norm_hidden_states = self.norm3(hidden_states)
118 |
119 | if self.use_ada_layer_norm_zero:
120 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
121 |
122 | # (6) sito p_m
123 | norm_hidden_states = p_m(norm_hidden_states)
124 |
125 | ff_output = self.ff(norm_hidden_states)
126 |
127 | if self.use_ada_layer_norm_zero:
128 | ff_output = gate_mlp.unsqueeze(1) * ff_output
129 |
130 | # (7) sito r_m
131 | hidden_states = r_m(ff_output) + hidden_states
132 |
133 | return hidden_states
134 |
135 | return sitoBlock
136 |
137 |
138 | def hook_sito_model(model: torch.nn.Module):
139 | def hook(module, args):
140 | module._sito_info["size"] = (args[0].shape[2], args[0].shape[3])
141 | module._sito_info["timestep"] = args[1][0].cpu().item()
142 | return None
143 | model._sito_info["hooks"].append(model.register_forward_pre_hook(hook))
144 |
145 | def apply_patch(
146 | model: torch.nn.Module,
147 | prune_ratio: int = 0.5,
148 | max_downsample_ratio: int = 1,
149 | prune_selfattn_flag: bool = True,
150 | prune_crossattn_flag: bool = False,
151 | prune_mlp_flag: bool = False,
152 | sx: int = 2, sy: int = 2,
153 | noise_alpha:float = 0.1,
154 | sim_beta:float = 1
155 | ):
156 | remove_patch(model)
157 |
158 | is_diffusers = isinstance_str(model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin")
159 |
160 | if not is_diffusers:
161 | if not hasattr(model, "model") or not hasattr(model.model, "diffusion_model"):
162 | raise RuntimeError("Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.")
163 | diffusion_model = model.model.diffusion_model
164 | else:
165 | diffusion_model = model.unet if hasattr(model, "unet") else model
166 |
167 | diffusion_model._sito_info = {
168 | "name": None,
169 | "size": None,
170 | "timestep": None,
171 | "hooks": [],
172 | "args": {
173 | "prune_selfattn_flag": prune_selfattn_flag,
174 | "prune_crossattn_flag": prune_crossattn_flag,
175 | "prune_mlp_flag": prune_mlp_flag,
176 | "prune_ratio": prune_ratio,
177 | "max_downsample_ratio": max_downsample_ratio,
178 | "sx": sx, "sy": sy,
179 | "noise_alpha":noise_alpha,
180 | "sim_beta":sim_beta
181 | }
182 | }
183 | hook_sito_model(diffusion_model) # 添加size属性
184 | for x, module in diffusion_model.named_modules():
185 | if isinstance_str(module, "BasicTransformerBlock"):
186 | make_sito_block_fn = make_diffusers_sito_block if is_diffusers else make_sito_block
187 | module.__class__ = make_sito_block_fn(
188 | module.__class__)
189 | module._sito_info = diffusion_model._sito_info
190 | module._myname = x
191 | if not hasattr(module, "disable_self_attn") and not is_diffusers:
192 | module.disable_self_attn = False
193 | if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers:
194 | module.use_ada_layer_norm = False
195 | module.use_ada_layer_norm_zero = False
196 | return model
197 |
198 | def remove_patch(model: torch.nn.Module):
199 | """ Removes a patch from a sito Diffusion module if it was already patched. """
200 | model = model.unet if hasattr(model, "unet") else model
201 | for _, module in model.named_modules():
202 | if hasattr(module, "_sito_info"):
203 | for hook in module._sito_info["hooks"]:
204 | hook.remove()
205 | module._sito_info["hooks"].clear()
206 |
207 | if module.__class__.__name__ == "sitoBlock":
208 | module.__class__ = module._parent
209 |
210 | return model
211 |
212 | '''
213 | Unet Name
214 | input_blocks.1.1.transformer_blocks.0
215 | input_blocks.2.1.transformer_blocks.0
216 |
217 | input_blocks.4.1.transformer_blocks.0
218 | input_blocks.5.1.transformer_blocks.0
219 |
220 | input_blocks.7.1.transformer_blocks.0
221 | input_blocks.8.1.transformer_blocks.0
222 |
223 | middle_block.1.transformer_blocks.0
224 |
225 | output_blocks.3.1.transformer_blocks.0
226 | output_blocks.4.1.transformer_blocks.0
227 | output_blocks.5.1.transformer_blocks.0
228 |
229 | output_blocks.6.1.transformer_blocks.0
230 | output_blocks.7.1.transformer_blocks.0
231 | output_blocks.8.1.transformer_blocks.0
232 |
233 | output_blocks.9.1.transformer_blocks.0
234 | output_blocks.10.1.transformer_blocks.0
235 | output_blocks.11.1.transformer_blocks.0
236 | '''
--------------------------------------------------------------------------------
/sito.py:
--------------------------------------------------------------------------------
1 | import time
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import torch
5 | from typing import Tuple, Callable, Optional, Dict, Any
6 | import math
7 | from .utils import init_generator
8 |
9 |
10 | def do_nothing(x: torch.Tensor, mode: str = None):
11 | return x
12 |
13 | count_num = torch.zeros((2, 4096), dtype=torch.int32, device="cuda")
14 |
15 |
16 | def do_nothing(x: torch.Tensor, mode: str = None):
17 | return x
18 |
19 | def find_patch_max_indices(tensor, sx, sy): # tensor: (B,N)
20 | b, N = tensor.size()
21 | n = int(math.sqrt(N))
22 | tensor = tensor.view(b, n, n)
23 | h_patches = n // sy
24 | w_patches = n // sx
25 | tensor = tensor[:, :h_patches * sy, :w_patches * sx]
26 | tensor_reshaped = tensor.view(b, h_patches, sy, w_patches, sx)
27 | tensor_reshaped = tensor_reshaped.permute(0, 1, 3, 2, 4).contiguous()
28 | tensor_reshaped = tensor_reshaped.view(b, h_patches, w_patches, sy * sx)
29 | _, max_indices = tensor_reshaped.max(dim=-1, keepdim=True)
30 | return max_indices
31 |
32 | def duplicate_half_tensor(tensor):
33 | B, _, _,_ = tensor.shape
34 | if B % 2 != 0:
35 | raise ValueError("B must be even for this operation")
36 | first_half = tensor[:B // 2]
37 | tensor[B // 2:] = first_half
38 | return tensor
39 |
40 | def prune_and_recover_tokens(metric: torch.Tensor,
41 | num_prune: int,
42 | w: int,
43 | h: int,
44 | sx: int,
45 | sy: int,
46 | sim_beta: float,
47 | noise_alpha: float,
48 | current_timestep: int
49 | ) -> Tuple[Callable, Callable]:
50 | B, N, C = metric.shape
51 | if num_prune <= 0: # 如果r<0, 什么也不做
52 | return do_nothing, do_nothing
53 | gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
54 | with torch.no_grad():
55 | metric = metric.to(torch.float16)
56 | metric = F.normalize(metric, p=2, dim=-1)
57 | consine_graph = torch.matmul(metric, metric.transpose(-1, -2))
58 | # my_norm = metric.norm(dim=-1, keepdim=True)
59 | # my_norm = my_norm.to(torch.float16)
60 | # metric = metric / my_norm
61 | # consine_graph = metric @ metric.transpose(-1, -2)
62 | hsy, wsx = h // sy, w // sx
63 | # ##############################################################
64 | # Method 1: Select the highest score based on (sim_beta * SimScore + noise_alpha * Noise) within each patch.
65 | dst_score = consine_graph.sum(-1)
66 | noise_score= torch.randn(B, N, dtype=metric.dtype, device=metric.device)
67 | dst_score=sim_beta*dst_score+noise_alpha*noise_score
68 | rand_idx = find_patch_max_indices(dst_score, sx, sy) # [B,hsy,wsx,1]
69 | # Align CFG
70 | # rand_idx = duplicate_half_tensor(rand_idx)
71 | ##############################################################
72 | # Method 2:LocalRandom
73 | # generator = init_generator(metric.device)
74 | # rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) # [hsy, wsx, 1]
75 | # rand_idx = rand_idx.unsqueeze(0).expand(B, -1, -1, -1)
76 | #############################################################
77 | # Method3: Fixed selection of the top-left corner within each patch.
78 | # rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
79 | # rand_idx = rand_idx.unsqueeze(0).expand(B,-1,-1,-1)
80 | ##############################################################
81 | idx_buffer_view = torch.zeros(B, hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
82 | idx_buffer_view.scatter_(dim=-1, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
83 | idx_buffer_view = idx_buffer_view.view(B, hsy, wsx, sy, sx).transpose(2, 3).reshape(B, hsy * sy, wsx * sx)
84 | if (hsy * sy) < h or (wsx * sx) < w:
85 | idx_buffer = torch.zeros(B, h, w, device=metric.device, dtype=torch.int64)
86 | idx_buffer[:, :(hsy * sy), :(wsx * sx)] = idx_buffer_view # (B,h,w)
87 | else:
88 | idx_buffer = idx_buffer_view # (B,h,w)
89 | rand_idx = idx_buffer.reshape(B, -1).argsort(dim=1) # (B,N)
90 | del idx_buffer, idx_buffer_view
91 |
92 | #############################################################
93 | # Method 4: Randomly select within the global scope.
94 | # random_permutation = torch.randperm(N).to(metric.device)
95 | # rand_idx=random_permutation.unsqueeze(0).expand(B,-1)
96 | ############################################################
97 | # Method 5: Select the maximum SimScore within the global scope.
98 | # dst_score = consine_graph.sum(-1) # (B,N)
99 | # _, rand_idx = torch.sort(dst_score, dim=1, descending=True)
100 | #############################################################
101 | num_dst = hsy * wsx
102 | a_idx = rand_idx[:, num_dst:] # src (B,N-num_dst) 0~N
103 | b_idx = rand_idx[:, :num_dst] # dst (B,num_dst) 0~N
104 |
105 | score_sim_1 = gather(consine_graph, dim=1, index=a_idx.unsqueeze(-1).expand(B, -1, N)) # (B,N-num_dst,N)
106 | score_sim = gather(score_sim_1, dim=-1,
107 | index=b_idx.unsqueeze(1).expand(B, score_sim_1.shape[1], -1)) # (B,N-num_dst,num_dst)
108 | score_sim_value, score_sim_index = torch.max(score_sim, dim=2) # (B,N-num_dst) 0~num_dst
109 | total_score = - score_sim_value
110 |
111 | edge_idx = total_score.argsort(dim=-1) # The size is (B, N-num_dst), sorted in ascending order, with values as indices (0~N-num_dst).
112 | unm_idx = edge_idx[..., num_prune:] # Size: (B,N-num_dst-num_prune) Value:(0~N-num_dst) [8,1024] 0~3071
113 | src_idx = edge_idx[..., :num_prune] # Size: (B,num_prune) Value:(0~N-num_dst) [8,2048] 0~3071 1896
114 | a_idx_tmp = a_idx.expand(B, N - num_dst) # (8,3072) 0~4095
115 | a_unm_idx = gather(a_idx_tmp, dim=1, index=unm_idx) # Size: (B,N-num_dst-unm_prop) Value:(0~N)
116 | a_src_idx = gather(a_idx_tmp, dim=1,
117 | index=src_idx) # Size:(B,num_prune) Value(0~N) idx 1896 out of bound with size 1024
118 | #################################################################################
119 | # Count and visualize the number of pruning occurrences.
120 | """
121 | global count_num
122 | count_num.scatter_add_(1, a_src_idx, torch.ones_like(a_src_idx, dtype=torch.int32))
123 |
124 | if current_timestep ==1:
125 | count_num.to("cpu")
126 | n = int(np.sqrt(N))
127 |
128 | max_value = count_num.max().item()
129 | #print('max_value',max_value)
130 |
131 | value_range = torch.arange(1, max_value + 1)
132 | histograms = torch.zeros((B, max_value + 1), dtype=torch.int32)
133 |
134 | for b in range(B):
135 | histograms[b] = torch.bincount(count_num[b], minlength=max_value + 1)
136 |
137 | for b in range(B):
138 | non_zero_mask = histograms[b][1:] > 0
139 | non_zero_values = value_range[non_zero_mask]
140 | non_zero_counts = histograms[b][1:][non_zero_mask]
141 |
142 | plt.figure(figsize=(6, 6))
143 | plt.bar(non_zero_values.numpy(), non_zero_counts.numpy(),color='#b5e48c')
144 | plt.xlim(0, 250)
145 |
146 | plt.xticks(fontsize=20)
147 | plt.yticks(fontsize=20)
148 | # plt.xlabel('Value')
149 | # plt.ylabel('Count')
150 | # plt.title(f'Distribution of Values for B = {b} (non-zero counts)')
151 | if b==1:
152 | plt.savefig(f'/root/autodl-tmp/stable-diffusion-main/outputs/txt2img-samples/dst_1/test/samples/with_noise_prunenum.svg',format='svg') # 保存图片
153 | plt.close()
154 |
155 |
156 | for b in range(B):
157 | plt.figure(figsize=(6, 6))
158 | count_matrix = count_num[b].reshape(n, n).cpu().numpy()
159 |
160 | heatmap = plt.imshow(count_matrix, cmap='summer', interpolation='nearest', vmin=0, vmax=250)
161 | cbar = plt.colorbar(heatmap)
162 |
163 | # Adjust the layout to ensure the colorbar and heatmap align properly
164 | plt.subplots_adjust(left=0.1, right=0.8, top=0.6, bottom=0.2)
165 |
166 | if b == 1:
167 | plt.savefig(f'/root/autodl-tmp/stable-diffusion-main/outputs/txt2img-samples/dst_1/test/samples/with_noise_heatmap.svg', format='svg')
168 | plt.close()
169 | """
170 | ###################################################################################################
171 | combined = torch.cat((a_unm_idx, b_idx), dim=1) # (B,N-num_prune)
172 | weight = gather(consine_graph, dim=1, index=combined.unsqueeze(-1).expand(B, -1, N)) # (B,N-num_prune,N)
173 | weight_prop = gather(weight, dim=2,
174 | index=a_src_idx.unsqueeze(1).expand(-1, weight.shape[1], -1)) # (B,N-prune,num_prune)
175 | _, max_indices = torch.max(weight_prop, dim=1) # (B,num_prune) 0~N-num_prop
176 |
177 | def prune_tokens(x: torch.Tensor) -> torch.Tensor: # x: (B,N,C)
178 | B, N, C = x.shape #
179 | unm = gather(x, dim=-2,
180 | index=a_unm_idx.unsqueeze(2).expand(B, N - num_dst - num_prune, C)) # (B, N-num_dst-num_prune, C)
181 | dst = gather(x, dim=-2, index=b_idx.unsqueeze(2).expand(B, num_dst, C)) # (B,num_dst,C)
182 | result = torch.cat([unm, dst], dim=1)
183 | return result # (B,N-num_prune,c)
184 |
185 | def recover_tokens(x: torch.Tensor) -> torch.Tensor: # (B,N-num_prune,c)
186 | unm_len = a_unm_idx.shape[1] # N-num_dst-num_prune
187 | # unm: (B, N-num_dst-num_prune,C) dst: (B,num_dst,C)
188 | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
189 | _, _, c = unm.shape
190 | # weight_prop: (B,num_prune,num_dst)
191 | # dst: (B,num_dst,C)
192 | out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
193 | src = torch.gather(x, 1, max_indices.unsqueeze(-1).expand(-1, -1, c)) # (B,num_prune,c)
194 | out.scatter_(dim=-2, index=b_idx.unsqueeze(2).expand(B, num_dst, c), src=dst)
195 | out.scatter_(dim=-2, index=a_unm_idx.unsqueeze(2).expand(B, unm_len, c), src=unm)
196 | out.scatter_(dim=-2, index=a_src_idx.unsqueeze(2).expand(B, num_prune, c), src=src)
197 | return out # (B,N,C)
198 |
199 | return prune_tokens, recover_tokens
200 |
201 |
--------------------------------------------------------------------------------