├── assets
├── gif1.gif
└── gif2.gif
├── examples
├── DualParal_Wan.py
└── Wan-Video.py
├── readme.md
├── requirements.txt
└── src
├── __init__.py
├── distribution_utils.py
└── pipelines
├── __init__.py
├── base_pipeline.py
└── pipeline_Wan.py
/assets/gif1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/assets/gif1.gif
--------------------------------------------------------------------------------
/assets/gif2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/assets/gif2.gif
--------------------------------------------------------------------------------
/examples/DualParal_Wan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4 | os.environ["HF_DATASETS_CACHE"] = "../Checkpoint/"
5 | os.environ["HF_HOME"] = "../Checkpoint/"
6 | os.environ["HUGGINGFACE_HUB_CACHE"] = "../Checkpoint/"
7 | os.environ["TRANSFORMERS_CACHE"] = "../Checkpoint/"
8 | # sys.path.append('../../')
9 | import time
10 | import copy
11 | import argparse
12 | import numpy as np
13 | import multiprocessing
14 | from PIL import Image
15 | from tqdm import tqdm
16 | from datetime import datetime
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from pytorch_lightning import seed_everything
21 | from diffusers.utils import export_to_video
22 | from diffusers import AutoencoderKLWan, WanPipeline
23 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
24 |
25 | from src.pipelines import DualParalWanPipeline, QueueLatents, Queue
26 | from src.distribution_utils import RuntimeConfig, DistController, export_to_images, memory_check
27 |
28 | def _parse_args():
29 | # For basic args
30 | parser = argparse.ArgumentParser(description="DualParal with Wan")
31 | parser.add_argument("--dtype", type=str, default="bf16",
32 | help="Model dtype (float64, float32, float16, fp32, fp16, half, bf16)")
33 | parser.add_argument("--seed", type=int, default=12345,
34 | help="The seed to use for generating the image or video.")
35 | parser.add_argument("--save_file", type=str, default="../results/",
36 | help="The file to save the generated image or video to.")
37 | parser.add_argument("--verbose", action="store_true", default=False,
38 | help="Enable verbose mode")
39 | parser.add_argument("--export_image", action="store_true", default=False,
40 | help="Enable exporting video frames.")
41 |
42 | # For Wan-Video model
43 | parser.add_argument("--model_id", type=str, default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
44 | help="Model Id for Wan-2.1 Video.")
45 | parser.add_argument("--height", type=int, default=480,
46 | help="Height of generating videos")
47 | parser.add_argument("--width", type=int, default=832,
48 | help="Width of generating videos")
49 | parser.add_argument("--sample_steps", type=int, default=50,
50 | help="The sampling steps.")
51 | parser.add_argument("--sample_shift", type=float, default=None,
52 | help="Sampling shift factor for flow matching schedulers.")
53 | parser.add_argument("--sample_guide_scale", type=float, default=5.0,
54 | help="Classifier free guidance scale.")
55 | parser.add_argument("--prompt", type=str, default="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.",
56 | help="The prompt to generate the image or video from.")
57 |
58 | # For DualParal
59 | parser.add_argument("--num_per_block", type=int, default=10,
60 | help="How many latents per block in DualParal.")
61 | parser.add_argument("--latents_num", type=int, default=30,
62 | help="How many latents to sample from a image or video. The total frames is equal to (latents_num-1)*4+1.")
63 | parser.add_argument("--num_cat", type=int, default=5,
64 | help="How many latents to concat in previous and backward blocks separately.")
65 |
66 | args = parser.parse_args()
67 | return args
68 |
69 | def prepare_model(args, parallel_config, runtime_config):
70 | model_id = args.model_id
71 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
72 | flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
73 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
74 | model = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=runtime_config.dtype)
75 | model.scheduler = scheduler
76 |
77 | Pipe = DualParalWanPipeline(model, parallel_config, runtime_config)
78 | Pipe.to_device(device=parallel_config.device, dtype=runtime_config.dtype, non_blocking=True)
79 |
80 | num_channels_latents = Pipe.model.transformer.config.in_channels
81 | args.height = int(args.height) // Pipe.model.vae_scale_factor_spatial
82 | args.width = int(args.width) // Pipe.model.vae_scale_factor_spatial
83 | return Pipe, num_channels_latents
84 |
85 | def update(ID, QueueWan, cnt_, Pipe, z):
86 | test = QueueWan.get(ID)
87 | scheduler = Pipe.get_scheduler_dict(QueueWan.begin + ID)
88 | z = scheduler.step(z, Pipe.timesteps[test.denoise_time-1], test.z, return_dict=False)[0]
89 | test.z.copy_(z, non_blocking=True)
90 |
91 | if ID + QueueWan.begin in cnt_:
92 | cnt_ [ID + QueueWan.begin] += 1
93 | else:
94 | cnt_ [ID + QueueWan.begin] = 1
95 |
96 | def main(args, rank, world_size):
97 | #---------------Model Preparation--------------------
98 | torch.set_grad_enabled(False)
99 | torch.backends.cudnn.enabled = True
100 | torch.backends.cudnn.benchmark = True
101 | seed_everything(args.seed)
102 | parallel_config = DistController(rank, world_size)
103 | runtime_config = RuntimeConfig(args.seed, args.dtype)
104 |
105 | prompt = args.prompt
106 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
107 |
108 | Pipe, out_channels = prepare_model(args, parallel_config, runtime_config)
109 | Pipe.get_model_args(
110 | prompt=prompt,
111 | num_inference_steps=args.sample_steps,
112 | negative_prompt=negative_prompt,
113 | guidance_scale=args.sample_guide_scale,
114 | )
115 | timesteps = Pipe.timesteps
116 | latent_size = (args.num_per_block, args.height, args.width)
117 | #---------------Warm Up--------------------
118 | print(f"[{Pipe.parallel_config.device}]--------------Warm Up-----------------")
119 | latent_size_warmup = (1, out_channels, args.num_per_block, args.height, args.width)
120 | Block_latents_size = None
121 | for iteration in range(1):
122 | warmup_block = QueueLatents(1, out_channels, latent_size).to(Pipe.parallel_config.device, Pipe.runtime_config.dtype, non_blocking=True)
123 | if Pipe.parallel_config.rank==0:
124 | output = Pipe.onestep(warmup_block, latent_size=warmup_block.z.size())
125 | Block_latents_size = output.size()
126 | Pipe.parallel_config.pipeline_send(tensor=output, dtype=Pipe.runtime_config.dtype, verbose=False)
127 | output = Pipe.parallel_config.pipeline_recv(dtype=Pipe.runtime_config.dtype, dimension=warmup_block.z.dim(), verbose=False)
128 | else:
129 | output = Pipe.parallel_config.pipeline_recv(dtype=Pipe.runtime_config.dtype, dimension=3, verbose=False)
130 | warmup_block.z = output
131 | Block_latents_size = output.size()
132 | output = Pipe.onestep(warmup_block, latent_size=latent_size_warmup)
133 | Pipe.parallel_config.pipeline_send(tensor=output, dtype=Pipe.runtime_config.dtype, verbose=False)
134 | del output
135 | print(f"[{Pipe.parallel_config.device}]-------------Warm Up End----------------------")
136 |
137 | #---------------DualParal--------------------
138 | cnt, video, cnt_ = 0, None, {} # cnt for counting total latents in queue, video for concating video latents
139 | Pipe.parallel_config.init_buffer(timesteps)
140 |
141 | QueueWan = Queue(num_per_block=args.num_per_block, num_cat=args.num_cat, lenth_of_Queue=len(timesteps))
142 | Block_ref = copy.deepcopy(QueueWan)
143 | latent_tmp = (args.num_per_block + args.num_cat, args.height, args.width)
144 | test = QueueLatents(1, out_channels, latent_tmp)
145 | Block_ref.add_block(test)
146 | if Pipe.parallel_config.rank>0:
147 | num_frames = Block_latents_size[1]//args.num_per_block*Block_ref.get_size(0)
148 | tensor_shape = torch.tensor([Block_latents_size[0], num_frames, Block_latents_size[2]], dtype=torch.int64).contiguous()
149 | Pipe.parallel_config.modify_recv_queue(iteration=-1, idx=0, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False)
150 |
151 | torch.cuda.synchronize()
152 | torch.cuda.empty_cache()
153 | start_time = time.time()
154 | template_z = torch.randn(1, out_channels, args.num_per_block + args.num_cat, args.height, args.width, pin_memory=True)
155 | for iteration in tqdm(range(len(timesteps)+args.latents_num//args.num_per_block)):
156 | Pipe.cache = []
157 | now_iteration = (iteration-1+2)%2
158 | prev_get = Pipe.parallel_config.init_get_start(now_iteration)
159 | if iteration == 1: prev_get = 0 # Corner Case
160 |
161 | # Change Time in Block Ref
162 | if Pipe.parallel_config.rank > 0:
163 | for idx in range(Block_ref.end-Block_ref.begin, -1, -1):
164 | test = Block_ref.get(idx)
165 | test.denoise_time += 1
166 | if Block_ref.check_first(): Block_ref.del_prev_first()
167 |
168 | # Add block
169 | add_block=False
170 | if cnt < args.latents_num:
171 | latent_tmp = latent_size
172 | if iteration==0:
173 | latent_tmp = (args.num_per_block + args.num_cat, args.height, args.width)
174 | select_indice = torch.randperm(template_z.size(2) - args.num_cat)
175 | template_z = torch.cat((template_z[:, :, -args.num_cat:], template_z[:, :, select_indice]), dim=2)
176 | test = QueueLatents(1, out_channels, latent_tmp, template_z).to(Pipe.parallel_config.device, Pipe.runtime_config.dtype, non_blocking=True)
177 | QueueWan.add_block(test)
178 | cnt += args.num_per_block
179 | add_block=True
180 | Pipe.add_scheduler(QueueWan.end)
181 | if cnt < args.latents_num:
182 | test = QueueLatents(1, out_channels, latent_size)
183 | Block_ref.add_block(test)
184 | if args.verbose: print(f"[{parallel_config.device}]--------DualParal in iteration-{iteration} is Begin with Queue from {QueueWan.begin} to {QueueWan.end}-----------")
185 | # Prepare for receving latents between two GPUs
186 | if Pipe.parallel_config.rank > 0:
187 | for idx in range(Block_ref.end-Block_ref.begin, -1, -1):
188 | num = Block_latents_size[1]//args.num_per_block * Block_ref.get_size(idx)
189 | tensor_shape = torch.tensor([Block_latents_size[0], num, Block_latents_size[2]], dtype=torch.int64).contiguous()
190 | Pipe.parallel_config.modify_recv_queue(iteration, idx, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False)
191 |
192 | # Del First Block in QueueWan
193 | del_begin = 0
194 | if QueueWan.check_first():
195 | QueueWan.del_prev_first()
196 | ID = QueueWan.begin - 2
197 | if ID >= 0: Pipe.del_scheduler(ID)
198 | del_begin = 1
199 |
200 | for idx in range(QueueWan.end-QueueWan.begin, -1, -1):
201 | now_end = ( idx==0 and iteration==(len(timesteps)+args.latents_num//args.num_per_block-1) )
202 | get_next = False
203 | # Receving
204 | if Pipe.parallel_config.rank != 0:
205 | x = Pipe.parallel_config.recv_next(iteration-1, idx, queue_lenth=Block_ref.end-Block_ref.begin+1,
206 | force=True, end=(Block_ref.end-Block_ref.begin+1==0), verbose=False)
207 | input_block_tmp = QueueWan.get(idx)
208 | input_block = QueueLatents(1, out_channels, None, None)
209 | input_block.denoise_time = input_block_tmp.denoise_time
210 | input_block.z = x
211 | else:
212 | latent_size_tmp = latent_size_warmup
213 | if QueueWan.begin+idx == 0:
214 | latent_size_tmp = (1, out_channels, args.num_per_block + args.num_cat, args.height, args.width)
215 | tensor_shape = torch.tensor(latent_size_tmp, dtype=torch.int64).contiguous()
216 | Pipe.parallel_config.modify_recv_queue(iteration, idx, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False)
217 | force = False
218 | if prev_get >= 0:
219 | # Make sure the gap between communication large than world_size to make sure output
220 | # is already come out from the last device
221 | if (abs(prev_get+1 + QueueWan.end-idx+int(add_block)) <= Pipe.parallel_config.world_size or iteration==1)\
222 | and idx == QueueWan.end-QueueWan.begin and add_block==True:
223 | get_next = True
224 | else:
225 | force = (now_iteration==(iteration+1)%2 and prev_get==idx)
226 | z = Pipe.parallel_config.recv_next(now_iteration, prev_get, queue_lenth=QueueWan.end-QueueWan.begin+1,
227 | force=force, end=now_end, verbose=False)
228 | if z is not None:
229 | get_next = True
230 | ID = prev_get - del_begin
231 | if now_iteration%2 == iteration%2:
232 | ID = prev_get
233 | update(ID, QueueWan, cnt_, Pipe, z)
234 | prev_get -= 1
235 | else: get_next = True
236 |
237 | if prev_get < 0:
238 | now_iteration = now_iteration^1
239 | prev_get = Pipe.parallel_config.init_get_start(now_iteration)
240 | input_block = QueueLatents(1, out_channels, None, None)
241 | L, R, latents_z, denoise_time = QueueWan.prepare_for_forward(idx)
242 | input_block.z = latents_z
243 | input_block.denoise_time = denoise_time
244 |
245 | size_frames = QueueWan.get_size(idx)
246 | latent_size_tmp = (latent_size_warmup[0], latent_size_warmup[1], size_frames, latent_size_warmup[3], latent_size_warmup[4])
247 | # Pipe.cache = []
248 | x = Pipe.onestep(input_block, latent_size=latent_size_tmp, latent_num=Block_latents_size[1], select=args.num_cat, select_all=args.num_per_block, verbose=False)
249 |
250 | if Pipe.parallel_config.rank != Pipe.parallel_config.world_size-1:
251 | # check recieving next
252 | if parallel_config.rank==0 and not get_next:
253 | z = Pipe.parallel_config.recv_next(now_iteration, prev_get, queue_lenth=QueueWan.end-QueueWan.begin+1,
254 | force=True, end=now_end, verbose=False)
255 | ID = prev_get - del_begin
256 | if now_iteration%2 == iteration%2:
257 | ID = prev_get
258 | update(ID, QueueWan, cnt_, Pipe, z)
259 | prev_get -= 1
260 | if prev_get < 0:
261 | now_iteration = now_iteration^1
262 | prev_get = Pipe.parallel_config.init_get_start(now_iteration)
263 | if args.verbose: print(f"[{Pipe.parallel_config.device}] ready to send X with size {x.size()}")
264 | Pipe.parallel_config.pipeline_isend(tensor=x, dtype=Pipe.runtime_config.dtype, verbose=False)
265 | else:
266 | size_latent = QueueWan.get_size(idx, itself=True)
267 | x = x[:, :, -size_latent:].clone().contiguous()
268 | if args.verbose: print(f"[{Pipe.parallel_config.device}] ready to send X with size {x.size()}, with sum {x.size()}")
269 | Pipe.parallel_config.pipeline_isend(tensor=x, dtype=Pipe.runtime_config.dtype, verbose=False)
270 | #Update Denoise_Time in Queue
271 | input_block = QueueWan.get(idx)
272 | input_block.denoise_time += 1
273 |
274 | # Extract First Block
275 | if del_begin==1 and Pipe.parallel_config.rank==0:
276 | first_block = QueueWan.get(-1)
277 | z = Pipe.parallel_config.recv_next(iteration-1, idx=0, verbose=False, queue_lenth=QueueWan.end-QueueWan.begin+1,
278 | force=True, end=(iteration==(len(timesteps)+args.latents_num//args.num_per_block-1)))
279 | if z is not None:
280 | prev_get = -1
281 | cnt_ [prev_get + QueueWan.begin] += 1
282 | scheduler = Pipe.get_scheduler_dict(QueueWan.begin - 1)
283 | z = scheduler.step(z, Pipe.timesteps[first_block.denoise_time-1], first_block.z, return_dict=False)[0]
284 | first_block.z = z
285 | video_ = first_block.z
286 | video = video_ if video is None else torch.cat((video, video_), dim=2)
287 | if args.verbose: print(f"[{Pipe.parallel_config.device}] Now {iteration} video size: ", video.size())
288 |
289 | torch.cuda.synchronize()
290 | print(f"[{Pipe.parallel_config.device}] Whole inference time {time.time()-start_time:.6f}s")
291 | if Pipe.parallel_config.rank==0:
292 | print("Video latents size: ", video.size())
293 | video = Pipe.get_video(video)[0]
294 | print(f"Final Video shape: {video.shape}")
295 | export_to_video(video, args.save_file+"output.mp4", fps=16)
296 | if args.export_image:
297 | export_to_images(video, args.save_file + "frames/")
298 |
299 | if __name__ == "__main__":
300 | args = _parse_args()
301 | multiprocessing.set_start_method('spawn')
302 | num_processes = torch.cuda.device_count()
303 | processes = []
304 |
305 | for rank in range(num_processes):
306 | p = multiprocessing.Process(target=main, args=(args, rank, num_processes))
307 | p.start()
308 | processes.append(p)
309 |
310 | for p in processes:
311 | p.join()
--------------------------------------------------------------------------------
/examples/Wan-Video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4 | os.environ["HF_DATASETS_CACHE"] = "../Checkpoint/"
5 | os.environ["HF_HOME"] = "../Checkpoint/"
6 | os.environ["HUGGINGFACE_HUB_CACHE"] = "../Checkpoint/"
7 | os.environ["TRANSFORMERS_CACHE"] = "../Checkpoint/"
8 | import time
9 | import torch
10 | from diffusers.utils import export_to_video
11 | from diffusers import AutoencoderKLWan, WanPipeline
12 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
13 |
14 | # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
15 | model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
16 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
17 | flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
18 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
19 | pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
20 | pipe.scheduler = scheduler
21 | pipe.to("cuda")
22 |
23 | prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
24 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
25 |
26 | start = time.time()
27 | output = pipe(
28 | prompt=prompt,
29 | negative_prompt=negative_prompt,
30 | height=480,
31 | width=832,
32 | num_frames=24,
33 | guidance_scale=5.0,
34 | ).frames[0]
35 | print(f"Using {time.time()-start:.6f}s")
36 | export_to_video(output, "output.mp4", fps=16)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Minute-Long Videos with Dual Parallelisms
4 |
5 |
6 |
7 |

8 |

9 |
10 |
11 | ## 📚 TL;DR (Too Long; Didn't Read)
12 | **DualParal** is a distributed inference strategy for Diffusion Transformers (DiT)-based video diffusion models. It achieves high efficiency by parallelizing both temporal frames and model layers with the help of *block-wise denoising scheme*.
13 | Feel free to visit our [paper](https://arxiv.org/abs/2505.21070) for more information.
14 |
15 | ## 🎥 Demo--more video samples in our [project page](https://dualparal-project.github.io/dualparal.github.io/)!
16 |
17 |

18 |
19 | A white-suited astronaut with a gold visor spins in dark space, tethered by a drifting cable. Stars twinkle around him as Earth glows blue in the distance. His suit reflects faint starlight against the vastness of the cosmos.
20 |
21 |

22 |
23 | A flock of birds glides through the warm sunset sky, wings outstretched. Their feathers catch golden light as they soar above silhouetted treetops, with the sky glowing in soft hues of amber and pink.
24 |
25 |
26 |
27 | ## 🛠️ Setup
28 | ```
29 | conda create -n DualParal python=3.10
30 | conda activate DualParal
31 | # Ensure torch >= 2.4.0 according to your cuda version, the following use CUDA12.1 as example
32 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ## 🚀 Usage
37 | ### **Quick Start —— DualParal on multiple GPUs with Wan2.1-1.3B (480p)**
38 | ```bash
39 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m examples.DualParal_Wan --sample_steps 50 --num_per_block 8 --latents_num 40 --num_cat 8
40 | ```
41 |
42 | ### **Major parameters**
43 | - **Basic Args**
44 |
45 | | Parameter | Description |
46 | | ----------- | -------------------------------------- |
47 | | `dtype` | Model dtype (float64, float32, float16, fp32, fp16, half, bf16) |
48 | | `seed` | The seed to use for generating the video. |
49 | | `save_file` | The file to save the generated video to. |
50 | | `verbose` | Enable verbose mode for debug. |
51 | | `export_image` | Enable exporting video frames. |
52 |
53 | - **Model Args**
54 |
55 | | Parameter | Description |
56 | | ----------- | -------------------------------------- |
57 | | `model_id` | Model Id for Wan-2.1 Video (Wan-AI/Wan2.1-T2V-1.3B-Diffusers, or Wan-AI/Wan2.1-T2V-14B-Diffusers). |
58 | | `height` | Height of generating videos. |
59 | | `width` | Width of generating videos. |
60 | | `sample_steps` | The sampling steps. |
61 | | `sample_shift` | Sampling shift factor for flow matching schedulers. |
62 | | `sample_guide_scale` | Classifier free guidance scale. |
63 |
64 | - **Major Args for DualParal**
65 |
66 | | Parameter | Description |
67 | | ----------- | -------------------------------------- |
68 | | `prompt` | The prompt to generate the video from. |
69 | | `num_per_block` | The number of latents per block in DualParal. |
70 | | `latents_num` | The total number of latents sampled from video. `latents_num` **must** be divisible by `num_per_block`. The total number of video frames is calculated as (`latents_num` - 1) $\times$ 4 + 1. |
71 | | `num_cat` | The number of latents to concatenate in previous and subsequent blocks separately. Increasing it (not greater than `num_per_block`) will lead better global consistency and temperoal coherence. Note that $Num_C$ in paper is equal to 2*`num_cat`. |
72 |
73 | ### Further experiments
74 | - **Original Wan implementation with single GPU**
75 | ```bash
76 | python -m examples.Wan-Video.py
77 | ```
78 |
79 | - **DualParal on multiple GPUs with Wan2.1-14B (720p)**
80 | ```bash
81 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m examples.DualParal_Wan --model_id Wan-AI/Wan2.1-T2V-14B-Diffusers --height 720 --width 1280 --sample_steps 50 --num_per_block 8 --latents_num 40 --num_cat 8
82 | ```
83 |
84 | ## ☀️ Acknowledgements
85 | Our project is based on the [Wan2.1](https://github.com/Wan-Video/Wan2.1) model. We would like to thank the authors for their excellent work! ❤️
86 |
87 | ## 🔗 Citation
88 | ```
89 | @misc{wang2025minutelongvideosdualparallelisms,
90 | title={Minute-Long Videos with Dual Parallelisms},
91 | author={Zeqing Wang and Bowen Zheng and Xingyi Yang and Zhenxiong Tan and Yuecong Xu and Xinchao Wang},
92 | year={2025},
93 | eprint={2505.21070},
94 | archivePrefix={arXiv},
95 | primaryClass={cs.CV},
96 | url={https://arxiv.org/abs/2505.21070},
97 | }
98 | ```
99 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.4.0
2 | torchvision>=0.19.0
3 | opencv-python>=4.9.0.80
4 | diffusers>=0.31.0
5 | transformers>=4.49.0
6 | tokenizers>=0.20.3
7 | accelerate>=1.1.1
8 | tqdm
9 | imageio
10 | easydict
11 | ftfy
12 | dashscope
13 | imageio-ffmpeg
14 | packaging
15 | ninja
16 | gradio>=5.0.0
17 | numpy>=1.23.5,<2
18 | pytorch_lightning
19 | flash-attn
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/src/__init__.py
--------------------------------------------------------------------------------
/src/distribution_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | from PIL import Image
5 | import torch
6 | import torch.distributed as dist
7 |
8 | class RuntimeConfig():
9 | def __init__(self,
10 | seed: int = 42,
11 | dtype: torch.dtype = torch.float16
12 | ):
13 | self.seed = seed
14 | self.dtype = self.to_torch_dtype(dtype)
15 |
16 | def to_torch_dtype(self, dtype):
17 | if isinstance(dtype, torch.dtype):
18 | return dtype
19 | elif isinstance(dtype, str):
20 | dtype_mapping = {
21 | "float64": torch.float64,
22 | "float32": torch.float32,
23 | "float16": torch.float16,
24 | "fp32": torch.float32,
25 | "fp16": torch.float16,
26 | "half": torch.float16,
27 | "bf16": torch.bfloat16,
28 | }
29 | if dtype not in dtype_mapping:
30 | raise ValueError
31 | dtype = dtype_mapping[dtype]
32 | return dtype
33 | else:
34 | raise ValueError
35 |
36 | class DistController(object):
37 | def __init__(self, rank: int, world_size: int) -> None:
38 | super().__init__()
39 | self.rank = rank
40 | self.world_size = world_size
41 | self.device = torch.device(f"cuda:{self.rank}")
42 |
43 | self.prev = self.rank-1 if self.rank-1>=0 else self.world_size-1
44 | self.next = self.rank+1 if self.rank+11: self.init_group()
52 |
53 | def init_get_start(self, iteration):
54 | for i in range(len(self.buffer_recv[iteration])-1, -1, -1):
55 | if self.recv_queue[iteration][i] is not None:
56 | return i
57 | return -1
58 |
59 | def init_buffer(self, timesteps):
60 | for i in range(len(timesteps)+1):
61 | for key in self.buffer_recv:
62 | self.buffer_recv[key].append(None)
63 | self.recv_queue[key].append(None)
64 | # print(f"[{self.device}] Init Buffer----")
65 |
66 | def modify_recv_queue(self, iteration, idx, tensor_shape=None, dtype=torch.float16, verbose=False):
67 | iteration = (iteration+2)%2
68 |
69 | assert self.buffer_recv[iteration][idx] is None, f"[{self.device}] Buffer_Queue at iteration {iteration} and index {idx} is not None!"
70 | assert self.recv_queue[iteration][idx] is None, f"[{self.device}] Recv_Queue at iteration {iteration} and index {idx} is not None!"
71 |
72 | tensor = torch.empty(tensor_shape.tolist(), device="cpu", dtype=dtype, pin_memory=True)
73 | self.buffer_recv[iteration][idx] = tensor
74 | if verbose: print(f"[{self.device}] idx {idx} on iteration {iteration} modify with size {tensor_shape}")
75 |
76 | def recv_next(self, iteration, idx, queue_lenth=0, force=False, end=False, TIMEOUT=50, verbose=False):
77 | '''
78 | queue_lenth = Now queue lenth
79 | '''
80 | iteration = (iteration+2)%2
81 | req = self.recv_queue[iteration][idx]
82 | if req is None:
83 | if self.buffer_recv[iteration][idx] is None:
84 | # no tensor in buffer_recv means that 'idx' is already recieved
85 | return None
86 | else:
87 | self.buffer_recv[iteration][idx] = self.buffer_recv[iteration][idx].to(self.device, non_blocking=True).contiguous()
88 | req = self._pipeline_irecv(self.buffer_recv[iteration][idx])
89 |
90 | if verbose: start_time = time.time()
91 | if not req.is_completed() and not force:
92 | return None
93 | req.wait()
94 | ans = self.buffer_recv[iteration][idx]
95 |
96 | self.recv_queue[iteration][idx] = None
97 | self.buffer_recv[iteration][idx] = None
98 | if verbose: print(f"[{self.device}] Request status of idx{idx} after wait:", req.is_completed(), " with size: ", ans.size(), f" with time: {time.time()-start_time:.6f}s")
99 | if end: return ans
100 |
101 | if verbose: start = time.time()
102 | next_id = idx-1 #previous one in backward order
103 | if next_id < 0:
104 | next_id = 0 if queue_lenth==0 else (next_id + queue_lenth)%queue_lenth
105 | self.buffer_recv[iteration^1][next_id] = self.buffer_recv[iteration^1][next_id].to(self.device, non_blocking=True)
106 | self.recv_queue[iteration^1][next_id] = self._pipeline_irecv(self.buffer_recv[iteration^1][next_id])
107 | tmp = self.recv_queue[iteration^1][next_id]
108 | else:
109 | self.buffer_recv[iteration][next_id] = self.buffer_recv[iteration][next_id].to(self.device, non_blocking=True)
110 | self.recv_queue[iteration][next_id] = self._pipeline_irecv(self.buffer_recv[iteration][next_id],)
111 | tmp = self.recv_queue[iteration][next_id]
112 | if verbose: print(f"[{self.device}] Stream status {tmp.is_completed()} recieving {time.time()-start:.6f}s")
113 | return ans
114 |
115 | def pipeline_isend(self, tensor, dtype, verbose=False) -> None:
116 | tensor_shape = tensor.size()
117 | if not tensor.is_contiguous():
118 | tensor = tensor.contiguous()
119 | if tensor.dtype != dtype:
120 | tensor = tensor.to(dtype)
121 | req = self._pipeline_isend(tensor)
122 |
123 |
124 | def _pipeline_irecv(self, tensor: torch.tensor):
125 | return torch.distributed.irecv(
126 | tensor,
127 | src=self.prev,
128 | group=self.gpu_group_receive,
129 | )
130 |
131 | def _pipeline_isend(self, tensor: torch.tensor):
132 | return torch.distributed.isend(
133 | tensor,
134 | dst=self.next,
135 | group=self.gpu_group_send
136 | )
137 |
138 | def init_dist(self):
139 | torch.cuda.set_device(self.device)
140 | print(f"Rank {self.rank} (world size {self.world_size}, with prev {self.prev} and next {self.next}) is running.")
141 | os.environ['MASTER_ADDR'] = '127.0.0.1'
142 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '29500')
143 | dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size)
144 |
145 | def init_group(self):
146 | group = list(range(self.world_size))
147 | if len(group) > 2 or len(group) == 1:
148 | device_group = torch.distributed.new_group(group, backend="nccl")
149 | self.gpu_group_receive = device_group
150 | self.gpu_group_send = device_group
151 | elif len(group) == 2:
152 | # when pipeline parallelism is 2, we need to create two groups to avoid
153 | # communication stall.
154 | # *_group_0_1 represents the group for communication from device 0 to
155 | # device 1.
156 | # *_group_1_0 represents the group for communication from device 1 to
157 | # device 0.
158 | device_group_0_1 = torch.distributed.new_group(group, backend="nccl")
159 | device_group_1_0 = torch.distributed.new_group(group, backend="nccl")
160 | groups = [device_group_0_1, device_group_1_0]
161 | self.gpu_group_send = groups[self.rank]
162 | self.gpu_group_receive = groups[(self.rank+1)%2]
163 |
164 | def pipeline_send(self, tensor: torch.Tensor, dtype, verbose=False) -> None:
165 | if verbose: start_time = time.time()
166 |
167 | tensor_shape = torch.tensor(tensor.size(), device=tensor.device, dtype=torch.int64).contiguous()
168 | if verbose: print(f"[{self.device}] Send size {tensor_shape} with dimension {tensor.dim()}")
169 | self._pipeline_isend(tensor_shape).wait()
170 | if verbose: print(f"[{self.device}] Success Send size {tensor_shape}")
171 | tensor = tensor.contiguous().to(dtype)
172 | self._pipeline_isend(tensor).wait()
173 |
174 | del tensor_shape
175 | del tensor
176 | if verbose: print(f"[{self.device}] Sending Tensor with shape {tensor.size()} ({tensor.device}, {tensor.dtype}, {tensor.sum()}) in {time.time()-start_time:.6f}s")
177 |
178 | def pipeline_recv(self, dtype, dimension=3, verbose=False) -> torch.Tensor:
179 | if verbose: start_time = time.time()
180 |
181 | tensor_shape = torch.empty(dimension, device=self.device, dtype=torch.int64)
182 | if verbose: print(f"[{self.device}] Ready size {tensor_shape} with dimension {dimension}")
183 | self._pipeline_irecv(tensor_shape).wait()
184 | if verbose: print(f"[{self.device}] Got size {tensor_shape}")
185 | tensor = torch.empty(tensor_shape.tolist(), device=self.device, dtype=dtype) # 假设数据类型为float32,调整为适合的类型
186 | self._pipeline_irecv(tensor).wait()
187 |
188 | del tensor_shape
189 | if verbose: print(f"[{self.device}] Receiving Tensor with shape {tensor.shape}(dtype {dtype}, sum {tensor.sum()}) in {time.time()-start_time:.6f}s")
190 | return tensor
191 |
192 | def export_to_images(video_frames, output_dir: str = None,):
193 | os.makedirs(output_dir, exist_ok=True)
194 |
195 | output_paths = []
196 |
197 | if isinstance(video_frames[0], np.ndarray):
198 | video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
199 |
200 | elif isinstance(video_frames[0], PIL.Image.Image):
201 | video_frames = [np.array(frame) for frame in video_frames]
202 |
203 | for i, frame in enumerate(video_frames):
204 | image_path = os.path.join(output_dir, f"frame_{i:04d}.png")
205 | Image.fromarray(frame).save(image_path)
206 | output_paths.append(image_path)
207 |
208 | return output_paths
209 |
210 | def memory_check(device, info = ""):
211 | allocated_memory = torch.cuda.memory_allocated()
212 | cached_memory = torch.cuda.memory_reserved()
213 | max_allocated_memory = torch.cuda.max_memory_allocated()
214 | max_cached_memory = torch.cuda.max_memory_reserved()
215 |
216 | print(f"[{device}] {info} Allocated Memory: {allocated_memory / (1024 ** 2):.2f} MB")
217 | print(f"[{device}] {info} Cached Memory: {cached_memory / (1024 ** 2):.2f} MB")
218 | print(f"[{device}] {info} Max Allocated Memory: {max_allocated_memory / (1024 ** 2):.2f} MB")
219 | print(f"[{device}] {info} Max Cached Memory: {max_cached_memory / (1024 ** 2):.2f} MB")
220 |
221 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from src.pipelines.pipeline_Wan import DualParalWanPipeline
2 | from src.pipelines.base_pipeline import Queue, QueueLatents, DualParalPipelineBaseWrapper
3 | __all__ = [
4 | "Queue",
5 | "QueueLatents",
6 | "DualParalPipelineBaseWrapper",
7 | "DualParalWanPipeline",
8 | ]
--------------------------------------------------------------------------------
/src/pipelines/base_pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.distribution_utils import DistController, RuntimeConfig
3 |
4 | class QueueLatents(object):
5 | def __init__(self,
6 | batch_size,
7 | out_channels,
8 | latent_size,
9 | z = None
10 | ):
11 | self.z = torch.randn(batch_size, out_channels, *latent_size, pin_memory=True) if latent_size is not None else None
12 | self.z = z[:, :, -latent_size[0]:].clone() if z is not None else self.z
13 | self.denoise_time = 0
14 |
15 | def to(self, *args, **kwargs):
16 | self.z = self.z.to(*args, **kwargs).contiguous() if self.z is not None else None
17 | return self
18 |
19 | class Queue(object):
20 | def __init__(self, num_per_block=10, num_cat=5, lenth_of_Queue=50):
21 | self.num_per_block = num_per_block
22 | self.num_cat = num_cat
23 | self.lenth_of_Queue = lenth_of_Queue
24 | self.begin, self.end = 0, -1
25 | self.queue = []
26 |
27 | def get(self, idx):
28 | return self.queue[idx+self.begin]
29 |
30 | def get_size(self, idx, itself=False):
31 | idx = self.begin+idx
32 | if itself:
33 | return self.queue[idx].z.size(2)
34 |
35 | prev = self.queue[idx-1] if idx-1>=0 else None
36 | L, R = 0, self.queue[idx].z.size(2)
37 | if prev is not None:
38 | R += self.num_cat
39 | return R
40 |
41 | def prepare_for_forward(self, idx):
42 | idx = self.begin + idx
43 | prev = self.queue[idx - 1] if idx-1 >=0 else None
44 | now = self.queue[idx].z.clone()
45 | L, R = 0, now.size(2)
46 |
47 |
48 | if prev is not None and self.num_cat > 0:
49 | now = torch.cat((prev.z[:,:,-self.num_cat:], now), dim=2)
50 | now = now[:, :, -self.num_cat - now.size(2):]
51 | L, R = L + self.num_cat, R + self.num_cat
52 |
53 | return L, R, now, self.queue[idx].denoise_time
54 |
55 | def update(self, idx, block):
56 | self.queue[idx + self.begin].z.copy_(block.z, non_blocking=True)
57 |
58 | def add_block(self, block):
59 | self.queue.append(block)
60 | self.end += 1
61 |
62 | def check_first(self, is_first=False):
63 | if self.begin>self.end:
64 | return False
65 | first_block = self.queue[self.begin]
66 | return first_block.denoise_time==(self.lenth_of_Queue - int(is_first==True))
67 |
68 | def del_prev_first(self):
69 | '''
70 | Followed by check_first()
71 | Self.begin add 1
72 | '''
73 | if self.begin-1>=0 and self.queue[self.begin-1] is not None:
74 | self.queue[self.begin-1] = None
75 | self.begin += 1
76 |
77 | def print_queue(self, device):
78 | print(f"[{device}] LOOK Queue: ", end=" ")
79 | for i in range(self.begin, self.end+1):
80 | print(self.queue[i].denoise_time, end=", ")
81 | print()
82 |
83 | class DualParalPipelineBaseWrapper(object):
84 | def __init__(
85 | self,
86 | parallel_config: DistController,
87 | runtime_config: RuntimeConfig,
88 | ):
89 | self.runtime_config = runtime_config
90 | self.parallel_config = parallel_config
91 |
92 | def _split_transformer_backbone(self):
93 | world_size = self.parallel_config.world_size
94 | local_rank = self.parallel_config.rank
95 |
96 | lenth_of_blocks = len(self.transformer)
97 | lenth_of_blocks_per_device = lenth_of_blocks//world_size
98 | mod_of_blocks_per_device = lenth_of_blocks%world_size
99 | range_of_block = {}
100 | start_q, end_q = 0, lenth_of_blocks_per_device
101 | for i in range(world_size):
102 | end_q = end_q + (mod_of_blocks_per_device>0)
103 | range_of_block[i] = (start_q, end_q)
104 | start_q = end_q
105 | end_q = end_q +lenth_of_blocks_per_device
106 | mod_of_blocks_per_device -= 1
107 |
108 | start_range, final_range = range_of_block[local_rank][0], range_of_block[local_rank][1]
109 | if self.parallel_config.rank==self.parallel_config.world_size-1:
110 | final_range = lenth_of_blocks
111 | range_of_blocks = range(
112 | start_range, final_range
113 | )
114 | self.transformer_ = self.transformer[range_of_blocks.start:range_of_blocks.stop]
115 | del self.transformer
116 | self.transformer = self.transformer_
117 | del self.transformer_
118 |
119 | def forward(self):
120 | pass
121 |
122 | def to(self, *args, **kwargs):
123 | pass
--------------------------------------------------------------------------------
/src/pipelines/pipeline_Wan.py:
--------------------------------------------------------------------------------
1 | import time
2 | import copy
3 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from diffusers import WanPipeline
9 | from diffusers.models.attention_processor import Attention
10 | from diffusers.models.transformers.transformer_wan import WanTransformerBlock
11 |
12 | from src.pipelines.base_pipeline import DualParalPipelineBaseWrapper
13 | from src.distribution_utils import DistController, RuntimeConfig
14 |
15 | class DualParal_WanAttnProcessor:
16 | def __init__(self):
17 | if not hasattr(F, "scaled_dot_product_attention"):
18 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
19 |
20 | def __call__(
21 | self,
22 | attn: Attention,
23 | hidden_states: torch.Tensor,
24 | encoder_hidden_states: Optional[torch.Tensor] = None,
25 | attention_mask: Optional[torch.Tensor] = None,
26 | rotary_emb: Optional[torch.Tensor] = None,
27 | cache = None,
28 | latent_num = None,
29 | select = None,
30 | select_all = None,
31 | out_dim = None
32 | ) -> torch.Tensor:
33 | encoder_hidden_states_img = None
34 | if attn.add_k_proj is not None:
35 | encoder_hidden_states_img = encoder_hidden_states[:, :257]
36 | encoder_hidden_states = encoder_hidden_states[:, 257:]
37 | if encoder_hidden_states is None:
38 | encoder_hidden_states = hidden_states
39 |
40 | use_cache = False
41 | cache2, cache3 = None, None
42 | query = attn.to_q(hidden_states)
43 | key = attn.to_k(encoder_hidden_states)
44 | value = attn.to_v(encoder_hidden_states)
45 | if latent_num is not None:
46 | cache2, cache3 = key[:, -latent_num:], value[:, -latent_num:]
47 | cutoff = latent_num//select_all*select
48 | cache2, cache3 = cache2[:, :cutoff].clone(), cache3[:, :cutoff].clone()
49 |
50 | if cache is not None and select>0:
51 | use_cache = True
52 | k_, v_ = cache[0], cache[1]
53 | key = torch.cat((key, k_), dim=1)
54 | value = torch.cat((value, v_), dim=1)
55 | cache = None
56 | if cache2 is not None and select>0:
57 | cache = (cache2, cache3)
58 | else:
59 | cache = None
60 | if attn.norm_q is not None:
61 | query = attn.norm_q(query)
62 | if attn.norm_k is not None:
63 | key = attn.norm_k(key)
64 |
65 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
66 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68 |
69 | if rotary_emb is not None:
70 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
71 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
72 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
73 | return x_out.type_as(hidden_states)
74 | query = apply_rotary_emb(query, rotary_emb[0])
75 | key = apply_rotary_emb(key, rotary_emb[1])
76 |
77 | L, S = query.size(-2), key.size(-2)
78 |
79 | attn_mask = None
80 | hidden_states = F.scaled_dot_product_attention(
81 | query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
82 | )
83 |
84 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
85 | if latent_num is not None:
86 | hidden_states = hidden_states[:, :out_dim] #delete cache
87 |
88 | hidden_states = hidden_states.type_as(query)
89 |
90 | hidden_states = attn.to_out[0](hidden_states)
91 | hidden_states = attn.to_out[1](hidden_states)
92 | return hidden_states, cache
93 |
94 | class DualParal_WanTransformerBlock(nn.Module):
95 | def __init__(
96 | self,
97 | block: WanTransformerBlock,
98 | ):
99 | super().__init__()
100 | self.block = block
101 |
102 | @torch.no_grad()
103 | def forward(
104 | self,
105 | hidden_states: torch.Tensor,
106 | encoder_hidden_states: torch.Tensor,
107 | temb: torch.Tensor,
108 | rotary_emb: torch.Tensor,
109 | cache: Optional[Tuple] = None,
110 | latent_num = None,
111 | select = None,
112 | select_all = None,
113 | attention_mask = None,
114 | ) -> torch.Tensor:
115 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
116 | self.block.scale_shift_table + temb.float()
117 | ).chunk(6, dim=1)
118 | # 1. Self-attention
119 | norm_hidden_states = (self.block.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
120 | attn_output, cache = self.block.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, cache=cache, \
121 | latent_num=latent_num, select=select, select_all=select_all, attention_mask=attention_mask, out_dim=hidden_states.size(1))
122 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
123 |
124 | # 2. Cross-attention
125 | norm_hidden_states = self.block.norm2(hidden_states.float()).type_as(hidden_states)
126 | attn_output, _ = self.block.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, \
127 | latent_num=None, select=select, select_all=select_all, attention_mask=attention_mask)
128 | hidden_states = hidden_states + attn_output
129 |
130 | # 3. Feed-forward
131 | norm_hidden_states = (self.block.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
132 | hidden_states
133 | )
134 | ff_output = self.block.ffn(norm_hidden_states)
135 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
136 | return hidden_states, cache
137 |
138 | class DualParalWanPipeline(DualParalPipelineBaseWrapper):
139 | def __init__(
140 | self,
141 | model: WanPipeline,
142 | parallel_config: DistController,
143 | runtimeconfig: RuntimeConfig,
144 | ):
145 | DualParalPipelineBaseWrapper.__init__(self, parallel_config, runtimeconfig)
146 | device, dtype = self.parallel_config.device, self.runtime_config.dtype
147 |
148 | self.tokenizer = model.tokenizer
149 | self.text_encoder = model.text_encoder
150 | self.transformer = model.transformer.blocks # Only Dit Blocks
151 | self.vae = model.vae
152 | self.scheduler = model.scheduler
153 | self.scheduler_dict = {}
154 | self.model = model # Other function or property all in self.model
155 | self.cache = []
156 | self.attention_mask = []
157 | self.tmp = None
158 |
159 | pretransformer, finaltransformer = False, False
160 | if self.parallel_config.world_size==1:
161 | pretransformer, finaltransformer = True, True
162 | elif self.parallel_config.rank==0:
163 | pretransformer, finaltransformer = True, False
164 | elif self.parallel_config.rank>0 and self.parallel_config.rank0:
207 | # for position
208 | num_frames += select
209 | latent_size = torch.Size([batch_size, num_channels, num_frames, height, width])
210 |
211 | self.tmp = torch.empty(*latent_size).to(self.parallel_config.device, self.runtime_config.dtype)
212 | rotary_emb_2 = self.model.transformer.rope(self.tmp)
213 | rotary_emb = (rotary_emb_1, rotary_emb_2)
214 |
215 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.model.transformer.condition_embedder(
216 | t, self.encoder_hidden_states, self.encoder_hidden_states_image
217 | )
218 | timestep_proj = timestep_proj.unflatten(1, (6, -1))
219 |
220 | if encoder_hidden_states_image is not None:
221 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
222 |
223 |
224 | if self.pretransformer:
225 | z = torch.cat([z, z], 0)
226 | z = self.model.transformer.patch_embedding(z)
227 | z = z.flatten(2).transpose(1, 2)
228 |
229 | hidden_states = z.contiguous()
230 | use_cache_ = (len(self.cache) > 0)
231 | for now_idx, block in enumerate(self.transformer):
232 | cache = None if not use_cache_ else self.cache[now_idx]
233 | size_tmp = hidden_states.size()
234 | hidden_states, cache = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, cache=cache,\
235 | latent_num=latent_num, select=select, select_all=select_all, attention_mask=self.attention_mask)
236 | if not use_cache_:
237 | self.cache.append(cache)
238 | else:
239 | self.cache[now_idx] = cache
240 |
241 | if self.finaltransformer:
242 | shift, scale = (self.model.transformer.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
243 | shift = shift.to(hidden_states.device)
244 | scale = scale.to(hidden_states.device)
245 | hidden_states = hidden_states.contiguous()
246 | hidden_states = (self.model.transformer.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
247 | hidden_states = self.model.transformer.proj_out(hidden_states)
248 | hidden_states = hidden_states.reshape(
249 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
250 | )
251 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
252 | noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
253 | noise_pred = noise_pred
254 | noise_pred, noise_uncond = noise_pred[0], noise_pred[1]
255 | noise_pred = noise_uncond + self.guidance_scale * (noise_pred - noise_uncond)
256 | noise_pred = noise_pred.unsqueeze(0)
257 | if verbose: print(f"[{self.parallel_config.device}] Final (use cache {use_cache_}) tiem {time.time()-start_time:.6f}s")
258 | return noise_pred
259 | else:
260 | if verbose: print(f"[{self.parallel_config.device}] Final (use cache {use_cache_}) tiem {time.time()-start_time:.6f}s")
261 | return hidden_states
262 |
263 | def get_model_args(
264 | self,
265 | prompt: Union[str, List[str]] = None,
266 | negative_prompt: Union[str, List[str]] = None,
267 | num_inference_steps: int = 50,
268 | guidance_scale: float = 5.0,
269 | num_videos_per_prompt: Optional[int] = 1,
270 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
271 | latents: Optional[torch.Tensor] = None,
272 | prompt_embeds: Optional[torch.Tensor] = None,
273 | negative_prompt_embeds: Optional[torch.Tensor] = None,
274 | do_classifier_free_guidance: bool = True,
275 | attention_kwargs: Optional[Dict[str, Any]] = None,
276 | max_sequence_length: int = 512,
277 | ):
278 | device, dtype = self.parallel_config.device, self.runtime_config.dtype
279 | self.guidance_scale = guidance_scale
280 |
281 | # 1. Get Prompt Embedding
282 | if prompt is not None and isinstance(prompt, str):
283 | batch_size = 1
284 | elif prompt is not None and isinstance(prompt, list):
285 | batch_size = len(prompt)
286 | else:
287 | batch_size = prompt_embeds.shape[0]
288 | prompt_embeds, negative_prompt_embeds = self.model.encode_prompt(
289 | prompt=prompt,
290 | negative_prompt=negative_prompt,
291 | do_classifier_free_guidance=do_classifier_free_guidance,
292 | num_videos_per_prompt=num_videos_per_prompt,
293 | prompt_embeds=prompt_embeds,
294 | negative_prompt_embeds=negative_prompt_embeds,
295 | max_sequence_length=max_sequence_length,
296 | device=device,
297 | dtype=dtype,
298 | )
299 | prompt_embeds = prompt_embeds.to(device, dtype, non_blocking=True)
300 | negative_prompt_embeds = negative_prompt_embeds.to(device, dtype, non_blocking=True) if negative_prompt_embeds is not None else None
301 | if do_classifier_free_guidance:
302 | self.encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
303 | else:
304 | self.encoder_hidden_states = prompt_embeds
305 | self.encoder_hidden_states_image = None
306 |
307 | del self.text_encoder
308 | del self.tokenizer
309 | del self.model.text_encoder
310 | del self.model.tokenizer
311 | torch.cuda.empty_cache()
312 |
313 | # 2. Prepare Timesteps
314 | self.scheduler.set_timesteps(num_inference_steps, device=device)
315 | self.timesteps = self.scheduler.timesteps
316 |
317 | def to_device(self, device, dtype, **kwargs):
318 | self.model = self.model.to(device, dtype, **kwargs) if self.model is not None else None
319 | self.vae = self.vae.to(dtype)
320 | return self
321 |
322 | def add_scheduler(self, idx):
323 | self.scheduler_dict[idx] = copy.deepcopy(self.scheduler)
324 |
325 | def del_scheduler(self, idx):
326 | del self.scheduler_dict[idx]
327 |
328 | def get_scheduler_dict(self, idx):
329 | return self.scheduler_dict[idx]
330 |
331 | @torch.no_grad()
332 | def get_video(
333 | self,
334 | latents,
335 | output_type: Optional[str] = "np",
336 | verbose=False):
337 | latents = latents.to(self.vae.dtype)
338 | latents_mean = (
339 | torch.tensor(self.vae.config.latents_mean)
340 | .view(1, self.vae.config.z_dim, 1, 1, 1)
341 | .to(latents.device, latents.dtype, non_blocking=True)
342 | )
343 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
344 | latents.device, latents.dtype
345 | )
346 | latents = latents / latents_std + latents_mean
347 | video = self.vae.decode(latents, return_dict=False)[0]
348 | video = self.model.video_processor.postprocess_video(video, output_type=output_type)
349 | return video
--------------------------------------------------------------------------------