├── assets ├── gif1.gif └── gif2.gif ├── examples ├── DualParal_Wan.py └── Wan-Video.py ├── readme.md ├── requirements.txt └── src ├── __init__.py ├── distribution_utils.py └── pipelines ├── __init__.py ├── base_pipeline.py └── pipeline_Wan.py /assets/gif1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/assets/gif1.gif -------------------------------------------------------------------------------- /assets/gif2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/assets/gif2.gif -------------------------------------------------------------------------------- /examples/DualParal_Wan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" 4 | os.environ["HF_DATASETS_CACHE"] = "../Checkpoint/" 5 | os.environ["HF_HOME"] = "../Checkpoint/" 6 | os.environ["HUGGINGFACE_HUB_CACHE"] = "../Checkpoint/" 7 | os.environ["TRANSFORMERS_CACHE"] = "../Checkpoint/" 8 | # sys.path.append('../../') 9 | import time 10 | import copy 11 | import argparse 12 | import numpy as np 13 | import multiprocessing 14 | from PIL import Image 15 | from tqdm import tqdm 16 | from datetime import datetime 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from pytorch_lightning import seed_everything 21 | from diffusers.utils import export_to_video 22 | from diffusers import AutoencoderKLWan, WanPipeline 23 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 24 | 25 | from src.pipelines import DualParalWanPipeline, QueueLatents, Queue 26 | from src.distribution_utils import RuntimeConfig, DistController, export_to_images, memory_check 27 | 28 | def _parse_args(): 29 | # For basic args 30 | parser = argparse.ArgumentParser(description="DualParal with Wan") 31 | parser.add_argument("--dtype", type=str, default="bf16", 32 | help="Model dtype (float64, float32, float16, fp32, fp16, half, bf16)") 33 | parser.add_argument("--seed", type=int, default=12345, 34 | help="The seed to use for generating the image or video.") 35 | parser.add_argument("--save_file", type=str, default="../results/", 36 | help="The file to save the generated image or video to.") 37 | parser.add_argument("--verbose", action="store_true", default=False, 38 | help="Enable verbose mode") 39 | parser.add_argument("--export_image", action="store_true", default=False, 40 | help="Enable exporting video frames.") 41 | 42 | # For Wan-Video model 43 | parser.add_argument("--model_id", type=str, default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", 44 | help="Model Id for Wan-2.1 Video.") 45 | parser.add_argument("--height", type=int, default=480, 46 | help="Height of generating videos") 47 | parser.add_argument("--width", type=int, default=832, 48 | help="Width of generating videos") 49 | parser.add_argument("--sample_steps", type=int, default=50, 50 | help="The sampling steps.") 51 | parser.add_argument("--sample_shift", type=float, default=None, 52 | help="Sampling shift factor for flow matching schedulers.") 53 | parser.add_argument("--sample_guide_scale", type=float, default=5.0, 54 | help="Classifier free guidance scale.") 55 | parser.add_argument("--prompt", type=str, default="A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.", 56 | help="The prompt to generate the image or video from.") 57 | 58 | # For DualParal 59 | parser.add_argument("--num_per_block", type=int, default=10, 60 | help="How many latents per block in DualParal.") 61 | parser.add_argument("--latents_num", type=int, default=30, 62 | help="How many latents to sample from a image or video. The total frames is equal to (latents_num-1)*4+1.") 63 | parser.add_argument("--num_cat", type=int, default=5, 64 | help="How many latents to concat in previous and backward blocks separately.") 65 | 66 | args = parser.parse_args() 67 | return args 68 | 69 | def prepare_model(args, parallel_config, runtime_config): 70 | model_id = args.model_id 71 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 72 | flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P 73 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) 74 | model = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=runtime_config.dtype) 75 | model.scheduler = scheduler 76 | 77 | Pipe = DualParalWanPipeline(model, parallel_config, runtime_config) 78 | Pipe.to_device(device=parallel_config.device, dtype=runtime_config.dtype, non_blocking=True) 79 | 80 | num_channels_latents = Pipe.model.transformer.config.in_channels 81 | args.height = int(args.height) // Pipe.model.vae_scale_factor_spatial 82 | args.width = int(args.width) // Pipe.model.vae_scale_factor_spatial 83 | return Pipe, num_channels_latents 84 | 85 | def update(ID, QueueWan, cnt_, Pipe, z): 86 | test = QueueWan.get(ID) 87 | scheduler = Pipe.get_scheduler_dict(QueueWan.begin + ID) 88 | z = scheduler.step(z, Pipe.timesteps[test.denoise_time-1], test.z, return_dict=False)[0] 89 | test.z.copy_(z, non_blocking=True) 90 | 91 | if ID + QueueWan.begin in cnt_: 92 | cnt_ [ID + QueueWan.begin] += 1 93 | else: 94 | cnt_ [ID + QueueWan.begin] = 1 95 | 96 | def main(args, rank, world_size): 97 | #---------------Model Preparation-------------------- 98 | torch.set_grad_enabled(False) 99 | torch.backends.cudnn.enabled = True 100 | torch.backends.cudnn.benchmark = True 101 | seed_everything(args.seed) 102 | parallel_config = DistController(rank, world_size) 103 | runtime_config = RuntimeConfig(args.seed, args.dtype) 104 | 105 | prompt = args.prompt 106 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 107 | 108 | Pipe, out_channels = prepare_model(args, parallel_config, runtime_config) 109 | Pipe.get_model_args( 110 | prompt=prompt, 111 | num_inference_steps=args.sample_steps, 112 | negative_prompt=negative_prompt, 113 | guidance_scale=args.sample_guide_scale, 114 | ) 115 | timesteps = Pipe.timesteps 116 | latent_size = (args.num_per_block, args.height, args.width) 117 | #---------------Warm Up-------------------- 118 | print(f"[{Pipe.parallel_config.device}]--------------Warm Up-----------------") 119 | latent_size_warmup = (1, out_channels, args.num_per_block, args.height, args.width) 120 | Block_latents_size = None 121 | for iteration in range(1): 122 | warmup_block = QueueLatents(1, out_channels, latent_size).to(Pipe.parallel_config.device, Pipe.runtime_config.dtype, non_blocking=True) 123 | if Pipe.parallel_config.rank==0: 124 | output = Pipe.onestep(warmup_block, latent_size=warmup_block.z.size()) 125 | Block_latents_size = output.size() 126 | Pipe.parallel_config.pipeline_send(tensor=output, dtype=Pipe.runtime_config.dtype, verbose=False) 127 | output = Pipe.parallel_config.pipeline_recv(dtype=Pipe.runtime_config.dtype, dimension=warmup_block.z.dim(), verbose=False) 128 | else: 129 | output = Pipe.parallel_config.pipeline_recv(dtype=Pipe.runtime_config.dtype, dimension=3, verbose=False) 130 | warmup_block.z = output 131 | Block_latents_size = output.size() 132 | output = Pipe.onestep(warmup_block, latent_size=latent_size_warmup) 133 | Pipe.parallel_config.pipeline_send(tensor=output, dtype=Pipe.runtime_config.dtype, verbose=False) 134 | del output 135 | print(f"[{Pipe.parallel_config.device}]-------------Warm Up End----------------------") 136 | 137 | #---------------DualParal-------------------- 138 | cnt, video, cnt_ = 0, None, {} # cnt for counting total latents in queue, video for concating video latents 139 | Pipe.parallel_config.init_buffer(timesteps) 140 | 141 | QueueWan = Queue(num_per_block=args.num_per_block, num_cat=args.num_cat, lenth_of_Queue=len(timesteps)) 142 | Block_ref = copy.deepcopy(QueueWan) 143 | latent_tmp = (args.num_per_block + args.num_cat, args.height, args.width) 144 | test = QueueLatents(1, out_channels, latent_tmp) 145 | Block_ref.add_block(test) 146 | if Pipe.parallel_config.rank>0: 147 | num_frames = Block_latents_size[1]//args.num_per_block*Block_ref.get_size(0) 148 | tensor_shape = torch.tensor([Block_latents_size[0], num_frames, Block_latents_size[2]], dtype=torch.int64).contiguous() 149 | Pipe.parallel_config.modify_recv_queue(iteration=-1, idx=0, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False) 150 | 151 | torch.cuda.synchronize() 152 | torch.cuda.empty_cache() 153 | start_time = time.time() 154 | template_z = torch.randn(1, out_channels, args.num_per_block + args.num_cat, args.height, args.width, pin_memory=True) 155 | for iteration in tqdm(range(len(timesteps)+args.latents_num//args.num_per_block)): 156 | Pipe.cache = [] 157 | now_iteration = (iteration-1+2)%2 158 | prev_get = Pipe.parallel_config.init_get_start(now_iteration) 159 | if iteration == 1: prev_get = 0 # Corner Case 160 | 161 | # Change Time in Block Ref 162 | if Pipe.parallel_config.rank > 0: 163 | for idx in range(Block_ref.end-Block_ref.begin, -1, -1): 164 | test = Block_ref.get(idx) 165 | test.denoise_time += 1 166 | if Block_ref.check_first(): Block_ref.del_prev_first() 167 | 168 | # Add block 169 | add_block=False 170 | if cnt < args.latents_num: 171 | latent_tmp = latent_size 172 | if iteration==0: 173 | latent_tmp = (args.num_per_block + args.num_cat, args.height, args.width) 174 | select_indice = torch.randperm(template_z.size(2) - args.num_cat) 175 | template_z = torch.cat((template_z[:, :, -args.num_cat:], template_z[:, :, select_indice]), dim=2) 176 | test = QueueLatents(1, out_channels, latent_tmp, template_z).to(Pipe.parallel_config.device, Pipe.runtime_config.dtype, non_blocking=True) 177 | QueueWan.add_block(test) 178 | cnt += args.num_per_block 179 | add_block=True 180 | Pipe.add_scheduler(QueueWan.end) 181 | if cnt < args.latents_num: 182 | test = QueueLatents(1, out_channels, latent_size) 183 | Block_ref.add_block(test) 184 | if args.verbose: print(f"[{parallel_config.device}]--------DualParal in iteration-{iteration} is Begin with Queue from {QueueWan.begin} to {QueueWan.end}-----------") 185 | # Prepare for receving latents between two GPUs 186 | if Pipe.parallel_config.rank > 0: 187 | for idx in range(Block_ref.end-Block_ref.begin, -1, -1): 188 | num = Block_latents_size[1]//args.num_per_block * Block_ref.get_size(idx) 189 | tensor_shape = torch.tensor([Block_latents_size[0], num, Block_latents_size[2]], dtype=torch.int64).contiguous() 190 | Pipe.parallel_config.modify_recv_queue(iteration, idx, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False) 191 | 192 | # Del First Block in QueueWan 193 | del_begin = 0 194 | if QueueWan.check_first(): 195 | QueueWan.del_prev_first() 196 | ID = QueueWan.begin - 2 197 | if ID >= 0: Pipe.del_scheduler(ID) 198 | del_begin = 1 199 | 200 | for idx in range(QueueWan.end-QueueWan.begin, -1, -1): 201 | now_end = ( idx==0 and iteration==(len(timesteps)+args.latents_num//args.num_per_block-1) ) 202 | get_next = False 203 | # Receving 204 | if Pipe.parallel_config.rank != 0: 205 | x = Pipe.parallel_config.recv_next(iteration-1, idx, queue_lenth=Block_ref.end-Block_ref.begin+1, 206 | force=True, end=(Block_ref.end-Block_ref.begin+1==0), verbose=False) 207 | input_block_tmp = QueueWan.get(idx) 208 | input_block = QueueLatents(1, out_channels, None, None) 209 | input_block.denoise_time = input_block_tmp.denoise_time 210 | input_block.z = x 211 | else: 212 | latent_size_tmp = latent_size_warmup 213 | if QueueWan.begin+idx == 0: 214 | latent_size_tmp = (1, out_channels, args.num_per_block + args.num_cat, args.height, args.width) 215 | tensor_shape = torch.tensor(latent_size_tmp, dtype=torch.int64).contiguous() 216 | Pipe.parallel_config.modify_recv_queue(iteration, idx, dtype=Pipe.runtime_config.dtype, tensor_shape=tensor_shape, verbose=False) 217 | force = False 218 | if prev_get >= 0: 219 | # Make sure the gap between communication large than world_size to make sure output 220 | # is already come out from the last device 221 | if (abs(prev_get+1 + QueueWan.end-idx+int(add_block)) <= Pipe.parallel_config.world_size or iteration==1)\ 222 | and idx == QueueWan.end-QueueWan.begin and add_block==True: 223 | get_next = True 224 | else: 225 | force = (now_iteration==(iteration+1)%2 and prev_get==idx) 226 | z = Pipe.parallel_config.recv_next(now_iteration, prev_get, queue_lenth=QueueWan.end-QueueWan.begin+1, 227 | force=force, end=now_end, verbose=False) 228 | if z is not None: 229 | get_next = True 230 | ID = prev_get - del_begin 231 | if now_iteration%2 == iteration%2: 232 | ID = prev_get 233 | update(ID, QueueWan, cnt_, Pipe, z) 234 | prev_get -= 1 235 | else: get_next = True 236 | 237 | if prev_get < 0: 238 | now_iteration = now_iteration^1 239 | prev_get = Pipe.parallel_config.init_get_start(now_iteration) 240 | input_block = QueueLatents(1, out_channels, None, None) 241 | L, R, latents_z, denoise_time = QueueWan.prepare_for_forward(idx) 242 | input_block.z = latents_z 243 | input_block.denoise_time = denoise_time 244 | 245 | size_frames = QueueWan.get_size(idx) 246 | latent_size_tmp = (latent_size_warmup[0], latent_size_warmup[1], size_frames, latent_size_warmup[3], latent_size_warmup[4]) 247 | # Pipe.cache = [] 248 | x = Pipe.onestep(input_block, latent_size=latent_size_tmp, latent_num=Block_latents_size[1], select=args.num_cat, select_all=args.num_per_block, verbose=False) 249 | 250 | if Pipe.parallel_config.rank != Pipe.parallel_config.world_size-1: 251 | # check recieving next 252 | if parallel_config.rank==0 and not get_next: 253 | z = Pipe.parallel_config.recv_next(now_iteration, prev_get, queue_lenth=QueueWan.end-QueueWan.begin+1, 254 | force=True, end=now_end, verbose=False) 255 | ID = prev_get - del_begin 256 | if now_iteration%2 == iteration%2: 257 | ID = prev_get 258 | update(ID, QueueWan, cnt_, Pipe, z) 259 | prev_get -= 1 260 | if prev_get < 0: 261 | now_iteration = now_iteration^1 262 | prev_get = Pipe.parallel_config.init_get_start(now_iteration) 263 | if args.verbose: print(f"[{Pipe.parallel_config.device}] ready to send X with size {x.size()}") 264 | Pipe.parallel_config.pipeline_isend(tensor=x, dtype=Pipe.runtime_config.dtype, verbose=False) 265 | else: 266 | size_latent = QueueWan.get_size(idx, itself=True) 267 | x = x[:, :, -size_latent:].clone().contiguous() 268 | if args.verbose: print(f"[{Pipe.parallel_config.device}] ready to send X with size {x.size()}, with sum {x.size()}") 269 | Pipe.parallel_config.pipeline_isend(tensor=x, dtype=Pipe.runtime_config.dtype, verbose=False) 270 | #Update Denoise_Time in Queue 271 | input_block = QueueWan.get(idx) 272 | input_block.denoise_time += 1 273 | 274 | # Extract First Block 275 | if del_begin==1 and Pipe.parallel_config.rank==0: 276 | first_block = QueueWan.get(-1) 277 | z = Pipe.parallel_config.recv_next(iteration-1, idx=0, verbose=False, queue_lenth=QueueWan.end-QueueWan.begin+1, 278 | force=True, end=(iteration==(len(timesteps)+args.latents_num//args.num_per_block-1))) 279 | if z is not None: 280 | prev_get = -1 281 | cnt_ [prev_get + QueueWan.begin] += 1 282 | scheduler = Pipe.get_scheduler_dict(QueueWan.begin - 1) 283 | z = scheduler.step(z, Pipe.timesteps[first_block.denoise_time-1], first_block.z, return_dict=False)[0] 284 | first_block.z = z 285 | video_ = first_block.z 286 | video = video_ if video is None else torch.cat((video, video_), dim=2) 287 | if args.verbose: print(f"[{Pipe.parallel_config.device}] Now {iteration} video size: ", video.size()) 288 | 289 | torch.cuda.synchronize() 290 | print(f"[{Pipe.parallel_config.device}] Whole inference time {time.time()-start_time:.6f}s") 291 | if Pipe.parallel_config.rank==0: 292 | print("Video latents size: ", video.size()) 293 | video = Pipe.get_video(video)[0] 294 | print(f"Final Video shape: {video.shape}") 295 | export_to_video(video, args.save_file+"output.mp4", fps=16) 296 | if args.export_image: 297 | export_to_images(video, args.save_file + "frames/") 298 | 299 | if __name__ == "__main__": 300 | args = _parse_args() 301 | multiprocessing.set_start_method('spawn') 302 | num_processes = torch.cuda.device_count() 303 | processes = [] 304 | 305 | for rank in range(num_processes): 306 | p = multiprocessing.Process(target=main, args=(args, rank, num_processes)) 307 | p.start() 308 | processes.append(p) 309 | 310 | for p in processes: 311 | p.join() -------------------------------------------------------------------------------- /examples/Wan-Video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" 4 | os.environ["HF_DATASETS_CACHE"] = "../Checkpoint/" 5 | os.environ["HF_HOME"] = "../Checkpoint/" 6 | os.environ["HUGGINGFACE_HUB_CACHE"] = "../Checkpoint/" 7 | os.environ["TRANSFORMERS_CACHE"] = "../Checkpoint/" 8 | import time 9 | import torch 10 | from diffusers.utils import export_to_video 11 | from diffusers import AutoencoderKLWan, WanPipeline 12 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 13 | 14 | # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers 15 | model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 16 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) 17 | flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P 18 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) 19 | pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 20 | pipe.scheduler = scheduler 21 | pipe.to("cuda") 22 | 23 | prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." 24 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 25 | 26 | start = time.time() 27 | output = pipe( 28 | prompt=prompt, 29 | negative_prompt=negative_prompt, 30 | height=480, 31 | width=832, 32 | num_frames=24, 33 | guidance_scale=5.0, 34 | ).frames[0] 35 | print(f"Using {time.time()-start:.6f}s") 36 | export_to_video(output, "output.mp4", fps=16) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Minute-Long Videos with Dual Parallelisms 4 | 5 | 6 |
7 | arXiv 8 | 9 |
10 | 11 | ## 📚 TL;DR (Too Long; Didn't Read) 12 | **DualParal** is a distributed inference strategy for Diffusion Transformers (DiT)-based video diffusion models. It achieves high efficiency by parallelizing both temporal frames and model layers with the help of *block-wise denoising scheme*. 13 | Feel free to visit our [paper](https://arxiv.org/abs/2505.21070) for more information. 14 | 15 | ## 🎥 Demo--more video samples in our [project page](https://dualparal-project.github.io/dualparal.github.io/)! 16 |
17 | 18 |

19 | A white-suited astronaut with a gold visor spins in dark space, tethered by a drifting cable. Stars twinkle around him as Earth glows blue in the distance. His suit reflects faint starlight against the vastness of the cosmos. 20 |

21 | 22 |

23 | A flock of birds glides through the warm sunset sky, wings outstretched. Their feathers catch golden light as they soar above silhouetted treetops, with the sky glowing in soft hues of amber and pink. 24 |

25 |
26 | 27 | ## 🛠️ Setup 28 | ``` 29 | conda create -n DualParal python=3.10 30 | conda activate DualParal 31 | # Ensure torch >= 2.4.0 according to your cuda version, the following use CUDA12.1 as example 32 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## 🚀 Usage 37 | ### **Quick Start —— DualParal on multiple GPUs with Wan2.1-1.3B (480p)** 38 | ```bash 39 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m examples.DualParal_Wan --sample_steps 50 --num_per_block 8 --latents_num 40 --num_cat 8 40 | ``` 41 | 42 | ### **Major parameters** 43 | - **Basic Args** 44 | 45 | | Parameter | Description | 46 | | ----------- | -------------------------------------- | 47 | | `dtype` | Model dtype (float64, float32, float16, fp32, fp16, half, bf16) | 48 | | `seed` | The seed to use for generating the video. | 49 | | `save_file` | The file to save the generated video to. | 50 | | `verbose` | Enable verbose mode for debug. | 51 | | `export_image` | Enable exporting video frames. | 52 | 53 | - **Model Args** 54 | 55 | | Parameter | Description | 56 | | ----------- | -------------------------------------- | 57 | | `model_id` | Model Id for Wan-2.1 Video (Wan-AI/Wan2.1-T2V-1.3B-Diffusers, or Wan-AI/Wan2.1-T2V-14B-Diffusers). | 58 | | `height` | Height of generating videos. | 59 | | `width` | Width of generating videos. | 60 | | `sample_steps` | The sampling steps. | 61 | | `sample_shift` | Sampling shift factor for flow matching schedulers. | 62 | | `sample_guide_scale` | Classifier free guidance scale. | 63 | 64 | - **Major Args for DualParal** 65 | 66 | | Parameter | Description | 67 | | ----------- | -------------------------------------- | 68 | | `prompt` | The prompt to generate the video from. | 69 | | `num_per_block` | The number of latents per block in DualParal. | 70 | | `latents_num` | The total number of latents sampled from video. `latents_num` **must** be divisible by `num_per_block`. The total number of video frames is calculated as (`latents_num` - 1) $\times$ 4 + 1. | 71 | | `num_cat` | The number of latents to concatenate in previous and subsequent blocks separately. Increasing it (not greater than `num_per_block`) will lead better global consistency and temperoal coherence. Note that $Num_C$ in paper is equal to 2*`num_cat`. | 72 | 73 | ### Further experiments 74 | - **Original Wan implementation with single GPU** 75 | ```bash 76 | python -m examples.Wan-Video.py 77 | ``` 78 | 79 | - **DualParal on multiple GPUs with Wan2.1-14B (720p)** 80 | ```bash 81 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m examples.DualParal_Wan --model_id Wan-AI/Wan2.1-T2V-14B-Diffusers --height 720 --width 1280 --sample_steps 50 --num_per_block 8 --latents_num 40 --num_cat 8 82 | ``` 83 | 84 | ## ☀️ Acknowledgements 85 | Our project is based on the [Wan2.1](https://github.com/Wan-Video/Wan2.1) model. We would like to thank the authors for their excellent work! ❤️ 86 | 87 | ## 🔗 Citation 88 | ``` 89 | @misc{wang2025minutelongvideosdualparallelisms, 90 | title={Minute-Long Videos with Dual Parallelisms}, 91 | author={Zeqing Wang and Bowen Zheng and Xingyi Yang and Zhenxiong Tan and Yuecong Xu and Xinchao Wang}, 92 | year={2025}, 93 | eprint={2505.21070}, 94 | archivePrefix={arXiv}, 95 | primaryClass={cs.CV}, 96 | url={https://arxiv.org/abs/2505.21070}, 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.4.0 2 | torchvision>=0.19.0 3 | opencv-python>=4.9.0.80 4 | diffusers>=0.31.0 5 | transformers>=4.49.0 6 | tokenizers>=0.20.3 7 | accelerate>=1.1.1 8 | tqdm 9 | imageio 10 | easydict 11 | ftfy 12 | dashscope 13 | imageio-ffmpeg 14 | packaging 15 | ninja 16 | gradio>=5.0.0 17 | numpy>=1.23.5,<2 18 | pytorch_lightning 19 | flash-attn -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DualParal-Project/DualParal/a6cd8a1fb8eaa3c703a13510c22245bb2deb662b/src/__init__.py -------------------------------------------------------------------------------- /src/distribution_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | import torch.distributed as dist 7 | 8 | class RuntimeConfig(): 9 | def __init__(self, 10 | seed: int = 42, 11 | dtype: torch.dtype = torch.float16 12 | ): 13 | self.seed = seed 14 | self.dtype = self.to_torch_dtype(dtype) 15 | 16 | def to_torch_dtype(self, dtype): 17 | if isinstance(dtype, torch.dtype): 18 | return dtype 19 | elif isinstance(dtype, str): 20 | dtype_mapping = { 21 | "float64": torch.float64, 22 | "float32": torch.float32, 23 | "float16": torch.float16, 24 | "fp32": torch.float32, 25 | "fp16": torch.float16, 26 | "half": torch.float16, 27 | "bf16": torch.bfloat16, 28 | } 29 | if dtype not in dtype_mapping: 30 | raise ValueError 31 | dtype = dtype_mapping[dtype] 32 | return dtype 33 | else: 34 | raise ValueError 35 | 36 | class DistController(object): 37 | def __init__(self, rank: int, world_size: int) -> None: 38 | super().__init__() 39 | self.rank = rank 40 | self.world_size = world_size 41 | self.device = torch.device(f"cuda:{self.rank}") 42 | 43 | self.prev = self.rank-1 if self.rank-1>=0 else self.world_size-1 44 | self.next = self.rank+1 if self.rank+11: self.init_group() 52 | 53 | def init_get_start(self, iteration): 54 | for i in range(len(self.buffer_recv[iteration])-1, -1, -1): 55 | if self.recv_queue[iteration][i] is not None: 56 | return i 57 | return -1 58 | 59 | def init_buffer(self, timesteps): 60 | for i in range(len(timesteps)+1): 61 | for key in self.buffer_recv: 62 | self.buffer_recv[key].append(None) 63 | self.recv_queue[key].append(None) 64 | # print(f"[{self.device}] Init Buffer----") 65 | 66 | def modify_recv_queue(self, iteration, idx, tensor_shape=None, dtype=torch.float16, verbose=False): 67 | iteration = (iteration+2)%2 68 | 69 | assert self.buffer_recv[iteration][idx] is None, f"[{self.device}] Buffer_Queue at iteration {iteration} and index {idx} is not None!" 70 | assert self.recv_queue[iteration][idx] is None, f"[{self.device}] Recv_Queue at iteration {iteration} and index {idx} is not None!" 71 | 72 | tensor = torch.empty(tensor_shape.tolist(), device="cpu", dtype=dtype, pin_memory=True) 73 | self.buffer_recv[iteration][idx] = tensor 74 | if verbose: print(f"[{self.device}] idx {idx} on iteration {iteration} modify with size {tensor_shape}") 75 | 76 | def recv_next(self, iteration, idx, queue_lenth=0, force=False, end=False, TIMEOUT=50, verbose=False): 77 | ''' 78 | queue_lenth = Now queue lenth 79 | ''' 80 | iteration = (iteration+2)%2 81 | req = self.recv_queue[iteration][idx] 82 | if req is None: 83 | if self.buffer_recv[iteration][idx] is None: 84 | # no tensor in buffer_recv means that 'idx' is already recieved 85 | return None 86 | else: 87 | self.buffer_recv[iteration][idx] = self.buffer_recv[iteration][idx].to(self.device, non_blocking=True).contiguous() 88 | req = self._pipeline_irecv(self.buffer_recv[iteration][idx]) 89 | 90 | if verbose: start_time = time.time() 91 | if not req.is_completed() and not force: 92 | return None 93 | req.wait() 94 | ans = self.buffer_recv[iteration][idx] 95 | 96 | self.recv_queue[iteration][idx] = None 97 | self.buffer_recv[iteration][idx] = None 98 | if verbose: print(f"[{self.device}] Request status of idx{idx} after wait:", req.is_completed(), " with size: ", ans.size(), f" with time: {time.time()-start_time:.6f}s") 99 | if end: return ans 100 | 101 | if verbose: start = time.time() 102 | next_id = idx-1 #previous one in backward order 103 | if next_id < 0: 104 | next_id = 0 if queue_lenth==0 else (next_id + queue_lenth)%queue_lenth 105 | self.buffer_recv[iteration^1][next_id] = self.buffer_recv[iteration^1][next_id].to(self.device, non_blocking=True) 106 | self.recv_queue[iteration^1][next_id] = self._pipeline_irecv(self.buffer_recv[iteration^1][next_id]) 107 | tmp = self.recv_queue[iteration^1][next_id] 108 | else: 109 | self.buffer_recv[iteration][next_id] = self.buffer_recv[iteration][next_id].to(self.device, non_blocking=True) 110 | self.recv_queue[iteration][next_id] = self._pipeline_irecv(self.buffer_recv[iteration][next_id],) 111 | tmp = self.recv_queue[iteration][next_id] 112 | if verbose: print(f"[{self.device}] Stream status {tmp.is_completed()} recieving {time.time()-start:.6f}s") 113 | return ans 114 | 115 | def pipeline_isend(self, tensor, dtype, verbose=False) -> None: 116 | tensor_shape = tensor.size() 117 | if not tensor.is_contiguous(): 118 | tensor = tensor.contiguous() 119 | if tensor.dtype != dtype: 120 | tensor = tensor.to(dtype) 121 | req = self._pipeline_isend(tensor) 122 | 123 | 124 | def _pipeline_irecv(self, tensor: torch.tensor): 125 | return torch.distributed.irecv( 126 | tensor, 127 | src=self.prev, 128 | group=self.gpu_group_receive, 129 | ) 130 | 131 | def _pipeline_isend(self, tensor: torch.tensor): 132 | return torch.distributed.isend( 133 | tensor, 134 | dst=self.next, 135 | group=self.gpu_group_send 136 | ) 137 | 138 | def init_dist(self): 139 | torch.cuda.set_device(self.device) 140 | print(f"Rank {self.rank} (world size {self.world_size}, with prev {self.prev} and next {self.next}) is running.") 141 | os.environ['MASTER_ADDR'] = '127.0.0.1' 142 | os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '29500') 143 | dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size) 144 | 145 | def init_group(self): 146 | group = list(range(self.world_size)) 147 | if len(group) > 2 or len(group) == 1: 148 | device_group = torch.distributed.new_group(group, backend="nccl") 149 | self.gpu_group_receive = device_group 150 | self.gpu_group_send = device_group 151 | elif len(group) == 2: 152 | # when pipeline parallelism is 2, we need to create two groups to avoid 153 | # communication stall. 154 | # *_group_0_1 represents the group for communication from device 0 to 155 | # device 1. 156 | # *_group_1_0 represents the group for communication from device 1 to 157 | # device 0. 158 | device_group_0_1 = torch.distributed.new_group(group, backend="nccl") 159 | device_group_1_0 = torch.distributed.new_group(group, backend="nccl") 160 | groups = [device_group_0_1, device_group_1_0] 161 | self.gpu_group_send = groups[self.rank] 162 | self.gpu_group_receive = groups[(self.rank+1)%2] 163 | 164 | def pipeline_send(self, tensor: torch.Tensor, dtype, verbose=False) -> None: 165 | if verbose: start_time = time.time() 166 | 167 | tensor_shape = torch.tensor(tensor.size(), device=tensor.device, dtype=torch.int64).contiguous() 168 | if verbose: print(f"[{self.device}] Send size {tensor_shape} with dimension {tensor.dim()}") 169 | self._pipeline_isend(tensor_shape).wait() 170 | if verbose: print(f"[{self.device}] Success Send size {tensor_shape}") 171 | tensor = tensor.contiguous().to(dtype) 172 | self._pipeline_isend(tensor).wait() 173 | 174 | del tensor_shape 175 | del tensor 176 | if verbose: print(f"[{self.device}] Sending Tensor with shape {tensor.size()} ({tensor.device}, {tensor.dtype}, {tensor.sum()}) in {time.time()-start_time:.6f}s") 177 | 178 | def pipeline_recv(self, dtype, dimension=3, verbose=False) -> torch.Tensor: 179 | if verbose: start_time = time.time() 180 | 181 | tensor_shape = torch.empty(dimension, device=self.device, dtype=torch.int64) 182 | if verbose: print(f"[{self.device}] Ready size {tensor_shape} with dimension {dimension}") 183 | self._pipeline_irecv(tensor_shape).wait() 184 | if verbose: print(f"[{self.device}] Got size {tensor_shape}") 185 | tensor = torch.empty(tensor_shape.tolist(), device=self.device, dtype=dtype) # 假设数据类型为float32,调整为适合的类型 186 | self._pipeline_irecv(tensor).wait() 187 | 188 | del tensor_shape 189 | if verbose: print(f"[{self.device}] Receiving Tensor with shape {tensor.shape}(dtype {dtype}, sum {tensor.sum()}) in {time.time()-start_time:.6f}s") 190 | return tensor 191 | 192 | def export_to_images(video_frames, output_dir: str = None,): 193 | os.makedirs(output_dir, exist_ok=True) 194 | 195 | output_paths = [] 196 | 197 | if isinstance(video_frames[0], np.ndarray): 198 | video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] 199 | 200 | elif isinstance(video_frames[0], PIL.Image.Image): 201 | video_frames = [np.array(frame) for frame in video_frames] 202 | 203 | for i, frame in enumerate(video_frames): 204 | image_path = os.path.join(output_dir, f"frame_{i:04d}.png") 205 | Image.fromarray(frame).save(image_path) 206 | output_paths.append(image_path) 207 | 208 | return output_paths 209 | 210 | def memory_check(device, info = ""): 211 | allocated_memory = torch.cuda.memory_allocated() 212 | cached_memory = torch.cuda.memory_reserved() 213 | max_allocated_memory = torch.cuda.max_memory_allocated() 214 | max_cached_memory = torch.cuda.max_memory_reserved() 215 | 216 | print(f"[{device}] {info} Allocated Memory: {allocated_memory / (1024 ** 2):.2f} MB") 217 | print(f"[{device}] {info} Cached Memory: {cached_memory / (1024 ** 2):.2f} MB") 218 | print(f"[{device}] {info} Max Allocated Memory: {max_allocated_memory / (1024 ** 2):.2f} MB") 219 | print(f"[{device}] {info} Max Cached Memory: {max_cached_memory / (1024 ** 2):.2f} MB") 220 | 221 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from src.pipelines.pipeline_Wan import DualParalWanPipeline 2 | from src.pipelines.base_pipeline import Queue, QueueLatents, DualParalPipelineBaseWrapper 3 | __all__ = [ 4 | "Queue", 5 | "QueueLatents", 6 | "DualParalPipelineBaseWrapper", 7 | "DualParalWanPipeline", 8 | ] -------------------------------------------------------------------------------- /src/pipelines/base_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.distribution_utils import DistController, RuntimeConfig 3 | 4 | class QueueLatents(object): 5 | def __init__(self, 6 | batch_size, 7 | out_channels, 8 | latent_size, 9 | z = None 10 | ): 11 | self.z = torch.randn(batch_size, out_channels, *latent_size, pin_memory=True) if latent_size is not None else None 12 | self.z = z[:, :, -latent_size[0]:].clone() if z is not None else self.z 13 | self.denoise_time = 0 14 | 15 | def to(self, *args, **kwargs): 16 | self.z = self.z.to(*args, **kwargs).contiguous() if self.z is not None else None 17 | return self 18 | 19 | class Queue(object): 20 | def __init__(self, num_per_block=10, num_cat=5, lenth_of_Queue=50): 21 | self.num_per_block = num_per_block 22 | self.num_cat = num_cat 23 | self.lenth_of_Queue = lenth_of_Queue 24 | self.begin, self.end = 0, -1 25 | self.queue = [] 26 | 27 | def get(self, idx): 28 | return self.queue[idx+self.begin] 29 | 30 | def get_size(self, idx, itself=False): 31 | idx = self.begin+idx 32 | if itself: 33 | return self.queue[idx].z.size(2) 34 | 35 | prev = self.queue[idx-1] if idx-1>=0 else None 36 | L, R = 0, self.queue[idx].z.size(2) 37 | if prev is not None: 38 | R += self.num_cat 39 | return R 40 | 41 | def prepare_for_forward(self, idx): 42 | idx = self.begin + idx 43 | prev = self.queue[idx - 1] if idx-1 >=0 else None 44 | now = self.queue[idx].z.clone() 45 | L, R = 0, now.size(2) 46 | 47 | 48 | if prev is not None and self.num_cat > 0: 49 | now = torch.cat((prev.z[:,:,-self.num_cat:], now), dim=2) 50 | now = now[:, :, -self.num_cat - now.size(2):] 51 | L, R = L + self.num_cat, R + self.num_cat 52 | 53 | return L, R, now, self.queue[idx].denoise_time 54 | 55 | def update(self, idx, block): 56 | self.queue[idx + self.begin].z.copy_(block.z, non_blocking=True) 57 | 58 | def add_block(self, block): 59 | self.queue.append(block) 60 | self.end += 1 61 | 62 | def check_first(self, is_first=False): 63 | if self.begin>self.end: 64 | return False 65 | first_block = self.queue[self.begin] 66 | return first_block.denoise_time==(self.lenth_of_Queue - int(is_first==True)) 67 | 68 | def del_prev_first(self): 69 | ''' 70 | Followed by check_first() 71 | Self.begin add 1 72 | ''' 73 | if self.begin-1>=0 and self.queue[self.begin-1] is not None: 74 | self.queue[self.begin-1] = None 75 | self.begin += 1 76 | 77 | def print_queue(self, device): 78 | print(f"[{device}] LOOK Queue: ", end=" ") 79 | for i in range(self.begin, self.end+1): 80 | print(self.queue[i].denoise_time, end=", ") 81 | print() 82 | 83 | class DualParalPipelineBaseWrapper(object): 84 | def __init__( 85 | self, 86 | parallel_config: DistController, 87 | runtime_config: RuntimeConfig, 88 | ): 89 | self.runtime_config = runtime_config 90 | self.parallel_config = parallel_config 91 | 92 | def _split_transformer_backbone(self): 93 | world_size = self.parallel_config.world_size 94 | local_rank = self.parallel_config.rank 95 | 96 | lenth_of_blocks = len(self.transformer) 97 | lenth_of_blocks_per_device = lenth_of_blocks//world_size 98 | mod_of_blocks_per_device = lenth_of_blocks%world_size 99 | range_of_block = {} 100 | start_q, end_q = 0, lenth_of_blocks_per_device 101 | for i in range(world_size): 102 | end_q = end_q + (mod_of_blocks_per_device>0) 103 | range_of_block[i] = (start_q, end_q) 104 | start_q = end_q 105 | end_q = end_q +lenth_of_blocks_per_device 106 | mod_of_blocks_per_device -= 1 107 | 108 | start_range, final_range = range_of_block[local_rank][0], range_of_block[local_rank][1] 109 | if self.parallel_config.rank==self.parallel_config.world_size-1: 110 | final_range = lenth_of_blocks 111 | range_of_blocks = range( 112 | start_range, final_range 113 | ) 114 | self.transformer_ = self.transformer[range_of_blocks.start:range_of_blocks.stop] 115 | del self.transformer 116 | self.transformer = self.transformer_ 117 | del self.transformer_ 118 | 119 | def forward(self): 120 | pass 121 | 122 | def to(self, *args, **kwargs): 123 | pass -------------------------------------------------------------------------------- /src/pipelines/pipeline_Wan.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from diffusers import WanPipeline 9 | from diffusers.models.attention_processor import Attention 10 | from diffusers.models.transformers.transformer_wan import WanTransformerBlock 11 | 12 | from src.pipelines.base_pipeline import DualParalPipelineBaseWrapper 13 | from src.distribution_utils import DistController, RuntimeConfig 14 | 15 | class DualParal_WanAttnProcessor: 16 | def __init__(self): 17 | if not hasattr(F, "scaled_dot_product_attention"): 18 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 19 | 20 | def __call__( 21 | self, 22 | attn: Attention, 23 | hidden_states: torch.Tensor, 24 | encoder_hidden_states: Optional[torch.Tensor] = None, 25 | attention_mask: Optional[torch.Tensor] = None, 26 | rotary_emb: Optional[torch.Tensor] = None, 27 | cache = None, 28 | latent_num = None, 29 | select = None, 30 | select_all = None, 31 | out_dim = None 32 | ) -> torch.Tensor: 33 | encoder_hidden_states_img = None 34 | if attn.add_k_proj is not None: 35 | encoder_hidden_states_img = encoder_hidden_states[:, :257] 36 | encoder_hidden_states = encoder_hidden_states[:, 257:] 37 | if encoder_hidden_states is None: 38 | encoder_hidden_states = hidden_states 39 | 40 | use_cache = False 41 | cache2, cache3 = None, None 42 | query = attn.to_q(hidden_states) 43 | key = attn.to_k(encoder_hidden_states) 44 | value = attn.to_v(encoder_hidden_states) 45 | if latent_num is not None: 46 | cache2, cache3 = key[:, -latent_num:], value[:, -latent_num:] 47 | cutoff = latent_num//select_all*select 48 | cache2, cache3 = cache2[:, :cutoff].clone(), cache3[:, :cutoff].clone() 49 | 50 | if cache is not None and select>0: 51 | use_cache = True 52 | k_, v_ = cache[0], cache[1] 53 | key = torch.cat((key, k_), dim=1) 54 | value = torch.cat((value, v_), dim=1) 55 | cache = None 56 | if cache2 is not None and select>0: 57 | cache = (cache2, cache3) 58 | else: 59 | cache = None 60 | if attn.norm_q is not None: 61 | query = attn.norm_q(query) 62 | if attn.norm_k is not None: 63 | key = attn.norm_k(key) 64 | 65 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 66 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 67 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 68 | 69 | if rotary_emb is not None: 70 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 71 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 72 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 73 | return x_out.type_as(hidden_states) 74 | query = apply_rotary_emb(query, rotary_emb[0]) 75 | key = apply_rotary_emb(key, rotary_emb[1]) 76 | 77 | L, S = query.size(-2), key.size(-2) 78 | 79 | attn_mask = None 80 | hidden_states = F.scaled_dot_product_attention( 81 | query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, 82 | ) 83 | 84 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 85 | if latent_num is not None: 86 | hidden_states = hidden_states[:, :out_dim] #delete cache 87 | 88 | hidden_states = hidden_states.type_as(query) 89 | 90 | hidden_states = attn.to_out[0](hidden_states) 91 | hidden_states = attn.to_out[1](hidden_states) 92 | return hidden_states, cache 93 | 94 | class DualParal_WanTransformerBlock(nn.Module): 95 | def __init__( 96 | self, 97 | block: WanTransformerBlock, 98 | ): 99 | super().__init__() 100 | self.block = block 101 | 102 | @torch.no_grad() 103 | def forward( 104 | self, 105 | hidden_states: torch.Tensor, 106 | encoder_hidden_states: torch.Tensor, 107 | temb: torch.Tensor, 108 | rotary_emb: torch.Tensor, 109 | cache: Optional[Tuple] = None, 110 | latent_num = None, 111 | select = None, 112 | select_all = None, 113 | attention_mask = None, 114 | ) -> torch.Tensor: 115 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( 116 | self.block.scale_shift_table + temb.float() 117 | ).chunk(6, dim=1) 118 | # 1. Self-attention 119 | norm_hidden_states = (self.block.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) 120 | attn_output, cache = self.block.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, cache=cache, \ 121 | latent_num=latent_num, select=select, select_all=select_all, attention_mask=attention_mask, out_dim=hidden_states.size(1)) 122 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 123 | 124 | # 2. Cross-attention 125 | norm_hidden_states = self.block.norm2(hidden_states.float()).type_as(hidden_states) 126 | attn_output, _ = self.block.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, \ 127 | latent_num=None, select=select, select_all=select_all, attention_mask=attention_mask) 128 | hidden_states = hidden_states + attn_output 129 | 130 | # 3. Feed-forward 131 | norm_hidden_states = (self.block.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( 132 | hidden_states 133 | ) 134 | ff_output = self.block.ffn(norm_hidden_states) 135 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) 136 | return hidden_states, cache 137 | 138 | class DualParalWanPipeline(DualParalPipelineBaseWrapper): 139 | def __init__( 140 | self, 141 | model: WanPipeline, 142 | parallel_config: DistController, 143 | runtimeconfig: RuntimeConfig, 144 | ): 145 | DualParalPipelineBaseWrapper.__init__(self, parallel_config, runtimeconfig) 146 | device, dtype = self.parallel_config.device, self.runtime_config.dtype 147 | 148 | self.tokenizer = model.tokenizer 149 | self.text_encoder = model.text_encoder 150 | self.transformer = model.transformer.blocks # Only Dit Blocks 151 | self.vae = model.vae 152 | self.scheduler = model.scheduler 153 | self.scheduler_dict = {} 154 | self.model = model # Other function or property all in self.model 155 | self.cache = [] 156 | self.attention_mask = [] 157 | self.tmp = None 158 | 159 | pretransformer, finaltransformer = False, False 160 | if self.parallel_config.world_size==1: 161 | pretransformer, finaltransformer = True, True 162 | elif self.parallel_config.rank==0: 163 | pretransformer, finaltransformer = True, False 164 | elif self.parallel_config.rank>0 and self.parallel_config.rank0: 207 | # for position 208 | num_frames += select 209 | latent_size = torch.Size([batch_size, num_channels, num_frames, height, width]) 210 | 211 | self.tmp = torch.empty(*latent_size).to(self.parallel_config.device, self.runtime_config.dtype) 212 | rotary_emb_2 = self.model.transformer.rope(self.tmp) 213 | rotary_emb = (rotary_emb_1, rotary_emb_2) 214 | 215 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.model.transformer.condition_embedder( 216 | t, self.encoder_hidden_states, self.encoder_hidden_states_image 217 | ) 218 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 219 | 220 | if encoder_hidden_states_image is not None: 221 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) 222 | 223 | 224 | if self.pretransformer: 225 | z = torch.cat([z, z], 0) 226 | z = self.model.transformer.patch_embedding(z) 227 | z = z.flatten(2).transpose(1, 2) 228 | 229 | hidden_states = z.contiguous() 230 | use_cache_ = (len(self.cache) > 0) 231 | for now_idx, block in enumerate(self.transformer): 232 | cache = None if not use_cache_ else self.cache[now_idx] 233 | size_tmp = hidden_states.size() 234 | hidden_states, cache = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, cache=cache,\ 235 | latent_num=latent_num, select=select, select_all=select_all, attention_mask=self.attention_mask) 236 | if not use_cache_: 237 | self.cache.append(cache) 238 | else: 239 | self.cache[now_idx] = cache 240 | 241 | if self.finaltransformer: 242 | shift, scale = (self.model.transformer.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 243 | shift = shift.to(hidden_states.device) 244 | scale = scale.to(hidden_states.device) 245 | hidden_states = hidden_states.contiguous() 246 | hidden_states = (self.model.transformer.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 247 | hidden_states = self.model.transformer.proj_out(hidden_states) 248 | hidden_states = hidden_states.reshape( 249 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 250 | ) 251 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 252 | noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 253 | noise_pred = noise_pred 254 | noise_pred, noise_uncond = noise_pred[0], noise_pred[1] 255 | noise_pred = noise_uncond + self.guidance_scale * (noise_pred - noise_uncond) 256 | noise_pred = noise_pred.unsqueeze(0) 257 | if verbose: print(f"[{self.parallel_config.device}] Final (use cache {use_cache_}) tiem {time.time()-start_time:.6f}s") 258 | return noise_pred 259 | else: 260 | if verbose: print(f"[{self.parallel_config.device}] Final (use cache {use_cache_}) tiem {time.time()-start_time:.6f}s") 261 | return hidden_states 262 | 263 | def get_model_args( 264 | self, 265 | prompt: Union[str, List[str]] = None, 266 | negative_prompt: Union[str, List[str]] = None, 267 | num_inference_steps: int = 50, 268 | guidance_scale: float = 5.0, 269 | num_videos_per_prompt: Optional[int] = 1, 270 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 271 | latents: Optional[torch.Tensor] = None, 272 | prompt_embeds: Optional[torch.Tensor] = None, 273 | negative_prompt_embeds: Optional[torch.Tensor] = None, 274 | do_classifier_free_guidance: bool = True, 275 | attention_kwargs: Optional[Dict[str, Any]] = None, 276 | max_sequence_length: int = 512, 277 | ): 278 | device, dtype = self.parallel_config.device, self.runtime_config.dtype 279 | self.guidance_scale = guidance_scale 280 | 281 | # 1. Get Prompt Embedding 282 | if prompt is not None and isinstance(prompt, str): 283 | batch_size = 1 284 | elif prompt is not None and isinstance(prompt, list): 285 | batch_size = len(prompt) 286 | else: 287 | batch_size = prompt_embeds.shape[0] 288 | prompt_embeds, negative_prompt_embeds = self.model.encode_prompt( 289 | prompt=prompt, 290 | negative_prompt=negative_prompt, 291 | do_classifier_free_guidance=do_classifier_free_guidance, 292 | num_videos_per_prompt=num_videos_per_prompt, 293 | prompt_embeds=prompt_embeds, 294 | negative_prompt_embeds=negative_prompt_embeds, 295 | max_sequence_length=max_sequence_length, 296 | device=device, 297 | dtype=dtype, 298 | ) 299 | prompt_embeds = prompt_embeds.to(device, dtype, non_blocking=True) 300 | negative_prompt_embeds = negative_prompt_embeds.to(device, dtype, non_blocking=True) if negative_prompt_embeds is not None else None 301 | if do_classifier_free_guidance: 302 | self.encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) 303 | else: 304 | self.encoder_hidden_states = prompt_embeds 305 | self.encoder_hidden_states_image = None 306 | 307 | del self.text_encoder 308 | del self.tokenizer 309 | del self.model.text_encoder 310 | del self.model.tokenizer 311 | torch.cuda.empty_cache() 312 | 313 | # 2. Prepare Timesteps 314 | self.scheduler.set_timesteps(num_inference_steps, device=device) 315 | self.timesteps = self.scheduler.timesteps 316 | 317 | def to_device(self, device, dtype, **kwargs): 318 | self.model = self.model.to(device, dtype, **kwargs) if self.model is not None else None 319 | self.vae = self.vae.to(dtype) 320 | return self 321 | 322 | def add_scheduler(self, idx): 323 | self.scheduler_dict[idx] = copy.deepcopy(self.scheduler) 324 | 325 | def del_scheduler(self, idx): 326 | del self.scheduler_dict[idx] 327 | 328 | def get_scheduler_dict(self, idx): 329 | return self.scheduler_dict[idx] 330 | 331 | @torch.no_grad() 332 | def get_video( 333 | self, 334 | latents, 335 | output_type: Optional[str] = "np", 336 | verbose=False): 337 | latents = latents.to(self.vae.dtype) 338 | latents_mean = ( 339 | torch.tensor(self.vae.config.latents_mean) 340 | .view(1, self.vae.config.z_dim, 1, 1, 1) 341 | .to(latents.device, latents.dtype, non_blocking=True) 342 | ) 343 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 344 | latents.device, latents.dtype 345 | ) 346 | latents = latents / latents_std + latents_mean 347 | video = self.vae.decode(latents, return_dict=False)[0] 348 | video = self.model.video_processor.postprocess_video(video, output_type=output_type) 349 | return video --------------------------------------------------------------------------------