├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── README.md ├── assets ├── logo2.png └── main_pipeline.jpg ├── generate_video.py ├── generate_video_df.py ├── requirements.txt ├── skycaptioner_v1 ├── README.md ├── examples │ ├── data │ │ ├── 1.mp4 │ │ ├── 2.mp4 │ │ ├── 3.mp4 │ │ └── 4.mp4 │ ├── test.csv │ └── test_result.csv ├── infer_fusion_caption.sh ├── infer_struct_caption.sh ├── requirements.txt └── scripts │ ├── utils.py │ ├── vllm_fusion_caption.py │ └── vllm_struct_caption.py └── skyreels_v2_infer ├── __init__.py ├── distributed ├── __init__.py └── xdit_context_parallel.py ├── modules ├── __init__.py ├── attention.py ├── clip.py ├── t5.py ├── tokenizers.py ├── transformer.py ├── vae.py └── xlm_roberta.py ├── pipelines ├── __init__.py ├── diffusion_forcing_pipeline.py ├── image2video_pipeline.py ├── prompt_enhancer.py └── text2video_pipeline.py └── scheduler ├── __init__.py └── fm_solvers_unipc.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoint/* 3 | checkpoint 4 | results/* 5 | .DS_Store 6 | results/* 7 | *.png 8 | *.jpg 9 | *.mp4 10 | *.log* 11 | *.json 12 | scripts/transformer/* 13 | compile_cache 14 | scripts/.gradio/* 15 | *.pkl 16 | # *.csv 17 | *.jsonl 18 | out/* 19 | model/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/reorder-python-imports.git 3 | rev: v3.8.3 4 | hooks: 5 | - id: reorder-python-imports 6 | name: Reorder Python imports 7 | types: [file, python] 8 | - repo: https://github.com/psf/black.git 9 | rev: 22.8.0 10 | hooks: 11 | - id: black 12 | additional_dependencies: ['click==8.0.4'] 13 | args: [--line-length=120] 14 | types: [file, python] 15 | - repo: https://github.com/pre-commit/pre-commit-hooks.git 16 | rev: v4.3.0 17 | hooks: 18 | - id: check-byte-order-marker 19 | types: [file, python] 20 | - id: trailing-whitespace 21 | types: [file, python] 22 | - id: end-of-file-fixer 23 | types: [file, python] 24 | 25 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | --- 2 | language: 3 | - en 4 | - zh 5 | license: other 6 | tasks: 7 | - text-generation 8 | 9 | --- 10 | 11 | 12 | 13 | 14 | # 声明与协议/Terms and Conditions 15 | 16 | ## 声明 17 | 18 | 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。 19 | 20 | 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。 21 | 22 | We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment. 23 | 24 | We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility. 25 | 26 | ## 协议 27 | 28 | 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。 29 | 30 | 31 | The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). 32 | 33 | 34 | 35 | [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf 36 | 37 | 38 | [skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com 39 | -------------------------------------------------------------------------------- /assets/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/assets/logo2.png -------------------------------------------------------------------------------- /assets/main_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/assets/main_pipeline.jpg -------------------------------------------------------------------------------- /generate_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | import time 6 | 7 | import imageio 8 | import torch 9 | from diffusers.utils import load_image 10 | 11 | from skyreels_v2_infer.modules import download_model 12 | from skyreels_v2_infer.pipelines import Image2VideoPipeline 13 | from skyreels_v2_infer.pipelines import PromptEnhancer 14 | from skyreels_v2_infer.pipelines import resizecrop 15 | from skyreels_v2_infer.pipelines import Text2VideoPipeline 16 | 17 | MODEL_ID_CONFIG = { 18 | "text2video": [ 19 | "Skywork/SkyReels-V2-T2V-14B-540P", 20 | "Skywork/SkyReels-V2-T2V-14B-720P", 21 | ], 22 | "image2video": [ 23 | "Skywork/SkyReels-V2-I2V-1.3B-540P", 24 | "Skywork/SkyReels-V2-I2V-14B-540P", 25 | "Skywork/SkyReels-V2-I2V-14B-720P", 26 | ], 27 | } 28 | 29 | 30 | if __name__ == "__main__": 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--outdir", type=str, default="video_out") 34 | parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-T2V-14B-540P") 35 | parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) 36 | parser.add_argument("--num_frames", type=int, default=97) 37 | parser.add_argument("--image", type=str, default=None) 38 | parser.add_argument("--guidance_scale", type=float, default=6.0) 39 | parser.add_argument("--shift", type=float, default=8.0) 40 | parser.add_argument("--inference_steps", type=int, default=30) 41 | parser.add_argument("--use_usp", action="store_true") 42 | parser.add_argument("--offload", action="store_true") 43 | parser.add_argument("--fps", type=int, default=24) 44 | parser.add_argument("--seed", type=int, default=None) 45 | parser.add_argument( 46 | "--prompt", 47 | type=str, 48 | default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.", 49 | ) 50 | parser.add_argument("--prompt_enhancer", action="store_true") 51 | parser.add_argument("--teacache", action="store_true") 52 | parser.add_argument( 53 | "--teacache_thresh", 54 | type=float, 55 | default=0.2, 56 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup") 57 | parser.add_argument( 58 | "--use_ret_steps", 59 | action="store_true", 60 | help="Using Retention Steps will result in faster generation speed and better generation quality.") 61 | args = parser.parse_args() 62 | 63 | args.model_id = download_model(args.model_id) 64 | print("model_id:", args.model_id) 65 | 66 | assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed" 67 | if args.seed is None: 68 | random.seed(time.time()) 69 | args.seed = int(random.randrange(4294967294)) 70 | 71 | if args.resolution == "540P": 72 | height = 544 73 | width = 960 74 | elif args.resolution == "720P": 75 | height = 720 76 | width = 1280 77 | else: 78 | raise ValueError(f"Invalid resolution: {args.resolution}") 79 | 80 | image = load_image(args.image).convert("RGB") if args.image else None 81 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 82 | local_rank = 0 83 | if args.use_usp: 84 | assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." 85 | from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment 86 | import torch.distributed as dist 87 | 88 | dist.init_process_group("nccl") 89 | local_rank = dist.get_rank() 90 | torch.cuda.set_device(dist.get_rank()) 91 | device = "cuda" 92 | 93 | init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) 94 | 95 | initialize_model_parallel( 96 | sequence_parallel_degree=dist.get_world_size(), 97 | ring_degree=1, 98 | ulysses_degree=dist.get_world_size(), 99 | ) 100 | 101 | prompt_input = args.prompt 102 | if args.prompt_enhancer and args.image is None: 103 | print(f"init prompt enhancer") 104 | prompt_enhancer = PromptEnhancer() 105 | prompt_input = prompt_enhancer(prompt_input) 106 | print(f"enhanced prompt: {prompt_input}") 107 | del prompt_enhancer 108 | gc.collect() 109 | torch.cuda.empty_cache() 110 | 111 | if image is None: 112 | assert "T2V" in args.model_id, f"check model_id:{args.model_id}" 113 | print("init text2video pipeline") 114 | pipe = Text2VideoPipeline( 115 | model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload 116 | ) 117 | else: 118 | assert "I2V" in args.model_id, f"check model_id:{args.model_id}" 119 | print("init img2video pipeline") 120 | pipe = Image2VideoPipeline( 121 | model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload 122 | ) 123 | args.image = load_image(args.image) 124 | image_width, image_height = args.image.size 125 | if image_height > image_width: 126 | height, width = width, height 127 | args.image = resizecrop(args.image, height, width) 128 | 129 | if args.teacache: 130 | pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps, 131 | teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps, 132 | ckpt_dir=args.model_id) 133 | 134 | prompt_input = args.prompt 135 | if args.prompt_enhancer and image is not None: 136 | prompt_input = prompt_enhancer(prompt_input) 137 | print(f"enhanced prompt: {prompt_input}") 138 | 139 | kwargs = { 140 | "prompt": prompt_input, 141 | "negative_prompt": negative_prompt, 142 | "num_frames": args.num_frames, 143 | "num_inference_steps": args.inference_steps, 144 | "guidance_scale": args.guidance_scale, 145 | "shift": args.shift, 146 | "generator": torch.Generator(device="cuda").manual_seed(args.seed), 147 | "height": height, 148 | "width": width, 149 | } 150 | 151 | if image is not None: 152 | kwargs["image"] = args.image.convert("RGB") 153 | 154 | save_dir = os.path.join("result", args.outdir) 155 | os.makedirs(save_dir, exist_ok=True) 156 | 157 | with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): 158 | print(f"infer kwargs:{kwargs}") 159 | video_frames = pipe(**kwargs)[0] 160 | 161 | if local_rank == 0: 162 | current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 163 | video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" 164 | output_path = os.path.join(save_dir, video_out_file) 165 | imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"]) 166 | -------------------------------------------------------------------------------- /generate_video_df.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import random 5 | import time 6 | 7 | import imageio 8 | import torch 9 | from diffusers.utils import load_image 10 | 11 | from skyreels_v2_infer import DiffusionForcingPipeline 12 | from skyreels_v2_infer.modules import download_model 13 | from skyreels_v2_infer.pipelines import PromptEnhancer 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--outdir", type=str, default="diffusion_forcing") 19 | parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P") 20 | parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) 21 | parser.add_argument("--num_frames", type=int, default=97) 22 | parser.add_argument("--image", type=str, default=None) 23 | parser.add_argument("--ar_step", type=int, default=0) 24 | parser.add_argument("--causal_attention", action="store_true") 25 | parser.add_argument("--causal_block_size", type=int, default=1) 26 | parser.add_argument("--base_num_frames", type=int, default=97) 27 | parser.add_argument("--overlap_history", type=int, default=None) 28 | parser.add_argument("--addnoise_condition", type=int, default=0) 29 | parser.add_argument("--guidance_scale", type=float, default=6.0) 30 | parser.add_argument("--shift", type=float, default=8.0) 31 | parser.add_argument("--inference_steps", type=int, default=30) 32 | parser.add_argument("--use_usp", action="store_true") 33 | parser.add_argument("--offload", action="store_true") 34 | parser.add_argument("--fps", type=int, default=24) 35 | parser.add_argument("--seed", type=int, default=None) 36 | parser.add_argument( 37 | "--prompt", 38 | type=str, 39 | default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.", 40 | ) 41 | parser.add_argument("--prompt_enhancer", action="store_true") 42 | parser.add_argument("--teacache", action="store_true") 43 | parser.add_argument( 44 | "--teacache_thresh", 45 | type=float, 46 | default=0.2, 47 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup") 48 | parser.add_argument( 49 | "--use_ret_steps", 50 | action="store_true", 51 | help="Using Retention Steps will result in faster generation speed and better generation quality.") 52 | args = parser.parse_args() 53 | 54 | args.model_id = download_model(args.model_id) 55 | print("model_id:", args.model_id) 56 | 57 | assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed" 58 | if args.seed is None: 59 | random.seed(time.time()) 60 | args.seed = int(random.randrange(4294967294)) 61 | 62 | if args.resolution == "540P": 63 | height = 544 64 | width = 960 65 | elif args.resolution == "720P": 66 | height = 720 67 | width = 1280 68 | else: 69 | raise ValueError(f"Invalid resolution: {args.resolution}") 70 | 71 | num_frames = args.num_frames 72 | fps = args.fps 73 | 74 | if num_frames > args.base_num_frames: 75 | assert ( 76 | args.overlap_history is not None 77 | ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' 78 | if args.addnoise_condition > 60: 79 | print( 80 | f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.' 81 | ) 82 | 83 | guidance_scale = args.guidance_scale 84 | shift = args.shift 85 | image = load_image(args.image).convert("RGB") if args.image else None 86 | negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" 87 | 88 | save_dir = os.path.join("result", args.outdir) 89 | os.makedirs(save_dir, exist_ok=True) 90 | local_rank = 0 91 | if args.use_usp: 92 | assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." 93 | from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment 94 | import torch.distributed as dist 95 | 96 | dist.init_process_group("nccl") 97 | local_rank = dist.get_rank() 98 | torch.cuda.set_device(dist.get_rank()) 99 | device = "cuda" 100 | 101 | init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) 102 | 103 | initialize_model_parallel( 104 | sequence_parallel_degree=dist.get_world_size(), 105 | ring_degree=1, 106 | ulysses_degree=dist.get_world_size(), 107 | ) 108 | 109 | prompt_input = args.prompt 110 | if args.prompt_enhancer and args.image is None: 111 | print(f"init prompt enhancer") 112 | prompt_enhancer = PromptEnhancer() 113 | prompt_input = prompt_enhancer(prompt_input) 114 | print(f"enhanced prompt: {prompt_input}") 115 | del prompt_enhancer 116 | gc.collect() 117 | torch.cuda.empty_cache() 118 | 119 | pipe = DiffusionForcingPipeline( 120 | args.model_id, 121 | dit_path=args.model_id, 122 | device=torch.device("cuda"), 123 | weight_dtype=torch.bfloat16, 124 | use_usp=args.use_usp, 125 | offload=args.offload, 126 | ) 127 | 128 | if args.causal_attention: 129 | pipe.transformer.set_ar_attention(args.causal_block_size) 130 | 131 | if args.teacache: 132 | if args.ar_step > 0: 133 | num_steps = args.inference_steps + (((args.base_num_frames - 1)//4 + 1) // args.causal_block_size - 1) * args.ar_step 134 | print('num_steps:', num_steps) 135 | else: 136 | num_steps = args.inference_steps 137 | pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps, 138 | teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps, 139 | ckpt_dir=args.model_id) 140 | 141 | print(f"prompt:{prompt_input}") 142 | print(f"guidance_scale:{guidance_scale}") 143 | 144 | with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): 145 | video_frames = pipe( 146 | prompt=prompt_input, 147 | negative_prompt=negative_prompt, 148 | image=image, 149 | height=height, 150 | width=width, 151 | num_frames=num_frames, 152 | num_inference_steps=args.inference_steps, 153 | shift=shift, 154 | guidance_scale=guidance_scale, 155 | generator=torch.Generator(device="cuda").manual_seed(args.seed), 156 | overlap_history=args.overlap_history, 157 | addnoise_condition=args.addnoise_condition, 158 | base_num_frames=args.base_num_frames, 159 | ar_step=args.ar_step, 160 | causal_block_size=args.causal_block_size, 161 | fps=fps, 162 | )[0] 163 | 164 | if local_rank == 0: 165 | current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 166 | video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" 167 | output_path = os.path.join(save_dir, video_out_file) 168 | imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchvision==0.20.1 3 | opencv-python==4.10.0.84 4 | diffusers>=0.31.0 5 | transformers==4.49.0 6 | tokenizers==0.21.1 7 | accelerate==1.6.0 8 | tqdm 9 | imageio 10 | easydict 11 | ftfy 12 | dashscope 13 | imageio-ffmpeg 14 | flash_attn 15 | numpy>=1.23.5,<2 16 | xfuser 17 | -------------------------------------------------------------------------------- /skycaptioner_v1/README.md: -------------------------------------------------------------------------------- 1 | # SkyCaptioner-V1: A Structural Video Captioning Model 2 | 3 |

4 | 📑 Technical Report · 👋 Playground · 💬 Discord · 🤗 Hugging Face · 🤖 ModelScope 5 |

6 | 7 | --- 8 | 9 | Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively. 10 | 11 | ## 🔥🔥🔥 News!! 12 | 13 | * Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code. 14 | * Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074). 15 | 16 | ## 📑 TODO List 17 | 18 | - SkyCaptioner-V1 19 | 20 | - [x] Checkpoints 21 | - [x] Batch Inference Code 22 | - [x] Caption Fusion Method 23 | - [ ] Web Demo (Gradio) 24 | 25 | ## 🌟 Overview 26 | 27 | SkyCaptioner-V1 is a structural video captioning model designed to generate high-quality, structural descriptions for video data. It integrates specialized sub-expert models and multimodal large language models (MLLMs) with human annotations to address the limitations of general captioners in capturing professional film-related details. Key aspects include: 28 | 29 | 1. ​​**Structural Representation**​: Combines general video descriptions (from MLLMs) with sub-expert captioner (e.g., shot types,shot angles, shot positions, camera motions.) and human annotations. 30 | 2. ​​**Knowledge Distillation**​: Distills expertise from sub-expert captioners into a unified model. 31 | 3. ​​**Application Flexibility**​: Generates dense captions for text-to-video (T2V) and concise prompts for image-to-video (I2V) tasks. 32 | 33 | ## 🔑 Key Features 34 | 35 | ### Structural Captioning Framework 36 | 37 | Our Video Captioning model captures multi-dimensional details: 38 | 39 | * ​​**Subjects**​: Appearance, action, expression, position, and hierarchical categorization. 40 | * ​​**Shot Metadata**​: Shot type (e.g., close-up, long shot), shot angle, shot position, camera motion, environment, lighting, etc. 41 | 42 | ### Sub-Expert Integration 43 | 44 | * ​​**Shot Captioner**​: Classifies shot type, angle, and position with high precision. 45 | * ​​**Expression Captioner**​: Analyzes facial expressions, emotion intensity, and temporal dynamics. 46 | * ​​**Camera Motion Captioner**​: Tracks 6DoF camera movements and composite motion types, 47 | 48 | ### Training Pipeline 49 | 50 | * Trained on \~2M high-quality, concept-balanced videos curated from 10M raw samples. 51 | * Fine-tuned on Qwen2.5-VL-7B-Instruct with a global batch size of 512 across 32 A800 GPUs. 52 | * Optimized using AdamW (learning rate: 1e-5) for 2 epochs. 53 | 54 | ### Dynamic Caption Fusion: 55 | 56 | * Adapts output length based on application (T2V/I2V). 57 | * Employs LLM Model to fusion structural fields to get a natural and fluency caption for downstream tasks. 58 | 59 | ## 📊 Benchmark Results 60 | 61 | SkyCaptioner-V1 demonstrates significant improvements over existing models in key film-specific captioning tasks, particularly in ​**shot-language understanding** and ​​**domain-specific precision**​. The differences stem from its structural architecture and expert-guided training: 62 | 63 | 1. ​​**Superior Shot-Language Understanding**​: 64 | * ​Our Captioner model outperforms Qwen2.5-VL-72B with +11.2% in shot type, +16.1% in shot angle, and +50.4% in shot position accuracy. Because SkyCaptioner-V1’s specialized shot classifiers outperform generalist MLLMs, which lack film-domain fine-tuning. 65 | * ​+28.5% accuracy in camera motion vs. Tarsier2-recap-7B (88.8% vs. 41.5%): 66 | Its 6DoF motion analysis and active learning pipeline address ambiguities in composite motions (e.g., tracking + panning) that challenge generic captioners. 67 | 2. ​​**High domain-specific precision**​: 68 | * ​​Expression accuracy​: ​68.8% vs. 54.3% (Tarsier2-recap-7B), leveraging temporal-aware S2D frameworks to capture dynamic facial changes. 69 | 70 |

71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 |
MetricQwen2.5-VL-7B-Ins.Qwen2.5-VL-72B-Ins.Tarsier2-recap-7BSkyCaptioner-V1
Avg accuracy51.4%58.7%49.4%76.3%
shot type76.8%82.5%60.2%93.7%
shot angle60.0%73.7%52.4%89.8%
shot position28.4%32.7%23.6%83.1%
camera motion62.0%61.2%45.3%85.3%
expression43.6%51.5%54.3%68.8%
TYPES_type43.5%49.7%47.6%82.5%
TYPES_sub_type38.9%44.9%45.9%75.4%
appearance40.9%52.0%45.6%59.3%
action32.4%52.0%69.8%68.8%
position35.4%48.6%45.5%57.5%
is_main_subject58.5%68.7%69.7%80.9%
environment70.4%72.7%61.4%70.5%
lighting77.1%80.0%21.2%76.5%
182 |

183 | 184 | ## 📦 Model Downloads 185 | 186 | Our SkyCaptioner-V1 model can be downloaded from [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). 187 | We use [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) as our caption fusion model to intelligently combine structured caption fields, producing either dense or sparse final captions depending on application requirements. 188 | 189 | ```shell 190 | # download SkyCaptioner-V1 191 | huggingface-cli download Skywork/SkyCaptioner-V1 --local-dir /path/to/your_local_model_path 192 | # download Qwen2.5-32B-Instruct 193 | huggingface-cli download Qwen/Qwen2.5-32B-Instruct --local-dir /path/to/your_local_model_path2 194 | ``` 195 | 196 | ## 🛠️ Running Guide 197 | 198 | Begin by cloning the repository: 199 | 200 | ```shell 201 | git clone https://github.com/SkyworkAI/SkyReels-V2 202 | cd skycaptioner_v1 203 | ``` 204 | 205 | ### Installation Guide for Linux 206 | 207 | We recommend Python 3.10 and CUDA version 12.2 for the manual installation. 208 | 209 | ```shell 210 | pip install -r requirements.txt 211 | ``` 212 | 213 | ### Running Command 214 | 215 | #### Get Structural Caption by SkyCaptioner-V1 216 | 217 | ```shell 218 | export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path" 219 | 220 | python scripts/vllm_struct_caption.py \ 221 | --model_path ${SkyCaptioner_V1_Model_PATH} \ 222 | --input_csv "./examples/test.csv" \ 223 | --out_csv "./examples/test_result.csv" \ 224 | --tp 1 \ 225 | --bs 4 226 | ``` 227 | 228 | #### T2V/I2V Caption Fusion by Qwen2.5-32B-Instruct Model 229 | 230 | ```shell 231 | export LLM_MODEL_PATH="/path/to/your_local_model_path2" 232 | 233 | python scripts/vllm_fusion_caption.py \ 234 | --model_path ${LLM_MODEL_PATH} \ 235 | --input_csv "./examples/test_result.csv" \ 236 | --out_csv "./examples/test_result_caption.csv" \ 237 | --bs 4 \ 238 | --tp 1 \ 239 | --task t2v 240 | ``` 241 | > **Note**: 242 | > - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command. 243 | 244 | ## Acknowledgements 245 | 246 | We would like to thank the contributors of Qwen2.5-VL, tarsier2 and vllm repositories, for their open research and contributions. 247 | 248 | ## Citation 249 | 250 | ```bibtex 251 | @misc{chen2025skyreelsv2infinitelengthfilmgenerative, 252 | author = {Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou}, 253 | title = {Skyreels V2:Infinite-Length Film Generative Model}, 254 | year = {2025}, 255 | eprint={2504.13074}, 256 | archivePrefix={arXiv}, 257 | primaryClass={cs.CV}, 258 | url={https://arxiv.org/abs/2504.13074} 259 | } 260 | ``` 261 | 262 | 263 | -------------------------------------------------------------------------------- /skycaptioner_v1/examples/data/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/1.mp4 -------------------------------------------------------------------------------- /skycaptioner_v1/examples/data/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/2.mp4 -------------------------------------------------------------------------------- /skycaptioner_v1/examples/data/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/3.mp4 -------------------------------------------------------------------------------- /skycaptioner_v1/examples/data/4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/4.mp4 -------------------------------------------------------------------------------- /skycaptioner_v1/examples/test.csv: -------------------------------------------------------------------------------- 1 | path 2 | ./examples/data/1.mp4 3 | ./examples/data/2.mp4 4 | ./examples/data/3.mp4 5 | ./examples/data/4.mp4 -------------------------------------------------------------------------------- /skycaptioner_v1/examples/test_result.csv: -------------------------------------------------------------------------------- 1 | path,structural_caption 2 | ./examples/data/1.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Wearing winter sports gear, including a helmet and goggles."", ""action"": ""The video shows a snowy mountain landscape with a ski slope surrounded by dense trees and distant mountains. A skier is seen descending the slope, moving from the top left to the bottom right of the frame. The skier maintains a steady pace, navigating the curves of the slope. The background includes a ski lift with chairs moving along the cables, and the slope is marked with red and white lines indicating the path for skiers. The skier continues to descend, gradually getting closer to the bottom of the slope."", ""expression"": """", ""position"": ""Centered in the frame, moving downwards on the slope."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""White, covering the ground and trees."", ""action"": """", ""expression"": """", ""position"": ""Surrounding the skier, covering the entire visible area."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Plant"", ""sub_type"": ""Tree""}, ""appearance"": ""Tall, evergreen, covered in snow."", ""action"": """", ""expression"": """", ""position"": ""Scattered throughout the scene, both in the foreground and background."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""Snow-covered, with ski lifts and structures."", ""action"": """", ""expression"": """", ""position"": ""In the background, providing context to the location."", ""is_main_subject"": false}], ""shot_type"": ""long_shot"", ""shot_angle"": ""high_angle"", ""shot_position"": ""overlooking_view"", ""camera_motion"": ""the camera moves toward zooms in"", ""environment"": ""A snowy mountain landscape with a ski slope, trees, and a ski resort in the background."", ""lighting"": ""Bright daylight, casting shadows and highlighting the snow's texture.""}" 3 | ./examples/data/2.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Human"", ""sub_type"": ""Woman""}, ""appearance"": ""Long, straight black hair, wearing a sparkling choker necklace with a diamond-like texture, light-colored top, subtle makeup with pink lipstick, stud earrings."", ""action"": ""A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her. The background outside the car is green, indicating a possible outdoor setting."", ""expression"": ""The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor."", ""position"": ""Seated in the foreground of the car, facing slightly to the right."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Human"", ""sub_type"": ""Man""}, ""appearance"": ""Short hair, wearing a dark-colored suit with a white shirt."", ""action"": """", ""expression"": """", ""position"": ""Seated in the background of the car, facing the woman."", ""is_main_subject"": false}], ""shot_type"": ""close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""side_view"", ""camera_motion"": """", ""environment"": ""Interior of a car with dark upholstery."", ""lighting"": ""Soft and natural lighting, suggesting daytime.""}" 4 | ./examples/data/3.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Animal"", ""sub_type"": ""Insect""}, ""appearance"": ""The spider has a spherical, yellowish-green body with darker green stripes and spots. It has eight slender legs with visible joints and segments."", ""action"": ""A spider with a yellow and green body and black and white striped legs is hanging from its web in a natural setting with a blurred background of green and brown hues. The spider remains mostly still, with slight movements in its legs and body, indicating a gentle swaying motion."", ""expression"": """", ""position"": ""The spider is centrally positioned in the frame, hanging from a web."", ""is_main_subject"": true}], ""shot_type"": ""extreme_close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""front_view"", ""camera_motion"": """", ""environment"": ""The background consists of vertical, out-of-focus lines in shades of green and brown, suggesting a natural environment with vegetation."", ""lighting"": ""The lighting is soft and diffused, with no harsh shadows, indicating an overcast sky or a shaded area.""}" 5 | ./examples/data/4.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Football""}, ""appearance"": ""Wearing a dark-colored jersey, black shorts, and bright blue soccer shoes with white soles."", ""action"": ""A man is on a grassy field with orange cones placed in a line. He is wearing a gray shirt, black shorts, and black socks with blue shoes. The man starts by standing still with his feet apart, then begins to move forward while keeping his eyes on the cones. He continues to run forward, maintaining his focus on the cones, and his feet move in a coordinated manner to navigate around them. The background shows a clear sky, some trees, and a few buildings in the distance."", ""expression"": """", ""position"": ""Centered in the frame, with the soccer ball positioned between the cones."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Bright orange, conical shape."", ""action"": """", ""expression"": """", ""position"": ""Placed on the grass, creating a path for the soccer player."", ""is_main_subject"": false}], ""shot_type"": ""full_shot"", ""shot_angle"": ""low_angle"", ""shot_position"": ""front_view"", ""camera_motion"": ""use a tracking shot, the camera moves toward zooms in"", ""environment"": ""Outdoor sports field with well-maintained grass, trees, and a clear blue sky."", ""lighting"": ""Bright and natural, suggesting a sunny day.""}" -------------------------------------------------------------------------------- /skycaptioner_v1/infer_fusion_caption.sh: -------------------------------------------------------------------------------- 1 | expor LLM_MODEL_PATH="/path/to/your_local_model_path2" 2 | 3 | python scripts/vllm_fusion_caption.py \ 4 | --model_path ${LLM_MODEL_PATH} \ 5 | --input_csv "./examples/test_result.csv" \ 6 | --out_csv "./examples/test_result_caption.csv" \ 7 | --bs 4 \ 8 | --tp 1 \ 9 | --task t2v 10 | -------------------------------------------------------------------------------- /skycaptioner_v1/infer_struct_caption.sh: -------------------------------------------------------------------------------- 1 | expor SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path" 2 | 3 | python scripts/vllm_struct_caption.py \ 4 | --model_path ${SkyCaptioner_V1_Model_PATH} \ 5 | --input_csv "./examples/test.csv" \ 6 | --out_csv "./examepls/test_result.csv" \ 7 | --tp 1 \ 8 | --bs 32 -------------------------------------------------------------------------------- /skycaptioner_v1/requirements.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | transformers>=4.49.0 3 | vllm==0.8.4 -------------------------------------------------------------------------------- /skycaptioner_v1/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column): 5 | flat_indices = [] 6 | for x in zip(indices_list): 7 | flat_indices.extend(x) 8 | flat_results = [] 9 | for x in zip(result_list): 10 | flat_results.extend(x) 11 | 12 | flat_indices = np.array(flat_indices) 13 | flat_results = np.array(flat_results) 14 | 15 | unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) 16 | meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx] 17 | 18 | meta = meta.loc[unique_indices] 19 | return meta -------------------------------------------------------------------------------- /skycaptioner_v1/scripts/vllm_fusion_caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import argparse 4 | import glob 5 | import time 6 | import gc 7 | from tqdm import tqdm 8 | import torch 9 | from transformers import AutoTokenizer 10 | import pandas as pd 11 | from vllm import LLM, SamplingParams 12 | from torch.utils.data import DataLoader 13 | import json 14 | import random 15 | from utils import result_writer 16 | 17 | SYSTEM_PROMPT_I2V = """ 18 | You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English. 19 | 20 | ## Structured Input 21 | {structured_input} 22 | 23 | ## Notes 24 | 1. If there has an empty field, just ignore it and do not mention it in the output. 25 | 2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning. 26 | 3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field. 27 | 28 | ## Output Principles and Orders 29 | 1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information. 30 | 2. Second, describe each subject with its pure action and expression if these fields exist. 31 | 32 | ## Output 33 | Please directly output the final composed caption without any additional information. 34 | """ 35 | 36 | SYSTEM_PROMPT_T2V = """ 37 | You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English. 38 | 39 | ## Structured Input 40 | {structured_input} 41 | 42 | ## Notes 43 | 1. According to the action field information, change its name field to the subject pronoun in the action. 44 | 2. If there has an empty field, just ignore it and do not mention it in the output. 45 | 3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning. 46 | 47 | ## Output Principles and Orders 48 | 1. First, declare the shot_type, then declare the shot_angle and the shot_position fields. 49 | 2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty. 50 | 3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist. 51 | 4. Finally, declare the environment and lighting if the environment and lighting fields are not empty. 52 | 53 | ## Output 54 | Please directly output the final composed caption without any additional information. 55 | """ 56 | 57 | SHOT_TYPE_LIST = [ 58 | 'close-up shot', 59 | 'extreme close-up shot', 60 | 'medium shot', 61 | 'long shot', 62 | 'full shot', 63 | ] 64 | 65 | 66 | class StructuralCaptionDataset(torch.utils.data.Dataset): 67 | def __init__(self, input_csv, model_path): 68 | self.meta = pd.read_csv(input_csv) 69 | self.task = args.task 70 | self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V 71 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 72 | 73 | 74 | def __len__(self): 75 | return len(self.meta) 76 | 77 | def __getitem__(self, index): 78 | row = self.meta.iloc[index] 79 | real_index = self.meta.index[index] 80 | 81 | struct_caption = json.loads(row["structural_caption"]) 82 | 83 | camera_movement = struct_caption.get('camera_motion', '') 84 | if camera_movement != '': 85 | camera_movement += '.' 86 | camera_movement = camera_movement.capitalize() 87 | 88 | fusion_by_llm = False 89 | cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task) 90 | if cleaned_struct_caption.get('num_subjects', 0) > 0: 91 | new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False) 92 | conversation = [ 93 | { 94 | "role": "system", 95 | "content": self.system_prompt.format(structured_input=new_struct_caption), 96 | }, 97 | ] 98 | text = self.tokenizer.apply_chat_template( 99 | conversation, 100 | tokenize=False, 101 | add_generation_prompt=True 102 | ) 103 | fusion_by_llm = True 104 | else: 105 | text = '-' 106 | return real_index, fusion_by_llm, text, '-', camera_movement 107 | 108 | def clean_struct_caption(self, struct_caption, task): 109 | raw_subjects = struct_caption.get('subjects', []) 110 | subjects = [] 111 | for subject in raw_subjects: 112 | subject_type = subject.get("TYPES", {}).get('type', '') 113 | subject_sub_type = subject.get("TYPES", {}).get('sub_type', '') 114 | if subject_type not in ["Human", "Animal"]: 115 | subject['expression'] = '' 116 | if subject_type == 'Human' and subject_sub_type == 'Accessory': 117 | subject['expression'] = '' 118 | if subject_sub_type != '': 119 | subject['name'] = subject_sub_type 120 | if 'TYPES' in subject: 121 | del subject['TYPES'] 122 | if 'is_main_subject' in subject: 123 | del subject['is_main_subject'] 124 | subjects.append(subject) 125 | 126 | to_del_subject_ids = [] 127 | for idx, subject in enumerate(subjects): 128 | action = subject.get('action', '').strip() 129 | subject['action'] = action 130 | if random.random() > 0.9 and 'appearance' in subject: 131 | del subject['appearance'] 132 | if random.random() > 0.9 and 'position' in subject: 133 | del subject['position'] 134 | if task == 'i2v': 135 | # just keep name and action, expression in subjects 136 | dropped_keys = ['appearance', 'position'] 137 | for key in dropped_keys: 138 | if key in subject: 139 | del subject[key] 140 | if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''): 141 | to_del_subject_ids.append(idx) 142 | 143 | # delete the subjects according to the to_del_subject_ids 144 | for idx in sorted(to_del_subject_ids, reverse=True): 145 | del subjects[idx] 146 | 147 | 148 | shot_type = struct_caption.get('shot_type', '').replace('_', ' ') 149 | if shot_type not in SHOT_TYPE_LIST: 150 | struct_caption['shot_type'] = '' 151 | 152 | new_struct_caption = { 153 | 'num_subjects': len(subjects), 154 | 'subjects': subjects, 155 | 'shot_type': struct_caption.get('shot_type', ''), 156 | 'shot_angle': struct_caption.get('shot_angle', ''), 157 | 'shot_position': struct_caption.get('shot_position', ''), 158 | 'environment': struct_caption.get('environment', ''), 159 | 'lighting': struct_caption.get('lighting', ''), 160 | } 161 | 162 | if task == 't2v' and random.random() > 0.9: 163 | del new_struct_caption['lighting'] 164 | 165 | if task == 'i2v': 166 | drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position'] 167 | for drop_key in drop_keys: 168 | del new_struct_caption[drop_key] 169 | return new_struct_caption 170 | 171 | def custom_collate_fn(batch): 172 | real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch) 173 | return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements) 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser(description="Caption Fusion by LLM") 178 | parser.add_argument("--input_csv", default="./examples/test_result.csv") 179 | parser.add_argument("--out_csv", default="./examples/test_result_caption.csv") 180 | parser.add_argument("--bs", type=int, default=4) 181 | parser.add_argument("--tp", type=int, default=1) 182 | parser.add_argument("--model_path", required=True, type=str, help="LLM model path") 183 | parser.add_argument("--task", default='t2v', help="t2v or i2v") 184 | 185 | args = parser.parse_args() 186 | 187 | sampling_params = SamplingParams( 188 | temperature=0.1, 189 | max_tokens=512, 190 | stop=['\n\n'] 191 | ) 192 | # model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/" 193 | 194 | 195 | llm = LLM( 196 | model=args.model_path, 197 | gpu_memory_utilization=0.9, 198 | max_model_len=4096, 199 | tensor_parallel_size = args.tp 200 | ) 201 | 202 | 203 | dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path) 204 | 205 | dataloader = DataLoader( 206 | dataset, 207 | batch_size=args.bs, 208 | num_workers=8, 209 | collate_fn=custom_collate_fn, 210 | shuffle=False, 211 | drop_last=False, 212 | ) 213 | 214 | indices_list = [] 215 | result_list = [] 216 | for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader): 217 | llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], [] 218 | for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements): 219 | if fusion_by_llm: 220 | llm_indices.append(idx) 221 | llm_texts.append(text) 222 | llm_original_texts.append(original_text) 223 | llm_camera_movements.append(camera_movement) 224 | else: 225 | indices_list.append(idx) 226 | caption = original_text + " " + camera_movement 227 | result_list.append(caption) 228 | if len(llm_texts) > 0: 229 | try: 230 | outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False) 231 | results = [] 232 | for output in outputs: 233 | result = output.outputs[0].text.strip() 234 | results.append(result) 235 | indices_list.extend(llm_indices) 236 | except Exception as e: 237 | print(f"Error at {llm_indices}: {str(e)}") 238 | indices_list.extend(llm_indices) 239 | results = llm_original_texts 240 | 241 | for result, camera_movement in zip(results, llm_camera_movements): 242 | # concat camera movement to fusion_caption 243 | llm_caption = result + " " + camera_movement 244 | result_list.append(llm_caption) 245 | torch.cuda.empty_cache() 246 | gc.collect() 247 | gathered_list = [indices_list, result_list] 248 | meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"]) 249 | meta_new.to_csv(args.out_csv, index=False) 250 | 251 | -------------------------------------------------------------------------------- /skycaptioner_v1/scripts/vllm_struct_caption.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import decord 4 | import argparse 5 | 6 | import pandas as pd 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | from vllm import LLM, SamplingParams 11 | from transformers import AutoTokenizer, AutoProcessor 12 | 13 | from torch.utils.data import DataLoader 14 | 15 | SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}" 16 | 17 | 18 | class VideoTextDataset(torch.utils.data.Dataset): 19 | def __init__(self, csv_path, model_path): 20 | self.meta = pd.read_csv(csv_path) 21 | self._path = 'path' 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 23 | self.processor = AutoProcessor.from_pretrained(model_path) 24 | 25 | def __getitem__(self, index): 26 | row = self.meta.iloc[index] 27 | path = row[self._path] 28 | real_index = self.meta.index[index] 29 | vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420) 30 | start = 0 31 | end = len(vr) 32 | # avg_fps = vr.get_avg_fps() 33 | index = self.get_index(end-start, 16, st=start) 34 | frames = vr.get_batch(index).asnumpy() # n h w c 35 | video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)] 36 | conversation = { 37 | "role": "user", 38 | "content": [ 39 | { 40 | "type": "video", 41 | "video": row['path'], 42 | "max_pixels": 360 * 420, # 460800 43 | "fps": 2.0, 44 | }, 45 | { 46 | "type": "text", 47 | "text": SYSTEM_PROMPT 48 | }, 49 | ], 50 | } 51 | 52 | # 生成 user_input 53 | user_input = self.processor.apply_chat_template( 54 | [conversation], 55 | tokenize=False, 56 | add_generation_prompt=True 57 | ) 58 | results = dict() 59 | inputs = { 60 | 'prompt': user_input, 61 | 'multi_modal_data': {'video': video_inputs} 62 | } 63 | results["index"] = real_index 64 | results['input'] = inputs 65 | return results 66 | 67 | def __len__(self): 68 | return len(self.meta) 69 | 70 | def get_index(self, video_size, num_frames, st=0): 71 | seg_size = max(0., float(video_size - 1) / num_frames) 72 | max_frame = int(video_size) - 1 73 | seq = [] 74 | # index from 1, must add 1 75 | for i in range(num_frames): 76 | start = int(np.round(seg_size * i)) 77 | # end = int(np.round(seg_size * (i + 1))) 78 | idx = min(start, max_frame) 79 | seq.append(idx+st) 80 | return seq 81 | 82 | def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column): 83 | flat_indices = [] 84 | for x in zip(indices_list): 85 | flat_indices.extend(x) 86 | flat_results = [] 87 | for x in zip(result_list): 88 | flat_results.extend(x) 89 | 90 | flat_indices = np.array(flat_indices) 91 | flat_results = np.array(flat_results) 92 | 93 | unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True) 94 | meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx] 95 | 96 | meta = meta.loc[unique_indices] 97 | return meta 98 | 99 | 100 | def worker_init_fn(worker_id): 101 | # Set different seed for each worker 102 | worker_seed = torch.initial_seed() % 2**32 103 | np.random.seed(worker_seed) 104 | # Prevent deadlocks by setting timeout 105 | torch.set_num_threads(1) 106 | 107 | def main(): 108 | parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference") 109 | parser.add_argument("--input_csv", default="./examples/test.csv") 110 | parser.add_argument("--out_csv", default="./examples/test_result.csv") 111 | parser.add_argument("--bs", type=int, default=4) 112 | parser.add_argument("--tp", type=int, default=1) 113 | parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path") 114 | args = parser.parse_args() 115 | 116 | dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path) 117 | dataloader = DataLoader( 118 | dataset, 119 | batch_size=args.bs, 120 | num_workers=4, 121 | worker_init_fn=worker_init_fn, 122 | persistent_workers=True, 123 | timeout=180, 124 | ) 125 | 126 | sampling_params = SamplingParams(temperature=0.05, max_tokens=2048) 127 | 128 | llm = LLM(model=args.model_path, 129 | gpu_memory_utilization=0.6, 130 | max_model_len=31920, 131 | tensor_parallel_size=args.tp) 132 | 133 | indices_list = [] 134 | caption_save = [] 135 | for video_batch in tqdm(dataloader): 136 | indices = video_batch["index"] 137 | inputs = video_batch["input"] 138 | batch_user_inputs = [] 139 | for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]): 140 | usi={'prompt':prompt, 'multi_modal_data':{'video':video}} 141 | batch_user_inputs.append(usi) 142 | outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False) 143 | struct_outputs = [output.outputs[0].text for output in outputs] 144 | 145 | indices_list.extend(indices.tolist()) 146 | caption_save.extend(struct_outputs) 147 | 148 | meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"]) 149 | meta_new.to_csv(args.out_csv, index=False) 150 | print(f'Saved structural_caption to {args.out_csv}') 151 | 152 | if __name__ == '__main__': 153 | main() -------------------------------------------------------------------------------- /skyreels_v2_infer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import DiffusionForcingPipeline 2 | -------------------------------------------------------------------------------- /skyreels_v2_infer/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skyreels_v2_infer/distributed/__init__.py -------------------------------------------------------------------------------- /skyreels_v2_infer/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.amp as amp 3 | from torch.backends.cuda import sdp_kernel 4 | from xfuser.core.distributed import get_sequence_parallel_rank 5 | from xfuser.core.distributed import get_sequence_parallel_world_size 6 | from xfuser.core.distributed import get_sp_group 7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 8 | 9 | from ..modules.transformer import sinusoidal_embedding_1d 10 | 11 | 12 | def pad_freqs(original_tensor, target_len): 13 | seq_len, s1, s2 = original_tensor.shape 14 | pad_size = target_len - seq_len 15 | padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) 16 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 17 | return padded_tensor 18 | 19 | 20 | @amp.autocast("cuda", enabled=False) 21 | def rope_apply(x, grid_sizes, freqs): 22 | """ 23 | x: [B, L, N, C]. 24 | grid_sizes: [B, 3]. 25 | freqs: [M, C // 2]. 26 | """ 27 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 28 | # split freqs 29 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 30 | 31 | # loop over samples 32 | output = [] 33 | grid = [grid_sizes.tolist()] * x.size(0) 34 | for i, (f, h, w) in enumerate(grid): 35 | seq_len = f * h * w 36 | 37 | # precompute multipliers 38 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) 39 | freqs_i = torch.cat( 40 | [ 41 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 42 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 43 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), 44 | ], 45 | dim=-1, 46 | ).reshape(seq_len, 1, -1) 47 | 48 | # apply rotary embedding 49 | sp_size = get_sequence_parallel_world_size() 50 | sp_rank = get_sequence_parallel_rank() 51 | freqs_i = pad_freqs(freqs_i, s * sp_size) 52 | s_per_rank = s 53 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :] 54 | x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2) 55 | x_i = torch.cat([x_i, x[i, s:]]) 56 | 57 | # append to collection 58 | output.append(x_i) 59 | return torch.stack(output).float() 60 | 61 | 62 | def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None): 63 | """ 64 | x: A list of videos each with shape [C, T, H, W]. 65 | t: [B]. 66 | context: A list of text embeddings each with shape [L, C]. 67 | """ 68 | if self.model_type == "i2v": 69 | assert clip_fea is not None and y is not None 70 | # params 71 | device = self.patch_embedding.weight.device 72 | if self.freqs.device != device: 73 | self.freqs = self.freqs.to(device) 74 | 75 | if y is not None: 76 | x = torch.cat([x, y], dim=1) 77 | 78 | # embeddings 79 | x = self.patch_embedding(x) 80 | grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long) 81 | x = x.flatten(2).transpose(1, 2) 82 | 83 | if self.flag_causal_attention: 84 | frame_num = grid_sizes[0] 85 | height = grid_sizes[1] 86 | width = grid_sizes[2] 87 | block_num = frame_num // self.num_frame_per_block 88 | range_tensor = torch.arange(block_num).view(-1, 1) 89 | range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() 90 | casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f 91 | casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) 92 | casual_mask = casual_mask.repeat(1, height, width, 1, height, width) 93 | casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) 94 | self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0) 95 | 96 | # time embeddings 97 | with amp.autocast("cuda", dtype=torch.float32): 98 | if t.dim() == 2: 99 | b, f = t.shape 100 | _flag_df = True 101 | else: 102 | _flag_df = False 103 | e = self.time_embedding( 104 | sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) 105 | ) # b, dim 106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim 107 | 108 | if self.inject_sample_info: 109 | fps = torch.tensor(fps, dtype=torch.long, device=device) 110 | 111 | fps_emb = self.fps_embedding(fps).float() 112 | if _flag_df: 113 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) 114 | else: 115 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) 116 | 117 | if _flag_df: 118 | e = e.view(b, f, 1, 1, self.dim) 119 | e0 = e0.view(b, f, 1, 1, 6, self.dim) 120 | e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) 121 | e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) 122 | e0 = e0.transpose(1, 2).contiguous() 123 | 124 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 125 | 126 | # context 127 | context = self.text_embedding(context) 128 | 129 | if clip_fea is not None: 130 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 131 | context = torch.concat([context_clip, context], dim=1) 132 | 133 | # arguments 134 | if e0.ndim == 4: 135 | e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()] 136 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask) 137 | 138 | # Context Parallel 139 | x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] 140 | 141 | for block in self.blocks: 142 | x = block(x, **kwargs) 143 | 144 | # head 145 | if e.ndim == 3: 146 | e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] 147 | x = self.head(x, e) 148 | 149 | # Context Parallel 150 | x = get_sp_group().all_gather(x, dim=1) 151 | 152 | # unpatchify 153 | x = self.unpatchify(x, grid_sizes) 154 | return x.float() 155 | 156 | 157 | def usp_attn_forward(self, x, grid_sizes, freqs, block_mask): 158 | 159 | r""" 160 | Args: 161 | x(Tensor): Shape [B, L, num_heads, C / num_heads] 162 | seq_lens(Tensor): Shape [B] 163 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 164 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 165 | """ 166 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 167 | half_dtypes = (torch.float16, torch.bfloat16) 168 | 169 | def half(x): 170 | return x if x.dtype in half_dtypes else x.to(torch.bfloat16) 171 | 172 | # query, key, value function 173 | def qkv_fn(x): 174 | q = self.norm_q(self.q(x)).view(b, s, n, d) 175 | k = self.norm_k(self.k(x)).view(b, s, n, d) 176 | v = self.v(x).view(b, s, n, d) 177 | return q, k, v 178 | 179 | x = x.to(self.q.weight.dtype) 180 | q, k, v = qkv_fn(x) 181 | 182 | if not self._flag_ar_attention: 183 | q = rope_apply(q, grid_sizes, freqs) 184 | k = rope_apply(k, grid_sizes, freqs) 185 | else: 186 | 187 | q = rope_apply(q, grid_sizes, freqs) 188 | k = rope_apply(k, grid_sizes, freqs) 189 | q = q.to(torch.bfloat16) 190 | k = k.to(torch.bfloat16) 191 | v = v.to(torch.bfloat16) 192 | # x = torch.nn.functional.scaled_dot_product_attention( 193 | # q.transpose(1, 2), 194 | # k.transpose(1, 2), 195 | # v.transpose(1, 2), 196 | # ).transpose(1, 2).contiguous() 197 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): 198 | x = ( 199 | torch.nn.functional.scaled_dot_product_attention( 200 | q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask 201 | ) 202 | .transpose(1, 2) 203 | .contiguous() 204 | ) 205 | x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size) 206 | 207 | # output 208 | x = x.flatten(2) 209 | x = self.o(x) 210 | return x 211 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import torch 5 | from safetensors.torch import load_file 6 | 7 | from .clip import CLIPModel 8 | from .t5 import T5EncoderModel 9 | from .transformer import WanModel 10 | from .vae import WanVAE 11 | 12 | 13 | def download_model(model_id): 14 | if not os.path.exists(model_id): 15 | from huggingface_hub import snapshot_download 16 | 17 | model_id = snapshot_download(repo_id=model_id) 18 | return model_id 19 | 20 | 21 | def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE: 22 | vae = WanVAE(model_path).to(device).to(weight_dtype) 23 | vae.vae.requires_grad_(False) 24 | vae.vae.eval() 25 | gc.collect() 26 | torch.cuda.empty_cache() 27 | return vae 28 | 29 | 30 | def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel: 31 | config_path = os.path.join(model_path, "config.json") 32 | transformer = WanModel.from_config(config_path).to(weight_dtype).to(device) 33 | 34 | for file in os.listdir(model_path): 35 | if file.endswith(".safetensors"): 36 | file_path = os.path.join(model_path, file) 37 | state_dict = load_file(file_path) 38 | transformer.load_state_dict(state_dict, strict=False) 39 | del state_dict 40 | gc.collect() 41 | torch.cuda.empty_cache() 42 | 43 | transformer.requires_grad_(False) 44 | transformer.eval() 45 | gc.collect() 46 | torch.cuda.empty_cache() 47 | return transformer 48 | 49 | 50 | def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel: 51 | t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth") 52 | tokenizer_path = os.path.join(model_path, "google", "umt5-xxl") 53 | text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype) 54 | text_encoder.requires_grad_(False) 55 | text_encoder.eval() 56 | gc.collect() 57 | torch.cuda.empty_cache() 58 | return text_encoder 59 | 60 | 61 | def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel: 62 | checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") 63 | tokenizer_path = os.path.join(model_path, "xlm-roberta-large") 64 | image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device) 65 | image_enc.requires_grad_(False) 66 | image_enc.eval() 67 | gc.collect() 68 | torch.cuda.empty_cache() 69 | return image_enc 70 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | 4 | try: 5 | import flash_attn_interface 6 | 7 | FLASH_ATTN_3_AVAILABLE = True 8 | except ModuleNotFoundError: 9 | FLASH_ATTN_3_AVAILABLE = False 10 | 11 | try: 12 | import flash_attn 13 | 14 | FLASH_ATTN_2_AVAILABLE = True 15 | except ModuleNotFoundError: 16 | FLASH_ATTN_2_AVAILABLE = False 17 | 18 | import warnings 19 | 20 | __all__ = [ 21 | "flash_attention", 22 | "attention", 23 | ] 24 | 25 | 26 | def flash_attention( 27 | q, 28 | k, 29 | v, 30 | q_lens=None, 31 | k_lens=None, 32 | dropout_p=0.0, 33 | softmax_scale=None, 34 | q_scale=None, 35 | causal=False, 36 | window_size=(-1, -1), 37 | deterministic=False, 38 | dtype=torch.bfloat16, 39 | version=None, 40 | ): 41 | """ 42 | q: [B, Lq, Nq, C1]. 43 | k: [B, Lk, Nk, C1]. 44 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 45 | q_lens: [B]. 46 | k_lens: [B]. 47 | dropout_p: float. Dropout probability. 48 | softmax_scale: float. The scaling of QK^T before applying softmax. 49 | causal: bool. Whether to apply causal attention mask. 50 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 51 | deterministic: bool. If True, slightly slower and uses more memory. 52 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 53 | """ 54 | half_dtypes = (torch.float16, torch.bfloat16) 55 | assert dtype in half_dtypes 56 | assert q.device.type == "cuda" and q.size(-1) <= 256 57 | 58 | # params 59 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 60 | 61 | def half(x): 62 | return x if x.dtype in half_dtypes else x.to(dtype) 63 | 64 | # preprocess query 65 | 66 | q = half(q.flatten(0, 1)) 67 | q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True) 68 | 69 | # preprocess key, value 70 | 71 | k = half(k.flatten(0, 1)) 72 | v = half(v.flatten(0, 1)) 73 | k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True) 74 | 75 | q = q.to(v.dtype) 76 | k = k.to(v.dtype) 77 | 78 | if q_scale is not None: 79 | q = q * q_scale 80 | 81 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 82 | warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.") 83 | 84 | torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}") 85 | # apply attention 86 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 87 | # Note: dropout_p, window_size are not supported in FA3 now. 88 | x = flash_attn_interface.flash_attn_varlen_func( 89 | q=q, 90 | k=k, 91 | v=v, 92 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]) 93 | .cumsum(0, dtype=torch.int32) 94 | .to(q.device, non_blocking=True), 95 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]) 96 | .cumsum(0, dtype=torch.int32) 97 | .to(q.device, non_blocking=True), 98 | seqused_q=None, 99 | seqused_k=None, 100 | max_seqlen_q=lq, 101 | max_seqlen_k=lk, 102 | softmax_scale=softmax_scale, 103 | causal=causal, 104 | deterministic=deterministic, 105 | )[0].unflatten(0, (b, lq)) 106 | else: 107 | assert FLASH_ATTN_2_AVAILABLE 108 | x = flash_attn.flash_attn_varlen_func( 109 | q=q, 110 | k=k, 111 | v=v, 112 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]) 113 | .cumsum(0, dtype=torch.int32) 114 | .to(q.device, non_blocking=True), 115 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]) 116 | .cumsum(0, dtype=torch.int32) 117 | .to(q.device, non_blocking=True), 118 | max_seqlen_q=lq, 119 | max_seqlen_k=lk, 120 | dropout_p=dropout_p, 121 | softmax_scale=softmax_scale, 122 | causal=causal, 123 | window_size=window_size, 124 | deterministic=deterministic, 125 | ).unflatten(0, (b, lq)) 126 | torch.cuda.nvtx.range_pop() 127 | 128 | # output 129 | return x 130 | 131 | 132 | def attention( 133 | q, 134 | k, 135 | v, 136 | q_lens=None, 137 | k_lens=None, 138 | dropout_p=0.0, 139 | softmax_scale=None, 140 | q_scale=None, 141 | causal=False, 142 | window_size=(-1, -1), 143 | deterministic=False, 144 | dtype=torch.bfloat16, 145 | fa_version=None, 146 | ): 147 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: 148 | return flash_attention( 149 | q=q, 150 | k=k, 151 | v=v, 152 | q_lens=q_lens, 153 | k_lens=k_lens, 154 | dropout_p=dropout_p, 155 | softmax_scale=softmax_scale, 156 | q_scale=q_scale, 157 | causal=causal, 158 | window_size=window_size, 159 | deterministic=deterministic, 160 | dtype=dtype, 161 | version=fa_version, 162 | ) 163 | else: 164 | if q_lens is not None or k_lens is not None: 165 | warnings.warn( 166 | "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance." 167 | ) 168 | attn_mask = None 169 | 170 | q = q.transpose(1, 2).to(dtype) 171 | k = k.transpose(1, 2).to(dtype) 172 | v = v.transpose(1, 2).to(dtype) 173 | 174 | out = torch.nn.functional.scaled_dot_product_attention( 175 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p 176 | ) 177 | 178 | out = out.transpose(1, 2).contiguous() 179 | return out 180 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/clip.py: -------------------------------------------------------------------------------- 1 | # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms as T 10 | from diffusers.models import ModelMixin 11 | 12 | from .attention import flash_attention 13 | from .tokenizers import HuggingfaceTokenizer 14 | from .xlm_roberta import XLMRoberta 15 | 16 | __all__ = [ 17 | "XLMRobertaCLIP", 18 | "clip_xlm_roberta_vit_h_14", 19 | "CLIPModel", 20 | ] 21 | 22 | 23 | def pos_interpolate(pos, seq_len): 24 | if pos.size(1) == seq_len: 25 | return pos 26 | else: 27 | src_grid = int(math.sqrt(pos.size(1))) 28 | tar_grid = int(math.sqrt(seq_len)) 29 | n = pos.size(1) - src_grid * src_grid 30 | return torch.cat( 31 | [ 32 | pos[:, :n], 33 | F.interpolate( 34 | pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), 35 | size=(tar_grid, tar_grid), 36 | mode="bicubic", 37 | align_corners=False, 38 | ) 39 | .flatten(2) 40 | .transpose(1, 2), 41 | ], 42 | dim=1, 43 | ) 44 | 45 | 46 | class QuickGELU(nn.Module): 47 | def forward(self, x): 48 | return x * torch.sigmoid(1.702 * x) 49 | 50 | 51 | class LayerNorm(nn.LayerNorm): 52 | def forward(self, x): 53 | return super().forward(x.float()).type_as(x) 54 | 55 | 56 | class SelfAttention(nn.Module): 57 | def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): 58 | assert dim % num_heads == 0 59 | super().__init__() 60 | self.dim = dim 61 | self.num_heads = num_heads 62 | self.head_dim = dim // num_heads 63 | self.causal = causal 64 | self.attn_dropout = attn_dropout 65 | self.proj_dropout = proj_dropout 66 | 67 | # layers 68 | self.to_qkv = nn.Linear(dim, dim * 3) 69 | self.proj = nn.Linear(dim, dim) 70 | 71 | def forward(self, x): 72 | """ 73 | x: [B, L, C]. 74 | """ 75 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 76 | 77 | # compute query, key, value 78 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) 79 | 80 | # compute attention 81 | p = self.attn_dropout if self.training else 0.0 82 | x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) 83 | x = x.reshape(b, s, c) 84 | 85 | # output 86 | x = self.proj(x) 87 | x = F.dropout(x, self.proj_dropout, self.training) 88 | return x 89 | 90 | 91 | class SwiGLU(nn.Module): 92 | def __init__(self, dim, mid_dim): 93 | super().__init__() 94 | self.dim = dim 95 | self.mid_dim = mid_dim 96 | 97 | # layers 98 | self.fc1 = nn.Linear(dim, mid_dim) 99 | self.fc2 = nn.Linear(dim, mid_dim) 100 | self.fc3 = nn.Linear(mid_dim, dim) 101 | 102 | def forward(self, x): 103 | x = F.silu(self.fc1(x)) * self.fc2(x) 104 | x = self.fc3(x) 105 | return x 106 | 107 | 108 | class AttentionBlock(nn.Module): 109 | def __init__( 110 | self, 111 | dim, 112 | mlp_ratio, 113 | num_heads, 114 | post_norm=False, 115 | causal=False, 116 | activation="quick_gelu", 117 | attn_dropout=0.0, 118 | proj_dropout=0.0, 119 | norm_eps=1e-5, 120 | ): 121 | assert activation in ["quick_gelu", "gelu", "swi_glu"] 122 | super().__init__() 123 | self.dim = dim 124 | self.mlp_ratio = mlp_ratio 125 | self.num_heads = num_heads 126 | self.post_norm = post_norm 127 | self.causal = causal 128 | self.norm_eps = norm_eps 129 | 130 | # layers 131 | self.norm1 = LayerNorm(dim, eps=norm_eps) 132 | self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) 133 | self.norm2 = LayerNorm(dim, eps=norm_eps) 134 | if activation == "swi_glu": 135 | self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) 136 | else: 137 | self.mlp = nn.Sequential( 138 | nn.Linear(dim, int(dim * mlp_ratio)), 139 | QuickGELU() if activation == "quick_gelu" else nn.GELU(), 140 | nn.Linear(int(dim * mlp_ratio), dim), 141 | nn.Dropout(proj_dropout), 142 | ) 143 | 144 | def forward(self, x): 145 | if self.post_norm: 146 | x = x + self.norm1(self.attn(x)) 147 | x = x + self.norm2(self.mlp(x)) 148 | else: 149 | x = x + self.attn(self.norm1(x)) 150 | x = x + self.mlp(self.norm2(x)) 151 | return x 152 | 153 | 154 | class AttentionPool(nn.Module): 155 | def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5): 156 | assert dim % num_heads == 0 157 | super().__init__() 158 | self.dim = dim 159 | self.mlp_ratio = mlp_ratio 160 | self.num_heads = num_heads 161 | self.head_dim = dim // num_heads 162 | self.proj_dropout = proj_dropout 163 | self.norm_eps = norm_eps 164 | 165 | # layers 166 | gain = 1.0 / math.sqrt(dim) 167 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 168 | self.to_q = nn.Linear(dim, dim) 169 | self.to_kv = nn.Linear(dim, dim * 2) 170 | self.proj = nn.Linear(dim, dim) 171 | self.norm = LayerNorm(dim, eps=norm_eps) 172 | self.mlp = nn.Sequential( 173 | nn.Linear(dim, int(dim * mlp_ratio)), 174 | QuickGELU() if activation == "quick_gelu" else nn.GELU(), 175 | nn.Linear(int(dim * mlp_ratio), dim), 176 | nn.Dropout(proj_dropout), 177 | ) 178 | 179 | def forward(self, x): 180 | """ 181 | x: [B, L, C]. 182 | """ 183 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 184 | 185 | # compute query, key, value 186 | q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) 187 | k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) 188 | 189 | # compute attention 190 | x = flash_attention(q, k, v, version=2) 191 | x = x.reshape(b, 1, c) 192 | 193 | # output 194 | x = self.proj(x) 195 | x = F.dropout(x, self.proj_dropout, self.training) 196 | 197 | # mlp 198 | x = x + self.mlp(self.norm(x)) 199 | return x[:, 0] 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__( 204 | self, 205 | image_size=224, 206 | patch_size=16, 207 | dim=768, 208 | mlp_ratio=4, 209 | out_dim=512, 210 | num_heads=12, 211 | num_layers=12, 212 | pool_type="token", 213 | pre_norm=True, 214 | post_norm=False, 215 | activation="quick_gelu", 216 | attn_dropout=0.0, 217 | proj_dropout=0.0, 218 | embedding_dropout=0.0, 219 | norm_eps=1e-5, 220 | ): 221 | if image_size % patch_size != 0: 222 | print("[WARNING] image_size is not divisible by patch_size", flush=True) 223 | assert pool_type in ("token", "token_fc", "attn_pool") 224 | out_dim = out_dim or dim 225 | super().__init__() 226 | self.image_size = image_size 227 | self.patch_size = patch_size 228 | self.num_patches = (image_size // patch_size) ** 2 229 | self.dim = dim 230 | self.mlp_ratio = mlp_ratio 231 | self.out_dim = out_dim 232 | self.num_heads = num_heads 233 | self.num_layers = num_layers 234 | self.pool_type = pool_type 235 | self.post_norm = post_norm 236 | self.norm_eps = norm_eps 237 | 238 | # embeddings 239 | gain = 1.0 / math.sqrt(dim) 240 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm) 241 | if pool_type in ("token", "token_fc"): 242 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 243 | self.pos_embedding = nn.Parameter( 244 | gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim) 245 | ) 246 | self.dropout = nn.Dropout(embedding_dropout) 247 | 248 | # transformer 249 | self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None 250 | self.transformer = nn.Sequential( 251 | *[ 252 | AttentionBlock( 253 | dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps 254 | ) 255 | for _ in range(num_layers) 256 | ] 257 | ) 258 | self.post_norm = LayerNorm(dim, eps=norm_eps) 259 | 260 | # head 261 | if pool_type == "token": 262 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) 263 | elif pool_type == "token_fc": 264 | self.head = nn.Linear(dim, out_dim) 265 | elif pool_type == "attn_pool": 266 | self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps) 267 | 268 | def forward(self, x, interpolation=False, use_31_block=False): 269 | b = x.size(0) 270 | 271 | # embeddings 272 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) 273 | if self.pool_type in ("token", "token_fc"): 274 | x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) 275 | if interpolation: 276 | e = pos_interpolate(self.pos_embedding, x.size(1)) 277 | else: 278 | e = self.pos_embedding 279 | x = self.dropout(x + e) 280 | if self.pre_norm is not None: 281 | x = self.pre_norm(x) 282 | 283 | # transformer 284 | if use_31_block: 285 | x = self.transformer[:-1](x) 286 | return x 287 | else: 288 | x = self.transformer(x) 289 | return x 290 | 291 | 292 | class XLMRobertaWithHead(XLMRoberta): 293 | def __init__(self, **kwargs): 294 | self.out_dim = kwargs.pop("out_dim") 295 | super().__init__(**kwargs) 296 | 297 | # head 298 | mid_dim = (self.dim + self.out_dim) // 2 299 | self.head = nn.Sequential( 300 | nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False) 301 | ) 302 | 303 | def forward(self, ids): 304 | # xlm-roberta 305 | x = super().forward(ids) 306 | 307 | # average pooling 308 | mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) 309 | x = (x * mask).sum(dim=1) / mask.sum(dim=1) 310 | 311 | # head 312 | x = self.head(x) 313 | return x 314 | 315 | 316 | class XLMRobertaCLIP(nn.Module): 317 | def __init__( 318 | self, 319 | embed_dim=1024, 320 | image_size=224, 321 | patch_size=14, 322 | vision_dim=1280, 323 | vision_mlp_ratio=4, 324 | vision_heads=16, 325 | vision_layers=32, 326 | vision_pool="token", 327 | vision_pre_norm=True, 328 | vision_post_norm=False, 329 | activation="gelu", 330 | vocab_size=250002, 331 | max_text_len=514, 332 | type_size=1, 333 | pad_id=1, 334 | text_dim=1024, 335 | text_heads=16, 336 | text_layers=24, 337 | text_post_norm=True, 338 | text_dropout=0.1, 339 | attn_dropout=0.0, 340 | proj_dropout=0.0, 341 | embedding_dropout=0.0, 342 | norm_eps=1e-5, 343 | ): 344 | super().__init__() 345 | self.embed_dim = embed_dim 346 | self.image_size = image_size 347 | self.patch_size = patch_size 348 | self.vision_dim = vision_dim 349 | self.vision_mlp_ratio = vision_mlp_ratio 350 | self.vision_heads = vision_heads 351 | self.vision_layers = vision_layers 352 | self.vision_pre_norm = vision_pre_norm 353 | self.vision_post_norm = vision_post_norm 354 | self.activation = activation 355 | self.vocab_size = vocab_size 356 | self.max_text_len = max_text_len 357 | self.type_size = type_size 358 | self.pad_id = pad_id 359 | self.text_dim = text_dim 360 | self.text_heads = text_heads 361 | self.text_layers = text_layers 362 | self.text_post_norm = text_post_norm 363 | self.norm_eps = norm_eps 364 | 365 | # models 366 | self.visual = VisionTransformer( 367 | image_size=image_size, 368 | patch_size=patch_size, 369 | dim=vision_dim, 370 | mlp_ratio=vision_mlp_ratio, 371 | out_dim=embed_dim, 372 | num_heads=vision_heads, 373 | num_layers=vision_layers, 374 | pool_type=vision_pool, 375 | pre_norm=vision_pre_norm, 376 | post_norm=vision_post_norm, 377 | activation=activation, 378 | attn_dropout=attn_dropout, 379 | proj_dropout=proj_dropout, 380 | embedding_dropout=embedding_dropout, 381 | norm_eps=norm_eps, 382 | ) 383 | self.textual = XLMRobertaWithHead( 384 | vocab_size=vocab_size, 385 | max_seq_len=max_text_len, 386 | type_size=type_size, 387 | pad_id=pad_id, 388 | dim=text_dim, 389 | out_dim=embed_dim, 390 | num_heads=text_heads, 391 | num_layers=text_layers, 392 | post_norm=text_post_norm, 393 | dropout=text_dropout, 394 | ) 395 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) 396 | 397 | def forward(self, imgs, txt_ids): 398 | """ 399 | imgs: [B, 3, H, W] of torch.float32. 400 | - mean: [0.48145466, 0.4578275, 0.40821073] 401 | - std: [0.26862954, 0.26130258, 0.27577711] 402 | txt_ids: [B, L] of torch.long. 403 | Encoded by data.CLIPTokenizer. 404 | """ 405 | xi = self.visual(imgs) 406 | xt = self.textual(txt_ids) 407 | return xi, xt 408 | 409 | def param_groups(self): 410 | groups = [ 411 | { 412 | "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], 413 | "weight_decay": 0.0, 414 | }, 415 | {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]}, 416 | ] 417 | return groups 418 | 419 | 420 | def _clip( 421 | pretrained=False, 422 | pretrained_name=None, 423 | model_cls=XLMRobertaCLIP, 424 | return_transforms=False, 425 | return_tokenizer=False, 426 | tokenizer_padding="eos", 427 | dtype=torch.float32, 428 | device="cpu", 429 | **kwargs, 430 | ): 431 | # init a model on device 432 | with torch.device(device): 433 | model = model_cls(**kwargs) 434 | 435 | # set device 436 | model = model.to(dtype=dtype, device=device) 437 | output = (model,) 438 | 439 | # init transforms 440 | if return_transforms: 441 | # mean and std 442 | if "siglip" in pretrained_name.lower(): 443 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 444 | else: 445 | mean = [0.48145466, 0.4578275, 0.40821073] 446 | std = [0.26862954, 0.26130258, 0.27577711] 447 | 448 | # transforms 449 | transforms = T.Compose( 450 | [ 451 | T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), 452 | T.ToTensor(), 453 | T.Normalize(mean=mean, std=std), 454 | ] 455 | ) 456 | output += (transforms,) 457 | return output[0] if len(output) == 1 else output 458 | 459 | 460 | def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): 461 | cfg = dict( 462 | embed_dim=1024, 463 | image_size=224, 464 | patch_size=14, 465 | vision_dim=1280, 466 | vision_mlp_ratio=4, 467 | vision_heads=16, 468 | vision_layers=32, 469 | vision_pool="token", 470 | activation="gelu", 471 | vocab_size=250002, 472 | max_text_len=514, 473 | type_size=1, 474 | pad_id=1, 475 | text_dim=1024, 476 | text_heads=16, 477 | text_layers=24, 478 | text_post_norm=True, 479 | text_dropout=0.1, 480 | attn_dropout=0.0, 481 | proj_dropout=0.0, 482 | embedding_dropout=0.0, 483 | ) 484 | cfg.update(**kwargs) 485 | return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) 486 | 487 | 488 | class CLIPModel(ModelMixin): 489 | def __init__(self, checkpoint_path, tokenizer_path): 490 | self.checkpoint_path = checkpoint_path 491 | self.tokenizer_path = tokenizer_path 492 | 493 | super().__init__() 494 | # init model 495 | self.model, self.transforms = clip_xlm_roberta_vit_h_14( 496 | pretrained=False, return_transforms=True, return_tokenizer=False 497 | ) 498 | self.model = self.model.eval().requires_grad_(False) 499 | logging.info(f"loading {checkpoint_path}") 500 | self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) 501 | 502 | # init tokenizer 503 | self.tokenizer = HuggingfaceTokenizer( 504 | name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace" 505 | ) 506 | 507 | def encode_video(self, video): 508 | # preprocess 509 | b, c, t, h, w = video.shape 510 | video = video.transpose(1, 2) 511 | video = video.reshape(b * t, c, h, w) 512 | size = (self.model.image_size,) * 2 513 | video = F.interpolate( 514 | video, 515 | size=size, 516 | mode='bicubic', 517 | align_corners=False) 518 | 519 | video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5)) 520 | 521 | # forward 522 | with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type): 523 | out = self.model.visual(video, use_31_block=True) 524 | 525 | return out 526 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/t5.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.t5.modeling_t5 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from diffusers.models import ModelMixin 10 | 11 | from .tokenizers import HuggingfaceTokenizer 12 | 13 | __all__ = [ 14 | "T5Model", 15 | "T5Encoder", 16 | "T5Decoder", 17 | "T5EncoderModel", 18 | ] 19 | 20 | 21 | def fp16_clamp(x): 22 | if x.dtype == torch.float16 and torch.isinf(x).any(): 23 | clamp = torch.finfo(x.dtype).max - 1000 24 | x = torch.clamp(x, min=-clamp, max=clamp) 25 | return x 26 | 27 | 28 | def init_weights(m): 29 | if isinstance(m, T5LayerNorm): 30 | nn.init.ones_(m.weight) 31 | elif isinstance(m, T5Model): 32 | nn.init.normal_(m.token_embedding.weight, std=1.0) 33 | elif isinstance(m, T5FeedForward): 34 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) 35 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) 36 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) 37 | elif isinstance(m, T5Attention): 38 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) 39 | nn.init.normal_(m.k.weight, std=m.dim**-0.5) 40 | nn.init.normal_(m.v.weight, std=m.dim**-0.5) 41 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) 42 | elif isinstance(m, T5RelativeEmbedding): 43 | nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) 44 | 45 | 46 | class GELU(nn.Module): 47 | def forward(self, x): 48 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 49 | 50 | 51 | class T5LayerNorm(nn.Module): 52 | def __init__(self, dim, eps=1e-6): 53 | super(T5LayerNorm, self).__init__() 54 | self.dim = dim 55 | self.eps = eps 56 | self.weight = nn.Parameter(torch.ones(dim)) 57 | 58 | def forward(self, x): 59 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) 60 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 61 | x = x.type_as(self.weight) 62 | return self.weight * x 63 | 64 | 65 | class T5Attention(nn.Module): 66 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1): 67 | assert dim_attn % num_heads == 0 68 | super(T5Attention, self).__init__() 69 | self.dim = dim 70 | self.dim_attn = dim_attn 71 | self.num_heads = num_heads 72 | self.head_dim = dim_attn // num_heads 73 | 74 | # layers 75 | self.q = nn.Linear(dim, dim_attn, bias=False) 76 | self.k = nn.Linear(dim, dim_attn, bias=False) 77 | self.v = nn.Linear(dim, dim_attn, bias=False) 78 | self.o = nn.Linear(dim_attn, dim, bias=False) 79 | self.dropout = nn.Dropout(dropout) 80 | 81 | def forward(self, x, context=None, mask=None, pos_bias=None): 82 | """ 83 | x: [B, L1, C]. 84 | context: [B, L2, C] or None. 85 | mask: [B, L2] or [B, L1, L2] or None. 86 | """ 87 | # check inputs 88 | context = x if context is None else context 89 | b, n, c = x.size(0), self.num_heads, self.head_dim 90 | 91 | # compute query, key, value 92 | q = self.q(x).view(b, -1, n, c) 93 | k = self.k(context).view(b, -1, n, c) 94 | v = self.v(context).view(b, -1, n, c) 95 | 96 | # attention bias 97 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) 98 | if pos_bias is not None: 99 | attn_bias += pos_bias 100 | if mask is not None: 101 | assert mask.ndim in [2, 3] 102 | mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) 103 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) 104 | 105 | # compute attention (T5 does not use scaling) 106 | attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias 107 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 108 | x = torch.einsum("bnij,bjnc->binc", attn, v) 109 | 110 | # output 111 | x = x.reshape(b, -1, n * c) 112 | x = self.o(x) 113 | x = self.dropout(x) 114 | return x 115 | 116 | 117 | class T5FeedForward(nn.Module): 118 | def __init__(self, dim, dim_ffn, dropout=0.1): 119 | super(T5FeedForward, self).__init__() 120 | self.dim = dim 121 | self.dim_ffn = dim_ffn 122 | 123 | # layers 124 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) 125 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False) 126 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False) 127 | self.dropout = nn.Dropout(dropout) 128 | 129 | def forward(self, x): 130 | x = self.fc1(x) * self.gate(x) 131 | x = self.dropout(x) 132 | x = self.fc2(x) 133 | x = self.dropout(x) 134 | return x 135 | 136 | 137 | class T5SelfAttention(nn.Module): 138 | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): 139 | super(T5SelfAttention, self).__init__() 140 | self.dim = dim 141 | self.dim_attn = dim_attn 142 | self.dim_ffn = dim_ffn 143 | self.num_heads = num_heads 144 | self.num_buckets = num_buckets 145 | self.shared_pos = shared_pos 146 | 147 | # layers 148 | self.norm1 = T5LayerNorm(dim) 149 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout) 150 | self.norm2 = T5LayerNorm(dim) 151 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 152 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) 153 | 154 | def forward(self, x, mask=None, pos_bias=None): 155 | e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) 156 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) 157 | x = fp16_clamp(x + self.ffn(self.norm2(x))) 158 | return x 159 | 160 | 161 | class T5CrossAttention(nn.Module): 162 | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): 163 | super(T5CrossAttention, self).__init__() 164 | self.dim = dim 165 | self.dim_attn = dim_attn 166 | self.dim_ffn = dim_ffn 167 | self.num_heads = num_heads 168 | self.num_buckets = num_buckets 169 | self.shared_pos = shared_pos 170 | 171 | # layers 172 | self.norm1 = T5LayerNorm(dim) 173 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) 174 | self.norm2 = T5LayerNorm(dim) 175 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) 176 | self.norm3 = T5LayerNorm(dim) 177 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 178 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) 179 | 180 | def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): 181 | e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) 182 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) 183 | x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) 184 | x = fp16_clamp(x + self.ffn(self.norm3(x))) 185 | return x 186 | 187 | 188 | class T5RelativeEmbedding(nn.Module): 189 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): 190 | super(T5RelativeEmbedding, self).__init__() 191 | self.num_buckets = num_buckets 192 | self.num_heads = num_heads 193 | self.bidirectional = bidirectional 194 | self.max_dist = max_dist 195 | 196 | # layers 197 | self.embedding = nn.Embedding(num_buckets, num_heads) 198 | 199 | def forward(self, lq, lk): 200 | device = self.embedding.weight.device 201 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ 202 | # torch.arange(lq).unsqueeze(1).to(device) 203 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) 204 | rel_pos = self._relative_position_bucket(rel_pos) 205 | rel_pos_embeds = self.embedding(rel_pos) 206 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] 207 | return rel_pos_embeds.contiguous() 208 | 209 | def _relative_position_bucket(self, rel_pos): 210 | # preprocess 211 | if self.bidirectional: 212 | num_buckets = self.num_buckets // 2 213 | rel_buckets = (rel_pos > 0).long() * num_buckets 214 | rel_pos = torch.abs(rel_pos) 215 | else: 216 | num_buckets = self.num_buckets 217 | rel_buckets = 0 218 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) 219 | 220 | # embeddings for small and large positions 221 | max_exact = num_buckets // 2 222 | rel_pos_large = ( 223 | max_exact 224 | + ( 225 | torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) 226 | ).long() 227 | ) 228 | rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) 229 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) 230 | return rel_buckets 231 | 232 | 233 | class T5Encoder(nn.Module): 234 | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): 235 | super(T5Encoder, self).__init__() 236 | self.dim = dim 237 | self.dim_attn = dim_attn 238 | self.dim_ffn = dim_ffn 239 | self.num_heads = num_heads 240 | self.num_layers = num_layers 241 | self.num_buckets = num_buckets 242 | self.shared_pos = shared_pos 243 | 244 | # layers 245 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) 246 | self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None 247 | self.dropout = nn.Dropout(dropout) 248 | self.blocks = nn.ModuleList( 249 | [ 250 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) 251 | for _ in range(num_layers) 252 | ] 253 | ) 254 | self.norm = T5LayerNorm(dim) 255 | 256 | # initialize weights 257 | self.apply(init_weights) 258 | 259 | def forward(self, ids, mask=None): 260 | x = self.token_embedding(ids) 261 | x = self.dropout(x) 262 | e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None 263 | for block in self.blocks: 264 | x = block(x, mask, pos_bias=e) 265 | x = self.norm(x) 266 | x = self.dropout(x) 267 | return x 268 | 269 | 270 | class T5Decoder(nn.Module): 271 | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): 272 | super(T5Decoder, self).__init__() 273 | self.dim = dim 274 | self.dim_attn = dim_attn 275 | self.dim_ffn = dim_ffn 276 | self.num_heads = num_heads 277 | self.num_layers = num_layers 278 | self.num_buckets = num_buckets 279 | self.shared_pos = shared_pos 280 | 281 | # layers 282 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) 283 | self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None 284 | self.dropout = nn.Dropout(dropout) 285 | self.blocks = nn.ModuleList( 286 | [ 287 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) 288 | for _ in range(num_layers) 289 | ] 290 | ) 291 | self.norm = T5LayerNorm(dim) 292 | 293 | # initialize weights 294 | self.apply(init_weights) 295 | 296 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): 297 | b, s = ids.size() 298 | 299 | # causal mask 300 | if mask is None: 301 | mask = torch.tril(torch.ones(1, s, s).to(ids.device)) 302 | elif mask.ndim == 2: 303 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) 304 | 305 | # layers 306 | x = self.token_embedding(ids) 307 | x = self.dropout(x) 308 | e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None 309 | for block in self.blocks: 310 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) 311 | x = self.norm(x) 312 | x = self.dropout(x) 313 | return x 314 | 315 | 316 | class T5Model(nn.Module): 317 | def __init__( 318 | self, 319 | vocab_size, 320 | dim, 321 | dim_attn, 322 | dim_ffn, 323 | num_heads, 324 | encoder_layers, 325 | decoder_layers, 326 | num_buckets, 327 | shared_pos=True, 328 | dropout=0.1, 329 | ): 330 | super(T5Model, self).__init__() 331 | self.vocab_size = vocab_size 332 | self.dim = dim 333 | self.dim_attn = dim_attn 334 | self.dim_ffn = dim_ffn 335 | self.num_heads = num_heads 336 | self.encoder_layers = encoder_layers 337 | self.decoder_layers = decoder_layers 338 | self.num_buckets = num_buckets 339 | 340 | # layers 341 | self.token_embedding = nn.Embedding(vocab_size, dim) 342 | self.encoder = T5Encoder( 343 | self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout 344 | ) 345 | self.decoder = T5Decoder( 346 | self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout 347 | ) 348 | self.head = nn.Linear(dim, vocab_size, bias=False) 349 | 350 | # initialize weights 351 | self.apply(init_weights) 352 | 353 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): 354 | x = self.encoder(encoder_ids, encoder_mask) 355 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) 356 | x = self.head(x) 357 | return x 358 | 359 | 360 | def _t5( 361 | name, 362 | encoder_only=False, 363 | decoder_only=False, 364 | return_tokenizer=False, 365 | tokenizer_kwargs={}, 366 | dtype=torch.float32, 367 | device="cpu", 368 | **kwargs, 369 | ): 370 | # sanity check 371 | assert not (encoder_only and decoder_only) 372 | 373 | # params 374 | if encoder_only: 375 | model_cls = T5Encoder 376 | kwargs["vocab"] = kwargs.pop("vocab_size") 377 | kwargs["num_layers"] = kwargs.pop("encoder_layers") 378 | _ = kwargs.pop("decoder_layers") 379 | elif decoder_only: 380 | model_cls = T5Decoder 381 | kwargs["vocab"] = kwargs.pop("vocab_size") 382 | kwargs["num_layers"] = kwargs.pop("decoder_layers") 383 | _ = kwargs.pop("encoder_layers") 384 | else: 385 | model_cls = T5Model 386 | 387 | # init model 388 | with torch.device(device): 389 | model = model_cls(**kwargs) 390 | 391 | # set device 392 | model = model.to(dtype=dtype, device=device) 393 | 394 | # init tokenizer 395 | if return_tokenizer: 396 | from .tokenizers import HuggingfaceTokenizer 397 | 398 | tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) 399 | return model, tokenizer 400 | else: 401 | return model 402 | 403 | 404 | def umt5_xxl(**kwargs): 405 | cfg = dict( 406 | vocab_size=256384, 407 | dim=4096, 408 | dim_attn=4096, 409 | dim_ffn=10240, 410 | num_heads=64, 411 | encoder_layers=24, 412 | decoder_layers=24, 413 | num_buckets=32, 414 | shared_pos=False, 415 | dropout=0.1, 416 | ) 417 | cfg.update(**kwargs) 418 | return _t5("umt5-xxl", **cfg) 419 | 420 | 421 | class T5EncoderModel(ModelMixin): 422 | def __init__( 423 | self, 424 | checkpoint_path=None, 425 | tokenizer_path=None, 426 | text_len=512, 427 | shard_fn=None, 428 | ): 429 | self.text_len = text_len 430 | self.checkpoint_path = checkpoint_path 431 | self.tokenizer_path = tokenizer_path 432 | 433 | super().__init__() 434 | # init model 435 | model = umt5_xxl(encoder_only=True, return_tokenizer=False) 436 | logging.info(f"loading {checkpoint_path}") 437 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) 438 | self.model = model 439 | if shard_fn is not None: 440 | self.model = shard_fn(self.model, sync_module_states=False) 441 | else: 442 | self.model.eval().requires_grad_(False) 443 | # init tokenizer 444 | self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") 445 | 446 | def encode(self, texts): 447 | ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) 448 | ids = ids.to(self.device) 449 | mask = mask.to(self.device) 450 | # seq_lens = mask.gt(0).sum(dim=1).long() 451 | context = self.model(ids, mask) 452 | context = context * mask.unsqueeze(-1).cuda() 453 | 454 | return context 455 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ["HuggingfaceTokenizer"] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r"\s+", " ", text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace("_", " ") 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans("", "", string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string) 30 | ) 31 | else: 32 | text = text.translate(str.maketrans("", "", string.punctuation)) 33 | text = text.lower() 34 | text = re.sub(r"\s+", " ", text) 35 | return text.strip() 36 | 37 | 38 | class HuggingfaceTokenizer: 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, "whitespace", "lower", "canonicalize") 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop("return_mask", False) 51 | 52 | # arguments 53 | _kwargs = {"return_tensors": "pt"} 54 | if self.seq_len is not None: 55 | _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) 56 | _kwargs.update(**kwargs) 57 | 58 | # tokenization 59 | if isinstance(sequence, str): 60 | sequence = [sequence] 61 | if self.clean: 62 | sequence = [self._clean(u) for u in sequence] 63 | ids = self.tokenizer(sequence, **_kwargs) 64 | 65 | # output 66 | if return_mask: 67 | return ids.input_ids, ids.attention_mask 68 | else: 69 | return ids.input_ids 70 | 71 | def _clean(self, text): 72 | if self.clean == "whitespace": 73 | text = whitespace_clean(basic_clean(text)) 74 | elif self.clean == "lower": 75 | text = whitespace_clean(basic_clean(text)).lower() 76 | elif self.clean == "canonicalize": 77 | text = canonicalize(basic_clean(text)) 78 | return text 79 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.amp as amp 6 | import torch.nn as nn 7 | from diffusers.configuration_utils import ConfigMixin 8 | from diffusers.configuration_utils import register_to_config 9 | from diffusers.loaders import PeftAdapterMixin 10 | from diffusers.models.modeling_utils import ModelMixin 11 | from torch.backends.cuda import sdp_kernel 12 | from torch.nn.attention.flex_attention import BlockMask 13 | from torch.nn.attention.flex_attention import create_block_mask 14 | from torch.nn.attention.flex_attention import flex_attention 15 | 16 | from .attention import flash_attention 17 | 18 | 19 | flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") 20 | 21 | DISABLE_COMPILE = False # get os env 22 | 23 | __all__ = ["WanModel"] 24 | 25 | 26 | def sinusoidal_embedding_1d(dim, position): 27 | # preprocess 28 | assert dim % 2 == 0 29 | half = dim // 2 30 | position = position.type(torch.float64) 31 | 32 | # calculation 33 | sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) 34 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) 35 | return x 36 | 37 | 38 | @amp.autocast("cuda", enabled=False) 39 | def rope_params(max_seq_len, dim, theta=10000): 40 | assert dim % 2 == 0 41 | freqs = torch.outer( 42 | torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) 43 | ) 44 | freqs = torch.polar(torch.ones_like(freqs), freqs) 45 | return freqs 46 | 47 | 48 | @amp.autocast("cuda", enabled=False) 49 | def rope_apply(x, grid_sizes, freqs): 50 | n, c = x.size(2), x.size(3) // 2 51 | bs = x.size(0) 52 | 53 | # split freqs 54 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 55 | 56 | # loop over samples 57 | f, h, w = grid_sizes.tolist() 58 | seq_len = f * h * w 59 | 60 | # precompute multipliers 61 | 62 | x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2)) 63 | freqs_i = torch.cat( 64 | [ 65 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 66 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 67 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), 68 | ], 69 | dim=-1, 70 | ).reshape(seq_len, 1, -1) 71 | 72 | # apply rotary embedding 73 | x = torch.view_as_real(x * freqs_i).flatten(3) 74 | 75 | return x 76 | 77 | 78 | @torch.compile(dynamic=True, disable=DISABLE_COMPILE) 79 | def fast_rms_norm(x, weight, eps): 80 | x = x.float() 81 | x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) 82 | x = x.type_as(x) * weight 83 | return x 84 | 85 | 86 | class WanRMSNorm(nn.Module): 87 | def __init__(self, dim, eps=1e-5): 88 | super().__init__() 89 | self.dim = dim 90 | self.eps = eps 91 | self.weight = nn.Parameter(torch.ones(dim)) 92 | 93 | def forward(self, x): 94 | r""" 95 | Args: 96 | x(Tensor): Shape [B, L, C] 97 | """ 98 | return fast_rms_norm(x, self.weight, self.eps) 99 | 100 | def _norm(self, x): 101 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) 102 | 103 | 104 | class WanLayerNorm(nn.LayerNorm): 105 | def __init__(self, dim, eps=1e-6, elementwise_affine=False): 106 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) 107 | 108 | def forward(self, x): 109 | r""" 110 | Args: 111 | x(Tensor): Shape [B, L, C] 112 | """ 113 | return super().forward(x) 114 | 115 | 116 | class WanSelfAttention(nn.Module): 117 | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): 118 | assert dim % num_heads == 0 119 | super().__init__() 120 | self.dim = dim 121 | self.num_heads = num_heads 122 | self.head_dim = dim // num_heads 123 | self.window_size = window_size 124 | self.qk_norm = qk_norm 125 | self.eps = eps 126 | 127 | # layers 128 | self.q = nn.Linear(dim, dim) 129 | self.k = nn.Linear(dim, dim) 130 | self.v = nn.Linear(dim, dim) 131 | self.o = nn.Linear(dim, dim) 132 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 133 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 134 | 135 | self._flag_ar_attention = False 136 | 137 | def set_ar_attention(self): 138 | self._flag_ar_attention = True 139 | 140 | def forward(self, x, grid_sizes, freqs, block_mask): 141 | r""" 142 | Args: 143 | x(Tensor): Shape [B, L, num_heads, C / num_heads] 144 | seq_lens(Tensor): Shape [B] 145 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 146 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 147 | """ 148 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 149 | 150 | # query, key, value function 151 | def qkv_fn(x): 152 | q = self.norm_q(self.q(x)).view(b, s, n, d) 153 | k = self.norm_k(self.k(x)).view(b, s, n, d) 154 | v = self.v(x).view(b, s, n, d) 155 | return q, k, v 156 | 157 | x = x.to(self.q.weight.dtype) 158 | q, k, v = qkv_fn(x) 159 | 160 | if not self._flag_ar_attention: 161 | q = rope_apply(q, grid_sizes, freqs) 162 | k = rope_apply(k, grid_sizes, freqs) 163 | x = flash_attention(q=q, k=k, v=v, window_size=self.window_size) 164 | else: 165 | q = rope_apply(q, grid_sizes, freqs) 166 | k = rope_apply(k, grid_sizes, freqs) 167 | q = q.to(torch.bfloat16) 168 | k = k.to(torch.bfloat16) 169 | v = v.to(torch.bfloat16) 170 | 171 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): 172 | x = ( 173 | torch.nn.functional.scaled_dot_product_attention( 174 | q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask 175 | ) 176 | .transpose(1, 2) 177 | .contiguous() 178 | ) 179 | 180 | # output 181 | x = x.flatten(2) 182 | x = self.o(x) 183 | return x 184 | 185 | 186 | class WanT2VCrossAttention(WanSelfAttention): 187 | def forward(self, x, context): 188 | r""" 189 | Args: 190 | x(Tensor): Shape [B, L1, C] 191 | context(Tensor): Shape [B, L2, C] 192 | context_lens(Tensor): Shape [B] 193 | """ 194 | b, n, d = x.size(0), self.num_heads, self.head_dim 195 | 196 | # compute query, key, value 197 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 198 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 199 | v = self.v(context).view(b, -1, n, d) 200 | 201 | # compute attention 202 | x = flash_attention(q, k, v) 203 | 204 | # output 205 | x = x.flatten(2) 206 | x = self.o(x) 207 | return x 208 | 209 | 210 | class WanI2VCrossAttention(WanSelfAttention): 211 | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): 212 | super().__init__(dim, num_heads, window_size, qk_norm, eps) 213 | 214 | self.k_img = nn.Linear(dim, dim) 215 | self.v_img = nn.Linear(dim, dim) 216 | # self.alpha = nn.Parameter(torch.zeros((1, ))) 217 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 218 | 219 | def forward(self, x, context): 220 | r""" 221 | Args: 222 | x(Tensor): Shape [B, L1, C] 223 | context(Tensor): Shape [B, L2, C] 224 | context_lens(Tensor): Shape [B] 225 | """ 226 | context_img = context[:, :257] 227 | context = context[:, 257:] 228 | b, n, d = x.size(0), self.num_heads, self.head_dim 229 | 230 | # compute query, key, value 231 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 232 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 233 | v = self.v(context).view(b, -1, n, d) 234 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) 235 | v_img = self.v_img(context_img).view(b, -1, n, d) 236 | img_x = flash_attention(q, k_img, v_img) 237 | # compute attention 238 | x = flash_attention(q, k, v) 239 | 240 | # output 241 | x = x.flatten(2) 242 | img_x = img_x.flatten(2) 243 | x = x + img_x 244 | x = self.o(x) 245 | return x 246 | 247 | 248 | WAN_CROSSATTENTION_CLASSES = { 249 | "t2v_cross_attn": WanT2VCrossAttention, 250 | "i2v_cross_attn": WanI2VCrossAttention, 251 | } 252 | 253 | 254 | def mul_add(x, y, z): 255 | return x.float() + y.float() * z.float() 256 | 257 | 258 | def mul_add_add(x, y, z): 259 | return x.float() * (1 + y) + z 260 | 261 | 262 | mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE) 263 | mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE) 264 | 265 | 266 | class WanAttentionBlock(nn.Module): 267 | def __init__( 268 | self, 269 | cross_attn_type, 270 | dim, 271 | ffn_dim, 272 | num_heads, 273 | window_size=(-1, -1), 274 | qk_norm=True, 275 | cross_attn_norm=False, 276 | eps=1e-6, 277 | ): 278 | super().__init__() 279 | self.dim = dim 280 | self.ffn_dim = ffn_dim 281 | self.num_heads = num_heads 282 | self.window_size = window_size 283 | self.qk_norm = qk_norm 284 | self.cross_attn_norm = cross_attn_norm 285 | self.eps = eps 286 | 287 | # layers 288 | self.norm1 = WanLayerNorm(dim, eps) 289 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) 290 | self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() 291 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) 292 | self.norm2 = WanLayerNorm(dim, eps) 293 | self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) 294 | 295 | # modulation 296 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 297 | 298 | def set_ar_attention(self): 299 | self.self_attn.set_ar_attention() 300 | 301 | def forward( 302 | self, 303 | x, 304 | e, 305 | grid_sizes, 306 | freqs, 307 | context, 308 | block_mask, 309 | ): 310 | r""" 311 | Args: 312 | x(Tensor): Shape [B, L, C] 313 | e(Tensor): Shape [B, 6, C] 314 | seq_lens(Tensor): Shape [B], length of each sequence in batch 315 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 316 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 317 | """ 318 | if e.dim() == 3: 319 | modulation = self.modulation # 1, 6, dim 320 | with amp.autocast("cuda", dtype=torch.float32): 321 | e = (modulation + e).chunk(6, dim=1) 322 | elif e.dim() == 4: 323 | modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim 324 | with amp.autocast("cuda", dtype=torch.float32): 325 | e = (modulation + e).chunk(6, dim=1) 326 | e = [ei.squeeze(1) for ei in e] 327 | 328 | # self-attention 329 | out = mul_add_add_compile(self.norm1(x), e[1], e[0]) 330 | y = self.self_attn(out, grid_sizes, freqs, block_mask) 331 | with amp.autocast("cuda", dtype=torch.float32): 332 | x = mul_add_compile(x, y, e[2]) 333 | 334 | # cross-attention & ffn function 335 | def cross_attn_ffn(x, context, e): 336 | dtype = context.dtype 337 | x = x + self.cross_attn(self.norm3(x.to(dtype)), context) 338 | y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype)) 339 | with amp.autocast("cuda", dtype=torch.float32): 340 | x = mul_add_compile(x, y, e[5]) 341 | return x 342 | 343 | x = cross_attn_ffn(x, context, e) 344 | return x.to(torch.bfloat16) 345 | 346 | 347 | class Head(nn.Module): 348 | def __init__(self, dim, out_dim, patch_size, eps=1e-6): 349 | super().__init__() 350 | self.dim = dim 351 | self.out_dim = out_dim 352 | self.patch_size = patch_size 353 | self.eps = eps 354 | 355 | # layers 356 | out_dim = math.prod(patch_size) * out_dim 357 | self.norm = WanLayerNorm(dim, eps) 358 | self.head = nn.Linear(dim, out_dim) 359 | 360 | # modulation 361 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) 362 | 363 | def forward(self, x, e): 364 | r""" 365 | Args: 366 | x(Tensor): Shape [B, L1, C] 367 | e(Tensor): Shape [B, C] 368 | """ 369 | with amp.autocast("cuda", dtype=torch.float32): 370 | if e.dim() == 2: 371 | modulation = self.modulation # 1, 2, dim 372 | e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) 373 | 374 | elif e.dim() == 3: 375 | modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim 376 | e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) 377 | e = [ei.squeeze(1) for ei in e] 378 | x = self.head(self.norm(x) * (1 + e[1]) + e[0]) 379 | return x 380 | 381 | 382 | class MLPProj(torch.nn.Module): 383 | def __init__(self, in_dim, out_dim): 384 | super().__init__() 385 | 386 | self.proj = torch.nn.Sequential( 387 | torch.nn.LayerNorm(in_dim), 388 | torch.nn.Linear(in_dim, in_dim), 389 | torch.nn.GELU(), 390 | torch.nn.Linear(in_dim, out_dim), 391 | torch.nn.LayerNorm(out_dim), 392 | ) 393 | 394 | def forward(self, image_embeds): 395 | clip_extra_context_tokens = self.proj(image_embeds) 396 | return clip_extra_context_tokens 397 | 398 | 399 | class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin): 400 | r""" 401 | Wan diffusion backbone supporting both text-to-video and image-to-video. 402 | """ 403 | 404 | ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] 405 | _no_split_modules = ["WanAttentionBlock"] 406 | 407 | _supports_gradient_checkpointing = True 408 | 409 | @register_to_config 410 | def __init__( 411 | self, 412 | model_type="t2v", 413 | patch_size=(1, 2, 2), 414 | text_len=512, 415 | in_dim=16, 416 | dim=2048, 417 | ffn_dim=8192, 418 | freq_dim=256, 419 | text_dim=4096, 420 | out_dim=16, 421 | num_heads=16, 422 | num_layers=32, 423 | window_size=(-1, -1), 424 | qk_norm=True, 425 | cross_attn_norm=True, 426 | inject_sample_info=False, 427 | eps=1e-6, 428 | ): 429 | r""" 430 | Initialize the diffusion model backbone. 431 | 432 | Args: 433 | model_type (`str`, *optional*, defaults to 't2v'): 434 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) 435 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 436 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) 437 | text_len (`int`, *optional*, defaults to 512): 438 | Fixed length for text embeddings 439 | in_dim (`int`, *optional*, defaults to 16): 440 | Input video channels (C_in) 441 | dim (`int`, *optional*, defaults to 2048): 442 | Hidden dimension of the transformer 443 | ffn_dim (`int`, *optional*, defaults to 8192): 444 | Intermediate dimension in feed-forward network 445 | freq_dim (`int`, *optional*, defaults to 256): 446 | Dimension for sinusoidal time embeddings 447 | text_dim (`int`, *optional*, defaults to 4096): 448 | Input dimension for text embeddings 449 | out_dim (`int`, *optional*, defaults to 16): 450 | Output video channels (C_out) 451 | num_heads (`int`, *optional*, defaults to 16): 452 | Number of attention heads 453 | num_layers (`int`, *optional*, defaults to 32): 454 | Number of transformer blocks 455 | window_size (`tuple`, *optional*, defaults to (-1, -1)): 456 | Window size for local attention (-1 indicates global attention) 457 | qk_norm (`bool`, *optional*, defaults to True): 458 | Enable query/key normalization 459 | cross_attn_norm (`bool`, *optional*, defaults to False): 460 | Enable cross-attention normalization 461 | eps (`float`, *optional*, defaults to 1e-6): 462 | Epsilon value for normalization layers 463 | """ 464 | 465 | super().__init__() 466 | 467 | assert model_type in ["t2v", "i2v"] 468 | self.model_type = model_type 469 | 470 | self.patch_size = patch_size 471 | self.text_len = text_len 472 | self.in_dim = in_dim 473 | self.dim = dim 474 | self.ffn_dim = ffn_dim 475 | self.freq_dim = freq_dim 476 | self.text_dim = text_dim 477 | self.out_dim = out_dim 478 | self.num_heads = num_heads 479 | self.num_layers = num_layers 480 | self.window_size = window_size 481 | self.qk_norm = qk_norm 482 | self.cross_attn_norm = cross_attn_norm 483 | self.eps = eps 484 | self.num_frame_per_block = 1 485 | self.flag_causal_attention = False 486 | self.block_mask = None 487 | self.enable_teacache = False 488 | 489 | # embeddings 490 | self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) 491 | self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) 492 | 493 | self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) 494 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) 495 | 496 | if inject_sample_info: 497 | self.fps_embedding = nn.Embedding(2, dim) 498 | self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6)) 499 | 500 | # blocks 501 | cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" 502 | self.blocks = nn.ModuleList( 503 | [ 504 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) 505 | for _ in range(num_layers) 506 | ] 507 | ) 508 | 509 | # head 510 | self.head = Head(dim, out_dim, patch_size, eps) 511 | 512 | # buffers (don't use register_buffer otherwise dtype will be changed in to()) 513 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 514 | d = dim // num_heads 515 | self.freqs = torch.cat( 516 | [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], 517 | dim=1, 518 | ) 519 | 520 | if model_type == "i2v": 521 | self.img_emb = MLPProj(1280, dim) 522 | 523 | self.gradient_checkpointing = False 524 | 525 | self.cpu_offloading = False 526 | 527 | self.inject_sample_info = inject_sample_info 528 | # initialize weights 529 | self.init_weights() 530 | 531 | def _set_gradient_checkpointing(self, module, value=False): 532 | self.gradient_checkpointing = value 533 | 534 | def zero_init_i2v_cross_attn(self): 535 | print("zero init i2v cross attn") 536 | for i in range(self.num_layers): 537 | self.blocks[i].cross_attn.v_img.weight.data.zero_() 538 | self.blocks[i].cross_attn.v_img.bias.data.zero_() 539 | 540 | @staticmethod 541 | def _prepare_blockwise_causal_attn_mask( 542 | device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 543 | ) -> BlockMask: 544 | """ 545 | we will divide the token sequence into the following format 546 | [1 latent frame] [1 latent frame] ... [1 latent frame] 547 | We use flexattention to construct the attention mask 548 | """ 549 | total_length = num_frames * frame_seqlen 550 | 551 | # we do right padding to get to a multiple of 128 552 | padded_length = math.ceil(total_length / 128) * 128 - total_length 553 | 554 | ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) 555 | 556 | # Block-wise causal mask will attend to all elements that are before the end of the current chunk 557 | frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device) 558 | 559 | for tmp in frame_indices: 560 | ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block 561 | 562 | def attention_mask(b, h, q_idx, kv_idx): 563 | return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) 564 | # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask 565 | 566 | block_mask = create_block_mask( 567 | attention_mask, 568 | B=None, 569 | H=None, 570 | Q_LEN=total_length + padded_length, 571 | KV_LEN=total_length + padded_length, 572 | _compile=False, 573 | device=device, 574 | ) 575 | 576 | return block_mask 577 | 578 | def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''): 579 | self.enable_teacache = enable_teacache 580 | print('using teacache') 581 | self.cnt = 0 582 | self.num_steps = num_steps 583 | self.teacache_thresh = teacache_thresh 584 | self.accumulated_rel_l1_distance_even = 0 585 | self.accumulated_rel_l1_distance_odd = 0 586 | self.previous_e0_even = None 587 | self.previous_e0_odd = None 588 | self.previous_residual_even = None 589 | self.previous_residual_odd = None 590 | self.use_ref_steps = use_ret_steps 591 | if "I2V" in ckpt_dir: 592 | if use_ret_steps: 593 | if '540P' in ckpt_dir: 594 | self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] 595 | if '720P' in ckpt_dir: 596 | self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] 597 | self.ret_steps = 5*2 598 | self.cutoff_steps = num_steps*2 599 | else: 600 | if '540P' in ckpt_dir: 601 | self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] 602 | if '720P' in ckpt_dir: 603 | self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] 604 | self.ret_steps = 1*2 605 | self.cutoff_steps = num_steps*2 - 2 606 | else: 607 | if use_ret_steps: 608 | if '1.3B' in ckpt_dir: 609 | self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] 610 | if '14B' in ckpt_dir: 611 | self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] 612 | self.ret_steps = 5*2 613 | self.cutoff_steps = num_steps*2 614 | else: 615 | if '1.3B' in ckpt_dir: 616 | self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] 617 | if '14B' in ckpt_dir: 618 | self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] 619 | self.ret_steps = 1*2 620 | self.cutoff_steps = num_steps*2 - 2 621 | 622 | def forward(self, x, t, context, clip_fea=None, y=None, fps=None): 623 | r""" 624 | Forward pass through the diffusion model 625 | 626 | Args: 627 | x (List[Tensor]): 628 | List of input video tensors, each with shape [C_in, F, H, W] 629 | t (Tensor): 630 | Diffusion timesteps tensor of shape [B] 631 | context (List[Tensor]): 632 | List of text embeddings each with shape [L, C] 633 | seq_len (`int`): 634 | Maximum sequence length for positional encoding 635 | clip_fea (Tensor, *optional*): 636 | CLIP image features for image-to-video mode 637 | y (List[Tensor], *optional*): 638 | Conditional video inputs for image-to-video mode, same shape as x 639 | 640 | Returns: 641 | List[Tensor]: 642 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 643 | """ 644 | if self.model_type == "i2v": 645 | assert clip_fea is not None and y is not None 646 | # params 647 | device = self.patch_embedding.weight.device 648 | if self.freqs.device != device: 649 | self.freqs = self.freqs.to(device) 650 | 651 | if y is not None: 652 | x = torch.cat([x, y], dim=1) 653 | 654 | # embeddings 655 | x = self.patch_embedding(x) 656 | grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long) 657 | x = x.flatten(2).transpose(1, 2) 658 | 659 | if self.flag_causal_attention: 660 | frame_num = grid_sizes[0] 661 | height = grid_sizes[1] 662 | width = grid_sizes[2] 663 | block_num = frame_num // self.num_frame_per_block 664 | range_tensor = torch.arange(block_num).view(-1, 1) 665 | range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() 666 | casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f 667 | casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) 668 | casual_mask = casual_mask.repeat(1, height, width, 1, height, width) 669 | casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) 670 | self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0) 671 | 672 | # time embeddings 673 | with amp.autocast("cuda", dtype=torch.float32): 674 | if t.dim() == 2: 675 | b, f = t.shape 676 | _flag_df = True 677 | else: 678 | _flag_df = False 679 | 680 | e = self.time_embedding( 681 | sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) 682 | ) # b, dim 683 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim 684 | 685 | if self.inject_sample_info: 686 | fps = torch.tensor(fps, dtype=torch.long, device=device) 687 | 688 | fps_emb = self.fps_embedding(fps).float() 689 | if _flag_df: 690 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) 691 | else: 692 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) 693 | 694 | if _flag_df: 695 | e = e.view(b, f, 1, 1, self.dim) 696 | e0 = e0.view(b, f, 1, 1, 6, self.dim) 697 | e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) 698 | e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) 699 | e0 = e0.transpose(1, 2).contiguous() 700 | 701 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 702 | 703 | # context 704 | context = self.text_embedding(context) 705 | 706 | if clip_fea is not None: 707 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 708 | context = torch.concat([context_clip, context], dim=1) 709 | 710 | # arguments 711 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask) 712 | if self.enable_teacache: 713 | modulated_inp = e0 if self.use_ref_steps else e 714 | # teacache 715 | if self.cnt%2==0: # even -> conditon 716 | self.is_even = True 717 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: 718 | should_calc_even = True 719 | self.accumulated_rel_l1_distance_even = 0 720 | else: 721 | rescale_func = np.poly1d(self.coefficients) 722 | self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) 723 | if self.accumulated_rel_l1_distance_even < self.teacache_thresh: 724 | should_calc_even = False 725 | else: 726 | should_calc_even = True 727 | self.accumulated_rel_l1_distance_even = 0 728 | self.previous_e0_even = modulated_inp.clone() 729 | 730 | else: # odd -> unconditon 731 | self.is_even = False 732 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: 733 | should_calc_odd = True 734 | self.accumulated_rel_l1_distance_odd = 0 735 | else: 736 | rescale_func = np.poly1d(self.coefficients) 737 | self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) 738 | if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: 739 | should_calc_odd = False 740 | else: 741 | should_calc_odd = True 742 | self.accumulated_rel_l1_distance_odd = 0 743 | self.previous_e0_odd = modulated_inp.clone() 744 | 745 | if self.enable_teacache: 746 | if self.is_even: 747 | if not should_calc_even: 748 | x += self.previous_residual_even 749 | else: 750 | ori_x = x.clone() 751 | for block in self.blocks: 752 | x = block(x, **kwargs) 753 | self.previous_residual_even = x - ori_x 754 | else: 755 | if not should_calc_odd: 756 | x += self.previous_residual_odd 757 | else: 758 | ori_x = x.clone() 759 | for block in self.blocks: 760 | x = block(x, **kwargs) 761 | self.previous_residual_odd = x - ori_x 762 | 763 | self.cnt += 1 764 | if self.cnt >= self.num_steps: 765 | self.cnt = 0 766 | else: 767 | for block in self.blocks: 768 | x = block(x, **kwargs) 769 | 770 | x = self.head(x, e) 771 | 772 | # unpatchify 773 | x = self.unpatchify(x, grid_sizes) 774 | 775 | return x.float() 776 | 777 | def unpatchify(self, x, grid_sizes): 778 | r""" 779 | Reconstruct video tensors from patch embeddings. 780 | 781 | Args: 782 | x (List[Tensor]): 783 | List of patchified features, each with shape [L, C_out * prod(patch_size)] 784 | grid_sizes (Tensor): 785 | Original spatial-temporal grid dimensions before patching, 786 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) 787 | 788 | Returns: 789 | List[Tensor]: 790 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] 791 | """ 792 | 793 | c = self.out_dim 794 | bs = x.shape[0] 795 | x = x.view(bs, *grid_sizes, *self.patch_size, c) 796 | x = torch.einsum("bfhwpqrc->bcfphqwr", x) 797 | x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) 798 | 799 | return x 800 | 801 | def set_ar_attention(self, causal_block_size): 802 | self.num_frame_per_block = causal_block_size 803 | self.flag_causal_attention = True 804 | for block in self.blocks: 805 | block.set_ar_attention() 806 | 807 | def init_weights(self): 808 | r""" 809 | Initialize model parameters using Xavier initialization. 810 | """ 811 | 812 | # basic init 813 | for m in self.modules(): 814 | if isinstance(m, nn.Linear): 815 | nn.init.xavier_uniform_(m.weight) 816 | if m.bias is not None: 817 | nn.init.zeros_(m.bias) 818 | 819 | # init embeddings 820 | nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) 821 | for m in self.text_embedding.modules(): 822 | if isinstance(m, nn.Linear): 823 | nn.init.normal_(m.weight, std=0.02) 824 | for m in self.time_embedding.modules(): 825 | if isinstance(m, nn.Linear): 826 | nn.init.normal_(m.weight, std=0.02) 827 | 828 | if self.inject_sample_info: 829 | nn.init.normal_(self.fps_embedding.weight, std=0.02) 830 | 831 | for m in self.fps_projection.modules(): 832 | if isinstance(m, nn.Linear): 833 | nn.init.normal_(m.weight, std=0.02) 834 | 835 | nn.init.zeros_(self.fps_projection[-1].weight) 836 | nn.init.zeros_(self.fps_projection[-1].bias) 837 | 838 | # init output layer 839 | nn.init.zeros_(self.head.head.weight) 840 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | 10 | __all__ = [ 11 | "WanVAE", 12 | ] 13 | 14 | CACHE_T = 2 15 | 16 | 17 | class CausalConv3d(nn.Conv3d): 18 | """ 19 | Causal 3d convolusion. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) 25 | self.padding = (0, 0, 0) 26 | 27 | def forward(self, x, cache_x=None): 28 | padding = list(self._padding) 29 | if cache_x is not None and self._padding[4] > 0: 30 | cache_x = cache_x.to(x.device) 31 | x = torch.cat([cache_x, x], dim=2) 32 | padding[4] -= cache_x.shape[2] 33 | x = F.pad(x, padding) 34 | 35 | return super().forward(x) 36 | 37 | 38 | class RMS_norm(nn.Module): 39 | def __init__(self, dim, channel_first=True, images=True, bias=False): 40 | super().__init__() 41 | broadcastable_dims = (1, 1, 1) if not images else (1, 1) 42 | shape = (dim, *broadcastable_dims) if channel_first else (dim,) 43 | 44 | self.channel_first = channel_first 45 | self.scale = dim**0.5 46 | self.gamma = nn.Parameter(torch.ones(shape)) 47 | self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 48 | 49 | def forward(self, x): 50 | return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias 51 | 52 | 53 | class Upsample(nn.Upsample): 54 | def forward(self, x): 55 | """ 56 | Fix bfloat16 support for nearest neighbor interpolation. 57 | """ 58 | return super().forward(x.float()).type_as(x) 59 | 60 | 61 | class Resample(nn.Module): 62 | def __init__(self, dim, mode): 63 | assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") 64 | super().__init__() 65 | self.dim = dim 66 | self.mode = mode 67 | 68 | # layers 69 | if mode == "upsample2d": 70 | self.resample = nn.Sequential( 71 | Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) 72 | ) 73 | elif mode == "upsample3d": 74 | self.resample = nn.Sequential( 75 | Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) 76 | ) 77 | self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) 78 | 79 | elif mode == "downsample2d": 80 | self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) 81 | elif mode == "downsample3d": 82 | self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) 83 | self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) 84 | 85 | else: 86 | self.resample = nn.Identity() 87 | 88 | def forward(self, x, feat_cache=None, feat_idx=[0]): 89 | b, c, t, h, w = x.size() 90 | if self.mode == "upsample3d": 91 | if feat_cache is not None: 92 | idx = feat_idx[0] 93 | if feat_cache[idx] is None: 94 | feat_cache[idx] = "Rep" 95 | feat_idx[0] += 1 96 | else: 97 | 98 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 99 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": 100 | # cache last frame of last two chunk 101 | cache_x = torch.cat( 102 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 103 | ) 104 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": 105 | cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) 106 | if feat_cache[idx] == "Rep": 107 | x = self.time_conv(x) 108 | else: 109 | x = self.time_conv(x, feat_cache[idx]) 110 | feat_cache[idx] = cache_x 111 | feat_idx[0] += 1 112 | 113 | x = x.reshape(b, 2, c, t, h, w) 114 | x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) 115 | x = x.reshape(b, c, t * 2, h, w) 116 | t = x.shape[2] 117 | x = rearrange(x, "b c t h w -> (b t) c h w") 118 | x = self.resample(x) 119 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t) 120 | 121 | if self.mode == "downsample3d": 122 | if feat_cache is not None: 123 | idx = feat_idx[0] 124 | if feat_cache[idx] is None: 125 | feat_cache[idx] = x.clone() 126 | feat_idx[0] += 1 127 | else: 128 | 129 | cache_x = x[:, :, -1:, :, :].clone() 130 | # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': 131 | # # cache last frame of last two chunk 132 | # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) 133 | 134 | x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) 135 | feat_cache[idx] = cache_x 136 | feat_idx[0] += 1 137 | return x 138 | 139 | def init_weight(self, conv): 140 | conv_weight = conv.weight 141 | nn.init.zeros_(conv_weight) 142 | c1, c2, t, h, w = conv_weight.size() 143 | one_matrix = torch.eye(c1, c2) 144 | init_matrix = one_matrix 145 | nn.init.zeros_(conv_weight) 146 | # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 147 | conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 148 | conv.weight.data.copy_(conv_weight) 149 | nn.init.zeros_(conv.bias.data) 150 | 151 | def init_weight2(self, conv): 152 | conv_weight = conv.weight.data 153 | nn.init.zeros_(conv_weight) 154 | c1, c2, t, h, w = conv_weight.size() 155 | init_matrix = torch.eye(c1 // 2, c2) 156 | # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) 157 | conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix 158 | conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix 159 | conv.weight.data.copy_(conv_weight) 160 | nn.init.zeros_(conv.bias.data) 161 | 162 | 163 | class ResidualBlock(nn.Module): 164 | def __init__(self, in_dim, out_dim, dropout=0.0): 165 | super().__init__() 166 | self.in_dim = in_dim 167 | self.out_dim = out_dim 168 | 169 | # layers 170 | self.residual = nn.Sequential( 171 | RMS_norm(in_dim, images=False), 172 | nn.SiLU(), 173 | CausalConv3d(in_dim, out_dim, 3, padding=1), 174 | RMS_norm(out_dim, images=False), 175 | nn.SiLU(), 176 | nn.Dropout(dropout), 177 | CausalConv3d(out_dim, out_dim, 3, padding=1), 178 | ) 179 | self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() 180 | 181 | def forward(self, x, feat_cache=None, feat_idx=[0]): 182 | h = self.shortcut(x) 183 | for layer in self.residual: 184 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 185 | idx = feat_idx[0] 186 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 187 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 188 | # cache last frame of last two chunk 189 | cache_x = torch.cat( 190 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 191 | ) 192 | x = layer(x, feat_cache[idx]) 193 | feat_cache[idx] = cache_x 194 | feat_idx[0] += 1 195 | else: 196 | x = layer(x) 197 | return x + h 198 | 199 | 200 | class AttentionBlock(nn.Module): 201 | """ 202 | Causal self-attention with a single head. 203 | """ 204 | 205 | def __init__(self, dim): 206 | super().__init__() 207 | self.dim = dim 208 | 209 | # layers 210 | self.norm = RMS_norm(dim) 211 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) 212 | self.proj = nn.Conv2d(dim, dim, 1) 213 | 214 | # zero out the last layer params 215 | nn.init.zeros_(self.proj.weight) 216 | 217 | def forward(self, x): 218 | identity = x 219 | b, c, t, h, w = x.size() 220 | x = rearrange(x, "b c t h w -> (b t) c h w") 221 | x = self.norm(x) 222 | # compute query, key, value 223 | q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) 224 | 225 | # apply attention 226 | x = F.scaled_dot_product_attention( 227 | q, 228 | k, 229 | v, 230 | ) 231 | x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) 232 | 233 | # output 234 | x = self.proj(x) 235 | x = rearrange(x, "(b t) c h w-> b c t h w", t=t) 236 | return x + identity 237 | 238 | 239 | class Encoder3d(nn.Module): 240 | def __init__( 241 | self, 242 | dim=128, 243 | z_dim=4, 244 | dim_mult=[1, 2, 4, 4], 245 | num_res_blocks=2, 246 | attn_scales=[], 247 | temperal_downsample=[True, True, False], 248 | dropout=0.0, 249 | ): 250 | super().__init__() 251 | self.dim = dim 252 | self.z_dim = z_dim 253 | self.dim_mult = dim_mult 254 | self.num_res_blocks = num_res_blocks 255 | self.attn_scales = attn_scales 256 | self.temperal_downsample = temperal_downsample 257 | 258 | # dimensions 259 | dims = [dim * u for u in [1] + dim_mult] 260 | scale = 1.0 261 | 262 | # init block 263 | self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) 264 | 265 | # downsample blocks 266 | downsamples = [] 267 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 268 | # residual (+attention) blocks 269 | for _ in range(num_res_blocks): 270 | downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) 271 | if scale in attn_scales: 272 | downsamples.append(AttentionBlock(out_dim)) 273 | in_dim = out_dim 274 | 275 | # downsample block 276 | if i != len(dim_mult) - 1: 277 | mode = "downsample3d" if temperal_downsample[i] else "downsample2d" 278 | downsamples.append(Resample(out_dim, mode=mode)) 279 | scale /= 2.0 280 | self.downsamples = nn.Sequential(*downsamples) 281 | 282 | # middle blocks 283 | self.middle = nn.Sequential( 284 | ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout) 285 | ) 286 | 287 | # output blocks 288 | self.head = nn.Sequential( 289 | RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) 290 | ) 291 | 292 | def forward(self, x, feat_cache=None, feat_idx=[0]): 293 | if feat_cache is not None: 294 | idx = feat_idx[0] 295 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 296 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 297 | # cache last frame of last two chunk 298 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) 299 | x = self.conv1(x, feat_cache[idx]) 300 | feat_cache[idx] = cache_x 301 | feat_idx[0] += 1 302 | else: 303 | x = self.conv1(x) 304 | 305 | ## downsamples 306 | for layer in self.downsamples: 307 | if feat_cache is not None: 308 | x = layer(x, feat_cache, feat_idx) 309 | else: 310 | x = layer(x) 311 | 312 | ## middle 313 | for layer in self.middle: 314 | if isinstance(layer, ResidualBlock) and feat_cache is not None: 315 | x = layer(x, feat_cache, feat_idx) 316 | else: 317 | x = layer(x) 318 | 319 | ## head 320 | for layer in self.head: 321 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 322 | idx = feat_idx[0] 323 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 324 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 325 | # cache last frame of last two chunk 326 | cache_x = torch.cat( 327 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 328 | ) 329 | x = layer(x, feat_cache[idx]) 330 | feat_cache[idx] = cache_x 331 | feat_idx[0] += 1 332 | else: 333 | x = layer(x) 334 | return x 335 | 336 | 337 | class Decoder3d(nn.Module): 338 | def __init__( 339 | self, 340 | dim=128, 341 | z_dim=4, 342 | dim_mult=[1, 2, 4, 4], 343 | num_res_blocks=2, 344 | attn_scales=[], 345 | temperal_upsample=[False, True, True], 346 | dropout=0.0, 347 | ): 348 | super().__init__() 349 | self.dim = dim 350 | self.z_dim = z_dim 351 | self.dim_mult = dim_mult 352 | self.num_res_blocks = num_res_blocks 353 | self.attn_scales = attn_scales 354 | self.temperal_upsample = temperal_upsample 355 | 356 | # dimensions 357 | dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] 358 | scale = 1.0 / 2 ** (len(dim_mult) - 2) 359 | 360 | # init block 361 | self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) 362 | 363 | # middle blocks 364 | self.middle = nn.Sequential( 365 | ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout) 366 | ) 367 | 368 | # upsample blocks 369 | upsamples = [] 370 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 371 | # residual (+attention) blocks 372 | if i == 1 or i == 2 or i == 3: 373 | in_dim = in_dim // 2 374 | for _ in range(num_res_blocks + 1): 375 | upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) 376 | if scale in attn_scales: 377 | upsamples.append(AttentionBlock(out_dim)) 378 | in_dim = out_dim 379 | 380 | # upsample block 381 | if i != len(dim_mult) - 1: 382 | mode = "upsample3d" if temperal_upsample[i] else "upsample2d" 383 | upsamples.append(Resample(out_dim, mode=mode)) 384 | scale *= 2.0 385 | self.upsamples = nn.Sequential(*upsamples) 386 | 387 | # output blocks 388 | self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) 389 | 390 | def forward(self, x, feat_cache=None, feat_idx=[0]): 391 | ## conv1 392 | if feat_cache is not None: 393 | idx = feat_idx[0] 394 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 395 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 396 | # cache last frame of last two chunk 397 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) 398 | x = self.conv1(x, feat_cache[idx]) 399 | feat_cache[idx] = cache_x 400 | feat_idx[0] += 1 401 | else: 402 | x = self.conv1(x) 403 | 404 | ## middle 405 | for layer in self.middle: 406 | if isinstance(layer, ResidualBlock) and feat_cache is not None: 407 | x = layer(x, feat_cache, feat_idx) 408 | else: 409 | x = layer(x) 410 | 411 | ## upsamples 412 | for layer in self.upsamples: 413 | if feat_cache is not None: 414 | x = layer(x, feat_cache, feat_idx) 415 | else: 416 | x = layer(x) 417 | 418 | ## head 419 | for layer in self.head: 420 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 421 | idx = feat_idx[0] 422 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 423 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 424 | # cache last frame of last two chunk 425 | cache_x = torch.cat( 426 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 427 | ) 428 | x = layer(x, feat_cache[idx]) 429 | feat_cache[idx] = cache_x 430 | feat_idx[0] += 1 431 | else: 432 | x = layer(x) 433 | return x 434 | 435 | 436 | def count_conv3d(model): 437 | count = 0 438 | for m in model.modules(): 439 | if isinstance(m, CausalConv3d): 440 | count += 1 441 | return count 442 | 443 | 444 | class WanVAE_(nn.Module): 445 | def __init__( 446 | self, 447 | dim=128, 448 | z_dim=4, 449 | dim_mult=[1, 2, 4, 4], 450 | num_res_blocks=2, 451 | attn_scales=[], 452 | temperal_downsample=[True, True, False], 453 | dropout=0.0, 454 | ): 455 | super().__init__() 456 | self.dim = dim 457 | self.z_dim = z_dim 458 | self.dim_mult = dim_mult 459 | self.num_res_blocks = num_res_blocks 460 | self.attn_scales = attn_scales 461 | self.temperal_downsample = temperal_downsample 462 | self.temperal_upsample = temperal_downsample[::-1] 463 | 464 | # modules 465 | self.encoder = Encoder3d( 466 | dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout 467 | ) 468 | self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) 469 | self.conv2 = CausalConv3d(z_dim, z_dim, 1) 470 | self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) 471 | 472 | def forward(self, x): 473 | mu, log_var = self.encode(x) 474 | z = self.reparameterize(mu, log_var) 475 | x_recon = self.decode(z) 476 | return x_recon, mu, log_var 477 | 478 | def encode(self, x, scale): 479 | self.clear_cache() 480 | ## cache 481 | t = x.shape[2] 482 | iter_ = 1 + (t - 1) // 4 483 | ## 对encode输入的x,按时间拆分为1、4、4、4.... 484 | for i in range(iter_): 485 | self._enc_conv_idx = [0] 486 | if i == 0: 487 | out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) 488 | else: 489 | out_ = self.encoder( 490 | x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], 491 | feat_cache=self._enc_feat_map, 492 | feat_idx=self._enc_conv_idx, 493 | ) 494 | out = torch.cat([out, out_], 2) 495 | mu, log_var = self.conv1(out).chunk(2, dim=1) 496 | if isinstance(scale[0], torch.Tensor): 497 | mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) 498 | else: 499 | mu = (mu - scale[0]) * scale[1] 500 | self.clear_cache() 501 | return mu 502 | 503 | def decode(self, z, scale): 504 | self.clear_cache() 505 | # z: [b,c,t,h,w] 506 | if isinstance(scale[0], torch.Tensor): 507 | z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) 508 | else: 509 | z = z / scale[1] + scale[0] 510 | iter_ = z.shape[2] 511 | x = self.conv2(z) 512 | for i in range(iter_): 513 | self._conv_idx = [0] 514 | if i == 0: 515 | out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) 516 | else: 517 | out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) 518 | out = torch.cat([out, out_], 2) 519 | self.clear_cache() 520 | return out 521 | 522 | def reparameterize(self, mu, log_var): 523 | std = torch.exp(0.5 * log_var) 524 | eps = torch.randn_like(std) 525 | return eps * std + mu 526 | 527 | def sample(self, imgs, deterministic=False): 528 | mu, log_var = self.encode(imgs) 529 | if deterministic: 530 | return mu 531 | std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) 532 | return mu + std * torch.randn_like(std) 533 | 534 | def clear_cache(self): 535 | self._conv_num = count_conv3d(self.decoder) 536 | self._conv_idx = [0] 537 | self._feat_map = [None] * self._conv_num 538 | # cache encode 539 | self._enc_conv_num = count_conv3d(self.encoder) 540 | self._enc_conv_idx = [0] 541 | self._enc_feat_map = [None] * self._enc_conv_num 542 | 543 | 544 | def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): 545 | """ 546 | Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. 547 | """ 548 | # params 549 | cfg = dict( 550 | dim=96, 551 | z_dim=z_dim, 552 | dim_mult=[1, 2, 4, 4], 553 | num_res_blocks=2, 554 | attn_scales=[], 555 | temperal_downsample=[False, True, True], 556 | dropout=0.0, 557 | ) 558 | cfg.update(**kwargs) 559 | 560 | # init model 561 | with torch.device("meta"): 562 | model = WanVAE_(**cfg) 563 | 564 | # load checkpoint 565 | logging.info(f"loading {pretrained_path}") 566 | model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) 567 | 568 | return model 569 | 570 | 571 | class WanVAE: 572 | def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16): 573 | 574 | mean = [ 575 | -0.7571, 576 | -0.7089, 577 | -0.9113, 578 | 0.1075, 579 | -0.1745, 580 | 0.9653, 581 | -0.1517, 582 | 1.5508, 583 | 0.4134, 584 | -0.0715, 585 | 0.5517, 586 | -0.3632, 587 | -0.1922, 588 | -0.9497, 589 | 0.2503, 590 | -0.2921, 591 | ] 592 | std = [ 593 | 2.8184, 594 | 1.4541, 595 | 2.3275, 596 | 2.6558, 597 | 1.2196, 598 | 1.7708, 599 | 2.6052, 600 | 2.0743, 601 | 3.2687, 602 | 2.1526, 603 | 2.8652, 604 | 1.5579, 605 | 1.6382, 606 | 1.1253, 607 | 2.8251, 608 | 1.9160, 609 | ] 610 | self.vae_stride = (4, 8, 8) 611 | self.mean = torch.tensor(mean) 612 | self.std = torch.tensor(std) 613 | self.scale = [self.mean, 1.0 / self.std] 614 | 615 | # init model 616 | self.vae = ( 617 | _video_vae( 618 | pretrained_path=vae_pth, 619 | z_dim=z_dim, 620 | ) 621 | .eval() 622 | .requires_grad_(False) 623 | ) 624 | 625 | def encode(self, video): 626 | """ 627 | videos: A list of videos each with shape [C, T, H, W]. 628 | """ 629 | return self.vae.encode(video, self.scale).float() 630 | 631 | def to(self, *args, **kwargs): 632 | self.mean = self.mean.to(*args, **kwargs) 633 | self.std = self.std.to(*args, **kwargs) 634 | self.scale = [self.mean, 1.0 / self.std] 635 | self.vae = self.vae.to(*args, **kwargs) 636 | return self 637 | 638 | def decode(self, z): 639 | return self.vae.decode(z, self.scale).float().clamp_(-1, 1) 640 | -------------------------------------------------------------------------------- /skyreels_v2_infer/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ["XLMRoberta", "xlm_roberta_large"] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 12 | assert dim % num_heads == 0 13 | super().__init__() 14 | self.dim = dim 15 | self.num_heads = num_heads 16 | self.head_dim = dim // num_heads 17 | self.eps = eps 18 | 19 | # layers 20 | self.q = nn.Linear(dim, dim) 21 | self.k = nn.Linear(dim, dim) 22 | self.v = nn.Linear(dim, dim) 23 | self.o = nn.Linear(dim, dim) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | def forward(self, x, mask): 27 | """ 28 | x: [B, L, C]. 29 | """ 30 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 31 | 32 | # compute query, key, value 33 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 34 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | 37 | # compute attention 38 | p = self.dropout.p if self.training else 0.0 39 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 40 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 41 | 42 | # output 43 | x = self.o(x) 44 | x = self.dropout(x) 45 | return x 46 | 47 | 48 | class AttentionBlock(nn.Module): 49 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 50 | super().__init__() 51 | self.dim = dim 52 | self.num_heads = num_heads 53 | self.post_norm = post_norm 54 | self.eps = eps 55 | 56 | # layers 57 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 58 | self.norm1 = nn.LayerNorm(dim, eps=eps) 59 | self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)) 60 | self.norm2 = nn.LayerNorm(dim, eps=eps) 61 | 62 | def forward(self, x, mask): 63 | if self.post_norm: 64 | x = self.norm1(x + self.attn(x, mask)) 65 | x = self.norm2(x + self.ffn(x)) 66 | else: 67 | x = x + self.attn(self.norm1(x), mask) 68 | x = x + self.ffn(self.norm2(x)) 69 | return x 70 | 71 | 72 | class XLMRoberta(nn.Module): 73 | """ 74 | XLMRobertaModel with no pooler and no LM head. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | vocab_size=250002, 80 | max_seq_len=514, 81 | type_size=1, 82 | pad_id=1, 83 | dim=1024, 84 | num_heads=16, 85 | num_layers=24, 86 | post_norm=True, 87 | dropout=0.1, 88 | eps=1e-5, 89 | ): 90 | super().__init__() 91 | self.vocab_size = vocab_size 92 | self.max_seq_len = max_seq_len 93 | self.type_size = type_size 94 | self.pad_id = pad_id 95 | self.dim = dim 96 | self.num_heads = num_heads 97 | self.num_layers = num_layers 98 | self.post_norm = post_norm 99 | self.eps = eps 100 | 101 | # embeddings 102 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 103 | self.type_embedding = nn.Embedding(type_size, dim) 104 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 105 | self.dropout = nn.Dropout(dropout) 106 | 107 | # blocks 108 | self.blocks = nn.ModuleList( 109 | [AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)] 110 | ) 111 | 112 | # norm layer 113 | self.norm = nn.LayerNorm(dim, eps=eps) 114 | 115 | def forward(self, ids): 116 | """ 117 | ids: [B, L] of torch.LongTensor. 118 | """ 119 | b, s = ids.shape 120 | mask = ids.ne(self.pad_id).long() 121 | 122 | # embeddings 123 | x = ( 124 | self.token_embedding(ids) 125 | + self.type_embedding(torch.zeros_like(ids)) 126 | + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 127 | ) 128 | if self.post_norm: 129 | x = self.norm(x) 130 | x = self.dropout(x) 131 | 132 | # blocks 133 | mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min) 134 | for block in self.blocks: 135 | x = block(x, mask) 136 | 137 | # output 138 | if not self.post_norm: 139 | x = self.norm(x) 140 | return x 141 | 142 | 143 | def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs): 144 | """ 145 | XLMRobertaLarge adapted from Huggingface. 146 | """ 147 | # params 148 | cfg = dict( 149 | vocab_size=250002, 150 | max_seq_len=514, 151 | type_size=1, 152 | pad_id=1, 153 | dim=1024, 154 | num_heads=16, 155 | num_layers=24, 156 | post_norm=True, 157 | dropout=0.1, 158 | eps=1e-5, 159 | ) 160 | cfg.update(**kwargs) 161 | 162 | # init a model on device 163 | with torch.device(device): 164 | model = XLMRoberta(**cfg) 165 | return model 166 | -------------------------------------------------------------------------------- /skyreels_v2_infer/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_forcing_pipeline import DiffusionForcingPipeline 2 | from .image2video_pipeline import Image2VideoPipeline 3 | from .image2video_pipeline import resizecrop 4 | from .prompt_enhancer import PromptEnhancer 5 | from .text2video_pipeline import Text2VideoPipeline 6 | -------------------------------------------------------------------------------- /skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import List 4 | from typing import Optional 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import numpy as np 9 | import torch 10 | from diffusers.image_processor import PipelineImageInput 11 | from diffusers.utils.torch_utils import randn_tensor 12 | from diffusers.video_processor import VideoProcessor 13 | from tqdm import tqdm 14 | 15 | from ..modules import get_text_encoder 16 | from ..modules import get_transformer 17 | from ..modules import get_vae 18 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler 19 | 20 | 21 | class DiffusionForcingPipeline: 22 | """ 23 | A pipeline for diffusion-based video generation tasks. 24 | 25 | This pipeline supports two main tasks: 26 | - Image-to-Video (i2v): Generates a video sequence from a source image 27 | - Text-to-Video (t2v): Generates a video sequence from a text description 28 | 29 | The pipeline integrates multiple components including: 30 | - A transformer model for diffusion 31 | - A VAE for encoding/decoding 32 | - A text encoder for processing text prompts 33 | - An image encoder for processing image inputs (i2v mode only) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | model_path: str, 39 | dit_path: str, 40 | device: str = "cuda", 41 | weight_dtype=torch.bfloat16, 42 | use_usp=False, 43 | offload=False, 44 | ): 45 | """ 46 | Initialize the diffusion forcing pipeline class 47 | 48 | Args: 49 | model_path (str): Path to the model 50 | dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor) 51 | device (str): Device to run on, defaults to 'cuda' 52 | weight_dtype: Weight data type, defaults to torch.bfloat16 53 | """ 54 | load_device = "cpu" if offload else device 55 | self.transformer = get_transformer(dit_path, load_device, weight_dtype) 56 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth") 57 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32) 58 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype) 59 | self.video_processor = VideoProcessor(vae_scale_factor=16) 60 | self.device = device 61 | self.offload = offload 62 | 63 | if use_usp: 64 | from xfuser.core.distributed import get_sequence_parallel_world_size 65 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward 66 | import types 67 | 68 | for block in self.transformer.blocks: 69 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) 70 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer) 71 | self.sp_size = get_sequence_parallel_world_size() 72 | 73 | self.scheduler = FlowUniPCMultistepScheduler() 74 | 75 | @property 76 | def do_classifier_free_guidance(self) -> bool: 77 | return self._guidance_scale > 1 78 | 79 | def encode_image( 80 | self, image: PipelineImageInput, height: int, width: int, num_frames: int 81 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 82 | 83 | # prefix_video 84 | prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) 85 | prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) 86 | if prefix_video.dtype == torch.uint8: 87 | prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 88 | prefix_video = prefix_video.to(self.device) 89 | prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] 90 | causal_block_size = self.transformer.num_frame_per_block 91 | if prefix_video[0].shape[1] % causal_block_size != 0: 92 | truncate_len = prefix_video[0].shape[1] % causal_block_size 93 | print("the length of prefix video is truncated for the casual block size alignment.") 94 | prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] 95 | predix_video_latent_length = prefix_video[0].shape[1] 96 | return prefix_video, predix_video_latent_length 97 | 98 | def prepare_latents( 99 | self, 100 | shape: Tuple[int], 101 | dtype: Optional[torch.dtype] = None, 102 | device: Optional[torch.device] = None, 103 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 104 | ) -> torch.Tensor: 105 | return randn_tensor(shape, generator, device=device, dtype=dtype) 106 | 107 | def generate_timestep_matrix( 108 | self, 109 | num_frames, 110 | step_template, 111 | base_num_frames, 112 | ar_step=5, 113 | num_pre_ready=0, 114 | casual_block_size=1, 115 | shrink_interval_with_mask=False, 116 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: 117 | step_matrix, step_index = [], [] 118 | update_mask, valid_interval = [], [] 119 | num_iterations = len(step_template) + 1 120 | num_frames_block = num_frames // casual_block_size 121 | base_num_frames_block = base_num_frames // casual_block_size 122 | if base_num_frames_block < num_frames_block: 123 | infer_step_num = len(step_template) 124 | gen_block = base_num_frames_block 125 | min_ar_step = infer_step_num / gen_block 126 | assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" 127 | # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) 128 | step_template = torch.cat( 129 | [ 130 | torch.tensor([999], dtype=torch.int64, device=step_template.device), 131 | step_template.long(), 132 | torch.tensor([0], dtype=torch.int64, device=step_template.device), 133 | ] 134 | ) # to handle the counter in row works starting from 1 135 | pre_row = torch.zeros(num_frames_block, dtype=torch.long) 136 | if num_pre_ready > 0: 137 | pre_row[: num_pre_ready // casual_block_size] = num_iterations 138 | 139 | while torch.all(pre_row >= (num_iterations - 1)) == False: 140 | new_row = torch.zeros(num_frames_block, dtype=torch.long) 141 | for i in range(num_frames_block): 142 | if i == 0 or pre_row[i - 1] >= ( 143 | num_iterations - 1 144 | ): # the first frame or the last frame is completely denoised 145 | new_row[i] = pre_row[i] + 1 146 | else: 147 | new_row[i] = new_row[i - 1] - ar_step 148 | new_row = new_row.clamp(0, num_iterations) 149 | 150 | update_mask.append( 151 | (new_row != pre_row) & (new_row != num_iterations) 152 | ) # False: no need to update, True: need to update 153 | step_index.append(new_row) 154 | step_matrix.append(step_template[new_row]) 155 | pre_row = new_row 156 | 157 | # for long video we split into several sequences, base_num_frames is set to the model max length (for training) 158 | terminal_flag = base_num_frames_block 159 | if shrink_interval_with_mask: 160 | idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) 161 | update_mask = update_mask[0] 162 | update_mask_idx = idx_sequence[update_mask] 163 | last_update_idx = update_mask_idx[-1].item() 164 | terminal_flag = last_update_idx + 1 165 | # for i in range(0, len(update_mask)): 166 | for curr_mask in update_mask: 167 | if terminal_flag < num_frames_block and curr_mask[terminal_flag]: 168 | terminal_flag += 1 169 | valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) 170 | 171 | step_update_mask = torch.stack(update_mask, dim=0) 172 | step_index = torch.stack(step_index, dim=0) 173 | step_matrix = torch.stack(step_matrix, dim=0) 174 | 175 | if casual_block_size > 1: 176 | step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() 177 | step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() 178 | step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() 179 | valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] 180 | 181 | return step_matrix, step_index, step_update_mask, valid_interval 182 | 183 | @torch.no_grad() 184 | def __call__( 185 | self, 186 | prompt: Union[str, List[str]], 187 | negative_prompt: Union[str, List[str]] = "", 188 | image: PipelineImageInput = None, 189 | height: int = 480, 190 | width: int = 832, 191 | num_frames: int = 97, 192 | num_inference_steps: int = 50, 193 | shift: float = 1.0, 194 | guidance_scale: float = 5.0, 195 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 196 | overlap_history: int = None, 197 | addnoise_condition: int = 0, 198 | base_num_frames: int = 97, 199 | ar_step: int = 5, 200 | causal_block_size: int = None, 201 | fps: int = 24, 202 | ): 203 | latent_height = height // 8 204 | latent_width = width // 8 205 | latent_length = (num_frames - 1) // 4 + 1 206 | 207 | self._guidance_scale = guidance_scale 208 | 209 | i2v_extra_kwrags = {} 210 | prefix_video = None 211 | predix_video_latent_length = 0 212 | if image: 213 | prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) 214 | 215 | self.text_encoder.to(self.device) 216 | prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) 217 | if self.do_classifier_free_guidance: 218 | negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype) 219 | if self.offload: 220 | self.text_encoder.cpu() 221 | torch.cuda.empty_cache() 222 | 223 | self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) 224 | init_timesteps = self.scheduler.timesteps 225 | if causal_block_size is None: 226 | causal_block_size = self.transformer.num_frame_per_block 227 | fps_embeds = [fps] * prompt_embeds.shape[0] 228 | fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] 229 | transformer_dtype = self.transformer.dtype 230 | # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad(): 231 | if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: 232 | # short video generation 233 | latent_shape = [16, latent_length, latent_height, latent_width] 234 | latents = self.prepare_latents( 235 | latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator 236 | ) 237 | latents = [latents] 238 | if prefix_video is not None: 239 | latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) 240 | base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length 241 | step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( 242 | latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size 243 | ) 244 | sample_schedulers = [] 245 | for _ in range(latent_length): 246 | sample_scheduler = FlowUniPCMultistepScheduler( 247 | num_train_timesteps=1000, shift=1, use_dynamic_shifting=False 248 | ) 249 | sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) 250 | sample_schedulers.append(sample_scheduler) 251 | sample_schedulers_counter = [0] * latent_length 252 | self.transformer.to(self.device) 253 | for i, timestep_i in enumerate(tqdm(step_matrix)): 254 | update_mask_i = step_update_mask[i] 255 | valid_interval_i = valid_interval[i] 256 | valid_interval_start, valid_interval_end = valid_interval_i 257 | timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() 258 | latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] 259 | if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: 260 | noise_factor = 0.001 * addnoise_condition 261 | timestep_for_noised_condition = addnoise_condition 262 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( 263 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) 264 | + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) 265 | * noise_factor 266 | ) 267 | timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition 268 | if not self.do_classifier_free_guidance: 269 | noise_pred = self.transformer( 270 | torch.stack([latent_model_input[0]]), 271 | t=timestep, 272 | context=prompt_embeds, 273 | fps=fps_embeds, 274 | **i2v_extra_kwrags, 275 | )[0] 276 | else: 277 | noise_pred_cond = self.transformer( 278 | torch.stack([latent_model_input[0]]), 279 | t=timestep, 280 | context=prompt_embeds, 281 | fps=fps_embeds, 282 | **i2v_extra_kwrags, 283 | )[0] 284 | noise_pred_uncond = self.transformer( 285 | torch.stack([latent_model_input[0]]), 286 | t=timestep, 287 | context=negative_prompt_embeds, 288 | fps=fps_embeds, 289 | **i2v_extra_kwrags, 290 | )[0] 291 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 292 | for idx in range(valid_interval_start, valid_interval_end): 293 | if update_mask_i[idx].item(): 294 | latents[0][:, idx] = sample_schedulers[idx].step( 295 | noise_pred[:, idx - valid_interval_start], 296 | timestep_i[idx], 297 | latents[0][:, idx], 298 | return_dict=False, 299 | generator=generator, 300 | )[0] 301 | sample_schedulers_counter[idx] += 1 302 | if self.offload: 303 | self.transformer.cpu() 304 | torch.cuda.empty_cache() 305 | x0 = latents[0].unsqueeze(0) 306 | videos = self.vae.decode(x0) 307 | videos = (videos / 2 + 0.5).clamp(0, 1) 308 | videos = [video for video in videos] 309 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] 310 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos] 311 | return videos 312 | else: 313 | # long video generation 314 | base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length 315 | overlap_history_frames = (overlap_history - 1) // 4 + 1 316 | n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 317 | print(f"n_iter:{n_iter}") 318 | output_video = None 319 | for i in range(n_iter): 320 | if output_video is not None: # i !=0 321 | prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) 322 | prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] 323 | if prefix_video[0].shape[1] % causal_block_size != 0: 324 | truncate_len = prefix_video[0].shape[1] % causal_block_size 325 | print("the length of prefix video is truncated for the casual block size alignment.") 326 | prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] 327 | predix_video_latent_length = prefix_video[0].shape[1] 328 | finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames 329 | left_frame_num = latent_length - finished_frame_num 330 | base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) 331 | if ar_step > 0 and self.transformer.enable_teacache: 332 | num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step 333 | self.transformer.num_steps = num_steps 334 | else: # i == 0 335 | base_num_frames_iter = base_num_frames 336 | latent_shape = [16, base_num_frames_iter, latent_height, latent_width] 337 | latents = self.prepare_latents( 338 | latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator 339 | ) 340 | latents = [latents] 341 | if prefix_video is not None: 342 | latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) 343 | step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( 344 | base_num_frames_iter, 345 | init_timesteps, 346 | base_num_frames_iter, 347 | ar_step, 348 | predix_video_latent_length, 349 | causal_block_size, 350 | ) 351 | sample_schedulers = [] 352 | for _ in range(base_num_frames_iter): 353 | sample_scheduler = FlowUniPCMultistepScheduler( 354 | num_train_timesteps=1000, shift=1, use_dynamic_shifting=False 355 | ) 356 | sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) 357 | sample_schedulers.append(sample_scheduler) 358 | sample_schedulers_counter = [0] * base_num_frames_iter 359 | self.transformer.to(self.device) 360 | for i, timestep_i in enumerate(tqdm(step_matrix)): 361 | update_mask_i = step_update_mask[i] 362 | valid_interval_i = valid_interval[i] 363 | valid_interval_start, valid_interval_end = valid_interval_i 364 | timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() 365 | latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] 366 | if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: 367 | noise_factor = 0.001 * addnoise_condition 368 | timestep_for_noised_condition = addnoise_condition 369 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( 370 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] 371 | * (1.0 - noise_factor) 372 | + torch.randn_like( 373 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] 374 | ) 375 | * noise_factor 376 | ) 377 | timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition 378 | if not self.do_classifier_free_guidance: 379 | noise_pred = self.transformer( 380 | torch.stack([latent_model_input[0]]), 381 | t=timestep, 382 | context=prompt_embeds, 383 | fps=fps_embeds, 384 | **i2v_extra_kwrags, 385 | )[0] 386 | else: 387 | noise_pred_cond = self.transformer( 388 | torch.stack([latent_model_input[0]]), 389 | t=timestep, 390 | context=prompt_embeds, 391 | fps=fps_embeds, 392 | **i2v_extra_kwrags, 393 | )[0] 394 | noise_pred_uncond = self.transformer( 395 | torch.stack([latent_model_input[0]]), 396 | t=timestep, 397 | context=negative_prompt_embeds, 398 | fps=fps_embeds, 399 | **i2v_extra_kwrags, 400 | )[0] 401 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 402 | for idx in range(valid_interval_start, valid_interval_end): 403 | if update_mask_i[idx].item(): 404 | latents[0][:, idx] = sample_schedulers[idx].step( 405 | noise_pred[:, idx - valid_interval_start], 406 | timestep_i[idx], 407 | latents[0][:, idx], 408 | return_dict=False, 409 | generator=generator, 410 | )[0] 411 | sample_schedulers_counter[idx] += 1 412 | if self.offload: 413 | self.transformer.cpu() 414 | torch.cuda.empty_cache() 415 | x0 = latents[0].unsqueeze(0) 416 | videos = [self.vae.decode(x0)[0]] 417 | if output_video is None: 418 | output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w 419 | else: 420 | output_video = torch.cat( 421 | [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 422 | ) # c, f, h, w 423 | output_video = [(output_video / 2 + 0.5).clamp(0, 1)] 424 | output_video = [video for video in output_video] 425 | output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] 426 | output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video] 427 | return output_video 428 | -------------------------------------------------------------------------------- /skyreels_v2_infer/pipelines/image2video_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from typing import Optional 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers.image_processor import PipelineImageInput 9 | from diffusers.video_processor import VideoProcessor 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | from ..modules import get_image_encoder 14 | from ..modules import get_text_encoder 15 | from ..modules import get_transformer 16 | from ..modules import get_vae 17 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler 18 | 19 | 20 | def resizecrop(image: Image.Image, th, tw): 21 | w, h = image.size 22 | if w == tw and h == th: 23 | return image 24 | if h / w > th / tw: 25 | new_w = int(w) 26 | new_h = int(new_w * th / tw) 27 | else: 28 | new_h = int(h) 29 | new_w = int(new_h * tw / th) 30 | left = (w - new_w) / 2 31 | top = (h - new_h) / 2 32 | right = (w + new_w) / 2 33 | bottom = (h + new_h) / 2 34 | image = image.crop((left, top, right, bottom)) 35 | return image 36 | 37 | 38 | class Image2VideoPipeline: 39 | def __init__( 40 | self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False 41 | ): 42 | load_device = "cpu" if offload else device 43 | self.transformer = get_transformer(dit_path, load_device, weight_dtype) 44 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth") 45 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32) 46 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype) 47 | self.clip = get_image_encoder(model_path, load_device, weight_dtype) 48 | self.sp_size = 1 49 | self.device = device 50 | self.offload = offload 51 | self.video_processor = VideoProcessor(vae_scale_factor=16) 52 | if use_usp: 53 | from xfuser.core.distributed import get_sequence_parallel_world_size 54 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward 55 | import types 56 | 57 | for block in self.transformer.blocks: 58 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) 59 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer) 60 | self.sp_size = get_sequence_parallel_world_size() 61 | 62 | self.scheduler = FlowUniPCMultistepScheduler() 63 | self.vae_stride = (4, 8, 8) 64 | self.patch_size = (1, 2, 2) 65 | 66 | @torch.no_grad() 67 | def __call__( 68 | self, 69 | image: PipelineImageInput, 70 | prompt: Union[str, List[str]] = None, 71 | negative_prompt: Union[str, List[str]] = None, 72 | height: int = 544, 73 | width: int = 960, 74 | num_frames: int = 97, 75 | num_inference_steps: int = 50, 76 | guidance_scale: float = 5.0, 77 | shift: float = 5.0, 78 | generator: Optional[torch.Generator] = None, 79 | ): 80 | F = num_frames 81 | 82 | latent_height = height // 8 // 2 * 2 83 | latent_width = width // 8 // 2 * 2 84 | latent_length = (F - 1) // 4 + 1 85 | 86 | h = latent_height * 8 87 | w = latent_width * 8 88 | 89 | img = self.video_processor.preprocess(image, height=h, width=w) 90 | 91 | img = img.to(device=self.device, dtype=self.transformer.dtype) 92 | 93 | padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device) 94 | 95 | img = img.unsqueeze(2) 96 | img_cond = torch.concat([img, padding_video], dim=2) 97 | img_cond = self.vae.encode(img_cond) 98 | mask = torch.ones_like(img_cond) 99 | mask[:, :, 1:] = 0 100 | y = torch.cat([mask[:, :4], img_cond], dim=1) 101 | self.clip.to(self.device) 102 | clip_context = self.clip.encode_video(img) 103 | if self.offload: 104 | self.clip.cpu() 105 | torch.cuda.empty_cache() 106 | 107 | # preprocess 108 | self.text_encoder.to(self.device) 109 | context = self.text_encoder.encode(prompt).to(self.device) 110 | context_null = self.text_encoder.encode(negative_prompt).to(self.device) 111 | if self.offload: 112 | self.text_encoder.cpu() 113 | torch.cuda.empty_cache() 114 | 115 | latent = torch.randn( 116 | 16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device 117 | ) 118 | 119 | self.transformer.to(self.device) 120 | with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad(): 121 | self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) 122 | timesteps = self.scheduler.timesteps 123 | 124 | arg_c = { 125 | "context": context, 126 | "clip_fea": clip_context, 127 | "y": y, 128 | } 129 | 130 | arg_null = { 131 | "context": context_null, 132 | "clip_fea": clip_context, 133 | "y": y, 134 | } 135 | 136 | self.transformer.to(self.device) 137 | for _, t in enumerate(tqdm(timesteps)): 138 | latent_model_input = torch.stack([latent]).to(self.device) 139 | timestep = torch.stack([t]).to(self.device) 140 | noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device) 141 | noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device) 142 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 143 | 144 | temp_x0 = self.scheduler.step( 145 | noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator 146 | )[0] 147 | latent = temp_x0.squeeze(0) 148 | if self.offload: 149 | self.transformer.cpu() 150 | torch.cuda.empty_cache() 151 | videos = self.vae.decode(latent) 152 | videos = (videos / 2 + 0.5).clamp(0, 1) 153 | videos = [video for video in videos] 154 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] 155 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos] 156 | return videos 157 | -------------------------------------------------------------------------------- /skyreels_v2_infer/pipelines/prompt_enhancer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | sys_prompt = """ 5 | Transform the short prompt into a detailed video-generation caption using this structure: 6 | ​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot) 7 | ​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions) 8 | ​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...') 9 | ​​Scene composition​​ (background, environment, spatial relationships) 10 | ​​Lighting/atmosphere​​ (natural/artificial, time of day, mood) 11 | ​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable. 12 | 13 | Pattern Summary from Examples: 14 | [Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement] 15 | 16 | ​One case: 17 | Short prompt: a person is playing football 18 | Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan. 19 | 20 | Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic. 21 | 22 | ​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English. 23 | """ 24 | 25 | class PromptEnhancer: 26 | def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"): 27 | self.model = AutoModelForCausalLM.from_pretrained( 28 | model_name, 29 | torch_dtype="auto", 30 | device_map="cuda:0", 31 | ) 32 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 33 | 34 | def __call__(self, prompt): 35 | prompt = prompt.strip() 36 | prompt = sys_prompt.format(prompt) 37 | messages = [ 38 | {"role": "system", "content": "You are a helpful assistant."}, 39 | {"role": "user", "content": prompt} 40 | ] 41 | text = self.tokenizer.apply_chat_template( 42 | messages, 43 | tokenize=False, 44 | add_generation_prompt=True 45 | ) 46 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) 47 | generated_ids = self.model.generate( 48 | **model_inputs, 49 | max_new_tokens=2048, 50 | ) 51 | generated_ids = [ 52 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 53 | ] 54 | rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 55 | return rewritten_prompt 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign") 60 | args = parser.parse_args() 61 | 62 | prompt_enhancer = PromptEnhancer() 63 | enhanced_prompt = prompt_enhancer(args.prompt) 64 | print(f'Original prompt: {args.prompt}') 65 | print(f'Enhanced prompt: {enhanced_prompt}') 66 | -------------------------------------------------------------------------------- /skyreels_v2_infer/pipelines/text2video_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from typing import Optional 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers.video_processor import VideoProcessor 9 | from tqdm import tqdm 10 | 11 | from ..modules import get_text_encoder 12 | from ..modules import get_transformer 13 | from ..modules import get_vae 14 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler 15 | 16 | 17 | class Text2VideoPipeline: 18 | def __init__( 19 | self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False 20 | ): 21 | load_device = "cpu" if offload else device 22 | self.transformer = get_transformer(dit_path, load_device, weight_dtype) 23 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth") 24 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32) 25 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype) 26 | self.video_processor = VideoProcessor(vae_scale_factor=16) 27 | self.sp_size = 1 28 | self.device = device 29 | self.offload = offload 30 | if use_usp: 31 | from xfuser.core.distributed import get_sequence_parallel_world_size 32 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward 33 | import types 34 | 35 | for block in self.transformer.blocks: 36 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) 37 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer) 38 | self.sp_size = get_sequence_parallel_world_size() 39 | 40 | self.scheduler = FlowUniPCMultistepScheduler() 41 | self.vae_stride = (4, 8, 8) 42 | self.patch_size = (1, 2, 2) 43 | 44 | @torch.no_grad() 45 | def __call__( 46 | self, 47 | prompt: Union[str, List[str]] = None, 48 | negative_prompt: Union[str, List[str]] = None, 49 | width: int = 544, 50 | height: int = 960, 51 | num_frames: int = 97, 52 | num_inference_steps: int = 50, 53 | guidance_scale: float = 5.0, 54 | shift: float = 5.0, 55 | generator: Optional[torch.Generator] = None, 56 | ): 57 | # preprocess 58 | F = num_frames 59 | target_shape = ( 60 | self.vae.vae.z_dim, 61 | (F - 1) // self.vae_stride[0] + 1, 62 | height // self.vae_stride[1], 63 | width // self.vae_stride[2], 64 | ) 65 | self.text_encoder.to(self.device) 66 | context = self.text_encoder.encode(prompt).to(self.device) 67 | context_null = self.text_encoder.encode(negative_prompt).to(self.device) 68 | if self.offload: 69 | self.text_encoder.cpu() 70 | torch.cuda.empty_cache() 71 | 72 | latents = [ 73 | torch.randn( 74 | target_shape[0], 75 | target_shape[1], 76 | target_shape[2], 77 | target_shape[3], 78 | dtype=torch.float32, 79 | device=self.device, 80 | generator=generator, 81 | ) 82 | ] 83 | 84 | # evaluation mode 85 | self.transformer.to(self.device) 86 | with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad(): 87 | self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) 88 | timesteps = self.scheduler.timesteps 89 | 90 | for _, t in enumerate(tqdm(timesteps)): 91 | latent_model_input = torch.stack(latents) 92 | timestep = torch.stack([t]) 93 | noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0] 94 | noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0] 95 | 96 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 97 | 98 | temp_x0 = self.scheduler.step( 99 | noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator 100 | )[0] 101 | latents = [temp_x0.squeeze(0)] 102 | if self.offload: 103 | self.transformer.cpu() 104 | torch.cuda.empty_cache() 105 | videos = self.vae.decode(latents[0]) 106 | videos = (videos / 2 + 0.5).clamp(0, 1) 107 | videos = [video for video in videos] 108 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] 109 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos] 110 | return videos 111 | -------------------------------------------------------------------------------- /skyreels_v2_infer/scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skyreels_v2_infer/scheduler/__init__.py --------------------------------------------------------------------------------