├── README.md ├── assets ├── dit.png └── sd.png ├── models ├── __init__.py ├── attention.py ├── ptp_utils.py └── transformer_2d.py ├── pipelines ├── DiTconditionPipeline.py └── __init__.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # DiT-Visualization 2 | 3 | This project aims to explore the differences in feature aspects between DiT-based diffusion models and Unet-based diffusion models. We found that DiT-based diffusion models have consistent feature scales across different layers, while Unet models exhibit significant changes in feature scales and resolutions across different layers. 4 | 5 | ## Contact 6 | I'm working on accelerating dit training and reasoning at the feature and token compression level, so don't hesitate to contact me if you're interested in doing **high-impact open source projects**! 7 | 8 | 9 | 10 | ## Visualization 11 | DiT visualization: 12 | ![DiT Visualization](assets/dit.png) 13 | 14 | SD visualization: 15 | ![SD Visualization](assets/sd.png) 16 | 17 | ## Acknowledgements 18 | The project utilizes code from the following repositories: 19 | - [diffusers](https://github.com/huggingface/diffusers) 20 | - [Plug-and-Play](https://github.com/MichalGeyer/plug-and-play) 21 | - [PixArt](https://github.com/PixArt-alpha/PixArt-alpha?tab=readme-ov-file) 22 | 23 | ## Citation 24 | 25 | If you use this project in your research, please cite the following: 26 | 27 | ```bibtex 28 | @misc{guo2024dit, 29 | author = {Qin Guo and Dongxu Yue}, 30 | title = {DiT-Visualization}, 31 | year = {2024}, 32 | howpublished = {\url{https://github.com/guoqincode/DiT-Visualization}}, 33 | note = {Exploring the differences between DiT-based and Unet-based diffusion models in feature aspects using code from diffusers, Plug-and-Play, and PixArt} 34 | } 35 | -------------------------------------------------------------------------------- /assets/dit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoqincode/DiT-Visualization/14aee09086f8b4766c288401b1eb84dcdf338c28/assets/dit.png -------------------------------------------------------------------------------- /assets/sd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoqincode/DiT-Visualization/14aee09086f8b4766c288401b1eb84dcdf338c28/assets/sd.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from diffusers.utils.constants import USE_PEFT_BACKEND 8 | from diffusers.utils.torch_utils import maybe_allow_in_graph 9 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 10 | from diffusers.models.attention_processor import Attention 11 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 12 | from diffusers.models.lora import LoRACompatibleLinear 13 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 14 | from diffusers.models.attention import FeedForward, GatedSelfAttentionDense, _chunked_feed_forward 15 | 16 | @maybe_allow_in_graph 17 | class BasicTransformerBlock(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | dim: int, 22 | num_attention_heads: int, 23 | attention_head_dim: int, 24 | dropout=0.0, 25 | cross_attention_dim: Optional[int] = None, 26 | activation_fn: str = "geglu", 27 | num_embeds_ada_norm: Optional[int] = None, 28 | attention_bias: bool = False, 29 | only_cross_attention: bool = False, 30 | double_self_attention: bool = False, 31 | upcast_attention: bool = False, 32 | norm_elementwise_affine: bool = True, 33 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen' 34 | norm_eps: float = 1e-5, 35 | final_dropout: bool = False, 36 | attention_type: str = "default", 37 | positional_embeddings: Optional[str] = None, 38 | num_positional_embeddings: Optional[int] = None, 39 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 40 | ada_norm_bias: Optional[int] = None, 41 | ff_inner_dim: Optional[int] = None, 42 | ff_bias: bool = True, 43 | attention_out_bias: bool = True, 44 | ): 45 | super().__init__() 46 | self.only_cross_attention = only_cross_attention 47 | 48 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 49 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 50 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 51 | self.use_layer_norm = norm_type == "layer_norm" 52 | self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" 53 | 54 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 55 | raise ValueError( 56 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 57 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 58 | ) 59 | 60 | self.norm_type = norm_type 61 | self.num_embeds_ada_norm = num_embeds_ada_norm 62 | 63 | if positional_embeddings and (num_positional_embeddings is None): 64 | raise ValueError( 65 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 66 | ) 67 | 68 | if positional_embeddings == "sinusoidal": 69 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 70 | else: 71 | self.pos_embed = None 72 | 73 | # Define 3 blocks. Each block has its own normalization layer. 74 | # 1. Self-Attn 75 | if norm_type == "ada_norm": 76 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 77 | elif norm_type == "ada_norm_zero": 78 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 79 | elif norm_type == "ada_norm_continuous": 80 | self.norm1 = AdaLayerNormContinuous( 81 | dim, 82 | ada_norm_continous_conditioning_embedding_dim, 83 | norm_elementwise_affine, 84 | norm_eps, 85 | ada_norm_bias, 86 | "rms_norm", 87 | ) 88 | else: 89 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 90 | 91 | self.attn1 = Attention( 92 | query_dim=dim, 93 | heads=num_attention_heads, 94 | dim_head=attention_head_dim, 95 | dropout=dropout, 96 | bias=attention_bias, 97 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 98 | upcast_attention=upcast_attention, 99 | out_bias=attention_out_bias, 100 | ) 101 | 102 | # 2. Cross-Attn 103 | if cross_attention_dim is not None or double_self_attention: 104 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 105 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 106 | # the second cross attention block. 107 | if norm_type == "ada_norm": 108 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 109 | elif norm_type == "ada_norm_continuous": 110 | self.norm2 = AdaLayerNormContinuous( 111 | dim, 112 | ada_norm_continous_conditioning_embedding_dim, 113 | norm_elementwise_affine, 114 | norm_eps, 115 | ada_norm_bias, 116 | "rms_norm", 117 | ) 118 | else: 119 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 120 | 121 | self.attn2 = Attention( 122 | query_dim=dim, 123 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 124 | heads=num_attention_heads, 125 | dim_head=attention_head_dim, 126 | dropout=dropout, 127 | bias=attention_bias, 128 | upcast_attention=upcast_attention, 129 | out_bias=attention_out_bias, 130 | ) # is self-attn if encoder_hidden_states is none 131 | else: 132 | self.norm2 = None 133 | self.attn2 = None 134 | 135 | # 3. Feed-forward 136 | if norm_type == "ada_norm_continuous": 137 | self.norm3 = AdaLayerNormContinuous( 138 | dim, 139 | ada_norm_continous_conditioning_embedding_dim, 140 | norm_elementwise_affine, 141 | norm_eps, 142 | ada_norm_bias, 143 | "layer_norm", 144 | ) 145 | 146 | elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: 147 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 148 | elif norm_type == "layer_norm_i2vgen": 149 | self.norm3 = None 150 | 151 | self.ff = FeedForward( 152 | dim, 153 | dropout=dropout, 154 | activation_fn=activation_fn, 155 | final_dropout=final_dropout, 156 | inner_dim=ff_inner_dim, 157 | bias=ff_bias, 158 | ) 159 | 160 | # 4. Fuser 161 | if attention_type == "gated" or attention_type == "gated-text-image": 162 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 163 | 164 | # 5. Scale-shift for PixArt-Alpha. 165 | if norm_type == "ada_norm_single": 166 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 167 | 168 | # let chunk size default to None 169 | self._chunk_size = None 170 | self._chunk_dim = 0 171 | 172 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 173 | # Sets chunk feed-forward 174 | self._chunk_size = chunk_size 175 | self._chunk_dim = dim 176 | 177 | def forward( 178 | self, 179 | hidden_states: torch.FloatTensor, 180 | attention_mask: Optional[torch.FloatTensor] = None, 181 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 182 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 183 | timestep: Optional[torch.LongTensor] = None, 184 | cross_attention_kwargs: Dict[str, Any] = None, 185 | class_labels: Optional[torch.LongTensor] = None, 186 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 187 | ) -> torch.FloatTensor: 188 | # Notice that normalization is always applied before the real computation in the following blocks. 189 | # 0. Self-Attention 190 | batch_size = hidden_states.shape[0] 191 | 192 | if self.norm_type == "ada_norm": 193 | norm_hidden_states = self.norm1(hidden_states, timestep) 194 | elif self.norm_type == "ada_norm_zero": 195 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 196 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 197 | ) 198 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 199 | norm_hidden_states = self.norm1(hidden_states) 200 | elif self.norm_type == "ada_norm_continuous": 201 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 202 | elif self.norm_type == "ada_norm_single": 203 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 204 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 205 | ).chunk(6, dim=1) 206 | norm_hidden_states = self.norm1(hidden_states) 207 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 208 | norm_hidden_states = norm_hidden_states.squeeze(1) 209 | else: 210 | raise ValueError("Incorrect norm used") 211 | 212 | if self.pos_embed is not None: 213 | norm_hidden_states = self.pos_embed(norm_hidden_states) 214 | 215 | # 1. Retrieve lora scale. 216 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 217 | 218 | # 2. Prepare GLIGEN inputs 219 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 220 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 221 | 222 | attn_output = self.attn1( 223 | norm_hidden_states, 224 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 225 | attention_mask=attention_mask, 226 | **cross_attention_kwargs, 227 | ) 228 | if self.norm_type == "ada_norm_zero": 229 | attn_output = gate_msa.unsqueeze(1) * attn_output 230 | elif self.norm_type == "ada_norm_single": 231 | attn_output = gate_msa * attn_output 232 | 233 | hidden_states = attn_output + hidden_states 234 | if hidden_states.ndim == 4: 235 | hidden_states = hidden_states.squeeze(1) 236 | 237 | 238 | # 2.5 GLIGEN Control 239 | if gligen_kwargs is not None: 240 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 241 | 242 | # 3. Cross-Attention 243 | if self.attn2 is not None: 244 | if self.norm_type == "ada_norm": 245 | norm_hidden_states = self.norm2(hidden_states, timestep) 246 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 247 | norm_hidden_states = self.norm2(hidden_states) 248 | elif self.norm_type == "ada_norm_single": 249 | # For PixArt norm2 isn't applied here: 250 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 251 | norm_hidden_states = hidden_states 252 | elif self.norm_type == "ada_norm_continuous": 253 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 254 | else: 255 | raise ValueError("Incorrect norm") 256 | 257 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 258 | norm_hidden_states = self.pos_embed(norm_hidden_states) 259 | 260 | attn_output = self.attn2( 261 | norm_hidden_states, 262 | encoder_hidden_states=encoder_hidden_states, 263 | attention_mask=encoder_attention_mask, 264 | **cross_attention_kwargs, 265 | ) 266 | hidden_states = attn_output + hidden_states 267 | 268 | # 4. Feed-forward 269 | # i2vgen doesn't have this norm 🤷‍♂️ 270 | if self.norm_type == "ada_norm_continuous": 271 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 272 | elif not self.norm_type == "ada_norm_single": 273 | norm_hidden_states = self.norm3(hidden_states) 274 | 275 | if self.norm_type == "ada_norm_zero": 276 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 277 | 278 | if self.norm_type == "ada_norm_single": 279 | norm_hidden_states = self.norm2(hidden_states) 280 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 281 | 282 | if self._chunk_size is not None: 283 | # "feed_forward_chunk_size" can be used to save memory 284 | ff_output = _chunked_feed_forward( 285 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 286 | ) 287 | else: 288 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 289 | 290 | if self.norm_type == "ada_norm_zero": 291 | ff_output = gate_mlp.unsqueeze(1) * ff_output 292 | elif self.norm_type == "ada_norm_single": 293 | ff_output = gate_mlp * ff_output 294 | 295 | hidden_states = ff_output + hidden_states 296 | if hidden_states.ndim == 4: 297 | hidden_states = hidden_states.squeeze(1) 298 | 299 | return hidden_states -------------------------------------------------------------------------------- /models/ptp_utils.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from typing import Optional, List 6 | from diffusers.models.attention_processor import Attention 7 | from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 8 | from diffusers.utils.constants import USE_PEFT_BACKEND 9 | from matplotlib import pyplot as plt 10 | from torch import einsum 11 | from einops import rearrange 12 | 13 | class Hack_AttnProcessor: 14 | 15 | def __init__(self, attnstore, layer_in_dit): 16 | super().__init__() 17 | self.attnstore = attnstore 18 | self.layer_in_dit = layer_in_dit 19 | 20 | def __call__( 21 | self, 22 | attn: Attention, 23 | hidden_states: torch.FloatTensor, 24 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 25 | attention_mask: Optional[torch.FloatTensor] = None, 26 | temb: Optional[torch.FloatTensor] = None, 27 | scale: float = 1.0, 28 | ) -> torch.Tensor: 29 | residual = hidden_states 30 | 31 | args = () if USE_PEFT_BACKEND else (scale,) 32 | 33 | if attn.spatial_norm is not None: 34 | hidden_states = attn.spatial_norm(hidden_states, temb) 35 | 36 | batch_size, sequence_length, _ = hidden_states.shape 37 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length,batch_size) 38 | 39 | if attn.group_norm is not None: 40 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 41 | 42 | # query = attn.to_q(hidden_states) 43 | query = attn.to_q(hidden_states, *args) 44 | 45 | is_cross = encoder_hidden_states is not None 46 | # encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 47 | if encoder_hidden_states is None: 48 | encoder_hidden_states = hidden_states 49 | elif attn.norm_cross: 50 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 51 | 52 | 53 | key = attn.to_k(encoder_hidden_states, *args) 54 | value = attn.to_v(encoder_hidden_states, *args) 55 | 56 | query = attn.head_to_batch_dim(query) 57 | key = attn.head_to_batch_dim(key) 58 | value = attn.head_to_batch_dim(value) 59 | 60 | # import pdb 61 | # pdb.set_trace() 62 | 63 | # print(f"query:{query.size()}") 64 | # print(f"key:{key.size()}") 65 | ''' 66 | query:torch.Size([32, 1024, 72]) 67 | key:torch.Size([32, 1024, 72]) 68 | ''' 69 | # sim = einsum('b i d, b j d -> b i j', query, key) * scale 70 | # self_attn_map = sim.softmax(dim=-1) 71 | # self_attn_map = rearrange(self_attn_map, 'h n m -> n (h m)') 72 | 73 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 74 | 75 | self.attnstore(attention_probs, is_cross, self.layer_in_dit) 76 | 77 | hidden_states = torch.bmm(attention_probs, value) 78 | hidden_states = attn.batch_to_head_dim(hidden_states) 79 | 80 | # linear proj 81 | hidden_states = attn.to_out[0](hidden_states) 82 | # dropout 83 | hidden_states = attn.to_out[1](hidden_states) 84 | 85 | if attn.residual_connection: 86 | hidden_states = hidden_states + residual 87 | 88 | hidden_states = hidden_states / attn.rescale_output_factor 89 | 90 | return hidden_states 91 | 92 | def register_attention_control(model, controller): 93 | # indices = [int(item.split('.')[1]) for item in processor_list] 94 | 95 | attn_procs = {} 96 | cross_att_count = 0 97 | 98 | # print(model.transformer.attn_processors.keys()) 99 | ''' 100 | dict_keys(['transformer_blocks.0.attn1.processor', 'transformer_blocks.0.attn2.processor', 'transformer_blocks.1.attn1.processor', 'transformer_blocks.1.attn2.processor', 'transformer_blocks.2.attn1.processor', 'transformer_blocks.2.attn2.processor', 'transformer_blocks.3.attn1.processor', 'transformer_blocks.3.attn2.processor', 'transformer_blocks.4.attn1.processor', 'transformer_blocks.4.attn2.processor', 'transformer_blocks.5.attn1.processor', 'transformer_blocks.5.attn2.processor', 'transformer_blocks.6.attn1.processor', 'transformer_blocks.6.attn2.processor', 'transformer_blocks.7.attn1.processor', 'transformer_blocks.7.attn2.processor', 'transformer_blocks.8.attn1.processor', 'transformer_blocks.8.attn2.processor', 'transformer_blocks.9.attn1.processor', 'transformer_blocks.9.attn2.processor', 'transformer_blocks.10.attn1.processor', 'transformer_blocks.10.attn2.processor', 'transformer_blocks.11.attn1.processor', 'transformer_blocks.11.attn2.processor', 'transformer_blocks.12.attn1.processor', 'transformer_blocks.12.attn2.processor', 'transformer_blocks.13.attn1.processor', 'transformer_blocks.13.attn2.processor', 'transformer_blocks.14.attn1.processor', 'transformer_blocks.14.attn2.processor', 'transformer_blocks.15.attn1.processor', 'transformer_blocks.15.attn2.processor', 'transformer_blocks.16.attn1.processor', 'transformer_blocks.16.attn2.processor', 'transformer_blocks.17.attn1.processor', 'transformer_blocks.17.attn2.processor', 'transformer_blocks.18.attn1.processor', 'transformer_blocks.18.attn2.processor', 'transformer_blocks.19.attn1.processor', 'transformer_blocks.19.attn2.processor', 'transformer_blocks.20.attn1.processor', 'transformer_blocks.20.attn2.processor', 'transformer_blocks.21.attn1.processor', 'transformer_blocks.21.attn2.processor', 'transformer_blocks.22.attn1.processor', 'transformer_blocks.22.attn2.processor', 'transformer_blocks.23.attn1.processor', 'transformer_blocks.23.attn2.processor', 'transformer_blocks.24.attn1.processor', 'transformer_blocks.24.attn2.processor', 'transformer_blocks.25.attn1.processor', 'transformer_blocks.25.attn2.processor', 'transformer_blocks.26.attn1.processor', 'transformer_blocks.26.attn2.processor', 'transformer_blocks.27.attn1.processor', 'transformer_blocks.27.attn2.processor']) 101 | ''' 102 | 103 | for name in model.transformer.attn_processors.keys(): 104 | 105 | # import pdb 106 | # pdb.set_trace() 107 | if 'fuser' in name: continue 108 | 109 | layer_in_dit = int(name.split('.')[1]) 110 | 111 | if 'attn1' in name: 112 | cross_att_count += 1 113 | attn_procs[name] = Hack_AttnProcessor( 114 | attnstore=controller, layer_in_dit=layer_in_dit 115 | ) 116 | else: 117 | attn_procs[name] = AttnProcessor2_0() 118 | 119 | # set_attn_processor需要实现 120 | # import pdb 121 | # pdb.set_trace() 122 | model.transformer.set_attn_processor(attn_procs) 123 | controller.num_att_layers = cross_att_count 124 | 125 | 126 | 127 | class AttentionControl(abc.ABC): 128 | 129 | def step_callback(self, x_t): 130 | return x_t 131 | 132 | def between_steps(self): 133 | return 134 | 135 | @property 136 | def num_uncond_att_layers(self): 137 | return 0 138 | 139 | @abc.abstractmethod 140 | def forward(self, attn, is_cross: bool, layer_in_dit: int): 141 | raise NotImplementedError 142 | 143 | def __call__(self, attn, is_cross: bool, layer_in_dit: int): 144 | if self.cur_att_layer >= self.num_uncond_att_layers: 145 | # conditional attention 146 | # h = attn.shape[0] 147 | # self[h//2:].forward(attn[h//2:], is_cross, layer_in_dit) 148 | attn = self.forward(attn, is_cross, layer_in_dit) 149 | 150 | self.cur_att_layer += 1 151 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 152 | self.cur_att_layer = 0 153 | self.cur_step += 1 154 | self.between_steps() 155 | 156 | def reset(self): 157 | self.cur_step = 0 158 | self.cur_att_layer = 0 159 | 160 | def __init__(self): 161 | self.cur_step = 0 162 | self.num_att_layers = -1 163 | self.cur_att_layer = 0 164 | 165 | 166 | class EmptyControl(AttentionControl): 167 | 168 | def forward(self, attn, is_cross: bool, layer_in_dit: int): 169 | return attn 170 | 171 | 172 | class AttentionStore(AttentionControl): 173 | 174 | @staticmethod 175 | def get_empty_store(): 176 | return { 177 | "0_self": [], "0_cross": [], 178 | "1_self": [], "1_cross": [], 179 | "2_self": [], "2_cross": [], 180 | "3_self": [], "3_cross": [], 181 | "4_self": [], "4_cross": [], 182 | "5_self": [], "5_cross": [], 183 | "6_self": [], "6_cross": [], 184 | "7_self": [], "7_cross": [], 185 | "8_self": [], "8_cross": [], 186 | "9_self": [], "9_cross": [], 187 | "10_self": [], "10_cross": [], 188 | "11_self": [], "11_cross": [], 189 | "12_self": [], "12_cross": [], 190 | "13_self": [], "13_cross": [], 191 | "14_self": [], "14_cross": [], 192 | "15_self": [], "15_cross": [], 193 | "16_self": [], "16_cross": [], 194 | "17_self": [], "17_cross": [], 195 | "18_self": [], "18_cross": [], 196 | "19_self": [], "19_cross": [], 197 | "20_self": [], "20_cross": [], 198 | "21_self": [], "21_cross": [], 199 | "22_self": [], "22_cross": [], 200 | "23_self": [], "23_cross": [], 201 | "24_self": [], "24_cross": [], 202 | "25_self": [], "25_cross": [], 203 | "26_self": [], "26_cross": [], 204 | "27_self": [], "27_cross": [] 205 | } 206 | 207 | # return {"down_cross": [], "mid_cross": [], "up_cross": [], 208 | # "down_self": [], "mid_self": [], "up_self": []} 209 | 210 | def forward(self, attn, is_cross: bool, layer_in_dit: int): 211 | key = f"{layer_in_dit}_{'cross' if is_cross else 'self'}" 212 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 213 | self.step_store[key].append(attn) 214 | return attn 215 | 216 | def between_steps(self): 217 | self.attention_store = self.step_store 218 | if self.save_global_store: 219 | with torch.no_grad(): 220 | if len(self.global_store) == 0: 221 | self.global_store = self.step_store 222 | else: 223 | for key in self.global_store: 224 | for i in range(len(self.global_store[key])): 225 | self.global_store[key][i] += self.step_store[key][i].detach() 226 | self.step_store = self.get_empty_store() 227 | self.step_store = self.get_empty_store() 228 | 229 | def get_average_attention(self): 230 | average_attention = self.attention_store 231 | return average_attention 232 | 233 | def get_average_global_attention(self): 234 | average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in 235 | self.attention_store} 236 | return average_attention 237 | 238 | def reset(self): 239 | super(AttentionStore, self).reset() 240 | self.step_store = self.get_empty_store() 241 | self.attention_store = {} 242 | self.global_store = {} 243 | 244 | def __init__(self, save_global_store=False): 245 | ''' 246 | Initialize an empty AttentionStore 247 | :param step_index: used to visualize only a specific step in the diffusion process 248 | ''' 249 | super(AttentionStore, self).__init__() 250 | self.save_global_store = save_global_store 251 | self.step_store = self.get_empty_store() 252 | self.attention_store = {} 253 | self.global_store = {} 254 | self.curr_step_index = 0 255 | 256 | 257 | def get_self_attention_map(attention_store: AttentionStore, 258 | tgt_res: int, 259 | from_which_layer: int, 260 | is_cross: bool, 261 | ): 262 | from_which_layer = f"{str(from_which_layer)}_{'cross' if is_cross else 'self'}" 263 | # import pdb 264 | # pdb.set_trace() 265 | attn_map = attention_store.attention_store[from_which_layer][0] 266 | # import pdb 267 | # pdb.set_trace() 268 | # for conditional score attention map # 269 | h = attn_map.shape[0] 270 | attn_map = attn_map[h//2:] # CFG torch.Size([16, 1024, 1024]) 271 | # for conditional score attention map # 272 | self_attn_map = rearrange(attn_map, 'h n m -> n (h m)') 273 | 274 | 275 | # attn_map = attn_map.mean(dim=0) 276 | # attn_map = torch.nn.functional.interpolate(attn_map.unsqueeze(0).unsqueeze(0).cuda(),tgt_res,mode='bilinear').cpu() 277 | # attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) 278 | # attn_map = attn_map.squeeze(0).squeeze(0).numpy() 279 | 280 | return self_attn_map 281 | 282 | def save_attention_map_as_image(attn_map, save_path, title=None, cmap='hot'): 283 | attn_map_np = attn_map.cpu().numpy() if torch.is_tensor(attn_map) else attn_map 284 | fig, ax = plt.subplots() 285 | 286 | cax = ax.imshow(attn_map_np, cmap=cmap, interpolation='nearest') 287 | fig.colorbar(cax) 288 | 289 | if title: 290 | ax.set_title(title) 291 | ax.axis('off') 292 | 293 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1) 294 | plt.close(fig) 295 | 296 | 297 | def aggregate_attention(attention_store: AttentionStore, 298 | res: int, 299 | from_where: List[str], 300 | is_cross: bool, 301 | select: int) -> torch.Tensor: 302 | """ Aggregates the attention across the different layers and heads at the specified resolution. """ 303 | out = [] 304 | attention_maps = attention_store.get_average_attention() 305 | num_pixels = res ** 2 306 | for location in from_where: 307 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 308 | if item.shape[1] == num_pixels: 309 | cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select] 310 | out.append(cross_maps) 311 | out = torch.cat(out, dim=0) 312 | out = out.sum(0) / out.shape[0] 313 | return out 314 | 315 | -------------------------------------------------------------------------------- /models/transformer_2d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional, Union 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 9 | # from diffusers.models.attention import BasicTransformerBlock 10 | from .attention import BasicTransformerBlock 11 | from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection 12 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 13 | from diffusers.models.modeling_utils import ModelMixin 14 | from diffusers.models.normalization import AdaLayerNormSingle 15 | from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput 16 | from diffusers.models.attention_processor import AttentionProcessor 17 | 18 | class Transformer2DModel(ModelMixin, ConfigMixin): 19 | 20 | _supports_gradient_checkpointing = True 21 | 22 | @register_to_config 23 | def __init__( 24 | self, 25 | num_attention_heads: int = 16, 26 | attention_head_dim: int = 88, 27 | in_channels: Optional[int] = None, 28 | out_channels: Optional[int] = None, 29 | num_layers: int = 1, 30 | dropout: float = 0.0, 31 | norm_num_groups: int = 32, 32 | cross_attention_dim: Optional[int] = None, 33 | attention_bias: bool = False, 34 | sample_size: Optional[int] = None, 35 | num_vector_embeds: Optional[int] = None, 36 | patch_size: Optional[int] = None, 37 | activation_fn: str = "geglu", 38 | num_embeds_ada_norm: Optional[int] = None, 39 | use_linear_projection: bool = False, 40 | only_cross_attention: bool = False, 41 | double_self_attention: bool = False, 42 | upcast_attention: bool = False, 43 | norm_type: str = "layer_norm", 44 | norm_elementwise_affine: bool = True, 45 | norm_eps: float = 1e-5, 46 | attention_type: str = "default", 47 | caption_channels: int = None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 56 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 57 | 58 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 59 | # Define whether input is continuous or discrete depending on configuration 60 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 61 | self.is_input_vectorized = num_vector_embeds is not None 62 | self.is_input_patches = in_channels is not None and patch_size is not None 63 | 64 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 65 | deprecation_message = ( 66 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 67 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 68 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 69 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 70 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 71 | ) 72 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 73 | norm_type = "ada_norm" 74 | 75 | if self.is_input_continuous and self.is_input_vectorized: 76 | raise ValueError( 77 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 78 | " sure that either `in_channels` or `num_vector_embeds` is None." 79 | ) 80 | elif self.is_input_vectorized and self.is_input_patches: 81 | raise ValueError( 82 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 83 | " sure that either `num_vector_embeds` or `num_patches` is None." 84 | ) 85 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 86 | raise ValueError( 87 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 88 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 89 | ) 90 | 91 | # 2. Define input layers 92 | if self.is_input_continuous: 93 | self.in_channels = in_channels 94 | 95 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 96 | if use_linear_projection: 97 | self.proj_in = linear_cls(in_channels, inner_dim) 98 | else: 99 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 100 | elif self.is_input_vectorized: 101 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 102 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 103 | 104 | self.height = sample_size 105 | self.width = sample_size 106 | self.num_vector_embeds = num_vector_embeds 107 | self.num_latent_pixels = self.height * self.width 108 | 109 | self.latent_image_embedding = ImagePositionalEmbeddings( 110 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 111 | ) 112 | elif self.is_input_patches: 113 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 114 | 115 | self.height = sample_size 116 | self.width = sample_size 117 | 118 | self.patch_size = patch_size 119 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 120 | interpolation_scale = max(interpolation_scale, 1) 121 | self.pos_embed = PatchEmbed( 122 | height=sample_size, 123 | width=sample_size, 124 | patch_size=patch_size, 125 | in_channels=in_channels, 126 | embed_dim=inner_dim, 127 | interpolation_scale=interpolation_scale, 128 | ) 129 | 130 | # 3. Define transformers blocks 131 | self.transformer_blocks = nn.ModuleList( 132 | [ 133 | BasicTransformerBlock( 134 | inner_dim, 135 | num_attention_heads, 136 | attention_head_dim, 137 | dropout=dropout, 138 | cross_attention_dim=cross_attention_dim, 139 | activation_fn=activation_fn, 140 | num_embeds_ada_norm=num_embeds_ada_norm, 141 | attention_bias=attention_bias, 142 | only_cross_attention=only_cross_attention, 143 | double_self_attention=double_self_attention, 144 | upcast_attention=upcast_attention, 145 | norm_type=norm_type, 146 | norm_elementwise_affine=norm_elementwise_affine, 147 | norm_eps=norm_eps, 148 | attention_type=attention_type, 149 | ) 150 | for d in range(num_layers) 151 | ] 152 | ) 153 | 154 | # 4. Define output layers 155 | self.out_channels = in_channels if out_channels is None else out_channels # for DiT, out_channels=8 156 | if self.is_input_continuous: 157 | # TODO: should use out_channels for continuous projections 158 | if use_linear_projection: 159 | self.proj_out = linear_cls(inner_dim, in_channels) 160 | else: 161 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 162 | elif self.is_input_vectorized: 163 | self.norm_out = nn.LayerNorm(inner_dim) 164 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 165 | elif self.is_input_patches and norm_type != "ada_norm_single": 166 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 167 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 168 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 169 | elif self.is_input_patches and norm_type == "ada_norm_single": 170 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 171 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 172 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 173 | 174 | # 5. PixArt-Alpha blocks. 175 | self.adaln_single = None 176 | self.use_additional_conditions = False 177 | if norm_type == "ada_norm_single": 178 | self.use_additional_conditions = self.config.sample_size == 128 179 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 180 | # additional conditions until we find better name 181 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 182 | 183 | self.caption_projection = None 184 | if caption_channels is not None: 185 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) 186 | 187 | self.gradient_checkpointing = False 188 | 189 | def _set_gradient_checkpointing(self, module, value=False): 190 | if hasattr(module, "gradient_checkpointing"): 191 | module.gradient_checkpointing = value 192 | 193 | def forward( 194 | self, 195 | hidden_states: torch.Tensor, 196 | encoder_hidden_states: Optional[torch.Tensor] = None, 197 | timestep: Optional[torch.LongTensor] = None, 198 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 199 | class_labels: Optional[torch.LongTensor] = None, 200 | cross_attention_kwargs: Dict[str, Any] = None, 201 | attention_mask: Optional[torch.Tensor] = None, 202 | encoder_attention_mask: Optional[torch.Tensor] = None, 203 | return_dict: bool = True, 204 | ): 205 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 206 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 207 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 208 | # expects mask of shape: 209 | # [batch, key_tokens] 210 | # adds singleton query_tokens dimension: 211 | # [batch, 1, key_tokens] 212 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 213 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 214 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 215 | if attention_mask is not None and attention_mask.ndim == 2: 216 | # assume that mask is expressed as: 217 | # (1 = keep, 0 = discard) 218 | # convert mask into a bias that can be added to attention scores: 219 | # (keep = +0, discard = -10000.0) 220 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 221 | attention_mask = attention_mask.unsqueeze(1) 222 | 223 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 224 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 225 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 226 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 227 | 228 | # Retrieve lora scale. 229 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 230 | 231 | # 1. Input 232 | if self.is_input_continuous: 233 | batch, _, height, width = hidden_states.shape 234 | residual = hidden_states 235 | 236 | hidden_states = self.norm(hidden_states) 237 | if not self.use_linear_projection: 238 | hidden_states = ( 239 | self.proj_in(hidden_states, scale=lora_scale) 240 | if not USE_PEFT_BACKEND 241 | else self.proj_in(hidden_states) 242 | ) 243 | inner_dim = hidden_states.shape[1] 244 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 245 | else: 246 | inner_dim = hidden_states.shape[1] 247 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 248 | hidden_states = ( 249 | self.proj_in(hidden_states, scale=lora_scale) 250 | if not USE_PEFT_BACKEND 251 | else self.proj_in(hidden_states) 252 | ) 253 | 254 | elif self.is_input_vectorized: 255 | hidden_states = self.latent_image_embedding(hidden_states) 256 | elif self.is_input_patches: 257 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 258 | hidden_states = self.pos_embed(hidden_states) 259 | 260 | if self.adaln_single is not None: 261 | if self.use_additional_conditions and added_cond_kwargs is None: 262 | raise ValueError( 263 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." 264 | ) 265 | batch_size = hidden_states.shape[0] 266 | timestep, embedded_timestep = self.adaln_single( 267 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 268 | ) 269 | 270 | # 2. Blocks 271 | if self.caption_projection is not None: 272 | batch_size = hidden_states.shape[0] 273 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 274 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 275 | 276 | # print(f"before all attention hidden_states:{hidden_states.size()}") 277 | 278 | for block in self.transformer_blocks: 279 | if self.training and self.gradient_checkpointing: 280 | 281 | def create_custom_forward(module, return_dict=None): 282 | def custom_forward(*inputs): 283 | if return_dict is not None: 284 | return module(*inputs, return_dict=return_dict) 285 | else: 286 | return module(*inputs) 287 | 288 | return custom_forward 289 | 290 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 291 | hidden_states = torch.utils.checkpoint.checkpoint( 292 | create_custom_forward(block), 293 | hidden_states, 294 | attention_mask, 295 | encoder_hidden_states, 296 | encoder_attention_mask, 297 | timestep, 298 | cross_attention_kwargs, 299 | class_labels, 300 | **ckpt_kwargs, 301 | ) 302 | else: 303 | hidden_states = block( 304 | hidden_states, 305 | attention_mask=attention_mask, 306 | encoder_hidden_states=encoder_hidden_states, 307 | encoder_attention_mask=encoder_attention_mask, 308 | timestep=timestep, 309 | cross_attention_kwargs=cross_attention_kwargs, 310 | class_labels=class_labels, 311 | ) 312 | 313 | # print(f"after all attention hidden_states:{hidden_states.size()}") 314 | 315 | # 3. Output 316 | if self.is_input_continuous: 317 | if not self.use_linear_projection: 318 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 319 | hidden_states = ( 320 | self.proj_out(hidden_states, scale=lora_scale) 321 | if not USE_PEFT_BACKEND 322 | else self.proj_out(hidden_states) 323 | ) 324 | else: 325 | hidden_states = ( 326 | self.proj_out(hidden_states, scale=lora_scale) 327 | if not USE_PEFT_BACKEND 328 | else self.proj_out(hidden_states) 329 | ) 330 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 331 | 332 | output = hidden_states + residual 333 | elif self.is_input_vectorized: 334 | hidden_states = self.norm_out(hidden_states) 335 | logits = self.out(hidden_states) 336 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 337 | logits = logits.permute(0, 2, 1) 338 | 339 | # log(p(x_0)) 340 | output = F.log_softmax(logits.double(), dim=1).float() 341 | 342 | if self.is_input_patches: 343 | if self.config.norm_type != "ada_norm_single": 344 | conditioning = self.transformer_blocks[0].norm1.emb( 345 | timestep, class_labels, hidden_dtype=hidden_states.dtype 346 | ) 347 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 348 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 349 | hidden_states = self.proj_out_2(hidden_states) 350 | elif self.config.norm_type == "ada_norm_single": 351 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 352 | hidden_states = self.norm_out(hidden_states) 353 | # Modulation 354 | hidden_states = hidden_states * (1 + scale) + shift 355 | hidden_states = self.proj_out(hidden_states) 356 | hidden_states = hidden_states.squeeze(1) 357 | 358 | # unpatchify 359 | if self.adaln_single is None: 360 | height = width = int(hidden_states.shape[1] ** 0.5) 361 | hidden_states = hidden_states.reshape( 362 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 363 | ) 364 | # print(f"after reshape hidden_states:{hidden_states.size()}") 365 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 366 | # print(f"after einsum hidden_states:{hidden_states.size()}") 367 | output = hidden_states.reshape( 368 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 369 | ) 370 | # print(f"after output hidden_states:{output.size()}") 371 | 372 | 373 | if not return_dict: 374 | return (output,) 375 | 376 | return Transformer2DModelOutput(sample=output) 377 | 378 | # https://github.com/huggingface/diffusers/blob/6133d98ff70eafad7b9f65da50a450a965d1957f/src/diffusers/models/unets/unet_2d_condition.py#L693 379 | @property 380 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 381 | r""" 382 | Returns: 383 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 384 | indexed by its weight name. 385 | """ 386 | # set recursively 387 | processors = {} 388 | 389 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 390 | if hasattr(module, "get_processor"): 391 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 392 | 393 | for sub_name, child in module.named_children(): 394 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 395 | 396 | return processors 397 | 398 | for name, module in self.named_children(): 399 | fn_recursive_add_processors(name, module, processors) 400 | 401 | return processors 402 | 403 | # https://github.com/huggingface/diffusers/blob/6133d98ff70eafad7b9f65da50a450a965d1957f/src/diffusers/models/unets/unet_2d_condition.py#L716 404 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 405 | r""" 406 | Sets the attention processor to use to compute attention. 407 | 408 | Parameters: 409 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 410 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 411 | for **all** `Attention` layers. 412 | 413 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 414 | processor. This is strongly recommended when setting trainable attention processors. 415 | 416 | """ 417 | count = len(self.attn_processors.keys()) 418 | 419 | if isinstance(processor, dict) and len(processor) != count: 420 | raise ValueError( 421 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 422 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 423 | ) 424 | 425 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 426 | if hasattr(module, "set_processor"): 427 | if not isinstance(processor, dict): 428 | module.set_processor(processor) 429 | else: 430 | module.set_processor(processor.pop(f"{name}.processor")) 431 | 432 | for sub_name, child in module.named_children(): 433 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 434 | 435 | for name, module in self.named_children(): 436 | fn_recursive_attn_processor(name, module, processor) 437 | 438 | -------------------------------------------------------------------------------- /pipelines/DiTconditionPipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | import urllib.parse as ul 4 | from typing import Callable, List, Optional, Tuple, Union 5 | from typing import Any, Dict 6 | 7 | from diffusers.models import AutoencoderKL, Transformer2DModel 8 | from diffusers.pipelines.pipeline_utils import ImagePipelineOutput 9 | from diffusers.schedulers import DPMSolverMultistepScheduler 10 | import torch 11 | import torch.nn.functional as F 12 | from transformers import T5EncoderModel, T5Tokenizer 13 | 14 | from models.transformer_2d import Transformer2DModel 15 | # from DiTcondition.models.transformer_2d import DiTcondition as Transformer2DModel 16 | from diffusers.pipelines import PixArtAlphaPipeline 17 | from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, retrieve_timesteps 18 | from diffusers.models.attention import GatedSelfAttentionDense 19 | from copy import deepcopy 20 | 21 | class DiTconditionPipeline(PixArtAlphaPipeline): 22 | def __init__( 23 | self, 24 | tokenizer: T5Tokenizer, 25 | text_encoder: T5EncoderModel, 26 | vae: AutoencoderKL, 27 | transformer: Transformer2DModel, 28 | scheduler: DPMSolverMultistepScheduler): 29 | super().__init__(tokenizer, text_encoder, vae, transformer, scheduler) 30 | 31 | 32 | @torch.no_grad() 33 | def __call__( 34 | self, 35 | prompt: Union[str, List[str]] = None, 36 | negative_prompt: str = "", 37 | num_inference_steps: int = 20, 38 | timesteps: List[int] = None, 39 | guidance_scale: float = 4.5, 40 | num_images_per_prompt: Optional[int] = 1, 41 | height: Optional[int] = None, 42 | width: Optional[int] = None, 43 | eta: float = 0, 44 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 45 | latents: Optional[torch.FloatTensor] = None, 46 | prompt_embeds: Optional[torch.FloatTensor] = None, 47 | prompt_attention_mask: Optional[torch.FloatTensor] = None, 48 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 49 | negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, 50 | output_type: Optional[str] = "pil", 51 | return_dict: bool = True, 52 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 53 | callback_steps: int = 1, 54 | clean_caption: bool = True, 55 | use_resolution_binning: bool = True, 56 | **kwargs 57 | ) -> Union[ImagePipelineOutput, Tuple]: 58 | return super().__call__( 59 | prompt, 60 | negative_prompt, 61 | num_inference_steps, 62 | timesteps, 63 | guidance_scale, 64 | num_images_per_prompt, 65 | height, 66 | width, 67 | eta, 68 | generator, 69 | latents, 70 | prompt_embeds, 71 | prompt_attention_mask, 72 | negative_prompt_embeds, 73 | negative_prompt_attention_mask, 74 | output_type, 75 | return_dict, 76 | callback, 77 | callback_steps, 78 | clean_caption, 79 | use_resolution_binning, 80 | **kwargs) -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | # import ipdb 4 | from diffusers import PixArtAlphaPipeline 5 | from models.transformer_2d import Transformer2DModel 6 | from models.ptp_utils import register_attention_control, AttentionStore 7 | from models.ptp_utils import get_self_attention_map, save_attention_map_as_image 8 | import copy 9 | from sklearn.decomposition import PCA 10 | from torchvision import transforms as T 11 | from math import sqrt 12 | from PIL import Image 13 | import numpy as np 14 | 15 | # from diffusers.models.transformers.transformer_2d import Transformer2DModel 16 | # torch2.x diffusers==0.26.3 17 | 18 | def visualize_and_save_features_pca(feature_maps_fit_data,feature_maps_transform_data, transform_experiments, t, save_dir): 19 | feature_maps_fit_data = feature_maps_fit_data.cpu().numpy() 20 | pca = PCA(n_components=3) 21 | pca.fit(feature_maps_fit_data) 22 | feature_maps_pca = pca.transform(feature_maps_transform_data.cpu().numpy()) # N X 3 23 | feature_maps_pca = feature_maps_pca.reshape(len(transform_experiments), -1, 3) # B x (H * W) x 3 24 | for i, experiment in enumerate(transform_experiments): 25 | pca_img = feature_maps_pca[i] # (H * W) x 3 26 | h = w = int(sqrt(pca_img.shape[0])) 27 | pca_img = pca_img.reshape(h, w, 3) 28 | pca_img_min = pca_img.min(axis=(0, 1)) 29 | pca_img_max = pca_img.max(axis=(0, 1)) 30 | pca_img = (pca_img - pca_img_min) / (pca_img_max - pca_img_min) 31 | pca_img = Image.fromarray((pca_img * 255).astype(np.uint8)) 32 | pca_img = T.Resize(512, interpolation=T.InterpolationMode.NEAREST)(pca_img) 33 | pca_img.save(os.path.join(save_dir, f"{experiment}_layer_{t}.png")) 34 | 35 | 36 | generator = torch.Generator("cuda").manual_seed(1024) 37 | pixart_transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", subfolder="transformer",torch_dtype=torch.float16,) 38 | pipe = PixArtAlphaPipeline.from_pretrained( 39 | "PixArt-alpha/PixArt-XL-2-512x512", 40 | transformer = pixart_transformer, 41 | torch_dtype=torch.float16) 42 | pipe = pipe.to("cuda") 43 | 44 | controller = AttentionStore() 45 | register_attention_control(pipe,controller) 46 | 47 | prompt = "An astronaut riding a horse." 48 | negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' 49 | 50 | images = pipe(prompt=prompt,negative_prompt=negative_prompt,height=512,width=512).images[0] 51 | 52 | for i in range(28): 53 | attn_map = get_self_attention_map(controller,256,i,False) 54 | transform_attn_maps = copy.deepcopy(attn_map) 55 | visualize_and_save_features_pca( 56 | torch.cat([attn_map], dim=0), 57 | torch.cat([transform_attn_maps], dim=0), 58 | ['debug'], 59 | i, 60 | './self_attn_maps' 61 | ) 62 | 63 | 64 | images.save('generated_img.png') 65 | --------------------------------------------------------------------------------