├── wan ├── distributed │ ├── __init__.py │ └── fsdp.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── segvideo.py │ ├── utils.py │ ├── vace_processor.py │ ├── qwen_vl_utils.py │ └── multitalk_utils.py ├── modules │ ├── __init__.py │ ├── tokenizers.py │ ├── xlm_roberta.py │ ├── vace_model.py │ └── attention.py ├── configs │ ├── shared_config.py │ ├── wan_t2v_14B.py │ ├── wan_i2v_14B.py │ ├── wan_t2v_1_3B.py │ ├── wan_multitalk_14B.py │ └── __init__.py ├── wan_lora.py ├── text2video.py ├── image2video.py └── first_last_frame2video.py ├── src ├── vram_management │ ├── __init__.py │ └── layers.py ├── audio_analysis │ ├── torch_utils.py │ └── wav2vec2.py └── utils.py ├── assets ├── logo.jpg ├── logo2.jpg └── pipeline.png ├── examples ├── single │ ├── 1.wav │ ├── ref_image.png │ └── ref_video.mp4 ├── single_example_video.json └── single_example_image.json ├── requirements.txt ├── kokoro ├── __init__.py ├── __main__.py ├── model.py ├── modules.py └── custom_stft.py ├── README.md └── LICENSE.txt /wan/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/vram_management/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/assets/logo.jpg -------------------------------------------------------------------------------- /assets/logo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/assets/logo2.jpg -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /examples/single/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/examples/single/1.wav -------------------------------------------------------------------------------- /examples/single/ref_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/examples/single/ref_image.png -------------------------------------------------------------------------------- /examples/single/ref_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmwas/InfiniteTalk/HEAD/examples/single/ref_video.mp4 -------------------------------------------------------------------------------- /examples/single_example_video.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A man is talking", 3 | "cond_video": "examples/single/ref_video.mp4", 4 | "cond_audio": { 5 | "person1": "examples/single/1.wav" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /wan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs, distributed, modules 2 | from .first_last_frame2video import WanFLF2V 3 | from .image2video import WanI2V 4 | from .text2video import WanT2V 5 | from .vace import WanVace, WanVaceMP 6 | from .multitalk import InfiniteTalkPipeline 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python>=4.9.0.80 2 | diffusers>=0.31.0 3 | transformers>=4.49.0 4 | tokenizers>=0.20.3 5 | accelerate>=1.1.1 6 | tqdm 7 | imageio 8 | easydict 9 | ftfy 10 | dashscope 11 | imageio-ffmpeg 12 | scikit-image 13 | loguru 14 | gradio>=5.0.0 15 | numpy>=1.23.5,<2 16 | xfuser>=0.4.1 17 | pyloudnorm 18 | optimum-quanto==0.2.6 19 | scenedetect 20 | moviepy==1.0.3 -------------------------------------------------------------------------------- /wan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import ( 2 | FlowDPMSolverMultistepScheduler, 3 | get_sampling_sigmas, 4 | retrieve_timesteps, 5 | ) 6 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 7 | from .vace_processor import VaceVideoProcessor 8 | 9 | __all__ = [ 10 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 11 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', 12 | 'VaceVideoProcessor' 13 | ] 14 | -------------------------------------------------------------------------------- /wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import flash_attention 2 | from .model import WanModel 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 4 | from .tokenizers import HuggingfaceTokenizer 5 | from .vace_model import VaceWanModel 6 | from .vae import WanVAE 7 | 8 | __all__ = [ 9 | 'WanVAE', 10 | 'WanModel', 11 | 'VaceWanModel', 12 | 'T5Model', 13 | 'T5Encoder', 14 | 'T5Decoder', 15 | 'T5EncoderModel', 16 | 'HuggingfaceTokenizer', 17 | 'flash_attention', 18 | ] 19 | -------------------------------------------------------------------------------- /examples/single_example_image.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A woman is passionately singing into a professional microphone in a recording studio. She wears large black headphones and a dark cardigan over a gray top. Her long, wavy brown hair frames her face as she looks slightly upwards, her mouth open mid-song. The studio is equipped with various audio equipment, including a mixing console and a keyboard, with soundproofing panels on the walls. The lighting is warm and focused on her, creating a professional and intimate atmosphere. A close-up shot captures her expressive performance.", 3 | "cond_video": "examples/single/ref_image.png", 4 | "cond_audio": { 5 | "person1": "examples/single/1.wav" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /src/audio_analysis/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_mask_from_lengths(lengths, max_len=None): 6 | lengths = lengths.to(torch.long) 7 | if max_len is None: 8 | max_len = torch.max(lengths).item() 9 | 10 | ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) 11 | mask = ids < lengths.unsqueeze(1).expand(-1, max_len) 12 | 13 | return mask 14 | 15 | 16 | def linear_interpolation(features, seq_len): 17 | features = features.transpose(1, 2) 18 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 19 | return output_features.transpose(1, 2) 20 | 21 | -------------------------------------------------------------------------------- /wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' -------------------------------------------------------------------------------- /kokoro/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.9.4' 2 | 3 | from loguru import logger 4 | import sys 5 | 6 | # Remove default handler 7 | logger.remove() 8 | 9 | # Add custom handler with clean format including module and line number 10 | logger.add( 11 | sys.stderr, 12 | format="{time:HH:mm:ss} | {module:>16}:{line} | {level: >8} | {message}", 13 | colorize=True, 14 | level="INFO" # "DEBUG" to enable logger.debug("message") and up prints 15 | # "ERROR" to enable only logger.error("message") prints 16 | # etc 17 | ) 18 | 19 | # Disable before release or as needed 20 | logger.disable("kokoro") 21 | 22 | from .model import KModel 23 | from .pipeline import KPipeline 24 | -------------------------------------------------------------------------------- /wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt 12 | 13 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 14 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 15 | 16 | # clip 17 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 18 | i2v_14B.clip_dtype = torch.float16 19 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 20 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 21 | 22 | # vae 23 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 24 | i2v_14B.vae_stride = (4, 8, 8) 25 | -------------------------------------------------------------------------------- /wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/distributed/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | from functools import partial 4 | 5 | import torch 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 9 | from torch.distributed.utils import _free_storage 10 | 11 | 12 | def shard_model( 13 | model, 14 | device_id, 15 | param_dtype=torch.bfloat16, 16 | reduce_dtype=torch.float32, 17 | buffer_dtype=torch.float32, 18 | process_group=None, 19 | sharding_strategy=ShardingStrategy.FULL_SHARD, 20 | sync_module_states=True, 21 | ): 22 | model = FSDP( 23 | module=model, 24 | process_group=process_group, 25 | sharding_strategy=sharding_strategy, 26 | auto_wrap_policy=partial( 27 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), 28 | # mixed_precision=MixedPrecision( 29 | # param_dtype=param_dtype, 30 | # reduce_dtype=reduce_dtype, 31 | # buffer_dtype=buffer_dtype), 32 | device_id=device_id, 33 | sync_module_states=sync_module_states) 34 | return model 35 | 36 | 37 | def free_model(model): 38 | for m in model.modules(): 39 | if isinstance(m, FSDP): 40 | _free_storage(m._handle.flat_param.data) 41 | del model 42 | gc.collect() 43 | torch.cuda.empty_cache() 44 | -------------------------------------------------------------------------------- /wan/configs/wan_multitalk_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | multitalk_14B = EasyDict(__name__='Config: Wan MultiTalk AI2V 14B') 10 | multitalk_14B.update(wan_shared_cfg) 11 | multitalk_14B.sample_neg_prompt = 'bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards' 12 | 13 | multitalk_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 14 | multitalk_14B.t5_tokenizer = 'google/umt5-xxl' 15 | 16 | # clip 17 | multitalk_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 18 | multitalk_14B.clip_dtype = torch.float16 19 | multitalk_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 20 | multitalk_14B.clip_tokenizer = 'xlm-roberta-large' 21 | 22 | # vae 23 | multitalk_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 24 | multitalk_14B.vae_stride = (4, 8, 8) 25 | 26 | # transformer 27 | multitalk_14B.patch_size = (1, 2, 2) 28 | multitalk_14B.dim = 5120 29 | multitalk_14B.ffn_dim = 13824 30 | multitalk_14B.freq_dim = 256 31 | multitalk_14B.num_heads = 40 32 | multitalk_14B.num_layers = 40 33 | multitalk_14B.window_size = (-1, -1) 34 | multitalk_14B.qk_norm = True 35 | multitalk_14B.cross_attn_norm = True 36 | multitalk_14B.eps = 1e-6 37 | -------------------------------------------------------------------------------- /wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_i2v_14B import i2v_14B 8 | from .wan_t2v_1_3B import t2v_1_3B 9 | from .wan_t2v_14B import t2v_14B 10 | from .wan_multitalk_14B import multitalk_14B 11 | 12 | # the config of t2i_14B is the same as t2v_14B 13 | t2i_14B = copy.deepcopy(t2v_14B) 14 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 15 | 16 | # the config of flf2v_14B is the same as i2v_14B 17 | flf2v_14B = copy.deepcopy(i2v_14B) 18 | flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' 19 | flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt 20 | 21 | WAN_CONFIGS = { 22 | 't2v-14B': t2v_14B, 23 | 't2v-1.3B': t2v_1_3B, 24 | 'i2v-14B': i2v_14B, 25 | 't2i-14B': t2i_14B, 26 | 'flf2v-14B': flf2v_14B, 27 | 'vace-1.3B': t2v_1_3B, 28 | 'vace-14B': t2v_14B, 29 | 'infinitetalk-14B': multitalk_14B, 30 | } 31 | 32 | SIZE_CONFIGS = { 33 | '720*1280': (720, 1280), 34 | '1280*720': (1280, 720), 35 | '480*832': (480, 832), 36 | '832*480': (832, 480), 37 | '1024*1024': (1024, 1024), 38 | 'infinitetalk-480': (640, 640), 39 | 'infinitetalk-720': (960, 960), 40 | } 41 | 42 | MAX_AREA_CONFIGS = { 43 | '720*1280': 720 * 1280, 44 | '1280*720': 1280 * 720, 45 | '480*832': 480 * 832, 46 | '832*480': 832 * 480, 47 | } 48 | 49 | SUPPORTED_SIZES = { 50 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 51 | 't2v-1.3B': ('480*832', '832*480'), 52 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 53 | 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 54 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 55 | 'vace-1.3B': ('480*832', '832*480'), 56 | 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480'), 57 | 'infinitetalk-14B': ('infinitetalk-480', 'infinitetalk-720'), 58 | } 59 | -------------------------------------------------------------------------------- /wan/utils/segvideo.py: -------------------------------------------------------------------------------- 1 | from scenedetect import SceneManager, open_video, ContentDetector, AdaptiveDetector, ThresholdDetector 2 | from moviepy.editor import * 3 | import copy,os,time,datetime 4 | 5 | def build_manager(): 6 | scene_manager = SceneManager() 7 | scene_manager.add_detector(ContentDetector()) 8 | scene_manager.add_detector(AdaptiveDetector()) 9 | scene_manager.add_detector(ThresholdDetector()) 10 | return scene_manager 11 | 12 | def seg_video(video_path, scene_list, output_dir): 13 | output_fp_list = [] 14 | with VideoFileClip(video_path) as video: 15 | for (start_time,end_time) in scene_list: 16 | if end_time-start_time > 0.5: 17 | start_time = start_time + 0.05 18 | end_time = end_time - 0.05 19 | video_clip = video.subclip(start_time, end_time) 20 | vid = video_path.split('/')[-1].rstrip('.mp4').split('___')[0] 21 | output_fp = os.path.join(output_dir, f'{vid}_{str(start_time)}_{str(end_time)}.mp4') 22 | video_clip.write_videofile(output_fp) 23 | output_fp_list.append(output_fp) 24 | video.close() 25 | return output_fp_list 26 | 27 | def shot_detect(video_path, output_dir): 28 | 29 | os.makedirs(output_dir, exist_ok=True) 30 | print(f'start process {video_path}') 31 | start_time = time.time() 32 | attribs = {} 33 | attribs['filepath'] = video_path 34 | try: 35 | video = open_video(video_path) 36 | scene_manager = build_manager() 37 | scene_manager.detect_scenes(video,show_progress=False) 38 | stamps = scene_manager.get_scene_list() 39 | scene_list = [] 40 | for stamp in stamps: 41 | start, end = stamp 42 | scene_list.append((start.get_seconds(), end.get_seconds())) 43 | 44 | attribs['shot_stamps'] = scene_list 45 | output_fp_list = seg_video(video_path, scene_list, output_dir) 46 | 47 | except Exception as e: 48 | print([e, video_path]) 49 | 50 | 51 | 52 | print(f"process {video_path} Done with {time.time()-start_time:.2f} seconds used.") 53 | return scene_list, output_fp_list 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | 5 | @contextmanager 6 | def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): 7 | old_register_parameter = torch.nn.Module.register_parameter 8 | if include_buffers: 9 | old_register_buffer = torch.nn.Module.register_buffer 10 | 11 | def register_empty_parameter(module, name, param): 12 | old_register_parameter(module, name, param) 13 | if param is not None: 14 | param_cls = type(module._parameters[name]) 15 | kwargs = module._parameters[name].__dict__ 16 | kwargs["requires_grad"] = param.requires_grad 17 | module._parameters[name] = param_cls( 18 | module._parameters[name].to(device), **kwargs 19 | ) 20 | 21 | def register_empty_buffer(module, name, buffer, persistent=True): 22 | old_register_buffer(module, name, buffer, persistent=persistent) 23 | if buffer is not None: 24 | module._buffers[name] = module._buffers[name].to(device) 25 | 26 | def patch_tensor_constructor(fn): 27 | def wrapper(*args, **kwargs): 28 | kwargs["device"] = device 29 | return fn(*args, **kwargs) 30 | 31 | return wrapper 32 | 33 | if include_buffers: 34 | tensor_constructors_to_patch = { 35 | torch_function_name: getattr(torch, torch_function_name) 36 | for torch_function_name in ["empty", "zeros", "ones", "full"] 37 | } 38 | else: 39 | tensor_constructors_to_patch = {} 40 | 41 | try: 42 | torch.nn.Module.register_parameter = register_empty_parameter 43 | if include_buffers: 44 | torch.nn.Module.register_buffer = register_empty_buffer 45 | for torch_function_name in tensor_constructors_to_patch.keys(): 46 | setattr( 47 | torch, 48 | torch_function_name, 49 | patch_tensor_constructor(getattr(torch, torch_function_name)), 50 | ) 51 | yield 52 | finally: 53 | torch.nn.Module.register_parameter = old_register_parameter 54 | if include_buffers: 55 | torch.nn.Module.register_buffer = old_register_buffer 56 | for ( 57 | torch_function_name, 58 | old_torch_function, 59 | ) in tensor_constructors_to_patch.items(): 60 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /wan/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /kokoro/__main__.py: -------------------------------------------------------------------------------- 1 | """Kokoro TTS CLI 2 | Example usage: 3 | python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug 4 | 5 | echo "Bom dia mundo, como vão vocês" > text.txt 6 | python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav 7 | 8 | Common issues: 9 | pip not installed: `uv pip install pip` 10 | (Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed) 11 | 12 | espeak not installed: `apt-get install espeak-ng` 13 | """ 14 | 15 | import argparse 16 | import wave 17 | from pathlib import Path 18 | from typing import Generator, TYPE_CHECKING 19 | 20 | import numpy as np 21 | from loguru import logger 22 | 23 | languages = [ 24 | "a", # American English 25 | "b", # British English 26 | "h", # Hindi 27 | "e", # Spanish 28 | "f", # French 29 | "i", # Italian 30 | "p", # Brazilian Portuguese 31 | "j", # Japanese 32 | "z", # Mandarin Chinese 33 | ] 34 | 35 | if TYPE_CHECKING: 36 | from kokoro import KPipeline 37 | 38 | 39 | def generate_audio( 40 | text: str, kokoro_language: str, voice: str, speed=1 41 | ) -> Generator["KPipeline.Result", None, None]: 42 | from kokoro import KPipeline 43 | 44 | if not voice.startswith(kokoro_language): 45 | logger.warning(f"Voice {voice} is not made for language {kokoro_language}") 46 | pipeline = KPipeline(lang_code=kokoro_language) 47 | yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+") 48 | 49 | 50 | def generate_and_save_audio( 51 | output_file: Path, text: str, kokoro_language: str, voice: str, speed=1 52 | ) -> None: 53 | with wave.open(str(output_file.resolve()), "wb") as wav_file: 54 | wav_file.setnchannels(1) # Mono audio 55 | wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio) 56 | wav_file.setframerate(24000) # Sample rate 57 | 58 | for result in generate_audio( 59 | text, kokoro_language=kokoro_language, voice=voice, speed=speed 60 | ): 61 | logger.debug(result.phonemes) 62 | if result.audio is None: 63 | continue 64 | audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes() 65 | wav_file.writeframes(audio_bytes) 66 | 67 | 68 | def main() -> None: 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument( 71 | "-m", 72 | "--voice", 73 | default="af_heart", 74 | help="Voice to use", 75 | ) 76 | parser.add_argument( 77 | "-l", 78 | "--language", 79 | help="Language to use (defaults to the one corresponding to the voice)", 80 | choices=languages, 81 | ) 82 | parser.add_argument( 83 | "-o", 84 | "--output-file", 85 | "--output_file", 86 | type=Path, 87 | help="Path to output WAV file", 88 | required=True, 89 | ) 90 | parser.add_argument( 91 | "-i", 92 | "--input-file", 93 | "--input_file", 94 | type=Path, 95 | help="Path to input text file (default: stdin)", 96 | ) 97 | parser.add_argument( 98 | "-t", 99 | "--text", 100 | help="Text to use instead of reading from stdin", 101 | ) 102 | parser.add_argument( 103 | "-s", 104 | "--speed", 105 | type=float, 106 | default=1.0, 107 | help="Speech speed", 108 | ) 109 | parser.add_argument( 110 | "--debug", 111 | action="store_true", 112 | help="Print DEBUG messages to console", 113 | ) 114 | args = parser.parse_args() 115 | if args.debug: 116 | logger.level("DEBUG") 117 | logger.debug(args) 118 | 119 | lang = args.language or args.voice[0] 120 | 121 | if args.text is not None and args.input_file is not None: 122 | raise Exception("You cannot specify both 'text' and 'input_file'") 123 | elif args.text: 124 | text = args.text 125 | elif args.input_file: 126 | file: Path = args.input_file 127 | text = file.read_text() 128 | else: 129 | import sys 130 | print("Press Ctrl+D to stop reading input and start generating", flush=True) 131 | text = '\n'.join(sys.stdin) 132 | 133 | logger.debug(f"Input text: {text!r}") 134 | 135 | out_file: Path = args.output_file 136 | if not out_file.suffix == ".wav": 137 | logger.warning("The output file name should end with .wav") 138 | generate_and_save_audio( 139 | output_file=out_file, 140 | text=text, 141 | kokoro_language=lang, 142 | voice=args.voice, 143 | speed=args.speed, 144 | ) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /wan/wan_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from safetensors import safe_open 4 | from loguru import logger 5 | import gc 6 | from functools import lru_cache 7 | from tqdm import tqdm 8 | 9 | @lru_cache(maxsize=None) 10 | def GET_DTYPE(): 11 | RUNNING_FLAG = os.getenv("DTYPE") 12 | return RUNNING_FLAG 13 | 14 | class WanLoraWrapper: 15 | def __init__(self, wan_model): 16 | self.model = wan_model 17 | self.lora_metadata = {} 18 | # self.override_dict = {} # On CPU 19 | 20 | def load_lora(self, lora_path, lora_name=None): 21 | if lora_name is None: 22 | lora_name = os.path.basename(lora_path).split(".")[0] 23 | 24 | if lora_name in self.lora_metadata: 25 | logger.info(f"LoRA {lora_name} already loaded, skipping...") 26 | return lora_name 27 | 28 | self.lora_metadata[lora_name] = {"path": lora_path} 29 | logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}") 30 | 31 | return lora_name 32 | 33 | def _load_lora_file(self, file_path, param_dtype): 34 | with safe_open(file_path, framework="pt") as f: 35 | tensor_dict = {key: f.get_tensor(key).to(param_dtype) for key in f.keys()} 36 | return tensor_dict 37 | 38 | def apply_lora(self, lora_name, alpha=1.0, param_dtype=torch.bfloat16, device='cpu'): 39 | if lora_name not in self.lora_metadata: 40 | logger.info(f"LoRA {lora_name} not found. Please load it first.") 41 | 42 | 43 | 44 | lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"], param_dtype) 45 | # weight_dict = self.model.original_weight_dict 46 | self._apply_lora_weights(lora_weights, alpha, device) 47 | # self.model._init_weights(weight_dict) 48 | 49 | logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") 50 | return True 51 | 52 | def get_parameter_by_name(self, model, param_name): 53 | parts = param_name.split('.') 54 | current = model 55 | for part in parts: 56 | if part.isdigit(): 57 | current = current[int(part)] 58 | else: 59 | current = getattr(current, part) 60 | return current 61 | 62 | @torch.no_grad() 63 | def _apply_lora_weights(self, lora_weights, alpha, device): 64 | lora_pairs = {} 65 | prefix = "diffusion_model." 66 | 67 | for key in lora_weights.keys(): 68 | if key.endswith("lora_down.weight") and key.startswith(prefix): 69 | base_name = key[len(prefix) :].replace("lora_down.weight", "weight") 70 | b_key = key.replace("lora_down.weight", "lora_up.weight") 71 | if b_key in lora_weights: 72 | lora_pairs[base_name] = (key, b_key) 73 | elif key.endswith("diff_b") and key.startswith(prefix): 74 | base_name = key[len(prefix) :].replace("diff_b", "bias") 75 | lora_pairs[base_name] = (key) 76 | elif key.endswith("diff") and key.startswith(prefix): 77 | base_name = key[len(prefix) :].replace("diff", "weight") 78 | lora_pairs[base_name] = (key) 79 | 80 | applied_count = 0 81 | for name in tqdm(lora_pairs.keys(), desc="Loading LoRA weights"): 82 | param = self.get_parameter_by_name(self.model, name) 83 | if device == 'cpu': 84 | dtype = torch.float32 85 | else: 86 | dtype = param.dtype 87 | if isinstance(lora_pairs[name], tuple): 88 | name_lora_A, name_lora_B = lora_pairs[name] 89 | lora_A = lora_weights[name_lora_A].to(device, dtype) 90 | lora_B = lora_weights[name_lora_B].to(device, dtype) 91 | delta = torch.matmul(lora_B, lora_A) * alpha 92 | delta = delta.to(param.device, param.dtype) 93 | param.add_(delta) 94 | else: 95 | name_lora = lora_pairs[name] 96 | delta = lora_weights[name_lora].to(param.device, dtype)* alpha 97 | delta = delta.to(param.device, param.dtype) 98 | param.add_(delta) 99 | applied_count += 1 100 | 101 | 102 | logger.info(f"Applied {applied_count} LoRA weight adjustments") 103 | if applied_count == 0: 104 | logger.info( 105 | "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model..lora_A.weight' and 'diffusion_model..lora_B.weight'. Please verify the LoRA weight file." 106 | ) 107 | 108 | 109 | def list_loaded_loras(self): 110 | return list(self.lora_metadata.keys()) 111 | 112 | def get_current_lora(self): 113 | return self.model.current_lora -------------------------------------------------------------------------------- /src/audio_analysis/wav2vec2.py: -------------------------------------------------------------------------------- 1 | from transformers import Wav2Vec2Config, Wav2Vec2Model 2 | from transformers.modeling_outputs import BaseModelOutput 3 | 4 | from src.audio_analysis.torch_utils import linear_interpolation 5 | 6 | # the implementation of Wav2Vec2Model is borrowed from 7 | # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py 8 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 9 | class Wav2Vec2Model(Wav2Vec2Model): 10 | def __init__(self, config: Wav2Vec2Config): 11 | super().__init__(config) 12 | 13 | def forward( 14 | self, 15 | input_values, 16 | seq_len, 17 | attention_mask=None, 18 | mask_time_indices=None, 19 | output_attentions=None, 20 | output_hidden_states=None, 21 | return_dict=None, 22 | ): 23 | self.config.output_attentions = True 24 | 25 | output_hidden_states = ( 26 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 27 | ) 28 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 29 | 30 | extract_features = self.feature_extractor(input_values) 31 | extract_features = extract_features.transpose(1, 2) 32 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 33 | 34 | if attention_mask is not None: 35 | # compute reduced attention_mask corresponding to feature vectors 36 | attention_mask = self._get_feature_vector_attention_mask( 37 | extract_features.shape[1], attention_mask, add_adapter=False 38 | ) 39 | 40 | hidden_states, extract_features = self.feature_projection(extract_features) 41 | hidden_states = self._mask_hidden_states( 42 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 43 | ) 44 | 45 | encoder_outputs = self.encoder( 46 | hidden_states, 47 | attention_mask=attention_mask, 48 | output_attentions=output_attentions, 49 | output_hidden_states=output_hidden_states, 50 | return_dict=return_dict, 51 | ) 52 | 53 | hidden_states = encoder_outputs[0] 54 | 55 | if self.adapter is not None: 56 | hidden_states = self.adapter(hidden_states) 57 | 58 | if not return_dict: 59 | return (hidden_states, ) + encoder_outputs[1:] 60 | return BaseModelOutput( 61 | last_hidden_state=hidden_states, 62 | hidden_states=encoder_outputs.hidden_states, 63 | attentions=encoder_outputs.attentions, 64 | ) 65 | 66 | 67 | def feature_extract( 68 | self, 69 | input_values, 70 | seq_len, 71 | ): 72 | extract_features = self.feature_extractor(input_values) 73 | extract_features = extract_features.transpose(1, 2) 74 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 75 | 76 | return extract_features 77 | 78 | def encode( 79 | self, 80 | extract_features, 81 | attention_mask=None, 82 | mask_time_indices=None, 83 | output_attentions=None, 84 | output_hidden_states=None, 85 | return_dict=None, 86 | ): 87 | self.config.output_attentions = True 88 | 89 | output_hidden_states = ( 90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 91 | ) 92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 93 | 94 | if attention_mask is not None: 95 | # compute reduced attention_mask corresponding to feature vectors 96 | attention_mask = self._get_feature_vector_attention_mask( 97 | extract_features.shape[1], attention_mask, add_adapter=False 98 | ) 99 | 100 | 101 | hidden_states, extract_features = self.feature_projection(extract_features) 102 | hidden_states = self._mask_hidden_states( 103 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 104 | ) 105 | 106 | encoder_outputs = self.encoder( 107 | hidden_states, 108 | attention_mask=attention_mask, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | return_dict=return_dict, 112 | ) 113 | 114 | hidden_states = encoder_outputs[0] 115 | 116 | if self.adapter is not None: 117 | hidden_states = self.adapter(hidden_states) 118 | 119 | if not return_dict: 120 | return (hidden_states, ) + encoder_outputs[1:] 121 | return BaseModelOutput( 122 | last_hidden_state=hidden_states, 123 | hidden_states=encoder_outputs.hidden_states, 124 | attentions=encoder_outputs.attentions, 125 | ) 126 | -------------------------------------------------------------------------------- /wan/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['XLMRoberta', 'xlm_roberta_large'] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | 12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 13 | assert dim % num_heads == 0 14 | super().__init__() 15 | self.dim = dim 16 | self.num_heads = num_heads 17 | self.head_dim = dim // num_heads 18 | self.eps = eps 19 | 20 | # layers 21 | self.q = nn.Linear(dim, dim) 22 | self.k = nn.Linear(dim, dim) 23 | self.v = nn.Linear(dim, dim) 24 | self.o = nn.Linear(dim, dim) 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, x, mask): 28 | """ 29 | x: [B, L, C]. 30 | """ 31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 32 | 33 | # compute query, key, value 34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 37 | 38 | # compute attention 39 | p = self.dropout.p if self.training else 0.0 40 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 42 | 43 | # output 44 | x = self.o(x) 45 | x = self.dropout(x) 46 | return x 47 | 48 | 49 | class AttentionBlock(nn.Module): 50 | 51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 52 | super().__init__() 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | self.post_norm = post_norm 56 | self.eps = eps 57 | 58 | # layers 59 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 60 | self.norm1 = nn.LayerNorm(dim, eps=eps) 61 | self.ffn = nn.Sequential( 62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), 63 | nn.Dropout(dropout)) 64 | self.norm2 = nn.LayerNorm(dim, eps=eps) 65 | 66 | def forward(self, x, mask): 67 | if self.post_norm: 68 | x = self.norm1(x + self.attn(x, mask)) 69 | x = self.norm2(x + self.ffn(x)) 70 | else: 71 | x = x + self.attn(self.norm1(x), mask) 72 | x = x + self.ffn(self.norm2(x)) 73 | return x 74 | 75 | 76 | class XLMRoberta(nn.Module): 77 | """ 78 | XLMRobertaModel with no pooler and no LM head. 79 | """ 80 | 81 | def __init__(self, 82 | vocab_size=250002, 83 | max_seq_len=514, 84 | type_size=1, 85 | pad_id=1, 86 | dim=1024, 87 | num_heads=16, 88 | num_layers=24, 89 | post_norm=True, 90 | dropout=0.1, 91 | eps=1e-5): 92 | super().__init__() 93 | self.vocab_size = vocab_size 94 | self.max_seq_len = max_seq_len 95 | self.type_size = type_size 96 | self.pad_id = pad_id 97 | self.dim = dim 98 | self.num_heads = num_heads 99 | self.num_layers = num_layers 100 | self.post_norm = post_norm 101 | self.eps = eps 102 | 103 | # embeddings 104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 105 | self.type_embedding = nn.Embedding(type_size, dim) 106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | # blocks 110 | self.blocks = nn.ModuleList([ 111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps) 112 | for _ in range(num_layers) 113 | ]) 114 | 115 | # norm layer 116 | self.norm = nn.LayerNorm(dim, eps=eps) 117 | 118 | def forward(self, ids): 119 | """ 120 | ids: [B, L] of torch.LongTensor. 121 | """ 122 | b, s = ids.shape 123 | mask = ids.ne(self.pad_id).long() 124 | 125 | # embeddings 126 | x = self.token_embedding(ids) + \ 127 | self.type_embedding(torch.zeros_like(ids)) + \ 128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 129 | if self.post_norm: 130 | x = self.norm(x) 131 | x = self.dropout(x) 132 | 133 | # blocks 134 | mask = torch.where( 135 | mask.view(b, 1, 1, s).gt(0), 0.0, 136 | torch.finfo(x.dtype).min) 137 | for block in self.blocks: 138 | x = block(x, mask) 139 | 140 | # output 141 | if not self.post_norm: 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | def xlm_roberta_large(pretrained=False, 147 | return_tokenizer=False, 148 | device='cpu', 149 | **kwargs): 150 | """ 151 | XLMRobertaLarge adapted from Huggingface. 152 | """ 153 | # params 154 | cfg = dict( 155 | vocab_size=250002, 156 | max_seq_len=514, 157 | type_size=1, 158 | pad_id=1, 159 | dim=1024, 160 | num_heads=16, 161 | num_layers=24, 162 | post_norm=True, 163 | dropout=0.1, 164 | eps=1e-5) 165 | cfg.update(**kwargs) 166 | 167 | # init a model on device 168 | with torch.device(device): 169 | model = XLMRoberta(**cfg) 170 | return model 171 | -------------------------------------------------------------------------------- /wan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import binascii 4 | import os 5 | import os.path as osp 6 | import cv2 7 | 8 | import imageio 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | import librosa 13 | import soundfile as sf 14 | import subprocess 15 | from decord import VideoReader, cpu 16 | import gc 17 | 18 | __all__ = ['cache_video', 'cache_image', 'str2bool'] 19 | 20 | 21 | def rand_name(length=8, suffix=''): 22 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 23 | if suffix: 24 | if not suffix.startswith('.'): 25 | suffix = '.' + suffix 26 | name += suffix 27 | return name 28 | 29 | 30 | 31 | def str2bool(v): 32 | """ 33 | Convert a string to a boolean. 34 | 35 | Supported true values: 'yes', 'true', 't', 'y', '1' 36 | Supported false values: 'no', 'false', 'f', 'n', '0' 37 | 38 | Args: 39 | v (str): String to convert. 40 | 41 | Returns: 42 | bool: Converted boolean value. 43 | 44 | Raises: 45 | argparse.ArgumentTypeError: If the value cannot be converted to boolean. 46 | """ 47 | if isinstance(v, bool): 48 | return v 49 | v_lower = v.lower() 50 | if v_lower in ('yes', 'true', 't', 'y', '1'): 51 | return True 52 | elif v_lower in ('no', 'false', 'f', 'n', '0'): 53 | return False 54 | else: 55 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') 56 | 57 | def cache_video(tensor, 58 | save_file=None, 59 | fps=30, 60 | suffix='.mp4', 61 | nrow=8, 62 | normalize=True, 63 | value_range=(-1, 1), 64 | retry=5): 65 | # cache file 66 | cache_file = osp.join('/tmp', rand_name( 67 | suffix=suffix)) if save_file is None else save_file 68 | 69 | # save to cache 70 | error = None 71 | for _ in range(retry): 72 | try: 73 | # preprocess 74 | tensor = tensor.clamp(min(value_range), max(value_range)) 75 | tensor = torch.stack([ 76 | torchvision.utils.make_grid( 77 | u, nrow=nrow, normalize=normalize, value_range=value_range) 78 | for u in tensor.unbind(2) 79 | ], 80 | dim=1).permute(1, 2, 3, 0) 81 | tensor = (tensor * 255).type(torch.uint8).cpu() 82 | 83 | # write video 84 | writer = imageio.get_writer( 85 | cache_file, fps=fps, codec='libx264', quality=8) 86 | for frame in tensor.numpy(): 87 | writer.append_data(frame) 88 | writer.close() 89 | return cache_file 90 | except Exception as e: 91 | error = e 92 | continue 93 | else: 94 | print(f'cache_video failed, error: {error}', flush=True) 95 | return None 96 | 97 | 98 | def cache_image(tensor, 99 | save_file, 100 | nrow=8, 101 | normalize=True, 102 | value_range=(-1, 1), 103 | retry=5): 104 | # cache file 105 | suffix = osp.splitext(save_file)[1] 106 | if suffix.lower() not in [ 107 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' 108 | ]: 109 | suffix = '.png' 110 | 111 | # save to cache 112 | error = None 113 | for _ in range(retry): 114 | try: 115 | tensor = tensor.clamp(min(value_range), max(value_range)) 116 | torchvision.utils.save_image( 117 | tensor, 118 | save_file, 119 | nrow=nrow, 120 | normalize=normalize, 121 | value_range=value_range) 122 | return save_file 123 | except Exception as e: 124 | error = e 125 | continue 126 | 127 | def convert_video_to_h264(input_video_path, output_video_path): 128 | subprocess.run( 129 | ['ffmpeg', '-i', input_video_path, '-c:v', 'libx264', '-c:a', 'copy', output_video_path], 130 | stdout=subprocess.PIPE, 131 | stderr=subprocess.PIPE 132 | ) 133 | 134 | 135 | def is_video(path): 136 | video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg'] 137 | return os.path.splitext(path)[1].lower() in video_exts 138 | 139 | 140 | def extract_specific_frames(video_path, frame_id): 141 | if is_video(video_path): 142 | vr = VideoReader(video_path, ctx=cpu(0)) 143 | if frame_id < vr._num_frame: 144 | frame = vr[frame_id].asnumpy() # RGB 145 | else: 146 | frame = vr[-1].asnumpy() 147 | del vr 148 | gc.collect() 149 | frame = Image.fromarray(frame) 150 | else: 151 | frame = Image.open(video_path).convert("RGB") 152 | return frame 153 | 154 | def get_video_codec(video_path): 155 | result = subprocess.run( 156 | ['ffprobe', '-v', 'error', '-select_streams', 'v:0', 157 | '-show_entries', 'stream=codec_name', '-of', 'default=nw=1:nk=1', video_path], 158 | stdout=subprocess.PIPE, 159 | stderr=subprocess.PIPE 160 | ) 161 | codec = result.stdout.decode().strip() 162 | return codec 163 | 164 | 165 | 166 | def split_wav_librosa(wav_path, segments, save_dir): 167 | y, sr = librosa.load(wav_path, sr=None) 168 | filename = wav_path.split('/')[-1].split('.')[0] 169 | save_list = [] 170 | for idx, (start, end) in enumerate(segments): 171 | start_sample = int(start * sr) 172 | end_sample = int(end * sr) 173 | segment = y[start_sample:end_sample] 174 | out_path = os.path.join(save_dir, filename + str(start) + '_' + str(end) + '.wav') 175 | sf.write(out_path, segment, sr) 176 | print(f"Saved {out_path}: {start}s to {end}s") 177 | save_list.append(out_path) 178 | return save_list 179 | 180 | -------------------------------------------------------------------------------- /kokoro/model.py: -------------------------------------------------------------------------------- 1 | from .istftnet import Decoder 2 | from .modules import CustomAlbert, ProsodyPredictor, TextEncoder 3 | from dataclasses import dataclass 4 | from huggingface_hub import hf_hub_download 5 | from loguru import logger 6 | from transformers import AlbertConfig 7 | from typing import Dict, Optional, Union 8 | import json 9 | import torch 10 | import os 11 | 12 | class KModel(torch.nn.Module): 13 | ''' 14 | KModel is a torch.nn.Module with 2 main responsibilities: 15 | 1. Init weights, downloading config.json + model.pth from HF if needed 16 | 2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor) 17 | 18 | You likely only need one KModel instance, and it can be reused across 19 | multiple KPipelines to avoid redundant memory allocation. 20 | 21 | Unlike KPipeline, KModel is language-blind. 22 | 23 | KModel stores self.vocab and thus knows how to map phonemes -> input_ids, 24 | so there is no need to repeatedly download config.json outside of KModel. 25 | ''' 26 | 27 | MODEL_NAMES = { 28 | 'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth', 29 | 'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth', 30 | } 31 | 32 | def __init__( 33 | self, 34 | repo_id: Optional[str] = None, 35 | config: Union[Dict, str, None] = None, 36 | model: Optional[str] = None, 37 | disable_complex: bool = False 38 | ): 39 | super().__init__() 40 | if repo_id is None: 41 | repo_id = 'hexgrad/Kokoro-82M' 42 | print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.") 43 | self.repo_id = repo_id 44 | if not isinstance(config, dict): 45 | if not config: 46 | logger.debug("No config provided, downloading from HF") 47 | config = hf_hub_download(repo_id=repo_id, filename='config.json') 48 | with open(config, 'r', encoding='utf-8') as r: 49 | config = json.load(r) 50 | logger.debug(f"Loaded config: {config}") 51 | self.vocab = config['vocab'] 52 | self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert'])) 53 | self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim']) 54 | self.context_length = self.bert.config.max_position_embeddings 55 | self.predictor = ProsodyPredictor( 56 | style_dim=config['style_dim'], d_hid=config['hidden_dim'], 57 | nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout'] 58 | ) 59 | self.text_encoder = TextEncoder( 60 | channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'], 61 | depth=config['n_layer'], n_symbols=config['n_token'] 62 | ) 63 | self.decoder = Decoder( 64 | dim_in=config['hidden_dim'], style_dim=config['style_dim'], 65 | dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet'] 66 | ) 67 | if not model: 68 | try: 69 | model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id]) 70 | except: 71 | model = os.path.join(repo_id, 'kokoro-v1_0.pth') 72 | for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items(): 73 | assert hasattr(self, key), key 74 | try: 75 | getattr(self, key).load_state_dict(state_dict) 76 | except: 77 | logger.debug(f"Did not load {key} from state_dict") 78 | state_dict = {k[7:]: v for k, v in state_dict.items()} 79 | getattr(self, key).load_state_dict(state_dict, strict=False) 80 | 81 | @property 82 | def device(self): 83 | return self.bert.device 84 | 85 | @dataclass 86 | class Output: 87 | audio: torch.FloatTensor 88 | pred_dur: Optional[torch.LongTensor] = None 89 | 90 | @torch.no_grad() 91 | def forward_with_tokens( 92 | self, 93 | input_ids: torch.LongTensor, 94 | ref_s: torch.FloatTensor, 95 | speed: float = 1 96 | ) -> tuple[torch.FloatTensor, torch.LongTensor]: 97 | input_lengths = torch.full( 98 | (input_ids.shape[0],), 99 | input_ids.shape[-1], 100 | device=input_ids.device, 101 | dtype=torch.long 102 | ) 103 | 104 | text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths) 105 | text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device) 106 | bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int()) 107 | d_en = self.bert_encoder(bert_dur).transpose(-1, -2) 108 | s = ref_s[:, 128:] 109 | d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask) 110 | x, _ = self.predictor.lstm(d) 111 | duration = self.predictor.duration_proj(x) 112 | duration = torch.sigmoid(duration).sum(axis=-1) / speed 113 | pred_dur = torch.round(duration).clamp(min=1).long().squeeze() 114 | indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur) 115 | pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device) 116 | pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1 117 | pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device) 118 | en = d.transpose(-1, -2) @ pred_aln_trg 119 | F0_pred, N_pred = self.predictor.F0Ntrain(en, s) 120 | t_en = self.text_encoder(input_ids, input_lengths, text_mask) 121 | asr = t_en @ pred_aln_trg 122 | audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze() 123 | return audio, pred_dur 124 | 125 | def forward( 126 | self, 127 | phonemes: str, 128 | ref_s: torch.FloatTensor, 129 | speed: float = 1, 130 | return_output: bool = False 131 | ) -> Union['KModel.Output', torch.FloatTensor]: 132 | input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes))) 133 | logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}") 134 | assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length) 135 | input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device) 136 | ref_s = ref_s.to(self.device) 137 | audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed) 138 | audio = audio.squeeze().cpu() 139 | pred_dur = pred_dur.cpu() if pred_dur is not None else None 140 | logger.debug(f"pred_dur: {pred_dur}") 141 | return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio 142 | 143 | class KModelForONNX(torch.nn.Module): 144 | def __init__(self, kmodel: KModel): 145 | super().__init__() 146 | self.kmodel = kmodel 147 | 148 | def forward( 149 | self, 150 | input_ids: torch.LongTensor, 151 | ref_s: torch.FloatTensor, 152 | speed: float = 1 153 | ) -> tuple[torch.FloatTensor, torch.LongTensor]: 154 | waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed) 155 | return waveform, duration 156 | -------------------------------------------------------------------------------- /kokoro/modules.py: -------------------------------------------------------------------------------- 1 | # https://github.com/yl4579/StyleTTS2/blob/main/models.py 2 | from .istftnet import AdainResBlk1d 3 | from torch.nn.utils import weight_norm 4 | from transformers import AlbertModel 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LinearNorm(nn.Module): 12 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 13 | super(LinearNorm, self).__init__() 14 | self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) 15 | nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class LayerNorm(nn.Module): 22 | def __init__(self, channels, eps=1e-5): 23 | super().__init__() 24 | self.channels = channels 25 | self.eps = eps 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class TextEncoder(nn.Module): 36 | def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): 37 | super().__init__() 38 | self.embedding = nn.Embedding(n_symbols, channels) 39 | padding = (kernel_size - 1) // 2 40 | self.cnn = nn.ModuleList() 41 | for _ in range(depth): 42 | self.cnn.append(nn.Sequential( 43 | weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)), 44 | LayerNorm(channels), 45 | actv, 46 | nn.Dropout(0.2), 47 | )) 48 | self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True) 49 | 50 | def forward(self, x, input_lengths, m): 51 | x = self.embedding(x) # [B, T, emb] 52 | x = x.transpose(1, 2) # [B, emb, T] 53 | m = m.unsqueeze(1) 54 | x.masked_fill_(m, 0.0) 55 | for c in self.cnn: 56 | x = c(x) 57 | x.masked_fill_(m, 0.0) 58 | x = x.transpose(1, 2) # [B, T, chn] 59 | lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu') 60 | x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 61 | self.lstm.flatten_parameters() 62 | x, _ = self.lstm(x) 63 | x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) 64 | x = x.transpose(-1, -2) 65 | x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) 66 | x_pad[:, :, :x.shape[-1]] = x 67 | x = x_pad 68 | x.masked_fill_(m, 0.0) 69 | return x 70 | 71 | 72 | class AdaLayerNorm(nn.Module): 73 | def __init__(self, style_dim, channels, eps=1e-5): 74 | super().__init__() 75 | self.channels = channels 76 | self.eps = eps 77 | self.fc = nn.Linear(style_dim, channels*2) 78 | 79 | def forward(self, x, s): 80 | x = x.transpose(-1, -2) 81 | x = x.transpose(1, -1) 82 | h = self.fc(s) 83 | h = h.view(h.size(0), h.size(1), 1) 84 | gamma, beta = torch.chunk(h, chunks=2, dim=1) 85 | gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) 86 | x = F.layer_norm(x, (self.channels,), eps=self.eps) 87 | x = (1 + gamma) * x + beta 88 | return x.transpose(1, -1).transpose(-1, -2) 89 | 90 | 91 | class ProsodyPredictor(nn.Module): 92 | def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): 93 | super().__init__() 94 | self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout) 95 | self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) 96 | self.duration_proj = LinearNorm(d_hid, max_dur) 97 | self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) 98 | self.F0 = nn.ModuleList() 99 | self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) 100 | self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) 101 | self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) 102 | self.N = nn.ModuleList() 103 | self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) 104 | self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) 105 | self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) 106 | self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) 107 | self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) 108 | 109 | def forward(self, texts, style, text_lengths, alignment, m): 110 | d = self.text_encoder(texts, style, text_lengths, m) 111 | m = m.unsqueeze(1) 112 | lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') 113 | x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False) 114 | self.lstm.flatten_parameters() 115 | x, _ = self.lstm(x) 116 | x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) 117 | x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device) 118 | x_pad[:, :x.shape[1], :] = x 119 | x = x_pad 120 | duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False)) 121 | en = (d.transpose(-1, -2) @ alignment) 122 | return duration.squeeze(-1), en 123 | 124 | def F0Ntrain(self, x, s): 125 | x, _ = self.shared(x.transpose(-1, -2)) 126 | F0 = x.transpose(-1, -2) 127 | for block in self.F0: 128 | F0 = block(F0, s) 129 | F0 = self.F0_proj(F0) 130 | N = x.transpose(-1, -2) 131 | for block in self.N: 132 | N = block(N, s) 133 | N = self.N_proj(N) 134 | return F0.squeeze(1), N.squeeze(1) 135 | 136 | 137 | class DurationEncoder(nn.Module): 138 | def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): 139 | super().__init__() 140 | self.lstms = nn.ModuleList() 141 | for _ in range(nlayers): 142 | self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout)) 143 | self.lstms.append(AdaLayerNorm(sty_dim, d_model)) 144 | self.dropout = dropout 145 | self.d_model = d_model 146 | self.sty_dim = sty_dim 147 | 148 | def forward(self, x, style, text_lengths, m): 149 | masks = m 150 | x = x.permute(2, 0, 1) 151 | s = style.expand(x.shape[0], x.shape[1], -1) 152 | x = torch.cat([x, s], axis=-1) 153 | x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) 154 | x = x.transpose(0, 1) 155 | x = x.transpose(-1, -2) 156 | for block in self.lstms: 157 | if isinstance(block, AdaLayerNorm): 158 | x = block(x.transpose(-1, -2), style).transpose(-1, -2) 159 | x = torch.cat([x, s.permute(1, 2, 0)], axis=1) 160 | x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0) 161 | else: 162 | lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu') 163 | x = x.transpose(-1, -2) 164 | x = nn.utils.rnn.pack_padded_sequence( 165 | x, lengths, batch_first=True, enforce_sorted=False) 166 | block.flatten_parameters() 167 | x, _ = block(x) 168 | x, _ = nn.utils.rnn.pad_packed_sequence( 169 | x, batch_first=True) 170 | x = F.dropout(x, p=self.dropout, training=False) 171 | x = x.transpose(-1, -2) 172 | x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device) 173 | x_pad[:, :, :x.shape[-1]] = x 174 | x = x_pad 175 | 176 | return x.transpose(-1, -2) 177 | 178 | 179 | # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py 180 | class CustomAlbert(AlbertModel): 181 | def forward(self, *args, **kwargs): 182 | outputs = super().forward(*args, **kwargs) 183 | return outputs.last_hidden_state 184 | -------------------------------------------------------------------------------- /kokoro/custom_stft.py: -------------------------------------------------------------------------------- 1 | from attr import attr 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class CustomSTFT(nn.Module): 8 | """ 9 | STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d. 10 | 11 | - forward STFT => Real-part conv1d + Imag-part conv1d 12 | - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum 13 | - avoids F.unfold, so easier to export to ONNX 14 | - uses replicate or constant padding for 'center=True' to approximate 'reflect' 15 | (reflect is not supported for dynamic shapes in ONNX) 16 | """ 17 | 18 | def __init__( 19 | self, 20 | filter_length=800, 21 | hop_length=200, 22 | win_length=800, 23 | window="hann", 24 | center=True, 25 | pad_mode="replicate", # or 'constant' 26 | ): 27 | super().__init__() 28 | self.filter_length = filter_length 29 | self.hop_length = hop_length 30 | self.win_length = win_length 31 | self.n_fft = filter_length 32 | self.center = center 33 | self.pad_mode = pad_mode 34 | 35 | # Number of frequency bins for real-valued STFT with onesided=True 36 | self.freq_bins = self.n_fft // 2 + 1 37 | 38 | # Build window 39 | assert window == 'hann', window 40 | window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32) 41 | if self.win_length < self.n_fft: 42 | # Zero-pad up to n_fft 43 | extra = self.n_fft - self.win_length 44 | window_tensor = F.pad(window_tensor, (0, extra)) 45 | elif self.win_length > self.n_fft: 46 | window_tensor = window_tensor[: self.n_fft] 47 | self.register_buffer("window", window_tensor) 48 | 49 | # Precompute forward DFT (real, imag) 50 | # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...) 51 | n = np.arange(self.n_fft) 52 | k = np.arange(self.freq_bins) 53 | angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft) 54 | dft_real = np.cos(angle) 55 | dft_imag = -np.sin(angle) # note negative sign 56 | 57 | # Combine window and dft => shape (freq_bins, filter_length) 58 | # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length). 59 | forward_window = window_tensor.numpy() # shape (n_fft,) 60 | forward_real = dft_real * forward_window # (freq_bins, n_fft) 61 | forward_imag = dft_imag * forward_window 62 | 63 | # Convert to PyTorch 64 | forward_real_torch = torch.from_numpy(forward_real).float() 65 | forward_imag_torch = torch.from_numpy(forward_imag).float() 66 | 67 | # Register as Conv1d weight => (out_channels, in_channels, kernel_size) 68 | # out_channels = freq_bins, in_channels=1, kernel_size=n_fft 69 | self.register_buffer( 70 | "weight_forward_real", forward_real_torch.unsqueeze(1) 71 | ) 72 | self.register_buffer( 73 | "weight_forward_imag", forward_imag_torch.unsqueeze(1) 74 | ) 75 | 76 | # Precompute inverse DFT 77 | # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc. 78 | # For simplicity, we won't do the "DC/nyquist not doubled" approach here. 79 | # If you want perfect real iSTFT, you can add that logic. 80 | # This version just yields good approximate reconstruction with Hann + typical overlap. 81 | inv_scale = 1.0 / self.n_fft 82 | n = np.arange(self.n_fft) 83 | angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins) 84 | idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft) 85 | idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft) 86 | 87 | # Multiply by window again for typical overlap-add 88 | # We also incorporate the scale factor 1/n_fft 89 | inv_window = window_tensor.numpy() * inv_scale 90 | backward_real = idft_cos * inv_window # (freq_bins, n_fft) 91 | backward_imag = idft_sin * inv_window 92 | 93 | # We'll implement iSTFT as real+imag conv_transpose with stride=hop. 94 | self.register_buffer( 95 | "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1) 96 | ) 97 | self.register_buffer( 98 | "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1) 99 | ) 100 | 101 | 102 | 103 | def transform(self, waveform: torch.Tensor): 104 | """ 105 | Forward STFT => returns magnitude, phase 106 | Output shape => (batch, freq_bins, frames) 107 | """ 108 | # waveform shape => (B, T). conv1d expects (B, 1, T). 109 | # Optional center pad 110 | if self.center: 111 | pad_len = self.n_fft // 2 112 | waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode) 113 | 114 | x = waveform.unsqueeze(1) # => (B, 1, T) 115 | # Convolution to get real part => shape (B, freq_bins, frames) 116 | real_out = F.conv1d( 117 | x, 118 | self.weight_forward_real, 119 | bias=None, 120 | stride=self.hop_length, 121 | padding=0, 122 | ) 123 | # Imag part 124 | imag_out = F.conv1d( 125 | x, 126 | self.weight_forward_imag, 127 | bias=None, 128 | stride=self.hop_length, 129 | padding=0, 130 | ) 131 | 132 | # magnitude, phase 133 | magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14) 134 | phase = torch.atan2(imag_out, real_out) 135 | # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch 136 | # In this case, PyTorch returns pi, ONNX returns -pi 137 | correction_mask = (imag_out == 0) & (real_out < 0) 138 | phase[correction_mask] = torch.pi 139 | return magnitude, phase 140 | 141 | 142 | def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None): 143 | """ 144 | Inverse STFT => returns waveform shape (B, T). 145 | """ 146 | # magnitude, phase => (B, freq_bins, frames) 147 | # Re-create real/imag => shape (B, freq_bins, frames) 148 | real_part = magnitude * torch.cos(phase) 149 | imag_part = magnitude * torch.sin(phase) 150 | 151 | # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension 152 | # so we do (B, freq_bins, frames) => (B, freq_bins, frames) 153 | # But PyTorch conv_transpose1d expects (B, in_channels, input_length) 154 | real_part = real_part # (B, freq_bins, frames) 155 | imag_part = imag_part 156 | 157 | # real iSTFT => convolve with "backward_real", "backward_imag", and sum 158 | # We'll do 2 conv_transpose calls, each giving (B, 1, time), 159 | # then add them => (B, 1, time). 160 | real_rec = F.conv_transpose1d( 161 | real_part, 162 | self.weight_backward_real, # shape (freq_bins, 1, filter_length) 163 | bias=None, 164 | stride=self.hop_length, 165 | padding=0, 166 | ) 167 | imag_rec = F.conv_transpose1d( 168 | imag_part, 169 | self.weight_backward_imag, 170 | bias=None, 171 | stride=self.hop_length, 172 | padding=0, 173 | ) 174 | # sum => (B, 1, time) 175 | waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part 176 | 177 | # If we used "center=True" in forward, we should remove pad 178 | if self.center: 179 | pad_len = self.n_fft // 2 180 | # Because of transposed convolution, total length might have extra samples 181 | # We remove `pad_len` from start & end if possible 182 | waveform = waveform[..., pad_len:-pad_len] 183 | 184 | # If a specific length is desired, clamp 185 | if length is not None: 186 | waveform = waveform[..., :length] 187 | 188 | # shape => (B, T) 189 | return waveform 190 | 191 | def forward(self, x: torch.Tensor): 192 | """ 193 | Full STFT -> iSTFT pass: returns time-domain reconstruction. 194 | Same interface as your original code. 195 | """ 196 | mag, phase = self.transform(x) 197 | return self.inverse(mag, phase, length=x.shape[-1]) 198 | -------------------------------------------------------------------------------- /src/vram_management/layers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from src.utils import init_weights_on_device 6 | import optimum.quanto.nn.qlinear as qlinear 7 | 8 | def cast_to(weight, dtype, device): 9 | r = torch.empty_like(weight, dtype=dtype, device=device) 10 | r.copy_(weight) 11 | return r 12 | 13 | def cast_to_device(weight, device): 14 | if hasattr(weight, '__class__') and 'optimum.quanto' in str(weight.__class__): 15 | return weight.to(device) 16 | else: 17 | r = torch.empty_like(weight, device=device) 18 | r.copy_(weight) 19 | return r 20 | 21 | class AutoWrappedModule(torch.nn.Module): 22 | def __init__( 23 | self, 24 | module: torch.nn.Module, 25 | offload_dtype, 26 | offload_device, 27 | onload_dtype, 28 | onload_device, 29 | computation_dtype, 30 | computation_device, 31 | ): 32 | super().__init__() 33 | self.module = module.to(dtype=offload_dtype, device=offload_device) 34 | self.offload_dtype = offload_dtype 35 | self.offload_device = offload_device 36 | self.onload_dtype = onload_dtype 37 | self.onload_device = onload_device 38 | self.computation_dtype = computation_dtype 39 | self.computation_device = computation_device 40 | self.state = 0 41 | 42 | def offload(self): 43 | if self.state == 1 and ( 44 | self.offload_dtype != self.onload_dtype 45 | or self.offload_device != self.onload_device 46 | ): 47 | self.module.to(dtype=self.offload_dtype, device=self.offload_device) 48 | self.state = 0 49 | 50 | def onload(self): 51 | if self.state == 0 and ( 52 | self.offload_dtype != self.onload_dtype 53 | or self.offload_device != self.onload_device 54 | ): 55 | self.module.to(dtype=self.onload_dtype, device=self.onload_device) 56 | self.state = 1 57 | 58 | def forward(self, *args, **kwargs): 59 | if ( 60 | self.onload_dtype == self.computation_dtype 61 | and self.onload_device == self.computation_device 62 | ): 63 | module = self.module 64 | else: 65 | module = copy.deepcopy(self.module).to( 66 | dtype=self.computation_dtype, device=self.computation_device 67 | ) 68 | return module(*args, **kwargs) 69 | 70 | 71 | 72 | class AutoWrappedQLinear(qlinear.QLinear): 73 | def __init__( 74 | self, 75 | module: qlinear.QLinear, 76 | offload_dtype, 77 | offload_device, 78 | onload_dtype, 79 | onload_device, 80 | computation_dtype, 81 | computation_device, 82 | ): 83 | with init_weights_on_device(device=torch.device("meta")): 84 | super().__init__( 85 | in_features=module.in_features, 86 | out_features=module.out_features, 87 | bias=module.bias is not None, 88 | device=offload_device, 89 | ) 90 | self.weight = module.weight 91 | self.bias = module.bias 92 | self.offload_device = offload_device 93 | 94 | self.onload_device = onload_device 95 | self.computation_device = computation_device 96 | self.state = 0 97 | 98 | def offload(self): 99 | if self.state == 1 and ( 100 | self.offload_device != self.onload_device 101 | ): 102 | self.to(device=self.offload_device) 103 | self.state = 0 104 | 105 | def onload(self): 106 | if self.state == 0 and ( 107 | self.offload_device != self.onload_device 108 | ): 109 | self.to(device=self.onload_device) 110 | self.state = 1 111 | 112 | def forward(self, x, *args, **kwargs): 113 | if ( 114 | self.onload_device == self.computation_device 115 | ): 116 | 117 | return torch.nn.functional.linear(x, self.weight, bias=self.bias) 118 | else: 119 | 120 | qweight = cast_to_device(self.weight, self.computation_device) 121 | bias = ( 122 | None 123 | if self.bias is None 124 | else cast_to_device(self.bias, self.computation_device) 125 | ) 126 | return torch.nn.functional.linear(x, qweight, bias) 127 | 128 | class AutoWrappedLinear(torch.nn.Linear): 129 | def __init__( 130 | self, 131 | module: torch.nn.Linear, 132 | offload_dtype, 133 | offload_device, 134 | onload_dtype, 135 | onload_device, 136 | computation_dtype, 137 | computation_device, 138 | ): 139 | with init_weights_on_device(device=torch.device("meta")): 140 | super().__init__( 141 | in_features=module.in_features, 142 | out_features=module.out_features, 143 | bias=module.bias is not None, 144 | dtype=offload_dtype, 145 | device=offload_device, 146 | ) 147 | self.weight = module.weight 148 | self.bias = module.bias 149 | self.offload_dtype = offload_dtype 150 | self.offload_device = offload_device 151 | self.onload_dtype = onload_dtype 152 | self.onload_device = onload_device 153 | self.computation_dtype = computation_dtype 154 | self.computation_device = computation_device 155 | self.state = 0 156 | 157 | def offload(self): 158 | if self.state == 1 and ( 159 | self.offload_dtype != self.onload_dtype 160 | or self.offload_device != self.onload_device 161 | ): 162 | self.to(dtype=self.offload_dtype, device=self.offload_device) 163 | self.state = 0 164 | 165 | def onload(self): 166 | if self.state == 0 and ( 167 | self.offload_dtype != self.onload_dtype 168 | or self.offload_device != self.onload_device 169 | ): 170 | self.to(dtype=self.onload_dtype, device=self.onload_device) 171 | self.state = 1 172 | 173 | def forward(self, x, *args, **kwargs): 174 | if ( 175 | self.onload_dtype == self.computation_dtype 176 | and self.onload_device == self.computation_device 177 | ): 178 | weight, bias = self.weight, self.bias 179 | else: 180 | weight = cast_to( 181 | self.weight, self.computation_dtype, self.computation_device 182 | ) 183 | bias = ( 184 | None 185 | if self.bias is None 186 | else cast_to(self.bias, self.computation_dtype, self.computation_device) 187 | ) 188 | return torch.nn.functional.linear(x, weight, bias) 189 | 190 | 191 | def enable_vram_management_recursively( 192 | model: torch.nn.Module, 193 | module_map: dict, 194 | module_config: dict, 195 | max_num_param=None, 196 | overflow_module_config: dict = None, 197 | total_num_param=0, 198 | ): 199 | for name, module in model.named_children(): 200 | for source_module, target_module in module_map.items(): 201 | if isinstance(module, source_module): 202 | num_param = sum(p.numel() for p in module.parameters()) 203 | # print(str(module) + ':' + str(num_param)) 204 | if ( 205 | max_num_param is not None 206 | and total_num_param + num_param > max_num_param 207 | ): 208 | # print(str(module) + '-->\t\t num:' + str(num_param) + "\t total:" + str(total_num_param)) 209 | module_config_ = overflow_module_config 210 | else: 211 | module_config_ = module_config 212 | module_ = target_module(module, **module_config_) 213 | setattr(model, name, module_) 214 | total_num_param += num_param 215 | break 216 | else: 217 | total_num_param = enable_vram_management_recursively( 218 | module, 219 | module_map, 220 | module_config, 221 | max_num_param, 222 | overflow_module_config, 223 | total_num_param, 224 | ) 225 | return total_num_param 226 | 227 | 228 | def enable_vram_management( 229 | model: torch.nn.Module, 230 | module_map: dict, 231 | module_config: dict, 232 | max_num_param=None, 233 | overflow_module_config: dict = None, 234 | ): 235 | enable_vram_management_recursively( 236 | model, 237 | module_map, 238 | module_config, 239 | max_num_param, 240 | overflow_module_config, 241 | total_num_param=0, 242 | ) 243 | model.vram_management_enabled = True 244 | -------------------------------------------------------------------------------- /wan/modules/vace_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.cuda.amp as amp 4 | import torch.nn as nn 5 | from diffusers.configuration_utils import register_to_config 6 | 7 | from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d 8 | 9 | 10 | class VaceWanAttentionBlock(WanAttentionBlock): 11 | 12 | def __init__(self, 13 | cross_attn_type, 14 | dim, 15 | ffn_dim, 16 | num_heads, 17 | window_size=(-1, -1), 18 | qk_norm=True, 19 | cross_attn_norm=False, 20 | eps=1e-6, 21 | block_id=0): 22 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, 23 | qk_norm, cross_attn_norm, eps) 24 | self.block_id = block_id 25 | if block_id == 0: 26 | self.before_proj = nn.Linear(self.dim, self.dim) 27 | nn.init.zeros_(self.before_proj.weight) 28 | nn.init.zeros_(self.before_proj.bias) 29 | self.after_proj = nn.Linear(self.dim, self.dim) 30 | nn.init.zeros_(self.after_proj.weight) 31 | nn.init.zeros_(self.after_proj.bias) 32 | 33 | def forward(self, c, x, **kwargs): 34 | if self.block_id == 0: 35 | c = self.before_proj(c) + x 36 | 37 | c = super().forward(c, **kwargs) 38 | c_skip = self.after_proj(c) 39 | return c, c_skip 40 | 41 | 42 | class BaseWanAttentionBlock(WanAttentionBlock): 43 | 44 | def __init__(self, 45 | cross_attn_type, 46 | dim, 47 | ffn_dim, 48 | num_heads, 49 | window_size=(-1, -1), 50 | qk_norm=True, 51 | cross_attn_norm=False, 52 | eps=1e-6, 53 | block_id=None): 54 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, 55 | qk_norm, cross_attn_norm, eps) 56 | self.block_id = block_id 57 | 58 | def forward(self, x, hints, context_scale=1.0, **kwargs): 59 | x = super().forward(x, **kwargs) 60 | if self.block_id is not None: 61 | x = x + hints[self.block_id] * context_scale 62 | return x 63 | 64 | 65 | class VaceWanModel(WanModel): 66 | 67 | @register_to_config 68 | def __init__(self, 69 | vace_layers=None, 70 | vace_in_dim=None, 71 | model_type='vace', 72 | patch_size=(1, 2, 2), 73 | text_len=512, 74 | in_dim=16, 75 | dim=2048, 76 | ffn_dim=8192, 77 | freq_dim=256, 78 | text_dim=4096, 79 | out_dim=16, 80 | num_heads=16, 81 | num_layers=32, 82 | window_size=(-1, -1), 83 | qk_norm=True, 84 | cross_attn_norm=True, 85 | eps=1e-6): 86 | super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, 87 | freq_dim, text_dim, out_dim, num_heads, num_layers, 88 | window_size, qk_norm, cross_attn_norm, eps) 89 | 90 | self.vace_layers = [i for i in range(0, self.num_layers, 2) 91 | ] if vace_layers is None else vace_layers 92 | self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim 93 | 94 | assert 0 in self.vace_layers 95 | self.vace_layers_mapping = { 96 | i: n for n, i in enumerate(self.vace_layers) 97 | } 98 | 99 | # blocks 100 | self.blocks = nn.ModuleList([ 101 | BaseWanAttentionBlock( 102 | 't2v_cross_attn', 103 | self.dim, 104 | self.ffn_dim, 105 | self.num_heads, 106 | self.window_size, 107 | self.qk_norm, 108 | self.cross_attn_norm, 109 | self.eps, 110 | block_id=self.vace_layers_mapping[i] 111 | if i in self.vace_layers else None) 112 | for i in range(self.num_layers) 113 | ]) 114 | 115 | # vace blocks 116 | self.vace_blocks = nn.ModuleList([ 117 | VaceWanAttentionBlock( 118 | 't2v_cross_attn', 119 | self.dim, 120 | self.ffn_dim, 121 | self.num_heads, 122 | self.window_size, 123 | self.qk_norm, 124 | self.cross_attn_norm, 125 | self.eps, 126 | block_id=i) for i in self.vace_layers 127 | ]) 128 | 129 | # vace patch embeddings 130 | self.vace_patch_embedding = nn.Conv3d( 131 | self.vace_in_dim, 132 | self.dim, 133 | kernel_size=self.patch_size, 134 | stride=self.patch_size) 135 | 136 | def forward_vace(self, x, vace_context, seq_len, kwargs): 137 | # embeddings 138 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] 139 | c = [u.flatten(2).transpose(1, 2) for u in c] 140 | c = torch.cat([ 141 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 142 | dim=1) for u in c 143 | ]) 144 | 145 | # arguments 146 | new_kwargs = dict(x=x) 147 | new_kwargs.update(kwargs) 148 | 149 | hints = [] 150 | for block in self.vace_blocks: 151 | c, c_skip = block(c, **new_kwargs) 152 | hints.append(c_skip) 153 | return hints 154 | 155 | def forward( 156 | self, 157 | x, 158 | t, 159 | vace_context, 160 | context, 161 | seq_len, 162 | vace_context_scale=1.0, 163 | clip_fea=None, 164 | y=None, 165 | ): 166 | r""" 167 | Forward pass through the diffusion model 168 | 169 | Args: 170 | x (List[Tensor]): 171 | List of input video tensors, each with shape [C_in, F, H, W] 172 | t (Tensor): 173 | Diffusion timesteps tensor of shape [B] 174 | context (List[Tensor]): 175 | List of text embeddings each with shape [L, C] 176 | seq_len (`int`): 177 | Maximum sequence length for positional encoding 178 | clip_fea (Tensor, *optional*): 179 | CLIP image features for image-to-video mode 180 | y (List[Tensor], *optional*): 181 | Conditional video inputs for image-to-video mode, same shape as x 182 | 183 | Returns: 184 | List[Tensor]: 185 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 186 | """ 187 | # if self.model_type == 'i2v': 188 | # assert clip_fea is not None and y is not None 189 | # params 190 | device = self.patch_embedding.weight.device 191 | if self.freqs.device != device: 192 | self.freqs = self.freqs.to(device) 193 | 194 | # if y is not None: 195 | # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 196 | 197 | # embeddings 198 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 199 | grid_sizes = torch.stack( 200 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 201 | x = [u.flatten(2).transpose(1, 2) for u in x] 202 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 203 | assert seq_lens.max() <= seq_len 204 | x = torch.cat([ 205 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 206 | dim=1) for u in x 207 | ]) 208 | 209 | # time embeddings 210 | with amp.autocast(dtype=torch.float32): 211 | e = self.time_embedding( 212 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 213 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 214 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 215 | 216 | # context 217 | context_lens = None 218 | context = self.text_embedding( 219 | torch.stack([ 220 | torch.cat( 221 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 222 | for u in context 223 | ])) 224 | 225 | # if clip_fea is not None: 226 | # context_clip = self.img_emb(clip_fea) # bs x 257 x dim 227 | # context = torch.concat([context_clip, context], dim=1) 228 | 229 | # arguments 230 | kwargs = dict( 231 | e=e0, 232 | seq_lens=seq_lens, 233 | grid_sizes=grid_sizes, 234 | freqs=self.freqs, 235 | context=context, 236 | context_lens=context_lens) 237 | 238 | hints = self.forward_vace(x, vace_context, seq_len, kwargs) 239 | kwargs['hints'] = hints 240 | kwargs['context_scale'] = vace_context_scale 241 | 242 | for block in self.blocks: 243 | x = block(x, **kwargs) 244 | 245 | # head 246 | x = self.head(x, e) 247 | 248 | # unpatchify 249 | x = self.unpatchify(x, grid_sizes) 250 | return [u.float() for u in x] 251 | -------------------------------------------------------------------------------- /wan/text2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import torch 13 | import torch.cuda.amp as amp 14 | import torch.distributed as dist 15 | from tqdm import tqdm 16 | 17 | from .distributed.fsdp import shard_model 18 | from .modules.model import WanModel 19 | from .modules.t5 import T5EncoderModel 20 | from .modules.vae import WanVAE 21 | from .utils.fm_solvers import ( 22 | FlowDPMSolverMultistepScheduler, 23 | get_sampling_sigmas, 24 | retrieve_timesteps, 25 | ) 26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 27 | 28 | 29 | class WanT2V: 30 | 31 | def __init__( 32 | self, 33 | config, 34 | checkpoint_dir, 35 | device_id=0, 36 | rank=0, 37 | t5_fsdp=False, 38 | dit_fsdp=False, 39 | use_usp=False, 40 | t5_cpu=False, 41 | ): 42 | r""" 43 | Initializes the Wan text-to-video generation model components. 44 | 45 | Args: 46 | config (EasyDict): 47 | Object containing model parameters initialized from config.py 48 | checkpoint_dir (`str`): 49 | Path to directory containing model checkpoints 50 | device_id (`int`, *optional*, defaults to 0): 51 | Id of target GPU device 52 | rank (`int`, *optional*, defaults to 0): 53 | Process rank for distributed training 54 | t5_fsdp (`bool`, *optional*, defaults to False): 55 | Enable FSDP sharding for T5 model 56 | dit_fsdp (`bool`, *optional*, defaults to False): 57 | Enable FSDP sharding for DiT model 58 | use_usp (`bool`, *optional*, defaults to False): 59 | Enable distribution strategy of USP. 60 | t5_cpu (`bool`, *optional*, defaults to False): 61 | Whether to place T5 model on CPU. Only works without t5_fsdp. 62 | """ 63 | self.device = torch.device(f"cuda:{device_id}") 64 | self.config = config 65 | self.rank = rank 66 | self.t5_cpu = t5_cpu 67 | 68 | self.num_train_timesteps = config.num_train_timesteps 69 | self.param_dtype = config.param_dtype 70 | 71 | shard_fn = partial(shard_model, device_id=device_id) 72 | self.text_encoder = T5EncoderModel( 73 | text_len=config.text_len, 74 | dtype=config.t5_dtype, 75 | device=torch.device('cpu'), 76 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 77 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 78 | shard_fn=shard_fn if t5_fsdp else None) 79 | 80 | self.vae_stride = config.vae_stride 81 | self.patch_size = config.patch_size 82 | self.vae = WanVAE( 83 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 84 | device=self.device) 85 | 86 | logging.info(f"Creating WanModel from {checkpoint_dir}") 87 | self.model = WanModel.from_pretrained(checkpoint_dir) 88 | self.model.eval().requires_grad_(False) 89 | 90 | if use_usp: 91 | from xfuser.core.distributed import get_sequence_parallel_world_size 92 | 93 | from .distributed.xdit_context_parallel import ( 94 | usp_attn_forward, 95 | usp_dit_forward, 96 | ) 97 | for block in self.model.blocks: 98 | block.self_attn.forward = types.MethodType( 99 | usp_attn_forward, block.self_attn) 100 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 101 | self.sp_size = get_sequence_parallel_world_size() 102 | else: 103 | self.sp_size = 1 104 | 105 | if dist.is_initialized(): 106 | dist.barrier() 107 | if dit_fsdp: 108 | self.model = shard_fn(self.model) 109 | else: 110 | self.model.to(self.device) 111 | 112 | self.sample_neg_prompt = config.sample_neg_prompt 113 | 114 | def generate(self, 115 | input_prompt, 116 | size=(1280, 720), 117 | frame_num=81, 118 | shift=5.0, 119 | sample_solver='unipc', 120 | sampling_steps=50, 121 | guide_scale=5.0, 122 | n_prompt="", 123 | seed=-1, 124 | offload_model=True): 125 | r""" 126 | Generates video frames from text prompt using diffusion process. 127 | 128 | Args: 129 | input_prompt (`str`): 130 | Text prompt for content generation 131 | size (tupele[`int`], *optional*, defaults to (1280,720)): 132 | Controls video resolution, (width,height). 133 | frame_num (`int`, *optional*, defaults to 81): 134 | How many frames to sample from a video. The number should be 4n+1 135 | shift (`float`, *optional*, defaults to 5.0): 136 | Noise schedule shift parameter. Affects temporal dynamics 137 | sample_solver (`str`, *optional*, defaults to 'unipc'): 138 | Solver used to sample the video. 139 | sampling_steps (`int`, *optional*, defaults to 40): 140 | Number of diffusion sampling steps. Higher values improve quality but slow generation 141 | guide_scale (`float`, *optional*, defaults 5.0): 142 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 143 | n_prompt (`str`, *optional*, defaults to ""): 144 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 145 | seed (`int`, *optional*, defaults to -1): 146 | Random seed for noise generation. If -1, use random seed. 147 | offload_model (`bool`, *optional*, defaults to True): 148 | If True, offloads models to CPU during generation to save VRAM 149 | 150 | Returns: 151 | torch.Tensor: 152 | Generated video frames tensor. Dimensions: (C, N H, W) where: 153 | - C: Color channels (3 for RGB) 154 | - N: Number of frames (81) 155 | - H: Frame height (from size) 156 | - W: Frame width from size) 157 | """ 158 | # preprocess 159 | F = frame_num 160 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, 161 | size[1] // self.vae_stride[1], 162 | size[0] // self.vae_stride[2]) 163 | 164 | seq_len = math.ceil((target_shape[2] * target_shape[3]) / 165 | (self.patch_size[1] * self.patch_size[2]) * 166 | target_shape[1] / self.sp_size) * self.sp_size 167 | 168 | if n_prompt == "": 169 | n_prompt = self.sample_neg_prompt 170 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 171 | seed_g = torch.Generator(device=self.device) 172 | seed_g.manual_seed(seed) 173 | 174 | if not self.t5_cpu: 175 | self.text_encoder.model.to(self.device) 176 | context = self.text_encoder([input_prompt], self.device) 177 | context_null = self.text_encoder([n_prompt], self.device) 178 | if offload_model: 179 | self.text_encoder.model.cpu() 180 | else: 181 | context = self.text_encoder([input_prompt], torch.device('cpu')) 182 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 183 | context = [t.to(self.device) for t in context] 184 | context_null = [t.to(self.device) for t in context_null] 185 | 186 | noise = [ 187 | torch.randn( 188 | target_shape[0], 189 | target_shape[1], 190 | target_shape[2], 191 | target_shape[3], 192 | dtype=torch.float32, 193 | device=self.device, 194 | generator=seed_g) 195 | ] 196 | 197 | @contextmanager 198 | def noop_no_sync(): 199 | yield 200 | 201 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 202 | 203 | # evaluation mode 204 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 205 | 206 | if sample_solver == 'unipc': 207 | sample_scheduler = FlowUniPCMultistepScheduler( 208 | num_train_timesteps=self.num_train_timesteps, 209 | shift=1, 210 | use_dynamic_shifting=False) 211 | sample_scheduler.set_timesteps( 212 | sampling_steps, device=self.device, shift=shift) 213 | timesteps = sample_scheduler.timesteps 214 | elif sample_solver == 'dpm++': 215 | sample_scheduler = FlowDPMSolverMultistepScheduler( 216 | num_train_timesteps=self.num_train_timesteps, 217 | shift=1, 218 | use_dynamic_shifting=False) 219 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 220 | timesteps, _ = retrieve_timesteps( 221 | sample_scheduler, 222 | device=self.device, 223 | sigmas=sampling_sigmas) 224 | else: 225 | raise NotImplementedError("Unsupported solver.") 226 | 227 | # sample videos 228 | latents = noise 229 | 230 | arg_c = {'context': context, 'seq_len': seq_len} 231 | arg_null = {'context': context_null, 'seq_len': seq_len} 232 | 233 | for _, t in enumerate(tqdm(timesteps)): 234 | latent_model_input = latents 235 | timestep = [t] 236 | 237 | timestep = torch.stack(timestep) 238 | 239 | self.model.to(self.device) 240 | noise_pred_cond = self.model( 241 | latent_model_input, t=timestep, **arg_c)[0] 242 | noise_pred_uncond = self.model( 243 | latent_model_input, t=timestep, **arg_null)[0] 244 | 245 | noise_pred = noise_pred_uncond + guide_scale * ( 246 | noise_pred_cond - noise_pred_uncond) 247 | 248 | temp_x0 = sample_scheduler.step( 249 | noise_pred.unsqueeze(0), 250 | t, 251 | latents[0].unsqueeze(0), 252 | return_dict=False, 253 | generator=seed_g)[0] 254 | latents = [temp_x0.squeeze(0)] 255 | 256 | x0 = latents 257 | if offload_model: 258 | self.model.cpu() 259 | torch.cuda.empty_cache() 260 | if self.rank == 0: 261 | videos = self.vae.decode(x0) 262 | 263 | del noise, latents 264 | del sample_scheduler 265 | if offload_model: 266 | gc.collect() 267 | torch.cuda.synchronize() 268 | if dist.is_initialized(): 269 | dist.barrier() 270 | 271 | return videos[0] if self.rank == 0 else None 272 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | InfinteTalk 5 |

6 | 7 |

InfiniteTalk: Audio-driven Video Generation for Sparse-Frame Video Dubbing

8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | > **TL; DR:** InfiniteTalk is an unlimited-length talking video generation​​ model that supports both audio-driven video-to-video and image-to-video generation 19 | 20 |

21 | 22 |

23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | ## 🔥 Latest News 31 | 32 | * August 18, 2025: We release the [Technique-Report]() of **InfiniteTalk**. The Gradio and the [ComfyUI](https://github.com/MeiGen-AI/InfiniteTalk/tree/comfyui) branch have been released. 33 | * August 18, 2025: We release the [project page]() of **InfiniteTalk** 34 | 35 | 36 | ## ✨ Key Features 37 | We propose **InfiniteTalk**​​, a novel sparse-frame video dubbing framework. Given an input video and audio track, InfiniteTalk synthesizes a new video with ​​accurate lip synchronization​​ while ​​simultaneously aligning head movements, body posture, and facial expressions​​ with the audio. Unlike traditional dubbing methods that focus solely on lips, InfiniteTalk enables ​​infinite-length video generation​​ with accurate lip synchronization and consistent identity preservation. Beside, InfiniteTalk can also be used as an image-audio-to-video model with an image and an audio as input. 38 | - 💬 ​​Sparse-frame Video Dubbing​​ – Synchronizes not only lips, but aslo head, body, and expressions 39 | - ⏱️ ​​Infinite-Length Generation​​ – Supports unlimited video duration 40 | - ✨ ​​Stability​​ – Reduces hand/body distortions compared to MultiTalk 41 | - 🚀 ​​Lip Accuracy​​ – Achieves superior lip synchronization to MultiTalk 42 | 43 | 44 | 45 | ## 🌐 Community Works 46 | - 47 | 48 | 49 | ## 📑 Todo List 50 | 51 | - [x] Release the technical report 52 | - [x] Inference 53 | - [x] Checkpoints 54 | - [x] Multi-GPU Inference 55 | - [ ] Inference acceleration 56 | - [x] TeaCache 57 | - [x] int8 quantization 58 | - [ ] LCM distillation 59 | - [ ] Sparse Attention 60 | - [x] Run with very low VRAM 61 | - [x] Gradio demo 62 | - [x] ComfyUI 63 | 64 | 65 | ## Quick Start 66 | 67 | ### 🛠️Installation 68 | 69 | #### 1. Create a conda environment and install pytorch, xformers 70 | ``` 71 | conda create -n multitalk python=3.10 72 | conda activate multitalk 73 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121 74 | pip install -U xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121 75 | ``` 76 | #### 2. Flash-attn installation: 77 | ``` 78 | pip install misaki[en] 79 | pip install ninja 80 | pip install psutil 81 | pip install packaging 82 | pip install flash_attn==2.7.4.post1 83 | ``` 84 | 85 | #### 3. Other dependencies 86 | ``` 87 | pip install -r requirements.txt 88 | conda install -c conda-forge librosa 89 | ``` 90 | 91 | #### 4. FFmeg installation 92 | ``` 93 | conda install -c conda-forge ffmpeg 94 | ``` 95 | or 96 | ``` 97 | sudo yum install ffmpeg ffmpeg-devel 98 | ``` 99 | 100 | ### 🧱Model Preparation 101 | 102 | #### 1. Model Download 103 | 104 | | Models | Download Link | Notes | 105 | | --------------|-------------------------------------------------------------------------------|-------------------------------| 106 | | Wan2.1-I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | Base model 107 | | chinese-wav2vec2-base | 🤗 [Huggingface](https://huggingface.co/TencentGameMate/chinese-wav2vec2-base) | Audio encoder 108 | | MeiGen-InfiniteTalk | 🤗 [Huggingface]() | Our audio condition weights 109 | 110 | Download models using huggingface-cli: 111 | ``` sh 112 | huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P 113 | huggingface-cli download TencentGameMate/chinese-wav2vec2-base --local-dir ./weights/chinese-wav2vec2-base 114 | huggingface-cli download TencentGameMate/chinese-wav2vec2-base model.safetensors --revision refs/pr/1 --local-dir ./weights/chinese-wav2vec2-base 115 | ``` 116 | 117 | #### 2. Link or Copy MultiTalk Model to Wan2.1-I2V-14B-480P Directory 118 | 119 | Link through: 120 | ``` 121 | 122 | ``` 123 | ### 🔑 Quick Inference 124 | 125 | Our model is compatible with both 480P and 720P resolutions. 126 | > Some tips 127 | > - Lip synchronization accuracy:​​ Audio CFG works optimally between 3–5. Increase the audio CFG value for better synchronization. 128 | > - FusionX: While it enables faster inference and higher quality, FusionX LoRA exacerbates color shift over 1 minute and reduces ID preservation in videos. 129 | > - V2V generation: Enables unlimited length generation. The model mimics the original video's camera movement, though not identically. Using SDEdit improves camera movement accuracy significantly but introduces color shift and is best suited for short clips. Improvements for long video camera control are planned. 130 | > - I2V generation: Generates good results from a single image for up to 1 minute. Beyond 1 minute, color shifts become more pronounced. One trick for the high-quailty generation beyond 1 min is to copy the image to a video by translating or zooming in the image. 131 | 132 | 133 | #### Usage of InfiniteTalk 134 | ``` 135 | --mode streaming: long video generation. 136 | --mode clip: generate short video with one chunk. 137 | --use_teacache: run with TeaCache. 138 | --size infinitetalk-480: generate 480P video. 139 | --size infinitetalk-720: generate 720P video. 140 | --use_apg: run with APG. 141 | --teacache_thresh: A coefficient used for TeaCache acceleration 142 | —-sample_text_guide_scale: When not using LoRA, the optimal value is 5. After applying LoRA, the recommended value is 1. 143 | —-sample_audio_guide_scale: When not using LoRA, the optimal value is 4. After applying LoRA, the recommended value is 2. 144 | ``` 145 | 146 | #### 1. Inference 147 | 148 | ##### 1) Run with single GPU 149 | 150 | 151 | ``` 152 | python generate_infinitetalk.py \ 153 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 154 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 155 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 156 | --input_json examples/single_example_image.json \ 157 | --size infinitetalk-480 \ 158 | --sample_steps 40 \ 159 | --mode streaming \ 160 | --motion_frame 9 \ 161 | --save_file infinitetalk_res 162 | 163 | ``` 164 | 165 | ##### 2) Run with 720P 166 | 167 | If you want run with 720P, set `--size infinitetalk-720`: 168 | 169 | ``` 170 | python generate_infinitetalk.py \ 171 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 172 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 173 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 174 | --input_json examples/single_example_image.json \ 175 | --size infinitetalk-720 \ 176 | --sample_steps 40 \ 177 | --mode streaming \ 178 | --motion_frame 9 \ 179 | --save_file infinitetalk_res_720p 180 | 181 | ``` 182 | 183 | ##### 3) Run with very low VRAM 184 | 185 | If you want run with very low VRAM, set `--num_persistent_param_in_dit 0`: 186 | 187 | 188 | ``` 189 | python generate_infinitetalk.py \ 190 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 191 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 192 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 193 | --input_json examples/single_example_image.json \ 194 | --size infinitetalk-480 \ 195 | --sample_steps 40 \ 196 | --num_persistent_param_in_dit 0 \ 197 | --mode streaming \ 198 | --motion_frame 9 \ 199 | --save_file infinitetalk_res_lowvram 200 | ``` 201 | 202 | ##### 4) Multi-GPU inference 203 | 204 | ``` 205 | GPU_NUM=8 206 | torchrun --nproc_per_node=$GPU_NUM --standalone generate_infinitetalk.py \ 207 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 208 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 209 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 210 | --dit_fsdp --t5_fsdp \ 211 | --ulysses_size=$GPU_NUM \ 212 | --input_json examples/single_example_image.json \ 213 | --size infinitetalk-480 \ 214 | --sample_steps 40 \ 215 | --mode streaming \ 216 | --motion_frame 9 \ 217 | --save_file infinitetalk_res_multigpu 218 | ``` 219 | 220 | 221 | 222 | 223 | #### 2. Run with FusioniX or Lightx2v(Require only 4~8 steps) 224 | 225 | [FusioniX](https://huggingface.co/vrgamedevgirl84/Wan14BT2VFusioniX/blob/main/FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors) require 8 steps and [lightx2v](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors) requires only 4 steps. 226 | 227 | ``` 228 | python generate_infinitetalk.py \ 229 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 230 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 231 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 232 | --lora_dir weights/Wan2.1_I2V_14B_FusionX_LoRA.safetensors \ 233 | --input_json examples/single_example_image.json \ 234 | --lora_scale 1.0 \ 235 | --size infinitetalk-480 \ 236 | --sample_text_guide_scale 1.0 \ 237 | --sample_audio_guide_scale 2.0 \ 238 | --sample_steps 8 \ 239 | --mode streaming \ 240 | --motion_frame 9 \ 241 | --sample_shift 2 \ 242 | --num_persistent_param_in_dit 0 \ 243 | --save_file infinitetalk_res_lora 244 | ``` 245 | 246 | 247 | 248 | #### 3. Run with the quantization model (Only support run with single gpu) 249 | 250 | ``` 251 | python generate_infinitetalk.py \ 252 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 253 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 254 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 255 | --input_json examples/single_example_image.json \ 256 | --size infinitetalk-480 \ 257 | --sample_steps 40 \ 258 | --mode streaming \ 259 | --quant fp8 \ 260 | --quant_dir weights/InfiniteTalk/quant_models/infinitetalk_single_fp8.safetensors \ 261 | --motion_frame 9 \ 262 | --num_persistent_param_in_dit 0 \ 263 | --save_file infinitetalk_res_quant 264 | ``` 265 | 266 | 267 | #### 4. Run with Gradio 268 | 269 | 270 | 271 | ``` 272 | python app.py \ 273 | --ckpt_dir weights/Wan2.1-I2V-14B-480P \ 274 | --wav2vec_dir 'weights/chinese-wav2vec2-base' \ 275 | --infinitetalk_dir weights/InfiniteTalk/single/infinitetalk.safetensors \ 276 | --input_json examples/single_example_image.json \ 277 | --num_persistent_param_in_dit 0 \ 278 | --motion_frame 9 279 | ``` 280 | 281 | 282 | 283 | ## 📚 Citation 284 | 285 | If you find our work useful in your research, please consider citing: 286 | 287 | ``` 288 | 289 | ``` 290 | 291 | ## 📜 License 292 | The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, 293 | granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. 294 | You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, 295 | causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. 296 | 297 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /wan/utils/vace_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.transforms.functional as TF 6 | from PIL import Image 7 | 8 | 9 | class VaceImageProcessor(object): 10 | 11 | def __init__(self, downsample=None, seq_len=None): 12 | self.downsample = downsample 13 | self.seq_len = seq_len 14 | 15 | def _pillow_convert(self, image, cvt_type='RGB'): 16 | if image.mode != cvt_type: 17 | if image.mode == 'P': 18 | image = image.convert(f'{cvt_type}A') 19 | if image.mode == f'{cvt_type}A': 20 | bg = Image.new( 21 | cvt_type, 22 | size=(image.width, image.height), 23 | color=(255, 255, 255)) 24 | bg.paste(image, (0, 0), mask=image) 25 | image = bg 26 | else: 27 | image = image.convert(cvt_type) 28 | return image 29 | 30 | def _load_image(self, img_path): 31 | if img_path is None or img_path == '': 32 | return None 33 | img = Image.open(img_path) 34 | img = self._pillow_convert(img) 35 | return img 36 | 37 | def _resize_crop(self, img, oh, ow, normalize=True): 38 | """ 39 | Resize, center crop, convert to tensor, and normalize. 40 | """ 41 | # resize and crop 42 | iw, ih = img.size 43 | if iw != ow or ih != oh: 44 | # resize 45 | scale = max(ow / iw, oh / ih) 46 | img = img.resize((round(scale * iw), round(scale * ih)), 47 | resample=Image.Resampling.LANCZOS) 48 | assert img.width >= ow and img.height >= oh 49 | 50 | # center crop 51 | x1 = (img.width - ow) // 2 52 | y1 = (img.height - oh) // 2 53 | img = img.crop((x1, y1, x1 + ow, y1 + oh)) 54 | 55 | # normalize 56 | if normalize: 57 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) 58 | return img 59 | 60 | def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): 61 | return self._resize_crop(img, oh, ow, normalize) 62 | 63 | def load_image(self, data_key, **kwargs): 64 | return self.load_image_batch(data_key, **kwargs) 65 | 66 | def load_image_pair(self, data_key, data_key2, **kwargs): 67 | return self.load_image_batch(data_key, data_key2, **kwargs) 68 | 69 | def load_image_batch(self, 70 | *data_key_batch, 71 | normalize=True, 72 | seq_len=None, 73 | **kwargs): 74 | seq_len = self.seq_len if seq_len is None else seq_len 75 | imgs = [] 76 | for data_key in data_key_batch: 77 | img = self._load_image(data_key) 78 | imgs.append(img) 79 | w, h = imgs[0].size 80 | dh, dw = self.downsample[1:] 81 | 82 | # compute output size 83 | scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) 84 | oh = int(h * scale) // dh * dh 85 | ow = int(w * scale) // dw * dw 86 | assert (oh // dh) * (ow // dw) <= seq_len 87 | imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] 88 | return *imgs, (oh, ow) 89 | 90 | 91 | class VaceVideoProcessor(object): 92 | 93 | def __init__(self, downsample, min_area, max_area, min_fps, max_fps, 94 | zero_start, seq_len, keep_last, **kwargs): 95 | self.downsample = downsample 96 | self.min_area = min_area 97 | self.max_area = max_area 98 | self.min_fps = min_fps 99 | self.max_fps = max_fps 100 | self.zero_start = zero_start 101 | self.keep_last = keep_last 102 | self.seq_len = seq_len 103 | assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) 104 | 105 | def set_area(self, area): 106 | self.min_area = area 107 | self.max_area = area 108 | 109 | def set_seq_len(self, seq_len): 110 | self.seq_len = seq_len 111 | 112 | @staticmethod 113 | def resize_crop(video: torch.Tensor, oh: int, ow: int): 114 | """ 115 | Resize, center crop and normalize for decord loaded video (torch.Tensor type) 116 | 117 | Parameters: 118 | video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) 119 | oh - target height (int) 120 | ow - target width (int) 121 | 122 | Returns: 123 | The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) 124 | 125 | Raises: 126 | """ 127 | # permute ([t, h, w, c] -> [t, c, h, w]) 128 | video = video.permute(0, 3, 1, 2) 129 | 130 | # resize and crop 131 | ih, iw = video.shape[2:] 132 | if ih != oh or iw != ow: 133 | # resize 134 | scale = max(ow / iw, oh / ih) 135 | video = F.interpolate( 136 | video, 137 | size=(round(scale * ih), round(scale * iw)), 138 | mode='bicubic', 139 | antialias=True) 140 | assert video.size(3) >= ow and video.size(2) >= oh 141 | 142 | # center crop 143 | x1 = (video.size(3) - ow) // 2 144 | y1 = (video.size(2) - oh) // 2 145 | video = video[:, :, y1:y1 + oh, x1:x1 + ow] 146 | 147 | # permute ([t, c, h, w] -> [c, t, h, w]) and normalize 148 | video = video.transpose(0, 1).float().div_(127.5).sub_(1.) 149 | return video 150 | 151 | def _video_preprocess(self, video, oh, ow): 152 | return self.resize_crop(video, oh, ow) 153 | 154 | def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, 155 | rng): 156 | target_fps = min(fps, self.max_fps) 157 | duration = frame_timestamps[-1].mean() 158 | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box 159 | h, w = y2 - y1, x2 - x1 160 | ratio = h / w 161 | df, dh, dw = self.downsample 162 | 163 | area_z = min(self.seq_len, self.max_area / (dh * dw), 164 | (h // dh) * (w // dw)) 165 | of = min((int(duration * target_fps) - 1) // df + 1, 166 | int(self.seq_len / area_z)) 167 | 168 | # deduce target shape of the [latent video] 169 | target_area_z = min(area_z, int(self.seq_len / of)) 170 | oh = round(np.sqrt(target_area_z * ratio)) 171 | ow = int(target_area_z / oh) 172 | of = (of - 1) * df + 1 173 | oh *= dh 174 | ow *= dw 175 | 176 | # sample frame ids 177 | target_duration = of / target_fps 178 | begin = 0. if self.zero_start else rng.uniform( 179 | 0, duration - target_duration) 180 | timestamps = np.linspace(begin, begin + target_duration, of) 181 | frame_ids = np.argmax( 182 | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], 183 | timestamps[:, None] < frame_timestamps[None, :, 1]), 184 | axis=1).tolist() 185 | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps 186 | 187 | def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, 188 | crop_box, rng): 189 | duration = frame_timestamps[-1].mean() 190 | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box 191 | h, w = y2 - y1, x2 - x1 192 | ratio = h / w 193 | df, dh, dw = self.downsample 194 | 195 | area_z = min(self.seq_len, self.max_area / (dh * dw), 196 | (h // dh) * (w // dw)) 197 | of = min((len(frame_timestamps) - 1) // df + 1, 198 | int(self.seq_len / area_z)) 199 | 200 | # deduce target shape of the [latent video] 201 | target_area_z = min(area_z, int(self.seq_len / of)) 202 | oh = round(np.sqrt(target_area_z * ratio)) 203 | ow = int(target_area_z / oh) 204 | of = (of - 1) * df + 1 205 | oh *= dh 206 | ow *= dw 207 | 208 | # sample frame ids 209 | target_duration = duration 210 | target_fps = of / target_duration 211 | timestamps = np.linspace(0., target_duration, of) 212 | frame_ids = np.argmax( 213 | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], 214 | timestamps[:, None] <= frame_timestamps[None, :, 1]), 215 | axis=1).tolist() 216 | # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) 217 | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps 218 | 219 | def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): 220 | if self.keep_last: 221 | return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, 222 | w, crop_box, rng) 223 | else: 224 | return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, 225 | crop_box, rng) 226 | 227 | def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): 228 | return self.load_video_batch( 229 | data_key, crop_box=crop_box, seed=seed, **kwargs) 230 | 231 | def load_video_pair(self, 232 | data_key, 233 | data_key2, 234 | crop_box=None, 235 | seed=2024, 236 | **kwargs): 237 | return self.load_video_batch( 238 | data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) 239 | 240 | def load_video_batch(self, 241 | *data_key_batch, 242 | crop_box=None, 243 | seed=2024, 244 | **kwargs): 245 | rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) 246 | # read video 247 | import decord 248 | decord.bridge.set_bridge('torch') 249 | readers = [] 250 | for data_k in data_key_batch: 251 | reader = decord.VideoReader(data_k) 252 | readers.append(reader) 253 | 254 | fps = readers[0].get_avg_fps() 255 | length = min([len(r) for r in readers]) 256 | frame_timestamps = [ 257 | readers[0].get_frame_timestamp(i) for i in range(length) 258 | ] 259 | frame_timestamps = np.array(frame_timestamps, dtype=np.float32) 260 | h, w = readers[0].next().shape[:2] 261 | frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox( 262 | fps, frame_timestamps, h, w, crop_box, rng) 263 | 264 | # preprocess video 265 | videos = [ 266 | reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] 267 | for reader in readers 268 | ] 269 | videos = [self._video_preprocess(video, oh, ow) for video in videos] 270 | return *videos, frame_ids, (oh, ow), fps 271 | # return videos if len(videos) > 1 else videos[0] 272 | 273 | 274 | def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, 275 | device): 276 | for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): 277 | if sub_src_video is None and sub_src_mask is None: 278 | src_video[i] = torch.zeros( 279 | (3, num_frames, image_size[0], image_size[1]), device=device) 280 | src_mask[i] = torch.ones( 281 | (1, num_frames, image_size[0], image_size[1]), device=device) 282 | for i, ref_images in enumerate(src_ref_images): 283 | if ref_images is not None: 284 | for j, ref_img in enumerate(ref_images): 285 | if ref_img is not None and ref_img.shape[-2:] != image_size: 286 | canvas_height, canvas_width = image_size 287 | ref_height, ref_width = ref_img.shape[-2:] 288 | white_canvas = torch.ones( 289 | (3, 1, canvas_height, canvas_width), 290 | device=device) # [-1, 1] 291 | scale = min(canvas_height / ref_height, 292 | canvas_width / ref_width) 293 | new_height = int(ref_height * scale) 294 | new_width = int(ref_width * scale) 295 | resized_image = F.interpolate( 296 | ref_img.squeeze(1).unsqueeze(0), 297 | size=(new_height, new_width), 298 | mode='bilinear', 299 | align_corners=False).squeeze(0).unsqueeze(1) 300 | top = (canvas_height - new_height) // 2 301 | left = (canvas_width - new_width) // 2 302 | white_canvas[:, :, top:top + new_height, 303 | left:left + new_width] = resized_image 304 | src_ref_images[i][j] = white_canvas 305 | return src_video, src_mask, src_ref_images 306 | -------------------------------------------------------------------------------- /wan/utils/qwen_vl_utils.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/kq-chen/qwen-vl-utils 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | from __future__ import annotations 4 | 5 | import base64 6 | import logging 7 | import math 8 | import os 9 | import sys 10 | import time 11 | import warnings 12 | from functools import lru_cache 13 | from io import BytesIO 14 | 15 | import requests 16 | import torch 17 | import torchvision 18 | from packaging import version 19 | from PIL import Image 20 | from torchvision import io, transforms 21 | from torchvision.transforms import InterpolationMode 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | IMAGE_FACTOR = 28 26 | MIN_PIXELS = 4 * 28 * 28 27 | MAX_PIXELS = 16384 * 28 * 28 28 | MAX_RATIO = 200 29 | 30 | VIDEO_MIN_PIXELS = 128 * 28 * 28 31 | VIDEO_MAX_PIXELS = 768 * 28 * 28 32 | VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 33 | FRAME_FACTOR = 2 34 | FPS = 2.0 35 | FPS_MIN_FRAMES = 4 36 | FPS_MAX_FRAMES = 768 37 | 38 | 39 | def round_by_factor(number: int, factor: int) -> int: 40 | """Returns the closest integer to 'number' that is divisible by 'factor'.""" 41 | return round(number / factor) * factor 42 | 43 | 44 | def ceil_by_factor(number: int, factor: int) -> int: 45 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" 46 | return math.ceil(number / factor) * factor 47 | 48 | 49 | def floor_by_factor(number: int, factor: int) -> int: 50 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" 51 | return math.floor(number / factor) * factor 52 | 53 | 54 | def smart_resize(height: int, 55 | width: int, 56 | factor: int = IMAGE_FACTOR, 57 | min_pixels: int = MIN_PIXELS, 58 | max_pixels: int = MAX_PIXELS) -> tuple[int, int]: 59 | """ 60 | Rescales the image so that the following conditions are met: 61 | 62 | 1. Both dimensions (height and width) are divisible by 'factor'. 63 | 64 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 65 | 66 | 3. The aspect ratio of the image is maintained as closely as possible. 67 | """ 68 | if max(height, width) / min(height, width) > MAX_RATIO: 69 | raise ValueError( 70 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" 71 | ) 72 | h_bar = max(factor, round_by_factor(height, factor)) 73 | w_bar = max(factor, round_by_factor(width, factor)) 74 | if h_bar * w_bar > max_pixels: 75 | beta = math.sqrt((height * width) / max_pixels) 76 | h_bar = floor_by_factor(height / beta, factor) 77 | w_bar = floor_by_factor(width / beta, factor) 78 | elif h_bar * w_bar < min_pixels: 79 | beta = math.sqrt(min_pixels / (height * width)) 80 | h_bar = ceil_by_factor(height * beta, factor) 81 | w_bar = ceil_by_factor(width * beta, factor) 82 | return h_bar, w_bar 83 | 84 | 85 | def fetch_image(ele: dict[str, str | Image.Image], 86 | size_factor: int = IMAGE_FACTOR) -> Image.Image: 87 | if "image" in ele: 88 | image = ele["image"] 89 | else: 90 | image = ele["image_url"] 91 | image_obj = None 92 | if isinstance(image, Image.Image): 93 | image_obj = image 94 | elif image.startswith("http://") or image.startswith("https://"): 95 | image_obj = Image.open(requests.get(image, stream=True).raw) 96 | elif image.startswith("file://"): 97 | image_obj = Image.open(image[7:]) 98 | elif image.startswith("data:image"): 99 | if "base64," in image: 100 | _, base64_data = image.split("base64,", 1) 101 | data = base64.b64decode(base64_data) 102 | image_obj = Image.open(BytesIO(data)) 103 | else: 104 | image_obj = Image.open(image) 105 | if image_obj is None: 106 | raise ValueError( 107 | f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" 108 | ) 109 | image = image_obj.convert("RGB") 110 | ## resize 111 | if "resized_height" in ele and "resized_width" in ele: 112 | resized_height, resized_width = smart_resize( 113 | ele["resized_height"], 114 | ele["resized_width"], 115 | factor=size_factor, 116 | ) 117 | else: 118 | width, height = image.size 119 | min_pixels = ele.get("min_pixels", MIN_PIXELS) 120 | max_pixels = ele.get("max_pixels", MAX_PIXELS) 121 | resized_height, resized_width = smart_resize( 122 | height, 123 | width, 124 | factor=size_factor, 125 | min_pixels=min_pixels, 126 | max_pixels=max_pixels, 127 | ) 128 | image = image.resize((resized_width, resized_height)) 129 | 130 | return image 131 | 132 | 133 | def smart_nframes( 134 | ele: dict, 135 | total_frames: int, 136 | video_fps: int | float, 137 | ) -> int: 138 | """calculate the number of frames for video used for model inputs. 139 | 140 | Args: 141 | ele (dict): a dict contains the configuration of video. 142 | support either `fps` or `nframes`: 143 | - nframes: the number of frames to extract for model inputs. 144 | - fps: the fps to extract frames for model inputs. 145 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 146 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 147 | total_frames (int): the original total number of frames of the video. 148 | video_fps (int | float): the original fps of the video. 149 | 150 | Raises: 151 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 152 | 153 | Returns: 154 | int: the number of frames for video used for model inputs. 155 | """ 156 | assert not ("fps" in ele and 157 | "nframes" in ele), "Only accept either `fps` or `nframes`" 158 | if "nframes" in ele: 159 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 160 | else: 161 | fps = ele.get("fps", FPS) 162 | min_frames = ceil_by_factor( 163 | ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 164 | max_frames = floor_by_factor( 165 | ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), 166 | FRAME_FACTOR) 167 | nframes = total_frames / video_fps * fps 168 | nframes = min(max(nframes, min_frames), max_frames) 169 | nframes = round_by_factor(nframes, FRAME_FACTOR) 170 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 171 | raise ValueError( 172 | f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." 173 | ) 174 | return nframes 175 | 176 | 177 | def _read_video_torchvision(ele: dict,) -> torch.Tensor: 178 | """read video using torchvision.io.read_video 179 | 180 | Args: 181 | ele (dict): a dict contains the configuration of video. 182 | support keys: 183 | - video: the path of video. support "file://", "http://", "https://" and local path. 184 | - video_start: the start time of video. 185 | - video_end: the end time of video. 186 | Returns: 187 | torch.Tensor: the video tensor with shape (T, C, H, W). 188 | """ 189 | video_path = ele["video"] 190 | if version.parse(torchvision.__version__) < version.parse("0.19.0"): 191 | if "http://" in video_path or "https://" in video_path: 192 | warnings.warn( 193 | "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." 194 | ) 195 | if "file://" in video_path: 196 | video_path = video_path[7:] 197 | st = time.time() 198 | video, audio, info = io.read_video( 199 | video_path, 200 | start_pts=ele.get("video_start", 0.0), 201 | end_pts=ele.get("video_end", None), 202 | pts_unit="sec", 203 | output_format="TCHW", 204 | ) 205 | total_frames, video_fps = video.size(0), info["video_fps"] 206 | logger.info( 207 | f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 208 | ) 209 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 210 | idx = torch.linspace(0, total_frames - 1, nframes).round().long() 211 | video = video[idx] 212 | return video 213 | 214 | 215 | def is_decord_available() -> bool: 216 | import importlib.util 217 | 218 | return importlib.util.find_spec("decord") is not None 219 | 220 | 221 | def _read_video_decord(ele: dict,) -> torch.Tensor: 222 | """read video using decord.VideoReader 223 | 224 | Args: 225 | ele (dict): a dict contains the configuration of video. 226 | support keys: 227 | - video: the path of video. support "file://", "http://", "https://" and local path. 228 | - video_start: the start time of video. 229 | - video_end: the end time of video. 230 | Returns: 231 | torch.Tensor: the video tensor with shape (T, C, H, W). 232 | """ 233 | import decord 234 | video_path = ele["video"] 235 | st = time.time() 236 | vr = decord.VideoReader(video_path) 237 | # TODO: support start_pts and end_pts 238 | if 'video_start' in ele or 'video_end' in ele: 239 | raise NotImplementedError( 240 | "not support start_pts and end_pts in decord for now.") 241 | total_frames, video_fps = len(vr), vr.get_avg_fps() 242 | logger.info( 243 | f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 244 | ) 245 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 246 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 247 | video = vr.get_batch(idx).asnumpy() 248 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format 249 | return video 250 | 251 | 252 | VIDEO_READER_BACKENDS = { 253 | "decord": _read_video_decord, 254 | "torchvision": _read_video_torchvision, 255 | } 256 | 257 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) 258 | 259 | 260 | @lru_cache(maxsize=1) 261 | def get_video_reader_backend() -> str: 262 | if FORCE_QWENVL_VIDEO_READER is not None: 263 | video_reader_backend = FORCE_QWENVL_VIDEO_READER 264 | elif is_decord_available(): 265 | video_reader_backend = "decord" 266 | else: 267 | video_reader_backend = "torchvision" 268 | print( 269 | f"qwen-vl-utils using {video_reader_backend} to read video.", 270 | file=sys.stderr) 271 | return video_reader_backend 272 | 273 | 274 | def fetch_video( 275 | ele: dict, 276 | image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: 277 | if isinstance(ele["video"], str): 278 | video_reader_backend = get_video_reader_backend() 279 | video = VIDEO_READER_BACKENDS[video_reader_backend](ele) 280 | nframes, _, height, width = video.shape 281 | 282 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 283 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 284 | max_pixels = max( 285 | min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), 286 | int(min_pixels * 1.05)) 287 | max_pixels = ele.get("max_pixels", max_pixels) 288 | if "resized_height" in ele and "resized_width" in ele: 289 | resized_height, resized_width = smart_resize( 290 | ele["resized_height"], 291 | ele["resized_width"], 292 | factor=image_factor, 293 | ) 294 | else: 295 | resized_height, resized_width = smart_resize( 296 | height, 297 | width, 298 | factor=image_factor, 299 | min_pixels=min_pixels, 300 | max_pixels=max_pixels, 301 | ) 302 | video = transforms.functional.resize( 303 | video, 304 | [resized_height, resized_width], 305 | interpolation=InterpolationMode.BICUBIC, 306 | antialias=True, 307 | ).float() 308 | return video 309 | else: 310 | assert isinstance(ele["video"], (list, tuple)) 311 | process_info = ele.copy() 312 | process_info.pop("type", None) 313 | process_info.pop("video", None) 314 | images = [ 315 | fetch_image({ 316 | "image": video_element, 317 | **process_info 318 | }, 319 | size_factor=image_factor) 320 | for video_element in ele["video"] 321 | ] 322 | nframes = ceil_by_factor(len(images), FRAME_FACTOR) 323 | if len(images) < nframes: 324 | images.extend([images[-1]] * (nframes - len(images))) 325 | return images 326 | 327 | 328 | def extract_vision_info( 329 | conversations: list[dict] | list[list[dict]]) -> list[dict]: 330 | vision_infos = [] 331 | if isinstance(conversations[0], dict): 332 | conversations = [conversations] 333 | for conversation in conversations: 334 | for message in conversation: 335 | if isinstance(message["content"], list): 336 | for ele in message["content"]: 337 | if ("image" in ele or "image_url" in ele or 338 | "video" in ele or 339 | ele["type"] in ("image", "image_url", "video")): 340 | vision_infos.append(ele) 341 | return vision_infos 342 | 343 | 344 | def process_vision_info( 345 | conversations: list[dict] | list[list[dict]], 346 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | 347 | None]: 348 | vision_infos = extract_vision_info(conversations) 349 | ## Read images or videos 350 | image_inputs = [] 351 | video_inputs = [] 352 | for vision_info in vision_infos: 353 | if "image" in vision_info or "image_url" in vision_info: 354 | image_inputs.append(fetch_image(vision_info)) 355 | elif "video" in vision_info: 356 | video_inputs.append(fetch_video(vision_info)) 357 | else: 358 | raise ValueError("image, image_url or video should in content.") 359 | if len(image_inputs) == 0: 360 | image_inputs = None 361 | if len(video_inputs) == 0: 362 | video_inputs = None 363 | return image_inputs, video_inputs 364 | -------------------------------------------------------------------------------- /wan/image2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import ( 25 | FlowDPMSolverMultistepScheduler, 26 | get_sampling_sigmas, 27 | retrieve_timesteps, 28 | ) 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 30 | 31 | 32 | class WanI2V: 33 | 34 | def __init__( 35 | self, 36 | config, 37 | checkpoint_dir, 38 | device_id=0, 39 | rank=0, 40 | t5_fsdp=False, 41 | dit_fsdp=False, 42 | use_usp=False, 43 | t5_cpu=False, 44 | init_on_cpu=True, 45 | ): 46 | r""" 47 | Initializes the image-to-video generation model components. 48 | 49 | Args: 50 | config (EasyDict): 51 | Object containing model parameters initialized from config.py 52 | checkpoint_dir (`str`): 53 | Path to directory containing model checkpoints 54 | device_id (`int`, *optional*, defaults to 0): 55 | Id of target GPU device 56 | rank (`int`, *optional*, defaults to 0): 57 | Process rank for distributed training 58 | t5_fsdp (`bool`, *optional*, defaults to False): 59 | Enable FSDP sharding for T5 model 60 | dit_fsdp (`bool`, *optional*, defaults to False): 61 | Enable FSDP sharding for DiT model 62 | use_usp (`bool`, *optional*, defaults to False): 63 | Enable distribution strategy of USP. 64 | t5_cpu (`bool`, *optional*, defaults to False): 65 | Whether to place T5 model on CPU. Only works without t5_fsdp. 66 | init_on_cpu (`bool`, *optional*, defaults to True): 67 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 68 | """ 69 | self.device = torch.device(f"cuda:{device_id}") 70 | self.config = config 71 | self.rank = rank 72 | self.use_usp = use_usp 73 | self.t5_cpu = t5_cpu 74 | 75 | self.num_train_timesteps = config.num_train_timesteps 76 | self.param_dtype = config.param_dtype 77 | 78 | shard_fn = partial(shard_model, device_id=device_id) 79 | self.text_encoder = T5EncoderModel( 80 | text_len=config.text_len, 81 | dtype=config.t5_dtype, 82 | device=torch.device('cpu'), 83 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 84 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 85 | shard_fn=shard_fn if t5_fsdp else None, 86 | ) 87 | 88 | self.vae_stride = config.vae_stride 89 | self.patch_size = config.patch_size 90 | self.vae = WanVAE( 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 92 | device=self.device) 93 | 94 | self.clip = CLIPModel( 95 | dtype=config.clip_dtype, 96 | device=self.device, 97 | checkpoint_path=os.path.join(checkpoint_dir, 98 | config.clip_checkpoint), 99 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 100 | 101 | logging.info(f"Creating WanModel from {checkpoint_dir}") 102 | self.model = WanModel.from_pretrained(checkpoint_dir) 103 | self.model.eval().requires_grad_(False) 104 | 105 | if t5_fsdp or dit_fsdp or use_usp: 106 | init_on_cpu = False 107 | 108 | if use_usp: 109 | from xfuser.core.distributed import get_sequence_parallel_world_size 110 | 111 | from .distributed.xdit_context_parallel import ( 112 | usp_attn_forward, 113 | usp_dit_forward, 114 | ) 115 | for block in self.model.blocks: 116 | block.self_attn.forward = types.MethodType( 117 | usp_attn_forward, block.self_attn) 118 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 119 | self.sp_size = get_sequence_parallel_world_size() 120 | else: 121 | self.sp_size = 1 122 | 123 | if dist.is_initialized(): 124 | dist.barrier() 125 | if dit_fsdp: 126 | self.model = shard_fn(self.model) 127 | else: 128 | if not init_on_cpu: 129 | self.model.to(self.device) 130 | 131 | self.sample_neg_prompt = config.sample_neg_prompt 132 | 133 | def generate(self, 134 | input_prompt, 135 | img, 136 | max_area=720 * 1280, 137 | frame_num=81, 138 | shift=5.0, 139 | sample_solver='unipc', 140 | sampling_steps=40, 141 | guide_scale=5.0, 142 | n_prompt="", 143 | seed=-1, 144 | offload_model=True): 145 | r""" 146 | Generates video frames from input image and text prompt using diffusion process. 147 | 148 | Args: 149 | input_prompt (`str`): 150 | Text prompt for content generation. 151 | img (PIL.Image.Image): 152 | Input image tensor. Shape: [3, H, W] 153 | max_area (`int`, *optional*, defaults to 720*1280): 154 | Maximum pixel area for latent space calculation. Controls video resolution scaling 155 | frame_num (`int`, *optional*, defaults to 81): 156 | How many frames to sample from a video. The number should be 4n+1 157 | shift (`float`, *optional*, defaults to 5.0): 158 | Noise schedule shift parameter. Affects temporal dynamics 159 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 160 | sample_solver (`str`, *optional*, defaults to 'unipc'): 161 | Solver used to sample the video. 162 | sampling_steps (`int`, *optional*, defaults to 40): 163 | Number of diffusion sampling steps. Higher values improve quality but slow generation 164 | guide_scale (`float`, *optional*, defaults 5.0): 165 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 166 | n_prompt (`str`, *optional*, defaults to ""): 167 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 168 | seed (`int`, *optional*, defaults to -1): 169 | Random seed for noise generation. If -1, use random seed 170 | offload_model (`bool`, *optional*, defaults to True): 171 | If True, offloads models to CPU during generation to save VRAM 172 | 173 | Returns: 174 | torch.Tensor: 175 | Generated video frames tensor. Dimensions: (C, N H, W) where: 176 | - C: Color channels (3 for RGB) 177 | - N: Number of frames (81) 178 | - H: Frame height (from max_area) 179 | - W: Frame width from max_area) 180 | """ 181 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) 182 | 183 | F = frame_num 184 | h, w = img.shape[1:] 185 | aspect_ratio = h / w 186 | lat_h = round( 187 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 188 | self.patch_size[1] * self.patch_size[1]) 189 | lat_w = round( 190 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 191 | self.patch_size[2] * self.patch_size[2]) 192 | h = lat_h * self.vae_stride[1] 193 | w = lat_w * self.vae_stride[2] 194 | 195 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 196 | self.patch_size[1] * self.patch_size[2]) 197 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 198 | 199 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 200 | seed_g = torch.Generator(device=self.device) 201 | seed_g.manual_seed(seed) 202 | noise = torch.randn( 203 | 16, (F - 1) // 4 + 1, 204 | lat_h, 205 | lat_w, 206 | dtype=torch.float32, 207 | generator=seed_g, 208 | device=self.device) 209 | 210 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 211 | msk[:, 1:] = 0 212 | msk = torch.concat([ 213 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 214 | ], 215 | dim=1) 216 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 217 | msk = msk.transpose(1, 2)[0] 218 | 219 | if n_prompt == "": 220 | n_prompt = self.sample_neg_prompt 221 | 222 | # preprocess 223 | if not self.t5_cpu: 224 | self.text_encoder.model.to(self.device) 225 | context = self.text_encoder([input_prompt], self.device) 226 | context_null = self.text_encoder([n_prompt], self.device) 227 | if offload_model: 228 | self.text_encoder.model.cpu() 229 | else: 230 | context = self.text_encoder([input_prompt], torch.device('cpu')) 231 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 232 | context = [t.to(self.device) for t in context] 233 | context_null = [t.to(self.device) for t in context_null] 234 | 235 | self.clip.model.to(self.device) 236 | clip_context = self.clip.visual([img[:, None, :, :]]) 237 | if offload_model: 238 | self.clip.model.cpu() 239 | 240 | y = self.vae.encode([ 241 | torch.concat([ 242 | torch.nn.functional.interpolate( 243 | img[None].cpu(), size=(h, w), mode='bicubic').transpose( 244 | 0, 1), 245 | torch.zeros(3, F - 1, h, w) 246 | ], 247 | dim=1).to(self.device) 248 | ])[0] 249 | y = torch.concat([msk, y]) 250 | 251 | @contextmanager 252 | def noop_no_sync(): 253 | yield 254 | 255 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 256 | 257 | # evaluation mode 258 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 259 | 260 | if sample_solver == 'unipc': 261 | sample_scheduler = FlowUniPCMultistepScheduler( 262 | num_train_timesteps=self.num_train_timesteps, 263 | shift=1, 264 | use_dynamic_shifting=False) 265 | sample_scheduler.set_timesteps( 266 | sampling_steps, device=self.device, shift=shift) 267 | timesteps = sample_scheduler.timesteps 268 | elif sample_solver == 'dpm++': 269 | sample_scheduler = FlowDPMSolverMultistepScheduler( 270 | num_train_timesteps=self.num_train_timesteps, 271 | shift=1, 272 | use_dynamic_shifting=False) 273 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 274 | timesteps, _ = retrieve_timesteps( 275 | sample_scheduler, 276 | device=self.device, 277 | sigmas=sampling_sigmas) 278 | else: 279 | raise NotImplementedError("Unsupported solver.") 280 | 281 | # sample videos 282 | latent = noise 283 | 284 | arg_c = { 285 | 'context': [context[0]], 286 | 'clip_fea': clip_context, 287 | 'seq_len': max_seq_len, 288 | 'y': [y], 289 | } 290 | 291 | arg_null = { 292 | 'context': context_null, 293 | 'clip_fea': clip_context, 294 | 'seq_len': max_seq_len, 295 | 'y': [y], 296 | } 297 | 298 | if offload_model: 299 | torch.cuda.empty_cache() 300 | 301 | self.model.to(self.device) 302 | for _, t in enumerate(tqdm(timesteps)): 303 | latent_model_input = [latent.to(self.device)] 304 | timestep = [t] 305 | 306 | timestep = torch.stack(timestep).to(self.device) 307 | 308 | noise_pred_cond = self.model( 309 | latent_model_input, t=timestep, **arg_c)[0].to( 310 | torch.device('cpu') if offload_model else self.device) 311 | if offload_model: 312 | torch.cuda.empty_cache() 313 | noise_pred_uncond = self.model( 314 | latent_model_input, t=timestep, **arg_null)[0].to( 315 | torch.device('cpu') if offload_model else self.device) 316 | if offload_model: 317 | torch.cuda.empty_cache() 318 | noise_pred = noise_pred_uncond + guide_scale * ( 319 | noise_pred_cond - noise_pred_uncond) 320 | 321 | latent = latent.to( 322 | torch.device('cpu') if offload_model else self.device) 323 | 324 | temp_x0 = sample_scheduler.step( 325 | noise_pred.unsqueeze(0), 326 | t, 327 | latent.unsqueeze(0), 328 | return_dict=False, 329 | generator=seed_g)[0] 330 | latent = temp_x0.squeeze(0) 331 | 332 | x0 = [latent.to(self.device)] 333 | del latent_model_input, timestep 334 | 335 | if offload_model: 336 | self.model.cpu() 337 | torch.cuda.empty_cache() 338 | 339 | if self.rank == 0: 340 | videos = self.vae.decode(x0) 341 | 342 | del noise, latent 343 | del sample_scheduler 344 | if offload_model: 345 | gc.collect() 346 | torch.cuda.synchronize() 347 | if dist.is_initialized(): 348 | dist.barrier() 349 | 350 | return videos[0] if self.rank == 0 else None 351 | -------------------------------------------------------------------------------- /wan/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange, repeat 5 | from ..utils.multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids 6 | from xfuser.core.distributed import ( 7 | get_sequence_parallel_rank, 8 | get_sequence_parallel_world_size, 9 | get_sp_group, 10 | ) 11 | import xformers.ops 12 | 13 | try: 14 | import flash_attn_interface 15 | FLASH_ATTN_3_AVAILABLE = True 16 | except ModuleNotFoundError: 17 | FLASH_ATTN_3_AVAILABLE = False 18 | 19 | try: 20 | import flash_attn 21 | FLASH_ATTN_2_AVAILABLE = True 22 | except ModuleNotFoundError: 23 | FLASH_ATTN_2_AVAILABLE = False 24 | 25 | import warnings 26 | 27 | __all__ = [ 28 | 'flash_attention', 29 | 'attention', 30 | ] 31 | 32 | 33 | def flash_attention( 34 | q, 35 | k, 36 | v, 37 | q_lens=None, 38 | k_lens=None, 39 | dropout_p=0., 40 | softmax_scale=None, 41 | q_scale=None, 42 | causal=False, 43 | window_size=(-1, -1), 44 | deterministic=False, 45 | dtype=torch.bfloat16, 46 | version=None, 47 | ): 48 | """ 49 | q: [B, Lq, Nq, C1]. 50 | k: [B, Lk, Nk, C1]. 51 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 52 | q_lens: [B]. 53 | k_lens: [B]. 54 | dropout_p: float. Dropout probability. 55 | softmax_scale: float. The scaling of QK^T before applying softmax. 56 | causal: bool. Whether to apply causal attention mask. 57 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 58 | deterministic: bool. If True, slightly slower and uses more memory. 59 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 60 | """ 61 | half_dtypes = (torch.float16, torch.bfloat16) 62 | assert dtype in half_dtypes 63 | assert q.device.type == 'cuda' and q.size(-1) <= 256 64 | 65 | # params 66 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 67 | 68 | def half(x): 69 | return x if x.dtype in half_dtypes else x.to(dtype) 70 | 71 | # preprocess query 72 | if q_lens is None: 73 | q = half(q.flatten(0, 1)) 74 | q_lens = torch.tensor( 75 | [lq] * b, dtype=torch.int32).to( 76 | device=q.device, non_blocking=True) 77 | else: 78 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) 79 | 80 | # preprocess key, value 81 | if k_lens is None: 82 | k = half(k.flatten(0, 1)) 83 | v = half(v.flatten(0, 1)) 84 | k_lens = torch.tensor( 85 | [lk] * b, dtype=torch.int32).to( 86 | device=k.device, non_blocking=True) 87 | else: 88 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) 89 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) 90 | 91 | q = q.to(v.dtype) 92 | k = k.to(v.dtype) 93 | 94 | if q_scale is not None: 95 | q = q * q_scale 96 | 97 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 98 | warnings.warn( 99 | 'Flash attention 3 is not available, use flash attention 2 instead.' 100 | ) 101 | 102 | # apply attention 103 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 104 | # Note: dropout_p, window_size are not supported in FA3 now. 105 | x = flash_attn_interface.flash_attn_varlen_func( 106 | q=q, 107 | k=k, 108 | v=v, 109 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 110 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 111 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 112 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 113 | seqused_q=None, 114 | seqused_k=None, 115 | max_seqlen_q=lq, 116 | max_seqlen_k=lk, 117 | softmax_scale=softmax_scale, 118 | causal=causal, 119 | deterministic=deterministic)[0].unflatten(0, (b, lq)) 120 | else: 121 | assert FLASH_ATTN_2_AVAILABLE 122 | x = flash_attn.flash_attn_varlen_func( 123 | q=q, 124 | k=k, 125 | v=v, 126 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 127 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 128 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 129 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 130 | max_seqlen_q=lq, 131 | max_seqlen_k=lk, 132 | dropout_p=dropout_p, 133 | softmax_scale=softmax_scale, 134 | causal=causal, 135 | window_size=window_size, 136 | deterministic=deterministic).unflatten(0, (b, lq)) 137 | 138 | # output 139 | return x.type(out_dtype) 140 | 141 | 142 | def attention( 143 | q, 144 | k, 145 | v, 146 | q_lens=None, 147 | k_lens=None, 148 | dropout_p=0., 149 | softmax_scale=None, 150 | q_scale=None, 151 | causal=False, 152 | window_size=(-1, -1), 153 | deterministic=False, 154 | dtype=torch.bfloat16, 155 | fa_version=None, 156 | ): 157 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: 158 | return flash_attention( 159 | q=q, 160 | k=k, 161 | v=v, 162 | q_lens=q_lens, 163 | k_lens=k_lens, 164 | dropout_p=dropout_p, 165 | softmax_scale=softmax_scale, 166 | q_scale=q_scale, 167 | causal=causal, 168 | window_size=window_size, 169 | deterministic=deterministic, 170 | dtype=dtype, 171 | version=fa_version, 172 | ) 173 | else: 174 | if q_lens is not None or k_lens is not None: 175 | warnings.warn( 176 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' 177 | ) 178 | attn_mask = None 179 | 180 | q = q.transpose(1, 2).to(dtype) 181 | k = k.transpose(1, 2).to(dtype) 182 | v = v.transpose(1, 2).to(dtype) 183 | 184 | out = torch.nn.functional.scaled_dot_product_attention( 185 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 186 | 187 | out = out.transpose(1, 2).contiguous() 188 | return out 189 | 190 | 191 | class SingleStreamAttention(nn.Module): 192 | def __init__( 193 | self, 194 | dim: int, 195 | encoder_hidden_states_dim: int, 196 | num_heads: int, 197 | qkv_bias: bool, 198 | qk_norm: bool, 199 | norm_layer: nn.Module, 200 | attn_drop: float = 0.0, 201 | proj_drop: float = 0.0, 202 | eps: float = 1e-6, 203 | ) -> None: 204 | super().__init__() 205 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 206 | self.dim = dim 207 | self.encoder_hidden_states_dim = encoder_hidden_states_dim 208 | self.num_heads = num_heads 209 | self.head_dim = dim // num_heads 210 | self.scale = self.head_dim**-0.5 211 | self.qk_norm = qk_norm 212 | 213 | self.q_linear = nn.Linear(dim, dim, bias=qkv_bias) 214 | 215 | self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity() 216 | self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity() 217 | 218 | self.attn_drop = nn.Dropout(attn_drop) 219 | self.proj = nn.Linear(dim, dim) 220 | self.proj_drop = nn.Dropout(proj_drop) 221 | 222 | self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias) 223 | 224 | self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 225 | self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 226 | 227 | def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: 228 | 229 | N_t, N_h, N_w = shape 230 | if not enable_sp: 231 | x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 232 | 233 | # get q for hidden_state 234 | B, N, C = x.shape 235 | q = self.q_linear(x) 236 | q_shape = (B, N, self.num_heads, self.head_dim) 237 | q = q.view(q_shape).permute((0, 2, 1, 3)) 238 | 239 | if self.qk_norm: 240 | q = self.q_norm(q) 241 | 242 | # get kv from encoder_hidden_states 243 | _, N_a, _ = encoder_hidden_states.shape 244 | encoder_kv = self.kv_linear(encoder_hidden_states) 245 | encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) 246 | encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 247 | encoder_k, encoder_v = encoder_kv.unbind(0) 248 | 249 | if self.qk_norm: 250 | encoder_k = self.add_k_norm(encoder_k) 251 | 252 | 253 | q = rearrange(q, "B H M K -> B M H K") 254 | encoder_k = rearrange(encoder_k, "B H M K -> B M H K") 255 | encoder_v = rearrange(encoder_v, "B H M K -> B M H K") 256 | 257 | if enable_sp: 258 | # context parallel 259 | sp_size = get_sequence_parallel_world_size() 260 | sp_rank = get_sequence_parallel_rank() 261 | visual_seqlen, _ = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank) 262 | assert kv_seq is not None, f"kv_seq should not be None." 263 | attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq) 264 | else: 265 | attn_bias = None 266 | x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) 267 | x = rearrange(x, "B M H K -> B H M K") 268 | 269 | # linear transform 270 | x_output_shape = (B, N, C) 271 | x = x.transpose(1, 2) 272 | x = x.reshape(x_output_shape) 273 | x = self.proj(x) 274 | x = self.proj_drop(x) 275 | 276 | if not enable_sp: 277 | # reshape x to origin shape 278 | x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) 279 | 280 | return x 281 | 282 | class SingleStreamMutiAttention(SingleStreamAttention): 283 | def __init__( 284 | self, 285 | dim: int, 286 | encoder_hidden_states_dim: int, 287 | num_heads: int, 288 | qkv_bias: bool, 289 | qk_norm: bool, 290 | norm_layer: nn.Module, 291 | attn_drop: float = 0.0, 292 | proj_drop: float = 0.0, 293 | eps: float = 1e-6, 294 | class_range: int = 24, 295 | class_interval: int = 4, 296 | ) -> None: 297 | super().__init__( 298 | dim=dim, 299 | encoder_hidden_states_dim=encoder_hidden_states_dim, 300 | num_heads=num_heads, 301 | qkv_bias=qkv_bias, 302 | qk_norm=qk_norm, 303 | norm_layer=norm_layer, 304 | attn_drop=attn_drop, 305 | proj_drop=proj_drop, 306 | eps=eps, 307 | ) 308 | self.class_interval = class_interval 309 | self.class_range = class_range 310 | self.rope_h1 = (0, self.class_interval) 311 | self.rope_h2 = (self.class_range - self.class_interval, self.class_range) 312 | self.rope_bak = int(self.class_range // 2) 313 | 314 | self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) 315 | 316 | def forward(self, 317 | x: torch.Tensor, 318 | encoder_hidden_states: torch.Tensor, 319 | shape=None, 320 | x_ref_attn_map=None, 321 | human_num=None) -> torch.Tensor: 322 | 323 | encoder_hidden_states = encoder_hidden_states.squeeze(0) 324 | if human_num == 1: 325 | return super().forward(x, encoder_hidden_states, shape) 326 | 327 | N_t, _, _ = shape 328 | x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 329 | 330 | # get q for hidden_state 331 | B, N, C = x.shape 332 | q = self.q_linear(x) 333 | q_shape = (B, N, self.num_heads, self.head_dim) 334 | q = q.view(q_shape).permute((0, 2, 1, 3)) 335 | 336 | if self.qk_norm: 337 | q = self.q_norm(q) 338 | 339 | 340 | max_values = x_ref_attn_map.max(1).values[:, None, None] 341 | min_values = x_ref_attn_map.min(1).values[:, None, None] 342 | max_min_values = torch.cat([max_values, min_values], dim=2) 343 | 344 | human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() 345 | human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() 346 | 347 | human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1])) 348 | human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1])) 349 | back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device) 350 | max_indices = x_ref_attn_map.argmax(dim=0) 351 | normalized_map = torch.stack([human1, human2, back], dim=1) 352 | normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 353 | 354 | q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) 355 | q = self.rope_1d(q, normalized_pos) 356 | q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) 357 | 358 | _, N_a, _ = encoder_hidden_states.shape 359 | encoder_kv = self.kv_linear(encoder_hidden_states) 360 | encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) 361 | encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 362 | encoder_k, encoder_v = encoder_kv.unbind(0) 363 | 364 | if self.qk_norm: 365 | encoder_k = self.add_k_norm(encoder_k) 366 | 367 | 368 | per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device) 369 | per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2 370 | per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 371 | encoder_pos = torch.concat([per_frame]*N_t, dim=0) 372 | encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) 373 | encoder_k = self.rope_1d(encoder_k, encoder_pos) 374 | encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) 375 | 376 | 377 | q = rearrange(q, "B H M K -> B M H K") 378 | encoder_k = rearrange(encoder_k, "B H M K -> B M H K") 379 | encoder_v = rearrange(encoder_v, "B H M K -> B M H K") 380 | x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) 381 | x = rearrange(x, "B M H K -> B H M K") 382 | 383 | # linear transform 384 | x_output_shape = (B, N, C) 385 | x = x.transpose(1, 2) 386 | x = x.reshape(x_output_shape) 387 | x = self.proj(x) 388 | x = self.proj_drop(x) 389 | 390 | # reshape x to origin shape 391 | x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) 392 | 393 | return x -------------------------------------------------------------------------------- /wan/first_last_frame2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import ( 25 | FlowDPMSolverMultistepScheduler, 26 | get_sampling_sigmas, 27 | retrieve_timesteps, 28 | ) 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 30 | 31 | 32 | class WanFLF2V: 33 | 34 | def __init__( 35 | self, 36 | config, 37 | checkpoint_dir, 38 | device_id=0, 39 | rank=0, 40 | t5_fsdp=False, 41 | dit_fsdp=False, 42 | use_usp=False, 43 | t5_cpu=False, 44 | init_on_cpu=True, 45 | ): 46 | r""" 47 | Initializes the image-to-video generation model components. 48 | 49 | Args: 50 | config (EasyDict): 51 | Object containing model parameters initialized from config.py 52 | checkpoint_dir (`str`): 53 | Path to directory containing model checkpoints 54 | device_id (`int`, *optional*, defaults to 0): 55 | Id of target GPU device 56 | rank (`int`, *optional*, defaults to 0): 57 | Process rank for distributed training 58 | t5_fsdp (`bool`, *optional*, defaults to False): 59 | Enable FSDP sharding for T5 model 60 | dit_fsdp (`bool`, *optional*, defaults to False): 61 | Enable FSDP sharding for DiT model 62 | use_usp (`bool`, *optional*, defaults to False): 63 | Enable distribution strategy of USP. 64 | t5_cpu (`bool`, *optional*, defaults to False): 65 | Whether to place T5 model on CPU. Only works without t5_fsdp. 66 | init_on_cpu (`bool`, *optional*, defaults to True): 67 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 68 | """ 69 | self.device = torch.device(f"cuda:{device_id}") 70 | self.config = config 71 | self.rank = rank 72 | self.use_usp = use_usp 73 | self.t5_cpu = t5_cpu 74 | 75 | self.num_train_timesteps = config.num_train_timesteps 76 | self.param_dtype = config.param_dtype 77 | 78 | shard_fn = partial(shard_model, device_id=device_id) 79 | self.text_encoder = T5EncoderModel( 80 | text_len=config.text_len, 81 | dtype=config.t5_dtype, 82 | device=torch.device('cpu'), 83 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 84 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 85 | shard_fn=shard_fn if t5_fsdp else None, 86 | ) 87 | 88 | self.vae_stride = config.vae_stride 89 | self.patch_size = config.patch_size 90 | self.vae = WanVAE( 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 92 | device=self.device) 93 | 94 | self.clip = CLIPModel( 95 | dtype=config.clip_dtype, 96 | device=self.device, 97 | checkpoint_path=os.path.join(checkpoint_dir, 98 | config.clip_checkpoint), 99 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 100 | 101 | logging.info(f"Creating WanModel from {checkpoint_dir}") 102 | self.model = WanModel.from_pretrained(checkpoint_dir) 103 | self.model.eval().requires_grad_(False) 104 | 105 | if t5_fsdp or dit_fsdp or use_usp: 106 | init_on_cpu = False 107 | 108 | if use_usp: 109 | from xfuser.core.distributed import get_sequence_parallel_world_size 110 | 111 | from .distributed.xdit_context_parallel import ( 112 | usp_attn_forward, 113 | usp_dit_forward, 114 | ) 115 | for block in self.model.blocks: 116 | block.self_attn.forward = types.MethodType( 117 | usp_attn_forward, block.self_attn) 118 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 119 | self.sp_size = get_sequence_parallel_world_size() 120 | else: 121 | self.sp_size = 1 122 | 123 | if dist.is_initialized(): 124 | dist.barrier() 125 | if dit_fsdp: 126 | self.model = shard_fn(self.model) 127 | else: 128 | if not init_on_cpu: 129 | self.model.to(self.device) 130 | 131 | self.sample_neg_prompt = config.sample_neg_prompt 132 | 133 | def generate(self, 134 | input_prompt, 135 | first_frame, 136 | last_frame, 137 | max_area=720 * 1280, 138 | frame_num=81, 139 | shift=16, 140 | sample_solver='unipc', 141 | sampling_steps=50, 142 | guide_scale=5.5, 143 | n_prompt="", 144 | seed=-1, 145 | offload_model=True): 146 | r""" 147 | Generates video frames from input first-last frame and text prompt using diffusion process. 148 | 149 | Args: 150 | input_prompt (`str`): 151 | Text prompt for content generation. 152 | first_frame (PIL.Image.Image): 153 | Input image tensor. Shape: [3, H, W] 154 | last_frame (PIL.Image.Image): 155 | Input image tensor. Shape: [3, H, W] 156 | [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized 157 | to match first_frame. 158 | max_area (`int`, *optional*, defaults to 720*1280): 159 | Maximum pixel area for latent space calculation. Controls video resolution scaling 160 | frame_num (`int`, *optional*, defaults to 81): 161 | How many frames to sample from a video. The number should be 4n+1 162 | shift (`float`, *optional*, defaults to 5.0): 163 | Noise schedule shift parameter. Affects temporal dynamics 164 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 165 | sample_solver (`str`, *optional*, defaults to 'unipc'): 166 | Solver used to sample the video. 167 | sampling_steps (`int`, *optional*, defaults to 40): 168 | Number of diffusion sampling steps. Higher values improve quality but slow generation 169 | guide_scale (`float`, *optional*, defaults 5.0): 170 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 171 | n_prompt (`str`, *optional*, defaults to ""): 172 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 173 | seed (`int`, *optional*, defaults to -1): 174 | Random seed for noise generation. If -1, use random seed 175 | offload_model (`bool`, *optional*, defaults to True): 176 | If True, offloads models to CPU during generation to save VRAM 177 | 178 | Returns: 179 | torch.Tensor: 180 | Generated video frames tensor. Dimensions: (C, N H, W) where: 181 | - C: Color channels (3 for RGB) 182 | - N: Number of frames (81) 183 | - H: Frame height (from max_area) 184 | - W: Frame width from max_area) 185 | """ 186 | first_frame_size = first_frame.size 187 | last_frame_size = last_frame.size 188 | first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( 189 | self.device) 190 | last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to( 191 | self.device) 192 | 193 | F = frame_num 194 | first_frame_h, first_frame_w = first_frame.shape[1:] 195 | aspect_ratio = first_frame_h / first_frame_w 196 | lat_h = round( 197 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 198 | self.patch_size[1] * self.patch_size[1]) 199 | lat_w = round( 200 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 201 | self.patch_size[2] * self.patch_size[2]) 202 | first_frame_h = lat_h * self.vae_stride[1] 203 | first_frame_w = lat_w * self.vae_stride[2] 204 | if first_frame_size != last_frame_size: 205 | # 1. resize 206 | last_frame_resize_ratio = max( 207 | first_frame_size[0] / last_frame_size[0], 208 | first_frame_size[1] / last_frame_size[1]) 209 | last_frame_size = [ 210 | round(last_frame_size[0] * last_frame_resize_ratio), 211 | round(last_frame_size[1] * last_frame_resize_ratio), 212 | ] 213 | # 2. center crop 214 | last_frame = TF.center_crop(last_frame, last_frame_size) 215 | 216 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 217 | self.patch_size[1] * self.patch_size[2]) 218 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 219 | 220 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 221 | seed_g = torch.Generator(device=self.device) 222 | seed_g.manual_seed(seed) 223 | noise = torch.randn( 224 | 16, (F - 1) // 4 + 1, 225 | lat_h, 226 | lat_w, 227 | dtype=torch.float32, 228 | generator=seed_g, 229 | device=self.device) 230 | 231 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 232 | msk[:, 1:-1] = 0 233 | msk = torch.concat([ 234 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 235 | ], 236 | dim=1) 237 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 238 | msk = msk.transpose(1, 2)[0] 239 | 240 | if n_prompt == "": 241 | n_prompt = self.sample_neg_prompt 242 | 243 | # preprocess 244 | if not self.t5_cpu: 245 | self.text_encoder.model.to(self.device) 246 | context = self.text_encoder([input_prompt], self.device) 247 | context_null = self.text_encoder([n_prompt], self.device) 248 | if offload_model: 249 | self.text_encoder.model.cpu() 250 | else: 251 | context = self.text_encoder([input_prompt], torch.device('cpu')) 252 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 253 | context = [t.to(self.device) for t in context] 254 | context_null = [t.to(self.device) for t in context_null] 255 | 256 | self.clip.model.to(self.device) 257 | clip_context = self.clip.visual( 258 | [first_frame[:, None, :, :], last_frame[:, None, :, :]]) 259 | if offload_model: 260 | self.clip.model.cpu() 261 | 262 | y = self.vae.encode([ 263 | torch.concat([ 264 | torch.nn.functional.interpolate( 265 | first_frame[None].cpu(), 266 | size=(first_frame_h, first_frame_w), 267 | mode='bicubic').transpose(0, 1), 268 | torch.zeros(3, F - 2, first_frame_h, first_frame_w), 269 | torch.nn.functional.interpolate( 270 | last_frame[None].cpu(), 271 | size=(first_frame_h, first_frame_w), 272 | mode='bicubic').transpose(0, 1), 273 | ], 274 | dim=1).to(self.device) 275 | ])[0] 276 | y = torch.concat([msk, y]) 277 | 278 | @contextmanager 279 | def noop_no_sync(): 280 | yield 281 | 282 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 283 | 284 | # evaluation mode 285 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 286 | 287 | if sample_solver == 'unipc': 288 | sample_scheduler = FlowUniPCMultistepScheduler( 289 | num_train_timesteps=self.num_train_timesteps, 290 | shift=1, 291 | use_dynamic_shifting=False) 292 | sample_scheduler.set_timesteps( 293 | sampling_steps, device=self.device, shift=shift) 294 | timesteps = sample_scheduler.timesteps 295 | elif sample_solver == 'dpm++': 296 | sample_scheduler = FlowDPMSolverMultistepScheduler( 297 | num_train_timesteps=self.num_train_timesteps, 298 | shift=1, 299 | use_dynamic_shifting=False) 300 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 301 | timesteps, _ = retrieve_timesteps( 302 | sample_scheduler, 303 | device=self.device, 304 | sigmas=sampling_sigmas) 305 | else: 306 | raise NotImplementedError("Unsupported solver.") 307 | 308 | # sample videos 309 | latent = noise 310 | 311 | arg_c = { 312 | 'context': [context[0]], 313 | 'clip_fea': clip_context, 314 | 'seq_len': max_seq_len, 315 | 'y': [y], 316 | } 317 | 318 | arg_null = { 319 | 'context': context_null, 320 | 'clip_fea': clip_context, 321 | 'seq_len': max_seq_len, 322 | 'y': [y], 323 | } 324 | 325 | if offload_model: 326 | torch.cuda.empty_cache() 327 | 328 | self.model.to(self.device) 329 | for _, t in enumerate(tqdm(timesteps)): 330 | latent_model_input = [latent.to(self.device)] 331 | timestep = [t] 332 | 333 | timestep = torch.stack(timestep).to(self.device) 334 | 335 | noise_pred_cond = self.model( 336 | latent_model_input, t=timestep, **arg_c)[0].to( 337 | torch.device('cpu') if offload_model else self.device) 338 | if offload_model: 339 | torch.cuda.empty_cache() 340 | noise_pred_uncond = self.model( 341 | latent_model_input, t=timestep, **arg_null)[0].to( 342 | torch.device('cpu') if offload_model else self.device) 343 | if offload_model: 344 | torch.cuda.empty_cache() 345 | noise_pred = noise_pred_uncond + guide_scale * ( 346 | noise_pred_cond - noise_pred_uncond) 347 | 348 | latent = latent.to( 349 | torch.device('cpu') if offload_model else self.device) 350 | 351 | temp_x0 = sample_scheduler.step( 352 | noise_pred.unsqueeze(0), 353 | t, 354 | latent.unsqueeze(0), 355 | return_dict=False, 356 | generator=seed_g)[0] 357 | latent = temp_x0.squeeze(0) 358 | 359 | x0 = [latent.to(self.device)] 360 | del latent_model_input, timestep 361 | 362 | if offload_model: 363 | self.model.cpu() 364 | torch.cuda.empty_cache() 365 | 366 | if self.rank == 0: 367 | videos = self.vae.decode(x0) 368 | 369 | del noise, latent 370 | del sample_scheduler 371 | if offload_model: 372 | gc.collect() 373 | torch.cuda.synchronize() 374 | if dist.is_initialized(): 375 | dist.barrier() 376 | 377 | return videos[0] if self.rank == 0 else None 378 | -------------------------------------------------------------------------------- /wan/utils/multitalk_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from einops import rearrange 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from xfuser.core.distributed import ( 8 | get_sequence_parallel_rank, 9 | get_sequence_parallel_world_size, 10 | get_sp_group, 11 | ) 12 | from einops import rearrange, repeat 13 | from functools import lru_cache 14 | import imageio 15 | import uuid 16 | from tqdm import tqdm 17 | import numpy as np 18 | import subprocess 19 | import soundfile as sf 20 | import torchvision 21 | import binascii 22 | import os.path as osp 23 | from skimage import color 24 | 25 | VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") 26 | ASPECT_RATIO_627 = { 27 | '0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), 28 | '0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), 29 | '1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), 30 | '3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} 31 | 32 | 33 | ASPECT_RATIO_960 = { 34 | '0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), 35 | '0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), 36 | '1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), 37 | '1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), 38 | '2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), 39 | '3.75': ([1920, 512], 1)} 40 | 41 | 42 | 43 | def torch_gc(): 44 | torch.cuda.empty_cache() 45 | torch.cuda.ipc_collect() 46 | 47 | 48 | 49 | def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): 50 | 51 | S = T * token_frame 52 | split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] 53 | start = sum(split_sizes[:rank]) 54 | end = start + split_sizes[rank] 55 | counts = [0] * T 56 | for idx in range(start, end): 57 | t = idx // token_frame 58 | counts[t] += 1 59 | 60 | counts_filtered = [] 61 | frame_ids = [] 62 | for t, c in enumerate(counts): 63 | if c > 0: 64 | counts_filtered.append(c) 65 | frame_ids.append(t) 66 | return counts_filtered, frame_ids 67 | 68 | 69 | def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): 70 | 71 | source_min, source_max = source_range 72 | new_min, new_max = target_range 73 | 74 | normalized = (column - source_min) / (source_max - source_min + epsilon) 75 | scaled = normalized * (new_max - new_min) + new_min 76 | return scaled 77 | 78 | 79 | @torch.compile 80 | def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, mode='mean', attn_bias=None): 81 | 82 | ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) 83 | scale = 1.0 / visual_q.shape[-1] ** 0.5 84 | visual_q = visual_q * scale 85 | visual_q = visual_q.transpose(1, 2) 86 | ref_k = ref_k.transpose(1, 2) 87 | attn = visual_q @ ref_k.transpose(-2, -1) 88 | 89 | if attn_bias is not None: 90 | attn = attn + attn_bias 91 | 92 | x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens 93 | 94 | 95 | x_ref_attn_maps = [] 96 | ref_target_masks = ref_target_masks.to(visual_q.dtype) 97 | x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) 98 | 99 | for class_idx, ref_target_mask in enumerate(ref_target_masks): 100 | torch_gc() 101 | ref_target_mask = ref_target_mask[None, None, None, ...] 102 | x_ref_attnmap = x_ref_attn_map_source * ref_target_mask 103 | x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens 104 | x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H 105 | 106 | if mode == 'mean': 107 | x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens 108 | elif mode == 'max': 109 | x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens 110 | 111 | x_ref_attn_maps.append(x_ref_attnmap) 112 | 113 | del attn 114 | del x_ref_attn_map_source 115 | torch_gc() 116 | 117 | return torch.concat(x_ref_attn_maps, dim=0) 118 | 119 | 120 | def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2, enable_sp=False): 121 | """Args: 122 | query (torch.tensor): B M H K 123 | key (torch.tensor): B M H K 124 | shape (tuple): (N_t, N_h, N_w) 125 | ref_target_masks: [B, N_h * N_w] 126 | """ 127 | 128 | N_t, N_h, N_w = shape 129 | if enable_sp: 130 | ref_k = get_sp_group().all_gather(ref_k, dim=1) 131 | 132 | x_seqlens = N_h * N_w 133 | ref_k = ref_k[:, :x_seqlens] 134 | _, seq_lens, heads, _ = visual_q.shape 135 | class_num, _ = ref_target_masks.shape 136 | x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) 137 | 138 | split_chunk = heads // split_num 139 | 140 | for i in range(split_num): 141 | x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks) 142 | x_ref_attn_maps += x_ref_attn_maps_perhead 143 | 144 | return x_ref_attn_maps / split_num 145 | 146 | 147 | def rotate_half(x): 148 | x = rearrange(x, "... (d r) -> ... d r", r=2) 149 | x1, x2 = x.unbind(dim=-1) 150 | x = torch.stack((-x2, x1), dim=-1) 151 | return rearrange(x, "... d r -> ... (d r)") 152 | 153 | 154 | class RotaryPositionalEmbedding1D(nn.Module): 155 | 156 | def __init__(self, 157 | head_dim, 158 | ): 159 | super().__init__() 160 | self.head_dim = head_dim 161 | self.base = 10000 162 | 163 | 164 | @lru_cache(maxsize=32) 165 | def precompute_freqs_cis_1d(self, pos_indices): 166 | 167 | freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) 168 | freqs = freqs.to(pos_indices.device) 169 | freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) 170 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 171 | return freqs 172 | 173 | def forward(self, x, pos_indices): 174 | """1D RoPE. 175 | 176 | Args: 177 | query (torch.tensor): [B, head, seq, head_dim] 178 | pos_indices (torch.tensor): [seq,] 179 | Returns: 180 | query with the same shape as input. 181 | """ 182 | freqs_cis = self.precompute_freqs_cis_1d(pos_indices) 183 | 184 | x_ = x.float() 185 | 186 | freqs_cis = freqs_cis.float().to(x.device) 187 | cos, sin = freqs_cis.cos(), freqs_cis.sin() 188 | cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') 189 | x_ = (x_ * cos) + (rotate_half(x_) * sin) 190 | 191 | return x_.type_as(x) 192 | 193 | 194 | 195 | def rand_name(length=8, suffix=''): 196 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 197 | if suffix: 198 | if not suffix.startswith('.'): 199 | suffix = '.' + suffix 200 | name += suffix 201 | return name 202 | 203 | def cache_video(tensor, 204 | save_file=None, 205 | fps=30, 206 | suffix='.mp4', 207 | nrow=8, 208 | normalize=True, 209 | value_range=(-1, 1), 210 | retry=5): 211 | 212 | # cache file 213 | cache_file = osp.join('/tmp', rand_name( 214 | suffix=suffix)) if save_file is None else save_file 215 | 216 | # save to cache 217 | error = None 218 | for _ in range(retry): 219 | 220 | # preprocess 221 | tensor = tensor.clamp(min(value_range), max(value_range)) 222 | tensor = torch.stack([ 223 | torchvision.utils.make_grid( 224 | u, nrow=nrow, normalize=normalize, value_range=value_range) 225 | for u in tensor.unbind(2) 226 | ], 227 | dim=1).permute(1, 2, 3, 0) 228 | tensor = (tensor * 255).type(torch.uint8).cpu() 229 | 230 | # write video 231 | writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"]) 232 | for frame in tensor.numpy(): 233 | writer.append_data(frame) 234 | writer.close() 235 | return cache_file 236 | 237 | def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): 238 | 239 | def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): 240 | writer = imageio.get_writer( 241 | save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params 242 | ) 243 | for frame in tqdm(frames, desc="Saving video"): 244 | frame = np.array(frame) 245 | writer.append_data(frame) 246 | writer.close() 247 | save_path_tmp = save_path + "-temp.mp4" 248 | 249 | if high_quality_save: 250 | cache_video( 251 | tensor=gen_video_samples.unsqueeze(0), 252 | save_file=save_path_tmp, 253 | fps=fps, 254 | nrow=1, 255 | normalize=True, 256 | value_range=(-1, 1) 257 | ) 258 | else: 259 | video_audio = (gen_video_samples+1)/2 # C T H W 260 | video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() 261 | video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] 262 | save_video(video_audio, save_path_tmp, fps=fps, quality=quality) 263 | 264 | 265 | # crop audio according to video length 266 | _, T, _, _ = gen_video_samples.shape 267 | duration = T / fps 268 | save_path_crop_audio = save_path + "-cropaudio.wav" 269 | final_command = [ 270 | "ffmpeg", 271 | "-i", 272 | vocal_audio_list[0], 273 | "-t", 274 | f'{duration}', 275 | save_path_crop_audio, 276 | ] 277 | subprocess.run(final_command, check=True) 278 | 279 | save_path = save_path + ".mp4" 280 | if high_quality_save: 281 | final_command = [ 282 | "ffmpeg", 283 | "-y", 284 | "-i", save_path_tmp, 285 | "-i", save_path_crop_audio, 286 | "-c:v", "libx264", 287 | "-crf", "0", 288 | "-preset", "veryslow", 289 | "-c:a", "aac", 290 | "-shortest", 291 | save_path, 292 | ] 293 | subprocess.run(final_command, check=True) 294 | os.remove(save_path_tmp) 295 | os.remove(save_path_crop_audio) 296 | else: 297 | final_command = [ 298 | "ffmpeg", 299 | "-y", 300 | "-i", 301 | save_path_tmp, 302 | "-i", 303 | save_path_crop_audio, 304 | "-c:v", 305 | "libx264", 306 | "-c:a", 307 | "aac", 308 | "-shortest", 309 | save_path, 310 | ] 311 | subprocess.run(final_command, check=True) 312 | os.remove(save_path_tmp) 313 | os.remove(save_path_crop_audio) 314 | 315 | 316 | class MomentumBuffer: 317 | def __init__(self, momentum: float): 318 | self.momentum = momentum 319 | self.running_average = 0 320 | 321 | def update(self, update_value: torch.Tensor): 322 | new_average = self.momentum * self.running_average 323 | self.running_average = update_value + new_average 324 | 325 | 326 | 327 | def project( 328 | v0: torch.Tensor, # [B, C, T, H, W] 329 | v1: torch.Tensor, # [B, C, T, H, W] 330 | ): 331 | dtype = v0.dtype 332 | v0, v1 = v0.double(), v1.double() 333 | v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) 334 | v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 335 | v0_orthogonal = v0 - v0_parallel 336 | return v0_parallel.to(dtype), v0_orthogonal.to(dtype) 337 | 338 | 339 | def adaptive_projected_guidance( 340 | diff: torch.Tensor, # [B, C, T, H, W] 341 | pred_cond: torch.Tensor, # [B, C, T, H, W] 342 | momentum_buffer: MomentumBuffer = None, 343 | eta: float = 0.0, 344 | norm_threshold: float = 55, 345 | ): 346 | if momentum_buffer is not None: 347 | momentum_buffer.update(diff) 348 | diff = momentum_buffer.running_average 349 | if norm_threshold > 0: 350 | ones = torch.ones_like(diff) 351 | diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) 352 | print(f"diff_norm: {diff_norm}") 353 | scale_factor = torch.minimum(ones, norm_threshold / diff_norm) 354 | diff = diff * scale_factor 355 | diff_parallel, diff_orthogonal = project(diff, pred_cond) 356 | normalized_update = diff_orthogonal + eta * diff_parallel 357 | return normalized_update 358 | 359 | 360 | 361 | def match_and_blend_colors(source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float) -> torch.Tensor: 362 | """ 363 | Matches the color of a source video chunk to a reference image and blends with the original. 364 | 365 | Args: 366 | source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1]. 367 | Assumes B=1 (batch size of 1). 368 | reference_image (torch.Tensor): The reference image (B, C, 1, H, W) in range [-1, 1]. 369 | Assumes B=1 and T=1 (single reference frame). 370 | strength (float): The strength of the color correction (0.0 to 1.0). 371 | 0.0 means no correction, 1.0 means full correction. 372 | 373 | Returns: 374 | torch.Tensor: The color-corrected and blended video chunk. 375 | """ 376 | # print(f"[match_and_blend_colors] Input source_chunk shape: {source_chunk.shape}, reference_image shape: {reference_image.shape}, strength: {strength}") 377 | 378 | if strength == 0.0: 379 | # print(f"[match_and_blend_colors] Strength is 0, returning original source_chunk.") 380 | return source_chunk 381 | 382 | if not 0.0 <= strength <= 1.0: 383 | raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}") 384 | 385 | device = source_chunk.device 386 | dtype = source_chunk.dtype 387 | 388 | # Squeeze batch dimension, permute to T, H, W, C for skimage 389 | # Source: (1, C, T, H, W) -> (T, H, W, C) 390 | source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() 391 | # Reference: (1, C, 1, H, W) -> (H, W, C) 392 | ref_np = reference_image.squeeze(0).squeeze(1).permute(1, 2, 0).cpu().numpy() # Squeeze T dimension as well 393 | 394 | # Normalize from [-1, 1] to [0, 1] for skimage 395 | source_np_01 = (source_np + 1.0) / 2.0 396 | ref_np_01 = (ref_np + 1.0) / 2.0 397 | 398 | # Clip to ensure values are strictly in [0, 1] after potential float precision issues 399 | source_np_01 = np.clip(source_np_01, 0.0, 1.0) 400 | ref_np_01 = np.clip(ref_np_01, 0.0, 1.0) 401 | 402 | # Convert reference to Lab 403 | try: 404 | ref_lab = color.rgb2lab(ref_np_01) 405 | except ValueError as e: 406 | # Handle potential errors if image data is not valid for conversion 407 | print(f"Warning: Could not convert reference image to Lab: {e}. Skipping color correction for this chunk.") 408 | return source_chunk 409 | 410 | 411 | corrected_frames_np_01 = [] 412 | for i in range(source_np_01.shape[0]): # Iterate over time (T) 413 | source_frame_rgb_01 = source_np_01[i] 414 | 415 | try: 416 | source_lab = color.rgb2lab(source_frame_rgb_01) 417 | except ValueError as e: 418 | print(f"Warning: Could not convert source frame {i} to Lab: {e}. Using original frame.") 419 | corrected_frames_np_01.append(source_frame_rgb_01) 420 | continue 421 | 422 | corrected_lab_frame = source_lab.copy() 423 | 424 | # Perform color transfer for L, a, b channels 425 | for j in range(3): # L, a, b 426 | mean_src, std_src = source_lab[:, :, j].mean(), source_lab[:, :, j].std() 427 | mean_ref, std_ref = ref_lab[:, :, j].mean(), ref_lab[:, :, j].std() 428 | 429 | # Avoid division by zero if std_src is 0 430 | if std_src == 0: 431 | # If source channel has no variation, keep it as is, but shift by reference mean 432 | # This case is debatable, could also just copy source or target mean. 433 | # Shifting by target mean helps if source is flat but target isn't. 434 | corrected_lab_frame[:, :, j] = mean_ref 435 | else: 436 | corrected_lab_frame[:, :, j] = (corrected_lab_frame[:, :, j] - mean_src) * (std_ref / std_src) + mean_ref 437 | 438 | try: 439 | fully_corrected_frame_rgb_01 = color.lab2rgb(corrected_lab_frame) 440 | except ValueError as e: 441 | print(f"Warning: Could not convert corrected frame {i} back to RGB: {e}. Using original frame.") 442 | corrected_frames_np_01.append(source_frame_rgb_01) 443 | continue 444 | 445 | # Clip again after lab2rgb as it can go slightly out of [0,1] 446 | fully_corrected_frame_rgb_01 = np.clip(fully_corrected_frame_rgb_01, 0.0, 1.0) 447 | 448 | # Blend with original source frame (in [0,1] RGB) 449 | blended_frame_rgb_01 = (1 - strength) * source_frame_rgb_01 + strength * fully_corrected_frame_rgb_01 450 | corrected_frames_np_01.append(blended_frame_rgb_01) 451 | 452 | corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0) 453 | 454 | # Convert back to [-1, 1] 455 | corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0 456 | 457 | # Permute back to (C, T, H, W), add batch dim, and convert to original torch.Tensor type and device 458 | # (T, H, W, C) -> (C, T, H, W) 459 | corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0) 460 | corrected_chunk_tensor = corrected_chunk_tensor.contiguous() # Ensure contiguous memory layout 461 | output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype) 462 | # print(f"[match_and_blend_colors] Output tensor shape: {output_tensor.shape}") 463 | return output_tensor 464 | --------------------------------------------------------------------------------