├── EasyCache4HunyuanVideo
├── README.md
├── easycache_sample_video.py
├── hyvideo_svg_easycache.py
├── tools
│ └── video_metrics.py
└── videos
│ ├── baseline_544p.gif
│ ├── baseline_720p.gif
│ ├── easycache_544p.gif
│ └── svg_with_easycache_720p.gif
├── EasyCache4Wan2.1
├── README.md
├── easycache_generate.py
├── example
│ └── grogu.png
├── tools
│ └── video_metrics.py
└── videos
│ ├── i2v_easycache_14b_720p.gif
│ ├── i2v_gt_14b_720p.gif
│ ├── t2v_easycache_14b_720p.gif
│ └── t2v_gt_14b_720p.gif
├── LICENSE
├── README.md
└── demo
├── gt
├── 6.gif
└── 7.gif
├── our
├── 6.gif
└── 7.gif
├── pab
├── 6.gif
└── 7.gif
└── teacache
├── 6.gif
└── 7.gif
/EasyCache4HunyuanVideo/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching
3 |
4 |
Xin Zhou1\*,
5 |
Dingkang Liang1\*,
6 | Kaijin Chen
1, Tianrui Feng
1,
7 |
Xiwu Chen2, Hongkai Lin
1,
8 |
Yikang Ding2, Feiyang Tan
2,
9 |
Hengshuang Zhao3,
10 |
Xiang Bai1†
11 |
12 |
1 Huazhong University of Science and Technology,
2 MEGVII Technology,
3 University of Hong Kong
13 |
14 | (\*) Equal contribution. (†) Corresponding author.
15 |
16 | [](https://H-EmbodVis.github.io/EasyCache/)
17 | [](https://github.com/LMD0311/EasyCache/blob/main/LICENSE)
18 |
19 |
20 |
21 | ---
22 |
23 | This document provides the implementation for accelerating the [**HunyuanVideo**](https://github.com/Tencent/HunyuanVideo) model using **EasyCache**.
24 |
25 | ### ✨ Visual Comparison
26 |
27 | EasyCache significantly accelerates inference speed while maintaining high visual fidelity.
28 |
29 | **Prompt: "A cat walks on the grass, realistic style." (Base Acceleration)**
30 |
31 | | HunyuanVideo (Baseline, 544p, H20) | EasyCache (Ours) |
32 | | :---: | :---: |
33 | |  |  |
34 | | **Inference Time: ~2327s** | **Inference Time: ~1025s (2.3x Speedup)** |
35 |
36 | **Prompt: "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." (SVG with EasyCache)**
37 |
38 | | HunyuanVideo (Baseline, 720p, H20) | SVG with EasyCache (Ours) |
39 | |:---:|:---:|
40 | |  |  |
41 | | **Inference Time: ~6572s** | **Inference Time: ~1773s (3.71x Speedup)** |
42 |
43 |
44 | ---
45 |
46 | ### 🚀 Usage Instructions
47 |
48 | This section provides instructions for two settings: base acceleration with EasyCache alone and combined acceleration using EasyCache with SVG.
49 |
50 | #### **1. Base Acceleration (EasyCache Only)**
51 |
52 | **a. Prerequisites** ⚙️
53 |
54 | Before you begin, please follow the instructions in the [official HunyuanVideo repository](https://github.com/Tencent/HunyuanVideo) to configure the required environment and download the pretrained model weights.
55 |
56 | **b. Copy Files** 📂
57 |
58 | Copy `easycache_sample_video.py` into the root directory of your local `HunyuanVideo` project.
59 |
60 | **c. Run Inference** ▶️
61 |
62 | Execute the following command from the root of the `HunyuanVideo` project to generate a video. To generate videos in 720p resolution, set the `--video-size` argument to `720 1280`. You can also specify your own custom prompts.
63 |
64 | ```bash
65 | python3 easycache_sample_video.py \
66 | --video-size 544 960 \
67 | --video-length 129 \
68 | --infer-steps 50 \
69 | --prompt "A cat walks on the grass, realistic style." \
70 | --flow-reverse \
71 | --use-cpu-offload \
72 | --save-path ./results \
73 | --seed 42
74 | ```
75 |
76 | #### **2. Combined Acceleration (SVG with EasyCache)**
77 |
78 | **a. Prerequisites** ⚙️
79 |
80 | Ensure you have set up the environments for both [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) and [SVG](https://github.com/svg-project/Sparse-VideoGen).
81 |
82 | **b. Copy Files** 📂
83 |
84 | Copy `hyvideo_svg_easycache.py` into the root directory of your local `HunyuanVideo` project.
85 |
86 | **c. Run Inference** ▶️
87 |
88 | Execute the following command to generate a 720p video using both SVG and EasyCache for maximum acceleration. You can also specify your own custom prompts.
89 |
90 | ```bash
91 | python3 hyvideo_svg_easycache.py \
92 | --video-size 720 1280 \
93 | --video-length 129 \
94 | --infer-steps 50 \
95 | --prompt "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." \
96 | --embedded-cfg-scale 6.0 \
97 | --flow-shift 7.0 \
98 | --flow-reverse \
99 | --use-cpu-offload \
100 | --save-path ./results \
101 | --output_path ./results \
102 | --pattern "SVG" \
103 | --num_sampled_rows 64 \
104 | --sparsity 0.2 \
105 | --first_times_fp 0.055 \
106 | --first_layers_fp 0.025 \
107 | --record_attention \
108 | --seed 42
109 | ```
110 |
111 | ### 📊 Evaluating Video Similarity
112 |
113 | We provide a simple script to quickly evaluate the similarity between two videos (e.g., the baseline result and your generated result) using common metrics.
114 |
115 | **Usage**
116 |
117 | ```bash
118 | # install required packages.
119 | pip install lpips numpy tqdm torchmetrics
120 |
121 | python tools/video_metrics.py --original_video video1.mp4 --generated_video video2.mp4
122 | ```
123 |
124 | - `--original_video`: Path to the first video (e.g., the baseline).
125 | - `--generated_video`: Path to the second video (e.g., the one generated with EasyCache).
126 |
127 | ## 🌹 Acknowledgements
128 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo), and [SVG](https://github.com/svg-project/Sparse-VideoGen) repositories, for their open research and exploration.
129 |
130 | ## 📖 Citation
131 |
132 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation.
133 | ```bibtex
134 | @article{zhou2025easycache,
135 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching},
136 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang},
137 | journal={arXiv preprint arXiv:2507.02860},
138 | year={2025}
139 | }
140 | ```
141 |
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/easycache_sample_video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The Tecent Hunyuan Team Authors. All rights reserved.
2 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved.
3 |
4 | import os
5 | import time
6 | from pathlib import Path
7 | from loguru import logger
8 | from datetime import datetime
9 |
10 | from hyvideo.utils.file_utils import save_videos_grid
11 | from hyvideo.config import parse_args
12 | from hyvideo.inference import HunyuanVideoSampler
13 |
14 | from hyvideo.modules.modulate_layers import modulate
15 | from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens
16 | from typing import Any, List, Tuple, Optional, Union, Dict
17 | import torch
18 | import json
19 | import numpy as np
20 | import portalocker
21 | import json
22 | import random
23 | from tqdm import tqdm
24 | from torch.utils.data import Dataset, DataLoader
25 |
26 |
27 | def easycache_forward(
28 | self,
29 | x: torch.Tensor,
30 | t: torch.Tensor, # Should be in range(0, 1000).
31 | text_states: torch.Tensor = None,
32 | text_mask: torch.Tensor = None, # Now we don't use it.
33 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
34 | freqs_cos: Optional[torch.Tensor] = None,
35 | freqs_sin: Optional[torch.Tensor] = None,
36 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
37 | return_dict: bool = True,
38 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
39 | torch.cuda.synchronize()
40 | start_time = time.time()
41 |
42 | out = {}
43 | raw_input = x.clone()
44 | img = x
45 | txt = text_states
46 | _, _, ot, oh, ow = x.shape
47 | tt, th, tw = (
48 | ot // self.patch_size[0],
49 | oh // self.patch_size[1],
50 | ow // self.patch_size[2],
51 | )
52 |
53 | # Prepare modulation vectors.
54 | vec = self.time_in(t)
55 |
56 | # text modulation
57 | vec = vec + self.vector_in(text_states_2)
58 |
59 | # guidance modulation
60 | if self.guidance_embed:
61 | if guidance is None:
62 | raise ValueError(
63 | "Didn't get guidance strength for guidance distilled model."
64 | )
65 |
66 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
67 | vec = vec + self.guidance_in(guidance)
68 |
69 | if self.cnt < self.ret_steps or self.cnt >= self.num_steps - 1:
70 | should_calc = True
71 | self.accumulated_error = 0
72 | else:
73 | # Check if previous inputs and outputs exist
74 | if hasattr(self, 'previous_raw_input') and hasattr(self, 'previous_output') \
75 | and self.previous_raw_input is not None and self.previous_output is not None:
76 |
77 | raw_input_change = (raw_input - self.previous_raw_input).abs().mean()
78 |
79 | if hasattr(self, 'k') and self.k is not None:
80 |
81 | output_norm = self.previous_output.abs().mean()
82 | pred_change = self.k * (raw_input_change / output_norm)
83 | self.accumulated_error += pred_change
84 |
85 | if self.accumulated_error < self.thresh:
86 | should_calc = False
87 | else:
88 | should_calc = True
89 | self.accumulated_error = 0
90 | else:
91 | should_calc = True
92 | else:
93 | should_calc = True
94 |
95 | self.previous_raw_input = raw_input.clone() # (1, 16, 33, 68, 120)
96 |
97 | if not should_calc and self.cache is not None:
98 | result = raw_input + self.cache
99 | self.cnt += 1
100 |
101 | if self.cnt >= self.num_steps:
102 | self.cnt = 0
103 |
104 | torch.cuda.synchronize()
105 | end_time = time.time()
106 | self.total_time += (end_time - start_time)
107 |
108 | if return_dict:
109 | out["x"] = result
110 | return out
111 | return result
112 |
113 | img = self.img_in(img)
114 | if self.text_projection == "linear":
115 | txt = self.txt_in(txt)
116 | elif self.text_projection == "single_refiner":
117 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
118 | else:
119 | raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
120 |
121 | txt_seq_len = txt.shape[1]
122 | img_seq_len = img.shape[1]
123 |
124 | # Compute cu_squlens and max_seqlen for flash attention
125 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
126 | cu_seqlens_kv = cu_seqlens_q
127 | max_seqlen_q = img_seq_len + txt_seq_len
128 | max_seqlen_kv = max_seqlen_q
129 |
130 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
131 |
132 | # --------------------- Pass through DiT blocks ------------------------
133 | for _, block in enumerate(self.double_blocks):
134 | double_block_args = [
135 | img,
136 | txt,
137 | vec,
138 | cu_seqlens_q,
139 | cu_seqlens_kv,
140 | max_seqlen_q,
141 | max_seqlen_kv,
142 | freqs_cis,
143 | ]
144 | img, txt = block(*double_block_args)
145 |
146 | # Merge txt and img to pass through single stream blocks.
147 | x = torch.cat((img, txt), 1)
148 | if len(self.single_blocks) > 0:
149 | for _, block in enumerate(self.single_blocks):
150 | single_block_args = [
151 | x,
152 | vec,
153 | txt_seq_len,
154 | cu_seqlens_q,
155 | cu_seqlens_kv,
156 | max_seqlen_q,
157 | max_seqlen_kv,
158 | (freqs_cos, freqs_sin),
159 | ]
160 | x = block(*single_block_args)
161 |
162 | img = x[:, :img_seq_len, ...]
163 |
164 | # ---------------------------- Final layer ------------------------------
165 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
166 |
167 | result = self.unpatchify(img, tt, th, tw)
168 |
169 | # store the cache for next step
170 | self.cache = result - raw_input
171 | if hasattr(self, 'previous_output') and self.previous_output is not None:
172 | output_change = (result - self.previous_output).abs().mean()
173 | if hasattr(self, 'prev_prev_raw_input') and self.prev_prev_raw_input is not None:
174 | input_change = (self.previous_raw_input - self.prev_prev_raw_input).abs().mean()
175 | self.k = output_change / input_change
176 |
177 | # update the previous state
178 | self.prev_prev_raw_input = getattr(self, 'previous_raw_input', None)
179 | self.previous_output = result.clone()
180 |
181 | self.cnt += 1
182 | if self.cnt >= self.num_steps:
183 | self.cnt = 0
184 |
185 | torch.cuda.synchronize()
186 | end_time = time.time()
187 | self.total_time += (end_time - start_time)
188 |
189 | if return_dict:
190 | out["x"] = result
191 | return out
192 | return result
193 |
194 |
195 | def main():
196 | args = parse_args()
197 |
198 | print(args)
199 | models_root_path = Path(args.model_base)
200 | if not models_root_path.exists():
201 | raise ValueError(f"`models_root` not exists: {models_root_path}")
202 |
203 | # Create save folder to save the samples
204 | os.makedirs(args.save_path, exist_ok=True)
205 |
206 | # Load models
207 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
208 |
209 | # Get the updated args
210 | args = hunyuan_video_sampler.args
211 |
212 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
213 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps
214 | hunyuan_video_sampler.pipeline.transformer.__class__.thresh = 0.025
215 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = easycache_forward
216 | hunyuan_video_sampler.pipeline.transformer.__class__.ret_steps = 5
217 | hunyuan_video_sampler.pipeline.transformer.__class__.k = None
218 | hunyuan_video_sampler.pipeline.transformer.__class__.total_time = 0.0
219 |
220 | # record time cost for DiTs
221 | generation_time = []
222 | time_cost = {
223 | "GPU_Device": torch.cuda.get_device_name(0),
224 | "number_prompt": None,
225 | "avg_cost_time": None
226 | }
227 |
228 | hunyuan_video_sampler.pipeline.transformer.total_time = 0.0
229 | outputs = hunyuan_video_sampler.predict(
230 | prompt=args.prompt,
231 | height=args.video_size[0],
232 | width=args.video_size[1],
233 | video_length=args.video_length,
234 | seed=args.seed,
235 | negative_prompt=args.neg_prompt,
236 | infer_steps=args.infer_steps,
237 | guidance_scale=args.cfg_scale,
238 | num_videos_per_prompt=1,
239 | flow_shift=args.flow_shift,
240 | batch_size=args.batch_size,
241 | embedded_guidance_scale=args.embedded_cfg_scale
242 | )
243 |
244 | generation_time.append(hunyuan_video_sampler.pipeline.transformer.total_time)
245 | samples = outputs['samples']
246 |
247 | # Save samples
248 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
249 | for i, sample in enumerate(samples):
250 | sample = samples[i].unsqueeze(0)
251 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
252 | save_path = f"{args.save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/', '')}.mp4"
253 | save_videos_grid(sample, save_path, fps=24)
254 | logger.info(f'Sample save to: {save_path}')
255 |
256 | if generation_time:
257 | time_cost["number_prompt"] = len(generation_time)
258 | time_cost["avg_cost_time"] = sum(generation_time) / len(generation_time) if generation_time else 0
259 |
260 | print(
261 | f"GPU_Device: {time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}")
262 |
263 | try:
264 | with open(f"{args.save_path}/time_cost.json", "a+") as f:
265 | portalocker.lock(f, portalocker.LOCK_EX)
266 | f.seek(0)
267 | try:
268 | existing_data = json.load(f)
269 | except (json.JSONDecodeError, FileNotFoundError):
270 | existing_data = []
271 |
272 | existing_data.append(time_cost)
273 | f.seek(0)
274 | f.truncate()
275 | json.dump(existing_data, f, indent=4)
276 | except Exception as e:
277 | print(f"Error writing time cost to file: {e}")
278 |
279 |
280 | if __name__ == "__main__":
281 | main()
282 |
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/hyvideo_svg_easycache.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The Tencent Hunyuan Team Authors. All rights reserved.
2 | # Copyright 2025 The SVG Team Authors. All rights reserve.
3 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved.
4 |
5 | import os
6 | import time
7 | import math
8 | import json
9 | from pathlib import Path
10 | from loguru import logger
11 | from datetime import datetime
12 |
13 | import torch
14 | from svg.models.hyvideo.utils.file_utils import save_videos_grid
15 | from svg.models.hyvideo.config import parse_args
16 | from svg.models.hyvideo.inference import HunyuanVideoSampler
17 | from torch.utils.data import Dataset, DataLoader
18 | from typing import Any, List, Tuple, Optional, Union, Dict
19 | import portalocker
20 | from tqdm import tqdm
21 |
22 |
23 | def get_cu_seqlens(text_mask, img_len):
24 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
25 |
26 | Args:
27 | text_mask (torch.Tensor): the mask of text
28 | img_len (int): the length of image
29 |
30 | Returns:
31 | torch.Tensor: the calculated cu_seqlens for flash attention
32 | """
33 | batch_size = text_mask.shape[0]
34 | text_len = text_mask.sum(dim=1)
35 | max_len = text_mask.shape[1] + img_len
36 |
37 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
38 |
39 | for i in range(batch_size):
40 | s = text_len[i] + img_len
41 | s1 = i * max_len + s
42 | s2 = (i + 1) * max_len
43 | cu_seqlens[2 * i + 1] = s1
44 | cu_seqlens[2 * i + 2] = s2
45 |
46 | return cu_seqlens
47 |
48 |
49 | @torch.compile()
50 | def easycache_forward(
51 | self,
52 | x: torch.Tensor,
53 | t: torch.Tensor, # Should be in range(0, 1000).
54 | text_states: torch.Tensor = None,
55 | text_mask: torch.Tensor = None, # Now we don't use it.
56 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
57 | freqs_cos: Optional[torch.Tensor] = None,
58 | freqs_sin: Optional[torch.Tensor] = None,
59 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
60 | return_dict: bool = True,
61 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
62 | torch.cuda.synchronize()
63 | start_time = time.time()
64 |
65 | out = {}
66 | raw_input = x.clone()
67 | img = x
68 | txt = text_states
69 | _, _, ot, oh, ow = x.shape
70 | tt, th, tw = (
71 | ot // self.patch_size[0],
72 | oh // self.patch_size[1],
73 | ow // self.patch_size[2],
74 | )
75 |
76 | # Prepare modulation vectors.
77 | vec = self.time_in(t)
78 |
79 | # text modulation
80 | vec = vec + self.vector_in(text_states_2)
81 |
82 | # guidance modulation
83 | if self.guidance_embed:
84 | if guidance is None:
85 | raise ValueError(
86 | "Didn't get guidance strength for guidance distilled model."
87 | )
88 |
89 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
90 | vec = vec + self.guidance_in(guidance)
91 |
92 | if self.cnt < self.ret_steps or self.cnt >= self.num_steps - 1:
93 | should_calc = True
94 | self.accumulated_error = 0
95 | else:
96 | # Check if previous inputs and outputs exist
97 | if hasattr(self, 'previous_raw_input') and hasattr(self, 'previous_output') \
98 | and self.previous_raw_input is not None and self.previous_output is not None:
99 |
100 | raw_input_change = (raw_input - self.previous_raw_input).abs().mean()
101 |
102 | if hasattr(self, 'k') and self.k is not None:
103 |
104 | output_norm = self.previous_output.abs().mean()
105 | pred_change = self.k * (raw_input_change / output_norm)
106 | self.accumulated_error += pred_change
107 |
108 | if self.accumulated_error < self.thresh:
109 | should_calc = False
110 | else:
111 | should_calc = True
112 | self.accumulated_error = 0
113 | else:
114 | should_calc = True
115 | else:
116 | should_calc = True
117 |
118 | self.previous_raw_input = raw_input.clone() # (1, 16, 33, 68, 120)
119 |
120 | if not should_calc and self.cache is not None:
121 | result = raw_input + self.cache
122 | self.cnt += 1
123 |
124 | if self.cnt >= self.num_steps:
125 | self.cnt = 0
126 |
127 | torch.cuda.synchronize()
128 | end_time = time.time()
129 | self.total_time += (end_time - start_time)
130 |
131 | if return_dict:
132 | out["x"] = result
133 | return out
134 | return result
135 |
136 | img = self.img_in(img)
137 | if self.text_projection == "linear":
138 | txt = self.txt_in(txt)
139 | elif self.text_projection == "single_refiner":
140 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
141 | else:
142 | raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
143 |
144 | txt_seq_len = txt.shape[1]
145 | img_seq_len = img.shape[1]
146 |
147 | # Compute cu_squlens and max_seqlen for flash attention
148 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
149 | cu_seqlens_kv = cu_seqlens_q
150 | max_seqlen_q = img_seq_len + txt_seq_len
151 | max_seqlen_kv = max_seqlen_q
152 |
153 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
154 |
155 | # --------------------- Pass through DiT blocks ------------------------
156 | for _, block in enumerate(self.double_blocks):
157 | double_block_args = [
158 | img,
159 | txt,
160 | vec,
161 | cu_seqlens_q,
162 | cu_seqlens_kv,
163 | max_seqlen_q,
164 | max_seqlen_kv,
165 | freqs_cis,
166 | t,
167 | ]
168 | img, txt = block(*double_block_args)
169 |
170 | # Merge txt and img to pass through single stream blocks.
171 | x = torch.cat((img, txt), 1)
172 | if len(self.single_blocks) > 0:
173 | for _, block in enumerate(self.single_blocks):
174 | single_block_args = [
175 | x,
176 | vec,
177 | txt_seq_len,
178 | cu_seqlens_q,
179 | cu_seqlens_kv,
180 | max_seqlen_q,
181 | max_seqlen_kv,
182 | (freqs_cos, freqs_sin),
183 | t,
184 | ]
185 | x = block(*single_block_args)
186 | img = x[:, :img_seq_len, ...]
187 |
188 | # ---------------------------- Final layer ------------------------------
189 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
190 |
191 | result = self.unpatchify(img, tt, th, tw)
192 |
193 | # store the cache for next step
194 | self.cache = result - raw_input
195 | if hasattr(self, 'previous_output') and self.previous_output is not None:
196 | output_change = (result - self.previous_output).abs().mean()
197 | if hasattr(self, 'prev_prev_raw_input') and self.prev_prev_raw_input is not None:
198 | input_change = (self.previous_raw_input - self.prev_prev_raw_input).abs().mean()
199 | self.k = output_change / input_change
200 |
201 | # update the previous state
202 | self.prev_prev_raw_input = getattr(self, 'previous_raw_input', None)
203 | self.previous_output = result.clone()
204 |
205 | self.cnt += 1
206 | if self.cnt >= self.num_steps:
207 | self.cnt = 0
208 |
209 | torch.cuda.synchronize()
210 | end_time = time.time()
211 | self.total_time += (end_time - start_time)
212 |
213 | if return_dict:
214 | out["x"] = result
215 | return out
216 | return result
217 |
218 |
219 | def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
220 | seq_len = context_length + num_frame * frame_size
221 | total_elements = seq_len ** 2
222 |
223 | sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements
224 |
225 | width = seq_len * (1 - math.sqrt(1 - sparsity))
226 | width_frame = width / frame_size
227 |
228 | return width_frame
229 |
230 |
231 | def main():
232 | args = parse_args()
233 | print(args)
234 | models_root_path = Path("./HunyuanVideo")
235 | if not models_root_path.exists():
236 | raise ValueError(f"`models_root` not exists: {models_root_path}")
237 |
238 | # Create save folder to save the samples
239 | os.makedirs(args.save_path, exist_ok=True)
240 |
241 | # Load models
242 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
243 |
244 | # Get the updated args
245 | args = hunyuan_video_sampler.args
246 |
247 | # Sparsity Related
248 | transformer = hunyuan_video_sampler.pipeline.transformer
249 | for _, block in enumerate(transformer.double_blocks):
250 | block.sparse_args = args
251 | for _, block in enumerate(transformer.single_blocks):
252 | block.sparse_args = args
253 | transformer.sparse_args = args
254 |
255 | print(
256 | f"Memory: {torch.cuda.memory_allocated() // 1024 ** 2} / {torch.cuda.max_memory_allocated() // 1024 ** 2} MB before Inference")
257 |
258 | cfg_size, num_head, head_dim, dtype, device = 1, 24, 128, torch.bfloat16, "cuda"
259 | context_length, num_frame, frame_size = 256, 33, 3600
260 |
261 | # Calculation
262 | spatial_width = temporal_width = sparsity_to_width(args.sparsity, context_length, num_frame, frame_size)
263 |
264 | print(f"Spatial_width: {spatial_width}, Temporal_width: {temporal_width}. Sparsity: {args.sparsity}")
265 |
266 | save_path = args.output_path
267 | if args.pattern == "SVG":
268 | masks = ["spatial", "temporal"]
269 |
270 | def get_attention_mask(mask_name):
271 |
272 | context_length = 256
273 | num_frame = 33
274 | frame_size = 3600
275 | attention_mask = torch.zeros(
276 | (context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")
277 |
278 | # TODO: fix hard coded mask
279 | if mask_name == "spatial":
280 | pixel_attn_mask = torch.zeros_like(attention_mask[:-context_length, :-context_length], dtype=torch.bool,
281 | device="cpu")
282 | block_size, block_thres = 128, frame_size * 1.5
283 | num_block = math.ceil(num_frame * frame_size / block_size)
284 | for i in range(num_block):
285 | for j in range(num_block):
286 | if abs(i - j) < block_thres // block_size:
287 | pixel_attn_mask[i * block_size: (i + 1) * block_size,
288 | j * block_size: (j + 1) * block_size] = 1
289 | attention_mask[:-context_length, :-context_length] = pixel_attn_mask
290 |
291 | attention_mask[-context_length:, :] = 1
292 | attention_mask[:, -context_length:] = 1
293 |
294 | else:
295 | pixel_attn_mask = torch.zeros_like(attention_mask[:-context_length, :-context_length], dtype=torch.bool,
296 | device=device)
297 |
298 | block_size, block_thres = 128, frame_size * 1.5
299 | num_block = math.ceil(num_frame * frame_size / block_size)
300 | for i in range(num_block):
301 | for j in range(num_block):
302 | if abs(i - j) < block_thres // block_size:
303 | pixel_attn_mask[i * block_size: (i + 1) * block_size,
304 | j * block_size: (j + 1) * block_size] = 1
305 |
306 | pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3,
307 | 2).reshape(
308 | frame_size * num_frame, frame_size * num_frame)
309 | attention_mask[:-context_length, :-context_length] = pixel_attn_mask
310 |
311 | attention_mask[-context_length:, :] = 1
312 | attention_mask[:, -context_length:] = 1
313 | attention_mask = attention_mask[:args.sample_mse_max_row].cuda()
314 | return attention_mask
315 |
316 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
317 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps
318 | hunyuan_video_sampler.pipeline.transformer.__class__.thresh = 0.025
319 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = easycache_forward
320 | hunyuan_video_sampler.pipeline.transformer.__class__.ret_steps = 5
321 | hunyuan_video_sampler.pipeline.transformer.__class__.k = None
322 | hunyuan_video_sampler.pipeline.transformer.__class__.total_time = 0.0
323 |
324 | if args.pattern == "SVG":
325 | from svg.models.hyvideo.modules.attenion import Hunyuan_SparseAttn, prepare_flexattention
326 | from svg.models.hyvideo.modules.custom_models import replace_sparse_forward
327 |
328 | AttnModule = Hunyuan_SparseAttn
329 | AttnModule.num_sampled_rows = args.num_sampled_rows
330 | AttnModule.sample_mse_max_row = args.sample_mse_max_row
331 | AttnModule.attention_masks = [get_attention_mask(mask_name) for mask_name in masks]
332 | AttnModule.first_layers_fp = args.first_layers_fp
333 | AttnModule.first_times_fp = args.first_times_fp
334 |
335 | generation_time = []
336 | time_cost = {
337 | "GPU_Device": torch.cuda.get_device_name(0),
338 | "number_prompt": None,
339 | "avg_cost_time": None
340 | }
341 | # Start sampling
342 | if args.pattern == "SVG":
343 | # We need to get the prompt len in advance, since HunyuanVideo handle the attention mask in a special way
344 | prompt_mask = hunyuan_video_sampler.get_prompt_mask(
345 | prompt=args.prompt,
346 | height=args.video_size[0],
347 | width=args.video_size[1],
348 | video_length=args.video_length,
349 | negative_prompt=args.neg_prompt,
350 | infer_steps=args.infer_steps,
351 | guidance_scale=args.cfg_scale,
352 | num_videos_per_prompt=args.num_videos,
353 | embedded_guidance_scale=args.embedded_cfg_scale
354 | )
355 | prompt_len = prompt_mask.sum()
356 |
357 | block_mask = prepare_flexattention(
358 | cfg_size, num_head, head_dim, dtype, device,
359 | context_length, prompt_len, num_frame, frame_size,
360 | diag_width=spatial_width, multiplier=temporal_width
361 | )
362 | AttnModule.block_mask = block_mask
363 | replace_sparse_forward()
364 |
365 | hunyuan_video_sampler.pipeline.transformer.total_time = 0.0
366 | outputs = hunyuan_video_sampler.predict(
367 | prompt=args.prompt,
368 | height=args.video_size[0],
369 | width=args.video_size[1],
370 | video_length=args.video_length,
371 | seed=args.seed,
372 | negative_prompt=args.neg_prompt,
373 | infer_steps=args.infer_steps,
374 | guidance_scale=args.cfg_scale,
375 | num_videos_per_prompt=1,
376 | flow_shift=args.flow_shift,
377 | batch_size=args.batch_size,
378 | embedded_guidance_scale=args.embedded_cfg_scale
379 | )
380 | generation_time.append(hunyuan_video_sampler.pipeline.transformer.total_time)
381 | samples = outputs['samples']
382 |
383 | # Save samples
384 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
385 | for i, sample in enumerate(samples):
386 | sample = samples[i].unsqueeze(0)
387 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
388 | save_path = f"{args.save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/', '')}.mp4"
389 | save_videos_grid(sample, save_path, fps=24)
390 | logger.info(f'Sample save to: {save_path}')
391 |
392 | if generation_time:
393 | time_cost["number_prompt"] = len(generation_time)
394 | time_cost["avg_cost_time"] = sum(generation_time) / len(generation_time) if generation_time else 0
395 |
396 | print(
397 | f"GPU_Device: {time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}")
398 |
399 | try:
400 | with open(f"{args.save_path}/time_cost.json", "a+") as f:
401 | portalocker.lock(f, portalocker.LOCK_EX)
402 | f.seek(0)
403 | try:
404 | existing_data = json.load(f)
405 | except (json.JSONDecodeError, FileNotFoundError):
406 | existing_data = []
407 |
408 | existing_data.append(time_cost)
409 | f.seek(0)
410 | f.truncate()
411 | json.dump(existing_data, f, indent=4)
412 | except Exception as e:
413 | print(f"Error writing time cost to file: {e}")
414 |
415 |
416 | if __name__ == "__main__":
417 | main()
418 |
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/tools/video_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import argparse
4 | import torch
5 | import lpips
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torchmetrics.image import StructuralSimilarityIndexMeasure
9 |
10 | def load_video_frames(path, resize_to=None):
11 | """
12 | Load all frames from a video file as a list of HxWx3 uint8 arrays.
13 | Optionally resize each frame to `resize_to` (w, h).
14 | """
15 |
16 | cap = cv2.VideoCapture(path)
17 | frames = []
18 | while True:
19 | ret, img = cap.read()
20 | if not ret:
21 | break
22 | if resize_to is not None:
23 | img = cv2.resize(img, resize_to)
24 | frames.append(np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), axis=0))
25 | cap.release()
26 | return np.concatenate(frames)
27 |
28 |
29 | def compute_video_metrics(frames_gt, frames_gen,
30 | device, ssim_metric, lpips_fn):
31 | """
32 | Compute PSNR, SSIM, LPIPS for two lists of frames (uint8 BGR).
33 | All computations on `device`.
34 | Returns (psnr, ssim, lpips) scalars.
35 | """
36 | # ensure same frame count
37 | # convert to tensors [N,3,H,W], normalize to [0,1]
38 | gt_t = torch.from_numpy(frames_gt).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous()
39 |
40 | gen_t = torch.from_numpy(frames_gen).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous()
41 |
42 | # PSNR (data_range=1.0): -10 * log10(mse)
43 | mse = torch.mean((gt_t - gen_t) ** 2)
44 | psnr = -10.0 * torch.log10(mse)
45 |
46 | # SSIM: returns average over batch
47 | ssim_val = ssim_metric(gen_t, gt_t)
48 |
49 | # LPIPS: expects [-1,1]
50 | with torch.no_grad():
51 | lpips_val = lpips_fn(gt_t * 2.0 - 1.0, gen_t * 2.0 - 1.0).mean()
52 |
53 | return psnr.item(), ssim_val.item(), lpips_val.item()
54 |
55 |
56 | def main():
57 | parser = argparse.ArgumentParser(
58 | description="Compute PSNR/SSIM/LPIPS on GPU for two folders of .mp4 videos"
59 | )
60 | parser.add_argument("--original_video", required=True,
61 | help="ground-truth .mp4 videos")
62 | parser.add_argument("--generated_video", required=True,
63 | help="generated .mp4 videos")
64 | parser.add_argument("--device", default="cuda",
65 | help="Torch device, e.g. 'cuda' or 'cpu'")
66 | parser.add_argument("--lpips_net", default="alex", choices=["alex", "vgg"],
67 | help="Backbone for LPIPS")
68 | args = parser.parse_args()
69 |
70 | device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu")
71 | # instantiate metrics on device
72 | ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
73 | lpips_fn = lpips.LPIPS(net=args.lpips_net, spatial=True).to(device)
74 |
75 | # gather .mp4 filenames
76 | gt_files = args.original_video
77 | gen_set = args.generated_video
78 |
79 | psnrs, ssims, lpips_vals = [], [], []
80 | for fname in tqdm([gt_files], desc="Videos"):
81 | path_gt = gt_files
82 | path_gen = gen_set
83 |
84 | # load frames; resize generated to match GT dimensions
85 | frames_gt = load_video_frames(path_gt)
86 | frames_gen = load_video_frames(path_gen)
87 |
88 | res = compute_video_metrics(frames_gt, frames_gen,
89 | device, ssim_metric, lpips_fn)
90 | if res is None:
91 | continue
92 | p, s, l = res
93 | psnrs.append(p);
94 | ssims.append(s);
95 | lpips_vals.append(l)
96 |
97 | if not psnrs:
98 | print("No valid videos processed.")
99 | return
100 |
101 | print("\n=== Overall Averages ===")
102 | print(f"Average PSNR : {np.mean(psnrs):.2f} dB")
103 | print(f"Average SSIM : {np.mean(ssims):.4f}")
104 | print(f"Average LPIPS: {np.mean(lpips_vals):.4f}")
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/videos/baseline_544p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/baseline_544p.gif
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/videos/baseline_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/baseline_720p.gif
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/videos/easycache_544p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/easycache_544p.gif
--------------------------------------------------------------------------------
/EasyCache4HunyuanVideo/videos/svg_with_easycache_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4HunyuanVideo/videos/svg_with_easycache_720p.gif
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching
3 |
4 |
Xin Zhou1\*,
5 |
Dingkang Liang1\*,
6 | Kaijin Chen
1, Tianrui Feng
1,
7 |
Xiwu Chen2, Hongkai Lin
1,
8 |
Yikang Ding2, Feiyang Tan
2,
9 |
Hengshuang Zhao3,
10 |
Xiang Bai1†
11 |
12 |
1 Huazhong University of Science and Technology,
2 MEGVII Technology,
3 University of Hong Kong
13 |
14 | (\*) Equal contribution. (†) Corresponding author.
15 |
16 | [](https://H-EmbodVis.github.io/EasyCache/)
17 | [](https://github.com/LMD0311/EasyCache/blob/main/LICENSE)
18 |
19 |
20 |
21 | ---
22 |
23 | This document provides the implementation for accelerating the [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) model using **EasyCache**.
24 |
25 | ### ✨ Visual Comparison
26 |
27 | EasyCache significantly accelerates inference speed while maintaining high visual fidelity.
28 |
29 | **Prompt: "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."**
30 |
31 | | Wan2.1-14B (Baseline, 720p, H20) | EasyCache (Ours, 720p, H20) |
32 | | :---: | :---: |
33 | |  |  |
34 | | **Inference Time: ~6862s** | **Inference Time: ~2884s (~2.4x Speedup)** |
35 |
36 |
37 | **Prompt: "A cute green alien child with large ears, wearing a brown robe, sits on a chair and eats a blue cookie at a table, with crumbs scattered on the robe, in a cozy indoor setting."**
38 |
39 | | Wan2.1-14B I2V (Baseline, 720p, H20) | EasyCache (Ours, 720p, H20) |
40 | | :---: | :---: |
41 | |  |  |
42 | | **Inference Time: ~5302s** | **Inference Time: ~2397s (~2.2x Speedup)** |
43 |
44 | ---
45 |
46 | ### 🚀 Usage Instructions
47 |
48 | #### **1. EasyCache Acceleration for Wan2.1 T2V**
49 |
50 | **a. Prerequisites** ⚙️
51 |
52 | Before you begin, please follow the instructions in the [official Wan2.1 repository](https://github.com/Wan-Video/Wan2.1) to configure the required environment and download the pretrained model weights.
53 |
54 | **b. Copy Files** 📂
55 |
56 | Copy `easycache_generate.py` into the root directory of your local `Wan2.1` project.
57 |
58 | **c. Run Inference** ▶️
59 |
60 | Execute the following command from the root of the `Wan2.1` project to generate a video. To generate videos in 720p resolution, set the `--size` argument to `1280*720`. You can also specify your own custom prompts.
61 |
62 | ```bash
63 | python easycache_generate.py \
64 | --task t2v-14B \
65 | --size "1280*720" \
66 | --ckpt_dir ./Wan2.1-T2V-14B \
67 | --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about." \
68 | --base_seed 0
69 | ```
70 | #### **2. EasyCache Acceleration for Wan2.1 I2V**
71 | Execute the following command from the root of the `Wan2.1` project to generate a video. To generate videos in 480p resolution, set the `--size` argument to `832*480` and set `--ckpt_dir` as `./Wan2.1-I2V-14B-480P`. You can also specify your own custom prompts and images.
72 |
73 | ```bash
74 | python easycache_generate.py \
75 | --task i2v-14B \
76 | --size "1280*720" \
77 | --ckpt_dir ./Wan2.1-I2V-14B-720P \
78 | --image examples/grogu.png \
79 | --prompt "A cute green alien child with large ears, wearing a brown robe, sits on a chair and eats a blue cookie at a table, with crumbs scattered on the robe, in a cozy indoor setting." \
80 | --base_seed 0
81 | ```
82 |
83 |
84 | ### 📊 Evaluating Video Similarity
85 |
86 | We provide a simple script to quickly evaluate the similarity between two videos (e.g., the baseline result and your generated result) using common metrics.
87 |
88 | **Usage**
89 |
90 | ```bash
91 | # install required packages.
92 | pip install lpips numpy tqdm torchmetrics
93 |
94 | python tools/video_metrics.py --original_video video1.mp4 --generated_video video2.mp4
95 | ```
96 |
97 | - `--original_video`: Path to the first video (e.g., the baseline).
98 | - `--generated_video`: Path to the second video (e.g., the one generated with EasyCache).
99 |
100 | ## 🌹 Acknowledgements
101 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repository, for the open research and exploration.
102 |
103 | ## 📖 Citation
104 |
105 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation.
106 | ```bibtex
107 | @article{zhou2025easycache,
108 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching},
109 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang},
110 | journal={arXiv preprint arXiv:2507.02860},
111 | year={2025}
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/easycache_generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | # Copyright 2025 The Huazhong University of Science and Technology VLRLab Authors. All rights reserved.
3 |
4 | import argparse
5 | from datetime import datetime
6 | import logging
7 | import os
8 | import sys
9 | import warnings
10 | import json
11 | from time import time
12 | import portalocker
13 |
14 | warnings.filterwarnings('ignore')
15 |
16 | import torch, random
17 | import torch.distributed as dist
18 | from torch.utils.data import DataLoader, Dataset
19 | from PIL import Image
20 |
21 | import wan
22 | from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
23 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
24 | from wan.utils.utils import cache_video, cache_image, str2bool
25 |
26 | import gc
27 | from contextlib import contextmanager
28 | import torchvision.transforms.functional as TF
29 | import torch.cuda.amp as amp
30 | import numpy as np
31 | import math
32 | from wan.modules.model import sinusoidal_embedding_1d
33 | from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
34 | get_sampling_sigmas, retrieve_timesteps)
35 | from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
36 | from tqdm import tqdm
37 |
38 | EXAMPLE_PROMPT = {
39 | "t2v-1.3B": {
40 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
41 | },
42 | "t2v-14B": {
43 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
44 | },
45 | "t2i-14B": {
46 | "prompt": "一个朴素端庄的美人",
47 | },
48 | "i2v-14B": {
49 | "prompt":
50 | "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
51 | "image":
52 | "examples/i2v_input.JPG",
53 | },
54 | }
55 |
56 |
57 | def t2v_generate(self,
58 | input_prompt,
59 | size=(1280, 720),
60 | frame_num=81,
61 | shift=5.0,
62 | sample_solver='unipc',
63 | sampling_steps=50,
64 | guide_scale=5.0,
65 | n_prompt="",
66 | seed=-1,
67 | offload_model=True):
68 | r"""
69 | Generates video frames from text prompt using diffusion process.
70 |
71 | Args:
72 | input_prompt (`str`):
73 | Text prompt for content generation
74 | size (tupele[`int`], *optional*, defaults to (1280,720)):
75 | Controls video resolution, (width,height).
76 | frame_num (`int`, *optional*, defaults to 81):
77 | How many frames to sample from a video. The number should be 4n+1
78 | shift (`float`, *optional*, defaults to 5.0):
79 | Noise schedule shift parameter. Affects temporal dynamics
80 | sample_solver (`str`, *optional*, defaults to 'unipc'):
81 | Solver used to sample the video.
82 | sampling_steps (`int`, *optional*, defaults to 40):
83 | Number of diffusion sampling steps. Higher values improve quality but slow generation
84 | guide_scale (`float`, *optional*, defaults 5.0):
85 | Classifier-free guidance scale. Controls prompt adherence vs. creativity
86 | n_prompt (`str`, *optional*, defaults to ""):
87 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
88 | seed (`int`, *optional*, defaults to -1):
89 | Random seed for noise generation. If -1, use random seed.
90 | offload_model (`bool`, *optional*, defaults to True):
91 | If True, offloads models to CPU during generation to save VRAM
92 |
93 | Returns:
94 | torch.Tensor:
95 | Generated video frames tensor. Dimensions: (C, N H, W) where:
96 | - C: Color channels (3 for RGB)
97 | - N: Number of frames (81)
98 | - H: Frame height (from size)
99 | - W: Frame width from size)
100 | """
101 | # preprocess
102 | F = frame_num
103 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
104 | size[1] // self.vae_stride[1],
105 | size[0] // self.vae_stride[2])
106 |
107 | seq_len = math.ceil((target_shape[2] * target_shape[3]) /
108 | (self.patch_size[1] * self.patch_size[2]) *
109 | target_shape[1] / self.sp_size) * self.sp_size
110 |
111 | if n_prompt == "":
112 | n_prompt = self.sample_neg_prompt
113 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
114 | seed_g = torch.Generator(device=self.device)
115 | seed_g.manual_seed(seed)
116 |
117 | if not self.t5_cpu:
118 | self.text_encoder.model.to(self.device)
119 | context = self.text_encoder([input_prompt], self.device)
120 | context_null = self.text_encoder([n_prompt], self.device)
121 | if offload_model:
122 | self.text_encoder.model.cpu()
123 | else:
124 | context = self.text_encoder([input_prompt], torch.device('cpu'))
125 | context_null = self.text_encoder([n_prompt], torch.device('cpu'))
126 | context = [t.to(self.device) for t in context]
127 | context_null = [t.to(self.device) for t in context_null]
128 |
129 | noise = [
130 | torch.randn(
131 | target_shape[0],
132 | target_shape[1],
133 | target_shape[2],
134 | target_shape[3],
135 | dtype=torch.float32,
136 | device=self.device,
137 | generator=seed_g)
138 | ]
139 |
140 | @contextmanager
141 | def noop_no_sync():
142 | yield
143 |
144 | no_sync = getattr(self.model, 'no_sync', noop_no_sync)
145 |
146 | # evaluation mode
147 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
148 |
149 | if sample_solver == 'unipc':
150 | sample_scheduler = FlowUniPCMultistepScheduler(
151 | num_train_timesteps=self.num_train_timesteps,
152 | shift=1,
153 | use_dynamic_shifting=False)
154 | sample_scheduler.set_timesteps(
155 | sampling_steps, device=self.device, shift=shift)
156 | timesteps = sample_scheduler.timesteps
157 | elif sample_solver == 'dpm++':
158 | sample_scheduler = FlowDPMSolverMultistepScheduler(
159 | num_train_timesteps=self.num_train_timesteps,
160 | shift=1,
161 | use_dynamic_shifting=False)
162 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
163 | timesteps, _ = retrieve_timesteps(
164 | sample_scheduler,
165 | device=self.device,
166 | sigmas=sampling_sigmas)
167 | else:
168 | raise NotImplementedError("Unsupported solver.")
169 |
170 | # sample videos
171 | latents = noise
172 |
173 | arg_c = {'context': context, 'seq_len': seq_len}
174 | arg_null = {'context': context_null, 'seq_len': seq_len}
175 |
176 | for _, t in enumerate(tqdm(timesteps)):
177 | torch.cuda.synchronize()
178 | start_time = time()
179 | latent_model_input = latents
180 | timestep = [t]
181 |
182 | timestep = torch.stack(timestep)
183 |
184 | self.model.to(self.device)
185 | noise_pred_cond = self.model(
186 | latent_model_input, t=timestep, **arg_c)[0]
187 | noise_pred_uncond = self.model(
188 | latent_model_input, t=timestep, **arg_null)[0]
189 |
190 | noise_pred = noise_pred_uncond + guide_scale * (
191 | noise_pred_cond - noise_pred_uncond)
192 |
193 | torch.cuda.synchronize()
194 | self.cost_time += (time() - start_time)
195 |
196 | temp_x0 = sample_scheduler.step(
197 | noise_pred.unsqueeze(0),
198 | t,
199 | latents[0].unsqueeze(0),
200 | return_dict=False,
201 | generator=seed_g)[0]
202 | latents = [temp_x0.squeeze(0)]
203 |
204 | x0 = latents
205 | if offload_model:
206 | self.model.cpu()
207 | torch.cuda.empty_cache()
208 | if self.rank == 0:
209 | videos = self.vae.decode(x0)
210 |
211 | del noise, latents
212 | del sample_scheduler
213 | if offload_model:
214 | gc.collect()
215 | torch.cuda.synchronize()
216 | if dist.is_initialized():
217 | dist.barrier()
218 |
219 | return videos[0] if self.rank == 0 else None
220 |
221 |
222 | def i2v_generate(self,
223 | input_prompt,
224 | img,
225 | max_area=720 * 1280,
226 | frame_num=81,
227 | shift=5.0,
228 | sample_solver='unipc',
229 | sampling_steps=40,
230 | guide_scale=5.0,
231 | n_prompt="",
232 | seed=-1,
233 | offload_model=True):
234 | r"""
235 | Generates video frames from input image and text prompt using diffusion process.
236 |
237 | Args:
238 | input_prompt (`str`):
239 | Text prompt for content generation.
240 | img (PIL.Image.Image):
241 | Input image tensor. Shape: [3, H, W]
242 | max_area (`int`, *optional*, defaults to 720*1280):
243 | Maximum pixel area for latent space calculation. Controls video resolution scaling
244 | frame_num (`int`, *optional*, defaults to 81):
245 | How many frames to sample from a video. The number should be 4n+1
246 | shift (`float`, *optional*, defaults to 5.0):
247 | Noise schedule shift parameter. Affects temporal dynamics
248 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
249 | sample_solver (`str`, *optional*, defaults to 'unipc'):
250 | Solver used to sample the video.
251 | sampling_steps (`int`, *optional*, defaults to 40):
252 | Number of diffusion sampling steps. Higher values improve quality but slow generation
253 | guide_scale (`float`, *optional*, defaults 5.0):
254 | Classifier-free guidance scale. Controls prompt adherence vs. creativity
255 | n_prompt (`str`, *optional*, defaults to ""):
256 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
257 | seed (`int`, *optional*, defaults to -1):
258 | Random seed for noise generation. If -1, use random seed
259 | offload_model (`bool`, *optional*, defaults to True):
260 | If True, offloads models to CPU during generation to save VRAM
261 |
262 | Returns:
263 | torch.Tensor:
264 | Generated video frames tensor. Dimensions: (C, N H, W) where:
265 | - C: Color channels (3 for RGB)
266 | - N: Number of frames (81)
267 | - H: Frame height (from max_area)
268 | - W: Frame width from max_area)
269 | """
270 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
271 |
272 | F = frame_num
273 | h, w = img.shape[1:]
274 | aspect_ratio = h / w
275 | lat_h = round(
276 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
277 | self.patch_size[1] * self.patch_size[1])
278 | lat_w = round(
279 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
280 | self.patch_size[2] * self.patch_size[2])
281 | h = lat_h * self.vae_stride[1]
282 | w = lat_w * self.vae_stride[2]
283 |
284 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
285 | self.patch_size[1] * self.patch_size[2])
286 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
287 |
288 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
289 | seed_g = torch.Generator(device=self.device)
290 | seed_g.manual_seed(seed)
291 | noise = torch.randn(
292 | self.vae.model.z_dim,
293 | (F - 1) // self.vae_stride[0] + 1,
294 | lat_h,
295 | lat_w,
296 | dtype=torch.float32,
297 | generator=seed_g,
298 | device=self.device)
299 |
300 | msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
301 | msk[:, 1:] = 0
302 | msk = torch.concat([
303 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
304 | ],
305 | dim=1)
306 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
307 | msk = msk.transpose(1, 2)[0]
308 |
309 | if n_prompt == "":
310 | n_prompt = self.sample_neg_prompt
311 |
312 | # preprocess
313 | if not self.t5_cpu:
314 | self.text_encoder.model.to(self.device)
315 | context = self.text_encoder([input_prompt], self.device)
316 | context_null = self.text_encoder([n_prompt], self.device)
317 | if offload_model:
318 | self.text_encoder.model.cpu()
319 | else:
320 | context = self.text_encoder([input_prompt], torch.device('cpu'))
321 | context_null = self.text_encoder([n_prompt], torch.device('cpu'))
322 | context = [t.to(self.device) for t in context]
323 | context_null = [t.to(self.device) for t in context_null]
324 |
325 | self.clip.model.to(self.device)
326 | clip_context = self.clip.visual([img[:, None, :, :]])
327 | if offload_model:
328 | self.clip.model.cpu()
329 |
330 | y = self.vae.encode([
331 | torch.concat([
332 | torch.nn.functional.interpolate(
333 | img[None].cpu(), size=(h, w), mode='bicubic').transpose(
334 | 0, 1),
335 | torch.zeros(3, F - 1, h, w)
336 | ],
337 | dim=1).to(self.device)
338 | ])[0]
339 | y = torch.concat([msk, y])
340 |
341 | @contextmanager
342 | def noop_no_sync():
343 | yield
344 |
345 | no_sync = getattr(self.model, 'no_sync', noop_no_sync)
346 |
347 | # evaluation mode
348 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
349 |
350 | if sample_solver == 'unipc':
351 | sample_scheduler = FlowUniPCMultistepScheduler(
352 | num_train_timesteps=self.num_train_timesteps,
353 | shift=1,
354 | use_dynamic_shifting=False)
355 | sample_scheduler.set_timesteps(
356 | sampling_steps, device=self.device, shift=shift)
357 | timesteps = sample_scheduler.timesteps
358 | elif sample_solver == 'dpm++':
359 | sample_scheduler = FlowDPMSolverMultistepScheduler(
360 | num_train_timesteps=self.num_train_timesteps,
361 | shift=1,
362 | use_dynamic_shifting=False)
363 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
364 | timesteps, _ = retrieve_timesteps(
365 | sample_scheduler,
366 | device=self.device,
367 | sigmas=sampling_sigmas)
368 | else:
369 | raise NotImplementedError("Unsupported solver.")
370 |
371 | # sample videos
372 | latent = noise
373 |
374 | arg_c = {
375 | 'context': [context[0]],
376 | 'clip_fea': clip_context,
377 | 'seq_len': max_seq_len,
378 | 'y': [y],
379 | # 'cond_flag': True,
380 | }
381 |
382 | arg_null = {
383 | 'context': context_null,
384 | 'clip_fea': clip_context,
385 | 'seq_len': max_seq_len,
386 | 'y': [y],
387 | # 'cond_flag': False,
388 | }
389 |
390 | if offload_model:
391 | torch.cuda.empty_cache()
392 |
393 | self.model.to(self.device)
394 | for _, t in enumerate(tqdm(timesteps)):
395 | torch.cuda.synchronize()
396 | start_time = time()
397 | latent_model_input = [latent.to(self.device)]
398 | timestep = [t]
399 |
400 | timestep = torch.stack(timestep).to(self.device)
401 |
402 | noise_pred_cond = self.model(
403 | latent_model_input, t=timestep, **arg_c)[0].to(
404 | torch.device('cpu') if offload_model else self.device)
405 | if offload_model:
406 | torch.cuda.empty_cache()
407 | noise_pred_uncond = self.model(
408 | latent_model_input, t=timestep, **arg_null)[0].to(
409 | torch.device('cpu') if offload_model else self.device)
410 | if offload_model:
411 | torch.cuda.empty_cache()
412 |
413 | noise_pred = noise_pred_uncond + guide_scale * (
414 | noise_pred_cond - noise_pred_uncond)
415 |
416 | latent = latent.to(
417 | torch.device('cpu') if offload_model else self.device)
418 |
419 | torch.cuda.synchronize()
420 | self.cost_time += (time() - start_time)
421 |
422 | temp_x0 = sample_scheduler.step(
423 | noise_pred.unsqueeze(0),
424 | t,
425 | latent.unsqueeze(0),
426 | return_dict=False,
427 | generator=seed_g)[0]
428 | latent = temp_x0.squeeze(0)
429 |
430 | x0 = [latent.to(self.device)]
431 | del latent_model_input, timestep
432 |
433 | if offload_model:
434 | self.model.cpu()
435 | torch.cuda.empty_cache()
436 |
437 | if self.rank == 0:
438 | videos = self.vae.decode(x0)
439 |
440 | del noise, latent
441 | del sample_scheduler
442 | if offload_model:
443 | gc.collect()
444 | torch.cuda.synchronize()
445 | if dist.is_initialized():
446 | dist.barrier()
447 |
448 | return videos[0] if self.rank == 0 else None
449 |
450 |
451 | def easycache_forward(
452 | self,
453 | x,
454 | t,
455 | context,
456 | seq_len,
457 | clip_fea=None,
458 | y=None,
459 | ):
460 | """
461 | Args:
462 | x (List[Tensor]): List of input video tensors with shape [C_in, F, H, W]
463 | t (Tensor): Diffusion timesteps tensor of shape [B]
464 | context (List[Tensor]): List of text embeddings each with shape [L, C]
465 | seq_len (int): Maximum sequence length for positional encoding
466 | clip_fea (Tensor, optional): CLIP image features for image-to-video mode
467 | y (List[Tensor], optional): Conditional video inputs for image-to-video mode
468 |
469 | Returns:
470 | List[Tensor]: List of denoised video tensors with original input shapes
471 | """
472 | if self.model_type == 'i2v':
473 | assert clip_fea is not None and y is not None
474 |
475 | # Store original raw input for end-to-end caching
476 | raw_input = [u.clone() for u in x]
477 |
478 | # params
479 | device = self.patch_embedding.weight.device
480 | if self.freqs.device != device:
481 | self.freqs = self.freqs.to(device)
482 |
483 | if y is not None:
484 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
485 |
486 | # Track which type of step (even=condition, odd=uncondition)
487 | self.is_even = (self.cnt % 2 == 0)
488 |
489 | # Only make decision on even (condition) steps
490 | if self.is_even:
491 | # Always compute first ret_steps and last steps
492 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
493 | self.should_calc_current_pair = True
494 | self.accumulated_error_even = 0
495 | else:
496 | # Check if we have previous step data for comparison
497 | if hasattr(self, 'previous_raw_input_even') and hasattr(self, 'previous_raw_output_even') and \
498 | self.previous_raw_input_even is not None and self.previous_raw_output_even is not None:
499 | # Calculate input changes
500 | raw_input_change = torch.cat([
501 | (u - v).flatten() for u, v in zip(raw_input, self.previous_raw_input_even)
502 | ]).abs().mean()
503 |
504 | # Compute predicted change if we have k factors
505 | if hasattr(self, 'k') and self.k is not None:
506 | # Calculate output norm for relative comparison
507 | output_norm = torch.cat([
508 | u.flatten() for u in self.previous_raw_output_even
509 | ]).abs().mean()
510 | pred_change = self.k * (raw_input_change / output_norm)
511 | combined_pred_change = pred_change
512 | # Accumulate predicted error
513 | if not hasattr(self, 'accumulated_error_even'):
514 | self.accumulated_error_even = 0
515 | self.accumulated_error_even += combined_pred_change
516 | # Decide if we need full calculation
517 | if self.accumulated_error_even < self.thresh:
518 | self.should_calc_current_pair = False
519 | else:
520 | self.should_calc_current_pair = True
521 | self.accumulated_error_even = 0
522 | else:
523 | # First time after ret_steps or missing k factors, need to calculate
524 | self.should_calc_current_pair = True
525 | else:
526 | # No previous data yet, must calculate
527 | self.should_calc_current_pair = True
528 |
529 | # Store current input state
530 | self.previous_raw_input_even = [u.clone() for u in raw_input]
531 |
532 | # Check if we can use cached output and return early
533 | if self.is_even and not self.should_calc_current_pair and \
534 | hasattr(self, 'previous_raw_output_even') and self.previous_raw_output_even is not None:
535 | # Use cached output directly
536 | self.cnt += 1
537 | # Check if we've reached the end of sampling
538 | if self.cnt >= self.num_steps:
539 | self.cnt = 0
540 |
541 | return [(u + v).float() for u, v in zip(raw_input, self.cache_even)]
542 |
543 | elif not self.is_even and not self.should_calc_current_pair and \
544 | hasattr(self, 'previous_raw_output_odd') and self.previous_raw_output_odd is not None:
545 | # Use cached output directly
546 | self.cnt += 1
547 |
548 | # Check if we've reached the end of sampling
549 | if self.cnt >= self.num_steps:
550 | self.cnt = 0
551 |
552 | # return [u.float() for u in self.previous_raw_output_odd]
553 | return [(u + v).float() for u, v in zip(raw_input, self.cache_odd)]
554 |
555 | # Continue with normal processing since we need to calculate
556 | # embeddings
557 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
558 | grid_sizes = torch.stack(
559 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
560 | x = [u.flatten(2).transpose(1, 2) for u in x]
561 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
562 | assert seq_lens.max() <= seq_len
563 | x = torch.cat([
564 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
565 | dim=1) for u in x
566 | ])
567 |
568 | # time embeddings
569 | with amp.autocast(dtype=torch.float32):
570 | e = self.time_embedding(
571 | sinusoidal_embedding_1d(self.freq_dim, t).float())
572 | e0 = self.time_projection(e).unflatten(1, (6, self.dim))
573 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
574 |
575 | # context
576 | context_lens = None
577 | context = self.text_embedding(
578 | torch.stack([
579 | torch.cat(
580 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
581 | for u in context
582 | ]))
583 |
584 | if clip_fea is not None:
585 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
586 | context = torch.concat([context_clip, context], dim=1)
587 |
588 | # arguments
589 | kwargs = dict(
590 | e=e0,
591 | seq_lens=seq_lens,
592 | grid_sizes=grid_sizes,
593 | freqs=self.freqs,
594 | context=context,
595 | context_lens=context_lens)
596 |
597 | # Apply transformer blocks
598 | for block in self.blocks:
599 | x = block(x, **kwargs)
600 |
601 | # Apply head
602 | x = self.head(x, e)
603 |
604 | # Unpatchify
605 | output = self.unpatchify(x, grid_sizes)
606 |
607 | # Update cache and calculate change rates if needed
608 | if self.is_even: # Condition path
609 | # If we have previous output, calculate k factors for future predictions
610 | if hasattr(self, 'previous_raw_output_even') and self.previous_raw_output_even is not None:
611 | # Calculate output change at the raw level
612 | output_change = torch.cat([
613 | (u - v).flatten() for u, v in zip(output, self.previous_raw_output_even)
614 | ]).abs().mean()
615 |
616 | # Check if we have previous input state for comparison
617 | if hasattr(self, 'prev_prev_raw_input_even') and self.prev_prev_raw_input_even is not None:
618 | # Calculate input change
619 | input_change = torch.cat([
620 | (u - v).flatten() for u, v in zip(
621 | self.previous_raw_input_even, self.prev_prev_raw_input_even
622 | )
623 | ]).abs().mean()
624 |
625 | self.k = output_change / input_change
626 |
627 | # Update history
628 | self.prev_prev_raw_input_even = getattr(self, 'previous_raw_input_even', None)
629 | self.previous_raw_output_even = [u.clone() for u in output]
630 | self.cache_even = [u - v for u, v in zip(output, raw_input)]
631 |
632 | else: # Uncondition path
633 | # Store output for unconditional path
634 | self.previous_raw_output_odd = [u.clone() for u in output]
635 | self.cache_odd = [u - v for u, v in zip(output, raw_input)]
636 |
637 | # Update counter
638 | self.cnt += 1
639 | if self.cnt >= self.num_steps:
640 | self.cnt = 0
641 | self.skip_cond_step = []
642 | self.skip_uncond_step = []
643 |
644 | return [u.float() for u in output]
645 |
646 |
647 | def _validate_args(args):
648 | # Basic check
649 | assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
650 | assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
651 | assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
652 |
653 | # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
654 | if args.sample_steps is None:
655 | args.sample_steps = 40 if "i2v" in args.task else 50
656 |
657 | if args.sample_shift is None:
658 | args.sample_shift = 5.0
659 | if "i2v" in args.task and args.size in ["832*480", "480*832"]:
660 | args.sample_shift = 3.0
661 |
662 | # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
663 | if args.frame_num is None:
664 | args.frame_num = 1 if "t2i" in args.task else 81
665 |
666 | # T2I frame_num check
667 | if "t2i" in args.task:
668 | assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
669 |
670 | args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
671 | 0, sys.maxsize)
672 | # Size check
673 | assert args.size in SUPPORTED_SIZES[
674 | args.
675 | task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
676 |
677 |
678 | def _parse_args():
679 | parser = argparse.ArgumentParser(
680 | description="Generate a image or video from a text prompt or image using Wan"
681 | )
682 | parser.add_argument(
683 | "--task",
684 | type=str,
685 | default="t2v-14B",
686 | choices=list(WAN_CONFIGS.keys()),
687 | help="The task to run.")
688 | parser.add_argument(
689 | "--size",
690 | type=str,
691 | default="1280*720",
692 | choices=list(SIZE_CONFIGS.keys()),
693 | help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
694 | )
695 | parser.add_argument(
696 | "--frame_num",
697 | type=int,
698 | default=None,
699 | help="How many frames to sample from a image or video. The number should be 4n+1"
700 | )
701 | parser.add_argument(
702 | "--ckpt_dir",
703 | type=str,
704 | default="./model_weights/Wan2.1-T2V-1.3B",
705 | help="The path to the checkpoint directory.")
706 | parser.add_argument(
707 | "--offload_model",
708 | type=str2bool,
709 | default=None,
710 | help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
711 | )
712 | parser.add_argument(
713 | "--ulysses_size",
714 | type=int,
715 | default=1,
716 | help="The size of the ulysses parallelism in DiT.")
717 | parser.add_argument(
718 | "--ring_size",
719 | type=int,
720 | default=1,
721 | help="The size of the ring attention parallelism in DiT.")
722 | parser.add_argument(
723 | "--t5_fsdp",
724 | action="store_true",
725 | default=False,
726 | help="Whether to use FSDP for T5.")
727 | parser.add_argument(
728 | "--t5_cpu",
729 | action="store_true",
730 | default=False,
731 | help="Whether to place T5 model on CPU.")
732 | parser.add_argument(
733 | "--dit_fsdp",
734 | action="store_true",
735 | default=False,
736 | help="Whether to use FSDP for DiT.")
737 | parser.add_argument(
738 | "--save_file",
739 | type=str,
740 | default=None,
741 | help="The file to save the generated image or video to.")
742 | parser.add_argument(
743 | "--prompt",
744 | type=str,
745 | default=None,
746 | help="The prompt to generate the image or video from.")
747 | parser.add_argument(
748 | "--use_prompt_extend",
749 | action="store_true",
750 | default=False,
751 | help="Whether to use prompt extend.")
752 | parser.add_argument(
753 | "--prompt_extend_method",
754 | type=str,
755 | default="local_qwen",
756 | choices=["dashscope", "local_qwen"],
757 | help="The prompt extend method to use.")
758 | parser.add_argument(
759 | "--prompt_extend_model",
760 | type=str,
761 | default=None,
762 | help="The prompt extend model to use.")
763 | parser.add_argument(
764 | "--prompt_extend_target_lang",
765 | type=str,
766 | default="ch",
767 | choices=["ch", "en"],
768 | help="The target language of prompt extend.")
769 | parser.add_argument(
770 | "--base_seed",
771 | type=int,
772 | default=-1,
773 | help="The seed to use for generating the image or video.")
774 | parser.add_argument(
775 | "--image",
776 | type=str,
777 | default=None,
778 | help="The image to generate the video from.")
779 | parser.add_argument(
780 | "--sample_solver",
781 | type=str,
782 | default='unipc',
783 | choices=['unipc', 'dpm++'],
784 | help="The solver used to sample.")
785 | parser.add_argument(
786 | "--sample_steps", type=int, default=None, help="The sampling steps.")
787 | parser.add_argument(
788 | "--sample_shift",
789 | type=float,
790 | default=None,
791 | help="Sampling shift factor for flow matching schedulers.")
792 | parser.add_argument(
793 | "--sample_guide_scale",
794 | type=float,
795 | default=5.0,
796 | help="Classifier free guidance scale.")
797 | parser.add_argument(
798 | "--thresh",
799 | type=float,
800 | default=0.05,
801 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
802 | parser.add_argument(
803 | "--ret_steps",
804 | default=10,
805 | type=int,
806 | help="Number of steps to retain in the cache. Default is 10.")
807 | parser.add_argument(
808 | "--alpha",
809 | default=0.,
810 | type=float,
811 | help="Averaging factor for the cache update. Default is 0.5.")
812 | parser.add_argument(
813 | "--beta",
814 | default=1.0,
815 | type=float,
816 | help="Averaging factor for the k_t and k_x update. Default is 1.0.")
817 | parser.add_argument(
818 | "--start_idx",
819 | type=int,
820 | default=0)
821 | parser.add_argument(
822 | "--end_idx",
823 | type=int,
824 | default=946)
825 | parser.add_argument(
826 | "--out_dir",
827 | type=str,
828 | default="./output",
829 | )
830 |
831 | args = parser.parse_args()
832 |
833 | _validate_args(args)
834 |
835 | return args
836 |
837 |
838 | def _init_logging(rank):
839 | # logging
840 | if rank == 0:
841 | # set format
842 | logging.basicConfig(
843 | level=logging.INFO,
844 | format="[%(asctime)s] %(levelname)s: %(message)s",
845 | handlers=[logging.StreamHandler(stream=sys.stdout)])
846 | else:
847 | logging.basicConfig(level=logging.ERROR)
848 |
849 |
850 | def generate(args):
851 | rank = int(os.getenv("RANK", 0))
852 | world_size = int(os.getenv("WORLD_SIZE", 1))
853 | local_rank = int(os.getenv("LOCAL_RANK", 0))
854 | device = local_rank
855 | _init_logging(rank)
856 |
857 | if args.offload_model is None:
858 | args.offload_model = False if world_size > 1 else True
859 | logging.info(
860 | f"offload_model is not specified, set to {args.offload_model}.")
861 | if world_size > 1:
862 | torch.cuda.set_device(local_rank)
863 | dist.init_process_group(
864 | backend="nccl",
865 | init_method="env://",
866 | rank=rank,
867 | world_size=world_size)
868 | else:
869 | assert not (
870 | args.t5_fsdp or args.dit_fsdp
871 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
872 | assert not (
873 | args.ulysses_size > 1 or args.ring_size > 1
874 | ), f"context parallel are not supported in non-distributed environments."
875 |
876 | if args.ulysses_size > 1 or args.ring_size > 1:
877 | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
878 | from xfuser.core.distributed import (initialize_model_parallel,
879 | init_distributed_environment)
880 | init_distributed_environment(
881 | rank=dist.get_rank(), world_size=dist.get_world_size())
882 |
883 | initialize_model_parallel(
884 | sequence_parallel_degree=dist.get_world_size(),
885 | ring_degree=args.ring_size,
886 | ulysses_degree=args.ulysses_size,
887 | )
888 |
889 | if args.use_prompt_extend:
890 | if args.prompt_extend_method == "dashscope":
891 | prompt_expander = DashScopePromptExpander(
892 | model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
893 | elif args.prompt_extend_method == "local_qwen":
894 | prompt_expander = QwenPromptExpander(
895 | model_name=args.prompt_extend_model,
896 | is_vl="i2v" in args.task,
897 | device=rank)
898 | else:
899 | raise NotImplementedError(
900 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
901 |
902 | cfg = WAN_CONFIGS[args.task]
903 | if args.ulysses_size > 1:
904 | assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
905 |
906 | logging.info(f"Generation job args: {args}")
907 | logging.info(f"Generation model config: {cfg}")
908 |
909 | if dist.is_initialized():
910 | base_seed = [args.base_seed] if rank == 0 else [None]
911 | dist.broadcast_object_list(base_seed, src=0)
912 | args.base_seed = base_seed[0]
913 |
914 | if "t2v" in args.task or "t2i" in args.task:
915 | if args.prompt is None:
916 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
917 | logging.info(f"Input prompt: {args.prompt}")
918 | if args.use_prompt_extend:
919 | logging.info("Extending prompt ...")
920 | if rank == 0:
921 | prompt_output = prompt_expander(
922 | args.prompt,
923 | tar_lang=args.prompt_extend_target_lang,
924 | seed=args.base_seed)
925 | if prompt_output.status == False:
926 | logging.info(
927 | f"Extending prompt failed: {prompt_output.message}")
928 | logging.info("Falling back to original prompt.")
929 | input_prompt = args.prompt
930 | else:
931 | input_prompt = prompt_output.prompt
932 | input_prompt = [input_prompt]
933 | else:
934 | input_prompt = [None]
935 | if dist.is_initialized():
936 | dist.broadcast_object_list(input_prompt, src=0)
937 | args.prompt = input_prompt[0]
938 | logging.info(f"Extended prompt: {args.prompt}")
939 |
940 | logging.info("Creating WanT2V pipeline.")
941 | wan_t2v = wan.WanT2V(
942 | config=cfg,
943 | checkpoint_dir=args.ckpt_dir,
944 | device_id=device,
945 | rank=rank,
946 | t5_fsdp=args.t5_fsdp,
947 | dit_fsdp=args.dit_fsdp,
948 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
949 | t5_cpu=args.t5_cpu,
950 | )
951 |
952 | generation_time = []
953 | time_cost = {"GPU_Device": torch.cuda.get_device_name(0), "number_prompt": None, "avg_cost_time": None}
954 | wan_t2v.__class__.cost_time = 0
955 | wan_t2v.__class__.generate = t2v_generate
956 | wan_t2v.model.__class__.forward = easycache_forward
957 | wan_t2v.model.__class__.cnt = 0
958 | wan_t2v.model.__class__.skip_cond_step = []
959 | wan_t2v.model.__class__.skip_uncond_step = []
960 | wan_t2v.model.__class__.num_steps = args.sample_steps * 2
961 | wan_t2v.model.__class__.thresh = args.thresh
962 | wan_t2v.model.__class__.accumulated_error_even = 0
963 | wan_t2v.model.__class__.should_calc_current_pair = True
964 | wan_t2v.model.__class__.k = None
965 |
966 | wan_t2v.model.__class__.previous_raw_input_even = None
967 | wan_t2v.model.__class__.previous_raw_output_even = None
968 | wan_t2v.model.__class__.previous_raw_output_odd = None
969 | wan_t2v.model.__class__.prev_prev_raw_input_even = None
970 | wan_t2v.model.__class__.cache_even = None
971 | wan_t2v.model.__class__.cache_odd = None
972 |
973 | wan_t2v.cost_time = 0
974 | wan_t2v.model.__class__.ret_steps = 10 * 2
975 | wan_t2v.model.__class__.cutoff_steps = args.sample_steps * 2 - 2
976 |
977 | print(
978 | f"Generating {'image' if 't2i' in args.task else 'video'} ...")
979 |
980 | # start_time = time()
981 | video = wan_t2v.generate(
982 | args.prompt,
983 | size=SIZE_CONFIGS[args.size],
984 | frame_num=args.frame_num,
985 | shift=args.sample_shift,
986 | sample_solver=args.sample_solver,
987 | sampling_steps=args.sample_steps,
988 | guide_scale=args.sample_guide_scale,
989 | seed=args.base_seed,
990 | offload_model=args.offload_model)
991 | generation_time.append(wan_t2v.cost_time)
992 | if rank == 0:
993 | if args.save_file is None:
994 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
995 | formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
996 | suffix = '.png' if "t2i" in args.task else '.mp4'
997 | args.save_file = f"{args.task}_easycache_thresh{args.thresh}_step{args.sample_steps}_{formatted_prompt}_{formatted_time}" + suffix
998 |
999 | if "t2i" in args.task:
1000 | logging.info(f"Saving generated image to {args.save_file}")
1001 | cache_image(
1002 | tensor=video.squeeze(1)[None],
1003 | save_file=args.save_file,
1004 | nrow=1,
1005 | normalize=True,
1006 | value_range=(-1, 1))
1007 | else:
1008 | logging.info(f"Saving generated video to {args.save_file}")
1009 | cache_video(
1010 | tensor=video[None],
1011 | save_file=args.save_file,
1012 | fps=cfg.sample_fps,
1013 | nrow=1,
1014 | normalize=True,
1015 | value_range=(-1, 1))
1016 | logging.info("Finished.")
1017 |
1018 |
1019 | else:
1020 | if args.prompt is None:
1021 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
1022 | if args.image is None:
1023 | args.image = EXAMPLE_PROMPT[args.task]["image"]
1024 | print(f"Input prompt: {args.prompt}")
1025 | print(f"Input image: {args.image}")
1026 |
1027 | img = Image.open(args.image).convert("RGB")
1028 | if args.use_prompt_extend:
1029 | print("Extending prompt ...")
1030 | if rank == 0:
1031 | prompt_output = prompt_expander(
1032 | args.prompt,
1033 | tar_lang=args.prompt_extend_target_lang,
1034 | image=img,
1035 | seed=args.base_seed)
1036 | if prompt_output.status == False:
1037 | print(
1038 | f"Extending prompt failed: {prompt_output.message}")
1039 | print("Falling back to original prompt.")
1040 | input_prompt = args.prompt
1041 | else:
1042 | input_prompt = prompt_output.prompt
1043 | input_prompt = [input_prompt]
1044 | else:
1045 | input_prompt = [None]
1046 | if dist.is_initialized():
1047 | dist.broadcast_object_list(input_prompt, src=0)
1048 | args.prompt = input_prompt[0]
1049 | print(f"Extended prompt: {args.prompt}")
1050 |
1051 | print("Creating WanI2V pipeline.")
1052 | wan_i2v = wan.WanI2V(
1053 | config=cfg,
1054 | checkpoint_dir=args.ckpt_dir,
1055 | device_id=device,
1056 | rank=rank,
1057 | t5_fsdp=args.t5_fsdp,
1058 | dit_fsdp=args.dit_fsdp,
1059 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
1060 | t5_cpu=args.t5_cpu,
1061 | )
1062 | generation_time = []
1063 | time_cost = {"GPU_Device": torch.cuda.get_device_name(0), "number_prompt": None, "avg_cost_time": None}
1064 | wan_i2v.__class__.generate = i2v_generate
1065 | wan_i2v.model.__class__.forward = easycache_forward
1066 | wan_i2v.model.__class__.cnt = 0
1067 | wan_i2v.model.__class__.num_steps = args.sample_steps * 2
1068 | wan_i2v.model.__class__.thresh = args.thresh
1069 |
1070 | wan_i2v.model.__class__.accumulated_error_even = 0
1071 | wan_i2v.model.__class__.should_calc_current_pair = True
1072 | wan_i2v.model.__class__.k = None
1073 |
1074 | wan_i2v.model.__class__.previous_raw_input_even = None
1075 | wan_i2v.model.__class__.previous_raw_output_even = None
1076 | wan_i2v.model.__class__.previous_raw_output_odd = None
1077 | wan_i2v.model.__class__.prev_prev_raw_input_even = None
1078 | wan_i2v.model.__class__.cache_even = None
1079 | wan_i2v.model.__class__.cache_odd = None
1080 |
1081 | wan_i2v.cost_time = 0
1082 |
1083 | wan_i2v.model.__class__.ret_steps = 10 * 2
1084 | wan_i2v.model.__class__.cutoff_steps = args.sample_steps * 2 - 2
1085 |
1086 | print("Generating video ...")
1087 | video = wan_i2v.generate(
1088 | args.prompt,
1089 | img,
1090 | max_area=MAX_AREA_CONFIGS[args.size],
1091 | frame_num=args.frame_num,
1092 | shift=args.sample_shift,
1093 | sample_solver=args.sample_solver,
1094 | sampling_steps=args.sample_steps,
1095 | guide_scale=args.sample_guide_scale,
1096 | seed=args.base_seed,
1097 | offload_model=args.offload_model)
1098 | generation_time.append(wan_i2v.cost_time)
1099 |
1100 | if rank == 0:
1101 | if args.save_file is None:
1102 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
1103 | formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
1104 | suffix = '.png' if "t2i" in args.task else '.mp4'
1105 | args.save_file = f"{args.task}_easycache_thresh{args.thresh}_step{args.sample_steps}_{formatted_prompt}_{formatted_time}" + suffix
1106 |
1107 | if "t2i" in args.task:
1108 | logging.info(f"Saving generated image to {args.save_file}")
1109 | cache_image(
1110 | tensor=video.squeeze(1)[None],
1111 | save_file=args.save_file,
1112 | nrow=1,
1113 | normalize=True,
1114 | value_range=(-1, 1))
1115 | else:
1116 | logging.info(f"Saving generated video to {args.save_file}")
1117 | cache_video(
1118 | tensor=video[None],
1119 | save_file=args.save_file,
1120 | fps=cfg.sample_fps,
1121 | nrow=1,
1122 | normalize=True,
1123 | value_range=(-1, 1))
1124 | logging.info("Finished.")
1125 |
1126 | time_cost["number_prompt"] = len(generation_time)
1127 | time_cost["avg_cost_time"] = sum(generation_time) / (len(generation_time)) if len(generation_time) > 0 else 0
1128 |
1129 | print(
1130 | f"GPU_Device:{time_cost['GPU_Device']}, number_prompt: {time_cost['number_prompt']}, avg_cost_time: {time_cost['avg_cost_time']}")
1131 | try:
1132 | with open(f"./{args.out_dir}/1time_cost.json", "a+") as f:
1133 | portalocker.lock(f, portalocker.LOCK_EX)
1134 | f.seek(0)
1135 | try:
1136 | existing_data = json.load(f)
1137 | except (json.JSONDecodeError, FileNotFoundError):
1138 | existing_data = []
1139 | existing_data.append(time_cost)
1140 | f.seek(0)
1141 | f.truncate()
1142 | json.dump(existing_data, f, indent=4)
1143 | except Exception as e:
1144 | print(f"Error saving time cost data: {e}")
1145 |
1146 |
1147 | if __name__ == "__main__":
1148 | args = _parse_args()
1149 | generate(args)
1150 |
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/example/grogu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/example/grogu.png
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/tools/video_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import argparse
4 | import torch
5 | import lpips
6 | import numpy as np
7 | from tqdm import tqdm
8 | from torchmetrics.image import StructuralSimilarityIndexMeasure
9 |
10 | def load_video_frames(path, resize_to=None):
11 | """
12 | Load all frames from a video file as a list of HxWx3 uint8 arrays.
13 | Optionally resize each frame to `resize_to` (w, h).
14 | """
15 |
16 | cap = cv2.VideoCapture(path)
17 | frames = []
18 | while True:
19 | ret, img = cap.read()
20 | if not ret:
21 | break
22 | if resize_to is not None:
23 | img = cv2.resize(img, resize_to)
24 | frames.append(np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), axis=0))
25 | cap.release()
26 | return np.concatenate(frames)
27 |
28 |
29 | def compute_video_metrics(frames_gt, frames_gen,
30 | device, ssim_metric, lpips_fn):
31 | """
32 | Compute PSNR, SSIM, LPIPS for two lists of frames (uint8 BGR).
33 | All computations on `device`.
34 | Returns (psnr, ssim, lpips) scalars.
35 | """
36 | # ensure same frame count
37 | # convert to tensors [N,3,H,W], normalize to [0,1]
38 | gt_t = torch.from_numpy(frames_gt).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous()
39 |
40 | gen_t = torch.from_numpy(frames_gen).float().to(device).permute(0, 3, 1, 2).div_(255).contiguous()
41 |
42 | # PSNR (data_range=1.0): -10 * log10(mse)
43 | mse = torch.mean((gt_t - gen_t) ** 2)
44 | psnr = -10.0 * torch.log10(mse)
45 |
46 | # SSIM: returns average over batch
47 | ssim_val = ssim_metric(gen_t, gt_t)
48 |
49 | # LPIPS: expects [-1,1]
50 | with torch.no_grad():
51 | lpips_val = lpips_fn(gt_t * 2.0 - 1.0, gen_t * 2.0 - 1.0).mean()
52 |
53 | return psnr.item(), ssim_val.item(), lpips_val.item()
54 |
55 |
56 | def main():
57 | parser = argparse.ArgumentParser(
58 | description="Compute PSNR/SSIM/LPIPS on GPU for two folders of .mp4 videos"
59 | )
60 | parser.add_argument("--original_video", required=True,
61 | help="ground-truth .mp4 videos")
62 | parser.add_argument("--generated_video", required=True,
63 | help="generated .mp4 videos")
64 | parser.add_argument("--device", default="cuda",
65 | help="Torch device, e.g. 'cuda' or 'cpu'")
66 | parser.add_argument("--lpips_net", default="alex", choices=["alex", "vgg"],
67 | help="Backbone for LPIPS")
68 | args = parser.parse_args()
69 |
70 | device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu")
71 | # instantiate metrics on device
72 | ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
73 | lpips_fn = lpips.LPIPS(net=args.lpips_net, spatial=True).to(device)
74 |
75 | # gather .mp4 filenames
76 | gt_files = args.original_video
77 | gen_set = args.generated_video
78 |
79 | psnrs, ssims, lpips_vals = [], [], []
80 | for fname in tqdm([gt_files], desc="Videos"):
81 | path_gt = gt_files
82 | path_gen = gen_set
83 |
84 | # load frames; resize generated to match GT dimensions
85 | frames_gt = load_video_frames(path_gt)
86 | frames_gen = load_video_frames(path_gen)
87 |
88 | res = compute_video_metrics(frames_gt, frames_gen,
89 | device, ssim_metric, lpips_fn)
90 | if res is None:
91 | continue
92 | p, s, l = res
93 | psnrs.append(p);
94 | ssims.append(s);
95 | lpips_vals.append(l)
96 |
97 | if not psnrs:
98 | print("No valid videos processed.")
99 | return
100 |
101 | print("\n=== Overall Averages ===")
102 | print(f"Average PSNR : {np.mean(psnrs):.2f} dB")
103 | print(f"Average SSIM : {np.mean(ssims):.4f}")
104 | print(f"Average LPIPS: {np.mean(lpips_vals):.4f}")
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/videos/i2v_easycache_14b_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/i2v_easycache_14b_720p.gif
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/videos/i2v_gt_14b_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/i2v_gt_14b_720p.gif
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/videos/t2v_easycache_14b_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/t2v_easycache_14b_720p.gif
--------------------------------------------------------------------------------
/EasyCache4Wan2.1/videos/t2v_gt_14b_720p.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/EasyCache4Wan2.1/videos/t2v_gt_14b_720p.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching
3 |
4 |
Xin Zhou1\*,
5 |
Dingkang Liang1\*,
6 | Kaijin Chen
1, Tianrui Feng
1,
7 |
Xiwu Chen2, Hongkai Lin
1,
8 |
Yikang Ding2, Feiyang Tan
2,
9 |
Hengshuang Zhao3,
10 |
Xiang Bai1†
11 |
12 |
1 Huazhong University of Science and Technology,
2 MEGVII Technology,
3 The University of Hong Kong
13 |
14 | (\*) Equal contribution. (†) Corresponding author.
15 |
16 | [](https://arxiv.org/abs/2507.02860)
17 | [](https://H-EmbodVis.github.io/EasyCache/)
18 | [](https://github.com/LMD0311/EasyCache/blob/main/LICENSE)
19 |
20 |
21 |
22 | ## 🎬 Visual Comparisons
23 | Video synchronization issues may occur due to network load, for improved visualization, see the [project page](https://H-EmbodVis.github.io/EasyCache/)
24 |
25 | **Prompt: "Grassland at dusk, wild horses galloping, golden light flickering across manes."**
26 | *(HunyuanVideo)*
27 |
28 | | Baseline | Ours (2.28x) | TeaCache (1.68x) | PAB (1.19x) |
29 | | :---: | :---: | :---: | :---: |
30 | |  |  |  |  |
31 |
32 | **Prompt: "A top-down view of a barista creating latte art, skillfully pouring milk to form the letters 'TPAMI' on coffee."**
33 | *(Wan2.1-14B)*
34 |
35 | | Baseline | Ours (2.63x) | TeaCache (1.46x) | PAB (2.10x) |
36 | | :---: | :---: | :---: | :---: |
37 | |  |  |  |  |
38 |
39 | **Compatibility with SVG**
40 |
41 | SVG with EasyCache on HunyuanVideo can achieve more than 3x speedup.
42 |
43 | https://github.com/user-attachments/assets/248ab88f-dfa8-4980-9b51-5c081e27db9a
44 |
45 |
46 | ## 📰 News
47 | - **If you like our project, please give us a star ⭐ on GitHub for the latest update.**
48 | - **[2025/07/06]** 🔥 EasyCache for [**Wan2.1**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4Wan2.1) I2V is released.
49 | - **[2025/07/05]** 🔥 EasyCache for [**Wan2.1**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4Wan2.1) T2V is released.
50 | - **[2025/07/04]** 🎉 Release the [**paper**](https://arxiv.org/abs/2507.02860) of EasyCache.
51 | - **[2025/07/03]** 🔥 EasyCache for Sparse-VideoGen on [**HunyuanVideo**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4HunyuanVideo) is released.
52 | - **[2025/07/02]** 🔥 EasyCache for [**HunyuanVideo**](https://github.com/H-EmbodVis/EasyCache/tree/main/EasyCache4HunyuanVideo) is released.
53 |
54 | ## Abstract
55 | Video generation models have demonstrated remarkable performance, yet their broader adoption remains constrained by slow inference speeds and substantial computational costs, primarily due to the iterative nature of the denoising process. Addressing this bottleneck is essential for democratizing advanced video synthesis technologies and enabling their integration into real-world applications. This work proposes EasyCache, a training-free acceleration framework for video diffusion models. EasyCache introduces a lightweight, runtime-adaptive caching mechanism that dynamically reuses previously computed transformation vectors, avoiding redundant computations during inference. Unlike prior approaches, EasyCache requires no offline profiling, pre-computation, or extensive parameter tuning. We conduct comprehensive studies on various large-scale video generation models, including OpenSora, Wan2.1, and HunyuanVideo. Our method achieves leading acceleration performance, reducing inference time by up to 2.1-3.3× compared to the original baselines while maintaining high visual fidelity with a significant up to 36% PSNR improvement compared to the previous SOTA method. This improvement makes our EasyCache a efficient and highly accessible solution for high-quality video generation in both research and practical applications.
56 |
57 |
58 | ## 🚀 Main Performance
59 |
60 | We validated the performance of EasyCache on leading video generation models and compared it with other state-of-the-art training-free acceleration methods.
61 |
62 | ### Comparison with SOTA Methods
63 |
64 | Tested on Vbench prompts with NVIDIA A800.
65 |
66 | **Performance on HunyuanVideo:**
67 | | Method | Latency (s)↓ | Speedup ↑ | PSNR ↑ | SSIM ↑ | LPIPS ↓ |
68 | |:---:|:---:|:---:|:---:|:---:|:---:|
69 | | HunyuanVideo (Baseline) | 1124.30 | 1.00x | - | - | - |
70 | | PAB | 958.23 | 1.17x | 18.58 | 0.7023 | 0.3827 |
71 | | TeaCache | 674.04 | 1.67x | 23.85 | 0.8185 | 0.1730 |
72 | | SVG | 802.70 | 1.40x | 26.57 | 0.8596 | 0.1368 |
73 | | **EasyCache (Ours)** | **507.97** | **2.21x** | **32.66** | **0.9313** | **0.0533** |
74 |
75 | **Performance on Wan2.1-1.3B:**
76 |
77 | | Method | Latency (s)↓ | Speedup ↑ | PSNR ↑ | SSIM ↑ | LPIPS ↓ |
78 | |:---:|:---:|:---:|:---:|:---:|:---:|
79 | | Wan2.1 (Baseline) | 175.35 | 1.00x | - | - | - |
80 | | PAB | 102.03 | 1.72x | 18.84 | 0.6484 | 0.3010 |
81 | | TeaCache | 87.77 | 2.00x | 22.57 | 0.8057 | 0.1277 |
82 | | **EasyCache (Ours)** | **69.11** | **2.54x** | **25.24** | **0.8337** | **0.0952** |
83 |
84 | ### Compatibility with Other Acceleration Techniques
85 |
86 | EasyCache is orthogonal to other acceleration techniques, such as the efficient attention mechanism SVG, and can be combined with them for even greater performance gains.
87 |
88 | **Combined Performance on HunyuanVideo (720p):**
89 | *Tested on NVIDIA H20 GPUs.*
90 | | Method | Latency (s)↓ | Speedup ↑ | PSNR (dB) ↑ |
91 | |:---:|:---:|:---:|:---:|
92 | | Baseline | 6594s | 1.00x | - |
93 | | SVG | 3474s | 1.90x | 27.56 |
94 | | SVG (w/ TeaCache) | 2071s | 3.18x | 22.65 |
95 | | SVG (w/ **Ours**) | **1981s** | **3.33x** | **27.26** |
96 |
97 |
98 | ## 🛠️ Usage
99 | Detailed instructions for each supported model are provided in their respective directories. We are continuously working to extend support to more models.
100 |
101 | ### HunyuanVideo
102 | 1. **Prerequisites**: Set up the environment and download weights from the official HunyuanVideo repository.
103 | 2. **Copy Files**: Place the EasyCache script files into your local HunyuanVideo project directory.
104 | 3. **Run**: Execute the provided Python script to run inference with acceleration.
105 | **For complete instructions, please refer to the [README](./EasyCache4HunyuanVideo/README.md).**
106 |
107 | ### Wan2.1
108 | 1. **Prerequisites**: Set up the environment and download weights from the official Wan2.1 repository.
109 | 2. **Copy Files**: Place the EasyCache script files into your local Wan2.1 project directory.
110 | 3. **Run**: Execute the provided Python script to run inference with acceleration.
111 | **For complete instructions, please refer to the [README](./EasyCache4Wan2.1/README.md).**
112 |
113 | ## 🎯 To Do
114 |
115 | - [x] Support HunyuanVideo
116 | - [x] Support Sparse-VideoGen on HunyuanVideo
117 | - [x] Support Wan2.1 T2V
118 | - [x] Support Wan2.1 I2V
119 | - [ ] Support FLUX
120 |
121 | ## 🌹 Acknowledgements
122 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1), [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo), [OpenSora](https://github.com/hpcaitech/Open-Sora), and [SVG](https://github.com/svg-project/Sparse-VideoGen) repositories, for their open research and exploration.
123 |
124 | ## 📖 Citation
125 |
126 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation.
127 | ```bibtex
128 | @article{zhou2025easycache,
129 | title={Less is Enough: Training-Free Video Diffusion Acceleration via Runtime-Adaptive Caching},
130 | author={Zhou, Xin and Liang, Dingkang and Chen, Kaijin and and Feng, Tianrui and Chen, Xiwu and Lin, Hongkai and Ding, Yikang and Tan, Feiyang and Zhao, Hengshuang and Bai, Xiang},
131 | journal={arXiv preprint arXiv:2507.02860},
132 | year={2025}
133 | }
134 | ```
135 |
--------------------------------------------------------------------------------
/demo/gt/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/gt/6.gif
--------------------------------------------------------------------------------
/demo/gt/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/gt/7.gif
--------------------------------------------------------------------------------
/demo/our/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/our/6.gif
--------------------------------------------------------------------------------
/demo/our/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/our/7.gif
--------------------------------------------------------------------------------
/demo/pab/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/pab/6.gif
--------------------------------------------------------------------------------
/demo/pab/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/pab/7.gif
--------------------------------------------------------------------------------
/demo/teacache/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/teacache/6.gif
--------------------------------------------------------------------------------
/demo/teacache/7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/H-EmbodVis/EasyCache/73b04c13c05dc446b712f91cdf0f367399a752e4/demo/teacache/7.gif
--------------------------------------------------------------------------------