├── teaser.png ├── main.py ├── README.md └── cross_processor.py /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/GRAT/HEAD/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 3 | from diffusers.utils import export_to_video 4 | from cross_processor import CrossAttnProcessor2_0, init_local_mask_flex 5 | model_id = "hunyuanvideo-community/HunyuanVideo" 6 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 7 | model_id, subfolder="transformer", torch_dtype=torch.bfloat16 8 | ) 9 | height=768 10 | width=1280 11 | frame = 128 12 | device = torch.device('cuda') 13 | pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 14 | pipe.vae.enable_tiling() 15 | pipe.enable_model_cpu_offload() 16 | #pipe.to(device) 17 | prompt = "A sleek white yacht gliding across a crystal-blue sea at sunset, camera circles the vessel as golden light sparkles on gentle waves, slight lens distortion." 18 | attenable = len(pipe.tokenizer(prompt)['input_ids']) 19 | group_t, group_h, group_w = 4,8,8 20 | mask = init_local_mask_flex(frame//4, height // 16, width // 16, text_length=256, attenable_text=attenable, group_t=group_t, group_h=group_h,group_w=group_w, device=device) 21 | attn_processors = {} 22 | for k,v in transformer.attn_processors.items(): 23 | if "token_refiner" in k: 24 | attn_processors[k] = v 25 | else: 26 | attn_processors[k] = CrossAttnProcessor2_0(mask, frame//4, height // 16, width // 16, group_t, group_h, group_w,text_length=256) 27 | transformer.set_attn_processor(attn_processors) 28 | output = pipe( 29 | prompt=prompt, 30 | height=height, 31 | width=width, 32 | num_frames=frame, 33 | num_inference_steps=30 34 | ).frames[0] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRAT 2 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2505.14687-b31b1b.svg)](https://arxiv.org/abs/2505.14687) 3 | [![project page](https://img.shields.io/badge/project%20page-lightblue)](https://oliverrensu.github.io/project/GRAT/) 4 | 5 | 6 | This repository is the official implementation of our [Grouping First, Attending Smartly: Training-Free Acceleration for Diffusion Transformers](https://arxiv.org/abs/2505.14687) 7 | 8 | ## Introduction 9 | Diffusion-based Transformers have demonstrated impressive generative capabilities, but their high computational costs hinder practical deployment—for example, generating an $8192\times8192$ image can take over an hour on an A100 GPU. 10 | In this work, we propose GRAT (GRouping first, ATtending smartly), a training-free attention acceleration strategy for fast image and video generation without compromising output quality. 11 | The key insight is to exploit the inherent sparsity in learned attention maps (which tend to be locally focused) in pretrained Diffusion Transformers and leverage better GPU parallelism. 12 | Specifically, GRAT first partitions contiguous tokens into non-overlapping groups, aligning both with GPU execution patterns and the local attention structures learned in pretrained generative Transformers. 13 | It then accelerates attention by having all query tokens within the same group share a common set of attendable key and value tokens. These key and value tokens are further restricted to structured regions, such as surrounding blocks or criss-cross regions, significantly reducing computational overhead (e.g., attaining a 35.8$\times$ speedup over full attention when generating $8192\times8192$ images) while preserving essential attention patterns and long-range context. 14 | We validate GRAT on pretrained Flux and HunyuanVideo for image and video generation, respectively. 15 | In both cases, GRAT achieves substantially faster inference without any fine-tuning, while maintaining the performance of full attention. 16 | We hope GRAT will inspire future research on accelerating Diffusion Transformers for scalable visual generation. 17 | ![teaser](teaser.png) 18 | ## Generate a video 19 | ```python 20 | python main.py 21 | ``` 22 | Generate a 5s 720p video. 23 | 24 | 25 | ## Reference 26 | If you have any question, feel free to contact [Sucheng Ren](oliverrensu@gmail.com) 27 | 28 | ``` 29 | @article{ren2025grat, 30 | title={Grouping First, Attending Smartly: Training-Free Acceleration for Diffusion Transformers}, 31 | author={Ren, Sucheng and Yu, Qihang and He, Ju and Yuille, Alan and Chen, Liang-Chieh}, 32 | journal={arXiv preprint arXiv:2505.14687}, 33 | year={2025} 34 | } 35 | ``` 36 | 37 | ## Acknowledgement 38 | [CLEAR](https://github.com/Huage001/CLEAR) 39 | 40 | -------------------------------------------------------------------------------- /cross_processor.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from diffusers.models.attention_processor import Attention, AttentionProcessor 4 | from torch.nn.attention.flex_attention import create_block_mask, flex_attention 5 | create_block_mask = torch.compile(create_block_mask) 6 | from typing import Optional 7 | from functools import partial, lru_cache 8 | 9 | @lru_cache 10 | def init_local_mask_flex(frames, height, width, text_length, attenable_text, group_t, group_h, group_w, device): 11 | total_length=height*width*frames 12 | cell_size = group_t* group_h* group_w 13 | t,h,w = frames//group_t, height//group_h, width//group_w 14 | def local_mask(b, h_, q_idx, kv_idx): 15 | q_y=q_idx//cell_size 16 | kv_y = kv_idx//cell_size 17 | q_t = q_y//(h*w) 18 | q_h = (q_y%(h*w))//w 19 | q_w = (q_y%(h*w))%w 20 | 21 | kv_t = kv_y//(h*w) 22 | kv_h = (kv_y%(h*w))//w 23 | kv_w = (kv_y%(h*w))%w 24 | 25 | text = kv_idx=total_length, torch.logical_and(kv_idx=total_length)) 27 | 28 | image = torch.logical_and(torch.logical_or(torch.logical_or(q_t==kv_t, q_w==kv_w), q_h==kv_h), q_idxnxthwapqc', x) 55 | x = x.reshape(bsz, head, -1, c) 56 | return x 57 | def unclusterify(self, x): 58 | bsz, head, n, c = x.shape 59 | p_t, p_h, p_w = self.group_t, self.group_h, self.group_w 60 | t, h, w = self.t, self.h, self.w 61 | t_, h_, w_ = t//p_t, h // p_h, w // p_w 62 | x = x.reshape(bsz, head, t_, h_, w_, p_t, p_h, p_w, c) 63 | x = torch.einsum('nxthwapqc->nxtahpwqc', x) 64 | x = x.reshape(bsz, head, -1, c) 65 | return x 66 | 67 | def __call__( 68 | self, 69 | attn: Attention, 70 | hidden_states: torch.Tensor, 71 | encoder_hidden_states: Optional[torch.Tensor] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | image_rotary_emb: Optional[torch.Tensor] = None, 74 | ) -> torch.Tensor: 75 | if attn.add_q_proj is None and encoder_hidden_states is not None: 76 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) 77 | 78 | # 1. QKV projections 79 | query = attn.to_q(hidden_states) 80 | key = attn.to_k(hidden_states) 81 | value = attn.to_v(hidden_states) 82 | 83 | 84 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 85 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 86 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 87 | 88 | 89 | 90 | # 2. QK normalization 91 | if attn.norm_q is not None: 92 | query = attn.norm_q(query) 93 | if attn.norm_k is not None: 94 | key = attn.norm_k(key) 95 | 96 | # 3. Rotational positional embeddings applied to latent stream 97 | if image_rotary_emb is not None: 98 | from diffusers.models.embeddings import apply_rotary_emb 99 | 100 | if attn.add_q_proj is None and encoder_hidden_states is not None: 101 | query = torch.cat( 102 | [ 103 | apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 104 | query[:, :, -encoder_hidden_states.shape[1] :], 105 | ], 106 | dim=2, 107 | ) 108 | key = torch.cat( 109 | [ 110 | apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 111 | key[:, :, -encoder_hidden_states.shape[1] :], 112 | ], 113 | dim=2, 114 | ) 115 | else: 116 | query = apply_rotary_emb(query, image_rotary_emb) 117 | key = apply_rotary_emb(key, image_rotary_emb) 118 | 119 | # 4. Encoder condition QKV projection and normalization 120 | if attn.add_q_proj is not None and encoder_hidden_states is not None: 121 | encoder_query = attn.add_q_proj(encoder_hidden_states) 122 | encoder_key = attn.add_k_proj(encoder_hidden_states) 123 | encoder_value = attn.add_v_proj(encoder_hidden_states) 124 | 125 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 126 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 127 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 128 | 129 | if attn.norm_added_q is not None: 130 | encoder_query = attn.norm_added_q(encoder_query) 131 | if attn.norm_added_k is not None: 132 | encoder_key = attn.norm_added_k(encoder_key) 133 | 134 | query = torch.cat([query, encoder_query], dim=2) 135 | key = torch.cat([key, encoder_key], dim=2) 136 | value = torch.cat([value, encoder_value], dim=2) 137 | 138 | query_image, query_text = query[:, :, :-self.text_length], query[:, :, -self.text_length:] #b h n c 139 | key_image, key_text = key[:, :, :-self.text_length], key[:, :, -self.text_length:] 140 | value_image, value_text = value[:, :, :-self.text_length], value[:, :, -self.text_length:] 141 | 142 | query_image = self.clusterify(query_image) 143 | key_image = self.clusterify(key_image) 144 | value_image = self.clusterify(value_image) 145 | query = torch.cat([query_image, query_text, ],dim=2) 146 | key = torch.cat([key_image, key_text, ], dim=2) 147 | value = torch.cat([value_image, value_text, ], dim=2) 148 | 149 | hidden_states = self.flex_attn(query, key, value) 150 | hidden_states_image, hidden_states_text = hidden_states[:, :, :-self.text_length], hidden_states[:, :, -self.text_length:] #b h n c 151 | hidden_states_image = self.unclusterify(hidden_states_image) 152 | hidden_states = torch.cat([hidden_states_image, hidden_states_text],dim=2) 153 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 154 | hidden_states = hidden_states.to(query.dtype) 155 | 156 | 157 | # 6. Output projection 158 | if encoder_hidden_states is not None: 159 | hidden_states, encoder_hidden_states = ( 160 | hidden_states[:, : -encoder_hidden_states.shape[1]], 161 | hidden_states[:, -encoder_hidden_states.shape[1] :], 162 | ) 163 | 164 | if getattr(attn, "to_out", None) is not None: 165 | hidden_states = attn.to_out[0](hidden_states) 166 | hidden_states = attn.to_out[1](hidden_states) 167 | 168 | if getattr(attn, "to_add_out", None) is not None: 169 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 170 | 171 | return hidden_states, encoder_hidden_states --------------------------------------------------------------------------------