├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── README.md
├── assets
├── logo2.png
└── main_pipeline.jpg
├── generate_video.py
├── generate_video_df.py
├── requirements.txt
├── skycaptioner_v1
├── README.md
├── examples
│ ├── data
│ │ ├── 1.mp4
│ │ ├── 2.mp4
│ │ ├── 3.mp4
│ │ └── 4.mp4
│ ├── test.csv
│ └── test_result.csv
├── infer_fusion_caption.sh
├── infer_struct_caption.sh
├── requirements.txt
└── scripts
│ ├── utils.py
│ ├── vllm_fusion_caption.py
│ └── vllm_struct_caption.py
└── skyreels_v2_infer
├── __init__.py
├── distributed
├── __init__.py
└── xdit_context_parallel.py
├── modules
├── __init__.py
├── attention.py
├── clip.py
├── t5.py
├── tokenizers.py
├── transformer.py
├── vae.py
└── xlm_roberta.py
├── pipelines
├── __init__.py
├── diffusion_forcing_pipeline.py
├── image2video_pipeline.py
├── prompt_enhancer.py
└── text2video_pipeline.py
└── scheduler
├── __init__.py
└── fm_solvers_unipc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | checkpoint/*
3 | checkpoint
4 | results/*
5 | .DS_Store
6 | results/*
7 | *.png
8 | *.jpg
9 | *.mp4
10 | *.log*
11 | *.json
12 | scripts/transformer/*
13 | compile_cache
14 | scripts/.gradio/*
15 | *.pkl
16 | # *.csv
17 | *.jsonl
18 | out/*
19 | model/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/asottile/reorder-python-imports.git
3 | rev: v3.8.3
4 | hooks:
5 | - id: reorder-python-imports
6 | name: Reorder Python imports
7 | types: [file, python]
8 | - repo: https://github.com/psf/black.git
9 | rev: 22.8.0
10 | hooks:
11 | - id: black
12 | additional_dependencies: ['click==8.0.4']
13 | args: [--line-length=120]
14 | types: [file, python]
15 | - repo: https://github.com/pre-commit/pre-commit-hooks.git
16 | rev: v4.3.0
17 | hooks:
18 | - id: check-byte-order-marker
19 | types: [file, python]
20 | - id: trailing-whitespace
21 | types: [file, python]
22 | - id: end-of-file-fixer
23 | types: [file, python]
24 |
25 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | ---
2 | language:
3 | - en
4 | - zh
5 | license: other
6 | tasks:
7 | - text-generation
8 |
9 | ---
10 |
11 |
12 |
13 |
14 | # 声明与协议/Terms and Conditions
15 |
16 | ## 声明
17 |
18 | 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
19 |
20 | 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
21 |
22 | We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
23 |
24 | We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
25 |
26 | ## 协议
27 |
28 | 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
29 |
30 |
31 | The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
32 |
33 |
34 |
35 | [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
36 |
37 |
38 | [skywork-opensource@kunlun-inc.com]: mailto:skywork-opensource@kunlun-inc.com
39 |
--------------------------------------------------------------------------------
/assets/logo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/assets/logo2.png
--------------------------------------------------------------------------------
/assets/main_pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/assets/main_pipeline.jpg
--------------------------------------------------------------------------------
/generate_video.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | import time
6 |
7 | import imageio
8 | import torch
9 | from diffusers.utils import load_image
10 |
11 | from skyreels_v2_infer.modules import download_model
12 | from skyreels_v2_infer.pipelines import Image2VideoPipeline
13 | from skyreels_v2_infer.pipelines import PromptEnhancer
14 | from skyreels_v2_infer.pipelines import resizecrop
15 | from skyreels_v2_infer.pipelines import Text2VideoPipeline
16 |
17 | MODEL_ID_CONFIG = {
18 | "text2video": [
19 | "Skywork/SkyReels-V2-T2V-14B-540P",
20 | "Skywork/SkyReels-V2-T2V-14B-720P",
21 | ],
22 | "image2video": [
23 | "Skywork/SkyReels-V2-I2V-1.3B-540P",
24 | "Skywork/SkyReels-V2-I2V-14B-540P",
25 | "Skywork/SkyReels-V2-I2V-14B-720P",
26 | ],
27 | }
28 |
29 |
30 | if __name__ == "__main__":
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--outdir", type=str, default="video_out")
34 | parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-T2V-14B-540P")
35 | parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
36 | parser.add_argument("--num_frames", type=int, default=97)
37 | parser.add_argument("--image", type=str, default=None)
38 | parser.add_argument("--guidance_scale", type=float, default=6.0)
39 | parser.add_argument("--shift", type=float, default=8.0)
40 | parser.add_argument("--inference_steps", type=int, default=30)
41 | parser.add_argument("--use_usp", action="store_true")
42 | parser.add_argument("--offload", action="store_true")
43 | parser.add_argument("--fps", type=int, default=24)
44 | parser.add_argument("--seed", type=int, default=None)
45 | parser.add_argument(
46 | "--prompt",
47 | type=str,
48 | default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.",
49 | )
50 | parser.add_argument("--prompt_enhancer", action="store_true")
51 | parser.add_argument("--teacache", action="store_true")
52 | parser.add_argument(
53 | "--teacache_thresh",
54 | type=float,
55 | default=0.2,
56 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
57 | parser.add_argument(
58 | "--use_ret_steps",
59 | action="store_true",
60 | help="Using Retention Steps will result in faster generation speed and better generation quality.")
61 | args = parser.parse_args()
62 |
63 | args.model_id = download_model(args.model_id)
64 | print("model_id:", args.model_id)
65 |
66 | assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
67 | if args.seed is None:
68 | random.seed(time.time())
69 | args.seed = int(random.randrange(4294967294))
70 |
71 | if args.resolution == "540P":
72 | height = 544
73 | width = 960
74 | elif args.resolution == "720P":
75 | height = 720
76 | width = 1280
77 | else:
78 | raise ValueError(f"Invalid resolution: {args.resolution}")
79 |
80 | image = load_image(args.image).convert("RGB") if args.image else None
81 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
82 | local_rank = 0
83 | if args.use_usp:
84 | assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
85 | from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
86 | import torch.distributed as dist
87 |
88 | dist.init_process_group("nccl")
89 | local_rank = dist.get_rank()
90 | torch.cuda.set_device(dist.get_rank())
91 | device = "cuda"
92 |
93 | init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
94 |
95 | initialize_model_parallel(
96 | sequence_parallel_degree=dist.get_world_size(),
97 | ring_degree=1,
98 | ulysses_degree=dist.get_world_size(),
99 | )
100 |
101 | prompt_input = args.prompt
102 | if args.prompt_enhancer and args.image is None:
103 | print(f"init prompt enhancer")
104 | prompt_enhancer = PromptEnhancer()
105 | prompt_input = prompt_enhancer(prompt_input)
106 | print(f"enhanced prompt: {prompt_input}")
107 | del prompt_enhancer
108 | gc.collect()
109 | torch.cuda.empty_cache()
110 |
111 | if image is None:
112 | assert "T2V" in args.model_id, f"check model_id:{args.model_id}"
113 | print("init text2video pipeline")
114 | pipe = Text2VideoPipeline(
115 | model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
116 | )
117 | else:
118 | assert "I2V" in args.model_id, f"check model_id:{args.model_id}"
119 | print("init img2video pipeline")
120 | pipe = Image2VideoPipeline(
121 | model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
122 | )
123 | args.image = load_image(args.image)
124 | image_width, image_height = args.image.size
125 | if image_height > image_width:
126 | height, width = width, height
127 | args.image = resizecrop(args.image, height, width)
128 |
129 | if args.teacache:
130 | pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
131 | teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
132 | ckpt_dir=args.model_id)
133 |
134 | prompt_input = args.prompt
135 | if args.prompt_enhancer and image is not None:
136 | prompt_input = prompt_enhancer(prompt_input)
137 | print(f"enhanced prompt: {prompt_input}")
138 |
139 | kwargs = {
140 | "prompt": prompt_input,
141 | "negative_prompt": negative_prompt,
142 | "num_frames": args.num_frames,
143 | "num_inference_steps": args.inference_steps,
144 | "guidance_scale": args.guidance_scale,
145 | "shift": args.shift,
146 | "generator": torch.Generator(device="cuda").manual_seed(args.seed),
147 | "height": height,
148 | "width": width,
149 | }
150 |
151 | if image is not None:
152 | kwargs["image"] = args.image.convert("RGB")
153 |
154 | save_dir = os.path.join("result", args.outdir)
155 | os.makedirs(save_dir, exist_ok=True)
156 |
157 | with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
158 | print(f"infer kwargs:{kwargs}")
159 | video_frames = pipe(**kwargs)[0]
160 |
161 | if local_rank == 0:
162 | current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
163 | video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
164 | output_path = os.path.join(save_dir, video_out_file)
165 | imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
166 |
--------------------------------------------------------------------------------
/generate_video_df.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import os
4 | import random
5 | import time
6 |
7 | import imageio
8 | import torch
9 | from diffusers.utils import load_image
10 |
11 | from skyreels_v2_infer import DiffusionForcingPipeline
12 | from skyreels_v2_infer.modules import download_model
13 | from skyreels_v2_infer.pipelines import PromptEnhancer
14 |
15 | if __name__ == "__main__":
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--outdir", type=str, default="diffusion_forcing")
19 | parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
20 | parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
21 | parser.add_argument("--num_frames", type=int, default=97)
22 | parser.add_argument("--image", type=str, default=None)
23 | parser.add_argument("--ar_step", type=int, default=0)
24 | parser.add_argument("--causal_attention", action="store_true")
25 | parser.add_argument("--causal_block_size", type=int, default=1)
26 | parser.add_argument("--base_num_frames", type=int, default=97)
27 | parser.add_argument("--overlap_history", type=int, default=None)
28 | parser.add_argument("--addnoise_condition", type=int, default=0)
29 | parser.add_argument("--guidance_scale", type=float, default=6.0)
30 | parser.add_argument("--shift", type=float, default=8.0)
31 | parser.add_argument("--inference_steps", type=int, default=30)
32 | parser.add_argument("--use_usp", action="store_true")
33 | parser.add_argument("--offload", action="store_true")
34 | parser.add_argument("--fps", type=int, default=24)
35 | parser.add_argument("--seed", type=int, default=None)
36 | parser.add_argument(
37 | "--prompt",
38 | type=str,
39 | default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
40 | )
41 | parser.add_argument("--prompt_enhancer", action="store_true")
42 | parser.add_argument("--teacache", action="store_true")
43 | parser.add_argument(
44 | "--teacache_thresh",
45 | type=float,
46 | default=0.2,
47 | help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
48 | parser.add_argument(
49 | "--use_ret_steps",
50 | action="store_true",
51 | help="Using Retention Steps will result in faster generation speed and better generation quality.")
52 | args = parser.parse_args()
53 |
54 | args.model_id = download_model(args.model_id)
55 | print("model_id:", args.model_id)
56 |
57 | assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
58 | if args.seed is None:
59 | random.seed(time.time())
60 | args.seed = int(random.randrange(4294967294))
61 |
62 | if args.resolution == "540P":
63 | height = 544
64 | width = 960
65 | elif args.resolution == "720P":
66 | height = 720
67 | width = 1280
68 | else:
69 | raise ValueError(f"Invalid resolution: {args.resolution}")
70 |
71 | num_frames = args.num_frames
72 | fps = args.fps
73 |
74 | if num_frames > args.base_num_frames:
75 | assert (
76 | args.overlap_history is not None
77 | ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
78 | if args.addnoise_condition > 60:
79 | print(
80 | f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
81 | )
82 |
83 | guidance_scale = args.guidance_scale
84 | shift = args.shift
85 | image = load_image(args.image).convert("RGB") if args.image else None
86 | negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
87 |
88 | save_dir = os.path.join("result", args.outdir)
89 | os.makedirs(save_dir, exist_ok=True)
90 | local_rank = 0
91 | if args.use_usp:
92 | assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
93 | from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
94 | import torch.distributed as dist
95 |
96 | dist.init_process_group("nccl")
97 | local_rank = dist.get_rank()
98 | torch.cuda.set_device(dist.get_rank())
99 | device = "cuda"
100 |
101 | init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
102 |
103 | initialize_model_parallel(
104 | sequence_parallel_degree=dist.get_world_size(),
105 | ring_degree=1,
106 | ulysses_degree=dist.get_world_size(),
107 | )
108 |
109 | prompt_input = args.prompt
110 | if args.prompt_enhancer and args.image is None:
111 | print(f"init prompt enhancer")
112 | prompt_enhancer = PromptEnhancer()
113 | prompt_input = prompt_enhancer(prompt_input)
114 | print(f"enhanced prompt: {prompt_input}")
115 | del prompt_enhancer
116 | gc.collect()
117 | torch.cuda.empty_cache()
118 |
119 | pipe = DiffusionForcingPipeline(
120 | args.model_id,
121 | dit_path=args.model_id,
122 | device=torch.device("cuda"),
123 | weight_dtype=torch.bfloat16,
124 | use_usp=args.use_usp,
125 | offload=args.offload,
126 | )
127 |
128 | if args.causal_attention:
129 | pipe.transformer.set_ar_attention(args.causal_block_size)
130 |
131 | if args.teacache:
132 | if args.ar_step > 0:
133 | num_steps = args.inference_steps + (((args.base_num_frames - 1)//4 + 1) // args.causal_block_size - 1) * args.ar_step
134 | print('num_steps:', num_steps)
135 | else:
136 | num_steps = args.inference_steps
137 | pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=num_steps,
138 | teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
139 | ckpt_dir=args.model_id)
140 |
141 | print(f"prompt:{prompt_input}")
142 | print(f"guidance_scale:{guidance_scale}")
143 |
144 | with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
145 | video_frames = pipe(
146 | prompt=prompt_input,
147 | negative_prompt=negative_prompt,
148 | image=image,
149 | height=height,
150 | width=width,
151 | num_frames=num_frames,
152 | num_inference_steps=args.inference_steps,
153 | shift=shift,
154 | guidance_scale=guidance_scale,
155 | generator=torch.Generator(device="cuda").manual_seed(args.seed),
156 | overlap_history=args.overlap_history,
157 | addnoise_condition=args.addnoise_condition,
158 | base_num_frames=args.base_num_frames,
159 | ar_step=args.ar_step,
160 | causal_block_size=args.causal_block_size,
161 | fps=fps,
162 | )[0]
163 |
164 | if local_rank == 0:
165 | current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
166 | video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
167 | output_path = os.path.join(save_dir, video_out_file)
168 | imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
169 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | torchvision==0.20.1
3 | opencv-python==4.10.0.84
4 | diffusers>=0.31.0
5 | transformers==4.49.0
6 | tokenizers==0.21.1
7 | accelerate==1.6.0
8 | tqdm
9 | imageio
10 | easydict
11 | ftfy
12 | dashscope
13 | imageio-ffmpeg
14 | flash_attn
15 | numpy>=1.23.5,<2
16 | xfuser
17 |
--------------------------------------------------------------------------------
/skycaptioner_v1/README.md:
--------------------------------------------------------------------------------
1 | # SkyCaptioner-V1: A Structural Video Captioning Model
2 |
3 |
4 | 📑 Technical Report · 👋 Playground · 💬 Discord · 🤗 Hugging Face · 🤖 ModelScope
5 |
6 |
7 | ---
8 |
9 | Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively.
10 |
11 | ## 🔥🔥🔥 News!!
12 |
13 | * Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code.
14 | * Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074).
15 |
16 | ## 📑 TODO List
17 |
18 | - SkyCaptioner-V1
19 |
20 | - [x] Checkpoints
21 | - [x] Batch Inference Code
22 | - [x] Caption Fusion Method
23 | - [ ] Web Demo (Gradio)
24 |
25 | ## 🌟 Overview
26 |
27 | SkyCaptioner-V1 is a structural video captioning model designed to generate high-quality, structural descriptions for video data. It integrates specialized sub-expert models and multimodal large language models (MLLMs) with human annotations to address the limitations of general captioners in capturing professional film-related details. Key aspects include:
28 |
29 | 1. **Structural Representation**: Combines general video descriptions (from MLLMs) with sub-expert captioner (e.g., shot types,shot angles, shot positions, camera motions.) and human annotations.
30 | 2. **Knowledge Distillation**: Distills expertise from sub-expert captioners into a unified model.
31 | 3. **Application Flexibility**: Generates dense captions for text-to-video (T2V) and concise prompts for image-to-video (I2V) tasks.
32 |
33 | ## 🔑 Key Features
34 |
35 | ### Structural Captioning Framework
36 |
37 | Our Video Captioning model captures multi-dimensional details:
38 |
39 | * **Subjects**: Appearance, action, expression, position, and hierarchical categorization.
40 | * **Shot Metadata**: Shot type (e.g., close-up, long shot), shot angle, shot position, camera motion, environment, lighting, etc.
41 |
42 | ### Sub-Expert Integration
43 |
44 | * **Shot Captioner**: Classifies shot type, angle, and position with high precision.
45 | * **Expression Captioner**: Analyzes facial expressions, emotion intensity, and temporal dynamics.
46 | * **Camera Motion Captioner**: Tracks 6DoF camera movements and composite motion types,
47 |
48 | ### Training Pipeline
49 |
50 | * Trained on \~2M high-quality, concept-balanced videos curated from 10M raw samples.
51 | * Fine-tuned on Qwen2.5-VL-7B-Instruct with a global batch size of 512 across 32 A800 GPUs.
52 | * Optimized using AdamW (learning rate: 1e-5) for 2 epochs.
53 |
54 | ### Dynamic Caption Fusion:
55 |
56 | * Adapts output length based on application (T2V/I2V).
57 | * Employs LLM Model to fusion structural fields to get a natural and fluency caption for downstream tasks.
58 |
59 | ## 📊 Benchmark Results
60 |
61 | SkyCaptioner-V1 demonstrates significant improvements over existing models in key film-specific captioning tasks, particularly in **shot-language understanding** and **domain-specific precision**. The differences stem from its structural architecture and expert-guided training:
62 |
63 | 1. **Superior Shot-Language Understanding**:
64 | * Our Captioner model outperforms Qwen2.5-VL-72B with +11.2% in shot type, +16.1% in shot angle, and +50.4% in shot position accuracy. Because SkyCaptioner-V1’s specialized shot classifiers outperform generalist MLLMs, which lack film-domain fine-tuning.
65 | * +28.5% accuracy in camera motion vs. Tarsier2-recap-7B (88.8% vs. 41.5%):
66 | Its 6DoF motion analysis and active learning pipeline address ambiguities in composite motions (e.g., tracking + panning) that challenge generic captioners.
67 | 2. **High domain-specific precision**:
68 | * Expression accuracy: 68.8% vs. 54.3% (Tarsier2-recap-7B), leveraging temporal-aware S2D frameworks to capture dynamic facial changes.
69 |
70 |
71 |
72 |
73 |
74 | Metric |
75 | Qwen2.5-VL-7B-Ins. |
76 | Qwen2.5-VL-72B-Ins. |
77 | Tarsier2-recap-7B |
78 | SkyCaptioner-V1 |
79 |
80 |
81 |
82 |
83 | Avg accuracy |
84 | 51.4% |
85 | 58.7% |
86 | 49.4% |
87 | 76.3% |
88 |
89 |
90 | shot type |
91 | 76.8% |
92 | 82.5% |
93 | 60.2% |
94 | 93.7% |
95 |
96 |
97 | shot angle |
98 | 60.0% |
99 | 73.7% |
100 | 52.4% |
101 | 89.8% |
102 |
103 |
104 | shot position |
105 | 28.4% |
106 | 32.7% |
107 | 23.6% |
108 | 83.1% |
109 |
110 |
111 | camera motion |
112 | 62.0% |
113 | 61.2% |
114 | 45.3% |
115 | 85.3% |
116 |
117 |
118 | expression |
119 | 43.6% |
120 | 51.5% |
121 | 54.3% |
122 | 68.8% |
123 |
124 |
125 | TYPES_type |
126 | 43.5% |
127 | 49.7% |
128 | 47.6% |
129 | 82.5% |
130 |
131 |
132 | TYPES_sub_type |
133 | 38.9% |
134 | 44.9% |
135 | 45.9% |
136 | 75.4% |
137 |
138 |
139 | appearance |
140 | 40.9% |
141 | 52.0% |
142 | 45.6% |
143 | 59.3% |
144 |
145 |
146 | action |
147 | 32.4% |
148 | 52.0% |
149 | 69.8% |
150 | 68.8% |
151 |
152 |
153 | position |
154 | 35.4% |
155 | 48.6% |
156 | 45.5% |
157 | 57.5% |
158 |
159 |
160 | is_main_subject |
161 | 58.5% |
162 | 68.7% |
163 | 69.7% |
164 | 80.9% |
165 |
166 |
167 | environment |
168 | 70.4% |
169 | 72.7% |
170 | 61.4% |
171 | 70.5% |
172 |
173 |
174 | lighting |
175 | 77.1% |
176 | 80.0% |
177 | 21.2% |
178 | 76.5% |
179 |
180 |
181 |
182 |
183 |
184 | ## 📦 Model Downloads
185 |
186 | Our SkyCaptioner-V1 model can be downloaded from [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1).
187 | We use [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) as our caption fusion model to intelligently combine structured caption fields, producing either dense or sparse final captions depending on application requirements.
188 |
189 | ```shell
190 | # download SkyCaptioner-V1
191 | huggingface-cli download Skywork/SkyCaptioner-V1 --local-dir /path/to/your_local_model_path
192 | # download Qwen2.5-32B-Instruct
193 | huggingface-cli download Qwen/Qwen2.5-32B-Instruct --local-dir /path/to/your_local_model_path2
194 | ```
195 |
196 | ## 🛠️ Running Guide
197 |
198 | Begin by cloning the repository:
199 |
200 | ```shell
201 | git clone https://github.com/SkyworkAI/SkyReels-V2
202 | cd skycaptioner_v1
203 | ```
204 |
205 | ### Installation Guide for Linux
206 |
207 | We recommend Python 3.10 and CUDA version 12.2 for the manual installation.
208 |
209 | ```shell
210 | pip install -r requirements.txt
211 | ```
212 |
213 | ### Running Command
214 |
215 | #### Get Structural Caption by SkyCaptioner-V1
216 |
217 | ```shell
218 | export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
219 |
220 | python scripts/vllm_struct_caption.py \
221 | --model_path ${SkyCaptioner_V1_Model_PATH} \
222 | --input_csv "./examples/test.csv" \
223 | --out_csv "./examples/test_result.csv" \
224 | --tp 1 \
225 | --bs 4
226 | ```
227 |
228 | #### T2V/I2V Caption Fusion by Qwen2.5-32B-Instruct Model
229 |
230 | ```shell
231 | export LLM_MODEL_PATH="/path/to/your_local_model_path2"
232 |
233 | python scripts/vllm_fusion_caption.py \
234 | --model_path ${LLM_MODEL_PATH} \
235 | --input_csv "./examples/test_result.csv" \
236 | --out_csv "./examples/test_result_caption.csv" \
237 | --bs 4 \
238 | --tp 1 \
239 | --task t2v
240 | ```
241 | > **Note**:
242 | > - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command.
243 |
244 | ## Acknowledgements
245 |
246 | We would like to thank the contributors of Qwen2.5-VL, tarsier2 and vllm repositories, for their open research and contributions.
247 |
248 | ## Citation
249 |
250 | ```bibtex
251 | @misc{chen2025skyreelsv2infinitelengthfilmgenerative,
252 | author = {Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
253 | title = {Skyreels V2:Infinite-Length Film Generative Model},
254 | year = {2025},
255 | eprint={2504.13074},
256 | archivePrefix={arXiv},
257 | primaryClass={cs.CV},
258 | url={https://arxiv.org/abs/2504.13074}
259 | }
260 | ```
261 |
262 |
263 |
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/data/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/1.mp4
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/data/2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/2.mp4
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/data/3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/3.mp4
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/data/4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skycaptioner_v1/examples/data/4.mp4
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/test.csv:
--------------------------------------------------------------------------------
1 | path
2 | ./examples/data/1.mp4
3 | ./examples/data/2.mp4
4 | ./examples/data/3.mp4
5 | ./examples/data/4.mp4
--------------------------------------------------------------------------------
/skycaptioner_v1/examples/test_result.csv:
--------------------------------------------------------------------------------
1 | path,structural_caption
2 | ./examples/data/1.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Wearing winter sports gear, including a helmet and goggles."", ""action"": ""The video shows a snowy mountain landscape with a ski slope surrounded by dense trees and distant mountains. A skier is seen descending the slope, moving from the top left to the bottom right of the frame. The skier maintains a steady pace, navigating the curves of the slope. The background includes a ski lift with chairs moving along the cables, and the slope is marked with red and white lines indicating the path for skiers. The skier continues to descend, gradually getting closer to the bottom of the slope."", ""expression"": """", ""position"": ""Centered in the frame, moving downwards on the slope."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""White, covering the ground and trees."", ""action"": """", ""expression"": """", ""position"": ""Surrounding the skier, covering the entire visible area."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Plant"", ""sub_type"": ""Tree""}, ""appearance"": ""Tall, evergreen, covered in snow."", ""action"": """", ""expression"": """", ""position"": ""Scattered throughout the scene, both in the foreground and background."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""Snow-covered, with ski lifts and structures."", ""action"": """", ""expression"": """", ""position"": ""In the background, providing context to the location."", ""is_main_subject"": false}], ""shot_type"": ""long_shot"", ""shot_angle"": ""high_angle"", ""shot_position"": ""overlooking_view"", ""camera_motion"": ""the camera moves toward zooms in"", ""environment"": ""A snowy mountain landscape with a ski slope, trees, and a ski resort in the background."", ""lighting"": ""Bright daylight, casting shadows and highlighting the snow's texture.""}"
3 | ./examples/data/2.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Human"", ""sub_type"": ""Woman""}, ""appearance"": ""Long, straight black hair, wearing a sparkling choker necklace with a diamond-like texture, light-colored top, subtle makeup with pink lipstick, stud earrings."", ""action"": ""A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her. The background outside the car is green, indicating a possible outdoor setting."", ""expression"": ""The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor."", ""position"": ""Seated in the foreground of the car, facing slightly to the right."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Human"", ""sub_type"": ""Man""}, ""appearance"": ""Short hair, wearing a dark-colored suit with a white shirt."", ""action"": """", ""expression"": """", ""position"": ""Seated in the background of the car, facing the woman."", ""is_main_subject"": false}], ""shot_type"": ""close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""side_view"", ""camera_motion"": """", ""environment"": ""Interior of a car with dark upholstery."", ""lighting"": ""Soft and natural lighting, suggesting daytime.""}"
4 | ./examples/data/3.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Animal"", ""sub_type"": ""Insect""}, ""appearance"": ""The spider has a spherical, yellowish-green body with darker green stripes and spots. It has eight slender legs with visible joints and segments."", ""action"": ""A spider with a yellow and green body and black and white striped legs is hanging from its web in a natural setting with a blurred background of green and brown hues. The spider remains mostly still, with slight movements in its legs and body, indicating a gentle swaying motion."", ""expression"": """", ""position"": ""The spider is centrally positioned in the frame, hanging from a web."", ""is_main_subject"": true}], ""shot_type"": ""extreme_close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""front_view"", ""camera_motion"": """", ""environment"": ""The background consists of vertical, out-of-focus lines in shades of green and brown, suggesting a natural environment with vegetation."", ""lighting"": ""The lighting is soft and diffused, with no harsh shadows, indicating an overcast sky or a shaded area.""}"
5 | ./examples/data/4.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Football""}, ""appearance"": ""Wearing a dark-colored jersey, black shorts, and bright blue soccer shoes with white soles."", ""action"": ""A man is on a grassy field with orange cones placed in a line. He is wearing a gray shirt, black shorts, and black socks with blue shoes. The man starts by standing still with his feet apart, then begins to move forward while keeping his eyes on the cones. He continues to run forward, maintaining his focus on the cones, and his feet move in a coordinated manner to navigate around them. The background shows a clear sky, some trees, and a few buildings in the distance."", ""expression"": """", ""position"": ""Centered in the frame, with the soccer ball positioned between the cones."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Bright orange, conical shape."", ""action"": """", ""expression"": """", ""position"": ""Placed on the grass, creating a path for the soccer player."", ""is_main_subject"": false}], ""shot_type"": ""full_shot"", ""shot_angle"": ""low_angle"", ""shot_position"": ""front_view"", ""camera_motion"": ""use a tracking shot, the camera moves toward zooms in"", ""environment"": ""Outdoor sports field with well-maintained grass, trees, and a clear blue sky."", ""lighting"": ""Bright and natural, suggesting a sunny day.""}"
--------------------------------------------------------------------------------
/skycaptioner_v1/infer_fusion_caption.sh:
--------------------------------------------------------------------------------
1 | expor LLM_MODEL_PATH="/path/to/your_local_model_path2"
2 |
3 | python scripts/vllm_fusion_caption.py \
4 | --model_path ${LLM_MODEL_PATH} \
5 | --input_csv "./examples/test_result.csv" \
6 | --out_csv "./examples/test_result_caption.csv" \
7 | --bs 4 \
8 | --tp 1 \
9 | --task t2v
10 |
--------------------------------------------------------------------------------
/skycaptioner_v1/infer_struct_caption.sh:
--------------------------------------------------------------------------------
1 | expor SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
2 |
3 | python scripts/vllm_struct_caption.py \
4 | --model_path ${SkyCaptioner_V1_Model_PATH} \
5 | --input_csv "./examples/test.csv" \
6 | --out_csv "./examepls/test_result.csv" \
7 | --tp 1 \
8 | --bs 32
--------------------------------------------------------------------------------
/skycaptioner_v1/requirements.txt:
--------------------------------------------------------------------------------
1 | decord==0.6.0
2 | transformers>=4.49.0
3 | vllm==0.8.4
--------------------------------------------------------------------------------
/skycaptioner_v1/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
5 | flat_indices = []
6 | for x in zip(indices_list):
7 | flat_indices.extend(x)
8 | flat_results = []
9 | for x in zip(result_list):
10 | flat_results.extend(x)
11 |
12 | flat_indices = np.array(flat_indices)
13 | flat_results = np.array(flat_results)
14 |
15 | unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
16 | meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
17 |
18 | meta = meta.loc[unique_indices]
19 | return meta
--------------------------------------------------------------------------------
/skycaptioner_v1/scripts/vllm_fusion_caption.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import argparse
4 | import glob
5 | import time
6 | import gc
7 | from tqdm import tqdm
8 | import torch
9 | from transformers import AutoTokenizer
10 | import pandas as pd
11 | from vllm import LLM, SamplingParams
12 | from torch.utils.data import DataLoader
13 | import json
14 | import random
15 | from utils import result_writer
16 |
17 | SYSTEM_PROMPT_I2V = """
18 | You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
19 |
20 | ## Structured Input
21 | {structured_input}
22 |
23 | ## Notes
24 | 1. If there has an empty field, just ignore it and do not mention it in the output.
25 | 2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
26 | 3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.
27 |
28 | ## Output Principles and Orders
29 | 1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
30 | 2. Second, describe each subject with its pure action and expression if these fields exist.
31 |
32 | ## Output
33 | Please directly output the final composed caption without any additional information.
34 | """
35 |
36 | SYSTEM_PROMPT_T2V = """
37 | You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
38 |
39 | ## Structured Input
40 | {structured_input}
41 |
42 | ## Notes
43 | 1. According to the action field information, change its name field to the subject pronoun in the action.
44 | 2. If there has an empty field, just ignore it and do not mention it in the output.
45 | 3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
46 |
47 | ## Output Principles and Orders
48 | 1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
49 | 2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
50 | 3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
51 | 4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.
52 |
53 | ## Output
54 | Please directly output the final composed caption without any additional information.
55 | """
56 |
57 | SHOT_TYPE_LIST = [
58 | 'close-up shot',
59 | 'extreme close-up shot',
60 | 'medium shot',
61 | 'long shot',
62 | 'full shot',
63 | ]
64 |
65 |
66 | class StructuralCaptionDataset(torch.utils.data.Dataset):
67 | def __init__(self, input_csv, model_path):
68 | self.meta = pd.read_csv(input_csv)
69 | self.task = args.task
70 | self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
71 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
72 |
73 |
74 | def __len__(self):
75 | return len(self.meta)
76 |
77 | def __getitem__(self, index):
78 | row = self.meta.iloc[index]
79 | real_index = self.meta.index[index]
80 |
81 | struct_caption = json.loads(row["structural_caption"])
82 |
83 | camera_movement = struct_caption.get('camera_motion', '')
84 | if camera_movement != '':
85 | camera_movement += '.'
86 | camera_movement = camera_movement.capitalize()
87 |
88 | fusion_by_llm = False
89 | cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
90 | if cleaned_struct_caption.get('num_subjects', 0) > 0:
91 | new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
92 | conversation = [
93 | {
94 | "role": "system",
95 | "content": self.system_prompt.format(structured_input=new_struct_caption),
96 | },
97 | ]
98 | text = self.tokenizer.apply_chat_template(
99 | conversation,
100 | tokenize=False,
101 | add_generation_prompt=True
102 | )
103 | fusion_by_llm = True
104 | else:
105 | text = '-'
106 | return real_index, fusion_by_llm, text, '-', camera_movement
107 |
108 | def clean_struct_caption(self, struct_caption, task):
109 | raw_subjects = struct_caption.get('subjects', [])
110 | subjects = []
111 | for subject in raw_subjects:
112 | subject_type = subject.get("TYPES", {}).get('type', '')
113 | subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
114 | if subject_type not in ["Human", "Animal"]:
115 | subject['expression'] = ''
116 | if subject_type == 'Human' and subject_sub_type == 'Accessory':
117 | subject['expression'] = ''
118 | if subject_sub_type != '':
119 | subject['name'] = subject_sub_type
120 | if 'TYPES' in subject:
121 | del subject['TYPES']
122 | if 'is_main_subject' in subject:
123 | del subject['is_main_subject']
124 | subjects.append(subject)
125 |
126 | to_del_subject_ids = []
127 | for idx, subject in enumerate(subjects):
128 | action = subject.get('action', '').strip()
129 | subject['action'] = action
130 | if random.random() > 0.9 and 'appearance' in subject:
131 | del subject['appearance']
132 | if random.random() > 0.9 and 'position' in subject:
133 | del subject['position']
134 | if task == 'i2v':
135 | # just keep name and action, expression in subjects
136 | dropped_keys = ['appearance', 'position']
137 | for key in dropped_keys:
138 | if key in subject:
139 | del subject[key]
140 | if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
141 | to_del_subject_ids.append(idx)
142 |
143 | # delete the subjects according to the to_del_subject_ids
144 | for idx in sorted(to_del_subject_ids, reverse=True):
145 | del subjects[idx]
146 |
147 |
148 | shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
149 | if shot_type not in SHOT_TYPE_LIST:
150 | struct_caption['shot_type'] = ''
151 |
152 | new_struct_caption = {
153 | 'num_subjects': len(subjects),
154 | 'subjects': subjects,
155 | 'shot_type': struct_caption.get('shot_type', ''),
156 | 'shot_angle': struct_caption.get('shot_angle', ''),
157 | 'shot_position': struct_caption.get('shot_position', ''),
158 | 'environment': struct_caption.get('environment', ''),
159 | 'lighting': struct_caption.get('lighting', ''),
160 | }
161 |
162 | if task == 't2v' and random.random() > 0.9:
163 | del new_struct_caption['lighting']
164 |
165 | if task == 'i2v':
166 | drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
167 | for drop_key in drop_keys:
168 | del new_struct_caption[drop_key]
169 | return new_struct_caption
170 |
171 | def custom_collate_fn(batch):
172 | real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
173 | return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)
174 |
175 |
176 | if __name__ == "__main__":
177 | parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
178 | parser.add_argument("--input_csv", default="./examples/test_result.csv")
179 | parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
180 | parser.add_argument("--bs", type=int, default=4)
181 | parser.add_argument("--tp", type=int, default=1)
182 | parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
183 | parser.add_argument("--task", default='t2v', help="t2v or i2v")
184 |
185 | args = parser.parse_args()
186 |
187 | sampling_params = SamplingParams(
188 | temperature=0.1,
189 | max_tokens=512,
190 | stop=['\n\n']
191 | )
192 | # model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"
193 |
194 |
195 | llm = LLM(
196 | model=args.model_path,
197 | gpu_memory_utilization=0.9,
198 | max_model_len=4096,
199 | tensor_parallel_size = args.tp
200 | )
201 |
202 |
203 | dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
204 |
205 | dataloader = DataLoader(
206 | dataset,
207 | batch_size=args.bs,
208 | num_workers=8,
209 | collate_fn=custom_collate_fn,
210 | shuffle=False,
211 | drop_last=False,
212 | )
213 |
214 | indices_list = []
215 | result_list = []
216 | for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
217 | llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
218 | for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
219 | if fusion_by_llm:
220 | llm_indices.append(idx)
221 | llm_texts.append(text)
222 | llm_original_texts.append(original_text)
223 | llm_camera_movements.append(camera_movement)
224 | else:
225 | indices_list.append(idx)
226 | caption = original_text + " " + camera_movement
227 | result_list.append(caption)
228 | if len(llm_texts) > 0:
229 | try:
230 | outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
231 | results = []
232 | for output in outputs:
233 | result = output.outputs[0].text.strip()
234 | results.append(result)
235 | indices_list.extend(llm_indices)
236 | except Exception as e:
237 | print(f"Error at {llm_indices}: {str(e)}")
238 | indices_list.extend(llm_indices)
239 | results = llm_original_texts
240 |
241 | for result, camera_movement in zip(results, llm_camera_movements):
242 | # concat camera movement to fusion_caption
243 | llm_caption = result + " " + camera_movement
244 | result_list.append(llm_caption)
245 | torch.cuda.empty_cache()
246 | gc.collect()
247 | gathered_list = [indices_list, result_list]
248 | meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
249 | meta_new.to_csv(args.out_csv, index=False)
250 |
251 |
--------------------------------------------------------------------------------
/skycaptioner_v1/scripts/vllm_struct_caption.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import decord
4 | import argparse
5 |
6 | import pandas as pd
7 | import numpy as np
8 |
9 | from tqdm import tqdm
10 | from vllm import LLM, SamplingParams
11 | from transformers import AutoTokenizer, AutoProcessor
12 |
13 | from torch.utils.data import DataLoader
14 |
15 | SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}"
16 |
17 |
18 | class VideoTextDataset(torch.utils.data.Dataset):
19 | def __init__(self, csv_path, model_path):
20 | self.meta = pd.read_csv(csv_path)
21 | self._path = 'path'
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
23 | self.processor = AutoProcessor.from_pretrained(model_path)
24 |
25 | def __getitem__(self, index):
26 | row = self.meta.iloc[index]
27 | path = row[self._path]
28 | real_index = self.meta.index[index]
29 | vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420)
30 | start = 0
31 | end = len(vr)
32 | # avg_fps = vr.get_avg_fps()
33 | index = self.get_index(end-start, 16, st=start)
34 | frames = vr.get_batch(index).asnumpy() # n h w c
35 | video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)]
36 | conversation = {
37 | "role": "user",
38 | "content": [
39 | {
40 | "type": "video",
41 | "video": row['path'],
42 | "max_pixels": 360 * 420, # 460800
43 | "fps": 2.0,
44 | },
45 | {
46 | "type": "text",
47 | "text": SYSTEM_PROMPT
48 | },
49 | ],
50 | }
51 |
52 | # 生成 user_input
53 | user_input = self.processor.apply_chat_template(
54 | [conversation],
55 | tokenize=False,
56 | add_generation_prompt=True
57 | )
58 | results = dict()
59 | inputs = {
60 | 'prompt': user_input,
61 | 'multi_modal_data': {'video': video_inputs}
62 | }
63 | results["index"] = real_index
64 | results['input'] = inputs
65 | return results
66 |
67 | def __len__(self):
68 | return len(self.meta)
69 |
70 | def get_index(self, video_size, num_frames, st=0):
71 | seg_size = max(0., float(video_size - 1) / num_frames)
72 | max_frame = int(video_size) - 1
73 | seq = []
74 | # index from 1, must add 1
75 | for i in range(num_frames):
76 | start = int(np.round(seg_size * i))
77 | # end = int(np.round(seg_size * (i + 1)))
78 | idx = min(start, max_frame)
79 | seq.append(idx+st)
80 | return seq
81 |
82 | def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
83 | flat_indices = []
84 | for x in zip(indices_list):
85 | flat_indices.extend(x)
86 | flat_results = []
87 | for x in zip(result_list):
88 | flat_results.extend(x)
89 |
90 | flat_indices = np.array(flat_indices)
91 | flat_results = np.array(flat_results)
92 |
93 | unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
94 | meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
95 |
96 | meta = meta.loc[unique_indices]
97 | return meta
98 |
99 |
100 | def worker_init_fn(worker_id):
101 | # Set different seed for each worker
102 | worker_seed = torch.initial_seed() % 2**32
103 | np.random.seed(worker_seed)
104 | # Prevent deadlocks by setting timeout
105 | torch.set_num_threads(1)
106 |
107 | def main():
108 | parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference")
109 | parser.add_argument("--input_csv", default="./examples/test.csv")
110 | parser.add_argument("--out_csv", default="./examples/test_result.csv")
111 | parser.add_argument("--bs", type=int, default=4)
112 | parser.add_argument("--tp", type=int, default=1)
113 | parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path")
114 | args = parser.parse_args()
115 |
116 | dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path)
117 | dataloader = DataLoader(
118 | dataset,
119 | batch_size=args.bs,
120 | num_workers=4,
121 | worker_init_fn=worker_init_fn,
122 | persistent_workers=True,
123 | timeout=180,
124 | )
125 |
126 | sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
127 |
128 | llm = LLM(model=args.model_path,
129 | gpu_memory_utilization=0.6,
130 | max_model_len=31920,
131 | tensor_parallel_size=args.tp)
132 |
133 | indices_list = []
134 | caption_save = []
135 | for video_batch in tqdm(dataloader):
136 | indices = video_batch["index"]
137 | inputs = video_batch["input"]
138 | batch_user_inputs = []
139 | for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]):
140 | usi={'prompt':prompt, 'multi_modal_data':{'video':video}}
141 | batch_user_inputs.append(usi)
142 | outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False)
143 | struct_outputs = [output.outputs[0].text for output in outputs]
144 |
145 | indices_list.extend(indices.tolist())
146 | caption_save.extend(struct_outputs)
147 |
148 | meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"])
149 | meta_new.to_csv(args.out_csv, index=False)
150 | print(f'Saved structural_caption to {args.out_csv}')
151 |
152 | if __name__ == '__main__':
153 | main()
--------------------------------------------------------------------------------
/skyreels_v2_infer/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import DiffusionForcingPipeline
2 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skyreels_v2_infer/distributed/__init__.py
--------------------------------------------------------------------------------
/skyreels_v2_infer/distributed/xdit_context_parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.amp as amp
3 | from torch.backends.cuda import sdp_kernel
4 | from xfuser.core.distributed import get_sequence_parallel_rank
5 | from xfuser.core.distributed import get_sequence_parallel_world_size
6 | from xfuser.core.distributed import get_sp_group
7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8 |
9 | from ..modules.transformer import sinusoidal_embedding_1d
10 |
11 |
12 | def pad_freqs(original_tensor, target_len):
13 | seq_len, s1, s2 = original_tensor.shape
14 | pad_size = target_len - seq_len
15 | padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
16 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
17 | return padded_tensor
18 |
19 |
20 | @amp.autocast("cuda", enabled=False)
21 | def rope_apply(x, grid_sizes, freqs):
22 | """
23 | x: [B, L, N, C].
24 | grid_sizes: [B, 3].
25 | freqs: [M, C // 2].
26 | """
27 | s, n, c = x.size(1), x.size(2), x.size(3) // 2
28 | # split freqs
29 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
30 |
31 | # loop over samples
32 | output = []
33 | grid = [grid_sizes.tolist()] * x.size(0)
34 | for i, (f, h, w) in enumerate(grid):
35 | seq_len = f * h * w
36 |
37 | # precompute multipliers
38 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
39 | freqs_i = torch.cat(
40 | [
41 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
42 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
43 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
44 | ],
45 | dim=-1,
46 | ).reshape(seq_len, 1, -1)
47 |
48 | # apply rotary embedding
49 | sp_size = get_sequence_parallel_world_size()
50 | sp_rank = get_sequence_parallel_rank()
51 | freqs_i = pad_freqs(freqs_i, s * sp_size)
52 | s_per_rank = s
53 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
54 | x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
55 | x_i = torch.cat([x_i, x[i, s:]])
56 |
57 | # append to collection
58 | output.append(x_i)
59 | return torch.stack(output).float()
60 |
61 |
62 | def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
63 | """
64 | x: A list of videos each with shape [C, T, H, W].
65 | t: [B].
66 | context: A list of text embeddings each with shape [L, C].
67 | """
68 | if self.model_type == "i2v":
69 | assert clip_fea is not None and y is not None
70 | # params
71 | device = self.patch_embedding.weight.device
72 | if self.freqs.device != device:
73 | self.freqs = self.freqs.to(device)
74 |
75 | if y is not None:
76 | x = torch.cat([x, y], dim=1)
77 |
78 | # embeddings
79 | x = self.patch_embedding(x)
80 | grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
81 | x = x.flatten(2).transpose(1, 2)
82 |
83 | if self.flag_causal_attention:
84 | frame_num = grid_sizes[0]
85 | height = grid_sizes[1]
86 | width = grid_sizes[2]
87 | block_num = frame_num // self.num_frame_per_block
88 | range_tensor = torch.arange(block_num).view(-1, 1)
89 | range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
90 | casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
91 | casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
92 | casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
93 | casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
94 | self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
95 |
96 | # time embeddings
97 | with amp.autocast("cuda", dtype=torch.float32):
98 | if t.dim() == 2:
99 | b, f = t.shape
100 | _flag_df = True
101 | else:
102 | _flag_df = False
103 | e = self.time_embedding(
104 | sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
105 | ) # b, dim
106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
107 |
108 | if self.inject_sample_info:
109 | fps = torch.tensor(fps, dtype=torch.long, device=device)
110 |
111 | fps_emb = self.fps_embedding(fps).float()
112 | if _flag_df:
113 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
114 | else:
115 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
116 |
117 | if _flag_df:
118 | e = e.view(b, f, 1, 1, self.dim)
119 | e0 = e0.view(b, f, 1, 1, 6, self.dim)
120 | e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
121 | e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
122 | e0 = e0.transpose(1, 2).contiguous()
123 |
124 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
125 |
126 | # context
127 | context = self.text_embedding(context)
128 |
129 | if clip_fea is not None:
130 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
131 | context = torch.concat([context_clip, context], dim=1)
132 |
133 | # arguments
134 | if e0.ndim == 4:
135 | e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
136 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
137 |
138 | # Context Parallel
139 | x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
140 |
141 | for block in self.blocks:
142 | x = block(x, **kwargs)
143 |
144 | # head
145 | if e.ndim == 3:
146 | e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
147 | x = self.head(x, e)
148 |
149 | # Context Parallel
150 | x = get_sp_group().all_gather(x, dim=1)
151 |
152 | # unpatchify
153 | x = self.unpatchify(x, grid_sizes)
154 | return x.float()
155 |
156 |
157 | def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
158 |
159 | r"""
160 | Args:
161 | x(Tensor): Shape [B, L, num_heads, C / num_heads]
162 | seq_lens(Tensor): Shape [B]
163 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
164 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
165 | """
166 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
167 | half_dtypes = (torch.float16, torch.bfloat16)
168 |
169 | def half(x):
170 | return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
171 |
172 | # query, key, value function
173 | def qkv_fn(x):
174 | q = self.norm_q(self.q(x)).view(b, s, n, d)
175 | k = self.norm_k(self.k(x)).view(b, s, n, d)
176 | v = self.v(x).view(b, s, n, d)
177 | return q, k, v
178 |
179 | x = x.to(self.q.weight.dtype)
180 | q, k, v = qkv_fn(x)
181 |
182 | if not self._flag_ar_attention:
183 | q = rope_apply(q, grid_sizes, freqs)
184 | k = rope_apply(k, grid_sizes, freqs)
185 | else:
186 |
187 | q = rope_apply(q, grid_sizes, freqs)
188 | k = rope_apply(k, grid_sizes, freqs)
189 | q = q.to(torch.bfloat16)
190 | k = k.to(torch.bfloat16)
191 | v = v.to(torch.bfloat16)
192 | # x = torch.nn.functional.scaled_dot_product_attention(
193 | # q.transpose(1, 2),
194 | # k.transpose(1, 2),
195 | # v.transpose(1, 2),
196 | # ).transpose(1, 2).contiguous()
197 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
198 | x = (
199 | torch.nn.functional.scaled_dot_product_attention(
200 | q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
201 | )
202 | .transpose(1, 2)
203 | .contiguous()
204 | )
205 | x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
206 |
207 | # output
208 | x = x.flatten(2)
209 | x = self.o(x)
210 | return x
211 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/__init__.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 |
4 | import torch
5 | from safetensors.torch import load_file
6 |
7 | from .clip import CLIPModel
8 | from .t5 import T5EncoderModel
9 | from .transformer import WanModel
10 | from .vae import WanVAE
11 |
12 |
13 | def download_model(model_id):
14 | if not os.path.exists(model_id):
15 | from huggingface_hub import snapshot_download
16 |
17 | model_id = snapshot_download(repo_id=model_id)
18 | return model_id
19 |
20 |
21 | def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
22 | vae = WanVAE(model_path).to(device).to(weight_dtype)
23 | vae.vae.requires_grad_(False)
24 | vae.vae.eval()
25 | gc.collect()
26 | torch.cuda.empty_cache()
27 | return vae
28 |
29 |
30 | def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
31 | config_path = os.path.join(model_path, "config.json")
32 | transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
33 |
34 | for file in os.listdir(model_path):
35 | if file.endswith(".safetensors"):
36 | file_path = os.path.join(model_path, file)
37 | state_dict = load_file(file_path)
38 | transformer.load_state_dict(state_dict, strict=False)
39 | del state_dict
40 | gc.collect()
41 | torch.cuda.empty_cache()
42 |
43 | transformer.requires_grad_(False)
44 | transformer.eval()
45 | gc.collect()
46 | torch.cuda.empty_cache()
47 | return transformer
48 |
49 |
50 | def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
51 | t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
52 | tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
53 | text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
54 | text_encoder.requires_grad_(False)
55 | text_encoder.eval()
56 | gc.collect()
57 | torch.cuda.empty_cache()
58 | return text_encoder
59 |
60 |
61 | def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
62 | checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
63 | tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
64 | image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
65 | image_enc.requires_grad_(False)
66 | image_enc.eval()
67 | gc.collect()
68 | torch.cuda.empty_cache()
69 | return image_enc
70 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 |
4 | try:
5 | import flash_attn_interface
6 |
7 | FLASH_ATTN_3_AVAILABLE = True
8 | except ModuleNotFoundError:
9 | FLASH_ATTN_3_AVAILABLE = False
10 |
11 | try:
12 | import flash_attn
13 |
14 | FLASH_ATTN_2_AVAILABLE = True
15 | except ModuleNotFoundError:
16 | FLASH_ATTN_2_AVAILABLE = False
17 |
18 | import warnings
19 |
20 | __all__ = [
21 | "flash_attention",
22 | "attention",
23 | ]
24 |
25 |
26 | def flash_attention(
27 | q,
28 | k,
29 | v,
30 | q_lens=None,
31 | k_lens=None,
32 | dropout_p=0.0,
33 | softmax_scale=None,
34 | q_scale=None,
35 | causal=False,
36 | window_size=(-1, -1),
37 | deterministic=False,
38 | dtype=torch.bfloat16,
39 | version=None,
40 | ):
41 | """
42 | q: [B, Lq, Nq, C1].
43 | k: [B, Lk, Nk, C1].
44 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
45 | q_lens: [B].
46 | k_lens: [B].
47 | dropout_p: float. Dropout probability.
48 | softmax_scale: float. The scaling of QK^T before applying softmax.
49 | causal: bool. Whether to apply causal attention mask.
50 | window_size: (left right). If not (-1, -1), apply sliding window local attention.
51 | deterministic: bool. If True, slightly slower and uses more memory.
52 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
53 | """
54 | half_dtypes = (torch.float16, torch.bfloat16)
55 | assert dtype in half_dtypes
56 | assert q.device.type == "cuda" and q.size(-1) <= 256
57 |
58 | # params
59 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
60 |
61 | def half(x):
62 | return x if x.dtype in half_dtypes else x.to(dtype)
63 |
64 | # preprocess query
65 |
66 | q = half(q.flatten(0, 1))
67 | q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
68 |
69 | # preprocess key, value
70 |
71 | k = half(k.flatten(0, 1))
72 | v = half(v.flatten(0, 1))
73 | k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
74 |
75 | q = q.to(v.dtype)
76 | k = k.to(v.dtype)
77 |
78 | if q_scale is not None:
79 | q = q * q_scale
80 |
81 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
82 | warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
83 |
84 | torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
85 | # apply attention
86 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
87 | # Note: dropout_p, window_size are not supported in FA3 now.
88 | x = flash_attn_interface.flash_attn_varlen_func(
89 | q=q,
90 | k=k,
91 | v=v,
92 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
93 | .cumsum(0, dtype=torch.int32)
94 | .to(q.device, non_blocking=True),
95 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
96 | .cumsum(0, dtype=torch.int32)
97 | .to(q.device, non_blocking=True),
98 | seqused_q=None,
99 | seqused_k=None,
100 | max_seqlen_q=lq,
101 | max_seqlen_k=lk,
102 | softmax_scale=softmax_scale,
103 | causal=causal,
104 | deterministic=deterministic,
105 | )[0].unflatten(0, (b, lq))
106 | else:
107 | assert FLASH_ATTN_2_AVAILABLE
108 | x = flash_attn.flash_attn_varlen_func(
109 | q=q,
110 | k=k,
111 | v=v,
112 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
113 | .cumsum(0, dtype=torch.int32)
114 | .to(q.device, non_blocking=True),
115 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
116 | .cumsum(0, dtype=torch.int32)
117 | .to(q.device, non_blocking=True),
118 | max_seqlen_q=lq,
119 | max_seqlen_k=lk,
120 | dropout_p=dropout_p,
121 | softmax_scale=softmax_scale,
122 | causal=causal,
123 | window_size=window_size,
124 | deterministic=deterministic,
125 | ).unflatten(0, (b, lq))
126 | torch.cuda.nvtx.range_pop()
127 |
128 | # output
129 | return x
130 |
131 |
132 | def attention(
133 | q,
134 | k,
135 | v,
136 | q_lens=None,
137 | k_lens=None,
138 | dropout_p=0.0,
139 | softmax_scale=None,
140 | q_scale=None,
141 | causal=False,
142 | window_size=(-1, -1),
143 | deterministic=False,
144 | dtype=torch.bfloat16,
145 | fa_version=None,
146 | ):
147 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
148 | return flash_attention(
149 | q=q,
150 | k=k,
151 | v=v,
152 | q_lens=q_lens,
153 | k_lens=k_lens,
154 | dropout_p=dropout_p,
155 | softmax_scale=softmax_scale,
156 | q_scale=q_scale,
157 | causal=causal,
158 | window_size=window_size,
159 | deterministic=deterministic,
160 | dtype=dtype,
161 | version=fa_version,
162 | )
163 | else:
164 | if q_lens is not None or k_lens is not None:
165 | warnings.warn(
166 | "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
167 | )
168 | attn_mask = None
169 |
170 | q = q.transpose(1, 2).to(dtype)
171 | k = k.transpose(1, 2).to(dtype)
172 | v = v.transpose(1, 2).to(dtype)
173 |
174 | out = torch.nn.functional.scaled_dot_product_attention(
175 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
176 | )
177 |
178 | out = out.transpose(1, 2).contiguous()
179 | return out
180 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/clip.py:
--------------------------------------------------------------------------------
1 | # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import logging
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision.transforms as T
10 | from diffusers.models import ModelMixin
11 |
12 | from .attention import flash_attention
13 | from .tokenizers import HuggingfaceTokenizer
14 | from .xlm_roberta import XLMRoberta
15 |
16 | __all__ = [
17 | "XLMRobertaCLIP",
18 | "clip_xlm_roberta_vit_h_14",
19 | "CLIPModel",
20 | ]
21 |
22 |
23 | def pos_interpolate(pos, seq_len):
24 | if pos.size(1) == seq_len:
25 | return pos
26 | else:
27 | src_grid = int(math.sqrt(pos.size(1)))
28 | tar_grid = int(math.sqrt(seq_len))
29 | n = pos.size(1) - src_grid * src_grid
30 | return torch.cat(
31 | [
32 | pos[:, :n],
33 | F.interpolate(
34 | pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
35 | size=(tar_grid, tar_grid),
36 | mode="bicubic",
37 | align_corners=False,
38 | )
39 | .flatten(2)
40 | .transpose(1, 2),
41 | ],
42 | dim=1,
43 | )
44 |
45 |
46 | class QuickGELU(nn.Module):
47 | def forward(self, x):
48 | return x * torch.sigmoid(1.702 * x)
49 |
50 |
51 | class LayerNorm(nn.LayerNorm):
52 | def forward(self, x):
53 | return super().forward(x.float()).type_as(x)
54 |
55 |
56 | class SelfAttention(nn.Module):
57 | def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
58 | assert dim % num_heads == 0
59 | super().__init__()
60 | self.dim = dim
61 | self.num_heads = num_heads
62 | self.head_dim = dim // num_heads
63 | self.causal = causal
64 | self.attn_dropout = attn_dropout
65 | self.proj_dropout = proj_dropout
66 |
67 | # layers
68 | self.to_qkv = nn.Linear(dim, dim * 3)
69 | self.proj = nn.Linear(dim, dim)
70 |
71 | def forward(self, x):
72 | """
73 | x: [B, L, C].
74 | """
75 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
76 |
77 | # compute query, key, value
78 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
79 |
80 | # compute attention
81 | p = self.attn_dropout if self.training else 0.0
82 | x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
83 | x = x.reshape(b, s, c)
84 |
85 | # output
86 | x = self.proj(x)
87 | x = F.dropout(x, self.proj_dropout, self.training)
88 | return x
89 |
90 |
91 | class SwiGLU(nn.Module):
92 | def __init__(self, dim, mid_dim):
93 | super().__init__()
94 | self.dim = dim
95 | self.mid_dim = mid_dim
96 |
97 | # layers
98 | self.fc1 = nn.Linear(dim, mid_dim)
99 | self.fc2 = nn.Linear(dim, mid_dim)
100 | self.fc3 = nn.Linear(mid_dim, dim)
101 |
102 | def forward(self, x):
103 | x = F.silu(self.fc1(x)) * self.fc2(x)
104 | x = self.fc3(x)
105 | return x
106 |
107 |
108 | class AttentionBlock(nn.Module):
109 | def __init__(
110 | self,
111 | dim,
112 | mlp_ratio,
113 | num_heads,
114 | post_norm=False,
115 | causal=False,
116 | activation="quick_gelu",
117 | attn_dropout=0.0,
118 | proj_dropout=0.0,
119 | norm_eps=1e-5,
120 | ):
121 | assert activation in ["quick_gelu", "gelu", "swi_glu"]
122 | super().__init__()
123 | self.dim = dim
124 | self.mlp_ratio = mlp_ratio
125 | self.num_heads = num_heads
126 | self.post_norm = post_norm
127 | self.causal = causal
128 | self.norm_eps = norm_eps
129 |
130 | # layers
131 | self.norm1 = LayerNorm(dim, eps=norm_eps)
132 | self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
133 | self.norm2 = LayerNorm(dim, eps=norm_eps)
134 | if activation == "swi_glu":
135 | self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
136 | else:
137 | self.mlp = nn.Sequential(
138 | nn.Linear(dim, int(dim * mlp_ratio)),
139 | QuickGELU() if activation == "quick_gelu" else nn.GELU(),
140 | nn.Linear(int(dim * mlp_ratio), dim),
141 | nn.Dropout(proj_dropout),
142 | )
143 |
144 | def forward(self, x):
145 | if self.post_norm:
146 | x = x + self.norm1(self.attn(x))
147 | x = x + self.norm2(self.mlp(x))
148 | else:
149 | x = x + self.attn(self.norm1(x))
150 | x = x + self.mlp(self.norm2(x))
151 | return x
152 |
153 |
154 | class AttentionPool(nn.Module):
155 | def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
156 | assert dim % num_heads == 0
157 | super().__init__()
158 | self.dim = dim
159 | self.mlp_ratio = mlp_ratio
160 | self.num_heads = num_heads
161 | self.head_dim = dim // num_heads
162 | self.proj_dropout = proj_dropout
163 | self.norm_eps = norm_eps
164 |
165 | # layers
166 | gain = 1.0 / math.sqrt(dim)
167 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
168 | self.to_q = nn.Linear(dim, dim)
169 | self.to_kv = nn.Linear(dim, dim * 2)
170 | self.proj = nn.Linear(dim, dim)
171 | self.norm = LayerNorm(dim, eps=norm_eps)
172 | self.mlp = nn.Sequential(
173 | nn.Linear(dim, int(dim * mlp_ratio)),
174 | QuickGELU() if activation == "quick_gelu" else nn.GELU(),
175 | nn.Linear(int(dim * mlp_ratio), dim),
176 | nn.Dropout(proj_dropout),
177 | )
178 |
179 | def forward(self, x):
180 | """
181 | x: [B, L, C].
182 | """
183 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
184 |
185 | # compute query, key, value
186 | q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
187 | k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
188 |
189 | # compute attention
190 | x = flash_attention(q, k, v, version=2)
191 | x = x.reshape(b, 1, c)
192 |
193 | # output
194 | x = self.proj(x)
195 | x = F.dropout(x, self.proj_dropout, self.training)
196 |
197 | # mlp
198 | x = x + self.mlp(self.norm(x))
199 | return x[:, 0]
200 |
201 |
202 | class VisionTransformer(nn.Module):
203 | def __init__(
204 | self,
205 | image_size=224,
206 | patch_size=16,
207 | dim=768,
208 | mlp_ratio=4,
209 | out_dim=512,
210 | num_heads=12,
211 | num_layers=12,
212 | pool_type="token",
213 | pre_norm=True,
214 | post_norm=False,
215 | activation="quick_gelu",
216 | attn_dropout=0.0,
217 | proj_dropout=0.0,
218 | embedding_dropout=0.0,
219 | norm_eps=1e-5,
220 | ):
221 | if image_size % patch_size != 0:
222 | print("[WARNING] image_size is not divisible by patch_size", flush=True)
223 | assert pool_type in ("token", "token_fc", "attn_pool")
224 | out_dim = out_dim or dim
225 | super().__init__()
226 | self.image_size = image_size
227 | self.patch_size = patch_size
228 | self.num_patches = (image_size // patch_size) ** 2
229 | self.dim = dim
230 | self.mlp_ratio = mlp_ratio
231 | self.out_dim = out_dim
232 | self.num_heads = num_heads
233 | self.num_layers = num_layers
234 | self.pool_type = pool_type
235 | self.post_norm = post_norm
236 | self.norm_eps = norm_eps
237 |
238 | # embeddings
239 | gain = 1.0 / math.sqrt(dim)
240 | self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
241 | if pool_type in ("token", "token_fc"):
242 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
243 | self.pos_embedding = nn.Parameter(
244 | gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
245 | )
246 | self.dropout = nn.Dropout(embedding_dropout)
247 |
248 | # transformer
249 | self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
250 | self.transformer = nn.Sequential(
251 | *[
252 | AttentionBlock(
253 | dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
254 | )
255 | for _ in range(num_layers)
256 | ]
257 | )
258 | self.post_norm = LayerNorm(dim, eps=norm_eps)
259 |
260 | # head
261 | if pool_type == "token":
262 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
263 | elif pool_type == "token_fc":
264 | self.head = nn.Linear(dim, out_dim)
265 | elif pool_type == "attn_pool":
266 | self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
267 |
268 | def forward(self, x, interpolation=False, use_31_block=False):
269 | b = x.size(0)
270 |
271 | # embeddings
272 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
273 | if self.pool_type in ("token", "token_fc"):
274 | x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
275 | if interpolation:
276 | e = pos_interpolate(self.pos_embedding, x.size(1))
277 | else:
278 | e = self.pos_embedding
279 | x = self.dropout(x + e)
280 | if self.pre_norm is not None:
281 | x = self.pre_norm(x)
282 |
283 | # transformer
284 | if use_31_block:
285 | x = self.transformer[:-1](x)
286 | return x
287 | else:
288 | x = self.transformer(x)
289 | return x
290 |
291 |
292 | class XLMRobertaWithHead(XLMRoberta):
293 | def __init__(self, **kwargs):
294 | self.out_dim = kwargs.pop("out_dim")
295 | super().__init__(**kwargs)
296 |
297 | # head
298 | mid_dim = (self.dim + self.out_dim) // 2
299 | self.head = nn.Sequential(
300 | nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
301 | )
302 |
303 | def forward(self, ids):
304 | # xlm-roberta
305 | x = super().forward(ids)
306 |
307 | # average pooling
308 | mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
309 | x = (x * mask).sum(dim=1) / mask.sum(dim=1)
310 |
311 | # head
312 | x = self.head(x)
313 | return x
314 |
315 |
316 | class XLMRobertaCLIP(nn.Module):
317 | def __init__(
318 | self,
319 | embed_dim=1024,
320 | image_size=224,
321 | patch_size=14,
322 | vision_dim=1280,
323 | vision_mlp_ratio=4,
324 | vision_heads=16,
325 | vision_layers=32,
326 | vision_pool="token",
327 | vision_pre_norm=True,
328 | vision_post_norm=False,
329 | activation="gelu",
330 | vocab_size=250002,
331 | max_text_len=514,
332 | type_size=1,
333 | pad_id=1,
334 | text_dim=1024,
335 | text_heads=16,
336 | text_layers=24,
337 | text_post_norm=True,
338 | text_dropout=0.1,
339 | attn_dropout=0.0,
340 | proj_dropout=0.0,
341 | embedding_dropout=0.0,
342 | norm_eps=1e-5,
343 | ):
344 | super().__init__()
345 | self.embed_dim = embed_dim
346 | self.image_size = image_size
347 | self.patch_size = patch_size
348 | self.vision_dim = vision_dim
349 | self.vision_mlp_ratio = vision_mlp_ratio
350 | self.vision_heads = vision_heads
351 | self.vision_layers = vision_layers
352 | self.vision_pre_norm = vision_pre_norm
353 | self.vision_post_norm = vision_post_norm
354 | self.activation = activation
355 | self.vocab_size = vocab_size
356 | self.max_text_len = max_text_len
357 | self.type_size = type_size
358 | self.pad_id = pad_id
359 | self.text_dim = text_dim
360 | self.text_heads = text_heads
361 | self.text_layers = text_layers
362 | self.text_post_norm = text_post_norm
363 | self.norm_eps = norm_eps
364 |
365 | # models
366 | self.visual = VisionTransformer(
367 | image_size=image_size,
368 | patch_size=patch_size,
369 | dim=vision_dim,
370 | mlp_ratio=vision_mlp_ratio,
371 | out_dim=embed_dim,
372 | num_heads=vision_heads,
373 | num_layers=vision_layers,
374 | pool_type=vision_pool,
375 | pre_norm=vision_pre_norm,
376 | post_norm=vision_post_norm,
377 | activation=activation,
378 | attn_dropout=attn_dropout,
379 | proj_dropout=proj_dropout,
380 | embedding_dropout=embedding_dropout,
381 | norm_eps=norm_eps,
382 | )
383 | self.textual = XLMRobertaWithHead(
384 | vocab_size=vocab_size,
385 | max_seq_len=max_text_len,
386 | type_size=type_size,
387 | pad_id=pad_id,
388 | dim=text_dim,
389 | out_dim=embed_dim,
390 | num_heads=text_heads,
391 | num_layers=text_layers,
392 | post_norm=text_post_norm,
393 | dropout=text_dropout,
394 | )
395 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
396 |
397 | def forward(self, imgs, txt_ids):
398 | """
399 | imgs: [B, 3, H, W] of torch.float32.
400 | - mean: [0.48145466, 0.4578275, 0.40821073]
401 | - std: [0.26862954, 0.26130258, 0.27577711]
402 | txt_ids: [B, L] of torch.long.
403 | Encoded by data.CLIPTokenizer.
404 | """
405 | xi = self.visual(imgs)
406 | xt = self.textual(txt_ids)
407 | return xi, xt
408 |
409 | def param_groups(self):
410 | groups = [
411 | {
412 | "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
413 | "weight_decay": 0.0,
414 | },
415 | {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
416 | ]
417 | return groups
418 |
419 |
420 | def _clip(
421 | pretrained=False,
422 | pretrained_name=None,
423 | model_cls=XLMRobertaCLIP,
424 | return_transforms=False,
425 | return_tokenizer=False,
426 | tokenizer_padding="eos",
427 | dtype=torch.float32,
428 | device="cpu",
429 | **kwargs,
430 | ):
431 | # init a model on device
432 | with torch.device(device):
433 | model = model_cls(**kwargs)
434 |
435 | # set device
436 | model = model.to(dtype=dtype, device=device)
437 | output = (model,)
438 |
439 | # init transforms
440 | if return_transforms:
441 | # mean and std
442 | if "siglip" in pretrained_name.lower():
443 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
444 | else:
445 | mean = [0.48145466, 0.4578275, 0.40821073]
446 | std = [0.26862954, 0.26130258, 0.27577711]
447 |
448 | # transforms
449 | transforms = T.Compose(
450 | [
451 | T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
452 | T.ToTensor(),
453 | T.Normalize(mean=mean, std=std),
454 | ]
455 | )
456 | output += (transforms,)
457 | return output[0] if len(output) == 1 else output
458 |
459 |
460 | def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
461 | cfg = dict(
462 | embed_dim=1024,
463 | image_size=224,
464 | patch_size=14,
465 | vision_dim=1280,
466 | vision_mlp_ratio=4,
467 | vision_heads=16,
468 | vision_layers=32,
469 | vision_pool="token",
470 | activation="gelu",
471 | vocab_size=250002,
472 | max_text_len=514,
473 | type_size=1,
474 | pad_id=1,
475 | text_dim=1024,
476 | text_heads=16,
477 | text_layers=24,
478 | text_post_norm=True,
479 | text_dropout=0.1,
480 | attn_dropout=0.0,
481 | proj_dropout=0.0,
482 | embedding_dropout=0.0,
483 | )
484 | cfg.update(**kwargs)
485 | return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
486 |
487 |
488 | class CLIPModel(ModelMixin):
489 | def __init__(self, checkpoint_path, tokenizer_path):
490 | self.checkpoint_path = checkpoint_path
491 | self.tokenizer_path = tokenizer_path
492 |
493 | super().__init__()
494 | # init model
495 | self.model, self.transforms = clip_xlm_roberta_vit_h_14(
496 | pretrained=False, return_transforms=True, return_tokenizer=False
497 | )
498 | self.model = self.model.eval().requires_grad_(False)
499 | logging.info(f"loading {checkpoint_path}")
500 | self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
501 |
502 | # init tokenizer
503 | self.tokenizer = HuggingfaceTokenizer(
504 | name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
505 | )
506 |
507 | def encode_video(self, video):
508 | # preprocess
509 | b, c, t, h, w = video.shape
510 | video = video.transpose(1, 2)
511 | video = video.reshape(b * t, c, h, w)
512 | size = (self.model.image_size,) * 2
513 | video = F.interpolate(
514 | video,
515 | size=size,
516 | mode='bicubic',
517 | align_corners=False)
518 |
519 | video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
520 |
521 | # forward
522 | with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
523 | out = self.model.visual(video, use_31_block=True)
524 |
525 | return out
526 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/t5.py:
--------------------------------------------------------------------------------
1 | # Modified from transformers.models.t5.modeling_t5
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import logging
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from diffusers.models import ModelMixin
10 |
11 | from .tokenizers import HuggingfaceTokenizer
12 |
13 | __all__ = [
14 | "T5Model",
15 | "T5Encoder",
16 | "T5Decoder",
17 | "T5EncoderModel",
18 | ]
19 |
20 |
21 | def fp16_clamp(x):
22 | if x.dtype == torch.float16 and torch.isinf(x).any():
23 | clamp = torch.finfo(x.dtype).max - 1000
24 | x = torch.clamp(x, min=-clamp, max=clamp)
25 | return x
26 |
27 |
28 | def init_weights(m):
29 | if isinstance(m, T5LayerNorm):
30 | nn.init.ones_(m.weight)
31 | elif isinstance(m, T5Model):
32 | nn.init.normal_(m.token_embedding.weight, std=1.0)
33 | elif isinstance(m, T5FeedForward):
34 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
35 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
36 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
37 | elif isinstance(m, T5Attention):
38 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
39 | nn.init.normal_(m.k.weight, std=m.dim**-0.5)
40 | nn.init.normal_(m.v.weight, std=m.dim**-0.5)
41 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
42 | elif isinstance(m, T5RelativeEmbedding):
43 | nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
44 |
45 |
46 | class GELU(nn.Module):
47 | def forward(self, x):
48 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
49 |
50 |
51 | class T5LayerNorm(nn.Module):
52 | def __init__(self, dim, eps=1e-6):
53 | super(T5LayerNorm, self).__init__()
54 | self.dim = dim
55 | self.eps = eps
56 | self.weight = nn.Parameter(torch.ones(dim))
57 |
58 | def forward(self, x):
59 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
60 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
61 | x = x.type_as(self.weight)
62 | return self.weight * x
63 |
64 |
65 | class T5Attention(nn.Module):
66 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
67 | assert dim_attn % num_heads == 0
68 | super(T5Attention, self).__init__()
69 | self.dim = dim
70 | self.dim_attn = dim_attn
71 | self.num_heads = num_heads
72 | self.head_dim = dim_attn // num_heads
73 |
74 | # layers
75 | self.q = nn.Linear(dim, dim_attn, bias=False)
76 | self.k = nn.Linear(dim, dim_attn, bias=False)
77 | self.v = nn.Linear(dim, dim_attn, bias=False)
78 | self.o = nn.Linear(dim_attn, dim, bias=False)
79 | self.dropout = nn.Dropout(dropout)
80 |
81 | def forward(self, x, context=None, mask=None, pos_bias=None):
82 | """
83 | x: [B, L1, C].
84 | context: [B, L2, C] or None.
85 | mask: [B, L2] or [B, L1, L2] or None.
86 | """
87 | # check inputs
88 | context = x if context is None else context
89 | b, n, c = x.size(0), self.num_heads, self.head_dim
90 |
91 | # compute query, key, value
92 | q = self.q(x).view(b, -1, n, c)
93 | k = self.k(context).view(b, -1, n, c)
94 | v = self.v(context).view(b, -1, n, c)
95 |
96 | # attention bias
97 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
98 | if pos_bias is not None:
99 | attn_bias += pos_bias
100 | if mask is not None:
101 | assert mask.ndim in [2, 3]
102 | mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
103 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
104 |
105 | # compute attention (T5 does not use scaling)
106 | attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
107 | attn = F.softmax(attn.float(), dim=-1).type_as(attn)
108 | x = torch.einsum("bnij,bjnc->binc", attn, v)
109 |
110 | # output
111 | x = x.reshape(b, -1, n * c)
112 | x = self.o(x)
113 | x = self.dropout(x)
114 | return x
115 |
116 |
117 | class T5FeedForward(nn.Module):
118 | def __init__(self, dim, dim_ffn, dropout=0.1):
119 | super(T5FeedForward, self).__init__()
120 | self.dim = dim
121 | self.dim_ffn = dim_ffn
122 |
123 | # layers
124 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
125 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
126 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
127 | self.dropout = nn.Dropout(dropout)
128 |
129 | def forward(self, x):
130 | x = self.fc1(x) * self.gate(x)
131 | x = self.dropout(x)
132 | x = self.fc2(x)
133 | x = self.dropout(x)
134 | return x
135 |
136 |
137 | class T5SelfAttention(nn.Module):
138 | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
139 | super(T5SelfAttention, self).__init__()
140 | self.dim = dim
141 | self.dim_attn = dim_attn
142 | self.dim_ffn = dim_ffn
143 | self.num_heads = num_heads
144 | self.num_buckets = num_buckets
145 | self.shared_pos = shared_pos
146 |
147 | # layers
148 | self.norm1 = T5LayerNorm(dim)
149 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
150 | self.norm2 = T5LayerNorm(dim)
151 | self.ffn = T5FeedForward(dim, dim_ffn, dropout)
152 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
153 |
154 | def forward(self, x, mask=None, pos_bias=None):
155 | e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
156 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
157 | x = fp16_clamp(x + self.ffn(self.norm2(x)))
158 | return x
159 |
160 |
161 | class T5CrossAttention(nn.Module):
162 | def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
163 | super(T5CrossAttention, self).__init__()
164 | self.dim = dim
165 | self.dim_attn = dim_attn
166 | self.dim_ffn = dim_ffn
167 | self.num_heads = num_heads
168 | self.num_buckets = num_buckets
169 | self.shared_pos = shared_pos
170 |
171 | # layers
172 | self.norm1 = T5LayerNorm(dim)
173 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
174 | self.norm2 = T5LayerNorm(dim)
175 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
176 | self.norm3 = T5LayerNorm(dim)
177 | self.ffn = T5FeedForward(dim, dim_ffn, dropout)
178 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
179 |
180 | def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
181 | e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
182 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
183 | x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
184 | x = fp16_clamp(x + self.ffn(self.norm3(x)))
185 | return x
186 |
187 |
188 | class T5RelativeEmbedding(nn.Module):
189 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
190 | super(T5RelativeEmbedding, self).__init__()
191 | self.num_buckets = num_buckets
192 | self.num_heads = num_heads
193 | self.bidirectional = bidirectional
194 | self.max_dist = max_dist
195 |
196 | # layers
197 | self.embedding = nn.Embedding(num_buckets, num_heads)
198 |
199 | def forward(self, lq, lk):
200 | device = self.embedding.weight.device
201 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
202 | # torch.arange(lq).unsqueeze(1).to(device)
203 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
204 | rel_pos = self._relative_position_bucket(rel_pos)
205 | rel_pos_embeds = self.embedding(rel_pos)
206 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
207 | return rel_pos_embeds.contiguous()
208 |
209 | def _relative_position_bucket(self, rel_pos):
210 | # preprocess
211 | if self.bidirectional:
212 | num_buckets = self.num_buckets // 2
213 | rel_buckets = (rel_pos > 0).long() * num_buckets
214 | rel_pos = torch.abs(rel_pos)
215 | else:
216 | num_buckets = self.num_buckets
217 | rel_buckets = 0
218 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
219 |
220 | # embeddings for small and large positions
221 | max_exact = num_buckets // 2
222 | rel_pos_large = (
223 | max_exact
224 | + (
225 | torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
226 | ).long()
227 | )
228 | rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
229 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
230 | return rel_buckets
231 |
232 |
233 | class T5Encoder(nn.Module):
234 | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
235 | super(T5Encoder, self).__init__()
236 | self.dim = dim
237 | self.dim_attn = dim_attn
238 | self.dim_ffn = dim_ffn
239 | self.num_heads = num_heads
240 | self.num_layers = num_layers
241 | self.num_buckets = num_buckets
242 | self.shared_pos = shared_pos
243 |
244 | # layers
245 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
246 | self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
247 | self.dropout = nn.Dropout(dropout)
248 | self.blocks = nn.ModuleList(
249 | [
250 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
251 | for _ in range(num_layers)
252 | ]
253 | )
254 | self.norm = T5LayerNorm(dim)
255 |
256 | # initialize weights
257 | self.apply(init_weights)
258 |
259 | def forward(self, ids, mask=None):
260 | x = self.token_embedding(ids)
261 | x = self.dropout(x)
262 | e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
263 | for block in self.blocks:
264 | x = block(x, mask, pos_bias=e)
265 | x = self.norm(x)
266 | x = self.dropout(x)
267 | return x
268 |
269 |
270 | class T5Decoder(nn.Module):
271 | def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
272 | super(T5Decoder, self).__init__()
273 | self.dim = dim
274 | self.dim_attn = dim_attn
275 | self.dim_ffn = dim_ffn
276 | self.num_heads = num_heads
277 | self.num_layers = num_layers
278 | self.num_buckets = num_buckets
279 | self.shared_pos = shared_pos
280 |
281 | # layers
282 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
283 | self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
284 | self.dropout = nn.Dropout(dropout)
285 | self.blocks = nn.ModuleList(
286 | [
287 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
288 | for _ in range(num_layers)
289 | ]
290 | )
291 | self.norm = T5LayerNorm(dim)
292 |
293 | # initialize weights
294 | self.apply(init_weights)
295 |
296 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
297 | b, s = ids.size()
298 |
299 | # causal mask
300 | if mask is None:
301 | mask = torch.tril(torch.ones(1, s, s).to(ids.device))
302 | elif mask.ndim == 2:
303 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
304 |
305 | # layers
306 | x = self.token_embedding(ids)
307 | x = self.dropout(x)
308 | e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
309 | for block in self.blocks:
310 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
311 | x = self.norm(x)
312 | x = self.dropout(x)
313 | return x
314 |
315 |
316 | class T5Model(nn.Module):
317 | def __init__(
318 | self,
319 | vocab_size,
320 | dim,
321 | dim_attn,
322 | dim_ffn,
323 | num_heads,
324 | encoder_layers,
325 | decoder_layers,
326 | num_buckets,
327 | shared_pos=True,
328 | dropout=0.1,
329 | ):
330 | super(T5Model, self).__init__()
331 | self.vocab_size = vocab_size
332 | self.dim = dim
333 | self.dim_attn = dim_attn
334 | self.dim_ffn = dim_ffn
335 | self.num_heads = num_heads
336 | self.encoder_layers = encoder_layers
337 | self.decoder_layers = decoder_layers
338 | self.num_buckets = num_buckets
339 |
340 | # layers
341 | self.token_embedding = nn.Embedding(vocab_size, dim)
342 | self.encoder = T5Encoder(
343 | self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
344 | )
345 | self.decoder = T5Decoder(
346 | self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
347 | )
348 | self.head = nn.Linear(dim, vocab_size, bias=False)
349 |
350 | # initialize weights
351 | self.apply(init_weights)
352 |
353 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
354 | x = self.encoder(encoder_ids, encoder_mask)
355 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
356 | x = self.head(x)
357 | return x
358 |
359 |
360 | def _t5(
361 | name,
362 | encoder_only=False,
363 | decoder_only=False,
364 | return_tokenizer=False,
365 | tokenizer_kwargs={},
366 | dtype=torch.float32,
367 | device="cpu",
368 | **kwargs,
369 | ):
370 | # sanity check
371 | assert not (encoder_only and decoder_only)
372 |
373 | # params
374 | if encoder_only:
375 | model_cls = T5Encoder
376 | kwargs["vocab"] = kwargs.pop("vocab_size")
377 | kwargs["num_layers"] = kwargs.pop("encoder_layers")
378 | _ = kwargs.pop("decoder_layers")
379 | elif decoder_only:
380 | model_cls = T5Decoder
381 | kwargs["vocab"] = kwargs.pop("vocab_size")
382 | kwargs["num_layers"] = kwargs.pop("decoder_layers")
383 | _ = kwargs.pop("encoder_layers")
384 | else:
385 | model_cls = T5Model
386 |
387 | # init model
388 | with torch.device(device):
389 | model = model_cls(**kwargs)
390 |
391 | # set device
392 | model = model.to(dtype=dtype, device=device)
393 |
394 | # init tokenizer
395 | if return_tokenizer:
396 | from .tokenizers import HuggingfaceTokenizer
397 |
398 | tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
399 | return model, tokenizer
400 | else:
401 | return model
402 |
403 |
404 | def umt5_xxl(**kwargs):
405 | cfg = dict(
406 | vocab_size=256384,
407 | dim=4096,
408 | dim_attn=4096,
409 | dim_ffn=10240,
410 | num_heads=64,
411 | encoder_layers=24,
412 | decoder_layers=24,
413 | num_buckets=32,
414 | shared_pos=False,
415 | dropout=0.1,
416 | )
417 | cfg.update(**kwargs)
418 | return _t5("umt5-xxl", **cfg)
419 |
420 |
421 | class T5EncoderModel(ModelMixin):
422 | def __init__(
423 | self,
424 | checkpoint_path=None,
425 | tokenizer_path=None,
426 | text_len=512,
427 | shard_fn=None,
428 | ):
429 | self.text_len = text_len
430 | self.checkpoint_path = checkpoint_path
431 | self.tokenizer_path = tokenizer_path
432 |
433 | super().__init__()
434 | # init model
435 | model = umt5_xxl(encoder_only=True, return_tokenizer=False)
436 | logging.info(f"loading {checkpoint_path}")
437 | model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
438 | self.model = model
439 | if shard_fn is not None:
440 | self.model = shard_fn(self.model, sync_module_states=False)
441 | else:
442 | self.model.eval().requires_grad_(False)
443 | # init tokenizer
444 | self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
445 |
446 | def encode(self, texts):
447 | ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
448 | ids = ids.to(self.device)
449 | mask = mask.to(self.device)
450 | # seq_lens = mask.gt(0).sum(dim=1).long()
451 | context = self.model(ids, mask)
452 | context = context * mask.unsqueeze(-1).cuda()
453 |
454 | return context
455 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import html
3 | import string
4 |
5 | import ftfy
6 | import regex as re
7 | from transformers import AutoTokenizer
8 |
9 | __all__ = ["HuggingfaceTokenizer"]
10 |
11 |
12 | def basic_clean(text):
13 | text = ftfy.fix_text(text)
14 | text = html.unescape(html.unescape(text))
15 | return text.strip()
16 |
17 |
18 | def whitespace_clean(text):
19 | text = re.sub(r"\s+", " ", text)
20 | text = text.strip()
21 | return text
22 |
23 |
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 | text = text.replace("_", " ")
26 | if keep_punctuation_exact_string:
27 | text = keep_punctuation_exact_string.join(
28 | part.translate(str.maketrans("", "", string.punctuation))
29 | for part in text.split(keep_punctuation_exact_string)
30 | )
31 | else:
32 | text = text.translate(str.maketrans("", "", string.punctuation))
33 | text = text.lower()
34 | text = re.sub(r"\s+", " ", text)
35 | return text.strip()
36 |
37 |
38 | class HuggingfaceTokenizer:
39 | def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 | assert clean in (None, "whitespace", "lower", "canonicalize")
41 | self.name = name
42 | self.seq_len = seq_len
43 | self.clean = clean
44 |
45 | # init tokenizer
46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 | self.vocab_size = self.tokenizer.vocab_size
48 |
49 | def __call__(self, sequence, **kwargs):
50 | return_mask = kwargs.pop("return_mask", False)
51 |
52 | # arguments
53 | _kwargs = {"return_tensors": "pt"}
54 | if self.seq_len is not None:
55 | _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
56 | _kwargs.update(**kwargs)
57 |
58 | # tokenization
59 | if isinstance(sequence, str):
60 | sequence = [sequence]
61 | if self.clean:
62 | sequence = [self._clean(u) for u in sequence]
63 | ids = self.tokenizer(sequence, **_kwargs)
64 |
65 | # output
66 | if return_mask:
67 | return ids.input_ids, ids.attention_mask
68 | else:
69 | return ids.input_ids
70 |
71 | def _clean(self, text):
72 | if self.clean == "whitespace":
73 | text = whitespace_clean(basic_clean(text))
74 | elif self.clean == "lower":
75 | text = whitespace_clean(basic_clean(text)).lower()
76 | elif self.clean == "canonicalize":
77 | text = canonicalize(basic_clean(text))
78 | return text
79 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import math
3 | import numpy as np
4 | import torch
5 | import torch.amp as amp
6 | import torch.nn as nn
7 | from diffusers.configuration_utils import ConfigMixin
8 | from diffusers.configuration_utils import register_to_config
9 | from diffusers.loaders import PeftAdapterMixin
10 | from diffusers.models.modeling_utils import ModelMixin
11 | from torch.backends.cuda import sdp_kernel
12 | from torch.nn.attention.flex_attention import BlockMask
13 | from torch.nn.attention.flex_attention import create_block_mask
14 | from torch.nn.attention.flex_attention import flex_attention
15 |
16 | from .attention import flash_attention
17 |
18 |
19 | flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
20 |
21 | DISABLE_COMPILE = False # get os env
22 |
23 | __all__ = ["WanModel"]
24 |
25 |
26 | def sinusoidal_embedding_1d(dim, position):
27 | # preprocess
28 | assert dim % 2 == 0
29 | half = dim // 2
30 | position = position.type(torch.float64)
31 |
32 | # calculation
33 | sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
34 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
35 | return x
36 |
37 |
38 | @amp.autocast("cuda", enabled=False)
39 | def rope_params(max_seq_len, dim, theta=10000):
40 | assert dim % 2 == 0
41 | freqs = torch.outer(
42 | torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
43 | )
44 | freqs = torch.polar(torch.ones_like(freqs), freqs)
45 | return freqs
46 |
47 |
48 | @amp.autocast("cuda", enabled=False)
49 | def rope_apply(x, grid_sizes, freqs):
50 | n, c = x.size(2), x.size(3) // 2
51 | bs = x.size(0)
52 |
53 | # split freqs
54 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
55 |
56 | # loop over samples
57 | f, h, w = grid_sizes.tolist()
58 | seq_len = f * h * w
59 |
60 | # precompute multipliers
61 |
62 | x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
63 | freqs_i = torch.cat(
64 | [
65 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
68 | ],
69 | dim=-1,
70 | ).reshape(seq_len, 1, -1)
71 |
72 | # apply rotary embedding
73 | x = torch.view_as_real(x * freqs_i).flatten(3)
74 |
75 | return x
76 |
77 |
78 | @torch.compile(dynamic=True, disable=DISABLE_COMPILE)
79 | def fast_rms_norm(x, weight, eps):
80 | x = x.float()
81 | x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
82 | x = x.type_as(x) * weight
83 | return x
84 |
85 |
86 | class WanRMSNorm(nn.Module):
87 | def __init__(self, dim, eps=1e-5):
88 | super().__init__()
89 | self.dim = dim
90 | self.eps = eps
91 | self.weight = nn.Parameter(torch.ones(dim))
92 |
93 | def forward(self, x):
94 | r"""
95 | Args:
96 | x(Tensor): Shape [B, L, C]
97 | """
98 | return fast_rms_norm(x, self.weight, self.eps)
99 |
100 | def _norm(self, x):
101 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
102 |
103 |
104 | class WanLayerNorm(nn.LayerNorm):
105 | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
106 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
107 |
108 | def forward(self, x):
109 | r"""
110 | Args:
111 | x(Tensor): Shape [B, L, C]
112 | """
113 | return super().forward(x)
114 |
115 |
116 | class WanSelfAttention(nn.Module):
117 | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
118 | assert dim % num_heads == 0
119 | super().__init__()
120 | self.dim = dim
121 | self.num_heads = num_heads
122 | self.head_dim = dim // num_heads
123 | self.window_size = window_size
124 | self.qk_norm = qk_norm
125 | self.eps = eps
126 |
127 | # layers
128 | self.q = nn.Linear(dim, dim)
129 | self.k = nn.Linear(dim, dim)
130 | self.v = nn.Linear(dim, dim)
131 | self.o = nn.Linear(dim, dim)
132 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
133 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
134 |
135 | self._flag_ar_attention = False
136 |
137 | def set_ar_attention(self):
138 | self._flag_ar_attention = True
139 |
140 | def forward(self, x, grid_sizes, freqs, block_mask):
141 | r"""
142 | Args:
143 | x(Tensor): Shape [B, L, num_heads, C / num_heads]
144 | seq_lens(Tensor): Shape [B]
145 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
146 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
147 | """
148 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149 |
150 | # query, key, value function
151 | def qkv_fn(x):
152 | q = self.norm_q(self.q(x)).view(b, s, n, d)
153 | k = self.norm_k(self.k(x)).view(b, s, n, d)
154 | v = self.v(x).view(b, s, n, d)
155 | return q, k, v
156 |
157 | x = x.to(self.q.weight.dtype)
158 | q, k, v = qkv_fn(x)
159 |
160 | if not self._flag_ar_attention:
161 | q = rope_apply(q, grid_sizes, freqs)
162 | k = rope_apply(k, grid_sizes, freqs)
163 | x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
164 | else:
165 | q = rope_apply(q, grid_sizes, freqs)
166 | k = rope_apply(k, grid_sizes, freqs)
167 | q = q.to(torch.bfloat16)
168 | k = k.to(torch.bfloat16)
169 | v = v.to(torch.bfloat16)
170 |
171 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
172 | x = (
173 | torch.nn.functional.scaled_dot_product_attention(
174 | q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
175 | )
176 | .transpose(1, 2)
177 | .contiguous()
178 | )
179 |
180 | # output
181 | x = x.flatten(2)
182 | x = self.o(x)
183 | return x
184 |
185 |
186 | class WanT2VCrossAttention(WanSelfAttention):
187 | def forward(self, x, context):
188 | r"""
189 | Args:
190 | x(Tensor): Shape [B, L1, C]
191 | context(Tensor): Shape [B, L2, C]
192 | context_lens(Tensor): Shape [B]
193 | """
194 | b, n, d = x.size(0), self.num_heads, self.head_dim
195 |
196 | # compute query, key, value
197 | q = self.norm_q(self.q(x)).view(b, -1, n, d)
198 | k = self.norm_k(self.k(context)).view(b, -1, n, d)
199 | v = self.v(context).view(b, -1, n, d)
200 |
201 | # compute attention
202 | x = flash_attention(q, k, v)
203 |
204 | # output
205 | x = x.flatten(2)
206 | x = self.o(x)
207 | return x
208 |
209 |
210 | class WanI2VCrossAttention(WanSelfAttention):
211 | def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
212 | super().__init__(dim, num_heads, window_size, qk_norm, eps)
213 |
214 | self.k_img = nn.Linear(dim, dim)
215 | self.v_img = nn.Linear(dim, dim)
216 | # self.alpha = nn.Parameter(torch.zeros((1, )))
217 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
218 |
219 | def forward(self, x, context):
220 | r"""
221 | Args:
222 | x(Tensor): Shape [B, L1, C]
223 | context(Tensor): Shape [B, L2, C]
224 | context_lens(Tensor): Shape [B]
225 | """
226 | context_img = context[:, :257]
227 | context = context[:, 257:]
228 | b, n, d = x.size(0), self.num_heads, self.head_dim
229 |
230 | # compute query, key, value
231 | q = self.norm_q(self.q(x)).view(b, -1, n, d)
232 | k = self.norm_k(self.k(context)).view(b, -1, n, d)
233 | v = self.v(context).view(b, -1, n, d)
234 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
235 | v_img = self.v_img(context_img).view(b, -1, n, d)
236 | img_x = flash_attention(q, k_img, v_img)
237 | # compute attention
238 | x = flash_attention(q, k, v)
239 |
240 | # output
241 | x = x.flatten(2)
242 | img_x = img_x.flatten(2)
243 | x = x + img_x
244 | x = self.o(x)
245 | return x
246 |
247 |
248 | WAN_CROSSATTENTION_CLASSES = {
249 | "t2v_cross_attn": WanT2VCrossAttention,
250 | "i2v_cross_attn": WanI2VCrossAttention,
251 | }
252 |
253 |
254 | def mul_add(x, y, z):
255 | return x.float() + y.float() * z.float()
256 |
257 |
258 | def mul_add_add(x, y, z):
259 | return x.float() * (1 + y) + z
260 |
261 |
262 | mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
263 | mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
264 |
265 |
266 | class WanAttentionBlock(nn.Module):
267 | def __init__(
268 | self,
269 | cross_attn_type,
270 | dim,
271 | ffn_dim,
272 | num_heads,
273 | window_size=(-1, -1),
274 | qk_norm=True,
275 | cross_attn_norm=False,
276 | eps=1e-6,
277 | ):
278 | super().__init__()
279 | self.dim = dim
280 | self.ffn_dim = ffn_dim
281 | self.num_heads = num_heads
282 | self.window_size = window_size
283 | self.qk_norm = qk_norm
284 | self.cross_attn_norm = cross_attn_norm
285 | self.eps = eps
286 |
287 | # layers
288 | self.norm1 = WanLayerNorm(dim, eps)
289 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
290 | self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
291 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
292 | self.norm2 = WanLayerNorm(dim, eps)
293 | self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
294 |
295 | # modulation
296 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
297 |
298 | def set_ar_attention(self):
299 | self.self_attn.set_ar_attention()
300 |
301 | def forward(
302 | self,
303 | x,
304 | e,
305 | grid_sizes,
306 | freqs,
307 | context,
308 | block_mask,
309 | ):
310 | r"""
311 | Args:
312 | x(Tensor): Shape [B, L, C]
313 | e(Tensor): Shape [B, 6, C]
314 | seq_lens(Tensor): Shape [B], length of each sequence in batch
315 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
316 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
317 | """
318 | if e.dim() == 3:
319 | modulation = self.modulation # 1, 6, dim
320 | with amp.autocast("cuda", dtype=torch.float32):
321 | e = (modulation + e).chunk(6, dim=1)
322 | elif e.dim() == 4:
323 | modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
324 | with amp.autocast("cuda", dtype=torch.float32):
325 | e = (modulation + e).chunk(6, dim=1)
326 | e = [ei.squeeze(1) for ei in e]
327 |
328 | # self-attention
329 | out = mul_add_add_compile(self.norm1(x), e[1], e[0])
330 | y = self.self_attn(out, grid_sizes, freqs, block_mask)
331 | with amp.autocast("cuda", dtype=torch.float32):
332 | x = mul_add_compile(x, y, e[2])
333 |
334 | # cross-attention & ffn function
335 | def cross_attn_ffn(x, context, e):
336 | dtype = context.dtype
337 | x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
338 | y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
339 | with amp.autocast("cuda", dtype=torch.float32):
340 | x = mul_add_compile(x, y, e[5])
341 | return x
342 |
343 | x = cross_attn_ffn(x, context, e)
344 | return x.to(torch.bfloat16)
345 |
346 |
347 | class Head(nn.Module):
348 | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
349 | super().__init__()
350 | self.dim = dim
351 | self.out_dim = out_dim
352 | self.patch_size = patch_size
353 | self.eps = eps
354 |
355 | # layers
356 | out_dim = math.prod(patch_size) * out_dim
357 | self.norm = WanLayerNorm(dim, eps)
358 | self.head = nn.Linear(dim, out_dim)
359 |
360 | # modulation
361 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
362 |
363 | def forward(self, x, e):
364 | r"""
365 | Args:
366 | x(Tensor): Shape [B, L1, C]
367 | e(Tensor): Shape [B, C]
368 | """
369 | with amp.autocast("cuda", dtype=torch.float32):
370 | if e.dim() == 2:
371 | modulation = self.modulation # 1, 2, dim
372 | e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
373 |
374 | elif e.dim() == 3:
375 | modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
376 | e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
377 | e = [ei.squeeze(1) for ei in e]
378 | x = self.head(self.norm(x) * (1 + e[1]) + e[0])
379 | return x
380 |
381 |
382 | class MLPProj(torch.nn.Module):
383 | def __init__(self, in_dim, out_dim):
384 | super().__init__()
385 |
386 | self.proj = torch.nn.Sequential(
387 | torch.nn.LayerNorm(in_dim),
388 | torch.nn.Linear(in_dim, in_dim),
389 | torch.nn.GELU(),
390 | torch.nn.Linear(in_dim, out_dim),
391 | torch.nn.LayerNorm(out_dim),
392 | )
393 |
394 | def forward(self, image_embeds):
395 | clip_extra_context_tokens = self.proj(image_embeds)
396 | return clip_extra_context_tokens
397 |
398 |
399 | class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
400 | r"""
401 | Wan diffusion backbone supporting both text-to-video and image-to-video.
402 | """
403 |
404 | ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
405 | _no_split_modules = ["WanAttentionBlock"]
406 |
407 | _supports_gradient_checkpointing = True
408 |
409 | @register_to_config
410 | def __init__(
411 | self,
412 | model_type="t2v",
413 | patch_size=(1, 2, 2),
414 | text_len=512,
415 | in_dim=16,
416 | dim=2048,
417 | ffn_dim=8192,
418 | freq_dim=256,
419 | text_dim=4096,
420 | out_dim=16,
421 | num_heads=16,
422 | num_layers=32,
423 | window_size=(-1, -1),
424 | qk_norm=True,
425 | cross_attn_norm=True,
426 | inject_sample_info=False,
427 | eps=1e-6,
428 | ):
429 | r"""
430 | Initialize the diffusion model backbone.
431 |
432 | Args:
433 | model_type (`str`, *optional*, defaults to 't2v'):
434 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
435 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
436 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
437 | text_len (`int`, *optional*, defaults to 512):
438 | Fixed length for text embeddings
439 | in_dim (`int`, *optional*, defaults to 16):
440 | Input video channels (C_in)
441 | dim (`int`, *optional*, defaults to 2048):
442 | Hidden dimension of the transformer
443 | ffn_dim (`int`, *optional*, defaults to 8192):
444 | Intermediate dimension in feed-forward network
445 | freq_dim (`int`, *optional*, defaults to 256):
446 | Dimension for sinusoidal time embeddings
447 | text_dim (`int`, *optional*, defaults to 4096):
448 | Input dimension for text embeddings
449 | out_dim (`int`, *optional*, defaults to 16):
450 | Output video channels (C_out)
451 | num_heads (`int`, *optional*, defaults to 16):
452 | Number of attention heads
453 | num_layers (`int`, *optional*, defaults to 32):
454 | Number of transformer blocks
455 | window_size (`tuple`, *optional*, defaults to (-1, -1)):
456 | Window size for local attention (-1 indicates global attention)
457 | qk_norm (`bool`, *optional*, defaults to True):
458 | Enable query/key normalization
459 | cross_attn_norm (`bool`, *optional*, defaults to False):
460 | Enable cross-attention normalization
461 | eps (`float`, *optional*, defaults to 1e-6):
462 | Epsilon value for normalization layers
463 | """
464 |
465 | super().__init__()
466 |
467 | assert model_type in ["t2v", "i2v"]
468 | self.model_type = model_type
469 |
470 | self.patch_size = patch_size
471 | self.text_len = text_len
472 | self.in_dim = in_dim
473 | self.dim = dim
474 | self.ffn_dim = ffn_dim
475 | self.freq_dim = freq_dim
476 | self.text_dim = text_dim
477 | self.out_dim = out_dim
478 | self.num_heads = num_heads
479 | self.num_layers = num_layers
480 | self.window_size = window_size
481 | self.qk_norm = qk_norm
482 | self.cross_attn_norm = cross_attn_norm
483 | self.eps = eps
484 | self.num_frame_per_block = 1
485 | self.flag_causal_attention = False
486 | self.block_mask = None
487 | self.enable_teacache = False
488 |
489 | # embeddings
490 | self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
491 | self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
492 |
493 | self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
494 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
495 |
496 | if inject_sample_info:
497 | self.fps_embedding = nn.Embedding(2, dim)
498 | self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
499 |
500 | # blocks
501 | cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
502 | self.blocks = nn.ModuleList(
503 | [
504 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
505 | for _ in range(num_layers)
506 | ]
507 | )
508 |
509 | # head
510 | self.head = Head(dim, out_dim, patch_size, eps)
511 |
512 | # buffers (don't use register_buffer otherwise dtype will be changed in to())
513 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
514 | d = dim // num_heads
515 | self.freqs = torch.cat(
516 | [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
517 | dim=1,
518 | )
519 |
520 | if model_type == "i2v":
521 | self.img_emb = MLPProj(1280, dim)
522 |
523 | self.gradient_checkpointing = False
524 |
525 | self.cpu_offloading = False
526 |
527 | self.inject_sample_info = inject_sample_info
528 | # initialize weights
529 | self.init_weights()
530 |
531 | def _set_gradient_checkpointing(self, module, value=False):
532 | self.gradient_checkpointing = value
533 |
534 | def zero_init_i2v_cross_attn(self):
535 | print("zero init i2v cross attn")
536 | for i in range(self.num_layers):
537 | self.blocks[i].cross_attn.v_img.weight.data.zero_()
538 | self.blocks[i].cross_attn.v_img.bias.data.zero_()
539 |
540 | @staticmethod
541 | def _prepare_blockwise_causal_attn_mask(
542 | device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
543 | ) -> BlockMask:
544 | """
545 | we will divide the token sequence into the following format
546 | [1 latent frame] [1 latent frame] ... [1 latent frame]
547 | We use flexattention to construct the attention mask
548 | """
549 | total_length = num_frames * frame_seqlen
550 |
551 | # we do right padding to get to a multiple of 128
552 | padded_length = math.ceil(total_length / 128) * 128 - total_length
553 |
554 | ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
555 |
556 | # Block-wise causal mask will attend to all elements that are before the end of the current chunk
557 | frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
558 |
559 | for tmp in frame_indices:
560 | ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
561 |
562 | def attention_mask(b, h, q_idx, kv_idx):
563 | return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
564 | # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
565 |
566 | block_mask = create_block_mask(
567 | attention_mask,
568 | B=None,
569 | H=None,
570 | Q_LEN=total_length + padded_length,
571 | KV_LEN=total_length + padded_length,
572 | _compile=False,
573 | device=device,
574 | )
575 |
576 | return block_mask
577 |
578 | def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
579 | self.enable_teacache = enable_teacache
580 | print('using teacache')
581 | self.cnt = 0
582 | self.num_steps = num_steps
583 | self.teacache_thresh = teacache_thresh
584 | self.accumulated_rel_l1_distance_even = 0
585 | self.accumulated_rel_l1_distance_odd = 0
586 | self.previous_e0_even = None
587 | self.previous_e0_odd = None
588 | self.previous_residual_even = None
589 | self.previous_residual_odd = None
590 | self.use_ref_steps = use_ret_steps
591 | if "I2V" in ckpt_dir:
592 | if use_ret_steps:
593 | if '540P' in ckpt_dir:
594 | self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
595 | if '720P' in ckpt_dir:
596 | self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
597 | self.ret_steps = 5*2
598 | self.cutoff_steps = num_steps*2
599 | else:
600 | if '540P' in ckpt_dir:
601 | self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
602 | if '720P' in ckpt_dir:
603 | self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
604 | self.ret_steps = 1*2
605 | self.cutoff_steps = num_steps*2 - 2
606 | else:
607 | if use_ret_steps:
608 | if '1.3B' in ckpt_dir:
609 | self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
610 | if '14B' in ckpt_dir:
611 | self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
612 | self.ret_steps = 5*2
613 | self.cutoff_steps = num_steps*2
614 | else:
615 | if '1.3B' in ckpt_dir:
616 | self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
617 | if '14B' in ckpt_dir:
618 | self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
619 | self.ret_steps = 1*2
620 | self.cutoff_steps = num_steps*2 - 2
621 |
622 | def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
623 | r"""
624 | Forward pass through the diffusion model
625 |
626 | Args:
627 | x (List[Tensor]):
628 | List of input video tensors, each with shape [C_in, F, H, W]
629 | t (Tensor):
630 | Diffusion timesteps tensor of shape [B]
631 | context (List[Tensor]):
632 | List of text embeddings each with shape [L, C]
633 | seq_len (`int`):
634 | Maximum sequence length for positional encoding
635 | clip_fea (Tensor, *optional*):
636 | CLIP image features for image-to-video mode
637 | y (List[Tensor], *optional*):
638 | Conditional video inputs for image-to-video mode, same shape as x
639 |
640 | Returns:
641 | List[Tensor]:
642 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
643 | """
644 | if self.model_type == "i2v":
645 | assert clip_fea is not None and y is not None
646 | # params
647 | device = self.patch_embedding.weight.device
648 | if self.freqs.device != device:
649 | self.freqs = self.freqs.to(device)
650 |
651 | if y is not None:
652 | x = torch.cat([x, y], dim=1)
653 |
654 | # embeddings
655 | x = self.patch_embedding(x)
656 | grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
657 | x = x.flatten(2).transpose(1, 2)
658 |
659 | if self.flag_causal_attention:
660 | frame_num = grid_sizes[0]
661 | height = grid_sizes[1]
662 | width = grid_sizes[2]
663 | block_num = frame_num // self.num_frame_per_block
664 | range_tensor = torch.arange(block_num).view(-1, 1)
665 | range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
666 | casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
667 | casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
668 | casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
669 | casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
670 | self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
671 |
672 | # time embeddings
673 | with amp.autocast("cuda", dtype=torch.float32):
674 | if t.dim() == 2:
675 | b, f = t.shape
676 | _flag_df = True
677 | else:
678 | _flag_df = False
679 |
680 | e = self.time_embedding(
681 | sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
682 | ) # b, dim
683 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
684 |
685 | if self.inject_sample_info:
686 | fps = torch.tensor(fps, dtype=torch.long, device=device)
687 |
688 | fps_emb = self.fps_embedding(fps).float()
689 | if _flag_df:
690 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
691 | else:
692 | e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
693 |
694 | if _flag_df:
695 | e = e.view(b, f, 1, 1, self.dim)
696 | e0 = e0.view(b, f, 1, 1, 6, self.dim)
697 | e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
698 | e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
699 | e0 = e0.transpose(1, 2).contiguous()
700 |
701 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
702 |
703 | # context
704 | context = self.text_embedding(context)
705 |
706 | if clip_fea is not None:
707 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
708 | context = torch.concat([context_clip, context], dim=1)
709 |
710 | # arguments
711 | kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
712 | if self.enable_teacache:
713 | modulated_inp = e0 if self.use_ref_steps else e
714 | # teacache
715 | if self.cnt%2==0: # even -> conditon
716 | self.is_even = True
717 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
718 | should_calc_even = True
719 | self.accumulated_rel_l1_distance_even = 0
720 | else:
721 | rescale_func = np.poly1d(self.coefficients)
722 | self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
723 | if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
724 | should_calc_even = False
725 | else:
726 | should_calc_even = True
727 | self.accumulated_rel_l1_distance_even = 0
728 | self.previous_e0_even = modulated_inp.clone()
729 |
730 | else: # odd -> unconditon
731 | self.is_even = False
732 | if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
733 | should_calc_odd = True
734 | self.accumulated_rel_l1_distance_odd = 0
735 | else:
736 | rescale_func = np.poly1d(self.coefficients)
737 | self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
738 | if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
739 | should_calc_odd = False
740 | else:
741 | should_calc_odd = True
742 | self.accumulated_rel_l1_distance_odd = 0
743 | self.previous_e0_odd = modulated_inp.clone()
744 |
745 | if self.enable_teacache:
746 | if self.is_even:
747 | if not should_calc_even:
748 | x += self.previous_residual_even
749 | else:
750 | ori_x = x.clone()
751 | for block in self.blocks:
752 | x = block(x, **kwargs)
753 | self.previous_residual_even = x - ori_x
754 | else:
755 | if not should_calc_odd:
756 | x += self.previous_residual_odd
757 | else:
758 | ori_x = x.clone()
759 | for block in self.blocks:
760 | x = block(x, **kwargs)
761 | self.previous_residual_odd = x - ori_x
762 |
763 | self.cnt += 1
764 | if self.cnt >= self.num_steps:
765 | self.cnt = 0
766 | else:
767 | for block in self.blocks:
768 | x = block(x, **kwargs)
769 |
770 | x = self.head(x, e)
771 |
772 | # unpatchify
773 | x = self.unpatchify(x, grid_sizes)
774 |
775 | return x.float()
776 |
777 | def unpatchify(self, x, grid_sizes):
778 | r"""
779 | Reconstruct video tensors from patch embeddings.
780 |
781 | Args:
782 | x (List[Tensor]):
783 | List of patchified features, each with shape [L, C_out * prod(patch_size)]
784 | grid_sizes (Tensor):
785 | Original spatial-temporal grid dimensions before patching,
786 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
787 |
788 | Returns:
789 | List[Tensor]:
790 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
791 | """
792 |
793 | c = self.out_dim
794 | bs = x.shape[0]
795 | x = x.view(bs, *grid_sizes, *self.patch_size, c)
796 | x = torch.einsum("bfhwpqrc->bcfphqwr", x)
797 | x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
798 |
799 | return x
800 |
801 | def set_ar_attention(self, causal_block_size):
802 | self.num_frame_per_block = causal_block_size
803 | self.flag_causal_attention = True
804 | for block in self.blocks:
805 | block.set_ar_attention()
806 |
807 | def init_weights(self):
808 | r"""
809 | Initialize model parameters using Xavier initialization.
810 | """
811 |
812 | # basic init
813 | for m in self.modules():
814 | if isinstance(m, nn.Linear):
815 | nn.init.xavier_uniform_(m.weight)
816 | if m.bias is not None:
817 | nn.init.zeros_(m.bias)
818 |
819 | # init embeddings
820 | nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
821 | for m in self.text_embedding.modules():
822 | if isinstance(m, nn.Linear):
823 | nn.init.normal_(m.weight, std=0.02)
824 | for m in self.time_embedding.modules():
825 | if isinstance(m, nn.Linear):
826 | nn.init.normal_(m.weight, std=0.02)
827 |
828 | if self.inject_sample_info:
829 | nn.init.normal_(self.fps_embedding.weight, std=0.02)
830 |
831 | for m in self.fps_projection.modules():
832 | if isinstance(m, nn.Linear):
833 | nn.init.normal_(m.weight, std=0.02)
834 |
835 | nn.init.zeros_(self.fps_projection[-1].weight)
836 | nn.init.zeros_(self.fps_projection[-1].bias)
837 |
838 | # init output layer
839 | nn.init.zeros_(self.head.head.weight)
840 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/vae.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 |
10 | __all__ = [
11 | "WanVAE",
12 | ]
13 |
14 | CACHE_T = 2
15 |
16 |
17 | class CausalConv3d(nn.Conv3d):
18 | """
19 | Causal 3d convolusion.
20 | """
21 |
22 | def __init__(self, *args, **kwargs):
23 | super().__init__(*args, **kwargs)
24 | self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
25 | self.padding = (0, 0, 0)
26 |
27 | def forward(self, x, cache_x=None):
28 | padding = list(self._padding)
29 | if cache_x is not None and self._padding[4] > 0:
30 | cache_x = cache_x.to(x.device)
31 | x = torch.cat([cache_x, x], dim=2)
32 | padding[4] -= cache_x.shape[2]
33 | x = F.pad(x, padding)
34 |
35 | return super().forward(x)
36 |
37 |
38 | class RMS_norm(nn.Module):
39 | def __init__(self, dim, channel_first=True, images=True, bias=False):
40 | super().__init__()
41 | broadcastable_dims = (1, 1, 1) if not images else (1, 1)
42 | shape = (dim, *broadcastable_dims) if channel_first else (dim,)
43 |
44 | self.channel_first = channel_first
45 | self.scale = dim**0.5
46 | self.gamma = nn.Parameter(torch.ones(shape))
47 | self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
48 |
49 | def forward(self, x):
50 | return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
51 |
52 |
53 | class Upsample(nn.Upsample):
54 | def forward(self, x):
55 | """
56 | Fix bfloat16 support for nearest neighbor interpolation.
57 | """
58 | return super().forward(x.float()).type_as(x)
59 |
60 |
61 | class Resample(nn.Module):
62 | def __init__(self, dim, mode):
63 | assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
64 | super().__init__()
65 | self.dim = dim
66 | self.mode = mode
67 |
68 | # layers
69 | if mode == "upsample2d":
70 | self.resample = nn.Sequential(
71 | Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
72 | )
73 | elif mode == "upsample3d":
74 | self.resample = nn.Sequential(
75 | Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
76 | )
77 | self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
78 |
79 | elif mode == "downsample2d":
80 | self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
81 | elif mode == "downsample3d":
82 | self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
83 | self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
84 |
85 | else:
86 | self.resample = nn.Identity()
87 |
88 | def forward(self, x, feat_cache=None, feat_idx=[0]):
89 | b, c, t, h, w = x.size()
90 | if self.mode == "upsample3d":
91 | if feat_cache is not None:
92 | idx = feat_idx[0]
93 | if feat_cache[idx] is None:
94 | feat_cache[idx] = "Rep"
95 | feat_idx[0] += 1
96 | else:
97 |
98 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
99 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
100 | # cache last frame of last two chunk
101 | cache_x = torch.cat(
102 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
103 | )
104 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
105 | cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
106 | if feat_cache[idx] == "Rep":
107 | x = self.time_conv(x)
108 | else:
109 | x = self.time_conv(x, feat_cache[idx])
110 | feat_cache[idx] = cache_x
111 | feat_idx[0] += 1
112 |
113 | x = x.reshape(b, 2, c, t, h, w)
114 | x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
115 | x = x.reshape(b, c, t * 2, h, w)
116 | t = x.shape[2]
117 | x = rearrange(x, "b c t h w -> (b t) c h w")
118 | x = self.resample(x)
119 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
120 |
121 | if self.mode == "downsample3d":
122 | if feat_cache is not None:
123 | idx = feat_idx[0]
124 | if feat_cache[idx] is None:
125 | feat_cache[idx] = x.clone()
126 | feat_idx[0] += 1
127 | else:
128 |
129 | cache_x = x[:, :, -1:, :, :].clone()
130 | # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
131 | # # cache last frame of last two chunk
132 | # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
133 |
134 | x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
135 | feat_cache[idx] = cache_x
136 | feat_idx[0] += 1
137 | return x
138 |
139 | def init_weight(self, conv):
140 | conv_weight = conv.weight
141 | nn.init.zeros_(conv_weight)
142 | c1, c2, t, h, w = conv_weight.size()
143 | one_matrix = torch.eye(c1, c2)
144 | init_matrix = one_matrix
145 | nn.init.zeros_(conv_weight)
146 | # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
147 | conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
148 | conv.weight.data.copy_(conv_weight)
149 | nn.init.zeros_(conv.bias.data)
150 |
151 | def init_weight2(self, conv):
152 | conv_weight = conv.weight.data
153 | nn.init.zeros_(conv_weight)
154 | c1, c2, t, h, w = conv_weight.size()
155 | init_matrix = torch.eye(c1 // 2, c2)
156 | # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
157 | conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
158 | conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
159 | conv.weight.data.copy_(conv_weight)
160 | nn.init.zeros_(conv.bias.data)
161 |
162 |
163 | class ResidualBlock(nn.Module):
164 | def __init__(self, in_dim, out_dim, dropout=0.0):
165 | super().__init__()
166 | self.in_dim = in_dim
167 | self.out_dim = out_dim
168 |
169 | # layers
170 | self.residual = nn.Sequential(
171 | RMS_norm(in_dim, images=False),
172 | nn.SiLU(),
173 | CausalConv3d(in_dim, out_dim, 3, padding=1),
174 | RMS_norm(out_dim, images=False),
175 | nn.SiLU(),
176 | nn.Dropout(dropout),
177 | CausalConv3d(out_dim, out_dim, 3, padding=1),
178 | )
179 | self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
180 |
181 | def forward(self, x, feat_cache=None, feat_idx=[0]):
182 | h = self.shortcut(x)
183 | for layer in self.residual:
184 | if isinstance(layer, CausalConv3d) and feat_cache is not None:
185 | idx = feat_idx[0]
186 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
187 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
188 | # cache last frame of last two chunk
189 | cache_x = torch.cat(
190 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
191 | )
192 | x = layer(x, feat_cache[idx])
193 | feat_cache[idx] = cache_x
194 | feat_idx[0] += 1
195 | else:
196 | x = layer(x)
197 | return x + h
198 |
199 |
200 | class AttentionBlock(nn.Module):
201 | """
202 | Causal self-attention with a single head.
203 | """
204 |
205 | def __init__(self, dim):
206 | super().__init__()
207 | self.dim = dim
208 |
209 | # layers
210 | self.norm = RMS_norm(dim)
211 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
212 | self.proj = nn.Conv2d(dim, dim, 1)
213 |
214 | # zero out the last layer params
215 | nn.init.zeros_(self.proj.weight)
216 |
217 | def forward(self, x):
218 | identity = x
219 | b, c, t, h, w = x.size()
220 | x = rearrange(x, "b c t h w -> (b t) c h w")
221 | x = self.norm(x)
222 | # compute query, key, value
223 | q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
224 |
225 | # apply attention
226 | x = F.scaled_dot_product_attention(
227 | q,
228 | k,
229 | v,
230 | )
231 | x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
232 |
233 | # output
234 | x = self.proj(x)
235 | x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
236 | return x + identity
237 |
238 |
239 | class Encoder3d(nn.Module):
240 | def __init__(
241 | self,
242 | dim=128,
243 | z_dim=4,
244 | dim_mult=[1, 2, 4, 4],
245 | num_res_blocks=2,
246 | attn_scales=[],
247 | temperal_downsample=[True, True, False],
248 | dropout=0.0,
249 | ):
250 | super().__init__()
251 | self.dim = dim
252 | self.z_dim = z_dim
253 | self.dim_mult = dim_mult
254 | self.num_res_blocks = num_res_blocks
255 | self.attn_scales = attn_scales
256 | self.temperal_downsample = temperal_downsample
257 |
258 | # dimensions
259 | dims = [dim * u for u in [1] + dim_mult]
260 | scale = 1.0
261 |
262 | # init block
263 | self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
264 |
265 | # downsample blocks
266 | downsamples = []
267 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
268 | # residual (+attention) blocks
269 | for _ in range(num_res_blocks):
270 | downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
271 | if scale in attn_scales:
272 | downsamples.append(AttentionBlock(out_dim))
273 | in_dim = out_dim
274 |
275 | # downsample block
276 | if i != len(dim_mult) - 1:
277 | mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
278 | downsamples.append(Resample(out_dim, mode=mode))
279 | scale /= 2.0
280 | self.downsamples = nn.Sequential(*downsamples)
281 |
282 | # middle blocks
283 | self.middle = nn.Sequential(
284 | ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
285 | )
286 |
287 | # output blocks
288 | self.head = nn.Sequential(
289 | RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)
290 | )
291 |
292 | def forward(self, x, feat_cache=None, feat_idx=[0]):
293 | if feat_cache is not None:
294 | idx = feat_idx[0]
295 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
296 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
297 | # cache last frame of last two chunk
298 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
299 | x = self.conv1(x, feat_cache[idx])
300 | feat_cache[idx] = cache_x
301 | feat_idx[0] += 1
302 | else:
303 | x = self.conv1(x)
304 |
305 | ## downsamples
306 | for layer in self.downsamples:
307 | if feat_cache is not None:
308 | x = layer(x, feat_cache, feat_idx)
309 | else:
310 | x = layer(x)
311 |
312 | ## middle
313 | for layer in self.middle:
314 | if isinstance(layer, ResidualBlock) and feat_cache is not None:
315 | x = layer(x, feat_cache, feat_idx)
316 | else:
317 | x = layer(x)
318 |
319 | ## head
320 | for layer in self.head:
321 | if isinstance(layer, CausalConv3d) and feat_cache is not None:
322 | idx = feat_idx[0]
323 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
324 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
325 | # cache last frame of last two chunk
326 | cache_x = torch.cat(
327 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
328 | )
329 | x = layer(x, feat_cache[idx])
330 | feat_cache[idx] = cache_x
331 | feat_idx[0] += 1
332 | else:
333 | x = layer(x)
334 | return x
335 |
336 |
337 | class Decoder3d(nn.Module):
338 | def __init__(
339 | self,
340 | dim=128,
341 | z_dim=4,
342 | dim_mult=[1, 2, 4, 4],
343 | num_res_blocks=2,
344 | attn_scales=[],
345 | temperal_upsample=[False, True, True],
346 | dropout=0.0,
347 | ):
348 | super().__init__()
349 | self.dim = dim
350 | self.z_dim = z_dim
351 | self.dim_mult = dim_mult
352 | self.num_res_blocks = num_res_blocks
353 | self.attn_scales = attn_scales
354 | self.temperal_upsample = temperal_upsample
355 |
356 | # dimensions
357 | dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
358 | scale = 1.0 / 2 ** (len(dim_mult) - 2)
359 |
360 | # init block
361 | self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
362 |
363 | # middle blocks
364 | self.middle = nn.Sequential(
365 | ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
366 | )
367 |
368 | # upsample blocks
369 | upsamples = []
370 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
371 | # residual (+attention) blocks
372 | if i == 1 or i == 2 or i == 3:
373 | in_dim = in_dim // 2
374 | for _ in range(num_res_blocks + 1):
375 | upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
376 | if scale in attn_scales:
377 | upsamples.append(AttentionBlock(out_dim))
378 | in_dim = out_dim
379 |
380 | # upsample block
381 | if i != len(dim_mult) - 1:
382 | mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
383 | upsamples.append(Resample(out_dim, mode=mode))
384 | scale *= 2.0
385 | self.upsamples = nn.Sequential(*upsamples)
386 |
387 | # output blocks
388 | self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
389 |
390 | def forward(self, x, feat_cache=None, feat_idx=[0]):
391 | ## conv1
392 | if feat_cache is not None:
393 | idx = feat_idx[0]
394 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
395 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
396 | # cache last frame of last two chunk
397 | cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
398 | x = self.conv1(x, feat_cache[idx])
399 | feat_cache[idx] = cache_x
400 | feat_idx[0] += 1
401 | else:
402 | x = self.conv1(x)
403 |
404 | ## middle
405 | for layer in self.middle:
406 | if isinstance(layer, ResidualBlock) and feat_cache is not None:
407 | x = layer(x, feat_cache, feat_idx)
408 | else:
409 | x = layer(x)
410 |
411 | ## upsamples
412 | for layer in self.upsamples:
413 | if feat_cache is not None:
414 | x = layer(x, feat_cache, feat_idx)
415 | else:
416 | x = layer(x)
417 |
418 | ## head
419 | for layer in self.head:
420 | if isinstance(layer, CausalConv3d) and feat_cache is not None:
421 | idx = feat_idx[0]
422 | cache_x = x[:, :, -CACHE_T:, :, :].clone()
423 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
424 | # cache last frame of last two chunk
425 | cache_x = torch.cat(
426 | [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
427 | )
428 | x = layer(x, feat_cache[idx])
429 | feat_cache[idx] = cache_x
430 | feat_idx[0] += 1
431 | else:
432 | x = layer(x)
433 | return x
434 |
435 |
436 | def count_conv3d(model):
437 | count = 0
438 | for m in model.modules():
439 | if isinstance(m, CausalConv3d):
440 | count += 1
441 | return count
442 |
443 |
444 | class WanVAE_(nn.Module):
445 | def __init__(
446 | self,
447 | dim=128,
448 | z_dim=4,
449 | dim_mult=[1, 2, 4, 4],
450 | num_res_blocks=2,
451 | attn_scales=[],
452 | temperal_downsample=[True, True, False],
453 | dropout=0.0,
454 | ):
455 | super().__init__()
456 | self.dim = dim
457 | self.z_dim = z_dim
458 | self.dim_mult = dim_mult
459 | self.num_res_blocks = num_res_blocks
460 | self.attn_scales = attn_scales
461 | self.temperal_downsample = temperal_downsample
462 | self.temperal_upsample = temperal_downsample[::-1]
463 |
464 | # modules
465 | self.encoder = Encoder3d(
466 | dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
467 | )
468 | self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
469 | self.conv2 = CausalConv3d(z_dim, z_dim, 1)
470 | self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
471 |
472 | def forward(self, x):
473 | mu, log_var = self.encode(x)
474 | z = self.reparameterize(mu, log_var)
475 | x_recon = self.decode(z)
476 | return x_recon, mu, log_var
477 |
478 | def encode(self, x, scale):
479 | self.clear_cache()
480 | ## cache
481 | t = x.shape[2]
482 | iter_ = 1 + (t - 1) // 4
483 | ## 对encode输入的x,按时间拆分为1、4、4、4....
484 | for i in range(iter_):
485 | self._enc_conv_idx = [0]
486 | if i == 0:
487 | out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
488 | else:
489 | out_ = self.encoder(
490 | x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
491 | feat_cache=self._enc_feat_map,
492 | feat_idx=self._enc_conv_idx,
493 | )
494 | out = torch.cat([out, out_], 2)
495 | mu, log_var = self.conv1(out).chunk(2, dim=1)
496 | if isinstance(scale[0], torch.Tensor):
497 | mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
498 | else:
499 | mu = (mu - scale[0]) * scale[1]
500 | self.clear_cache()
501 | return mu
502 |
503 | def decode(self, z, scale):
504 | self.clear_cache()
505 | # z: [b,c,t,h,w]
506 | if isinstance(scale[0], torch.Tensor):
507 | z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
508 | else:
509 | z = z / scale[1] + scale[0]
510 | iter_ = z.shape[2]
511 | x = self.conv2(z)
512 | for i in range(iter_):
513 | self._conv_idx = [0]
514 | if i == 0:
515 | out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
516 | else:
517 | out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
518 | out = torch.cat([out, out_], 2)
519 | self.clear_cache()
520 | return out
521 |
522 | def reparameterize(self, mu, log_var):
523 | std = torch.exp(0.5 * log_var)
524 | eps = torch.randn_like(std)
525 | return eps * std + mu
526 |
527 | def sample(self, imgs, deterministic=False):
528 | mu, log_var = self.encode(imgs)
529 | if deterministic:
530 | return mu
531 | std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
532 | return mu + std * torch.randn_like(std)
533 |
534 | def clear_cache(self):
535 | self._conv_num = count_conv3d(self.decoder)
536 | self._conv_idx = [0]
537 | self._feat_map = [None] * self._conv_num
538 | # cache encode
539 | self._enc_conv_num = count_conv3d(self.encoder)
540 | self._enc_conv_idx = [0]
541 | self._enc_feat_map = [None] * self._enc_conv_num
542 |
543 |
544 | def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
545 | """
546 | Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
547 | """
548 | # params
549 | cfg = dict(
550 | dim=96,
551 | z_dim=z_dim,
552 | dim_mult=[1, 2, 4, 4],
553 | num_res_blocks=2,
554 | attn_scales=[],
555 | temperal_downsample=[False, True, True],
556 | dropout=0.0,
557 | )
558 | cfg.update(**kwargs)
559 |
560 | # init model
561 | with torch.device("meta"):
562 | model = WanVAE_(**cfg)
563 |
564 | # load checkpoint
565 | logging.info(f"loading {pretrained_path}")
566 | model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
567 |
568 | return model
569 |
570 |
571 | class WanVAE:
572 | def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
573 |
574 | mean = [
575 | -0.7571,
576 | -0.7089,
577 | -0.9113,
578 | 0.1075,
579 | -0.1745,
580 | 0.9653,
581 | -0.1517,
582 | 1.5508,
583 | 0.4134,
584 | -0.0715,
585 | 0.5517,
586 | -0.3632,
587 | -0.1922,
588 | -0.9497,
589 | 0.2503,
590 | -0.2921,
591 | ]
592 | std = [
593 | 2.8184,
594 | 1.4541,
595 | 2.3275,
596 | 2.6558,
597 | 1.2196,
598 | 1.7708,
599 | 2.6052,
600 | 2.0743,
601 | 3.2687,
602 | 2.1526,
603 | 2.8652,
604 | 1.5579,
605 | 1.6382,
606 | 1.1253,
607 | 2.8251,
608 | 1.9160,
609 | ]
610 | self.vae_stride = (4, 8, 8)
611 | self.mean = torch.tensor(mean)
612 | self.std = torch.tensor(std)
613 | self.scale = [self.mean, 1.0 / self.std]
614 |
615 | # init model
616 | self.vae = (
617 | _video_vae(
618 | pretrained_path=vae_pth,
619 | z_dim=z_dim,
620 | )
621 | .eval()
622 | .requires_grad_(False)
623 | )
624 |
625 | def encode(self, video):
626 | """
627 | videos: A list of videos each with shape [C, T, H, W].
628 | """
629 | return self.vae.encode(video, self.scale).float()
630 |
631 | def to(self, *args, **kwargs):
632 | self.mean = self.mean.to(*args, **kwargs)
633 | self.std = self.std.to(*args, **kwargs)
634 | self.scale = [self.mean, 1.0 / self.std]
635 | self.vae = self.vae.to(*args, **kwargs)
636 | return self
637 |
638 | def decode(self, z):
639 | return self.vae.decode(z, self.scale).float().clamp_(-1, 1)
640 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/modules/xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | __all__ = ["XLMRoberta", "xlm_roberta_large"]
8 |
9 |
10 | class SelfAttention(nn.Module):
11 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
12 | assert dim % num_heads == 0
13 | super().__init__()
14 | self.dim = dim
15 | self.num_heads = num_heads
16 | self.head_dim = dim // num_heads
17 | self.eps = eps
18 |
19 | # layers
20 | self.q = nn.Linear(dim, dim)
21 | self.k = nn.Linear(dim, dim)
22 | self.v = nn.Linear(dim, dim)
23 | self.o = nn.Linear(dim, dim)
24 | self.dropout = nn.Dropout(dropout)
25 |
26 | def forward(self, x, mask):
27 | """
28 | x: [B, L, C].
29 | """
30 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
31 |
32 | # compute query, key, value
33 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
34 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36 |
37 | # compute attention
38 | p = self.dropout.p if self.training else 0.0
39 | x = F.scaled_dot_product_attention(q, k, v, mask, p)
40 | x = x.permute(0, 2, 1, 3).reshape(b, s, c)
41 |
42 | # output
43 | x = self.o(x)
44 | x = self.dropout(x)
45 | return x
46 |
47 |
48 | class AttentionBlock(nn.Module):
49 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
50 | super().__init__()
51 | self.dim = dim
52 | self.num_heads = num_heads
53 | self.post_norm = post_norm
54 | self.eps = eps
55 |
56 | # layers
57 | self.attn = SelfAttention(dim, num_heads, dropout, eps)
58 | self.norm1 = nn.LayerNorm(dim, eps=eps)
59 | self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
60 | self.norm2 = nn.LayerNorm(dim, eps=eps)
61 |
62 | def forward(self, x, mask):
63 | if self.post_norm:
64 | x = self.norm1(x + self.attn(x, mask))
65 | x = self.norm2(x + self.ffn(x))
66 | else:
67 | x = x + self.attn(self.norm1(x), mask)
68 | x = x + self.ffn(self.norm2(x))
69 | return x
70 |
71 |
72 | class XLMRoberta(nn.Module):
73 | """
74 | XLMRobertaModel with no pooler and no LM head.
75 | """
76 |
77 | def __init__(
78 | self,
79 | vocab_size=250002,
80 | max_seq_len=514,
81 | type_size=1,
82 | pad_id=1,
83 | dim=1024,
84 | num_heads=16,
85 | num_layers=24,
86 | post_norm=True,
87 | dropout=0.1,
88 | eps=1e-5,
89 | ):
90 | super().__init__()
91 | self.vocab_size = vocab_size
92 | self.max_seq_len = max_seq_len
93 | self.type_size = type_size
94 | self.pad_id = pad_id
95 | self.dim = dim
96 | self.num_heads = num_heads
97 | self.num_layers = num_layers
98 | self.post_norm = post_norm
99 | self.eps = eps
100 |
101 | # embeddings
102 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
103 | self.type_embedding = nn.Embedding(type_size, dim)
104 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
105 | self.dropout = nn.Dropout(dropout)
106 |
107 | # blocks
108 | self.blocks = nn.ModuleList(
109 | [AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]
110 | )
111 |
112 | # norm layer
113 | self.norm = nn.LayerNorm(dim, eps=eps)
114 |
115 | def forward(self, ids):
116 | """
117 | ids: [B, L] of torch.LongTensor.
118 | """
119 | b, s = ids.shape
120 | mask = ids.ne(self.pad_id).long()
121 |
122 | # embeddings
123 | x = (
124 | self.token_embedding(ids)
125 | + self.type_embedding(torch.zeros_like(ids))
126 | + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
127 | )
128 | if self.post_norm:
129 | x = self.norm(x)
130 | x = self.dropout(x)
131 |
132 | # blocks
133 | mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
134 | for block in self.blocks:
135 | x = block(x, mask)
136 |
137 | # output
138 | if not self.post_norm:
139 | x = self.norm(x)
140 | return x
141 |
142 |
143 | def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
144 | """
145 | XLMRobertaLarge adapted from Huggingface.
146 | """
147 | # params
148 | cfg = dict(
149 | vocab_size=250002,
150 | max_seq_len=514,
151 | type_size=1,
152 | pad_id=1,
153 | dim=1024,
154 | num_heads=16,
155 | num_layers=24,
156 | post_norm=True,
157 | dropout=0.1,
158 | eps=1e-5,
159 | )
160 | cfg.update(**kwargs)
161 |
162 | # init a model on device
163 | with torch.device(device):
164 | model = XLMRoberta(**cfg)
165 | return model
166 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusion_forcing_pipeline import DiffusionForcingPipeline
2 | from .image2video_pipeline import Image2VideoPipeline
3 | from .image2video_pipeline import resizecrop
4 | from .prompt_enhancer import PromptEnhancer
5 | from .text2video_pipeline import Text2VideoPipeline
6 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import List
4 | from typing import Optional
5 | from typing import Tuple
6 | from typing import Union
7 |
8 | import numpy as np
9 | import torch
10 | from diffusers.image_processor import PipelineImageInput
11 | from diffusers.utils.torch_utils import randn_tensor
12 | from diffusers.video_processor import VideoProcessor
13 | from tqdm import tqdm
14 |
15 | from ..modules import get_text_encoder
16 | from ..modules import get_transformer
17 | from ..modules import get_vae
18 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
19 |
20 |
21 | class DiffusionForcingPipeline:
22 | """
23 | A pipeline for diffusion-based video generation tasks.
24 |
25 | This pipeline supports two main tasks:
26 | - Image-to-Video (i2v): Generates a video sequence from a source image
27 | - Text-to-Video (t2v): Generates a video sequence from a text description
28 |
29 | The pipeline integrates multiple components including:
30 | - A transformer model for diffusion
31 | - A VAE for encoding/decoding
32 | - A text encoder for processing text prompts
33 | - An image encoder for processing image inputs (i2v mode only)
34 | """
35 |
36 | def __init__(
37 | self,
38 | model_path: str,
39 | dit_path: str,
40 | device: str = "cuda",
41 | weight_dtype=torch.bfloat16,
42 | use_usp=False,
43 | offload=False,
44 | ):
45 | """
46 | Initialize the diffusion forcing pipeline class
47 |
48 | Args:
49 | model_path (str): Path to the model
50 | dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor)
51 | device (str): Device to run on, defaults to 'cuda'
52 | weight_dtype: Weight data type, defaults to torch.bfloat16
53 | """
54 | load_device = "cpu" if offload else device
55 | self.transformer = get_transformer(dit_path, load_device, weight_dtype)
56 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
57 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
58 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
59 | self.video_processor = VideoProcessor(vae_scale_factor=16)
60 | self.device = device
61 | self.offload = offload
62 |
63 | if use_usp:
64 | from xfuser.core.distributed import get_sequence_parallel_world_size
65 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
66 | import types
67 |
68 | for block in self.transformer.blocks:
69 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
70 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
71 | self.sp_size = get_sequence_parallel_world_size()
72 |
73 | self.scheduler = FlowUniPCMultistepScheduler()
74 |
75 | @property
76 | def do_classifier_free_guidance(self) -> bool:
77 | return self._guidance_scale > 1
78 |
79 | def encode_image(
80 | self, image: PipelineImageInput, height: int, width: int, num_frames: int
81 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
82 |
83 | # prefix_video
84 | prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
85 | prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
86 | if prefix_video.dtype == torch.uint8:
87 | prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
88 | prefix_video = prefix_video.to(self.device)
89 | prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
90 | causal_block_size = self.transformer.num_frame_per_block
91 | if prefix_video[0].shape[1] % causal_block_size != 0:
92 | truncate_len = prefix_video[0].shape[1] % causal_block_size
93 | print("the length of prefix video is truncated for the casual block size alignment.")
94 | prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
95 | predix_video_latent_length = prefix_video[0].shape[1]
96 | return prefix_video, predix_video_latent_length
97 |
98 | def prepare_latents(
99 | self,
100 | shape: Tuple[int],
101 | dtype: Optional[torch.dtype] = None,
102 | device: Optional[torch.device] = None,
103 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
104 | ) -> torch.Tensor:
105 | return randn_tensor(shape, generator, device=device, dtype=dtype)
106 |
107 | def generate_timestep_matrix(
108 | self,
109 | num_frames,
110 | step_template,
111 | base_num_frames,
112 | ar_step=5,
113 | num_pre_ready=0,
114 | casual_block_size=1,
115 | shrink_interval_with_mask=False,
116 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
117 | step_matrix, step_index = [], []
118 | update_mask, valid_interval = [], []
119 | num_iterations = len(step_template) + 1
120 | num_frames_block = num_frames // casual_block_size
121 | base_num_frames_block = base_num_frames // casual_block_size
122 | if base_num_frames_block < num_frames_block:
123 | infer_step_num = len(step_template)
124 | gen_block = base_num_frames_block
125 | min_ar_step = infer_step_num / gen_block
126 | assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
127 | # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
128 | step_template = torch.cat(
129 | [
130 | torch.tensor([999], dtype=torch.int64, device=step_template.device),
131 | step_template.long(),
132 | torch.tensor([0], dtype=torch.int64, device=step_template.device),
133 | ]
134 | ) # to handle the counter in row works starting from 1
135 | pre_row = torch.zeros(num_frames_block, dtype=torch.long)
136 | if num_pre_ready > 0:
137 | pre_row[: num_pre_ready // casual_block_size] = num_iterations
138 |
139 | while torch.all(pre_row >= (num_iterations - 1)) == False:
140 | new_row = torch.zeros(num_frames_block, dtype=torch.long)
141 | for i in range(num_frames_block):
142 | if i == 0 or pre_row[i - 1] >= (
143 | num_iterations - 1
144 | ): # the first frame or the last frame is completely denoised
145 | new_row[i] = pre_row[i] + 1
146 | else:
147 | new_row[i] = new_row[i - 1] - ar_step
148 | new_row = new_row.clamp(0, num_iterations)
149 |
150 | update_mask.append(
151 | (new_row != pre_row) & (new_row != num_iterations)
152 | ) # False: no need to update, True: need to update
153 | step_index.append(new_row)
154 | step_matrix.append(step_template[new_row])
155 | pre_row = new_row
156 |
157 | # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
158 | terminal_flag = base_num_frames_block
159 | if shrink_interval_with_mask:
160 | idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
161 | update_mask = update_mask[0]
162 | update_mask_idx = idx_sequence[update_mask]
163 | last_update_idx = update_mask_idx[-1].item()
164 | terminal_flag = last_update_idx + 1
165 | # for i in range(0, len(update_mask)):
166 | for curr_mask in update_mask:
167 | if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
168 | terminal_flag += 1
169 | valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
170 |
171 | step_update_mask = torch.stack(update_mask, dim=0)
172 | step_index = torch.stack(step_index, dim=0)
173 | step_matrix = torch.stack(step_matrix, dim=0)
174 |
175 | if casual_block_size > 1:
176 | step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
177 | step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
178 | step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
179 | valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
180 |
181 | return step_matrix, step_index, step_update_mask, valid_interval
182 |
183 | @torch.no_grad()
184 | def __call__(
185 | self,
186 | prompt: Union[str, List[str]],
187 | negative_prompt: Union[str, List[str]] = "",
188 | image: PipelineImageInput = None,
189 | height: int = 480,
190 | width: int = 832,
191 | num_frames: int = 97,
192 | num_inference_steps: int = 50,
193 | shift: float = 1.0,
194 | guidance_scale: float = 5.0,
195 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
196 | overlap_history: int = None,
197 | addnoise_condition: int = 0,
198 | base_num_frames: int = 97,
199 | ar_step: int = 5,
200 | causal_block_size: int = None,
201 | fps: int = 24,
202 | ):
203 | latent_height = height // 8
204 | latent_width = width // 8
205 | latent_length = (num_frames - 1) // 4 + 1
206 |
207 | self._guidance_scale = guidance_scale
208 |
209 | i2v_extra_kwrags = {}
210 | prefix_video = None
211 | predix_video_latent_length = 0
212 | if image:
213 | prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
214 |
215 | self.text_encoder.to(self.device)
216 | prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
217 | if self.do_classifier_free_guidance:
218 | negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
219 | if self.offload:
220 | self.text_encoder.cpu()
221 | torch.cuda.empty_cache()
222 |
223 | self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
224 | init_timesteps = self.scheduler.timesteps
225 | if causal_block_size is None:
226 | causal_block_size = self.transformer.num_frame_per_block
227 | fps_embeds = [fps] * prompt_embeds.shape[0]
228 | fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
229 | transformer_dtype = self.transformer.dtype
230 | # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
231 | if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
232 | # short video generation
233 | latent_shape = [16, latent_length, latent_height, latent_width]
234 | latents = self.prepare_latents(
235 | latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
236 | )
237 | latents = [latents]
238 | if prefix_video is not None:
239 | latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
240 | base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
241 | step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
242 | latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
243 | )
244 | sample_schedulers = []
245 | for _ in range(latent_length):
246 | sample_scheduler = FlowUniPCMultistepScheduler(
247 | num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
248 | )
249 | sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
250 | sample_schedulers.append(sample_scheduler)
251 | sample_schedulers_counter = [0] * latent_length
252 | self.transformer.to(self.device)
253 | for i, timestep_i in enumerate(tqdm(step_matrix)):
254 | update_mask_i = step_update_mask[i]
255 | valid_interval_i = valid_interval[i]
256 | valid_interval_start, valid_interval_end = valid_interval_i
257 | timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
258 | latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
259 | if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
260 | noise_factor = 0.001 * addnoise_condition
261 | timestep_for_noised_condition = addnoise_condition
262 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
263 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
264 | + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
265 | * noise_factor
266 | )
267 | timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
268 | if not self.do_classifier_free_guidance:
269 | noise_pred = self.transformer(
270 | torch.stack([latent_model_input[0]]),
271 | t=timestep,
272 | context=prompt_embeds,
273 | fps=fps_embeds,
274 | **i2v_extra_kwrags,
275 | )[0]
276 | else:
277 | noise_pred_cond = self.transformer(
278 | torch.stack([latent_model_input[0]]),
279 | t=timestep,
280 | context=prompt_embeds,
281 | fps=fps_embeds,
282 | **i2v_extra_kwrags,
283 | )[0]
284 | noise_pred_uncond = self.transformer(
285 | torch.stack([latent_model_input[0]]),
286 | t=timestep,
287 | context=negative_prompt_embeds,
288 | fps=fps_embeds,
289 | **i2v_extra_kwrags,
290 | )[0]
291 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
292 | for idx in range(valid_interval_start, valid_interval_end):
293 | if update_mask_i[idx].item():
294 | latents[0][:, idx] = sample_schedulers[idx].step(
295 | noise_pred[:, idx - valid_interval_start],
296 | timestep_i[idx],
297 | latents[0][:, idx],
298 | return_dict=False,
299 | generator=generator,
300 | )[0]
301 | sample_schedulers_counter[idx] += 1
302 | if self.offload:
303 | self.transformer.cpu()
304 | torch.cuda.empty_cache()
305 | x0 = latents[0].unsqueeze(0)
306 | videos = self.vae.decode(x0)
307 | videos = (videos / 2 + 0.5).clamp(0, 1)
308 | videos = [video for video in videos]
309 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
310 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
311 | return videos
312 | else:
313 | # long video generation
314 | base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
315 | overlap_history_frames = (overlap_history - 1) // 4 + 1
316 | n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
317 | print(f"n_iter:{n_iter}")
318 | output_video = None
319 | for i in range(n_iter):
320 | if output_video is not None: # i !=0
321 | prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
322 | prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
323 | if prefix_video[0].shape[1] % causal_block_size != 0:
324 | truncate_len = prefix_video[0].shape[1] % causal_block_size
325 | print("the length of prefix video is truncated for the casual block size alignment.")
326 | prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
327 | predix_video_latent_length = prefix_video[0].shape[1]
328 | finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
329 | left_frame_num = latent_length - finished_frame_num
330 | base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
331 | if ar_step > 0 and self.transformer.enable_teacache:
332 | num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
333 | self.transformer.num_steps = num_steps
334 | else: # i == 0
335 | base_num_frames_iter = base_num_frames
336 | latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
337 | latents = self.prepare_latents(
338 | latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
339 | )
340 | latents = [latents]
341 | if prefix_video is not None:
342 | latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
343 | step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
344 | base_num_frames_iter,
345 | init_timesteps,
346 | base_num_frames_iter,
347 | ar_step,
348 | predix_video_latent_length,
349 | causal_block_size,
350 | )
351 | sample_schedulers = []
352 | for _ in range(base_num_frames_iter):
353 | sample_scheduler = FlowUniPCMultistepScheduler(
354 | num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
355 | )
356 | sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
357 | sample_schedulers.append(sample_scheduler)
358 | sample_schedulers_counter = [0] * base_num_frames_iter
359 | self.transformer.to(self.device)
360 | for i, timestep_i in enumerate(tqdm(step_matrix)):
361 | update_mask_i = step_update_mask[i]
362 | valid_interval_i = valid_interval[i]
363 | valid_interval_start, valid_interval_end = valid_interval_i
364 | timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
365 | latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
366 | if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
367 | noise_factor = 0.001 * addnoise_condition
368 | timestep_for_noised_condition = addnoise_condition
369 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
370 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
371 | * (1.0 - noise_factor)
372 | + torch.randn_like(
373 | latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
374 | )
375 | * noise_factor
376 | )
377 | timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
378 | if not self.do_classifier_free_guidance:
379 | noise_pred = self.transformer(
380 | torch.stack([latent_model_input[0]]),
381 | t=timestep,
382 | context=prompt_embeds,
383 | fps=fps_embeds,
384 | **i2v_extra_kwrags,
385 | )[0]
386 | else:
387 | noise_pred_cond = self.transformer(
388 | torch.stack([latent_model_input[0]]),
389 | t=timestep,
390 | context=prompt_embeds,
391 | fps=fps_embeds,
392 | **i2v_extra_kwrags,
393 | )[0]
394 | noise_pred_uncond = self.transformer(
395 | torch.stack([latent_model_input[0]]),
396 | t=timestep,
397 | context=negative_prompt_embeds,
398 | fps=fps_embeds,
399 | **i2v_extra_kwrags,
400 | )[0]
401 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
402 | for idx in range(valid_interval_start, valid_interval_end):
403 | if update_mask_i[idx].item():
404 | latents[0][:, idx] = sample_schedulers[idx].step(
405 | noise_pred[:, idx - valid_interval_start],
406 | timestep_i[idx],
407 | latents[0][:, idx],
408 | return_dict=False,
409 | generator=generator,
410 | )[0]
411 | sample_schedulers_counter[idx] += 1
412 | if self.offload:
413 | self.transformer.cpu()
414 | torch.cuda.empty_cache()
415 | x0 = latents[0].unsqueeze(0)
416 | videos = [self.vae.decode(x0)[0]]
417 | if output_video is None:
418 | output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
419 | else:
420 | output_video = torch.cat(
421 | [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
422 | ) # c, f, h, w
423 | output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
424 | output_video = [video for video in output_video]
425 | output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
426 | output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
427 | return output_video
428 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/pipelines/image2video_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | from typing import Optional
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | from diffusers.image_processor import PipelineImageInput
9 | from diffusers.video_processor import VideoProcessor
10 | from PIL import Image
11 | from tqdm import tqdm
12 |
13 | from ..modules import get_image_encoder
14 | from ..modules import get_text_encoder
15 | from ..modules import get_transformer
16 | from ..modules import get_vae
17 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
18 |
19 |
20 | def resizecrop(image: Image.Image, th, tw):
21 | w, h = image.size
22 | if w == tw and h == th:
23 | return image
24 | if h / w > th / tw:
25 | new_w = int(w)
26 | new_h = int(new_w * th / tw)
27 | else:
28 | new_h = int(h)
29 | new_w = int(new_h * tw / th)
30 | left = (w - new_w) / 2
31 | top = (h - new_h) / 2
32 | right = (w + new_w) / 2
33 | bottom = (h + new_h) / 2
34 | image = image.crop((left, top, right, bottom))
35 | return image
36 |
37 |
38 | class Image2VideoPipeline:
39 | def __init__(
40 | self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
41 | ):
42 | load_device = "cpu" if offload else device
43 | self.transformer = get_transformer(dit_path, load_device, weight_dtype)
44 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
45 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
46 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
47 | self.clip = get_image_encoder(model_path, load_device, weight_dtype)
48 | self.sp_size = 1
49 | self.device = device
50 | self.offload = offload
51 | self.video_processor = VideoProcessor(vae_scale_factor=16)
52 | if use_usp:
53 | from xfuser.core.distributed import get_sequence_parallel_world_size
54 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
55 | import types
56 |
57 | for block in self.transformer.blocks:
58 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
59 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
60 | self.sp_size = get_sequence_parallel_world_size()
61 |
62 | self.scheduler = FlowUniPCMultistepScheduler()
63 | self.vae_stride = (4, 8, 8)
64 | self.patch_size = (1, 2, 2)
65 |
66 | @torch.no_grad()
67 | def __call__(
68 | self,
69 | image: PipelineImageInput,
70 | prompt: Union[str, List[str]] = None,
71 | negative_prompt: Union[str, List[str]] = None,
72 | height: int = 544,
73 | width: int = 960,
74 | num_frames: int = 97,
75 | num_inference_steps: int = 50,
76 | guidance_scale: float = 5.0,
77 | shift: float = 5.0,
78 | generator: Optional[torch.Generator] = None,
79 | ):
80 | F = num_frames
81 |
82 | latent_height = height // 8 // 2 * 2
83 | latent_width = width // 8 // 2 * 2
84 | latent_length = (F - 1) // 4 + 1
85 |
86 | h = latent_height * 8
87 | w = latent_width * 8
88 |
89 | img = self.video_processor.preprocess(image, height=h, width=w)
90 |
91 | img = img.to(device=self.device, dtype=self.transformer.dtype)
92 |
93 | padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device)
94 |
95 | img = img.unsqueeze(2)
96 | img_cond = torch.concat([img, padding_video], dim=2)
97 | img_cond = self.vae.encode(img_cond)
98 | mask = torch.ones_like(img_cond)
99 | mask[:, :, 1:] = 0
100 | y = torch.cat([mask[:, :4], img_cond], dim=1)
101 | self.clip.to(self.device)
102 | clip_context = self.clip.encode_video(img)
103 | if self.offload:
104 | self.clip.cpu()
105 | torch.cuda.empty_cache()
106 |
107 | # preprocess
108 | self.text_encoder.to(self.device)
109 | context = self.text_encoder.encode(prompt).to(self.device)
110 | context_null = self.text_encoder.encode(negative_prompt).to(self.device)
111 | if self.offload:
112 | self.text_encoder.cpu()
113 | torch.cuda.empty_cache()
114 |
115 | latent = torch.randn(
116 | 16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device
117 | )
118 |
119 | self.transformer.to(self.device)
120 | with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
121 | self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
122 | timesteps = self.scheduler.timesteps
123 |
124 | arg_c = {
125 | "context": context,
126 | "clip_fea": clip_context,
127 | "y": y,
128 | }
129 |
130 | arg_null = {
131 | "context": context_null,
132 | "clip_fea": clip_context,
133 | "y": y,
134 | }
135 |
136 | self.transformer.to(self.device)
137 | for _, t in enumerate(tqdm(timesteps)):
138 | latent_model_input = torch.stack([latent]).to(self.device)
139 | timestep = torch.stack([t]).to(self.device)
140 | noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device)
141 | noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device)
142 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
143 |
144 | temp_x0 = self.scheduler.step(
145 | noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator
146 | )[0]
147 | latent = temp_x0.squeeze(0)
148 | if self.offload:
149 | self.transformer.cpu()
150 | torch.cuda.empty_cache()
151 | videos = self.vae.decode(latent)
152 | videos = (videos / 2 + 0.5).clamp(0, 1)
153 | videos = [video for video in videos]
154 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
155 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
156 | return videos
157 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/pipelines/prompt_enhancer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 |
4 | sys_prompt = """
5 | Transform the short prompt into a detailed video-generation caption using this structure:
6 | Opening shot type (long/medium/close-up/extreme close-up/full shot)
7 | Primary subject(s) with vivid attributes (colors, textures, actions, interactions)
8 | Dynamic elements (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
9 | Scene composition (background, environment, spatial relationships)
10 | Lighting/atmosphere (natural/artificial, time of day, mood)
11 | Camera motion (zooms, pans, static/handheld shots) if applicable.
12 |
13 | Pattern Summary from Examples:
14 | [Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
15 |
16 | One case:
17 | Short prompt: a person is playing football
18 | Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
19 |
20 | Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
21 |
22 | Now expand this short prompt: [{}]. Please only output the final long prompt in English.
23 | """
24 |
25 | class PromptEnhancer:
26 | def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
27 | self.model = AutoModelForCausalLM.from_pretrained(
28 | model_name,
29 | torch_dtype="auto",
30 | device_map="cuda:0",
31 | )
32 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
33 |
34 | def __call__(self, prompt):
35 | prompt = prompt.strip()
36 | prompt = sys_prompt.format(prompt)
37 | messages = [
38 | {"role": "system", "content": "You are a helpful assistant."},
39 | {"role": "user", "content": prompt}
40 | ]
41 | text = self.tokenizer.apply_chat_template(
42 | messages,
43 | tokenize=False,
44 | add_generation_prompt=True
45 | )
46 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
47 | generated_ids = self.model.generate(
48 | **model_inputs,
49 | max_new_tokens=2048,
50 | )
51 | generated_ids = [
52 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
53 | ]
54 | rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
55 | return rewritten_prompt
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
60 | args = parser.parse_args()
61 |
62 | prompt_enhancer = PromptEnhancer()
63 | enhanced_prompt = prompt_enhancer(args.prompt)
64 | print(f'Original prompt: {args.prompt}')
65 | print(f'Enhanced prompt: {enhanced_prompt}')
66 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/pipelines/text2video_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | from typing import Optional
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | from diffusers.video_processor import VideoProcessor
9 | from tqdm import tqdm
10 |
11 | from ..modules import get_text_encoder
12 | from ..modules import get_transformer
13 | from ..modules import get_vae
14 | from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
15 |
16 |
17 | class Text2VideoPipeline:
18 | def __init__(
19 | self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
20 | ):
21 | load_device = "cpu" if offload else device
22 | self.transformer = get_transformer(dit_path, load_device, weight_dtype)
23 | vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
24 | self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
25 | self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
26 | self.video_processor = VideoProcessor(vae_scale_factor=16)
27 | self.sp_size = 1
28 | self.device = device
29 | self.offload = offload
30 | if use_usp:
31 | from xfuser.core.distributed import get_sequence_parallel_world_size
32 | from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
33 | import types
34 |
35 | for block in self.transformer.blocks:
36 | block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
37 | self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
38 | self.sp_size = get_sequence_parallel_world_size()
39 |
40 | self.scheduler = FlowUniPCMultistepScheduler()
41 | self.vae_stride = (4, 8, 8)
42 | self.patch_size = (1, 2, 2)
43 |
44 | @torch.no_grad()
45 | def __call__(
46 | self,
47 | prompt: Union[str, List[str]] = None,
48 | negative_prompt: Union[str, List[str]] = None,
49 | width: int = 544,
50 | height: int = 960,
51 | num_frames: int = 97,
52 | num_inference_steps: int = 50,
53 | guidance_scale: float = 5.0,
54 | shift: float = 5.0,
55 | generator: Optional[torch.Generator] = None,
56 | ):
57 | # preprocess
58 | F = num_frames
59 | target_shape = (
60 | self.vae.vae.z_dim,
61 | (F - 1) // self.vae_stride[0] + 1,
62 | height // self.vae_stride[1],
63 | width // self.vae_stride[2],
64 | )
65 | self.text_encoder.to(self.device)
66 | context = self.text_encoder.encode(prompt).to(self.device)
67 | context_null = self.text_encoder.encode(negative_prompt).to(self.device)
68 | if self.offload:
69 | self.text_encoder.cpu()
70 | torch.cuda.empty_cache()
71 |
72 | latents = [
73 | torch.randn(
74 | target_shape[0],
75 | target_shape[1],
76 | target_shape[2],
77 | target_shape[3],
78 | dtype=torch.float32,
79 | device=self.device,
80 | generator=generator,
81 | )
82 | ]
83 |
84 | # evaluation mode
85 | self.transformer.to(self.device)
86 | with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
87 | self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
88 | timesteps = self.scheduler.timesteps
89 |
90 | for _, t in enumerate(tqdm(timesteps)):
91 | latent_model_input = torch.stack(latents)
92 | timestep = torch.stack([t])
93 | noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0]
94 | noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0]
95 |
96 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
97 |
98 | temp_x0 = self.scheduler.step(
99 | noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator
100 | )[0]
101 | latents = [temp_x0.squeeze(0)]
102 | if self.offload:
103 | self.transformer.cpu()
104 | torch.cuda.empty_cache()
105 | videos = self.vae.decode(latents[0])
106 | videos = (videos / 2 + 0.5).clamp(0, 1)
107 | videos = [video for video in videos]
108 | videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
109 | videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
110 | return videos
111 |
--------------------------------------------------------------------------------
/skyreels_v2_infer/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SkyworkAI/SkyReels-V2/6ace9655735f34e4cb8cae8cf8e35289142ecda7/skyreels_v2_infer/scheduler/__init__.py
--------------------------------------------------------------------------------