├── EasyCache4HunyuanVideo ├── README.md ├── easycache_sample_video.py ├── hyvideo_svg_easycache.py ├── tools │ └── video_metrics.py └── videos │ ├── baseline_544p.gif │ ├── baseline_720p.gif │ ├── easycache_544p.gif │ └── svg_with_easycache_720p.gif ├── EasyCache4Wan2.1 ├── README.md ├── easycache_generate.py ├── example │ └── grogu.png ├── tools │ └── video_metrics.py └── videos │ ├── i2v_easycache_14b_720p.gif │ ├── i2v_gt_14b_720p.gif │ ├── t2v_easycache_14b_720p.gif │ └── t2v_gt_14b_720p.gif ├── LICENSE ├── README.md └── demo ├── gt ├── 6.gif └── 7.gif ├── our ├── 6.gif └── 7.gif ├── pab ├── 6.gif └── 7.gif └── teacache ├── 6.gif └── 7.gif /EasyCache4HunyuanVideo/README.md: -------------------------------------------------------------------------------- 1 |
2 |

Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching

3 | 4 | Xin Zhou1\*, 5 | Dingkang Liang1\*, 6 | Kaijin Chen1, Tianrui Feng1, 7 | Xiwu Chen2, Hongkai Lin1,
8 | Yikang Ding2, Feiyang Tan2, 9 | Hengshuang Zhao3, 10 | Xiang Bai1† 11 | 12 | 1 Huazhong University of Science and Technology, 2 MEGVII Technology, 3 University of Hong Kong
13 | 14 | (\*) Equal contribution. (†) Corresponding author. 15 | 16 | [![Project](https://img.shields.io/badge/Homepage-project-orange.svg?logo=googlehome)](https://H-EmbodVis.github.io/EasyCache/) 17 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/LMD0311/EasyCache/blob/main/LICENSE) 18 | 19 |
20 | 21 | --- 22 | 23 | This document provides the implementation for accelerating the [**HunyuanVideo**](https://github.com/Tencent/HunyuanVideo) model using **EasyCache**. 24 | 25 | ### ✨ Visual Comparison 26 | 27 | EasyCache significantly accelerates inference speed while maintaining high visual fidelity. 28 | 29 | **Prompt: "A cat walks on the grass, realistic style." (Base Acceleration)** 30 | 31 | | HunyuanVideo (Baseline, 544p, H20) | EasyCache (Ours) | 32 | | :---: | :---: | 33 | | ![Baseline Video](./videos/baseline_544p.gif) | ![Our Video](./videos/easycache_544p.gif) | 34 | | **Inference Time: ~2327s** | **Inference Time: ~1025s (2.3x Speedup)** | 35 | 36 | **Prompt: "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." (SVG with EasyCache)** 37 | 38 | | HunyuanVideo (Baseline, 720p, H20) | SVG with EasyCache (Ours) | 39 | |:---:|:---:| 40 | | ![Baseline 720p GIF](./videos/baseline_720p.gif) | ![EasyCache+SVG 720p GIF](./videos/svg_with_easycache_720p.gif) | 41 | | **Inference Time: ~6572s** | **Inference Time: ~1773s (3.71x Speedup)** | 42 | 43 | 44 | --- 45 | 46 | ### 🚀 Usage Instructions 47 | 48 | This section provides instructions for two settings: base acceleration with EasyCache alone and combined acceleration using EasyCache with SVG. 49 | 50 | #### **1. Base Acceleration (EasyCache Only)** 51 | 52 | **a. Prerequisites** ⚙️ 53 | 54 | Before you begin, please follow the instructions in the [official HunyuanVideo repository](https://github.com/Tencent/HunyuanVideo) to configure the required environment and download the pretrained model weights. 55 | 56 | **b. Copy Files** 📂 57 | 58 | Copy `easycache_sample_video.py` into the root directory of your local `HunyuanVideo` project. 59 | 60 | **c. Run Inference** ▶️ 61 | 62 | Execute the following command from the root of the `HunyuanVideo` project to generate a video. To generate videos in 720p resolution, set the `--video-size` argument to `720 1280`. You can also specify your own custom prompts. 63 | 64 | ```bash 65 | python3 easycache_sample_video.py \ 66 | --video-size 544 960 \ 67 | --video-length 129 \ 68 | --infer-steps 50 \ 69 | --prompt "A cat walks on the grass, realistic style." \ 70 | --flow-reverse \ 71 | --use-cpu-offload \ 72 | --save-path ./results \ 73 | --seed 42 74 | ``` 75 | 76 | #### **2. Combined Acceleration (SVG with EasyCache)** 77 | 78 | **a. Prerequisites** ⚙️ 79 | 80 | Ensure you have set up the environments for both [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) and [SVG](https://github.com/svg-project/Sparse-VideoGen). 81 | 82 | **b. Copy Files** 📂 83 | 84 | Copy `hyvideo_svg_easycache.py` into the root directory of your local `HunyuanVideo` project. 85 | 86 | **c. Run Inference** ▶️ 87 | 88 | Execute the following command to generate a 720p video using both SVG and EasyCache for maximum acceleration. You can also specify your own custom prompts. 89 | 90 | ```bash 91 | python3 hyvideo_svg_easycache.py \ 92 | --video-size 720 1280 \ 93 | --video-length 129 \ 94 | --infer-steps 50 \ 95 | --prompt "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." \ 96 | --embedded-cfg-scale 6.0 \ 97 | --flow-shift 7.0 \ 98 | --flow-reverse \ 99 | --use-cpu-offload \ 100 | --save-path ./results \ 101 | --output_path ./results \ 102 | --pattern "SVG" \ 103 | --num_sampled_rows 64 \ 104 | --sparsity 0.2 \ 105 | --first_times_fp 0.055 \ 106 | --first_layers_fp 0.025 \ 107 | --record_attention \ 108 | --seed 42 109 | ``` 110 | 111 | ### 📊 Evaluating Video Similarity 112 | 113 | We provide a simple script to quickly evaluate the similarity between two videos (e.g., the baseline result and your generated result) using common metrics. 114 | 115 | **Usage** 116 | 117 | ```bash 118 | # install required packages. 119 | pip install lpips numpy tqdm torchmetrics 120 | 121 | python tools/video_metrics.py --original_video video1.mp4 --generated_video video2.mp4 122 | ``` 123 | 124 | - `--original_video`: Path to the first video (e.g., the baseline). 125 | - `--generated_video`: Path to the second video (e.g., the one generated with EasyCache). 126 | 127 | ## 🌹 Acknowledgements 128 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo), and [SVG](https://github.com/svg-project/Sparse-VideoGen) repositories, for their open research and exploration. 129 | 130 | ## 📖 Citation 131 | 132 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation. 133 | ```bibtex 134 | @article{zhou2025easycache, 135 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching}, 136 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang}, 137 | journal={arXiv preprint arXiv:2507.02860}, 138 | year={2025} 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/easycache_sample_video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Tecent Hunyuan Team Authors. All rights reserved. 2 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved. 3 | 4 | import os 5 | import time 6 | from pathlib import Path 7 | from loguru import logger 8 | from datetime import datetime 9 | 10 | from hyvideo.utils.file_utils import save_videos_grid 11 | from hyvideo.config import parse_args 12 | from hyvideo.inference import HunyuanVideoSampler 13 | 14 | from hyvideo.modules.modulate_layers import modulate 15 | from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens 16 | from typing import Any, List, Tuple, Optional, Union, Dict 17 | import torch 18 | import json 19 | import numpy as np 20 | import portalocker 21 | import json 22 | import random 23 | from tqdm import tqdm 24 | from torch.utils.data import Dataset, DataLoader 25 | 26 | 27 | def easycache_forward( 28 | self, 29 | x: torch.Tensor, 30 | t: torch.Tensor, # Should be in range(0, 1000). 31 | text_states: torch.Tensor = None, 32 | text_mask: torch.Tensor = None, # Now we don't use it. 33 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. 34 | freqs_cos: Optional[torch.Tensor] = None, 35 | freqs_sin: Optional[torch.Tensor] = None, 36 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. 37 | return_dict: bool = True, 38 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 39 | torch.cuda.synchronize() 40 | start_time = time.time() 41 | 42 | out = {} 43 | raw_input = x.clone() 44 | img = x 45 | txt = text_states 46 | _, _, ot, oh, ow = x.shape 47 | tt, th, tw = ( 48 | ot // self.patch_size[0], 49 | oh // self.patch_size[1], 50 | ow // self.patch_size[2], 51 | ) 52 | 53 | # Prepare modulation vectors. 54 | vec = self.time_in(t) 55 | 56 | # text modulation 57 | vec = vec + self.vector_in(text_states_2) 58 | 59 | # guidance modulation 60 | if self.guidance_embed: 61 | if guidance is None: 62 | raise ValueError( 63 | "Didn't get guidance strength for guidance distilled model." 64 | ) 65 | 66 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder) 67 | vec = vec + self.guidance_in(guidance) 68 | 69 | if self.cnt < self.ret_steps or self.cnt >= self.num_steps - 1: 70 | should_calc = True 71 | self.accumulated_error = 0 72 | else: 73 | # Check if previous inputs and outputs exist 74 | if hasattr(self, 'previous_raw_input') and hasattr(self, 'previous_output') \ 75 | and self.previous_raw_input is not None and self.previous_output is not None: 76 | 77 | raw_input_change = (raw_input - self.previous_raw_input).abs().mean() 78 | 79 | if hasattr(self, 'k') and self.k is not None: 80 | 81 | output_norm = self.previous_output.abs().mean() 82 | pred_change = self.k * (raw_input_change / output_norm) 83 | self.accumulated_error += pred_change 84 | 85 | if self.accumulated_error < self.thresh: 86 | should_calc = False 87 | else: 88 | should_calc = True 89 | self.accumulated_error = 0 90 | else: 91 | should_calc = True 92 | else: 93 | should_calc = True 94 | 95 | self.previous_raw_input = raw_input.clone() # (1, 16, 33, 68, 120) 96 | 97 | if not should_calc and self.cache is not None: 98 | result = raw_input + self.cache 99 | self.cnt += 1 100 | 101 | if self.cnt >= self.num_steps: 102 | self.cnt = 0 103 | 104 | torch.cuda.synchronize() 105 | end_time = time.time() 106 | self.total_time += (end_time - start_time) 107 | 108 | if return_dict: 109 | out["x"] = result 110 | return out 111 | return result 112 | 113 | img = self.img_in(img) 114 | if self.text_projection == "linear": 115 | txt = self.txt_in(txt) 116 | elif self.text_projection == "single_refiner": 117 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) 118 | else: 119 | raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") 120 | 121 | txt_seq_len = txt.shape[1] 122 | img_seq_len = img.shape[1] 123 | 124 | # Compute cu_squlens and max_seqlen for flash attention 125 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) 126 | cu_seqlens_kv = cu_seqlens_q 127 | max_seqlen_q = img_seq_len + txt_seq_len 128 | max_seqlen_kv = max_seqlen_q 129 | 130 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None 131 | 132 | # --------------------- Pass through DiT blocks ------------------------ 133 | for _, block in enumerate(self.double_blocks): 134 | double_block_args = [ 135 | img, 136 | txt, 137 | vec, 138 | cu_seqlens_q, 139 | cu_seqlens_kv, 140 | max_seqlen_q, 141 | max_seqlen_kv, 142 | freqs_cis, 143 | ] 144 | img, txt = block(*double_block_args) 145 | 146 | # Merge txt and img to pass through single stream blocks. 147 | x = torch.cat((img, txt), 1) 148 | if len(self.single_blocks) > 0: 149 | for _, block in enumerate(self.single_blocks): 150 | single_block_args = [ 151 | x, 152 | vec, 153 | txt_seq_len, 154 | cu_seqlens_q, 155 | cu_seqlens_kv, 156 | max_seqlen_q, 157 | max_seqlen_kv, 158 | (freqs_cos, freqs_sin), 159 | ] 160 | x = block(*single_block_args) 161 | 162 | img = x[:, :img_seq_len, ...] 163 | 164 | # ---------------------------- Final layer ------------------------------ 165 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 166 | 167 | result = self.unpatchify(img, tt, th, tw) 168 | 169 | # store the cache for next step 170 | self.cache = result - raw_input 171 | if hasattr(self, 'previous_output') and self.previous_output is not None: 172 | output_change = (result - self.previous_output).abs().mean() 173 | if hasattr(self, 'prev_prev_raw_input') and self.prev_prev_raw_input is not None: 174 | input_change = (self.previous_raw_input - self.prev_prev_raw_input).abs().mean() 175 | self.k = output_change / input_change 176 | 177 | # update the previous state 178 | self.prev_prev_raw_input = getattr(self, 'previous_raw_input', None) 179 | self.previous_output = result.clone() 180 | 181 | self.cnt += 1 182 | if self.cnt >= self.num_steps: 183 | self.cnt = 0 184 | 185 | torch.cuda.synchronize() 186 | end_time = time.time() 187 | self.total_time += (end_time - start_time) 188 | 189 | if return_dict: 190 | out["x"] = result 191 | return out 192 | return result 193 | 194 | 195 | def main(): 196 | args = parse_args() 197 | 198 | print(args) 199 | models_root_path = Path(args.model_base) 200 | if not models_root_path.exists(): 201 | raise ValueError(f"`models_root` not exists: {models_root_path}") 202 | 203 | # Create save folder to save the samples 204 | os.makedirs(args.save_path, exist_ok=True) 205 | 206 | # Load models 207 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args) 208 | 209 | # Get the updated args 210 | args = hunyuan_video_sampler.args 211 | 212 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0 213 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps 214 | hunyuan_video_sampler.pipeline.transformer.__class__.thresh = 0.025 215 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = easycache_forward 216 | hunyuan_video_sampler.pipeline.transformer.__class__.ret_steps = 5 217 | hunyuan_video_sampler.pipeline.transformer.__class__.k = None 218 | hunyuan_video_sampler.pipeline.transformer.__class__.total_time = 0.0 219 | 220 | # record time cost for DiTs 221 | generation_time = [] 222 | time_cost = { 223 | "GPU_Device": torch.cuda.get_device_name(0), 224 | "number_prompt": None, 225 | "avg_cost_time": None 226 | } 227 | 228 | hunyuan_video_sampler.pipeline.transformer.total_time = 0.0 229 | outputs = hunyuan_video_sampler.predict( 230 | prompt=args.prompt, 231 | height=args.video_size[0], 232 | width=args.video_size[1], 233 | video_length=args.video_length, 234 | seed=args.seed, 235 | negative_prompt=args.neg_prompt, 236 | infer_steps=args.infer_steps, 237 | guidance_scale=args.cfg_scale, 238 | num_videos_per_prompt=1, 239 | flow_shift=args.flow_shift, 240 | batch_size=args.batch_size, 241 | embedded_guidance_scale=args.embedded_cfg_scale 242 | ) 243 | 244 | generation_time.append(hunyuan_video_sampler.pipeline.transformer.total_time) 245 | samples = outputs['samples'] 246 | 247 | # Save samples 248 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 249 | for i, sample in enumerate(samples): 250 | sample = samples[i].unsqueeze(0) 251 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") 252 | save_path = f"{args.save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/', '')}.mp4" 253 | save_videos_grid(sample, save_path, fps=24) 254 | logger.info(f'Sample save to: {save_path}') 255 | 256 | if generation_time: 257 | time_cost["number_prompt"] = len(generation_time) 258 | time_cost["avg_cost_time"] = sum(generation_time) / len(generation_time) if generation_time else 0 259 | 260 | print( 261 | f"GPU_Device: {time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}") 262 | 263 | try: 264 | with open(f"{args.save_path}/time_cost.json", "a+") as f: 265 | portalocker.lock(f, portalocker.LOCK_EX) 266 | f.seek(0) 267 | try: 268 | existing_data = json.load(f) 269 | except (json.JSONDecodeError, FileNotFoundError): 270 | existing_data = [] 271 | 272 | existing_data.append(time_cost) 273 | f.seek(0) 274 | f.truncate() 275 | json.dump(existing_data, f, indent=4) 276 | except Exception as e: 277 | print(f"Error writing time cost to file: {e}") 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/hyvideo_svg_easycache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Tencent Hunyuan Team Authors. All rights reserved. 2 | # Copyright 2025 The SVG Team Authors. All rights reserve. 3 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved. 4 | 5 | import os 6 | import time 7 | import math 8 | import json 9 | from pathlib import Path 10 | from loguru import logger 11 | from datetime import datetime 12 | 13 | import torch 14 | from svg.models.hyvideo.utils.file_utils import save_videos_grid 15 | from svg.models.hyvideo.config import parse_args 16 | from svg.models.hyvideo.inference import HunyuanVideoSampler 17 | from torch.utils.data import Dataset, DataLoader 18 | from typing import Any, List, Tuple, Optional, Union, Dict 19 | import portalocker 20 | from tqdm import tqdm 21 | 22 | 23 | def get_cu_seqlens(text_mask, img_len): 24 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len 25 | 26 | Args: 27 | text_mask (torch.Tensor): the mask of text 28 | img_len (int): the length of image 29 | 30 | Returns: 31 | torch.Tensor: the calculated cu_seqlens for flash attention 32 | """ 33 | batch_size = text_mask.shape[0] 34 | text_len = text_mask.sum(dim=1) 35 | max_len = text_mask.shape[1] + img_len 36 | 37 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 38 | 39 | for i in range(batch_size): 40 | s = text_len[i] + img_len 41 | s1 = i * max_len + s 42 | s2 = (i + 1) * max_len 43 | cu_seqlens[2 * i + 1] = s1 44 | cu_seqlens[2 * i + 2] = s2 45 | 46 | return cu_seqlens 47 | 48 | 49 | @torch.compile() 50 | def easycache_forward( 51 | self, 52 | x: torch.Tensor, 53 | t: torch.Tensor, # Should be in range(0, 1000). 54 | text_states: torch.Tensor = None, 55 | text_mask: torch.Tensor = None, # Now we don't use it. 56 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. 57 | freqs_cos: Optional[torch.Tensor] = None, 58 | freqs_sin: Optional[torch.Tensor] = None, 59 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. 60 | return_dict: bool = True, 61 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 62 | torch.cuda.synchronize() 63 | start_time = time.time() 64 | 65 | out = {} 66 | raw_input = x.clone() 67 | img = x 68 | txt = text_states 69 | _, _, ot, oh, ow = x.shape 70 | tt, th, tw = ( 71 | ot // self.patch_size[0], 72 | oh // self.patch_size[1], 73 | ow // self.patch_size[2], 74 | ) 75 | 76 | # Prepare modulation vectors. 77 | vec = self.time_in(t) 78 | 79 | # text modulation 80 | vec = vec + self.vector_in(text_states_2) 81 | 82 | # guidance modulation 83 | if self.guidance_embed: 84 | if guidance is None: 85 | raise ValueError( 86 | "Didn't get guidance strength for guidance distilled model." 87 | ) 88 | 89 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder) 90 | vec = vec + self.guidance_in(guidance) 91 | 92 | if self.cnt < self.ret_steps or self.cnt >= self.num_steps - 1: 93 | should_calc = True 94 | self.accumulated_error = 0 95 | else: 96 | # Check if previous inputs and outputs exist 97 | if hasattr(self, 'previous_raw_input') and hasattr(self, 'previous_output') \ 98 | and self.previous_raw_input is not None and self.previous_output is not None: 99 | 100 | raw_input_change = (raw_input - self.previous_raw_input).abs().mean() 101 | 102 | if hasattr(self, 'k') and self.k is not None: 103 | 104 | output_norm = self.previous_output.abs().mean() 105 | pred_change = self.k * (raw_input_change / output_norm) 106 | self.accumulated_error += pred_change 107 | 108 | if self.accumulated_error < self.thresh: 109 | should_calc = False 110 | else: 111 | should_calc = True 112 | self.accumulated_error = 0 113 | else: 114 | should_calc = True 115 | else: 116 | should_calc = True 117 | 118 | self.previous_raw_input = raw_input.clone() # (1, 16, 33, 68, 120) 119 | 120 | if not should_calc and self.cache is not None: 121 | result = raw_input + self.cache 122 | self.cnt += 1 123 | 124 | if self.cnt >= self.num_steps: 125 | self.cnt = 0 126 | 127 | torch.cuda.synchronize() 128 | end_time = time.time() 129 | self.total_time += (end_time - start_time) 130 | 131 | if return_dict: 132 | out["x"] = result 133 | return out 134 | return result 135 | 136 | img = self.img_in(img) 137 | if self.text_projection == "linear": 138 | txt = self.txt_in(txt) 139 | elif self.text_projection == "single_refiner": 140 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) 141 | else: 142 | raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}") 143 | 144 | txt_seq_len = txt.shape[1] 145 | img_seq_len = img.shape[1] 146 | 147 | # Compute cu_squlens and max_seqlen for flash attention 148 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) 149 | cu_seqlens_kv = cu_seqlens_q 150 | max_seqlen_q = img_seq_len + txt_seq_len 151 | max_seqlen_kv = max_seqlen_q 152 | 153 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None 154 | 155 | # --------------------- Pass through DiT blocks ------------------------ 156 | for _, block in enumerate(self.double_blocks): 157 | double_block_args = [ 158 | img, 159 | txt, 160 | vec, 161 | cu_seqlens_q, 162 | cu_seqlens_kv, 163 | max_seqlen_q, 164 | max_seqlen_kv, 165 | freqs_cis, 166 | t, 167 | ] 168 | img, txt = block(*double_block_args) 169 | 170 | # Merge txt and img to pass through single stream blocks. 171 | x = torch.cat((img, txt), 1) 172 | if len(self.single_blocks) > 0: 173 | for _, block in enumerate(self.single_blocks): 174 | single_block_args = [ 175 | x, 176 | vec, 177 | txt_seq_len, 178 | cu_seqlens_q, 179 | cu_seqlens_kv, 180 | max_seqlen_q, 181 | max_seqlen_kv, 182 | (freqs_cos, freqs_sin), 183 | t, 184 | ] 185 | x = block(*single_block_args) 186 | img = x[:, :img_seq_len, ...] 187 | 188 | # ---------------------------- Final layer ------------------------------ 189 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 190 | 191 | result = self.unpatchify(img, tt, th, tw) 192 | 193 | # store the cache for next step 194 | self.cache = result - raw_input 195 | if hasattr(self, 'previous_output') and self.previous_output is not None: 196 | output_change = (result - self.previous_output).abs().mean() 197 | if hasattr(self, 'prev_prev_raw_input') and self.prev_prev_raw_input is not None: 198 | input_change = (self.previous_raw_input - self.prev_prev_raw_input).abs().mean() 199 | self.k = output_change / input_change 200 | 201 | # update the previous state 202 | self.prev_prev_raw_input = getattr(self, 'previous_raw_input', None) 203 | self.previous_output = result.clone() 204 | 205 | self.cnt += 1 206 | if self.cnt >= self.num_steps: 207 | self.cnt = 0 208 | 209 | torch.cuda.synchronize() 210 | end_time = time.time() 211 | self.total_time += (end_time - start_time) 212 | 213 | if return_dict: 214 | out["x"] = result 215 | return out 216 | return result 217 | 218 | 219 | def sparsity_to_width(sparsity, context_length, num_frame, frame_size): 220 | seq_len = context_length + num_frame * frame_size 221 | total_elements = seq_len ** 2 222 | 223 | sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements 224 | 225 | width = seq_len * (1 - math.sqrt(1 - sparsity)) 226 | width_frame = width / frame_size 227 | 228 | return width_frame 229 | 230 | 231 | def main(): 232 | args = parse_args() 233 | print(args) 234 | models_root_path = Path("./HunyuanVideo") 235 | if not models_root_path.exists(): 236 | raise ValueError(f"`models_root` not exists: {models_root_path}") 237 | 238 | # Create save folder to save the samples 239 | os.makedirs(args.save_path, exist_ok=True) 240 | 241 | # Load models 242 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args) 243 | 244 | # Get the updated args 245 | args = hunyuan_video_sampler.args 246 | 247 | # Sparsity Related 248 | transformer = hunyuan_video_sampler.pipeline.transformer 249 | for _, block in enumerate(transformer.double_blocks): 250 | block.sparse_args = args 251 | for _, block in enumerate(transformer.single_blocks): 252 | block.sparse_args = args 253 | transformer.sparse_args = args 254 | 255 | print( 256 | f"Memory: {torch.cuda.memory_allocated() // 1024 ** 2} / {torch.cuda.max_memory_allocated() // 1024 ** 2} MB before Inference") 257 | 258 | cfg_size, num_head, head_dim, dtype, device = 1, 24, 128, torch.bfloat16, "cuda" 259 | context_length, num_frame, frame_size = 256, 33, 3600 260 | 261 | # Calculation 262 | spatial_width = temporal_width = sparsity_to_width(args.sparsity, context_length, num_frame, frame_size) 263 | 264 | print(f"Spatial_width: {spatial_width}, Temporal_width: {temporal_width}. Sparsity: {args.sparsity}") 265 | 266 | save_path = args.output_path 267 | if args.pattern == "SVG": 268 | masks = ["spatial", "temporal"] 269 | 270 | def get_attention_mask(mask_name): 271 | 272 | context_length = 256 273 | num_frame = 33 274 | frame_size = 3600 275 | attention_mask = torch.zeros( 276 | (context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu") 277 | 278 | # TODO: fix hard coded mask 279 | if mask_name == "spatial": 280 | pixel_attn_mask = torch.zeros_like(attention_mask[:-context_length, :-context_length], dtype=torch.bool, 281 | device="cpu") 282 | block_size, block_thres = 128, frame_size * 1.5 283 | num_block = math.ceil(num_frame * frame_size / block_size) 284 | for i in range(num_block): 285 | for j in range(num_block): 286 | if abs(i - j) < block_thres // block_size: 287 | pixel_attn_mask[i * block_size: (i + 1) * block_size, 288 | j * block_size: (j + 1) * block_size] = 1 289 | attention_mask[:-context_length, :-context_length] = pixel_attn_mask 290 | 291 | attention_mask[-context_length:, :] = 1 292 | attention_mask[:, -context_length:] = 1 293 | 294 | else: 295 | pixel_attn_mask = torch.zeros_like(attention_mask[:-context_length, :-context_length], dtype=torch.bool, 296 | device=device) 297 | 298 | block_size, block_thres = 128, frame_size * 1.5 299 | num_block = math.ceil(num_frame * frame_size / block_size) 300 | for i in range(num_block): 301 | for j in range(num_block): 302 | if abs(i - j) < block_thres // block_size: 303 | pixel_attn_mask[i * block_size: (i + 1) * block_size, 304 | j * block_size: (j + 1) * block_size] = 1 305 | 306 | pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 307 | 2).reshape( 308 | frame_size * num_frame, frame_size * num_frame) 309 | attention_mask[:-context_length, :-context_length] = pixel_attn_mask 310 | 311 | attention_mask[-context_length:, :] = 1 312 | attention_mask[:, -context_length:] = 1 313 | attention_mask = attention_mask[:args.sample_mse_max_row].cuda() 314 | return attention_mask 315 | 316 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0 317 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps 318 | hunyuan_video_sampler.pipeline.transformer.__class__.thresh = 0.025 319 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = easycache_forward 320 | hunyuan_video_sampler.pipeline.transformer.__class__.ret_steps = 5 321 | hunyuan_video_sampler.pipeline.transformer.__class__.k = None 322 | hunyuan_video_sampler.pipeline.transformer.__class__.total_time = 0.0 323 | 324 | if args.pattern == "SVG": 325 | from svg.models.hyvideo.modules.attenion import Hunyuan_SparseAttn, prepare_flexattention 326 | from svg.models.hyvideo.modules.custom_models import replace_sparse_forward 327 | 328 | AttnModule = Hunyuan_SparseAttn 329 | AttnModule.num_sampled_rows = args.num_sampled_rows 330 | AttnModule.sample_mse_max_row = args.sample_mse_max_row 331 | AttnModule.attention_masks = [get_attention_mask(mask_name) for mask_name in masks] 332 | AttnModule.first_layers_fp = args.first_layers_fp 333 | AttnModule.first_times_fp = args.first_times_fp 334 | 335 | generation_time = [] 336 | time_cost = { 337 | "GPU_Device": torch.cuda.get_device_name(0), 338 | "number_prompt": None, 339 | "avg_cost_time": None 340 | } 341 | # Start sampling 342 | if args.pattern == "SVG": 343 | # We need to get the prompt len in advance, since HunyuanVideo handle the attention mask in a special way 344 | prompt_mask = hunyuan_video_sampler.get_prompt_mask( 345 | prompt=args.prompt, 346 | height=args.video_size[0], 347 | width=args.video_size[1], 348 | video_length=args.video_length, 349 | negative_prompt=args.neg_prompt, 350 | infer_steps=args.infer_steps, 351 | guidance_scale=args.cfg_scale, 352 | num_videos_per_prompt=args.num_videos, 353 | embedded_guidance_scale=args.embedded_cfg_scale 354 | ) 355 | prompt_len = prompt_mask.sum() 356 | 357 | block_mask = prepare_flexattention( 358 | cfg_size, num_head, head_dim, dtype, device, 359 | context_length, prompt_len, num_frame, frame_size, 360 | diag_width=spatial_width, multiplier=temporal_width 361 | ) 362 | AttnModule.block_mask = block_mask 363 | replace_sparse_forward() 364 | 365 | hunyuan_video_sampler.pipeline.transformer.total_time = 0.0 366 | outputs = hunyuan_video_sampler.predict( 367 | prompt=args.prompt, 368 | height=args.video_size[0], 369 | width=args.video_size[1], 370 | video_length=args.video_length, 371 | seed=args.seed, 372 | negative_prompt=args.neg_prompt, 373 | infer_steps=args.infer_steps, 374 | guidance_scale=args.cfg_scale, 375 | num_videos_per_prompt=1, 376 | flow_shift=args.flow_shift, 377 | batch_size=args.batch_size, 378 | embedded_guidance_scale=args.embedded_cfg_scale 379 | ) 380 | generation_time.append(hunyuan_video_sampler.pipeline.transformer.total_time) 381 | samples = outputs['samples'] 382 | 383 | # Save samples 384 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 385 | for i, sample in enumerate(samples): 386 | sample = samples[i].unsqueeze(0) 387 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") 388 | save_path = f"{args.save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/', '')}.mp4" 389 | save_videos_grid(sample, save_path, fps=24) 390 | logger.info(f'Sample save to: {save_path}') 391 | 392 | if generation_time: 393 | time_cost["number_prompt"] = len(generation_time) 394 | time_cost["avg_cost_time"] = sum(generation_time) / len(generation_time) if generation_time else 0 395 | 396 | print( 397 | f"GPU_Device: {time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}") 398 | 399 | try: 400 | with open(f"{args.save_path}/time_cost.json", "a+") as f: 401 | portalocker.lock(f, portalocker.LOCK_EX) 402 | f.seek(0) 403 | try: 404 | existing_data = json.load(f) 405 | except (json.JSONDecodeError, FileNotFoundError): 406 | existing_data = [] 407 | 408 | existing_data.append(time_cost) 409 | f.seek(0) 410 | f.truncate() 411 | json.dump(existing_data, f, indent=4) 412 | except Exception as e: 413 | print(f"Error writing time cost to file: {e}") 414 | 415 | 416 | if __name__ == "__main__": 417 | main() 418 | -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/tools/video_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import torch 5 | import lpips 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torchmetrics.image import StructuralSimilarityIndexMeasure 9 | 10 | def load_video_frames(path, resize_to=None): 11 | """ 12 | Load all frames from a video file as a list of HxWx3 uint8 arrays. 13 | Optionally resize each frame to `resize_to` (w, h). 14 | """ 15 | 16 | cap = cv2.VideoCapture(path) 17 | frames = [] 18 | while True: 19 | ret, img = cap.read() 20 | if not ret: 21 | break 22 | if resize_to is not None: 23 | img = cv2.resize(img, resize_to) 24 | frames.append(np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), axis=0)) 25 | cap.release() 26 | return np.concatenate(frames) 27 | 28 | 29 | def compute_video_metrics(frames_gt, frames_gen, 30 | device, ssim_metric, lpips_fn): 31 | """ 32 | Compute PSNR, SSIM, LPIPS for two lists of frames (uint8 BGR). 33 | All computations on `device`. 34 | Returns (psnr, ssim, lpips) scalars. 35 | """ 36 | # ensure same frame count 37 | # convert to tensors [N,3,H,W], normalize to [0,1] 38 | gt_t = torch.from_numpy(frames_gt).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous() 39 | 40 | gen_t = torch.from_numpy(frames_gen).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous() 41 | 42 | # PSNR (data_range=1.0): -10 * log10(mse) 43 | mse = torch.mean((gt_t - gen_t) ** 2) 44 | psnr = -10.0 * torch.log10(mse) 45 | 46 | # SSIM: returns average over batch 47 | ssim_val = ssim_metric(gen_t, gt_t) 48 | 49 | # LPIPS: expects [-1,1] 50 | with torch.no_grad(): 51 | lpips_val = lpips_fn(gt_t * 2.0 - 1.0, gen_t * 2.0 - 1.0).mean() 52 | 53 | return psnr.item(), ssim_val.item(), lpips_val.item() 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser( 58 | description="Compute PSNR/SSIM/LPIPS on GPU for two folders of .mp4 videos" 59 | ) 60 | parser.add_argument("--original_video", required=True, 61 | help="ground-truth .mp4 videos") 62 | parser.add_argument("--generated_video", required=True, 63 | help="generated .mp4 videos") 64 | parser.add_argument("--device", default="cuda", 65 | help="Torch device, e.g. 'cuda' or 'cpu'") 66 | parser.add_argument("--lpips_net", default="alex", choices=["alex", "vgg"], 67 | help="Backbone for LPIPS") 68 | args = parser.parse_args() 69 | 70 | device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu") 71 | # instantiate metrics on device 72 | ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) 73 | lpips_fn = lpips.LPIPS(net=args.lpips_net, spatial=True).to(device) 74 | 75 | # gather .mp4 filenames 76 | gt_files = args.original_video 77 | gen_set = args.generated_video 78 | 79 | psnrs, ssims, lpips_vals = [], [], [] 80 | for fname in tqdm([gt_files], desc="Videos"): 81 | path_gt = gt_files 82 | path_gen = gen_set 83 | 84 | # load frames; resize generated to match GT dimensions 85 | frames_gt = load_video_frames(path_gt) 86 | frames_gen = load_video_frames(path_gen) 87 | 88 | res = compute_video_metrics(frames_gt, frames_gen, 89 | device, ssim_metric, lpips_fn) 90 | if res is None: 91 | continue 92 | p, s, l = res 93 | psnrs.append(p); 94 | ssims.append(s); 95 | lpips_vals.append(l) 96 | 97 | if not psnrs: 98 | print("No valid videos processed.") 99 | return 100 | 101 | print("\n=== Overall Averages ===") 102 | print(f"Average PSNR : {np.mean(psnrs):.2f} dB") 103 | print(f"Average SSIM : {np.mean(ssims):.4f}") 104 | print(f"Average LPIPS: {np.mean(lpips_vals):.4f}") 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/videos/baseline_544p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/baseline_544p.gif -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/videos/baseline_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/baseline_720p.gif -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/videos/easycache_544p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/easycache_544p.gif -------------------------------------------------------------------------------- /EasyCache4HunyuanVideo/videos/svg_with_easycache_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/svg_with_easycache_720p.gif -------------------------------------------------------------------------------- /EasyCache4Wan2.1/README.md: -------------------------------------------------------------------------------- 1 |
2 |

Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching

3 | 4 | Xin Zhou1\*, 5 | Dingkang Liang1\*, 6 | Kaijin Chen1, Tianrui Feng1, 7 | Xiwu Chen2, Hongkai Lin1,
8 | Yikang Ding2, Feiyang Tan2, 9 | Hengshuang Zhao3, 10 | Xiang Bai1† 11 | 12 | 1 Huazhong University of Science and Technology, 2 MEGVII Technology, 3 University of Hong Kong
13 | 14 | (\*) Equal contribution. (†) Corresponding author. 15 | 16 | [![Project](https://img.shields.io/badge/Homepage-project-orange.svg?logo=googlehome)](https://H-EmbodVis.github.io/EasyCache/) 17 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/LMD0311/EasyCache/blob/main/LICENSE) 18 | 19 |
20 | 21 | --- 22 | 23 | This document provides the implementation for accelerating the [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) model using **EasyCache**. 24 | 25 | ### ✨ Visual Comparison 26 | 27 | EasyCache significantly accelerates inference speed while maintaining high visual fidelity. 28 | 29 | **Prompt: "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."** 30 | 31 | | Wan2.1-14B (Baseline, 720p, H20) | EasyCache (Ours, 720p, H20) | 32 | | :---: | :---: | 33 | | ![Baseline Video](./videos/t2v_gt_14b_720p.gif) | ![Our Video](./videos/t2v_easycache_14b_720p.gif) | 34 | | **Inference Time: ~6862s** | **Inference Time: ~2884s (~2.4x Speedup)** | 35 | 36 | 37 | **Prompt: "A cute green alien child with large ears, wearing a brown robe, sits on a chair and eats a blue cookie at a table, with crumbs scattered on the robe, in a cozy indoor setting."** 38 | 39 | | Wan2.1-14B I2V (Baseline, 720p, H20) | EasyCache (Ours, 720p, H20) | 40 | | :---: | :---: | 41 | | ![Baseline Video](./videos/i2v_gt_14b_720p.gif) | ![Our Video](./videos/i2v_easycache_14b_720p.gif) | 42 | | **Inference Time: ~5302s** | **Inference Time: ~2397s (~2.2x Speedup)** | 43 | 44 | --- 45 | 46 | ### 🚀 Usage Instructions 47 | 48 | #### **1. EasyCache Acceleration for Wan2.1 T2V** 49 | 50 | **a. Prerequisites** ⚙️ 51 | 52 | Before you begin, please follow the instructions in the [official Wan2.1 repository](https://github.com/Wan-Video/Wan2.1) to configure the required environment and download the pretrained model weights. 53 | 54 | **b. Copy Files** 📂 55 | 56 | Copy `easycache_generate.py` into the root directory of your local `Wan2.1` project. 57 | 58 | **c. Run Inference** ▶️ 59 | 60 | Execute the following command from the root of the `Wan2.1` project to generate a video. To generate videos in 720p resolution, set the `--size` argument to `1280*720`. You can also specify your own custom prompts. 61 | 62 | ```bash 63 | python easycache_generate.py \ 64 | --task t2v-14B \ 65 | --size "1280*720" \ 66 | --ckpt_dir ./Wan2.1-T2V-14B \ 67 | --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about." \ 68 | --base_seed 0 69 | ``` 70 | #### **2. EasyCache Acceleration for Wan2.1 I2V** 71 | Execute the following command from the root of the `Wan2.1` project to generate a video. To generate videos in 480p resolution, set the `--size` argument to `832*480` and set `--ckpt_dir` as `./Wan2.1-I2V-14B-480P`. You can also specify your own custom prompts and images. 72 | 73 | ```bash 74 | python easycache_generate.py \ 75 | --task i2v-14B \ 76 | --size "1280*720" \ 77 | --ckpt_dir ./Wan2.1-I2V-14B-720P \ 78 | --image examples/grogu.png \ 79 | --prompt "A cute green alien child with large ears, wearing a brown robe, sits on a chair and eats a blue cookie at a table, with crumbs scattered on the robe, in a cozy indoor setting." \ 80 | --base_seed 0 81 | ``` 82 | 83 | 84 | ### 📊 Evaluating Video Similarity 85 | 86 | We provide a simple script to quickly evaluate the similarity between two videos (e.g., the baseline result and your generated result) using common metrics. 87 | 88 | **Usage** 89 | 90 | ```bash 91 | # install required packages. 92 | pip install lpips numpy tqdm torchmetrics 93 | 94 | python tools/video_metrics.py --original_video video1.mp4 --generated_video video2.mp4 95 | ``` 96 | 97 | - `--original_video`: Path to the first video (e.g., the baseline). 98 | - `--generated_video`: Path to the second video (e.g., the one generated with EasyCache). 99 | 100 | ## 🌹 Acknowledgements 101 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repository, for the open research and exploration. 102 | 103 | ## 📖 Citation 104 | 105 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation. 106 | ```bibtex 107 | @article{zhou2025easycache, 108 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching}, 109 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang}, 110 | journal={arXiv preprint arXiv:2507.02860}, 111 | year={2025} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /EasyCache4Wan2.1/easycache_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved. 3 | 4 | import argparse 5 | from datetime import datetime 6 | import logging 7 | import os 8 | import sys 9 | import warnings 10 | import json 11 | from time import time 12 | import portalocker 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch, random 17 | import torch.distributed as dist 18 | from torch.utils.data import DataLoader, Dataset 19 | from PIL import Image 20 | 21 | import wan 22 | from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES 23 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 24 | from wan.utils.utils import cache_video, cache_image, str2bool 25 | 26 | import gc 27 | from contextlib import contextmanager 28 | import torchvision.transforms.functional as TF 29 | import torch.cuda.amp as amp 30 | import numpy as np 31 | import math 32 | from wan.modules.model import sinusoidal_embedding_1d 33 | from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, 34 | get_sampling_sigmas, retrieve_timesteps) 35 | from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 36 | from tqdm import tqdm 37 | 38 | EXAMPLE_PROMPT = { 39 | "t2v-1.3B": { 40 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", 41 | }, 42 | "t2v-14B": { 43 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", 44 | }, 45 | "t2i-14B": { 46 | "prompt": "一个朴素端庄的美人", 47 | }, 48 | "i2v-14B": { 49 | "prompt": 50 | "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", 51 | "image": 52 | "examples/i2v_input.JPG", 53 | }, 54 | } 55 | 56 | 57 | def t2v_generate(self, 58 | input_prompt, 59 | size=(1280, 720), 60 | frame_num=81, 61 | shift=5.0, 62 | sample_solver='unipc', 63 | sampling_steps=50, 64 | guide_scale=5.0, 65 | n_prompt="", 66 | seed=-1, 67 | offload_model=True): 68 | r""" 69 | Generates video frames from text prompt using diffusion process. 70 | 71 | Args: 72 | input_prompt (`str`): 73 | Text prompt for content generation 74 | size (tupele[`int`], *optional*, defaults to (1280,720)): 75 | Controls video resolution, (width,height). 76 | frame_num (`int`, *optional*, defaults to 81): 77 | How many frames to sample from a video. The number should be 4n+1 78 | shift (`float`, *optional*, defaults to 5.0): 79 | Noise schedule shift parameter. Affects temporal dynamics 80 | sample_solver (`str`, *optional*, defaults to 'unipc'): 81 | Solver used to sample the video. 82 | sampling_steps (`int`, *optional*, defaults to 40): 83 | Number of diffusion sampling steps. Higher values improve quality but slow generation 84 | guide_scale (`float`, *optional*, defaults 5.0): 85 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 86 | n_prompt (`str`, *optional*, defaults to ""): 87 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 88 | seed (`int`, *optional*, defaults to -1): 89 | Random seed for noise generation. If -1, use random seed. 90 | offload_model (`bool`, *optional*, defaults to True): 91 | If True, offloads models to CPU during generation to save VRAM 92 | 93 | Returns: 94 | torch.Tensor: 95 | Generated video frames tensor. Dimensions: (C, N H, W) where: 96 | - C: Color channels (3 for RGB) 97 | - N: Number of frames (81) 98 | - H: Frame height (from size) 99 | - W: Frame width from size) 100 | """ 101 | # preprocess 102 | F = frame_num 103 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, 104 | size[1] // self.vae_stride[1], 105 | size[0] // self.vae_stride[2]) 106 | 107 | seq_len = math.ceil((target_shape[2] * target_shape[3]) / 108 | (self.patch_size[1] * self.patch_size[2]) * 109 | target_shape[1] / self.sp_size) * self.sp_size 110 | 111 | if n_prompt == "": 112 | n_prompt = self.sample_neg_prompt 113 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 114 | seed_g = torch.Generator(device=self.device) 115 | seed_g.manual_seed(seed) 116 | 117 | if not self.t5_cpu: 118 | self.text_encoder.model.to(self.device) 119 | context = self.text_encoder([input_prompt], self.device) 120 | context_null = self.text_encoder([n_prompt], self.device) 121 | if offload_model: 122 | self.text_encoder.model.cpu() 123 | else: 124 | context = self.text_encoder([input_prompt], torch.device('cpu')) 125 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 126 | context = [t.to(self.device) for t in context] 127 | context_null = [t.to(self.device) for t in context_null] 128 | 129 | noise = [ 130 | torch.randn( 131 | target_shape[0], 132 | target_shape[1], 133 | target_shape[2], 134 | target_shape[3], 135 | dtype=torch.float32, 136 | device=self.device, 137 | generator=seed_g) 138 | ] 139 | 140 | @contextmanager 141 | def noop_no_sync(): 142 | yield 143 | 144 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 145 | 146 | # evaluation mode 147 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 148 | 149 | if sample_solver == 'unipc': 150 | sample_scheduler = FlowUniPCMultistepScheduler( 151 | num_train_timesteps=self.num_train_timesteps, 152 | shift=1, 153 | use_dynamic_shifting=False) 154 | sample_scheduler.set_timesteps( 155 | sampling_steps, device=self.device, shift=shift) 156 | timesteps = sample_scheduler.timesteps 157 | elif sample_solver == 'dpm++': 158 | sample_scheduler = FlowDPMSolverMultistepScheduler( 159 | num_train_timesteps=self.num_train_timesteps, 160 | shift=1, 161 | use_dynamic_shifting=False) 162 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 163 | timesteps, _ = retrieve_timesteps( 164 | sample_scheduler, 165 | device=self.device, 166 | sigmas=sampling_sigmas) 167 | else: 168 | raise NotImplementedError("Unsupported solver.") 169 | 170 | # sample videos 171 | latents = noise 172 | 173 | arg_c = {'context': context, 'seq_len': seq_len} 174 | arg_null = {'context': context_null, 'seq_len': seq_len} 175 | 176 | for _, t in enumerate(tqdm(timesteps)): 177 | torch.cuda.synchronize() 178 | start_time = time() 179 | latent_model_input = latents 180 | timestep = [t] 181 | 182 | timestep = torch.stack(timestep) 183 | 184 | self.model.to(self.device) 185 | noise_pred_cond = self.model( 186 | latent_model_input, t=timestep, **arg_c)[0] 187 | noise_pred_uncond = self.model( 188 | latent_model_input, t=timestep, **arg_null)[0] 189 | 190 | noise_pred = noise_pred_uncond + guide_scale * ( 191 | noise_pred_cond - noise_pred_uncond) 192 | 193 | torch.cuda.synchronize() 194 | self.cost_time += (time() - start_time) 195 | 196 | temp_x0 = sample_scheduler.step( 197 | noise_pred.unsqueeze(0), 198 | t, 199 | latents[0].unsqueeze(0), 200 | return_dict=False, 201 | generator=seed_g)[0] 202 | latents = [temp_x0.squeeze(0)] 203 | 204 | x0 = latents 205 | if offload_model: 206 | self.model.cpu() 207 | torch.cuda.empty_cache() 208 | if self.rank == 0: 209 | videos = self.vae.decode(x0) 210 | 211 | del noise, latents 212 | del sample_scheduler 213 | if offload_model: 214 | gc.collect() 215 | torch.cuda.synchronize() 216 | if dist.is_initialized(): 217 | dist.barrier() 218 | 219 | return videos[0] if self.rank == 0 else None 220 | 221 | 222 | def i2v_generate(self, 223 | input_prompt, 224 | img, 225 | max_area=720 * 1280, 226 | frame_num=81, 227 | shift=5.0, 228 | sample_solver='unipc', 229 | sampling_steps=40, 230 | guide_scale=5.0, 231 | n_prompt="", 232 | seed=-1, 233 | offload_model=True): 234 | r""" 235 | Generates video frames from input image and text prompt using diffusion process. 236 | 237 | Args: 238 | input_prompt (`str`): 239 | Text prompt for content generation. 240 | img (PIL.Image.Image): 241 | Input image tensor. Shape: [3, H, W] 242 | max_area (`int`, *optional*, defaults to 720*1280): 243 | Maximum pixel area for latent space calculation. Controls video resolution scaling 244 | frame_num (`int`, *optional*, defaults to 81): 245 | How many frames to sample from a video. The number should be 4n+1 246 | shift (`float`, *optional*, defaults to 5.0): 247 | Noise schedule shift parameter. Affects temporal dynamics 248 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 249 | sample_solver (`str`, *optional*, defaults to 'unipc'): 250 | Solver used to sample the video. 251 | sampling_steps (`int`, *optional*, defaults to 40): 252 | Number of diffusion sampling steps. Higher values improve quality but slow generation 253 | guide_scale (`float`, *optional*, defaults 5.0): 254 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 255 | n_prompt (`str`, *optional*, defaults to ""): 256 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 257 | seed (`int`, *optional*, defaults to -1): 258 | Random seed for noise generation. If -1, use random seed 259 | offload_model (`bool`, *optional*, defaults to True): 260 | If True, offloads models to CPU during generation to save VRAM 261 | 262 | Returns: 263 | torch.Tensor: 264 | Generated video frames tensor. Dimensions: (C, N H, W) where: 265 | - C: Color channels (3 for RGB) 266 | - N: Number of frames (81) 267 | - H: Frame height (from max_area) 268 | - W: Frame width from max_area) 269 | """ 270 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) 271 | 272 | F = frame_num 273 | h, w = img.shape[1:] 274 | aspect_ratio = h / w 275 | lat_h = round( 276 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 277 | self.patch_size[1] * self.patch_size[1]) 278 | lat_w = round( 279 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 280 | self.patch_size[2] * self.patch_size[2]) 281 | h = lat_h * self.vae_stride[1] 282 | w = lat_w * self.vae_stride[2] 283 | 284 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 285 | self.patch_size[1] * self.patch_size[2]) 286 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 287 | 288 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 289 | seed_g = torch.Generator(device=self.device) 290 | seed_g.manual_seed(seed) 291 | noise = torch.randn( 292 | self.vae.model.z_dim, 293 | (F - 1) // self.vae_stride[0] + 1, 294 | lat_h, 295 | lat_w, 296 | dtype=torch.float32, 297 | generator=seed_g, 298 | device=self.device) 299 | 300 | msk = torch.ones(1, F, lat_h, lat_w, device=self.device) 301 | msk[:, 1:] = 0 302 | msk = torch.concat([ 303 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 304 | ], 305 | dim=1) 306 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 307 | msk = msk.transpose(1, 2)[0] 308 | 309 | if n_prompt == "": 310 | n_prompt = self.sample_neg_prompt 311 | 312 | # preprocess 313 | if not self.t5_cpu: 314 | self.text_encoder.model.to(self.device) 315 | context = self.text_encoder([input_prompt], self.device) 316 | context_null = self.text_encoder([n_prompt], self.device) 317 | if offload_model: 318 | self.text_encoder.model.cpu() 319 | else: 320 | context = self.text_encoder([input_prompt], torch.device('cpu')) 321 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 322 | context = [t.to(self.device) for t in context] 323 | context_null = [t.to(self.device) for t in context_null] 324 | 325 | self.clip.model.to(self.device) 326 | clip_context = self.clip.visual([img[:, None, :, :]]) 327 | if offload_model: 328 | self.clip.model.cpu() 329 | 330 | y = self.vae.encode([ 331 | torch.concat([ 332 | torch.nn.functional.interpolate( 333 | img[None].cpu(), size=(h, w), mode='bicubic').transpose( 334 | 0, 1), 335 | torch.zeros(3, F - 1, h, w) 336 | ], 337 | dim=1).to(self.device) 338 | ])[0] 339 | y = torch.concat([msk, y]) 340 | 341 | @contextmanager 342 | def noop_no_sync(): 343 | yield 344 | 345 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 346 | 347 | # evaluation mode 348 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 349 | 350 | if sample_solver == 'unipc': 351 | sample_scheduler = FlowUniPCMultistepScheduler( 352 | num_train_timesteps=self.num_train_timesteps, 353 | shift=1, 354 | use_dynamic_shifting=False) 355 | sample_scheduler.set_timesteps( 356 | sampling_steps, device=self.device, shift=shift) 357 | timesteps = sample_scheduler.timesteps 358 | elif sample_solver == 'dpm++': 359 | sample_scheduler = FlowDPMSolverMultistepScheduler( 360 | num_train_timesteps=self.num_train_timesteps, 361 | shift=1, 362 | use_dynamic_shifting=False) 363 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 364 | timesteps, _ = retrieve_timesteps( 365 | sample_scheduler, 366 | device=self.device, 367 | sigmas=sampling_sigmas) 368 | else: 369 | raise NotImplementedError("Unsupported solver.") 370 | 371 | # sample videos 372 | latent = noise 373 | 374 | arg_c = { 375 | 'context': [context[0]], 376 | 'clip_fea': clip_context, 377 | 'seq_len': max_seq_len, 378 | 'y': [y], 379 | # 'cond_flag': True, 380 | } 381 | 382 | arg_null = { 383 | 'context': context_null, 384 | 'clip_fea': clip_context, 385 | 'seq_len': max_seq_len, 386 | 'y': [y], 387 | # 'cond_flag': False, 388 | } 389 | 390 | if offload_model: 391 | torch.cuda.empty_cache() 392 | 393 | self.model.to(self.device) 394 | for _, t in enumerate(tqdm(timesteps)): 395 | torch.cuda.synchronize() 396 | start_time = time() 397 | latent_model_input = [latent.to(self.device)] 398 | timestep = [t] 399 | 400 | timestep = torch.stack(timestep).to(self.device) 401 | 402 | noise_pred_cond = self.model( 403 | latent_model_input, t=timestep, **arg_c)[0].to( 404 | torch.device('cpu') if offload_model else self.device) 405 | if offload_model: 406 | torch.cuda.empty_cache() 407 | noise_pred_uncond = self.model( 408 | latent_model_input, t=timestep, **arg_null)[0].to( 409 | torch.device('cpu') if offload_model else self.device) 410 | if offload_model: 411 | torch.cuda.empty_cache() 412 | 413 | noise_pred = noise_pred_uncond + guide_scale * ( 414 | noise_pred_cond - noise_pred_uncond) 415 | 416 | latent = latent.to( 417 | torch.device('cpu') if offload_model else self.device) 418 | 419 | torch.cuda.synchronize() 420 | self.cost_time += (time() - start_time) 421 | 422 | temp_x0 = sample_scheduler.step( 423 | noise_pred.unsqueeze(0), 424 | t, 425 | latent.unsqueeze(0), 426 | return_dict=False, 427 | generator=seed_g)[0] 428 | latent = temp_x0.squeeze(0) 429 | 430 | x0 = [latent.to(self.device)] 431 | del latent_model_input, timestep 432 | 433 | if offload_model: 434 | self.model.cpu() 435 | torch.cuda.empty_cache() 436 | 437 | if self.rank == 0: 438 | videos = self.vae.decode(x0) 439 | 440 | del noise, latent 441 | del sample_scheduler 442 | if offload_model: 443 | gc.collect() 444 | torch.cuda.synchronize() 445 | if dist.is_initialized(): 446 | dist.barrier() 447 | 448 | return videos[0] if self.rank == 0 else None 449 | 450 | 451 | def easycache_forward( 452 | self, 453 | x, 454 | t, 455 | context, 456 | seq_len, 457 | clip_fea=None, 458 | y=None, 459 | ): 460 | """ 461 | Args: 462 | x (List[Tensor]): List of input video tensors with shape [C_in, F, H, W] 463 | t (Tensor): Diffusion timesteps tensor of shape [B] 464 | context (List[Tensor]): List of text embeddings each with shape [L, C] 465 | seq_len (int): Maximum sequence length for positional encoding 466 | clip_fea (Tensor, optional): CLIP image features for image-to-video mode 467 | y (List[Tensor], optional): Conditional video inputs for image-to-video mode 468 | 469 | Returns: 470 | List[Tensor]: List of denoised video tensors with original input shapes 471 | """ 472 | if self.model_type == 'i2v': 473 | assert clip_fea is not None and y is not None 474 | 475 | # Store original raw input for end-to-end caching 476 | raw_input = [u.clone() for u in x] 477 | 478 | # params 479 | device = self.patch_embedding.weight.device 480 | if self.freqs.device != device: 481 | self.freqs = self.freqs.to(device) 482 | 483 | if y is not None: 484 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 485 | 486 | # Track which type of step (even=condition, odd=uncondition) 487 | self.is_even = (self.cnt % 2 == 0) 488 | 489 | # Only make decision on even (condition) steps 490 | if self.is_even: 491 | # Always compute first ret_steps and last steps 492 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: 493 | self.should_calc_current_pair = True 494 | self.accumulated_error_even = 0 495 | else: 496 | # Check if we have previous step data for comparison 497 | if hasattr(self, 'previous_raw_input_even') and hasattr(self, 'previous_raw_output_even') and \ 498 | self.previous_raw_input_even is not None and self.previous_raw_output_even is not None: 499 | # Calculate input changes 500 | raw_input_change = torch.cat([ 501 | (u - v).flatten() for u, v in zip(raw_input, self.previous_raw_input_even) 502 | ]).abs().mean() 503 | 504 | # Compute predicted change if we have k factors 505 | if hasattr(self, 'k') and self.k is not None: 506 | # Calculate output norm for relative comparison 507 | output_norm = torch.cat([ 508 | u.flatten() for u in self.previous_raw_output_even 509 | ]).abs().mean() 510 | pred_change = self.k * (raw_input_change / output_norm) 511 | combined_pred_change = pred_change 512 | # Accumulate predicted error 513 | if not hasattr(self, 'accumulated_error_even'): 514 | self.accumulated_error_even = 0 515 | self.accumulated_error_even += combined_pred_change 516 | # Decide if we need full calculation 517 | if self.accumulated_error_even < self.thresh: 518 | self.should_calc_current_pair = False 519 | else: 520 | self.should_calc_current_pair = True 521 | self.accumulated_error_even = 0 522 | else: 523 | # First time after ret_steps or missing k factors, need to calculate 524 | self.should_calc_current_pair = True 525 | else: 526 | # No previous data yet, must calculate 527 | self.should_calc_current_pair = True 528 | 529 | # Store current input state 530 | self.previous_raw_input_even = [u.clone() for u in raw_input] 531 | 532 | # Check if we can use cached output and return early 533 | if self.is_even and not self.should_calc_current_pair and \ 534 | hasattr(self, 'previous_raw_output_even') and self.previous_raw_output_even is not None: 535 | # Use cached output directly 536 | self.cnt += 1 537 | # Check if we've reached the end of sampling 538 | if self.cnt >= self.num_steps: 539 | self.cnt = 0 540 | 541 | return [(u + v).float() for u, v in zip(raw_input, self.cache_even)] 542 | 543 | elif not self.is_even and not self.should_calc_current_pair and \ 544 | hasattr(self, 'previous_raw_output_odd') and self.previous_raw_output_odd is not None: 545 | # Use cached output directly 546 | self.cnt += 1 547 | 548 | # Check if we've reached the end of sampling 549 | if self.cnt >= self.num_steps: 550 | self.cnt = 0 551 | 552 | # return [u.float() for u in self.previous_raw_output_odd] 553 | return [(u + v).float() for u, v in zip(raw_input, self.cache_odd)] 554 | 555 | # Continue with normal processing since we need to calculate 556 | # embeddings 557 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 558 | grid_sizes = torch.stack( 559 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 560 | x = [u.flatten(2).transpose(1, 2) for u in x] 561 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 562 | assert seq_lens.max() <= seq_len 563 | x = torch.cat([ 564 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 565 | dim=1) for u in x 566 | ]) 567 | 568 | # time embeddings 569 | with amp.autocast(dtype=torch.float32): 570 | e = self.time_embedding( 571 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 572 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 573 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 574 | 575 | # context 576 | context_lens = None 577 | context = self.text_embedding( 578 | torch.stack([ 579 | torch.cat( 580 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 581 | for u in context 582 | ])) 583 | 584 | if clip_fea is not None: 585 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 586 | context = torch.concat([context_clip, context], dim=1) 587 | 588 | # arguments 589 | kwargs = dict( 590 | e=e0, 591 | seq_lens=seq_lens, 592 | grid_sizes=grid_sizes, 593 | freqs=self.freqs, 594 | context=context, 595 | context_lens=context_lens) 596 | 597 | # Apply transformer blocks 598 | for block in self.blocks: 599 | x = block(x, **kwargs) 600 | 601 | # Apply head 602 | x = self.head(x, e) 603 | 604 | # Unpatchify 605 | output = self.unpatchify(x, grid_sizes) 606 | 607 | # Update cache and calculate change rates if needed 608 | if self.is_even: # Condition path 609 | # If we have previous output, calculate k factors for future predictions 610 | if hasattr(self, 'previous_raw_output_even') and self.previous_raw_output_even is not None: 611 | # Calculate output change at the raw level 612 | output_change = torch.cat([ 613 | (u - v).flatten() for u, v in zip(output, self.previous_raw_output_even) 614 | ]).abs().mean() 615 | 616 | # Check if we have previous input state for comparison 617 | if hasattr(self, 'prev_prev_raw_input_even') and self.prev_prev_raw_input_even is not None: 618 | # Calculate input change 619 | input_change = torch.cat([ 620 | (u - v).flatten() for u, v in zip( 621 | self.previous_raw_input_even, self.prev_prev_raw_input_even 622 | ) 623 | ]).abs().mean() 624 | 625 | self.k = output_change / input_change 626 | 627 | # Update history 628 | self.prev_prev_raw_input_even = getattr(self, 'previous_raw_input_even', None) 629 | self.previous_raw_output_even = [u.clone() for u in output] 630 | self.cache_even = [u - v for u, v in zip(output, raw_input)] 631 | 632 | else: # Uncondition path 633 | # Store output for unconditional path 634 | self.previous_raw_output_odd = [u.clone() for u in output] 635 | self.cache_odd = [u - v for u, v in zip(output, raw_input)] 636 | 637 | # Update counter 638 | self.cnt += 1 639 | if self.cnt >= self.num_steps: 640 | self.cnt = 0 641 | self.skip_cond_step = [] 642 | self.skip_uncond_step = [] 643 | 644 | return [u.float() for u in output] 645 | 646 | 647 | def _validate_args(args): 648 | # Basic check 649 | assert args.ckpt_dir is not None, "Please specify the checkpoint directory." 650 | assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" 651 | assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" 652 | 653 | # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. 654 | if args.sample_steps is None: 655 | args.sample_steps = 40 if "i2v" in args.task else 50 656 | 657 | if args.sample_shift is None: 658 | args.sample_shift = 5.0 659 | if "i2v" in args.task and args.size in ["832*480", "480*832"]: 660 | args.sample_shift = 3.0 661 | 662 | # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. 663 | if args.frame_num is None: 664 | args.frame_num = 1 if "t2i" in args.task else 81 665 | 666 | # T2I frame_num check 667 | if "t2i" in args.task: 668 | assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" 669 | 670 | args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( 671 | 0, sys.maxsize) 672 | # Size check 673 | assert args.size in SUPPORTED_SIZES[ 674 | args. 675 | task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" 676 | 677 | 678 | def _parse_args(): 679 | parser = argparse.ArgumentParser( 680 | description="Generate a image or video from a text prompt or image using Wan" 681 | ) 682 | parser.add_argument( 683 | "--task", 684 | type=str, 685 | default="t2v-14B", 686 | choices=list(WAN_CONFIGS.keys()), 687 | help="The task to run.") 688 | parser.add_argument( 689 | "--size", 690 | type=str, 691 | default="1280*720", 692 | choices=list(SIZE_CONFIGS.keys()), 693 | help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." 694 | ) 695 | parser.add_argument( 696 | "--frame_num", 697 | type=int, 698 | default=None, 699 | help="How many frames to sample from a image or video. The number should be 4n+1" 700 | ) 701 | parser.add_argument( 702 | "--ckpt_dir", 703 | type=str, 704 | default="./model_weights/Wan2.1-T2V-1.3B", 705 | help="The path to the checkpoint directory.") 706 | parser.add_argument( 707 | "--offload_model", 708 | type=str2bool, 709 | default=None, 710 | help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." 711 | ) 712 | parser.add_argument( 713 | "--ulysses_size", 714 | type=int, 715 | default=1, 716 | help="The size of the ulysses parallelism in DiT.") 717 | parser.add_argument( 718 | "--ring_size", 719 | type=int, 720 | default=1, 721 | help="The size of the ring attention parallelism in DiT.") 722 | parser.add_argument( 723 | "--t5_fsdp", 724 | action="store_true", 725 | default=False, 726 | help="Whether to use FSDP for T5.") 727 | parser.add_argument( 728 | "--t5_cpu", 729 | action="store_true", 730 | default=False, 731 | help="Whether to place T5 model on CPU.") 732 | parser.add_argument( 733 | "--dit_fsdp", 734 | action="store_true", 735 | default=False, 736 | help="Whether to use FSDP for DiT.") 737 | parser.add_argument( 738 | "--save_file", 739 | type=str, 740 | default=None, 741 | help="The file to save the generated image or video to.") 742 | parser.add_argument( 743 | "--prompt", 744 | type=str, 745 | default=None, 746 | help="The prompt to generate the image or video from.") 747 | parser.add_argument( 748 | "--use_prompt_extend", 749 | action="store_true", 750 | default=False, 751 | help="Whether to use prompt extend.") 752 | parser.add_argument( 753 | "--prompt_extend_method", 754 | type=str, 755 | default="local_qwen", 756 | choices=["dashscope", "local_qwen"], 757 | help="The prompt extend method to use.") 758 | parser.add_argument( 759 | "--prompt_extend_model", 760 | type=str, 761 | default=None, 762 | help="The prompt extend model to use.") 763 | parser.add_argument( 764 | "--prompt_extend_target_lang", 765 | type=str, 766 | default="ch", 767 | choices=["ch", "en"], 768 | help="The target language of prompt extend.") 769 | parser.add_argument( 770 | "--base_seed", 771 | type=int, 772 | default=-1, 773 | help="The seed to use for generating the image or video.") 774 | parser.add_argument( 775 | "--image", 776 | type=str, 777 | default=None, 778 | help="The image to generate the video from.") 779 | parser.add_argument( 780 | "--sample_solver", 781 | type=str, 782 | default='unipc', 783 | choices=['unipc', 'dpm++'], 784 | help="The solver used to sample.") 785 | parser.add_argument( 786 | "--sample_steps", type=int, default=None, help="The sampling steps.") 787 | parser.add_argument( 788 | "--sample_shift", 789 | type=float, 790 | default=None, 791 | help="Sampling shift factor for flow matching schedulers.") 792 | parser.add_argument( 793 | "--sample_guide_scale", 794 | type=float, 795 | default=5.0, 796 | help="Classifier free guidance scale.") 797 | parser.add_argument( 798 | "--thresh", 799 | type=float, 800 | default=0.05, 801 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup") 802 | parser.add_argument( 803 | "--ret_steps", 804 | default=10, 805 | type=int, 806 | help="Number of steps to retain in the cache. Default is 10.") 807 | parser.add_argument( 808 | "--alpha", 809 | default=0., 810 | type=float, 811 | help="Averaging factor for the cache update. Default is 0.5.") 812 | parser.add_argument( 813 | "--beta", 814 | default=1.0, 815 | type=float, 816 | help="Averaging factor for the k_t and k_x update. Default is 1.0.") 817 | parser.add_argument( 818 | "--start_idx", 819 | type=int, 820 | default=0) 821 | parser.add_argument( 822 | "--end_idx", 823 | type=int, 824 | default=946) 825 | parser.add_argument( 826 | "--out_dir", 827 | type=str, 828 | default="./output", 829 | ) 830 | 831 | args = parser.parse_args() 832 | 833 | _validate_args(args) 834 | 835 | return args 836 | 837 | 838 | def _init_logging(rank): 839 | # logging 840 | if rank == 0: 841 | # set format 842 | logging.basicConfig( 843 | level=logging.INFO, 844 | format="[%(asctime)s] %(levelname)s: %(message)s", 845 | handlers=[logging.StreamHandler(stream=sys.stdout)]) 846 | else: 847 | logging.basicConfig(level=logging.ERROR) 848 | 849 | 850 | def generate(args): 851 | rank = int(os.getenv("RANK", 0)) 852 | world_size = int(os.getenv("WORLD_SIZE", 1)) 853 | local_rank = int(os.getenv("LOCAL_RANK", 0)) 854 | device = local_rank 855 | _init_logging(rank) 856 | 857 | if args.offload_model is None: 858 | args.offload_model = False if world_size > 1 else True 859 | logging.info( 860 | f"offload_model is not specified, set to {args.offload_model}.") 861 | if world_size > 1: 862 | torch.cuda.set_device(local_rank) 863 | dist.init_process_group( 864 | backend="nccl", 865 | init_method="env://", 866 | rank=rank, 867 | world_size=world_size) 868 | else: 869 | assert not ( 870 | args.t5_fsdp or args.dit_fsdp 871 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." 872 | assert not ( 873 | args.ulysses_size > 1 or args.ring_size > 1 874 | ), f"context parallel are not supported in non-distributed environments." 875 | 876 | if args.ulysses_size > 1 or args.ring_size > 1: 877 | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." 878 | from xfuser.core.distributed import (initialize_model_parallel, 879 | init_distributed_environment) 880 | init_distributed_environment( 881 | rank=dist.get_rank(), world_size=dist.get_world_size()) 882 | 883 | initialize_model_parallel( 884 | sequence_parallel_degree=dist.get_world_size(), 885 | ring_degree=args.ring_size, 886 | ulysses_degree=args.ulysses_size, 887 | ) 888 | 889 | if args.use_prompt_extend: 890 | if args.prompt_extend_method == "dashscope": 891 | prompt_expander = DashScopePromptExpander( 892 | model_name=args.prompt_extend_model, is_vl="i2v" in args.task) 893 | elif args.prompt_extend_method == "local_qwen": 894 | prompt_expander = QwenPromptExpander( 895 | model_name=args.prompt_extend_model, 896 | is_vl="i2v" in args.task, 897 | device=rank) 898 | else: 899 | raise NotImplementedError( 900 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 901 | 902 | cfg = WAN_CONFIGS[args.task] 903 | if args.ulysses_size > 1: 904 | assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." 905 | 906 | logging.info(f"Generation job args: {args}") 907 | logging.info(f"Generation model config: {cfg}") 908 | 909 | if dist.is_initialized(): 910 | base_seed = [args.base_seed] if rank == 0 else [None] 911 | dist.broadcast_object_list(base_seed, src=0) 912 | args.base_seed = base_seed[0] 913 | 914 | if "t2v" in args.task or "t2i" in args.task: 915 | if args.prompt is None: 916 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] 917 | logging.info(f"Input prompt: {args.prompt}") 918 | if args.use_prompt_extend: 919 | logging.info("Extending prompt ...") 920 | if rank == 0: 921 | prompt_output = prompt_expander( 922 | args.prompt, 923 | tar_lang=args.prompt_extend_target_lang, 924 | seed=args.base_seed) 925 | if prompt_output.status == False: 926 | logging.info( 927 | f"Extending prompt failed: {prompt_output.message}") 928 | logging.info("Falling back to original prompt.") 929 | input_prompt = args.prompt 930 | else: 931 | input_prompt = prompt_output.prompt 932 | input_prompt = [input_prompt] 933 | else: 934 | input_prompt = [None] 935 | if dist.is_initialized(): 936 | dist.broadcast_object_list(input_prompt, src=0) 937 | args.prompt = input_prompt[0] 938 | logging.info(f"Extended prompt: {args.prompt}") 939 | 940 | logging.info("Creating WanT2V pipeline.") 941 | wan_t2v = wan.WanT2V( 942 | config=cfg, 943 | checkpoint_dir=args.ckpt_dir, 944 | device_id=device, 945 | rank=rank, 946 | t5_fsdp=args.t5_fsdp, 947 | dit_fsdp=args.dit_fsdp, 948 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), 949 | t5_cpu=args.t5_cpu, 950 | ) 951 | 952 | generation_time = [] 953 | time_cost = {"GPU_Device": torch.cuda.get_device_name(0), "number_prompt": None, "avg_cost_time": None} 954 | wan_t2v.__class__.cost_time = 0 955 | wan_t2v.__class__.generate = t2v_generate 956 | wan_t2v.model.__class__.forward = easycache_forward 957 | wan_t2v.model.__class__.cnt = 0 958 | wan_t2v.model.__class__.skip_cond_step = [] 959 | wan_t2v.model.__class__.skip_uncond_step = [] 960 | wan_t2v.model.__class__.num_steps = args.sample_steps * 2 961 | wan_t2v.model.__class__.thresh = args.thresh 962 | wan_t2v.model.__class__.accumulated_error_even = 0 963 | wan_t2v.model.__class__.should_calc_current_pair = True 964 | wan_t2v.model.__class__.k = None 965 | 966 | wan_t2v.model.__class__.previous_raw_input_even = None 967 | wan_t2v.model.__class__.previous_raw_output_even = None 968 | wan_t2v.model.__class__.previous_raw_output_odd = None 969 | wan_t2v.model.__class__.prev_prev_raw_input_even = None 970 | wan_t2v.model.__class__.cache_even = None 971 | wan_t2v.model.__class__.cache_odd = None 972 | 973 | wan_t2v.cost_time = 0 974 | wan_t2v.model.__class__.ret_steps = 10 * 2 975 | wan_t2v.model.__class__.cutoff_steps = args.sample_steps * 2 - 2 976 | 977 | print( 978 | f"Generating {'image' if 't2i' in args.task else 'video'} ...") 979 | 980 | # start_time = time() 981 | video = wan_t2v.generate( 982 | args.prompt, 983 | size=SIZE_CONFIGS[args.size], 984 | frame_num=args.frame_num, 985 | shift=args.sample_shift, 986 | sample_solver=args.sample_solver, 987 | sampling_steps=args.sample_steps, 988 | guide_scale=args.sample_guide_scale, 989 | seed=args.base_seed, 990 | offload_model=args.offload_model) 991 | generation_time.append(wan_t2v.cost_time) 992 | if rank == 0: 993 | if args.save_file is None: 994 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") 995 | formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] 996 | suffix = '.png' if "t2i" in args.task else '.mp4' 997 | args.save_file = f"{args.task}_easycache_thresh{args.thresh}_step{args.sample_steps}_{formatted_prompt}_{formatted_time}" + suffix 998 | 999 | if "t2i" in args.task: 1000 | logging.info(f"Saving generated image to {args.save_file}") 1001 | cache_image( 1002 | tensor=video.squeeze(1)[None], 1003 | save_file=args.save_file, 1004 | nrow=1, 1005 | normalize=True, 1006 | value_range=(-1, 1)) 1007 | else: 1008 | logging.info(f"Saving generated video to {args.save_file}") 1009 | cache_video( 1010 | tensor=video[None], 1011 | save_file=args.save_file, 1012 | fps=cfg.sample_fps, 1013 | nrow=1, 1014 | normalize=True, 1015 | value_range=(-1, 1)) 1016 | logging.info("Finished.") 1017 | 1018 | 1019 | else: 1020 | if args.prompt is None: 1021 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] 1022 | if args.image is None: 1023 | args.image = EXAMPLE_PROMPT[args.task]["image"] 1024 | print(f"Input prompt: {args.prompt}") 1025 | print(f"Input image: {args.image}") 1026 | 1027 | img = Image.open(args.image).convert("RGB") 1028 | if args.use_prompt_extend: 1029 | print("Extending prompt ...") 1030 | if rank == 0: 1031 | prompt_output = prompt_expander( 1032 | args.prompt, 1033 | tar_lang=args.prompt_extend_target_lang, 1034 | image=img, 1035 | seed=args.base_seed) 1036 | if prompt_output.status == False: 1037 | print( 1038 | f"Extending prompt failed: {prompt_output.message}") 1039 | print("Falling back to original prompt.") 1040 | input_prompt = args.prompt 1041 | else: 1042 | input_prompt = prompt_output.prompt 1043 | input_prompt = [input_prompt] 1044 | else: 1045 | input_prompt = [None] 1046 | if dist.is_initialized(): 1047 | dist.broadcast_object_list(input_prompt, src=0) 1048 | args.prompt = input_prompt[0] 1049 | print(f"Extended prompt: {args.prompt}") 1050 | 1051 | print("Creating WanI2V pipeline.") 1052 | wan_i2v = wan.WanI2V( 1053 | config=cfg, 1054 | checkpoint_dir=args.ckpt_dir, 1055 | device_id=device, 1056 | rank=rank, 1057 | t5_fsdp=args.t5_fsdp, 1058 | dit_fsdp=args.dit_fsdp, 1059 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), 1060 | t5_cpu=args.t5_cpu, 1061 | ) 1062 | generation_time = [] 1063 | time_cost = {"GPU_Device": torch.cuda.get_device_name(0), "number_prompt": None, "avg_cost_time": None} 1064 | wan_i2v.__class__.generate = i2v_generate 1065 | wan_i2v.model.__class__.forward = easycache_forward 1066 | wan_i2v.model.__class__.cnt = 0 1067 | wan_i2v.model.__class__.num_steps = args.sample_steps * 2 1068 | wan_i2v.model.__class__.thresh = args.thresh 1069 | 1070 | wan_i2v.model.__class__.accumulated_error_even = 0 1071 | wan_i2v.model.__class__.should_calc_current_pair = True 1072 | wan_i2v.model.__class__.k = None 1073 | 1074 | wan_i2v.model.__class__.previous_raw_input_even = None 1075 | wan_i2v.model.__class__.previous_raw_output_even = None 1076 | wan_i2v.model.__class__.previous_raw_output_odd = None 1077 | wan_i2v.model.__class__.prev_prev_raw_input_even = None 1078 | wan_i2v.model.__class__.cache_even = None 1079 | wan_i2v.model.__class__.cache_odd = None 1080 | 1081 | wan_i2v.cost_time = 0 1082 | 1083 | wan_i2v.model.__class__.ret_steps = 10 * 2 1084 | wan_i2v.model.__class__.cutoff_steps = args.sample_steps * 2 - 2 1085 | 1086 | print("Generating video ...") 1087 | video = wan_i2v.generate( 1088 | args.prompt, 1089 | img, 1090 | max_area=MAX_AREA_CONFIGS[args.size], 1091 | frame_num=args.frame_num, 1092 | shift=args.sample_shift, 1093 | sample_solver=args.sample_solver, 1094 | sampling_steps=args.sample_steps, 1095 | guide_scale=args.sample_guide_scale, 1096 | seed=args.base_seed, 1097 | offload_model=args.offload_model) 1098 | generation_time.append(wan_i2v.cost_time) 1099 | 1100 | if rank == 0: 1101 | if args.save_file is None: 1102 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") 1103 | formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] 1104 | suffix = '.png' if "t2i" in args.task else '.mp4' 1105 | args.save_file = f"{args.task}_easycache_thresh{args.thresh}_step{args.sample_steps}_{formatted_prompt}_{formatted_time}" + suffix 1106 | 1107 | if "t2i" in args.task: 1108 | logging.info(f"Saving generated image to {args.save_file}") 1109 | cache_image( 1110 | tensor=video.squeeze(1)[None], 1111 | save_file=args.save_file, 1112 | nrow=1, 1113 | normalize=True, 1114 | value_range=(-1, 1)) 1115 | else: 1116 | logging.info(f"Saving generated video to {args.save_file}") 1117 | cache_video( 1118 | tensor=video[None], 1119 | save_file=args.save_file, 1120 | fps=cfg.sample_fps, 1121 | nrow=1, 1122 | normalize=True, 1123 | value_range=(-1, 1)) 1124 | logging.info("Finished.") 1125 | 1126 | time_cost["number_prompt"] = len(generation_time) 1127 | time_cost["avg_cost_time"] = sum(generation_time) / (len(generation_time)) if len(generation_time) > 0 else 0 1128 | 1129 | print( 1130 | f"GPU_Device:{time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}") 1131 | try: 1132 | with open(f"./{args.out_dir}/1time_cost.json", "a+") as f: 1133 | portalocker.lock(f, portalocker.LOCK_EX) 1134 | f.seek(0) 1135 | try: 1136 | existing_data = json.load(f) 1137 | except (json.JSONDecodeError, FileNotFoundError): 1138 | existing_data = [] 1139 | existing_data.append(time_cost) 1140 | f.seek(0) 1141 | f.truncate() 1142 | json.dump(existing_data, f, indent=4) 1143 | except Exception as e: 1144 | print(f"Error saving time cost data: {e}") 1145 | 1146 | 1147 | if __name__ == "__main__": 1148 | args = _parse_args() 1149 | generate(args) 1150 | -------------------------------------------------------------------------------- /EasyCache4Wan2.1/example/grogu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/example/grogu.png -------------------------------------------------------------------------------- /EasyCache4Wan2.1/tools/video_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import torch 5 | import lpips 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torchmetrics.image import StructuralSimilarityIndexMeasure 9 | 10 | def load_video_frames(path, resize_to=None): 11 | """ 12 | Load all frames from a video file as a list of HxWx3 uint8 arrays. 13 | Optionally resize each frame to `resize_to` (w, h). 14 | """ 15 | 16 | cap = cv2.VideoCapture(path) 17 | frames = [] 18 | while True: 19 | ret, img = cap.read() 20 | if not ret: 21 | break 22 | if resize_to is not None: 23 | img = cv2.resize(img, resize_to) 24 | frames.append(np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), axis=0)) 25 | cap.release() 26 | return np.concatenate(frames) 27 | 28 | 29 | def compute_video_metrics(frames_gt, frames_gen, 30 | device, ssim_metric, lpips_fn): 31 | """ 32 | Compute PSNR, SSIM, LPIPS for two lists of frames (uint8 BGR). 33 | All computations on `device`. 34 | Returns (psnr, ssim, lpips) scalars. 35 | """ 36 | # ensure same frame count 37 | # convert to tensors [N,3,H,W], normalize to [0,1] 38 | gt_t = torch.from_numpy(frames_gt).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous() 39 | 40 | gen_t = torch.from_numpy(frames_gen).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous() 41 | 42 | # PSNR (data_range=1.0): -10 * log10(mse) 43 | mse = torch.mean((gt_t - gen_t) ** 2) 44 | psnr = -10.0 * torch.log10(mse) 45 | 46 | # SSIM: returns average over batch 47 | ssim_val = ssim_metric(gen_t, gt_t) 48 | 49 | # LPIPS: expects [-1,1] 50 | with torch.no_grad(): 51 | lpips_val = lpips_fn(gt_t * 2.0 - 1.0, gen_t * 2.0 - 1.0).mean() 52 | 53 | return psnr.item(), ssim_val.item(), lpips_val.item() 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser( 58 | description="Compute PSNR/SSIM/LPIPS on GPU for two folders of .mp4 videos" 59 | ) 60 | parser.add_argument("--original_video", required=True, 61 | help="ground-truth .mp4 videos") 62 | parser.add_argument("--generated_video", required=True, 63 | help="generated .mp4 videos") 64 | parser.add_argument("--device", default="cuda", 65 | help="Torch device, e.g. 'cuda' or 'cpu'") 66 | parser.add_argument("--lpips_net", default="alex", choices=["alex", "vgg"], 67 | help="Backbone for LPIPS") 68 | args = parser.parse_args() 69 | 70 | device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu") 71 | # instantiate metrics on device 72 | ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) 73 | lpips_fn = lpips.LPIPS(net=args.lpips_net, spatial=True).to(device) 74 | 75 | # gather .mp4 filenames 76 | gt_files = args.original_video 77 | gen_set = args.generated_video 78 | 79 | psnrs, ssims, lpips_vals = [], [], [] 80 | for fname in tqdm([gt_files], desc="Videos"): 81 | path_gt = gt_files 82 | path_gen = gen_set 83 | 84 | # load frames; resize generated to match GT dimensions 85 | frames_gt = load_video_frames(path_gt) 86 | frames_gen = load_video_frames(path_gen) 87 | 88 | res = compute_video_metrics(frames_gt, frames_gen, 89 | device, ssim_metric, lpips_fn) 90 | if res is None: 91 | continue 92 | p, s, l = res 93 | psnrs.append(p); 94 | ssims.append(s); 95 | lpips_vals.append(l) 96 | 97 | if not psnrs: 98 | print("No valid videos processed.") 99 | return 100 | 101 | print("\n=== Overall Averages ===") 102 | print(f"Average PSNR : {np.mean(psnrs):.2f} dB") 103 | print(f"Average SSIM : {np.mean(ssims):.4f}") 104 | print(f"Average LPIPS: {np.mean(lpips_vals):.4f}") 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /EasyCache4Wan2.1/videos/i2v_easycache_14b_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/i2v_easycache_14b_720p.gif -------------------------------------------------------------------------------- /EasyCache4Wan2.1/videos/i2v_gt_14b_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/i2v_gt_14b_720p.gif -------------------------------------------------------------------------------- /EasyCache4Wan2.1/videos/t2v_easycache_14b_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/t2v_easycache_14b_720p.gif -------------------------------------------------------------------------------- /EasyCache4Wan2.1/videos/t2v_gt_14b_720p.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/t2v_gt_14b_720p.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching

3 | 4 | Xin Zhou1\*, 5 | Dingkang Liang1\*, 6 | Kaijin Chen1, Tianrui Feng1, 7 | Xiwu Chen2, Hongkai Lin1,
8 | Yikang Ding2, Feiyang Tan2, 9 | Hengshuang Zhao3, 10 | Xiang Bai1† 11 | 12 | 1 Huazhong University of Science and Technology, 2 MEGVII Technology, 3 The University of Hong Kong
13 | 14 | (\*) Equal contribution. (†) Corresponding author. 15 | 16 | [![arXiv](https://img.shields.io/badge/Arxiv-2507.02860-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2507.02860) 17 | [![Project](https://img.shields.io/badge/Homepage-project-orange.svg?logo=googlehome)](https://H-EmbodVis.github.io/EasyCache/) 18 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/LMD0311/EasyCache/blob/main/LICENSE) 19 | 20 |
21 | 22 | ## 🎬 Visual Comparisons 23 | Video synchronization issues may occur due to network load, for improved visualization, see the [project page](https://H-EmbodVis.github.io/EasyCache/) 24 | 25 | **Prompt: "Grassland at dusk, wild horses galloping, golden light flickering across manes."** 26 | *(HunyuanVideo)* 27 | 28 | | Baseline | Ours (2.28x) | TeaCache (1.68x) | PAB (1.19x) | 29 | | :---: | :---: | :---: | :---: | 30 | | ![Baseline Video](./demo/gt/6.gif) | ![Our Video](./demo/our/6.gif) | ![TeaCache Video](./demo/teacache/6.gif) | ![PAB Video](./demo/pab/6.gif) | 31 | 32 | **Prompt: "A top-down view of a barista creating latte art, skillfully pouring milk to form the letters 'TPAMI' on coffee."** 33 | *(Wan2.1-14B)* 34 | 35 | | Baseline | Ours (2.63x) | TeaCache (1.46x) | PAB (2.10x) | 36 | | :---: | :---: | :---: | :---: | 37 | | ![Baseline Latte](./demo/gt/7.gif) | ![Our Latte](./demo/our/7.gif) | ![TeaCache Latte](./demo/teacache/7.gif) | ![PAB Latte](./demo/pab/7.gif) | 38 | 39 | **Compatibility with SVG** 40 | 41 | SVG with EasyCache on HunyuanVideo can achieve more than 3x speedup. 42 | 43 | https://github.com/user-attachments/assets/248ab88f-dfa8-4980-9b51-5c081e27db9a 44 | 45 | 46 | ## 📰 News 47 | - **If you like our project, please give us a star ⭐ on GitHub for the latest update.** 48 | - **[2025/07/06]** 🔥 EasyCache for [**Wan2.1**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4Wan2.1) I2V is released. 49 | - **[2025/07/05]** 🔥 EasyCache for [**Wan2.1**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4Wan2.1) T2V is released. 50 | - **[2025/07/04]** 🎉 Release the [**paper**](https://arxiv.org/abs/2507.02860) of EasyCache. 51 | - **[2025/07/03]** 🔥 EasyCache for Sparse-VideoGen on [**HunyuanVideo**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4HunyuanVideo) is released. 52 | - **[2025/07/02]** 🔥 EasyCache for [**HunyuanVideo**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4HunyuanVideo) is released. 53 | 54 | ## Abstract 55 | Video generation models have demonstrated remarkable performance, yet their broader adoption remains constrained by slow inference speeds and substantial computational costs, primarily due to the iterative nature of the denoising process. Addressing this bottleneck is essential for democratizing advanced video synthesis technologies and enabling their integration into real-world applications. This work proposes EasyCache, a training-free acceleration framework for video diffusion models. EasyCache introduces a lightweight, runtime-adaptive caching mechanism that dynamically reuses previously computed transformation vectors, avoiding redundant computations during inference. Unlike prior approaches, EasyCache requires no offline profiling, pre-computation, or extensive parameter tuning. We conduct comprehensive studies on various large-scale video generation models, including OpenSora, Wan2.1, and HunyuanVideo. Our method achieves leading acceleration performance, reducing inference time by up to 2.1-3.3× compared to the original baselines while maintaining high visual fidelity with a significant up to 36% PSNR improvement compared to the previous SOTA method. This improvement makes our EasyCache a efficient and highly accessible solution for high-quality video generation in both research and practical applications. 56 | 57 | 58 | ## 🚀 Main Performance 59 | 60 | We validated the performance of EasyCache on leading video generation models and compared it with other state-of-the-art training-free acceleration methods. 61 | 62 | ### Comparison with SOTA Methods 63 | 64 | Tested on Vbench prompts with NVIDIA A800. 65 | 66 | **Performance on HunyuanVideo:** 67 | | Method | Latency (s)↓ | Speedup ↑ | PSNR ↑ | SSIM ↑ | LPIPS ↓ | 68 | |:---:|:---:|:---:|:---:|:---:|:---:| 69 | | HunyuanVideo (Baseline) | 1124.30 | 1.00x | - | - | - | 70 | | PAB | 958.23 | 1.17x | 18.58 | 0.7023 | 0.3827 | 71 | | TeaCache | 674.04 | 1.67x | 23.85 | 0.8185 | 0.1730 | 72 | | SVG | 802.70 | 1.40x | 26.57 | 0.8596 | 0.1368 | 73 | | **EasyCache (Ours)** | **507.97** | **2.21x** | **32.66** | **0.9313** | **0.0533** | 74 | 75 | **Performance on Wan2.1-1.3B:** 76 | 77 | | Method | Latency (s)↓ | Speedup ↑ | PSNR ↑ | SSIM ↑ | LPIPS ↓ | 78 | |:---:|:---:|:---:|:---:|:---:|:---:| 79 | | Wan2.1 (Baseline) | 175.35 | 1.00x | - | - | - | 80 | | PAB | 102.03 | 1.72x | 18.84 | 0.6484 | 0.3010 | 81 | | TeaCache | 87.77 | 2.00x | 22.57 | 0.8057 | 0.1277 | 82 | | **EasyCache (Ours)** | **69.11** | **2.54x** | **25.24** | **0.8337** | **0.0952** | 83 | 84 | ### Compatibility with Other Acceleration Techniques 85 | 86 | EasyCache is orthogonal to other acceleration techniques, such as the efficient attention mechanism SVG, and can be combined with them for even greater performance gains. 87 | 88 | **Combined Performance on HunyuanVideo (720p):** 89 | *Tested on NVIDIA H20 GPUs.* 90 | | Method | Latency (s)↓ | Speedup ↑ | PSNR (dB) ↑ | 91 | |:---:|:---:|:---:|:---:| 92 | | Baseline | 6594s | 1.00x | - | 93 | | SVG | 3474s | 1.90x | 27.56 | 94 | | SVG (w/ TeaCache) | 2071s | 3.18x | 22.65 | 95 | | SVG (w/ **Ours**) | **1981s** | **3.33x** | **27.26** | 96 | 97 | 98 | ## 🛠️ Usage 99 | Detailed instructions for each supported model are provided in their respective directories. We are continuously working to extend support to more models. 100 | 101 | ### HunyuanVideo 102 | 1. **Prerequisites**: Set up the environment and download weights from the official HunyuanVideo repository. 103 | 2. **Copy Files**: Place the EasyCache script files into your local HunyuanVideo project directory. 104 | 3. **Run**: Execute the provided Python script to run inference with acceleration. 105 | **For complete instructions, please refer to the [README](./EasyCache4HunyuanVideo/README.md).** 106 | 107 | ### Wan2.1 108 | 1. **Prerequisites**: Set up the environment and download weights from the official Wan2.1 repository. 109 | 2. **Copy Files**: Place the EasyCache script files into your local Wan2.1 project directory. 110 | 3. **Run**: Execute the provided Python script to run inference with acceleration. 111 | **For complete instructions, please refer to the [README](./EasyCache4Wan2.1/README.md).** 112 | 113 | ## 🎯 To Do 114 | 115 | - [x] Support HunyuanVideo 116 | - [x] Support Sparse-VideoGen on HunyuanVideo 117 | - [x] Support Wan2.1 T2V 118 | - [x] Support Wan2.1 I2V 119 | - [ ] Support FLUX 120 | 121 | ## 🌹 Acknowledgements 122 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1), [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo), [OpenSora](https://github.com/hpcaitech/Open-Sora), and [SVG](https://github.com/svg-project/Sparse-VideoGen) repositories, for their open research and exploration. 123 | 124 | ## 📖 Citation 125 | 126 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation. 127 | ```bibtex 128 | @article{zhou2025easycache, 129 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching}, 130 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang}, 131 | journal={arXiv preprint arXiv:2507.02860}, 132 | year={2025} 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /demo/gt/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/gt/6.gif -------------------------------------------------------------------------------- /demo/gt/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/gt/7.gif -------------------------------------------------------------------------------- /demo/our/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/our/6.gif -------------------------------------------------------------------------------- /demo/our/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/our/7.gif -------------------------------------------------------------------------------- /demo/pab/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/pab/6.gif -------------------------------------------------------------------------------- /demo/pab/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/pab/7.gif -------------------------------------------------------------------------------- /demo/teacache/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/teacache/6.gif -------------------------------------------------------------------------------- /demo/teacache/7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/teacache/7.gif --------------------------------------------------------------------------------