├── inputs ├── b2.mp4 └── ship.mp4 ├── src ├── __pycache__ │ └── __init__.cpython-310.pyc ├── models │ ├── __pycache__ │ │ ├── transformer_v.cpython-310.pyc │ │ └── transformer_wan.cpython-310.pyc │ ├── transformer_wan.py │ └── transformer_v.py ├── pipelines │ ├── __pycache__ │ │ ├── pipeline_v.cpython-310.pyc │ │ └── pipeline_wan.cpython-310.pyc │ ├── pipeline_wan.py │ └── pipeline_v.py └── __init__.py ├── utils ├── __pycache__ │ ├── native.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── cube2equi.cpython-310.pyc │ ├── save_vid.cpython-310.pyc │ ├── get_cubemap.cpython-310.pyc │ ├── infer_utils.cpython-310.pyc │ ├── rotate_equi.cpython-310.pyc │ └── affine_transform.cpython-310.pyc ├── __init__.py ├── save_vid.py ├── native.py ├── rotate_equi.py ├── get_cubemap.py ├── affine_transform.py ├── cube2equi.py └── infer_utils.py ├── configs ├── text_driven │ └── forest.yaml └── vid_driven │ ├── ship.yaml │ └── b2.yaml ├── requirements.txt ├── README.md └── inference.py /inputs/b2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/inputs/b2.mp4 -------------------------------------------------------------------------------- /inputs/ship.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/inputs/ship.mp4 -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/native.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/native.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cube2equi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/cube2equi.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/save_vid.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/save_vid.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_cubemap.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/get_cubemap.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/infer_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/infer_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rotate_equi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/rotate_equi.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/affine_transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/utils/__pycache__/affine_transform.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_v.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/src/models/__pycache__/transformer_v.cpython-310.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/transformer_wan.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/src/models/__pycache__/transformer_wan.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/pipeline_v.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/src/pipelines/__pycache__/pipeline_v.cpython-310.pyc -------------------------------------------------------------------------------- /src/pipelines/__pycache__/pipeline_wan.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/ViewPoint/HEAD/src/pipelines/__pycache__/pipeline_wan.cpython-310.pyc -------------------------------------------------------------------------------- /configs/text_driven/forest.yaml: -------------------------------------------------------------------------------- 1 | model_path: "path/to/Wan2.1-T2V-1.3B-Diffusers" 2 | transformer_id: "path/to/ViewPoint" 3 | 4 | seed: 38 5 | output_dir: "output_forest" 6 | prompt: "A scenic view of a river and a forest." 7 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines.pipeline_wan import WanPipeline 2 | from .pipelines.pipeline_v import ViewPipeline 3 | from .models.transformer_wan import WanTransformer3DModel 4 | from .models.transformer_v import ViewTransformer3DModel -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.4.0 2 | torchvision>=0.19.0 3 | opencv-python>=4.9.0.80 4 | diffusers>=0.31.0 5 | transformers>=4.49.0 6 | tokenizers>=0.20.3 7 | accelerate>=1.1.1 8 | tqdm 9 | imageio 10 | easydict 11 | ftfy 12 | dashscope 13 | imageio-ffmpeg 14 | flash_attn 15 | numpy>=1.23.5,<2 -------------------------------------------------------------------------------- /configs/vid_driven/ship.yaml: -------------------------------------------------------------------------------- 1 | model_path: "path/to/Wan2.1-T2V-1.3B-Diffusers" 2 | transformer_id: "path/to/ViewPoint" 3 | 4 | seed: 38 5 | output_dir: "output_ship" 6 | prompt: "夜晚的海洋,乌云和海浪包围,营造出一种神秘的气氛,展示了水和天空的浩瀚,高对比度摄影,空中视角,高分辨率,超逼真风格,电影般的灯光效果,超现实主义。" 7 | 8 | video: "./inputs/ship.mp4" 9 | direction: "F" -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_cubemap import get_cubemap 2 | from .rotate_equi import get_random_rotate, rotate_equi 3 | from .save_vid import save, save_cube_video 4 | from .affine_transform import get_affine_tensors, get_mask, get_condition_tensors, get_transform_matrix_1 5 | from .cube2equi import cube2equi -------------------------------------------------------------------------------- /configs/vid_driven/b2.yaml: -------------------------------------------------------------------------------- 1 | model_path: "path/to/Wan2.1-T2V-1.3B-Diffusers" 2 | transformer_id: "path/to/ViewPoint" 3 | 4 | seed: 43 5 | output_dir: "output_b2" 6 | prompt: "An aerial scene of dark cloud layers, with many airplanes traversing the sky, amid thunder and lightning, an apocalyptic scenario, surrealism." 7 | 8 | video: "./inputs/b2.mp4" 9 | direction: "F" -------------------------------------------------------------------------------- /utils/save_vid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from diffusers.utils import export_to_video, load_image, load_video 4 | 5 | def save(video, savepath="./test.mp4", fps=16, rescale=True, nrow=None, permute=True): 6 | if permute: 7 | video = video.permute(1,2,3,0) 8 | video = video.cpu().float() 9 | video = video.numpy() 10 | if rescale: 11 | video = (video + 1.0) / 2.0 12 | export_to_video(video, savepath, fps=fps) 13 | 14 | def save_cube_video(cube_tensor, output_folder='./cubemaps', fps=16, rescale=True): 15 | cube_tensor = cube_tensor.cpu().float() 16 | face_list = ['F','R','B','L','U','D'] 17 | os.makedirs(output_folder,exist_ok=True) 18 | 19 | cube_tensor = torch.chunk(cube_tensor,6,dim=-1) 20 | for i, face in enumerate(face_list): 21 | save(cube_tensor[i],os.path.join(output_folder,face+'.mp4'),fps=fps,rescale=rescale, permute=True, nrow=None) 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViewPoint 2 | ViewPoint: Panoramic Video Generation with Pretrained Diffusion Models 3 | 4 | [![arXiv](https://img.shields.io/badge/Paper-Arxiv-b31b1b.svg)](https://arxiv.org/abs/2506.23513) 5 | [![Project Page](https://img.shields.io/badge/Project-ViewPoint-green)](https://becauseimbatman0.github.io/ViewPoint) 6 | [![Modelscope](https://img.shields.io/badge/Models-ModelScope-purple)](https://www.modelscope.cn/models/highanddry/ViewPoint) 7 | 8 | ## Installation 9 | 10 | ``` 11 | git clone https://github.com/ali-vilab/ViewPoint 12 | cd ViewPoint 13 | ``` 14 | 15 | ## Environment 16 | ``` 17 | conda create -n viewpoint python=3.10 18 | conda activate viewpoint 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Pretrained Models 23 | | Models | DownloadLink | 24 | |-----------|---------| 25 | | Wan2.1-1.3B | [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | 26 | | ViewPoint ckpt | [ModelScope](https://www.modelscope.cn/models/highanddry/ViewPoint) | 27 | 28 | ## Inference 29 | ### 1. Modify your config file in ```configs/*_driven/*.yaml```. 30 | ``` 31 | model_path: "/path/to/your/Wan2.1-T2V-1.3B-Diffusers" 32 | transformer_id: "/path/to/your/ViewPoint" 33 | 34 | seed: 4396 35 | output_dir: "/path/to/your/output_dir" 36 | prompt: "English and Chinese prompts are okay. 英文和中文都行,可以适当做点prompt engineering." 37 | 38 | video: "/path/to/your/input_video" # Optional 39 | direction: "F" # Optional 40 | ``` 41 | ### 2. Text-driven generation. 42 | ``` 43 | python inference.py --config configs/text_driven/forest.yaml 44 | ``` 45 | ### 3. Video-driven generation. 46 | ``` 47 | python inference.py --config configs/vid_driven/ship.yaml 48 | ``` 49 | ## Citing 50 | ``` 51 | @misc{fang2025viewpointpanoramicvideogeneration, 52 | title={ViewPoint: Panoramic Video Generation with Pretrained Diffusion Models}, 53 | author={Zixun Fang and Kai Zhu and Zhiheng Liu and Yu Liu and Wei Zhai and Yang Cao and Zheng-Jun Zha}, 54 | year={2025}, 55 | eprint={2506.23513}, 56 | archivePrefix={arXiv}, 57 | primaryClass={cs.CV}, 58 | url={https://arxiv.org/abs/2506.23513}, 59 | } 60 | ``` -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from omegaconf import OmegaConf 4 | 5 | import torch 6 | from src import ViewPipeline, ViewTransformer3DModel 7 | from utils.infer_utils import * 8 | from utils import save, save_cube_video 9 | from decord import VideoReader, cpu 10 | from diffusers.utils import export_to_video, load_image, load_video 11 | import torch.nn.functional as F 12 | 13 | def generate_video(model_path, transformer_id, prompt, video_path, direction, rotate, output_dir, H, W, L, fps): 14 | transformer = ViewTransformer3DModel.from_pretrained( 15 | transformer_id, torch_dtype=torch.bfloat16, 16 | ) 17 | pipe = ViewPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda") 18 | 19 | if video_path is not None: 20 | vid = VideoReader(video_path, ctx=cpu(0), width=W, height=H) 21 | vid = vid.get_batch(range(len(vid))).asnumpy() 22 | vid = torch.from_numpy(vid).permute(0,3,1,2).to(device="cuda", dtype=torch.bfloat16) 23 | save(vid[:L,...].permute(1,0,2,3)/255.0, os.path.join(output_dir,"input.mp4"), fps=fps,rescale=False) 24 | print("input_shape:",vid[:L,...].shape) 25 | else: 26 | vid = None 27 | direction = "null" 28 | 29 | video = pipe( 30 | prompt=prompt, 31 | video=vid, 32 | selected_cube=direction, 33 | height = H, 34 | width = W, 35 | num_frames = L, 36 | rotate = rotate, 37 | ).frames[0] 38 | 39 | export_to_video(video,os.path.join(output_dir,"viewpoint.mp4"),fps=fps) 40 | 41 | cubemap = convert_to_cubemap2(video) 42 | 43 | save_cube_video(cubemap.permute(1,0,2,3), os.path.join(output_dir,"cubemap") ,fps=fps,rescale=False) 44 | 45 | equirec = cube2equi(cubemap) 46 | save(equirec.permute(1,0,2,3), os.path.join(output_dir,"equi.mp4"), fps=fps,rescale=False) 47 | 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser(description="Generate video using Wan2.1 and LoRA weights") 52 | parser.add_argument("-H", type=int, default=256) 53 | parser.add_argument("-W", type=int, default=256) 54 | parser.add_argument("-L", type=int, default=49) 55 | parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video") 56 | 57 | parser.add_argument("--config", type=str, default="./configs/VanGogh.yaml") 58 | 59 | args = parser.parse_args() 60 | config = OmegaConf.load(args.config) 61 | 62 | os.makedirs(config.output_dir,exist_ok=True) 63 | torch.manual_seed(config.seed) 64 | print("generating:",config.prompt) 65 | generate_video( 66 | config.model_path, 67 | config.transformer_id, 68 | config.prompt, 69 | config.get("video", None), 70 | config.get("direction", "null"), 71 | config.get("rotate",False), 72 | config.output_dir, 73 | args.H, 74 | args.W, 75 | args.L, 76 | args.fps 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /utils/native.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["native", "native_bicubic", "native_bilinear", "native_nearest"] 9 | 10 | 11 | def native( 12 | img: torch.Tensor, grid: torch.Tensor, mode: str = "bilinear" 13 | ) -> torch.Tensor: 14 | """Torch Grid Sample (default) 15 | 16 | - Uses `torch.nn.functional.grid_sample` 17 | - By far the best way to sample 18 | 19 | params: 20 | - img (torch.Tensor): Tensor[B, C, H, W] or Tensor[C, H, W] 21 | - grid (torch.Tensor): Tensor[B, 2, H, W] or Tensor[2, H, W] 22 | - device (int or str): torch.device 23 | - mode (str): (`bilinear`, `bicubic`, `nearest`) 24 | 25 | returns: 26 | - out (torch.Tensor): Tensor[B, C, H, W] or Tensor[C, H, W] 27 | where H, W are grid size 28 | 29 | NOTE: `img` and `grid` needs to be on the same device 30 | 31 | NOTE: `img` and `grid` is somehow mutated (inplace?), so if you need 32 | to reuse `img` and `grid` somewhere else, use `.clone()` before 33 | passing it to this function 34 | 35 | NOTE: this method is different from other grid sampling that 36 | the padding cannot be wrapped. There might be pixel inaccuracies 37 | when sampling from the boundaries of the image (the seam). 38 | 39 | I hope later on, we can add wrap padding to this since the function 40 | is super fast. 41 | 42 | """ 43 | 44 | assert ( 45 | grid.dtype == img.dtype 46 | ), "ERR: img and grid should have the same dtype" 47 | 48 | _, _, h, w = img.shape 49 | 50 | # grid in shape: (batch, channel, h_out, w_out) 51 | # grid out shape: (batch, h_out, w_out, channel) 52 | grid = grid.permute(0, 2, 3, 1) 53 | 54 | """Preprocess for grid_sample 55 | normalize grid -1 ~ 1 56 | 57 | assumptions: 58 | - values of `grid` is between `0 ~ (h-1)` and `0 ~ (w-1)` 59 | - input of `grid_sample` need to be between `-1 ~ 1` 60 | - maybe lose some precision when we map the values (int to float)? 61 | 62 | mapping (e.g. mapping of height): 63 | 1. 0 <= y <= (h-1) 64 | 2. -1/2 <= y' <= 1/2 <- y' = y/(h-1) - 1/2 65 | 3. -1 <= y" <= 1 <- y" = 2y' 66 | """ 67 | 68 | # FIXME: this is not necessary when we are already preprocessing grid before 69 | # this method is called 70 | # grid[..., 0] %= h 71 | # grid[..., 1] %= w 72 | 73 | norm_uj = torch.clamp(2 * grid[..., 0] / (h - 1) - 1, -1, 1) 74 | norm_ui = torch.clamp(2 * grid[..., 1] / (w - 1) - 1, -1, 1) 75 | 76 | # reverse: grid sample takes xy, not (height, width) 77 | grid[..., 0] = norm_ui 78 | grid[..., 1] = norm_uj 79 | 80 | out = F.grid_sample( 81 | img, 82 | grid, 83 | mode=mode, 84 | # use center of pixel instead of corner 85 | align_corners=True, 86 | # padding mode defaults to 'zeros' and there is no 'wrapping' mode 87 | padding_mode="reflection", 88 | ) 89 | 90 | return out 91 | 92 | 93 | # aliases 94 | native_nearest = partial(native, mode="nearest") 95 | native_bilinear = partial(native, mode="bilinear") 96 | native_bicubic = partial(native, mode="bicubic") 97 | -------------------------------------------------------------------------------- /utils/rotate_equi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from .native import native_bilinear 5 | from typing import Optional, List, Dict 6 | from einops import rearrange 7 | from .get_cubemap import torch_grid_sample 8 | 9 | def get_random_rotate(): 10 | roll_ang = 0 11 | pitch_ang = random.randint(-10, 10) 12 | yaw_ang = random.randint(0, 360) 13 | rot = [{ # TODO: define random rotate 14 | "roll": np.deg2rad(roll_ang), # 15 | "pitch": np.deg2rad(pitch_ang), # vertical 16 | "yaw": np.deg2rad(yaw_ang), # horizontal 17 | }] 18 | return rot 19 | 20 | 21 | def matmul(m: torch.Tensor, R: torch.Tensor) -> torch.Tensor: 22 | M = torch.matmul(R[:, None, None, ...], m) 23 | M = M.squeeze(-1) 24 | 25 | return M 26 | 27 | def create_normalized_grid( 28 | height: int, 29 | width: int, 30 | batch: Optional[int] = None, 31 | dtype: torch.dtype = torch.float32, 32 | device: torch.device = torch.device("cpu"), 33 | ) -> torch.Tensor: 34 | """Create coordinate grid with height and width 35 | 36 | NOTE: primarly used for equi2equi 37 | 38 | params: 39 | - height (int) 40 | - width (int) 41 | - batch (Optional[int]) 42 | - dtype (torch.dtype) 43 | 44 | return: 45 | - grid (torch.Tensor) 46 | 47 | """ 48 | 49 | # NOTE: RuntimeError: "linspace_cpu" not implemented for Half 50 | if device.type == "cpu": 51 | assert dtype in (torch.float32, torch.float64), ( 52 | f"ERR: {dtype} is not supported by {device.type}\n" 53 | "If device is `cpu`, use float32 or float64" 54 | ) 55 | 56 | xs = torch.linspace(0, width - 1, width, dtype=dtype, device=device) 57 | ys = torch.linspace(0, height - 1, height, dtype=dtype, device=device) 58 | theta = xs * 2 * np.pi / width - np.pi 59 | phi = ys * np.pi / height - np.pi / 2 60 | phi, theta = torch.meshgrid([phi, theta], indexing="ij") 61 | a = torch.stack((theta, phi), dim=-1) 62 | norm_A = 1 63 | x = norm_A * torch.cos(a[..., 1]) * torch.cos(a[..., 0]) 64 | y = norm_A * torch.cos(a[..., 1]) * torch.sin(a[..., 0]) 65 | z = norm_A * torch.sin(a[..., 1]) 66 | grid = torch.stack((x, y, z), dim=-1) 67 | 68 | # batched (stacked copies) 69 | if batch is not None: 70 | assert isinstance( 71 | batch, int 72 | ), f"ERR: batch needs to be integer: batch={batch}" 73 | assert ( 74 | batch > 0 75 | ), f"ERR: batch size needs to be larger than 0: batch={batch}" 76 | # FIXME: faster way of copying? 77 | grid = torch.cat([grid.unsqueeze(0)] * batch) 78 | # grid shape is (b, h, w, 3) 79 | 80 | return grid 81 | 82 | def convert_grid( 83 | M: torch.Tensor, h_equi: int, w_equi: int, method: str = "robust" 84 | ) -> torch.Tensor: 85 | # convert to rotation 86 | phi = torch.asin(M[..., 2] / torch.norm(M, dim=-1)) 87 | theta = torch.atan2(M[..., 1], M[..., 0]) 88 | 89 | if method == "robust": 90 | ui = (theta - np.pi) * w_equi / (2 * np.pi) 91 | uj = (phi - np.pi / 2) * h_equi / np.pi 92 | ui += 0.5 93 | uj += 0.5 94 | ui %= w_equi 95 | uj %= h_equi 96 | elif method == "faster": 97 | ui = (theta - np.pi) * w_equi / (2 * np.pi) 98 | uj = (phi - np.pi / 2) * h_equi / np.pi 99 | ui += 0.5 100 | uj += 0.5 101 | ui = torch.where(ui < 0, ui + w_equi, ui) 102 | ui = torch.where(ui >= w_equi, ui - w_equi, ui) 103 | uj = torch.where(uj < 0, uj + h_equi, uj) 104 | uj = torch.where(uj >= h_equi, uj - h_equi, uj) 105 | else: 106 | raise ValueError(f"ERR: {method} is not supported") 107 | 108 | # stack the pixel maps into a grid 109 | grid = torch.stack((uj, ui), dim=-3) 110 | 111 | return grid 112 | 113 | def create_rotation_matrix( 114 | roll: float, 115 | pitch: float, 116 | yaw: float, 117 | z_down: bool = True, 118 | dtype: torch.dtype = torch.float32, 119 | device: torch.device = torch.device("cpu"), 120 | ) -> torch.Tensor: 121 | """Create Rotation Matrix 122 | 123 | params: 124 | - roll, pitch, yaw (float): in radians 125 | - z_down (bool): flips pitch and yaw directions 126 | - dtype (torch.dtype): data types 127 | 128 | returns: 129 | - R (torch.Tensor): 3x3 rotation matrix 130 | """ 131 | 132 | # calculate rotation about the x-axis 133 | R_x = torch.tensor( 134 | [ 135 | [1.0, 0.0, 0.0], 136 | [0.0, np.cos(roll), -np.sin(roll)], 137 | [0.0, np.sin(roll), np.cos(roll)], 138 | ], 139 | dtype=dtype, 140 | ) 141 | # calculate rotation about the y-axis 142 | if not z_down: 143 | pitch = -pitch 144 | R_y = torch.tensor( 145 | [ 146 | [np.cos(pitch), 0.0, np.sin(pitch)], 147 | [0.0, 1.0, 0.0], 148 | [-np.sin(pitch), 0.0, np.cos(pitch)], 149 | ], 150 | dtype=dtype, 151 | ) 152 | # calculate rotation about the z-axis 153 | if not z_down: 154 | yaw = -yaw 155 | R_z = torch.tensor( 156 | [ 157 | [np.cos(yaw), -np.sin(yaw), 0.0], 158 | [np.sin(yaw), np.cos(yaw), 0.0], 159 | [0.0, 0.0, 1.0], 160 | ], 161 | dtype=dtype, 162 | ) 163 | R = R_z @ R_y @ R_x 164 | return R.to(device) 165 | 166 | def create_rotation_matrices( 167 | rots: List[Dict[str, float]], 168 | z_down: bool = True, 169 | dtype: torch.dtype = torch.float32, 170 | device: torch.device = torch.device("cpu"), 171 | ) -> torch.Tensor: 172 | """Create rotation matrices from batch of rotations 173 | 174 | This methods creates a bx3x3 np.ndarray where `b` referes to the number 175 | of rotations (rots) given in the input 176 | """ 177 | 178 | R = torch.empty((len(rots), 3, 3), dtype=dtype, device=device) 179 | for i, rot in enumerate(rots): 180 | # FIXME: maybe default to `create_rotation_matrix_at_once`? 181 | # NOTE: at_once is faster with cpu, while slower on GPU 182 | R[i, ...] = create_rotation_matrix( 183 | **rot, z_down=z_down, dtype=dtype, device=device 184 | ) 185 | 186 | return R 187 | 188 | def rotate_equi(equi_tensor, rots, z_down=True): 189 | if equi_tensor.ndim == 4: # image 190 | b,c,h,w = equi_tensor.shape 191 | batch_size = b 192 | resize_back = False 193 | else: 194 | b,c,f,h,w = equi_tensor.shape 195 | equi_tensor = rearrange(equi_tensor, 'b c f h w -> (b f) c h w') 196 | batch_size = b*f 197 | resize_back = True 198 | 199 | out = torch.empty((batch_size, c, h, w)).to(device=equi_tensor.device, dtype=equi_tensor.dtype) 200 | 201 | m = create_normalized_grid( 202 | height=h, width=w, batch=batch_size, device=equi_tensor.device, dtype=equi_tensor.dtype 203 | ) 204 | m = m.unsqueeze(-1) 205 | 206 | # create batched rotation matrices 207 | R = create_rotation_matrices( 208 | rots=rots, z_down=z_down, device=equi_tensor.device, dtype=equi_tensor.dtype 209 | ) 210 | 211 | # rotate the grid 212 | M = matmul(m, R) 213 | 214 | grid = convert_grid(M=M, h_equi=h, w_equi=w, method="robust") 215 | # grid sample 216 | out = torch_grid_sample( 217 | img=equi_tensor, 218 | grid=grid, 219 | out=out, # FIXME: is this necessary? 220 | backend="pure", 221 | ) 222 | 223 | if resize_back: 224 | out = rearrange(out,'(b f) c h w -> b c f h w',f=f) 225 | return out -------------------------------------------------------------------------------- /utils/get_cubemap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .native import native_bilinear 5 | from typing import Optional 6 | from einops import rearrange 7 | def linear_interp(v0, v1, d, L): 8 | return v0 * (1 - d) / L + v1 * d / L 9 | 10 | 11 | def interp2d(q00, q10, q01, q11, dy, dx): 12 | f0 = linear_interp(q00, q01, dx, 1) 13 | f1 = linear_interp(q10, q11, dx, 1) 14 | return linear_interp(f0, f1, dy, 1) 15 | 16 | 17 | def bilinear( 18 | img: torch.Tensor, grid: torch.Tensor, out: torch.Tensor 19 | ) -> torch.Tensor: 20 | b, _, h, w = img.shape 21 | 22 | min_grid = torch.floor(grid).type(torch.int64) 23 | max_grid = min_grid + 1 24 | d_grid = grid - min_grid 25 | 26 | min_grid[:, 0, :, :] %= h 27 | min_grid[:, 1, :, :] %= w 28 | max_grid[:, 0, :, :] %= h 29 | max_grid[:, 1, :, :] %= w 30 | 31 | # FIXME: anyway to do efficient batch? 32 | for i in range(b): 33 | dy = d_grid[i, 0, ...] 34 | dx = d_grid[i, 1, ...] 35 | min_ys = min_grid[i, 0, ...] 36 | min_xs = min_grid[i, 1, ...] 37 | max_ys = max_grid[i, 0, ...] 38 | max_xs = max_grid[i, 1, ...] 39 | 40 | min_ys %= h 41 | min_xs %= w 42 | 43 | p00 = img[i][:, min_ys, min_xs] 44 | p10 = img[i][:, max_ys, min_xs] 45 | p01 = img[i][:, min_ys, max_xs] 46 | p11 = img[i][:, min_ys, max_xs] 47 | 48 | out[i, ...] = interp2d(p00, p10, p01, p11, dy, dx) 49 | 50 | return out 51 | 52 | 53 | def create_xyz_grid( 54 | w_face: int, 55 | batch: Optional[int] = None, 56 | dtype: torch.dtype = torch.float32, 57 | device: torch.device = torch.device("cpu"), 58 | ) -> torch.Tensor: 59 | """xyz coordinates of the faces of the cube""" 60 | 61 | ratio = (w_face - 1) / w_face 62 | 63 | out = torch.zeros((w_face, w_face * 6, 3), dtype=dtype, device=device) 64 | rng = torch.linspace( 65 | -0.5 * ratio, 0.5 * ratio, w_face, dtype=dtype, device=device 66 | ) 67 | 68 | # NOTE: https://github.com/pytorch/pytorch/issues/15301 69 | # Torch meshgrid behaves differently than numpy 70 | 71 | # Front face (x = 0.5) 72 | out[:, 0 * w_face : 1 * w_face, [2, 1]] = torch.stack( 73 | torch.meshgrid([-rng, rng], indexing="ij"), -1 74 | ) 75 | out[:, 0 * w_face : 1 * w_face, 0] = 0.5 76 | 77 | # Right face (y = -0.5) 78 | out[:, 1 * w_face : 2 * w_face, [2, 0]] = torch.stack( 79 | torch.meshgrid([-rng, -rng], indexing="ij"), -1 80 | ) 81 | out[:, 1 * w_face : 2 * w_face, 1] = 0.5 82 | 83 | # Back face (x = -0.5) 84 | out[:, 2 * w_face : 3 * w_face, [2, 1]] = torch.stack( 85 | torch.meshgrid([-rng, -rng], indexing="ij"), -1 86 | ) 87 | out[:, 2 * w_face : 3 * w_face, 0] = -0.5 88 | 89 | # Left face (y = 0.5) 90 | out[:, 3 * w_face : 4 * w_face, [2, 0]] = torch.stack( 91 | torch.meshgrid([-rng, rng], indexing="ij"), -1 92 | ) 93 | out[:, 3 * w_face : 4 * w_face, 1] = -0.5 94 | 95 | # Up face (z = 0.5) 96 | out[:, 4 * w_face : 5 * w_face, [0, 1]] = torch.stack( 97 | torch.meshgrid([rng, rng], indexing="ij"), -1 98 | ) 99 | out[:, 4 * w_face : 5 * w_face, 2] = 0.5 100 | 101 | # Down face (z = -0.5) 102 | out[:, 5 * w_face : 6 * w_face, [0, 1]] = torch.stack( 103 | torch.meshgrid([-rng, rng], indexing="ij"), -1 104 | ) 105 | out[:, 5 * w_face : 6 * w_face, 2] = -0.5 106 | 107 | if batch is not None: 108 | assert isinstance( 109 | batch, int 110 | ), f"ERR: batch needs to be integer: batch={batch}" 111 | assert ( 112 | batch > 0 113 | ), f"ERR: batch size needs to be larger than 0: batch={batch}" 114 | # FIXME: faster way of copying? 115 | out = torch.cat([out.unsqueeze(0)] * batch) 116 | # grid shape is (b, h, w, 3) 117 | 118 | return out 119 | def convert_grid( 120 | xyz: torch.Tensor, h_equi: int, w_equi: int, method: str = "robust" 121 | ) -> torch.Tensor: 122 | # convert to rotation 123 | phi = torch.asin(xyz[..., 2] / torch.norm(xyz, dim=-1)) 124 | theta = torch.atan2(xyz[..., 1], xyz[..., 0]) 125 | 126 | if method == "robust": 127 | ui = (theta / (2 * np.pi) - 0.5) * w_equi - 0.5 128 | uj = (0.5 - phi / np.pi) * h_equi - 0.5 129 | ui %= w_equi 130 | uj %= h_equi 131 | elif method == "faster": 132 | ui = (theta / (2 * np.pi) - 0.5) * w_equi - 0.5 133 | uj = (0.5 - phi / np.pi) * h_equi - 0.5 134 | ui = torch.where(ui < 0, ui + w_equi, ui) 135 | ui = torch.where(ui >= w_equi, ui - w_equi, ui) 136 | uj = torch.where(uj < 0, uj + h_equi, uj) 137 | uj = torch.where(uj >= h_equi, uj - h_equi, uj) 138 | else: 139 | raise ValueError(f"ERR: {method} is not supported") 140 | 141 | # stack the pixel maps into a grid 142 | grid = torch.stack((uj, ui), dim=-3) 143 | return grid 144 | 145 | def torch_grid_sample( 146 | img: torch.Tensor, 147 | grid: torch.Tensor, 148 | out: Optional[torch.Tensor] = None, 149 | mode: str = "bilinear", 150 | backend: str = "native", 151 | ) -> torch.Tensor: 152 | """Torch grid sampling algorithm 153 | 154 | params: 155 | - img (torch.Tensor) 156 | - grid (torch.Tensor) 157 | - out (Optional[torch.Tensor]): defaults to None 158 | - mode (str): ('bilinear', 'bicubic', 'nearest') 159 | - backend (str): ('native', 'pure') 160 | 161 | return: 162 | - img (torch.Tensor) 163 | 164 | NOTE: for `backend`, `pure` is relatively efficient since grid doesn't need 165 | to be in the same device as the `img`. However, `native` is faster. 166 | 167 | NOTE: for `pure` backends, we need to pass reference to `out`. 168 | 169 | NOTE: for `native` backends, we should pass anything for `out` 170 | 171 | """ 172 | 173 | if backend == "native": 174 | if out is not None: 175 | # NOTE: out is created 176 | warnings.warn( 177 | "don't need to pass preallocated `out` to `grid_sample`" 178 | ) 179 | assert img.device == grid.device, ( 180 | f"ERR: when using {backend}, the devices of `img` and `grid` need" 181 | "to be on the same device" 182 | ) 183 | if mode == "nearest": 184 | out = native_nearest(img, grid) 185 | elif mode == "bilinear": 186 | out = native_bilinear(img, grid) 187 | elif mode == "bicubic": 188 | out = native_bicubic(img, grid) 189 | else: 190 | raise ValueError(f"ERR: {mode} is not supported") 191 | elif backend == "pure": 192 | # NOTE: img and grid can be on different devices, but grid should be on the cpu 193 | # FIXME: since bilinear implementation depends on `grid` being on device, I'm removing 194 | # this warning and will put `grid` onto the same device until a fix is found 195 | # if grid.device.type == "cuda": 196 | # warnings.warn("input `grid` should be on the cpu, but got a cuda tensor") 197 | assert ( 198 | out is not None 199 | ), "ERR: need to pass reference to `out`, but got None" 200 | assert img.device == grid.device, ( 201 | f"ERR: when using {backend}, the devices of `img` and `grid` need" 202 | "to be on the same device" 203 | ) 204 | if mode == "nearest": 205 | out = nearest(img, grid, out) 206 | elif mode == "bilinear": 207 | out = bilinear(img, grid, out) 208 | elif mode == "bicubic": 209 | out = bicubic(img, grid, out) 210 | else: 211 | raise ValueError(f"ERR: {mode} is not supported") 212 | else: 213 | raise ValueError(f"ERR: {backend} is not supported") 214 | 215 | return out 216 | 217 | 218 | 219 | def get_cubemap(equi_tensor): 220 | if equi_tensor.ndim == 4: # image 221 | b,c,h,w = equi_tensor.shape 222 | batch_size = b 223 | resize_back = False 224 | else: 225 | b,c,f,h,w = equi_tensor.shape 226 | equi_tensor = rearrange(equi_tensor, 'b c f h w -> (b f) c h w') 227 | batch_size = b*f 228 | resize_back = True 229 | 230 | res = w//4 231 | out = torch.empty((batch_size, c, res, res * 6)).to(device=equi_tensor.device, dtype=equi_tensor.dtype) 232 | xyz = create_xyz_grid(res,batch_size,device=equi_tensor.device, dtype=equi_tensor.dtype) 233 | # TODO: add rotation 234 | grid = convert_grid(xyz=xyz, h_equi=h, w_equi=w, method="robust").to(device=equi_tensor.device) 235 | out = torch_grid_sample(img=equi_tensor, grid=grid, out=out, backend="pure") 236 | 237 | out = rearrange(out, 'b c h (n w) -> b n c h w',n=6) 238 | if resize_back: 239 | out = rearrange(out,'(b f) n c h w -> b n c f h w',f=f) 240 | return out -------------------------------------------------------------------------------- /utils/affine_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | from einops import rearrange 5 | import math 6 | from . import save 7 | def get_condition_tensors(cubeface, cube_tensors, new_h=None, new_w=None): 8 | 9 | b,n,c,f,h,w, = cube_tensors.shape 10 | device = cube_tensors.device 11 | cube_tensors = rearrange(cube_tensors,'b n c f h w -> (b n f) c h w') 12 | matrix_ = get_transform_matrix_1(cubeface,cube_tensors.shape, new_h, new_w).to(device=device) 13 | affine_ = F.grid_sample(cube_tensors, matrix_, mode='bilinear', align_corners=False) 14 | affine_ = rearrange(affine_,'(b f) c h w -> b c f h w',f=f) 15 | zero_tensor = torch.zeros_like(affine_) 16 | 17 | if cubeface=='L': 18 | row1 = torch.cat([affine_,zero_tensor],dim=-1) 19 | row2 = torch.zeros_like(row1) 20 | 21 | if cubeface=='F': 22 | row1 = torch.cat([zero_tensor,affine_],dim=-1) 23 | row2 = torch.zeros_like(row1) 24 | 25 | if cubeface=='B': 26 | row2 = torch.cat([affine_,zero_tensor],dim=-1) 27 | row1 = torch.zeros_like(row2) 28 | 29 | if cubeface=='R': 30 | row2 = torch.cat([zero_tensor,affine_],dim=-1) 31 | row1 = torch.zeros_like(row2) 32 | 33 | res = torch.cat([row1,row2],dim=-2) 34 | return res 35 | 36 | 37 | def get_mask(cubeface, tensors_shape): 38 | b,c,h,w = tensors_shape 39 | if cubeface == 'null': 40 | mask = torch.zeros((b,1,h,w)) 41 | else: 42 | mask = torch.ones((b,1,h,w)) 43 | 44 | matrix_ = get_transform_matrix_1('F', tensors_shape) 45 | mask = F.grid_sample(mask, matrix_, mode='bilinear',align_corners=False) 46 | 47 | return mask 48 | 49 | def get_transform_matrix_1(cubeface, tensors_shape, scale=None, new_h=None, new_w=None, reverse=False): 50 | if cubeface == 'F': 51 | angle = -45 52 | elif cubeface == 'L': 53 | angle = 45 54 | elif cubeface == 'B': 55 | angle = 135 56 | elif cubeface == 'R': 57 | angle = -135 58 | elif cubeface == 'D': 59 | angle = -45 60 | elif cubeface == 'U': 61 | angle = -135 62 | 63 | if reverse: 64 | angle = -angle 65 | b,c,h,w = tensors_shape 66 | theta = np.radians(angle) 67 | if scale is None: 68 | scale = np.sqrt(2) 69 | if new_h is None: 70 | new_h = math.ceil(scale*h/8)*8 71 | if new_w is None: 72 | new_w = math.ceil(scale*w/8)*8 73 | 74 | cos_theta = np.cos(theta)*(scale) 75 | sin_theta = np.sin(theta)*(scale) 76 | transform_matrix = torch.tensor( 77 | [[cos_theta, -sin_theta, 0], 78 | [sin_theta, cos_theta, 0]] 79 | ).float().unsqueeze(0).repeat(b,1,1) # (1, 2, 3) 80 | grid = F.affine_grid(transform_matrix, (b, c, new_h, new_w), align_corners=False) 81 | return grid 82 | 83 | def get_transform_matrix(cubeface, tensors_shape): 84 | if cubeface == 'F': 85 | angle = -45 86 | elif cubeface == 'L': 87 | angle = 45 88 | elif cubeface == 'B': 89 | angle = 135 90 | elif cubeface == 'R': 91 | angle = -135 92 | elif cubeface == 'D': 93 | angle = -45 94 | 95 | b,c,h,w = tensors_shape 96 | theta = np.radians(angle) 97 | scale = np.sqrt(2) 98 | new_h = math.ceil(scale*h/8)*8 99 | new_w = math.ceil(scale*w/8)*8 100 | 101 | cos_theta = np.cos(theta)*(scale/3) 102 | sin_theta = np.sin(theta)*(scale/3) 103 | transform_matrix = torch.tensor( 104 | [[cos_theta, -sin_theta, 0], 105 | [sin_theta, cos_theta, 0]] 106 | ).float().unsqueeze(0).repeat(b,1,1) # (1, 2, 3) 107 | grid = F.affine_grid(transform_matrix, (b, c, new_h, new_w), align_corners=False) 108 | return grid 109 | 110 | 111 | def affine_U(U_tensors): 112 | # b c h w 113 | U_tensors = torch.cat(torch.chunk(U_tensors,2,dim=-2)[: :-1],dim=-2) 114 | b,c,H,W = U_tensors.shape 115 | 116 | y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') 117 | x = x.float() 118 | y = y.float() 119 | 120 | # 坐标系参数 121 | center = torch.tensor([H/2, W/2], device=x.device) 122 | 123 | # 计算相对坐标(考虑图像y轴方向) 124 | dx = x - center[0] 125 | dy = center[1] - y # 反转y轴 126 | 127 | # 极坐标转换 128 | theta = torch.atan2(dy, dx) 129 | r_prime = torch.sqrt(dx**2 + dy**2) 130 | 131 | # 计算菱形边界函数 132 | cos_theta = torch.cos(theta) 133 | sin_theta = torch.sin(theta) 134 | denominator = torch.abs(cos_theta) + torch.abs(sin_theta) 135 | #r_max = (H/2) / (denominator + 1e-6) # 防止除零 136 | r_max = (H/2) / (denominator) 137 | # 有效点判断 138 | valid_mask = r_prime <= r_max 139 | 140 | # 计算原图坐标 141 | r = torch.where(valid_mask, r_prime * ((H/2) / r_max), 0.0) 142 | px = center[0] + r * cos_theta 143 | py = center[1] - r * sin_theta # 注意y轴反转 144 | 145 | # 归一化到[-1, 1] 146 | grid_x = (px / (W-1)) * 2 - 1 147 | grid_y = (py / (H-1)) * 2 - 1 148 | 149 | # 无效点设置为超出范围 150 | grid_x = torch.where(valid_mask, grid_x, -2.0) 151 | grid_y = torch.where(valid_mask, grid_y, -2.0) 152 | 153 | grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0).repeat(b,1,1,1).to(device=U_tensors.device) 154 | affine_U_tensors = F.grid_sample(U_tensors, grid, mode='bilinear', align_corners=False) 155 | affine_U_tensors = torch.cat(torch.chunk(affine_U_tensors,2,dim=-2)[: :-1],dim=-2) 156 | return affine_U_tensors 157 | 158 | def get_cat_tensors(cube_tensors, cube): # cube_tensors is a list 159 | 160 | F_tensors = cube_tensors[0] 161 | R_tensors = cube_tensors[1] 162 | B_tensors = cube_tensors[2] 163 | L_tensors = cube_tensors[3] 164 | U_tensors = cube_tensors[4] 165 | D_tensors = cube_tensors[5] 166 | 167 | b,c,h,w = F_tensors.shape 168 | 169 | cat_tensors = torch.zeros((b,c,3*h,3*w)).to(device=F_tensors.device, dtype=F_tensors.dtype) 170 | 171 | if cube == 'D': 172 | cat_tensors[:, :, 0*h:1*h, 1*w:2*w] = F_tensors 173 | cat_tensors[:, :, 1*h:2*h, 0*w:1*w] = torch.rot90(L_tensors, k=1, dims=(2, 3)) 174 | cat_tensors[:, :, 1*h:2*h, 1*w:2*w] = D_tensors 175 | cat_tensors[:, :, 1*h:2*h, 2*w:3*w] = torch.rot90(R_tensors, k=-1, dims=(2, 3)) 176 | cat_tensors[:, :, 2*h:3*h, 1*w:2*w] = torch.rot90(B_tensors, k=2, dims=(2, 3)) 177 | 178 | if cube == 'F': 179 | cat_tensors[:, :, 0*h:1*h, 1*w:2*w] = affine_U(U_tensors) 180 | cat_tensors[:, :, 1*h:2*h, 0*w:1*w] = L_tensors 181 | cat_tensors[:, :, 1*h:2*h, 1*w:2*w] = F_tensors 182 | cat_tensors[:, :, 1*h:2*h, 2*w:3*w] = R_tensors 183 | cat_tensors[:, :, 2*h:3*h, 1*w:2*w] = D_tensors 184 | 185 | if cube == 'L': 186 | U_tensors = torch.rot90(U_tensors, k=1, dims=(2, 3)) 187 | cat_tensors[:, :, 0*h:1*h, 1*w:2*w] = affine_U(U_tensors) 188 | cat_tensors[:, :, 1*h:2*h, 0*w:1*w] = B_tensors 189 | cat_tensors[:, :, 1*h:2*h, 1*w:2*w] = L_tensors 190 | cat_tensors[:, :, 1*h:2*h, 2*w:3*w] = F_tensors 191 | cat_tensors[:, :, 2*h:3*h, 1*w:2*w] = torch.rot90(D_tensors, k=-1, dims=(2, 3)) 192 | 193 | if cube == 'R': 194 | U_tensors = torch.rot90(U_tensors, k=-1, dims=(2, 3)) 195 | cat_tensors[:, :, 0*h:1*h, 1*w:2*w] = affine_U(U_tensors) 196 | cat_tensors[:, :, 1*h:2*h, 0*w:1*w] = F_tensors 197 | cat_tensors[:, :, 1*h:2*h, 1*w:2*w] = R_tensors 198 | cat_tensors[:, :, 1*h:2*h, 2*w:3*w] = B_tensors 199 | cat_tensors[:, :, 2*h:3*h, 1*w:2*w] = torch.rot90(D_tensors, k=1, dims=(2, 3)) 200 | 201 | if cube == 'B': 202 | U_tensors = torch.rot90(U_tensors, k=2, dims=(2, 3)) 203 | cat_tensors[:, :, 0*h:1*h, 1*w:2*w] = affine_U(U_tensors) 204 | cat_tensors[:, :, 1*h:2*h, 0*w:1*w] = R_tensors 205 | cat_tensors[:, :, 1*h:2*h, 1*w:2*w] = B_tensors 206 | cat_tensors[:, :, 1*h:2*h, 2*w:3*w] = L_tensors 207 | cat_tensors[:, :, 2*h:3*h, 1*w:2*w] = torch.rot90(D_tensors, k=2, dims=(2, 3)) 208 | 209 | return cat_tensors 210 | 211 | def get_affine_tensors(cube_tensors, concat=False): 212 | b,n,c,f,h,w = cube_tensors[0].shape 213 | device = cube_tensors[0].device 214 | F_tensors = rearrange(cube_tensors[0],'b n c f h w -> (b n f) c h w') 215 | R_tensors = rearrange(cube_tensors[1],'b n c f h w -> (b n f) c h w') 216 | B_tensors = rearrange(cube_tensors[2],'b n c f h w -> (b n f) c h w') 217 | L_tensors = rearrange(cube_tensors[3],'b n c f h w -> (b n f) c h w') 218 | U_tensors = rearrange(cube_tensors[4],'b n c f h w -> (b n f) c h w') 219 | D_tensors = rearrange(cube_tensors[5],'b n c f h w -> (b n f) c h w') 220 | 221 | rearranged_list = [F_tensors,R_tensors,B_tensors,L_tensors,U_tensors,D_tensors] 222 | tensor_shape = F_tensors.shape 223 | 224 | cat_F = get_cat_tensors(rearranged_list, 'F') 225 | matrix_F = get_transform_matrix('F',tensor_shape).to(device=device) 226 | affine_F = F.grid_sample(cat_F, matrix_F, mode='bilinear', align_corners=False) 227 | affine_F = rearrange(affine_F,'(b f) c h w -> b c f h w',f=f) 228 | 229 | cat_L = get_cat_tensors(rearranged_list, 'L') 230 | matrix_L = get_transform_matrix('L',tensor_shape).to(device=device) 231 | affine_L = F.grid_sample(cat_L, matrix_L, mode='bilinear', align_corners=False) 232 | affine_L = rearrange(affine_L,'(b f) c h w -> b c f h w',f=f) 233 | 234 | cat_B = get_cat_tensors(rearranged_list, 'B') 235 | matrix_B = get_transform_matrix('B',tensor_shape).to(device=device) 236 | affine_B = F.grid_sample(cat_B, matrix_B, mode='bilinear', align_corners=False) 237 | affine_B = rearrange(affine_B,'(b f) c h w -> b c f h w',f=f) 238 | 239 | cat_R = get_cat_tensors(rearranged_list, 'R') 240 | matrix_R = get_transform_matrix('R',tensor_shape).to(device=device) 241 | affine_R = F.grid_sample(cat_R, matrix_R, mode='bilinear', align_corners=False) 242 | affine_R = rearrange(affine_R,'(b f) c h w -> b c f h w',f=f) 243 | 244 | if concat: 245 | L_F = torch.cat([affine_L,affine_F],dim=-1) 246 | B_R = torch.cat([affine_B,affine_R],dim=-1) 247 | L_F_B_R = torch.cat([L_F,B_R],dim=-2) 248 | else: 249 | L_F_B_R = torch.stack([affine_L,affine_F,affine_B,affine_R], dim=1) 250 | # L_F = torch.cat([affine_L,affine_F],dim=-1) 251 | # B_R = torch.cat([affine_B,affine_R],dim=-1) 252 | 253 | # tensors = torch.cat([L_F,B_R],dim=-2) 254 | 255 | return L_F_B_R -------------------------------------------------------------------------------- /utils/cube2equi.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | import numpy as np 5 | from scipy.ndimage import map_coordinates 6 | import cv2 7 | # from .get_cubemap import torch_grid_sample 8 | 9 | # def _equirect_facetype(h: int, w: int) -> torch.Tensor: 10 | # """0F 1R 2B 3L 4U 5D""" 11 | 12 | # int_dtype = torch.int64 13 | 14 | # w_ratio = (w - 1) / w 15 | # h_ratio = (h - 1) / h 16 | 17 | # tp = torch.roll( 18 | # torch.arange(4) # 1 19 | # .repeat_interleave(w // 4) # 2 same as np.repeat 20 | # .unsqueeze(0) 21 | # .transpose(0, 1) # 3 22 | # .repeat(1, h) # 4 23 | # .view(-1, h) # 5 24 | # .transpose(0, 1), # 6 25 | # shifts=3 * w // 8, 26 | # dims=1, 27 | # ) 28 | 29 | # # Prepare ceil mask 30 | # mask = torch.zeros((h, w // 4), dtype=torch.bool) 31 | # idx = torch.linspace(-(math.pi * w_ratio), math.pi * w_ratio, w // 4) / 4 32 | # idx = h // 2 - torch.round( 33 | # torch.atan(torch.cos(idx)) * h / (math.pi * h_ratio) 34 | # ) 35 | # idx = idx.type(int_dtype) 36 | # for i, j in enumerate(idx): 37 | # mask[:j, i] = 1 38 | # mask = torch.roll(torch.cat([mask] * 4, 1), 3 * w // 8, 1) 39 | 40 | # tp[mask] = 4 41 | # tp[torch.flip(mask, dims=(0,))] = 5 42 | 43 | # return tp.type(int_dtype) 44 | 45 | 46 | # def create_equi_grid( 47 | # h_out: int, 48 | # w_out: int, 49 | # w_face: int, 50 | # batch: int, 51 | # dtype: torch.dtype = torch.float32, 52 | # device: torch.device = torch.device("cpu"), 53 | # ) -> torch.Tensor: 54 | # w_ratio = (w_out - 1) / w_out 55 | # h_ratio = (h_out - 1) / h_out 56 | # theta = torch.linspace( 57 | # -(math.pi * w_ratio), 58 | # math.pi * w_ratio, 59 | # steps=w_out, 60 | # dtype=dtype, 61 | # device=device, 62 | # ) 63 | # phi = torch.linspace( 64 | # (math.pi * h_ratio) / 2, 65 | # -(math.pi * h_ratio) / 2, 66 | # steps=h_out, 67 | # dtype=dtype, 68 | # device=device, 69 | # ) 70 | # phi, theta = torch.meshgrid([phi, theta], indexing="ij") 71 | 72 | # # Get face id to each pixel: 0F 1R 2B 3L 4U 5D 73 | # tp = _equirect_facetype(h_out, w_out) 74 | 75 | # # xy coordinate map 76 | # coor_x = torch.zeros((h_out, w_out), dtype=dtype, device=device) 77 | # coor_y = torch.zeros((h_out, w_out), dtype=dtype, device=device) 78 | 79 | # # FIXME: there's a bug where left section (3L) has artifacts 80 | # # on top and bottom 81 | # # It might have to do with 4U or 5D 82 | # for i in range(6): 83 | # mask = tp == i 84 | 85 | # if i < 4: 86 | # coor_x[mask] = 0.5 * torch.tan(theta[mask] - math.pi * i / 2) 87 | # coor_y[mask] = ( 88 | # -0.5 89 | # * torch.tan(phi[mask]) 90 | # / torch.cos(theta[mask] - math.pi * i / 2) 91 | # ) 92 | # elif i == 4: 93 | # c = 0.5 * torch.tan(math.pi / 2 - phi[mask]) 94 | # coor_x[mask] = c * torch.sin(theta[mask]) 95 | # coor_y[mask] = c * torch.cos(theta[mask]) 96 | # elif i == 5: 97 | # c = 0.5 * torch.tan(math.pi / 2 - torch.abs(phi[mask])) 98 | # coor_x[mask] = c * torch.sin(theta[mask]) 99 | # coor_y[mask] = -c * torch.cos(theta[mask]) 100 | 101 | # # Final renormalize 102 | # coor_x = torch.clamp( 103 | # torch.clamp(coor_x + 0.5, 0, 1) * w_face, 0, w_face - 1 104 | # ) 105 | # coor_y = torch.clamp( 106 | # torch.clamp(coor_y + 0.5, 0, 1) * w_face, 0, w_face - 1 107 | # ) 108 | 109 | # # change x axis of the x coordinate map 110 | # for i in range(6): 111 | # mask = tp == i 112 | # coor_x[mask] = coor_x[mask] + w_face * i 113 | 114 | # # repeat batch 115 | # coor_x = coor_x.repeat(batch, 1, 1) - 0.5 116 | # coor_y = coor_y.repeat(batch, 1, 1) - 0.5 117 | 118 | # grid = torch.stack((coor_y, coor_x), dim=-3).to(device) 119 | # return grid 120 | 121 | # def cube2equi(cubemap): 122 | # b,c,h,w = cubemap.shape # w=6h 123 | # cubemap = cubemap * 255 124 | 125 | # dtype = cubemap.dtype 126 | # device = cubemap.device 127 | # grid = create_equi_grid( 128 | # h_out=2*h, 129 | # w_out=4*h, 130 | # w_face=h, 131 | # batch=b, 132 | # dtype=dtype, 133 | # device=device, 134 | # ) 135 | # out = torch.empty((b, c, 2*h, 4*h), dtype=dtype, device=device) 136 | # out = torch_grid_sample( 137 | # img=cubemap, grid=grid, out=out, backend="pure" 138 | # ) 139 | # out = out / 255 140 | # return out 141 | 142 | def sample_cubefaces(cube_faces, tp, coor_y, coor_x, order): 143 | cube_faces = cube_faces.copy() 144 | cube_faces[1] = np.flip(cube_faces[1], 1) 145 | cube_faces[2] = np.flip(cube_faces[2], 1) 146 | cube_faces[4] = np.flip(cube_faces[4], 0) 147 | 148 | # Pad up down 149 | pad_ud = np.zeros((6, 2, cube_faces.shape[2])) 150 | pad_ud[0, 0] = cube_faces[5, 0, :] 151 | pad_ud[0, 1] = cube_faces[4, -1, :] 152 | pad_ud[1, 0] = cube_faces[5, :, -1] 153 | pad_ud[1, 1] = cube_faces[4, ::-1, -1] 154 | pad_ud[2, 0] = cube_faces[5, -1, ::-1] 155 | pad_ud[2, 1] = cube_faces[4, 0, ::-1] 156 | pad_ud[3, 0] = cube_faces[5, ::-1, 0] 157 | pad_ud[3, 1] = cube_faces[4, :, 0] 158 | pad_ud[4, 0] = cube_faces[0, 0, :] 159 | pad_ud[4, 1] = cube_faces[2, 0, ::-1] 160 | pad_ud[5, 0] = cube_faces[2, -1, ::-1] 161 | pad_ud[5, 1] = cube_faces[0, -1, :] 162 | cube_faces = np.concatenate([cube_faces, pad_ud], 1) 163 | 164 | # Pad left right 165 | pad_lr = np.zeros((6, cube_faces.shape[1], 2)) 166 | pad_lr[0, :, 0] = cube_faces[1, :, 0] 167 | pad_lr[0, :, 1] = cube_faces[3, :, -1] 168 | pad_lr[1, :, 0] = cube_faces[2, :, 0] 169 | pad_lr[1, :, 1] = cube_faces[0, :, -1] 170 | pad_lr[2, :, 0] = cube_faces[3, :, 0] 171 | pad_lr[2, :, 1] = cube_faces[1, :, -1] 172 | pad_lr[3, :, 0] = cube_faces[0, :, 0] 173 | pad_lr[3, :, 1] = cube_faces[2, :, -1] 174 | pad_lr[4, 1:-1, 0] = cube_faces[1, 0, ::-1] 175 | pad_lr[4, 1:-1, 1] = cube_faces[3, 0, :] 176 | pad_lr[5, 1:-1, 0] = cube_faces[1, -2, :] 177 | pad_lr[5, 1:-1, 1] = cube_faces[3, -2, ::-1] 178 | cube_faces = np.concatenate([cube_faces, pad_lr], 2) 179 | 180 | return map_coordinates(cube_faces, [tp, coor_y, coor_x], order=order, mode='wrap') 181 | 182 | def equirect_uvgrid(h, w): 183 | u = np.linspace(-np.pi, np.pi, num=w, dtype=np.float32) 184 | v = np.linspace(np.pi, -np.pi, num=h, dtype=np.float32) / 2 185 | 186 | return np.stack(np.meshgrid(u, v), axis=-1) 187 | 188 | def equirect_facetype(h, w): 189 | ''' 190 | 0F 1R 2B 3L 4U 5D 191 | ''' 192 | tp = np.roll(np.arange(4).repeat(w // 4)[None, :].repeat(h, 0), 3 * w // 8, 1) 193 | 194 | # Prepare ceil mask 195 | mask = np.zeros((h, w // 4), bool) 196 | idx = np.linspace(-np.pi, np.pi, w // 4) / 4 197 | idx = h // 2 - np.round(np.arctan(np.cos(idx)) * h / np.pi).astype(int) 198 | for i, j in enumerate(idx): 199 | mask[:j, i] = 1 200 | mask = np.roll(np.concatenate([mask] * 4, 1), 3 * w // 8, 1) 201 | 202 | tp[mask] = 4 203 | tp[np.flip(mask, 0)] = 5 204 | 205 | return tp.astype(np.int32) 206 | 207 | def c2e(cubemap, h, w, mode='bilinear', cube_format='horizon'): 208 | if mode == 'bilinear': 209 | order = 1 210 | elif mode == 'nearest': 211 | order = 0 212 | else: 213 | raise NotImplementedError('unknown mode') 214 | 215 | if cube_format == 'horizon': 216 | pass 217 | elif cube_format == 'list': 218 | cubemap = utils.cube_list2h(cubemap) 219 | elif cube_format == 'dict': 220 | cubemap = utils.cube_dict2h(cubemap) 221 | elif cube_format == 'dice': 222 | cubemap = utils.cube_dice2h(cubemap) 223 | else: 224 | raise NotImplementedError('unknown cube_format') 225 | assert len(cubemap.shape) == 3 226 | assert cubemap.shape[0] * 6 == cubemap.shape[1] 227 | assert w % 8 == 0 228 | face_w = cubemap.shape[0] # h 6*h 3 229 | 230 | uv = equirect_uvgrid(h, w) 231 | u, v = np.split(uv, 2, axis=-1) 232 | u = u[..., 0] 233 | v = v[..., 0] 234 | cube_faces = np.stack(np.split(cubemap, 6, 1), 0) 235 | 236 | # Get face id to each pixel: 0F 1R 2B 3L 4U 5D 237 | tp = equirect_facetype(h, w) 238 | coor_x = np.zeros((h, w)) 239 | coor_y = np.zeros((h, w)) 240 | 241 | for i in range(4): 242 | mask = (tp == i) 243 | coor_x[mask] = 0.5 * np.tan(u[mask] - np.pi * i / 2) 244 | coor_y[mask] = -0.5 * np.tan(v[mask]) / np.cos(u[mask] - np.pi * i / 2) 245 | 246 | mask = (tp == 4) 247 | c = 0.5 * np.tan(np.pi / 2 - v[mask]) 248 | coor_x[mask] = c * np.sin(u[mask]) 249 | coor_y[mask] = c * np.cos(u[mask]) 250 | 251 | mask = (tp == 5) 252 | c = 0.5 * np.tan(np.pi / 2 - np.abs(v[mask])) 253 | coor_x[mask] = c * np.sin(u[mask]) 254 | coor_y[mask] = -c * np.cos(u[mask]) 255 | 256 | # Final renormalize 257 | coor_x = (np.clip(coor_x, -0.5, 0.5) + 0.5) * face_w 258 | coor_y = (np.clip(coor_y, -0.5, 0.5) + 0.5) * face_w 259 | 260 | equirec = np.stack([ 261 | sample_cubefaces(cube_faces[..., i], tp, coor_y, coor_x, order=order) 262 | for i in range(cube_faces.shape[3]) 263 | ], axis=-1) 264 | 265 | return equirec 266 | 267 | def inpainting_mask(equi_h,equi_w): 268 | cube_h = equi_h//2 269 | cube = np.zeros((cube_h, 6*cube_h, 3), dtype=np.uint8) 270 | U = np.zeros((cube_h, cube_h, 3), dtype=np.uint8) 271 | 272 | radius = 32 273 | center = (radius,radius) 274 | 275 | circle = np.zeros((2*radius, 2*radius, 3), dtype=np.uint8) 276 | # # U[:,:4,:]= 1 277 | cv2.circle(circle, center, radius, (1, 1, 1), -1) 278 | circle = 1-circle 279 | U[:radius,:radius,:] = circle[:radius,:radius,:] 280 | U[:radius,cube_h-radius:,:] = circle[:radius,radius:,:] 281 | U[cube_h-radius:,:radius,:] = circle[radius:,:radius,:] 282 | U[cube_h-radius:,cube_h-radius:,:] = circle[radius:,radius:,:] 283 | 284 | mask_len = cube_h//3 285 | U[:mask_len,:4,:] = 1 286 | U[-mask_len:,:4,:] = 1 287 | U[:mask_len,-4:,:] = 1 288 | U[-mask_len:,-4:,:] = 1 289 | U[:4,:mask_len,:] = 1 290 | U[:4,-mask_len:,:] = 1 291 | U[-4:,:mask_len,:] = 1 292 | U[-4:,-mask_len:,:] = 1 293 | 294 | 295 | cube[:,cube_h*4:cube_h*5,:] = U 296 | equi = c2e(cube,equi_h,equi_w) 297 | equi=equi[:,:,0] 298 | equi[2:,:] = equi[:-2,:].copy() 299 | 300 | return equi 301 | 302 | def cube2equi(cubemap): 303 | 304 | b,c,h,w = cubemap.shape # w=6h 305 | # alignment 306 | cubemap[:,:,:,h:2*h] = torch.flip(cubemap[:,:,:,h:2*h],dims=[-1]) 307 | cubemap[:,:,:,2*h:3*h] = torch.flip(cubemap[:,:,:,2*h:3*h],dims=[-1]) 308 | cubemap[:,:,:,4*h:5*h] = torch.rot90(torch.flip(cubemap[:,:,:,4*h:5*h],dims=[-1]), k=2, dims=(-2, -1)) 309 | 310 | equi_h = h*2 311 | equi_w = h*4 312 | equi_list = [] 313 | inp_mask = inpainting_mask(equi_h,equi_w).astype(np.uint8) 314 | 315 | #equi_list.append(torch.from_numpy(inp_mask).unsqueeze(2).repeat(1,1,3).permute(2,0,1)) 316 | 317 | for frame in cubemap: 318 | cube = frame.permute(1,2,0).numpy() 319 | equi = (c2e(cube,equi_h,equi_w)*255).astype(np.uint8) 320 | 321 | equi = cv2.inpaint(equi, inp_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA) 322 | equi = equi / 255.0 323 | equi = torch.from_numpy(equi) 324 | equi_list.append(equi.permute(2,0,1)) 325 | equi = torch.stack(equi_list) 326 | return equi -------------------------------------------------------------------------------- /src/models/transformer_wan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import Any, Dict, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 24 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 25 | from diffusers.models.attention import FeedForward 26 | from diffusers.models.attention_processor import Attention 27 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed 28 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 29 | from diffusers.models.modeling_utils import ModelMixin 30 | from diffusers.models.normalization import FP32LayerNorm 31 | 32 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 34 | 35 | 36 | class WanAttnProcessor2_0: 37 | def __init__(self): 38 | if not hasattr(F, "scaled_dot_product_attention"): 39 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 40 | 41 | def __call__( 42 | self, 43 | attn: Attention, 44 | hidden_states: torch.Tensor, 45 | encoder_hidden_states: Optional[torch.Tensor] = None, 46 | attention_mask: Optional[torch.Tensor] = None, 47 | rotary_emb: Optional[torch.Tensor] = None, 48 | ) -> torch.Tensor: 49 | encoder_hidden_states_img = None 50 | if attn.add_k_proj is not None: 51 | encoder_hidden_states_img = encoder_hidden_states[:, :257] 52 | encoder_hidden_states = encoder_hidden_states[:, 257:] 53 | if encoder_hidden_states is None: 54 | encoder_hidden_states = hidden_states 55 | 56 | query = attn.to_q(hidden_states) 57 | key = attn.to_k(encoder_hidden_states) 58 | value = attn.to_v(encoder_hidden_states) 59 | 60 | if attn.norm_q is not None: 61 | query = attn.norm_q(query) 62 | if attn.norm_k is not None: 63 | key = attn.norm_k(key) 64 | 65 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 66 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 67 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 68 | 69 | if rotary_emb is not None: 70 | 71 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 72 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 73 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 74 | return x_out.type_as(hidden_states) 75 | 76 | query = apply_rotary_emb(query, rotary_emb) 77 | key = apply_rotary_emb(key, rotary_emb) 78 | 79 | # I2V task 80 | hidden_states_img = None 81 | if encoder_hidden_states_img is not None: 82 | key_img = attn.add_k_proj(encoder_hidden_states_img) 83 | key_img = attn.norm_added_k(key_img) 84 | value_img = attn.add_v_proj(encoder_hidden_states_img) 85 | 86 | key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 87 | value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 88 | 89 | hidden_states_img = F.scaled_dot_product_attention( 90 | query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False 91 | ) 92 | hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) 93 | hidden_states_img = hidden_states_img.type_as(query) 94 | 95 | hidden_states = F.scaled_dot_product_attention( 96 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 97 | ) 98 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 99 | hidden_states = hidden_states.type_as(query) 100 | 101 | if hidden_states_img is not None: 102 | hidden_states = hidden_states + hidden_states_img 103 | 104 | hidden_states = attn.to_out[0](hidden_states) 105 | hidden_states = attn.to_out[1](hidden_states) 106 | return hidden_states 107 | 108 | 109 | class WanImageEmbedding(torch.nn.Module): 110 | def __init__(self, in_features: int, out_features: int): 111 | super().__init__() 112 | 113 | self.norm1 = FP32LayerNorm(in_features) 114 | self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") 115 | self.norm2 = FP32LayerNorm(out_features) 116 | 117 | def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: 118 | hidden_states = self.norm1(encoder_hidden_states_image) 119 | hidden_states = self.ff(hidden_states) 120 | hidden_states = self.norm2(hidden_states) 121 | return hidden_states 122 | 123 | 124 | class WanTimeTextImageEmbedding(nn.Module): 125 | def __init__( 126 | self, 127 | dim: int, 128 | time_freq_dim: int, 129 | time_proj_dim: int, 130 | text_embed_dim: int, 131 | image_embed_dim: Optional[int] = None, 132 | ): 133 | super().__init__() 134 | 135 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) 136 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) 137 | self.act_fn = nn.SiLU() 138 | self.time_proj = nn.Linear(dim, time_proj_dim) 139 | self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") 140 | 141 | self.image_embedder = None 142 | if image_embed_dim is not None: 143 | self.image_embedder = WanImageEmbedding(image_embed_dim, dim) 144 | 145 | def forward( 146 | self, 147 | timestep: torch.Tensor, 148 | encoder_hidden_states: torch.Tensor, 149 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 150 | ): 151 | timestep = self.timesteps_proj(timestep) 152 | 153 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype 154 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: 155 | timestep = timestep.to(time_embedder_dtype) 156 | temb = self.time_embedder(timestep).type_as(encoder_hidden_states) 157 | timestep_proj = self.time_proj(self.act_fn(temb)) 158 | 159 | encoder_hidden_states = self.text_embedder(encoder_hidden_states) 160 | if encoder_hidden_states_image is not None: 161 | encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) 162 | 163 | return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image 164 | 165 | 166 | class WanRotaryPosEmbed(nn.Module): 167 | def __init__( 168 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 169 | ): 170 | super().__init__() 171 | 172 | self.attention_head_dim = attention_head_dim 173 | self.patch_size = patch_size 174 | self.max_seq_len = max_seq_len 175 | 176 | h_dim = w_dim = 2 * (attention_head_dim // 6) 177 | t_dim = attention_head_dim - h_dim - w_dim 178 | 179 | freqs = [] 180 | for dim in [t_dim, h_dim, w_dim]: 181 | freq = get_1d_rotary_pos_embed( 182 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 183 | ) 184 | freqs.append(freq) 185 | self.freqs = torch.cat(freqs, dim=1) 186 | 187 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 188 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 189 | p_t, p_h, p_w = self.patch_size 190 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w 191 | 192 | self.freqs = self.freqs.to(hidden_states.device) 193 | freqs = self.freqs.split_with_sizes( 194 | [ 195 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), 196 | self.attention_head_dim // 6, 197 | self.attention_head_dim // 6, 198 | ], 199 | dim=1, 200 | ) 201 | 202 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) 203 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) 204 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) 205 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) 206 | return freqs 207 | 208 | 209 | class WanTransformerBlock(nn.Module): 210 | def __init__( 211 | self, 212 | dim: int, 213 | ffn_dim: int, 214 | num_heads: int, 215 | qk_norm: str = "rms_norm_across_heads", 216 | cross_attn_norm: bool = False, 217 | eps: float = 1e-6, 218 | added_kv_proj_dim: Optional[int] = None, 219 | ): 220 | super().__init__() 221 | 222 | # 1. Self-attention 223 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) 224 | self.attn1 = Attention( 225 | query_dim=dim, 226 | heads=num_heads, 227 | kv_heads=num_heads, 228 | dim_head=dim // num_heads, 229 | qk_norm=qk_norm, 230 | eps=eps, 231 | bias=True, 232 | cross_attention_dim=None, 233 | out_bias=True, 234 | processor=WanAttnProcessor2_0(), 235 | ) 236 | 237 | # 2. Cross-attention 238 | self.attn2 = Attention( 239 | query_dim=dim, 240 | heads=num_heads, 241 | kv_heads=num_heads, 242 | dim_head=dim // num_heads, 243 | qk_norm=qk_norm, 244 | eps=eps, 245 | bias=True, 246 | cross_attention_dim=None, 247 | out_bias=True, 248 | added_kv_proj_dim=added_kv_proj_dim, 249 | added_proj_bias=True, 250 | processor=WanAttnProcessor2_0(), 251 | ) 252 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() 253 | 254 | # 3. Feed-forward 255 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") 256 | self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) 257 | 258 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 259 | 260 | def forward( 261 | self, 262 | hidden_states: torch.Tensor, 263 | encoder_hidden_states: torch.Tensor, 264 | temb: torch.Tensor, 265 | rotary_emb: torch.Tensor, 266 | ) -> torch.Tensor: 267 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( 268 | self.scale_shift_table + temb.float() 269 | ).chunk(6, dim=1) 270 | 271 | # 1. Self-attention 272 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) 273 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) 274 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 275 | 276 | # 2. Cross-attention 277 | norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) 278 | attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 279 | hidden_states = hidden_states + attn_output 280 | 281 | # 3. Feed-forward 282 | norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( 283 | hidden_states 284 | ) 285 | ff_output = self.ffn(norm_hidden_states) 286 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) 287 | 288 | return hidden_states 289 | 290 | 291 | class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 292 | r""" 293 | A Transformer model for video-like data used in the Wan model. 294 | 295 | Args: 296 | patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): 297 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). 298 | num_attention_heads (`int`, defaults to `40`): 299 | Fixed length for text embeddings. 300 | attention_head_dim (`int`, defaults to `128`): 301 | The number of channels in each head. 302 | in_channels (`int`, defaults to `16`): 303 | The number of channels in the input. 304 | out_channels (`int`, defaults to `16`): 305 | The number of channels in the output. 306 | text_dim (`int`, defaults to `512`): 307 | Input dimension for text embeddings. 308 | freq_dim (`int`, defaults to `256`): 309 | Dimension for sinusoidal time embeddings. 310 | ffn_dim (`int`, defaults to `13824`): 311 | Intermediate dimension in feed-forward network. 312 | num_layers (`int`, defaults to `40`): 313 | The number of layers of transformer blocks to use. 314 | window_size (`Tuple[int]`, defaults to `(-1, -1)`): 315 | Window size for local attention (-1 indicates global attention). 316 | cross_attn_norm (`bool`, defaults to `True`): 317 | Enable cross-attention normalization. 318 | qk_norm (`bool`, defaults to `True`): 319 | Enable query/key normalization. 320 | eps (`float`, defaults to `1e-6`): 321 | Epsilon value for normalization layers. 322 | add_img_emb (`bool`, defaults to `False`): 323 | Whether to use img_emb. 324 | added_kv_proj_dim (`int`, *optional*, defaults to `None`): 325 | The number of channels to use for the added key and value projections. If `None`, no projection is used. 326 | """ 327 | 328 | _supports_gradient_checkpointing = True 329 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] 330 | _no_split_modules = ["WanTransformerBlock"] 331 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] 332 | _keys_to_ignore_on_load_unexpected = ["norm_added_q"] 333 | 334 | @register_to_config 335 | def __init__( 336 | self, 337 | patch_size: Tuple[int] = (1, 2, 2), 338 | num_attention_heads: int = 40, 339 | attention_head_dim: int = 128, 340 | in_channels: int = 16, 341 | out_channels: int = 16, 342 | text_dim: int = 4096, 343 | freq_dim: int = 256, 344 | ffn_dim: int = 13824, 345 | num_layers: int = 40, 346 | cross_attn_norm: bool = True, 347 | qk_norm: Optional[str] = "rms_norm_across_heads", 348 | eps: float = 1e-6, 349 | image_dim: Optional[int] = None, 350 | added_kv_proj_dim: Optional[int] = None, 351 | rope_max_seq_len: int = 1024, 352 | ) -> None: 353 | super().__init__() 354 | 355 | inner_dim = num_attention_heads * attention_head_dim 356 | out_channels = out_channels or in_channels 357 | 358 | # 1. Patch & position embedding 359 | self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) 360 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) 361 | 362 | # 2. Condition embeddings 363 | # image_embedding_dim=1280 for I2V model 364 | self.condition_embedder = WanTimeTextImageEmbedding( 365 | dim=inner_dim, 366 | time_freq_dim=freq_dim, 367 | time_proj_dim=inner_dim * 6, 368 | text_embed_dim=text_dim, 369 | image_embed_dim=image_dim, 370 | ) 371 | 372 | # 3. Transformer blocks 373 | self.blocks = nn.ModuleList( 374 | [ 375 | WanTransformerBlock( 376 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim 377 | ) 378 | for _ in range(num_layers) 379 | ] 380 | ) 381 | 382 | # 4. Output norm & projection 383 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) 384 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) 385 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) 386 | 387 | self.gradient_checkpointing = False 388 | 389 | def forward( 390 | self, 391 | hidden_states: torch.Tensor, 392 | timestep: torch.LongTensor, 393 | encoder_hidden_states: torch.Tensor, 394 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 395 | return_dict: bool = True, 396 | attention_kwargs: Optional[Dict[str, Any]] = None, 397 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 398 | if attention_kwargs is not None: 399 | attention_kwargs = attention_kwargs.copy() 400 | lora_scale = attention_kwargs.pop("scale", 1.0) 401 | else: 402 | lora_scale = 1.0 403 | 404 | if USE_PEFT_BACKEND: 405 | # weight the lora layers by setting `lora_scale` for each PEFT layer 406 | scale_lora_layers(self, lora_scale) 407 | else: 408 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 409 | logger.warning( 410 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 411 | ) 412 | 413 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 414 | p_t, p_h, p_w = self.config.patch_size 415 | post_patch_num_frames = num_frames // p_t 416 | post_patch_height = height // p_h 417 | post_patch_width = width // p_w 418 | 419 | rotary_emb = self.rope(hidden_states) 420 | 421 | hidden_states = self.patch_embedding(hidden_states) 422 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 423 | 424 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( 425 | timestep, encoder_hidden_states, encoder_hidden_states_image 426 | ) 427 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 428 | 429 | if encoder_hidden_states_image is not None: 430 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) 431 | 432 | # 4. Transformer blocks 433 | if torch.is_grad_enabled() and self.gradient_checkpointing: 434 | for block in self.blocks: 435 | hidden_states = self._gradient_checkpointing_func( 436 | block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb 437 | ) 438 | else: 439 | for block in self.blocks: 440 | hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) 441 | 442 | # 5. Output norm, projection & unpatchify 443 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 444 | 445 | # Move the shift and scale tensors to the same device as hidden_states. 446 | # When using multi-GPU inference via accelerate these will be on the 447 | # first device rather than the last device, which hidden_states ends up 448 | # on. 449 | shift = shift.to(hidden_states.device) 450 | scale = scale.to(hidden_states.device) 451 | 452 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 453 | hidden_states = self.proj_out(hidden_states) 454 | 455 | hidden_states = hidden_states.reshape( 456 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 457 | ) 458 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 459 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 460 | 461 | if USE_PEFT_BACKEND: 462 | # remove `lora_scale` from each PEFT layer 463 | unscale_lora_layers(self, lora_scale) 464 | 465 | if not return_dict: 466 | return (output,) 467 | 468 | return Transformer2DModelOutput(sample=output) 469 | -------------------------------------------------------------------------------- /src/models/transformer_v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import Any, Dict, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 24 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 25 | from diffusers.models.attention import FeedForward 26 | from diffusers.models.attention_processor import Attention 27 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed 28 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 29 | from diffusers.models.modeling_utils import ModelMixin 30 | from diffusers.models.normalization import FP32LayerNorm 31 | from einops import rearrange 32 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 34 | 35 | 36 | class WanAttnProcessor2_0: 37 | def __init__(self): 38 | if not hasattr(F, "scaled_dot_product_attention"): 39 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 40 | 41 | def __call__( 42 | self, 43 | attn: Attention, 44 | hidden_states: torch.Tensor, 45 | encoder_hidden_states: Optional[torch.Tensor] = None, 46 | attention_mask: Optional[torch.Tensor] = None, 47 | rotary_emb: Optional[torch.Tensor] = None, 48 | ) -> torch.Tensor: 49 | encoder_hidden_states_img = None 50 | if attn.add_k_proj is not None: 51 | encoder_hidden_states_img = encoder_hidden_states[:, :257] 52 | encoder_hidden_states = encoder_hidden_states[:, 257:] 53 | if encoder_hidden_states is None: 54 | encoder_hidden_states = hidden_states 55 | 56 | query = attn.to_q(hidden_states) 57 | key = attn.to_k(encoder_hidden_states) 58 | value = attn.to_v(encoder_hidden_states) 59 | 60 | if attn.norm_q is not None: 61 | query = attn.norm_q(query) 62 | if attn.norm_k is not None: 63 | key = attn.norm_k(key) 64 | 65 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 66 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 67 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 68 | 69 | if rotary_emb is not None: 70 | 71 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 72 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 73 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 74 | return x_out.type_as(hidden_states) 75 | 76 | query = apply_rotary_emb(query, rotary_emb) 77 | key = apply_rotary_emb(key, rotary_emb) 78 | 79 | # I2V task 80 | hidden_states_img = None 81 | if encoder_hidden_states_img is not None: 82 | key_img = attn.add_k_proj(encoder_hidden_states_img) 83 | key_img = attn.norm_added_k(key_img) 84 | value_img = attn.add_v_proj(encoder_hidden_states_img) 85 | 86 | key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 87 | value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 88 | 89 | hidden_states_img = F.scaled_dot_product_attention( 90 | query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False 91 | ) 92 | hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) 93 | hidden_states_img = hidden_states_img.type_as(query) 94 | 95 | hidden_states = F.scaled_dot_product_attention( 96 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 97 | ) 98 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 99 | hidden_states = hidden_states.type_as(query) 100 | 101 | if hidden_states_img is not None: 102 | hidden_states = hidden_states + hidden_states_img 103 | 104 | hidden_states = attn.to_out[0](hidden_states) 105 | hidden_states = attn.to_out[1](hidden_states) 106 | return hidden_states 107 | 108 | 109 | class WanImageEmbedding(torch.nn.Module): 110 | def __init__(self, in_features: int, out_features: int): 111 | super().__init__() 112 | 113 | self.norm1 = FP32LayerNorm(in_features) 114 | self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") 115 | self.norm2 = FP32LayerNorm(out_features) 116 | 117 | def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: 118 | hidden_states = self.norm1(encoder_hidden_states_image) 119 | hidden_states = self.ff(hidden_states) 120 | hidden_states = self.norm2(hidden_states) 121 | return hidden_states 122 | 123 | 124 | class WanTimeTextImageEmbedding(nn.Module): 125 | def __init__( 126 | self, 127 | dim: int, 128 | time_freq_dim: int, 129 | time_proj_dim: int, 130 | text_embed_dim: int, 131 | image_embed_dim: Optional[int] = None, 132 | ): 133 | super().__init__() 134 | 135 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) 136 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) 137 | self.act_fn = nn.SiLU() 138 | self.time_proj = nn.Linear(dim, time_proj_dim) 139 | self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") 140 | 141 | self.image_embedder = None 142 | if image_embed_dim is not None: 143 | self.image_embedder = WanImageEmbedding(image_embed_dim, dim) 144 | 145 | def forward( 146 | self, 147 | timestep: torch.Tensor, 148 | encoder_hidden_states: torch.Tensor, 149 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 150 | ): 151 | timestep = self.timesteps_proj(timestep) 152 | 153 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype 154 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: 155 | timestep = timestep.to(time_embedder_dtype) 156 | temb = self.time_embedder(timestep).type_as(encoder_hidden_states) 157 | timestep_proj = self.time_proj(self.act_fn(temb)) 158 | 159 | encoder_hidden_states = self.text_embedder(encoder_hidden_states) 160 | if encoder_hidden_states_image is not None: 161 | encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) 162 | 163 | return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image 164 | 165 | 166 | class WanRotaryPosEmbed(nn.Module): 167 | def __init__( 168 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 169 | ): 170 | super().__init__() 171 | 172 | self.attention_head_dim = attention_head_dim 173 | self.patch_size = patch_size 174 | self.max_seq_len = max_seq_len 175 | 176 | h_dim = w_dim = 2 * (attention_head_dim // 6) 177 | t_dim = attention_head_dim - h_dim - w_dim 178 | 179 | freqs = [] 180 | for dim in [t_dim, h_dim, w_dim]: 181 | freq = get_1d_rotary_pos_embed( 182 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 183 | ) 184 | freqs.append(freq) 185 | self.freqs = torch.cat(freqs, dim=1) 186 | 187 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 188 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 189 | p_t, p_h, p_w = self.patch_size 190 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w 191 | 192 | self.freqs = self.freqs.to(hidden_states.device) 193 | freqs = self.freqs.split_with_sizes( 194 | [ 195 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), 196 | self.attention_head_dim // 6, 197 | self.attention_head_dim // 6, 198 | ], 199 | dim=1, 200 | ) 201 | 202 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) 203 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) 204 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) 205 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) 206 | return freqs 207 | 208 | 209 | class WanTransformerBlock(nn.Module): 210 | def __init__( 211 | self, 212 | dim: int, 213 | ffn_dim: int, 214 | num_heads: int, 215 | qk_norm: str = "rms_norm_across_heads", 216 | cross_attn_norm: bool = False, 217 | eps: float = 1e-6, 218 | added_kv_proj_dim: Optional[int] = None, 219 | ): 220 | super().__init__() 221 | 222 | # 1. Self-attention 223 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) 224 | self.attn1 = Attention( 225 | query_dim=dim, 226 | heads=num_heads, 227 | kv_heads=num_heads, 228 | dim_head=dim // num_heads, 229 | qk_norm=qk_norm, 230 | eps=eps, 231 | bias=True, 232 | cross_attention_dim=None, 233 | out_bias=True, 234 | processor=WanAttnProcessor2_0(), 235 | ) 236 | 237 | # 2. Cross-attention 238 | self.attn2 = Attention( 239 | query_dim=dim, 240 | heads=num_heads, 241 | kv_heads=num_heads, 242 | dim_head=dim // num_heads, 243 | qk_norm=qk_norm, 244 | eps=eps, 245 | bias=True, 246 | cross_attention_dim=None, 247 | out_bias=True, 248 | added_kv_proj_dim=added_kv_proj_dim, 249 | added_proj_bias=True, 250 | processor=WanAttnProcessor2_0(), 251 | ) 252 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() 253 | 254 | # 3. Feed-forward 255 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") 256 | self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) 257 | 258 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 259 | 260 | def forward( 261 | self, 262 | hidden_states: torch.Tensor, 263 | encoder_hidden_states: torch.Tensor, 264 | temb: torch.Tensor, 265 | rotary_emb: torch.Tensor, 266 | ) -> torch.Tensor: 267 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( 268 | self.scale_shift_table + temb.float() 269 | ).chunk(6, dim=1) 270 | 271 | # 1. Self-attention 272 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) 273 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) 274 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 275 | 276 | # 2. Cross-attention 277 | norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) 278 | attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 279 | hidden_states = hidden_states + attn_output 280 | 281 | # 3. Feed-forward 282 | norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( 283 | hidden_states 284 | ) 285 | ff_output = self.ffn(norm_hidden_states) 286 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) 287 | 288 | return hidden_states 289 | 290 | 291 | class ViewTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 292 | r""" 293 | A Transformer model for video-like data used in the Wan model. 294 | 295 | Args: 296 | patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): 297 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). 298 | num_attention_heads (`int`, defaults to `40`): 299 | Fixed length for text embeddings. 300 | attention_head_dim (`int`, defaults to `128`): 301 | The number of channels in each head. 302 | in_channels (`int`, defaults to `16`): 303 | The number of channels in the input. 304 | out_channels (`int`, defaults to `16`): 305 | The number of channels in the output. 306 | text_dim (`int`, defaults to `512`): 307 | Input dimension for text embeddings. 308 | freq_dim (`int`, defaults to `256`): 309 | Dimension for sinusoidal time embeddings. 310 | ffn_dim (`int`, defaults to `13824`): 311 | Intermediate dimension in feed-forward network. 312 | num_layers (`int`, defaults to `40`): 313 | The number of layers of transformer blocks to use. 314 | window_size (`Tuple[int]`, defaults to `(-1, -1)`): 315 | Window size for local attention (-1 indicates global attention). 316 | cross_attn_norm (`bool`, defaults to `True`): 317 | Enable cross-attention normalization. 318 | qk_norm (`bool`, defaults to `True`): 319 | Enable query/key normalization. 320 | eps (`float`, defaults to `1e-6`): 321 | Epsilon value for normalization layers. 322 | add_img_emb (`bool`, defaults to `False`): 323 | Whether to use img_emb. 324 | added_kv_proj_dim (`int`, *optional*, defaults to `None`): 325 | The number of channels to use for the added key and value projections. If `None`, no projection is used. 326 | """ 327 | 328 | _supports_gradient_checkpointing = True 329 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] 330 | _no_split_modules = ["WanTransformerBlock"] 331 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] 332 | _keys_to_ignore_on_load_unexpected = ["norm_added_q"] 333 | 334 | @register_to_config 335 | def __init__( 336 | self, 337 | patch_size: Tuple[int] = (1, 2, 2), 338 | num_attention_heads: int = 40, 339 | attention_head_dim: int = 128, 340 | in_channels: int = 16, 341 | out_channels: int = 16, 342 | text_dim: int = 4096, 343 | freq_dim: int = 256, 344 | ffn_dim: int = 13824, 345 | num_layers: int = 40, 346 | cross_attn_norm: bool = True, 347 | qk_norm: Optional[str] = "rms_norm_across_heads", 348 | eps: float = 1e-6, 349 | image_dim: Optional[int] = None, 350 | added_kv_proj_dim: Optional[int] = None, 351 | rope_max_seq_len: int = 1024, 352 | ) -> None: 353 | super().__init__() 354 | 355 | inner_dim = num_attention_heads * attention_head_dim 356 | out_channels = out_channels or in_channels 357 | self.inner_dim = inner_dim 358 | # 1. Patch & position embedding 359 | self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) 360 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) 361 | 362 | # 2. Condition embeddings 363 | # image_embedding_dim=1280 for I2V model 364 | self.condition_embedder = WanTimeTextImageEmbedding( 365 | dim=inner_dim, 366 | time_freq_dim=freq_dim, 367 | time_proj_dim=inner_dim * 6, 368 | text_embed_dim=text_dim, 369 | image_embed_dim=image_dim, 370 | ) 371 | 372 | # 3. Transformer blocks 373 | self.blocks = nn.ModuleList( 374 | [ 375 | WanTransformerBlock( 376 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim 377 | ) 378 | for _ in range(num_layers) 379 | ] 380 | ) 381 | 382 | # 4. Output norm & projection 383 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) 384 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) 385 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) 386 | 387 | self.gradient_checkpointing = False 388 | 389 | def forward( 390 | self, 391 | hidden_states: torch.Tensor, 392 | timestep: torch.LongTensor, 393 | encoder_hidden_states: torch.Tensor, 394 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 395 | return_dict: bool = True, 396 | attention_kwargs: Optional[Dict[str, Any]] = None, 397 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 398 | if attention_kwargs is not None: 399 | attention_kwargs = attention_kwargs.copy() 400 | lora_scale = attention_kwargs.pop("scale", 1.0) 401 | else: 402 | lora_scale = 1.0 403 | 404 | if USE_PEFT_BACKEND: 405 | # weight the lora layers by setting `lora_scale` for each PEFT layer 406 | scale_lora_layers(self, lora_scale) 407 | else: 408 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 409 | logger.warning( 410 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 411 | ) 412 | 413 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 414 | p_t, p_h, p_w = self.config.patch_size 415 | post_patch_num_frames = num_frames // p_t 416 | post_patch_height = height // p_h 417 | post_patch_width = width // p_w 418 | 419 | rotary_emb_global = self.rope(hidden_states) 420 | 421 | L_F, B_R = torch.chunk(hidden_states,2,dim=-2) 422 | L, F = torch.chunk(L_F,2,dim=-1) 423 | B, R = torch.chunk(B_R,2,dim=-1) 424 | split_hidden_states = torch.cat([L,F,B,R],dim=0) 425 | rotary_emb_cube = self.rope(split_hidden_states) 426 | 427 | 428 | hidden_states = self.patch_embedding(hidden_states) 429 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 430 | 431 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( 432 | timestep, encoder_hidden_states, encoder_hidden_states_image 433 | ) 434 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 435 | 436 | if encoder_hidden_states_image is not None: 437 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) 438 | 439 | # 4. Transformer blocks 440 | if torch.is_grad_enabled() and self.gradient_checkpointing: 441 | for block_idx in range(len(self.blocks)//2): 442 | global_block = self.blocks[2*block_idx] 443 | cube_block = self.blocks[2*block_idx+1] 444 | 445 | hidden_states = self._gradient_checkpointing_func( 446 | global_block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb_global 447 | ) 448 | hidden_states = hidden_states.transpose(1, 2).unflatten(2,(post_patch_num_frames,post_patch_height,post_patch_width)) 449 | hidden_states = hidden_states.view((batch_size, -1, post_patch_num_frames, 2, post_patch_height//2, 2, post_patch_width//2)) 450 | hidden_states = hidden_states.permute(0, 3, 5, 1, 2, 4, 6).reshape(4*batch_size, -1, post_patch_num_frames, post_patch_height//2, post_patch_width//2) 451 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 452 | 453 | hidden_states = self._gradient_checkpointing_func( 454 | cube_block, hidden_states, encoder_hidden_states.repeat(4*batch_size,1,1), timestep_proj.repeat(4*batch_size,1,1), rotary_emb_cube 455 | ) 456 | 457 | hidden_states = hidden_states.transpose(1, 2).unflatten(2,(post_patch_num_frames,post_patch_height//2,post_patch_width//2)) 458 | hidden_states = hidden_states.view((batch_size, 2, 2, -1, post_patch_num_frames, post_patch_height//2, post_patch_width//2)) 459 | hidden_states = hidden_states.permute(0, 3, 4, 1, 5, 2, 6).reshape(batch_size, -1, post_patch_num_frames, post_patch_height, post_patch_width) 460 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 461 | else: 462 | for block_idx in range(len(self.blocks)//2): 463 | global_block = self.blocks[2*block_idx] 464 | cube_block = self.blocks[2*block_idx+1] 465 | 466 | hidden_states = global_block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb_global) 467 | 468 | hidden_states = hidden_states.transpose(1, 2).unflatten(2,(post_patch_num_frames,post_patch_height,post_patch_width)) 469 | hidden_states = hidden_states.view((batch_size, -1, post_patch_num_frames, 2, post_patch_height//2, 2, post_patch_width//2)) 470 | hidden_states = hidden_states.permute(0, 3, 5, 1, 2, 4, 6).reshape(4*batch_size, -1, post_patch_num_frames, post_patch_height//2, post_patch_width//2) 471 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 472 | 473 | hidden_states = cube_block(hidden_states, encoder_hidden_states.repeat(4*batch_size,1,1), timestep_proj.repeat(4*batch_size,1,1), rotary_emb_cube) 474 | 475 | hidden_states = hidden_states.transpose(1, 2).unflatten(2,(post_patch_num_frames,post_patch_height//2,post_patch_width//2)) 476 | hidden_states = hidden_states.view((batch_size, 2, 2, -1, post_patch_num_frames, post_patch_height//2, post_patch_width//2)) 477 | hidden_states = hidden_states.permute(0, 3, 4, 1, 5, 2, 6).reshape(batch_size, -1, post_patch_num_frames, post_patch_height, post_patch_width) 478 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 479 | 480 | # 5. Output norm, projection & unpatchify 481 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 482 | 483 | # Move the shift and scale tensors to the same device as hidden_states. 484 | # When using multi-GPU inference via accelerate these will be on the 485 | # first device rather than the last device, which hidden_states ends up 486 | # on. 487 | shift = shift.to(hidden_states.device) 488 | scale = scale.to(hidden_states.device) 489 | 490 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 491 | hidden_states = self.proj_out(hidden_states) 492 | 493 | hidden_states = hidden_states.reshape( 494 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 495 | ) 496 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 497 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 498 | 499 | if USE_PEFT_BACKEND: 500 | # remove `lora_scale` from each PEFT layer 501 | unscale_lora_layers(self, lora_scale) 502 | 503 | if not return_dict: 504 | return (output,) 505 | 506 | return Transformer2DModelOutput(sample=output) 507 | -------------------------------------------------------------------------------- /utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from torch.nn import functional as F 4 | import numpy as np 5 | import math 6 | from . import cube2equi, get_transform_matrix_1 7 | from PIL import Image 8 | import torchvision.transforms as T 9 | 10 | 11 | def convert_to_cubemap(L_F_D_B_R): 12 | L_F_D_B_R = torch.tensor(L_F_D_B_R).permute(0,3,1,2) 13 | L_F, B_R = torch.chunk(L_F_D_B_R, 2, dim=-2) 14 | L_, F_ = torch.chunk(L_F, 2, dim=-1) 15 | B_, R_ = torch.chunk(B_R, 2, dim=-1) 16 | 17 | scale = np.sqrt(2)/2 18 | L_affine = F.grid_sample(L_, get_transform_matrix_1('L',L_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 19 | F_affine = F.grid_sample(F_, get_transform_matrix_1('F',F_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 20 | B_affine = F.grid_sample(B_, get_transform_matrix_1('B',B_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 21 | R_affine = F.grid_sample(R_, get_transform_matrix_1('R',R_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 22 | 23 | scale = np.sqrt(2)/4 24 | D_affine = F.grid_sample(L_F_D_B_R, get_transform_matrix_1('D',L_F_D_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 25 | 26 | L_ = torch.rot90(torch.flip(L_,dims=[-1]), k=-1, dims=(-2, -1)) 27 | F_ = torch.rot90(torch.flip(F_,dims=[-1]), k=1, dims=(-2, -1)) 28 | B_ = torch.rot90(torch.flip(B_,dims=[-1]), k=1, dims=(-2, -1)) 29 | R_ = torch.rot90(torch.flip(R_,dims=[-1]), k=-1, dims=(-2, -1)) 30 | L_F_U_B_R = torch.cat([torch.cat([L_,F_],dim=-1),torch.cat([B_,R_],dim=-1)],dim=-2) 31 | L_F_U_B_R = torch.flip(L_F_U_B_R,dims=[-1]) 32 | U_affine = F.grid_sample(L_F_U_B_R, get_transform_matrix_1('U',L_F_U_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 33 | return torch.cat([F_affine,R_affine,B_affine,L_affine,U_affine,D_affine],dim=-1) 34 | 35 | def post_process_U(U_tensors): 36 | 37 | U_tensors_F_B = torch.cat(torch.chunk(U_tensors,2,dim=-2)[: :-1],dim=-2) 38 | U_tensors_L_R = torch.cat(torch.chunk(U_tensors,2,dim=-1)[: :-1],dim=-1) 39 | 40 | b,c,H,W = U_tensors_F_B.shape 41 | y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') 42 | x = x.float() 43 | y = y.float() 44 | 45 | center = torch.tensor([H/2, H/2], device=x.device) 46 | 47 | # 计算相对坐标 48 | dx = x - center[0] 49 | dy = center[1] - y # 反转y轴 50 | 51 | # 极坐标参数 52 | theta = torch.atan2(dy, dx) 53 | r = torch.sqrt(dx**2 + dy**2) 54 | 55 | # 圆形有效区域判断 56 | valid_mask = r <= (H/2) 57 | 58 | # 计算菱形映射参数 59 | cos_theta = torch.cos(theta) 60 | sin_theta = torch.sin(theta) 61 | denominator = torch.abs(cos_theta) + torch.abs(sin_theta) 62 | r_max = (H/2) / (denominator + 1e-6) 63 | 64 | # 计算原图坐标 65 | r_prime = torch.where(valid_mask, r * (r_max / (H/2)), 0.0) 66 | px = center[0] + r_prime * cos_theta 67 | py = center[1] - r_prime * sin_theta 68 | 69 | # 归一化处理 70 | grid_x = (px / (W-1)) * 2 - 1 71 | grid_y = (py / (H-1)) * 2 - 1 72 | 73 | # 无效点处理 74 | grid_x = torch.where(valid_mask, grid_x, -2.0) 75 | grid_y = torch.where(valid_mask, grid_y, -2.0) 76 | 77 | grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0).repeat(b,1,1,1).to(device=U_tensors_F_B.device) 78 | affine_U_tensors_F_B = F.grid_sample(U_tensors_F_B, grid, mode='bilinear', align_corners=False) 79 | affine_U_tensors_L_R = F.grid_sample(U_tensors_L_R, grid, mode='bilinear', align_corners=False) 80 | 81 | 82 | U_FB = torch.cat(torch.chunk(affine_U_tensors_F_B,2,dim=-2)[: :-1],dim=-2) 83 | U_LR = torch.cat(torch.chunk(affine_U_tensors_L_R,2,dim=-1)[: :-1],dim=-1) 84 | b_,c_,h_,w_ = U_FB.shape 85 | FB_mask = generate_diamond_mask(h_) 86 | FB_mask = np.concatenate((FB_mask[h_//2:,:], FB_mask[0:h_//2,:]), axis=0) 87 | 88 | U_affine = (torch.from_numpy(FB_mask))*U_FB + (torch.from_numpy(1-FB_mask))*U_LR 89 | # kernel_size = (5, 5) 90 | # sigma = (2.0, 2.0) 91 | # transform = T.GaussianBlur(kernel_size=kernel_size, sigma=sigma) 92 | # U_affine = torch.stack([transform(image) for image in U_affine]) 93 | mask_np = np.zeros((h_, w_)).astype(np.uint8) 94 | inpaint_size = 8 95 | h_off = (h_-inpaint_size)//2 96 | w_off = (w_-inpaint_size)//2 97 | mask_np[0:inpaint_size,0:inpaint_size] = 1 98 | mask_np[(h_-inpaint_size):,0:inpaint_size] = 1 99 | mask_np[0:inpaint_size,(w_-inpaint_size):] = 1 100 | mask_np[(h_-inpaint_size):,(w_-inpaint_size):] = 1 101 | mask_np[h_off:h_off+inpaint_size,w_off:w_off+inpaint_size] = 1 102 | repaired_images = [] 103 | for i in range(U_affine.shape[0]): 104 | # 将张量转换为 NumPy 数组并将通道移至最后 105 | image_np = (U_affine[i].permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8) 106 | 107 | # 使用 OpenCV 的 inpaint 进行修复 108 | repaired_np = cv2.inpaint(image_np, mask_np, inpaintRadius=3, flags=cv2.INPAINT_TELEA) 109 | 110 | # 将修复后的结果转换回 PyTorch 张量并将通道移至最前 111 | repaired_tensor = torch.from_numpy(repaired_np).permute(2, 0, 1) 112 | 113 | # 添加到结果列表中 114 | repaired_images.append(repaired_tensor) 115 | 116 | U_affine = torch.stack(repaired_images)/255.0 117 | return U_affine 118 | 119 | def post_process_U2(U_tensors): 120 | 121 | U_tensors_F_B = torch.cat(torch.chunk(U_tensors,2,dim=-2)[: :-1],dim=-2) 122 | U_tensors_L_R = torch.cat(torch.chunk(U_tensors,2,dim=-1)[: :-1],dim=-1) 123 | 124 | b,c,H,W = U_tensors_F_B.shape 125 | y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') 126 | x = x.float() 127 | y = y.float() 128 | 129 | center = torch.tensor([H/2, H/2], device=x.device) 130 | 131 | # 计算相对坐标 132 | dx = x - center[0] 133 | dy = center[1] - y # 反转y轴 134 | 135 | # 极坐标参数 136 | theta = torch.atan2(dy, dx) 137 | r = torch.sqrt(dx**2 + dy**2) 138 | 139 | # 圆形有效区域判断 140 | valid_mask = r <= (H/2) 141 | 142 | # 计算菱形映射参数 143 | cos_theta = torch.cos(theta) 144 | sin_theta = torch.sin(theta) 145 | denominator = torch.abs(cos_theta) + torch.abs(sin_theta) 146 | r_max = (H/2) / (denominator + 1e-6) 147 | 148 | # 计算原图坐标 149 | r_prime = torch.where(valid_mask, r * (r_max / (H/2)), 0.0) 150 | px = center[0] + r_prime * cos_theta 151 | py = center[1] - r_prime * sin_theta 152 | 153 | # 归一化处理 154 | grid_x = (px / (W-1)) * 2 - 1 155 | grid_y = (py / (H-1)) * 2 - 1 156 | 157 | # 无效点处理 158 | grid_x = torch.where(valid_mask, grid_x, -2.0) 159 | grid_y = torch.where(valid_mask, grid_y, -2.0) 160 | 161 | grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0).repeat(b,1,1,1).to(device=U_tensors_F_B.device) 162 | affine_U_tensors_F_B = F.grid_sample(U_tensors_F_B, grid, mode='bilinear', align_corners=False) 163 | affine_U_tensors_L_R = F.grid_sample(U_tensors_L_R, grid, mode='bilinear', align_corners=False) 164 | 165 | 166 | U_FB = torch.cat(torch.chunk(affine_U_tensors_F_B,2,dim=-2)[: :-1],dim=-2) 167 | U_LR = torch.cat(torch.chunk(affine_U_tensors_L_R,2,dim=-1)[: :-1],dim=-1) 168 | 169 | b_,c_,h_,w_ = U_FB.shape 170 | c_mask = create_circular_mask(h_,w_) 171 | 172 | FB_mask = np.concatenate((c_mask[h_//2:,:], c_mask[0:h_//2,:]), axis=0) 173 | LR_mask = np.concatenate((c_mask[:,w_//2:], c_mask[:,0:w_//2]), axis=1) 174 | fusion_mask = (FB_mask & LR_mask) 175 | 176 | fusion_part = torch.from_numpy(fusion_mask)*(U_FB+U_LR)/2 177 | U_affine = fusion_part+(torch.from_numpy(1-LR_mask))*U_FB + (torch.from_numpy(1-FB_mask))*U_LR 178 | # kernel_size = (5, 5) 179 | # sigma = (2.0, 2.0) 180 | # transform = T.GaussianBlur(kernel_size=kernel_size, sigma=sigma) 181 | # U_affine = torch.stack([transform(image) for image in U_affine]) 182 | mask_np = create_circle_line_mask(h_,8) 183 | mask_np1 = np.concatenate((mask_np[h_//2:,:], mask_np[0:h_//2,:]), axis=0) 184 | mask_np2 = np.concatenate((mask_np[:,w_//2:], mask_np[:,0:w_//2]), axis=1) 185 | mask_np = mask_np1|mask_np2 186 | 187 | # mask_np = np.zeros((h_, w_)).astype(np.uint8) 188 | # inpaint_size = 8 189 | # h_off = (h_-inpaint_size)//2 190 | # w_off = (w_-inpaint_size)//2 191 | # mask_np[0:inpaint_size,0:inpaint_size] = 1 192 | # mask_np[(h_-inpaint_size):,0:inpaint_size] = 1 193 | # mask_np[0:inpaint_size,(w_-inpaint_size):] = 1 194 | # mask_np[(h_-inpaint_size):,(w_-inpaint_size):] = 1 195 | # mask_np[h_off:h_off+inpaint_size,w_off:w_off+inpaint_size] = 1 196 | 197 | repaired_images = [] 198 | for i in range(U_affine.shape[0]): 199 | # 将张量转换为 NumPy 数组并将通道移至最后 200 | image_np = (U_affine[i].permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8) 201 | 202 | 203 | # 使用 OpenCV 的 inpaint 进行修复 204 | repaired_np = cv2.inpaint(image_np, mask_np, inpaintRadius=3, flags=cv2.INPAINT_TELEA) 205 | 206 | # 将修复后的结果转换回 PyTorch 张量并将通道移至最前 207 | repaired_tensor = torch.from_numpy(repaired_np).permute(2, 0, 1) 208 | 209 | # 添加到结果列表中 210 | repaired_images.append(repaired_tensor) 211 | 212 | U_affine = torch.stack(repaired_images)/255.0 213 | return U_affine 214 | 215 | def post_process_U3(U_tensors): 216 | 217 | U_tensors_F_B = torch.cat(torch.chunk(U_tensors,2,dim=-2)[: :-1],dim=-2) 218 | U_tensors_L_R = torch.cat(torch.chunk(U_tensors,2,dim=-1)[: :-1],dim=-1) 219 | 220 | b,c,H,W = U_tensors_F_B.shape 221 | y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') 222 | x = x.float() 223 | y = y.float() 224 | 225 | center = torch.tensor([H/2, H/2], device=x.device) 226 | 227 | # 计算相对坐标 228 | dx = x - center[0] 229 | dy = center[1] - y # 反转y轴 230 | 231 | # 极坐标参数 232 | theta = torch.atan2(dy, dx) 233 | r = torch.sqrt(dx**2 + dy**2) 234 | 235 | # 圆形有效区域判断 236 | valid_mask = r <= (H/2) 237 | 238 | # 计算菱形映射参数 239 | cos_theta = torch.cos(theta) 240 | sin_theta = torch.sin(theta) 241 | denominator = torch.abs(cos_theta) + torch.abs(sin_theta) 242 | r_max = (H/2) / (denominator + 1e-6) 243 | 244 | # 计算原图坐标 245 | r_prime = torch.where(valid_mask, r * (r_max / (H/2)), 0.0) 246 | px = center[0] + r_prime * cos_theta 247 | py = center[1] - r_prime * sin_theta 248 | 249 | # 归一化处理 250 | grid_x = (px / (W-1)) * 2 - 1 251 | grid_y = (py / (H-1)) * 2 - 1 252 | 253 | # 无效点处理 254 | grid_x = torch.where(valid_mask, grid_x, -2.0) 255 | grid_y = torch.where(valid_mask, grid_y, -2.0) 256 | 257 | grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0).repeat(b,1,1,1).to(device=U_tensors_F_B.device) 258 | affine_U_tensors_F_B = F.grid_sample(U_tensors_F_B, grid, mode='bilinear', align_corners=False) 259 | affine_U_tensors_L_R = F.grid_sample(U_tensors_L_R, grid, mode='bilinear', align_corners=False) 260 | 261 | 262 | U_FB = torch.cat(torch.chunk(affine_U_tensors_F_B,2,dim=-2)[: :-1],dim=-2) 263 | U_LR = torch.cat(torch.chunk(affine_U_tensors_L_R,2,dim=-1)[: :-1],dim=-1) 264 | 265 | b_,c_,h_,w_ = U_FB.shape 266 | c_mask = create_circular_mask(h_,w_) 267 | 268 | FB_mask = np.concatenate((c_mask[h_//2:,:], c_mask[0:h_//2,:]), axis=0) 269 | LR_mask = np.concatenate((c_mask[:,w_//2:], c_mask[:,0:w_//2]), axis=1) 270 | fusion_mask = (FB_mask & LR_mask) 271 | 272 | inner_radius = (np.sqrt(2)-1)*(h_//2) 273 | fusion_mask1 = create_gradient_ring(h_,inner_radius=inner_radius,r1=0,r2=1) 274 | fusion_mask1_1 = np.concatenate((fusion_mask1[h_//2:,:], fusion_mask1[0:h_//2,:]), axis=0) 275 | fusion_mask1_2 = np.concatenate((fusion_mask1[:,w_//2:], fusion_mask1[:,0:w_//2]), axis=1) 276 | fusion_part_1 = (torch.from_numpy(fusion_mask1_1)*U_LR +torch.from_numpy(1-fusion_mask1_1)*U_FB)*fusion_mask 277 | fusion_part_2 = (torch.from_numpy(1-fusion_mask1_2)*U_LR +torch.from_numpy(fusion_mask1_2)*U_FB)*fusion_mask 278 | fusion_part = (fusion_part_1 + fusion_part_2)/2 279 | 280 | U_affine = fusion_part+(torch.from_numpy(1-LR_mask))*U_FB + (torch.from_numpy(1-FB_mask))*U_LR 281 | # kernel_size = (5, 5) 282 | # sigma = (2.0, 2.0) 283 | # transform = T.GaussianBlur(kernel_size=kernel_size, sigma=sigma) 284 | # U_affine = torch.stack([transform(image) for image in U_affine]) 285 | mask_np = create_circle_line_mask(h_,6) 286 | mask_np1 = np.concatenate((mask_np[h_//2:,:], mask_np[0:h_//2,:]), axis=0) 287 | mask_np2 = np.concatenate((mask_np[:,w_//2:], mask_np[:,0:w_//2]), axis=1) 288 | mask_np = mask_np1|mask_np2 289 | mask_np[:h_//4,:] = 0 290 | mask_np[3*(h_//4):,:] = 0 291 | mask_np[:,:w_//4] = 0 292 | mask_np[:,3*(w_//4):] = 0 293 | # mask_np = np.zeros((h_, w_)).astype(np.uint8) 294 | # inpaint_size = 8 295 | # h_off = (h_-inpaint_size)//2 296 | # w_off = (w_-inpaint_size)//2 297 | # mask_np[0:inpaint_size,0:inpaint_size] = 1 298 | # mask_np[(h_-inpaint_size):,0:inpaint_size] = 1 299 | # mask_np[0:inpaint_size,(w_-inpaint_size):] = 1 300 | # mask_np[(h_-inpaint_size):,(w_-inpaint_size):] = 1 301 | # mask_np[h_off:h_off+inpaint_size,w_off:w_off+inpaint_size] = 1 302 | 303 | repaired_images = [] 304 | for i in range(U_affine.shape[0]): 305 | # 将张量转换为 NumPy 数组并将通道移至最后 306 | image_np = (U_affine[i].permute(1, 2, 0).cpu().numpy()*255).astype(np.uint8) 307 | 308 | 309 | # 使用 OpenCV 的 inpaint 进行修复 310 | repaired_np = cv2.inpaint(image_np, mask_np, inpaintRadius=3, flags=cv2.INPAINT_TELEA) 311 | 312 | # 将修复后的结果转换回 PyTorch 张量并将通道移至最前 313 | repaired_tensor = torch.from_numpy(repaired_np).permute(2, 0, 1) 314 | 315 | # 添加到结果列表中 316 | repaired_images.append(repaired_tensor) 317 | U_affine = torch.stack(repaired_images)/255.0 318 | 319 | return U_affine 320 | 321 | def create_gradient_ring(image_size, inner_radius, r1, r2): 322 | # Create coordinate grids 323 | y, x = np.ogrid[:image_size, :image_size] 324 | center = (image_size//2,image_size//2) 325 | outer_radius = image_size//2 326 | # Calculate Euclidean distance from each point to the center 327 | distance_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) 328 | 329 | # Create mask with gradient between two circles 330 | mask = np.zeros((image_size, image_size), dtype=np.float32) 331 | within_outer_circle = distance_from_center <= outer_radius 332 | outside_inner_circle = distance_from_center >= inner_radius 333 | 334 | # Apply gradient only in the ring section 335 | ring_section = within_outer_circle & outside_inner_circle 336 | normalized_distance = (distance_from_center[ring_section] - inner_radius) / (outer_radius - inner_radius) 337 | mask[ring_section] = r2 + (r1 - r2) * (1 - normalized_distance) 338 | 339 | return mask 340 | 341 | def create_gradient_circle(image_size, r1, r2): 342 | # Create coordinate grids 343 | y, x = np.ogrid[:image_size, :image_size] 344 | center = (image_size//2,image_size//2) 345 | radius = image_size//2 346 | # Calculate Euclidean distance from each point to the center 347 | distance_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2) 348 | 349 | # Create a mask that defines a circular gradient 350 | mask = np.zeros((image_size, image_size), dtype=np.float32) 351 | inside_circle = distance_from_center <= radius 352 | 353 | # Interpolate values between r1 (at the perimeter) and r2 (at the center) 354 | mask[inside_circle] = r1 + (r2 - r1) * (distance_from_center[inside_circle] / radius) 355 | 356 | return mask 357 | 358 | def create_circle_line_mask(image_size, thickness): 359 | # Create a blank black image 360 | mask = np.zeros((image_size, image_size), dtype=np.uint8) 361 | center = (image_size//2,image_size//2) 362 | radius = image_size//2 363 | # Draw a circle on the mask 364 | cv2.circle(mask, center, radius, (1), thickness=thickness) 365 | 366 | return mask 367 | 368 | def create_circular_mask(height, width, center=None, radius=None): 369 | # Default center and radius 370 | if center is None: # Compute center if not provided 371 | center = (int(width / 2), int(height / 2)) 372 | if radius is None: # Compute radius if not provided 373 | radius = min(center[0], center[1], width - center[0], height - center[1]) 374 | 375 | # Create a grid of coordinates 376 | Y, X = np.ogrid[:height, :width] 377 | 378 | # Compute the distance from the circle's center 379 | distance_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) 380 | 381 | # Apply mask 382 | mask = distance_from_center <= radius 383 | 384 | return mask.astype(np.uint8) 385 | 386 | def generate_diamond_mask(size): 387 | # 创建一个空的掩码,将其初始化为全零(False) 388 | mask = np.zeros((size, size)) 389 | 390 | # 计算中间点 391 | center = size // 2 392 | 393 | for i in range(size): 394 | for j in range(size): 395 | # 计算单元格到中心的曼哈顿距离(菱形距离) 396 | distance = abs(center - i) + abs(center - j) 397 | 398 | # 如果距离小于或等于中心的距离,则将该位置设为True 399 | if distance <= center: 400 | mask[i, j] = 1 401 | 402 | return mask.astype(np.uint8) 403 | 404 | def convert_to_cubemap1(L_F_D_B_R, norm_overlap=True): 405 | 406 | L_F_D_B_R = torch.tensor(L_F_D_B_R).permute(0,3,1,2) 407 | L_F, B_R = torch.chunk(L_F_D_B_R, 2, dim=-2) 408 | L_, F_ = torch.chunk(L_F, 2, dim=-1) 409 | B_, R_ = torch.chunk(B_R, 2, dim=-1) 410 | 411 | if norm_overlap: 412 | b,c,h,w = L_.shape 413 | # L 414 | L_norm = L_.clone() 415 | L_norm[:,:,0:h//2,w//2:w] = (L_[:,:,0:h//2,w//2:w] + torch.rot90(F_, k=1, dims=(-2, -1))[:,:,h//2:h,0:w//2])/2 416 | L_norm[:,:,h//2:h,0:w//2] = (L_[:,:,h//2:h,0:w//2] + torch.rot90(B_, k=-1, dims=(-2, -1))[:,:,0:h//2,w//2:w])/2 417 | # F 418 | F_norm = F_.clone() 419 | F_norm[:,:,0:h//2,0:w//2] = (F_[:,:,0:h//2,0:w//2] + torch.rot90(L_, k=-1, dims=(-2, -1))[:,:,h//2:h,w//2:w])/2 420 | F_norm[:,:,h//2:h,w//2:w] = (F_[:,:,h//2:h,w//2:w] + torch.rot90(R_, k=1, dims=(-2, -1))[:,:,0:h//2,0:w//2])/2 421 | # B 422 | B_norm = B_.clone() 423 | B_norm[:,:,0:h//2,0:w//2] = (B_[:,:,0:h//2,0:w//2] + torch.rot90(L_, k=1, dims=(-2, -1))[:,:,h//2:h,w//2:w])/2 424 | B_norm[:,:,h//2:h,w//2:w] = (B_[:,:,h//2:h,w//2:w] + torch.rot90(R_, k=-1, dims=(-2, -1))[:,:,0:h//2,0:w//2])/2 425 | # R 426 | R_norm = R_.clone() 427 | R_norm[:,:,0:h//2,w//2:w] = (R_[:,:,0:h//2,w//2:w] + torch.rot90(F_, k=-1, dims=(-2, -1))[:,:,h//2:h,0:w//2])/2 428 | R_norm[:,:,h//2:h,0:w//2] = (R_[:,:,h//2:h,0:w//2] + torch.rot90(B_, k=1, dims=(-2, -1))[:,:,0:h//2,w//2:w])/2 429 | 430 | L_ = L_norm 431 | F_ = F_norm 432 | B_ = B_norm 433 | R_ = R_norm 434 | 435 | scale = np.sqrt(2)/2 436 | L_affine = F.grid_sample(L_, get_transform_matrix_1('L',L_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 437 | F_affine = F.grid_sample(F_, get_transform_matrix_1('F',F_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 438 | B_affine = F.grid_sample(B_, get_transform_matrix_1('B',B_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 439 | R_affine = F.grid_sample(R_, get_transform_matrix_1('R',R_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 440 | 441 | scale = np.sqrt(2)/4 442 | D_affine = F.grid_sample(L_F_D_B_R, get_transform_matrix_1('D',L_F_D_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 443 | 444 | L_ = torch.rot90(torch.flip(L_,dims=[-1]), k=-1, dims=(-2, -1)) 445 | F_ = torch.rot90(torch.flip(F_,dims=[-1]), k=1, dims=(-2, -1)) 446 | B_ = torch.rot90(torch.flip(B_,dims=[-1]), k=1, dims=(-2, -1)) 447 | R_ = torch.rot90(torch.flip(R_,dims=[-1]), k=-1, dims=(-2, -1)) 448 | L_F_U_B_R = torch.cat([torch.cat([L_,F_],dim=-1),torch.cat([B_,R_],dim=-1)],dim=-2) 449 | L_F_U_B_R = torch.flip(L_F_U_B_R,dims=[-1]) 450 | U_affine = F.grid_sample(L_F_U_B_R, get_transform_matrix_1('U',L_F_U_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 451 | 452 | U_affine = post_process_U(U_affine) 453 | 454 | 455 | 456 | return torch.cat([F_affine,R_affine,B_affine,L_affine,U_affine,D_affine],dim=-1) 457 | 458 | def save_array_as_image(array, filename): 459 | # Ensure array is in uint8 format if not already 460 | array = (array * 255).astype(np.uint8) # Normalize if array values are [0, 1] 461 | 462 | # Convert NumPy array to PIL Image 463 | image = Image.fromarray(array) 464 | 465 | # Save the image 466 | image.save(filename) 467 | 468 | def generate_pyramid(n): 469 | # 创建一个空的正方形矩阵 470 | pyramid = torch.zeros((n, n), dtype=torch.float32) 471 | 472 | # 最大的对角线距离 473 | max_distance = np.sqrt((n - 1)**2 + (n - 1)**2) 474 | 475 | for i in range(n): 476 | for j in range(n): 477 | # 当前点到左上角 (0,0) 的欧几里德距离 478 | distance = np.sqrt(i**2 + j**2) 479 | 480 | # 计算高度值 481 | pyramid[i, j] = 1 - (distance / max_distance) 482 | 483 | return pyramid 484 | 485 | def generate_linear_gradient(n): 486 | # 创建正方形矩阵 487 | gradient = torch.zeros((n, n), dtype=torch.float32) 488 | 489 | for i in range(n): 490 | # 线性插值计算从1 (上边) 到 0 (下边) 的值 491 | gradient[i, :] = 1 - (i / (n - 1)) 492 | 493 | return (gradient + torch.rot90(gradient))/2 494 | 495 | def convert_to_cubemap2(L_F_D_B_R, norm_overlap=True): 496 | L_F_D_B_R = torch.tensor(L_F_D_B_R).permute(0,3,1,2) 497 | L_F, B_R = torch.chunk(L_F_D_B_R, 2, dim=-2) 498 | L_, F_ = torch.chunk(L_F, 2, dim=-1) 499 | B_, R_ = torch.chunk(B_R, 2, dim=-1) 500 | 501 | if norm_overlap: 502 | b,c,h,w = L_.shape 503 | weighted_pyramid1 = generate_pyramid(h//2) 504 | weighted_pyramid2 = generate_linear_gradient(h//2) 505 | 506 | # L 507 | L_norm = L_.clone() 508 | rot_weighted_pyramid1 = torch.rot90(weighted_pyramid2,k=1) 509 | rot_weighted_pyramid2 = torch.rot90(weighted_pyramid2,k=-1) 510 | L_norm[:,:,0:h//2,w//2:w] = rot_weighted_pyramid1*L_[:,:,0:h//2,w//2:w] + (1-rot_weighted_pyramid1)*torch.rot90(F_, k=1, dims=(-2, -1))[:,:,h//2:h,0:w//2] 511 | L_norm[:,:,h//2:h,0:w//2] = rot_weighted_pyramid2*L_[:,:,h//2:h,0:w//2] + (1-rot_weighted_pyramid2)*torch.rot90(B_, k=-1, dims=(-2, -1))[:,:,0:h//2,w//2:w] 512 | # F 513 | F_norm = F_.clone() 514 | rot_weighted_pyramid1 = torch.rot90(weighted_pyramid2,k=-2) 515 | rot_weighted_pyramid2 = weighted_pyramid2 516 | F_norm[:,:,0:h//2,0:w//2] = rot_weighted_pyramid1*F_[:,:,0:h//2,0:w//2] + (1-rot_weighted_pyramid1)*torch.rot90(L_, k=-1, dims=(-2, -1))[:,:,h//2:h,w//2:w] 517 | F_norm[:,:,h//2:h,w//2:w] = rot_weighted_pyramid2*F_[:,:,h//2:h,w//2:w] + (1-rot_weighted_pyramid2)*torch.rot90(R_, k=1, dims=(-2, -1))[:,:,0:h//2,0:w//2] 518 | # B 519 | B_norm = B_.clone() 520 | rot_weighted_pyramid1 = torch.rot90(weighted_pyramid2,k=-2) 521 | rot_weighted_pyramid2 = weighted_pyramid2 522 | B_norm[:,:,0:h//2,0:w//2] = rot_weighted_pyramid1*B_[:,:,0:h//2,0:w//2] + (1-rot_weighted_pyramid1)*torch.rot90(L_, k=1, dims=(-2, -1))[:,:,h//2:h,w//2:w] 523 | B_norm[:,:,h//2:h,w//2:w] = rot_weighted_pyramid2*B_[:,:,h//2:h,w//2:w] + (1-rot_weighted_pyramid2)*torch.rot90(R_, k=-1, dims=(-2, -1))[:,:,0:h//2,0:w//2] 524 | # R 525 | R_norm = R_.clone() 526 | rot_weighted_pyramid1 = torch.rot90(weighted_pyramid2,k=1) 527 | rot_weighted_pyramid2 = torch.rot90(weighted_pyramid2,k=-1) 528 | R_norm[:,:,0:h//2,w//2:w] = rot_weighted_pyramid1*R_[:,:,0:h//2,w//2:w] + (1-rot_weighted_pyramid1)*torch.rot90(F_, k=-1, dims=(-2, -1))[:,:,h//2:h,0:w//2] 529 | R_norm[:,:,h//2:h,0:w//2] = rot_weighted_pyramid2*R_[:,:,h//2:h,0:w//2] + (1-rot_weighted_pyramid2)*torch.rot90(B_, k=1, dims=(-2, -1))[:,:,0:h//2,w//2:w] 530 | 531 | L_ = L_norm 532 | F_ = F_norm 533 | B_ = B_norm 534 | R_ = R_norm 535 | 536 | scale = np.sqrt(2)/2 537 | 538 | L_affine = F.grid_sample(L_, get_transform_matrix_1('L',L_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 539 | F_affine = F.grid_sample(F_, get_transform_matrix_1('F',F_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 540 | B_affine = F.grid_sample(B_, get_transform_matrix_1('B',B_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 541 | R_affine = F.grid_sample(R_, get_transform_matrix_1('R',R_.shape,scale,reverse=True), mode='bilinear', align_corners=False) 542 | 543 | scale = np.sqrt(2)/4 544 | D_affine = F.grid_sample(L_F_D_B_R, get_transform_matrix_1('D',L_F_D_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 545 | 546 | L_ = torch.rot90(torch.flip(L_,dims=[-1]), k=-1, dims=(-2, -1)) 547 | F_ = torch.rot90(torch.flip(F_,dims=[-1]), k=1, dims=(-2, -1)) 548 | B_ = torch.rot90(torch.flip(B_,dims=[-1]), k=1, dims=(-2, -1)) 549 | R_ = torch.rot90(torch.flip(R_,dims=[-1]), k=-1, dims=(-2, -1)) 550 | L_F_U_B_R = torch.cat([torch.cat([L_,F_],dim=-1),torch.cat([B_,R_],dim=-1)],dim=-2) 551 | L_F_U_B_R = torch.flip(L_F_U_B_R,dims=[-1]) 552 | U_affine = F.grid_sample(L_F_U_B_R, get_transform_matrix_1('U',L_F_U_B_R.shape,scale,reverse=True), mode='bilinear', align_corners=False) 553 | U_affine = post_process_U3(U_affine) 554 | return torch.cat([F_affine,R_affine,B_affine,L_affine,U_affine,D_affine],dim=-1) 555 | 556 | -------------------------------------------------------------------------------- /src/pipelines/pipeline_wan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import html 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | from dataclasses import dataclass 18 | import ftfy 19 | import regex as re 20 | import torch 21 | from transformers import AutoTokenizer, UMT5EncoderModel 22 | 23 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 24 | from diffusers.loaders import WanLoraLoaderMixin 25 | from diffusers.models import AutoencoderKLWan 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.video_processor import VideoProcessor 30 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 31 | from diffusers.utils import BaseOutput 32 | from ..models.transformer_wan import WanTransformer3DModel 33 | 34 | if is_torch_xla_available(): 35 | import torch_xla.core.xla_model as xm 36 | 37 | XLA_AVAILABLE = True 38 | else: 39 | XLA_AVAILABLE = False 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | EXAMPLE_DOC_STRING = """ 45 | Examples: 46 | ```python 47 | >>> import torch 48 | >>> from diffusers.utils import export_to_video 49 | >>> from diffusers import AutoencoderKLWan, WanPipeline 50 | >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 51 | 52 | >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers 53 | >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 54 | >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 55 | >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 56 | >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P 57 | >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) 58 | >>> pipe.to("cuda") 59 | 60 | >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." 61 | >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 62 | 63 | >>> output = pipe( 64 | ... prompt=prompt, 65 | ... negative_prompt=negative_prompt, 66 | ... height=720, 67 | ... width=1280, 68 | ... num_frames=81, 69 | ... guidance_scale=5.0, 70 | ... ).frames[0] 71 | >>> export_to_video(output, "output.mp4", fps=16) 72 | ``` 73 | """ 74 | 75 | 76 | def basic_clean(text): 77 | text = ftfy.fix_text(text) 78 | text = html.unescape(html.unescape(text)) 79 | return text.strip() 80 | 81 | 82 | def whitespace_clean(text): 83 | text = re.sub(r"\s+", " ", text) 84 | text = text.strip() 85 | return text 86 | 87 | 88 | def prompt_clean(text): 89 | text = whitespace_clean(basic_clean(text)) 90 | return text 91 | 92 | 93 | class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): 94 | r""" 95 | Pipeline for text-to-video generation using Wan. 96 | 97 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 98 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 99 | 100 | Args: 101 | tokenizer ([`T5Tokenizer`]): 102 | Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), 103 | specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 104 | text_encoder ([`T5EncoderModel`]): 105 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 106 | the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 107 | transformer ([`WanTransformer3DModel`]): 108 | Conditional Transformer to denoise the input latents. 109 | scheduler ([`UniPCMultistepScheduler`]): 110 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 111 | vae ([`AutoencoderKLWan`]): 112 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 113 | """ 114 | 115 | model_cpu_offload_seq = "text_encoder->transformer->vae" 116 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 117 | 118 | def __init__( 119 | self, 120 | tokenizer: AutoTokenizer, 121 | text_encoder: UMT5EncoderModel, 122 | transformer: WanTransformer3DModel, 123 | vae: AutoencoderKLWan, 124 | scheduler: FlowMatchEulerDiscreteScheduler, 125 | ): 126 | super().__init__() 127 | 128 | self.register_modules( 129 | vae=vae, 130 | text_encoder=text_encoder, 131 | tokenizer=tokenizer, 132 | transformer=transformer, 133 | scheduler=scheduler, 134 | ) 135 | 136 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 137 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 138 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 139 | 140 | def _get_t5_prompt_embeds( 141 | self, 142 | prompt: Union[str, List[str]] = None, 143 | num_videos_per_prompt: int = 1, 144 | max_sequence_length: int = 226, 145 | device: Optional[torch.device] = None, 146 | dtype: Optional[torch.dtype] = None, 147 | ): 148 | device = device or self._execution_device 149 | dtype = dtype or self.text_encoder.dtype 150 | 151 | prompt = [prompt] if isinstance(prompt, str) else prompt 152 | prompt = [prompt_clean(u) for u in prompt] 153 | batch_size = len(prompt) 154 | 155 | text_inputs = self.tokenizer( 156 | prompt, 157 | padding="max_length", 158 | max_length=max_sequence_length, 159 | truncation=True, 160 | add_special_tokens=True, 161 | return_attention_mask=True, 162 | return_tensors="pt", 163 | ) 164 | text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask 165 | seq_lens = mask.gt(0).sum(dim=1).long() 166 | 167 | prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state 168 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 169 | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] 170 | prompt_embeds = torch.stack( 171 | [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 172 | ) 173 | 174 | # duplicate text embeddings for each generation per prompt, using mps friendly method 175 | _, seq_len, _ = prompt_embeds.shape 176 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 177 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 178 | 179 | return prompt_embeds 180 | 181 | def encode_prompt( 182 | self, 183 | prompt: Union[str, List[str]], 184 | negative_prompt: Optional[Union[str, List[str]]] = None, 185 | do_classifier_free_guidance: bool = True, 186 | num_videos_per_prompt: int = 1, 187 | prompt_embeds: Optional[torch.Tensor] = None, 188 | negative_prompt_embeds: Optional[torch.Tensor] = None, 189 | max_sequence_length: int = 226, 190 | device: Optional[torch.device] = None, 191 | dtype: Optional[torch.dtype] = None, 192 | ): 193 | r""" 194 | Encodes the prompt into text encoder hidden states. 195 | 196 | Args: 197 | prompt (`str` or `List[str]`, *optional*): 198 | prompt to be encoded 199 | negative_prompt (`str` or `List[str]`, *optional*): 200 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 201 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 202 | less than `1`). 203 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 204 | Whether to use classifier free guidance or not. 205 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 206 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 207 | prompt_embeds (`torch.Tensor`, *optional*): 208 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 209 | provided, text embeddings will be generated from `prompt` input argument. 210 | negative_prompt_embeds (`torch.Tensor`, *optional*): 211 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 212 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 213 | argument. 214 | device: (`torch.device`, *optional*): 215 | torch device 216 | dtype: (`torch.dtype`, *optional*): 217 | torch dtype 218 | """ 219 | device = device or self._execution_device 220 | 221 | prompt = [prompt] if isinstance(prompt, str) else prompt 222 | if prompt is not None: 223 | batch_size = len(prompt) 224 | else: 225 | batch_size = prompt_embeds.shape[0] 226 | 227 | if prompt_embeds is None: 228 | prompt_embeds = self._get_t5_prompt_embeds( 229 | prompt=prompt, 230 | num_videos_per_prompt=num_videos_per_prompt, 231 | max_sequence_length=max_sequence_length, 232 | device=device, 233 | dtype=dtype, 234 | ) 235 | 236 | if do_classifier_free_guidance and negative_prompt_embeds is None: 237 | negative_prompt = negative_prompt or "" 238 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 239 | 240 | if prompt is not None and type(prompt) is not type(negative_prompt): 241 | raise TypeError( 242 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 243 | f" {type(prompt)}." 244 | ) 245 | elif batch_size != len(negative_prompt): 246 | raise ValueError( 247 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 248 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 249 | " the batch size of `prompt`." 250 | ) 251 | 252 | negative_prompt_embeds = self._get_t5_prompt_embeds( 253 | prompt=negative_prompt, 254 | num_videos_per_prompt=num_videos_per_prompt, 255 | max_sequence_length=max_sequence_length, 256 | device=device, 257 | dtype=dtype, 258 | ) 259 | 260 | return prompt_embeds, negative_prompt_embeds 261 | 262 | def check_inputs( 263 | self, 264 | prompt, 265 | negative_prompt, 266 | height, 267 | width, 268 | prompt_embeds=None, 269 | negative_prompt_embeds=None, 270 | callback_on_step_end_tensor_inputs=None, 271 | ): 272 | if height % 16 != 0 or width % 16 != 0: 273 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 274 | 275 | if callback_on_step_end_tensor_inputs is not None and not all( 276 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 277 | ): 278 | raise ValueError( 279 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 280 | ) 281 | 282 | if prompt is not None and prompt_embeds is not None: 283 | raise ValueError( 284 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 285 | " only forward one of the two." 286 | ) 287 | elif negative_prompt is not None and negative_prompt_embeds is not None: 288 | raise ValueError( 289 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" 290 | " only forward one of the two." 291 | ) 292 | elif prompt is None and prompt_embeds is None: 293 | raise ValueError( 294 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 295 | ) 296 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 297 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 298 | elif negative_prompt is not None and ( 299 | not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) 300 | ): 301 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 302 | 303 | def prepare_latents( 304 | self, 305 | batch_size: int, 306 | num_channels_latents: int = 16, 307 | height: int = 480, 308 | width: int = 832, 309 | num_frames: int = 81, 310 | dtype: Optional[torch.dtype] = None, 311 | device: Optional[torch.device] = None, 312 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 313 | latents: Optional[torch.Tensor] = None, 314 | ) -> torch.Tensor: 315 | if latents is not None: 316 | return latents.to(device=device, dtype=dtype) 317 | 318 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 319 | shape = ( 320 | batch_size, 321 | num_channels_latents, 322 | num_latent_frames, 323 | int(height) // self.vae_scale_factor_spatial, 324 | int(width) // self.vae_scale_factor_spatial, 325 | ) 326 | if isinstance(generator, list) and len(generator) != batch_size: 327 | raise ValueError( 328 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 329 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 330 | ) 331 | 332 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 333 | return latents 334 | 335 | @property 336 | def guidance_scale(self): 337 | return self._guidance_scale 338 | 339 | @property 340 | def do_classifier_free_guidance(self): 341 | return self._guidance_scale > 1.0 342 | 343 | @property 344 | def num_timesteps(self): 345 | return self._num_timesteps 346 | 347 | @property 348 | def current_timestep(self): 349 | return self._current_timestep 350 | 351 | @property 352 | def interrupt(self): 353 | return self._interrupt 354 | 355 | @property 356 | def attention_kwargs(self): 357 | return self._attention_kwargs 358 | 359 | @torch.no_grad() 360 | @replace_example_docstring(EXAMPLE_DOC_STRING) 361 | def __call__( 362 | self, 363 | prompt: Union[str, List[str]] = None, 364 | negative_prompt: Union[str, List[str]] = None, 365 | height: int = 480, 366 | width: int = 832, 367 | num_frames: int = 81, 368 | num_inference_steps: int = 50, 369 | guidance_scale: float = 5.0, 370 | num_videos_per_prompt: Optional[int] = 1, 371 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 372 | latents: Optional[torch.Tensor] = None, 373 | prompt_embeds: Optional[torch.Tensor] = None, 374 | negative_prompt_embeds: Optional[torch.Tensor] = None, 375 | output_type: Optional[str] = "np", 376 | return_dict: bool = True, 377 | attention_kwargs: Optional[Dict[str, Any]] = None, 378 | callback_on_step_end: Optional[ 379 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 380 | ] = None, 381 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 382 | max_sequence_length: int = 512, 383 | ): 384 | r""" 385 | The call function to the pipeline for generation. 386 | 387 | Args: 388 | prompt (`str` or `List[str]`, *optional*): 389 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 390 | instead. 391 | height (`int`, defaults to `480`): 392 | The height in pixels of the generated image. 393 | width (`int`, defaults to `832`): 394 | The width in pixels of the generated image. 395 | num_frames (`int`, defaults to `81`): 396 | The number of frames in the generated video. 397 | num_inference_steps (`int`, defaults to `50`): 398 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 399 | expense of slower inference. 400 | guidance_scale (`float`, defaults to `5.0`): 401 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 402 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 403 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 404 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 405 | usually at the expense of lower image quality. 406 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 407 | The number of images to generate per prompt. 408 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 409 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 410 | generation deterministic. 411 | latents (`torch.Tensor`, *optional*): 412 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 413 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 414 | tensor is generated by sampling using the supplied random `generator`. 415 | prompt_embeds (`torch.Tensor`, *optional*): 416 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 417 | provided, text embeddings are generated from the `prompt` input argument. 418 | output_type (`str`, *optional*, defaults to `"pil"`): 419 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 420 | return_dict (`bool`, *optional*, defaults to `True`): 421 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. 422 | attention_kwargs (`dict`, *optional*): 423 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 424 | `self.processor` in 425 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 426 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 427 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 428 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 429 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 430 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 431 | callback_on_step_end_tensor_inputs (`List`, *optional*): 432 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 433 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 434 | `._callback_tensor_inputs` attribute of your pipeline class. 435 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): 436 | The dtype to use for the torch.amp.autocast. 437 | 438 | Examples: 439 | 440 | Returns: 441 | [`~WanPipelineOutput`] or `tuple`: 442 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where 443 | the first element is a list with the generated images and the second element is a list of `bool`s 444 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 445 | """ 446 | 447 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 448 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 449 | 450 | # 1. Check inputs. Raise error if not correct 451 | self.check_inputs( 452 | prompt, 453 | negative_prompt, 454 | height, 455 | width, 456 | prompt_embeds, 457 | negative_prompt_embeds, 458 | callback_on_step_end_tensor_inputs, 459 | ) 460 | 461 | self._guidance_scale = guidance_scale 462 | self._attention_kwargs = attention_kwargs 463 | self._current_timestep = None 464 | self._interrupt = False 465 | 466 | device = self._execution_device 467 | 468 | # 2. Define call parameters 469 | if prompt is not None and isinstance(prompt, str): 470 | batch_size = 1 471 | elif prompt is not None and isinstance(prompt, list): 472 | batch_size = len(prompt) 473 | else: 474 | batch_size = prompt_embeds.shape[0] 475 | 476 | # 3. Encode input prompt 477 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 478 | prompt=prompt, 479 | negative_prompt=negative_prompt, 480 | do_classifier_free_guidance=self.do_classifier_free_guidance, 481 | num_videos_per_prompt=num_videos_per_prompt, 482 | prompt_embeds=prompt_embeds, 483 | negative_prompt_embeds=negative_prompt_embeds, 484 | max_sequence_length=max_sequence_length, 485 | device=device, 486 | ) 487 | 488 | transformer_dtype = self.transformer.dtype 489 | prompt_embeds = prompt_embeds.to(transformer_dtype) 490 | if negative_prompt_embeds is not None: 491 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) 492 | 493 | # 4. Prepare timesteps 494 | self.scheduler.set_timesteps(num_inference_steps, device=device) 495 | timesteps = self.scheduler.timesteps 496 | 497 | # 5. Prepare latent variables 498 | num_channels_latents = self.transformer.config.in_channels 499 | latents = self.prepare_latents( 500 | batch_size * num_videos_per_prompt, 501 | num_channels_latents, 502 | height, 503 | width, 504 | num_frames, 505 | torch.float32, 506 | device, 507 | generator, 508 | latents, 509 | ) 510 | 511 | # 6. Denoising loop 512 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 513 | self._num_timesteps = len(timesteps) 514 | 515 | with self.progress_bar(total=num_inference_steps) as progress_bar: 516 | for i, t in enumerate(timesteps): 517 | if self.interrupt: 518 | continue 519 | 520 | self._current_timestep = t 521 | latent_model_input = latents.to(transformer_dtype) 522 | timestep = t.expand(latents.shape[0]) 523 | 524 | noise_pred = self.transformer( 525 | hidden_states=latent_model_input, 526 | timestep=timestep, 527 | encoder_hidden_states=prompt_embeds, 528 | attention_kwargs=attention_kwargs, 529 | return_dict=False, 530 | )[0] 531 | 532 | if self.do_classifier_free_guidance: 533 | noise_uncond = self.transformer( 534 | hidden_states=latent_model_input, 535 | timestep=timestep, 536 | encoder_hidden_states=negative_prompt_embeds, 537 | attention_kwargs=attention_kwargs, 538 | return_dict=False, 539 | )[0] 540 | noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) 541 | 542 | # compute the previous noisy sample x_t -> x_t-1 543 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 544 | 545 | if callback_on_step_end is not None: 546 | callback_kwargs = {} 547 | for k in callback_on_step_end_tensor_inputs: 548 | callback_kwargs[k] = locals()[k] 549 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 550 | 551 | latents = callback_outputs.pop("latents", latents) 552 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 553 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 554 | 555 | # call the callback, if provided 556 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 557 | progress_bar.update() 558 | 559 | if XLA_AVAILABLE: 560 | xm.mark_step() 561 | 562 | self._current_timestep = None 563 | 564 | if not output_type == "latent": 565 | latents = latents.to(self.vae.dtype) 566 | latents_mean = ( 567 | torch.tensor(self.vae.config.latents_mean) 568 | .view(1, self.vae.config.z_dim, 1, 1, 1) 569 | .to(latents.device, latents.dtype) 570 | ) 571 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 572 | latents.device, latents.dtype 573 | ) 574 | latents = latents / latents_std + latents_mean 575 | video = self.vae.decode(latents, return_dict=False)[0] 576 | video = self.video_processor.postprocess_video(video, output_type=output_type) 577 | else: 578 | video = latents 579 | 580 | # Offload all models 581 | self.maybe_free_model_hooks() 582 | 583 | if not return_dict: 584 | return (video,) 585 | 586 | return WanPipelineOutput(frames=video) 587 | 588 | @dataclass 589 | class WanPipelineOutput(BaseOutput): 590 | r""" 591 | Output class for Wan pipelines. 592 | 593 | Args: 594 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): 595 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing 596 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape 597 | `(batch_size, num_frames, channels, height, width)`. 598 | """ 599 | 600 | frames: torch.Tensor -------------------------------------------------------------------------------- /src/pipelines/pipeline_v.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import html 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | from dataclasses import dataclass 18 | import ftfy 19 | import regex as re 20 | import math 21 | import numpy as np 22 | import torch 23 | import torch.nn.functional as F 24 | from transformers import AutoTokenizer, UMT5EncoderModel 25 | from PIL import Image 26 | from einops import rearrange 27 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 28 | from diffusers.loaders import WanLoraLoaderMixin 29 | from diffusers.models import AutoencoderKLWan 30 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution 31 | 32 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 33 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 34 | from diffusers.utils.torch_utils import randn_tensor 35 | from diffusers.video_processor import VideoProcessor 36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 37 | from diffusers.utils import BaseOutput 38 | from utils import get_condition_tensors, get_mask 39 | from ..models.transformer_v import ViewTransformer3DModel 40 | 41 | if is_torch_xla_available(): 42 | import torch_xla.core.xla_model as xm 43 | 44 | XLA_AVAILABLE = True 45 | else: 46 | XLA_AVAILABLE = False 47 | 48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 49 | 50 | 51 | EXAMPLE_DOC_STRING = """ 52 | Examples: 53 | ```python 54 | >>> import torch 55 | >>> from diffusers.utils import export_to_video 56 | >>> from diffusers import AutoencoderKLWan, WanPipeline 57 | >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 58 | 59 | >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers 60 | >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 61 | >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 62 | >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 63 | >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P 64 | >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) 65 | >>> pipe.to("cuda") 66 | 67 | >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." 68 | >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 69 | 70 | >>> output = pipe( 71 | ... prompt=prompt, 72 | ... negative_prompt=negative_prompt, 73 | ... height=720, 74 | ... width=1280, 75 | ... num_frames=81, 76 | ... guidance_scale=5.0, 77 | ... ).frames[0] 78 | >>> export_to_video(output, "output.mp4", fps=16) 79 | ``` 80 | """ 81 | 82 | 83 | def basic_clean(text): 84 | text = ftfy.fix_text(text) 85 | text = html.unescape(html.unescape(text)) 86 | return text.strip() 87 | 88 | 89 | def whitespace_clean(text): 90 | text = re.sub(r"\s+", " ", text) 91 | text = text.strip() 92 | return text 93 | 94 | 95 | def prompt_clean(text): 96 | text = whitespace_clean(basic_clean(text)) 97 | return text 98 | 99 | 100 | class ViewPipeline(DiffusionPipeline, WanLoraLoaderMixin): 101 | r""" 102 | Pipeline for text-to-video generation using Wan. 103 | 104 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 105 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 106 | 107 | Args: 108 | tokenizer ([`T5Tokenizer`]): 109 | Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), 110 | specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 111 | text_encoder ([`T5EncoderModel`]): 112 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 113 | the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 114 | transformer ([`WanTransformer3DModel`]): 115 | Conditional Transformer to denoise the input latents. 116 | scheduler ([`UniPCMultistepScheduler`]): 117 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 118 | vae ([`AutoencoderKLWan`]): 119 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 120 | """ 121 | 122 | model_cpu_offload_seq = "text_encoder->transformer->vae" 123 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 124 | 125 | def __init__( 126 | self, 127 | tokenizer: AutoTokenizer, 128 | text_encoder: UMT5EncoderModel, 129 | transformer: ViewTransformer3DModel, 130 | vae: AutoencoderKLWan, 131 | scheduler: FlowMatchEulerDiscreteScheduler, 132 | ): 133 | super().__init__() 134 | 135 | self.register_modules( 136 | vae=vae, 137 | text_encoder=text_encoder, 138 | tokenizer=tokenizer, 139 | transformer=transformer, 140 | scheduler=scheduler, 141 | ) 142 | 143 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 144 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 145 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 146 | 147 | def _get_t5_prompt_embeds( 148 | self, 149 | prompt: Union[str, List[str]] = None, 150 | num_videos_per_prompt: int = 1, 151 | max_sequence_length: int = 226, 152 | device: Optional[torch.device] = None, 153 | dtype: Optional[torch.dtype] = None, 154 | ): 155 | device = device or self._execution_device 156 | dtype = dtype or self.text_encoder.dtype 157 | 158 | prompt = [prompt] if isinstance(prompt, str) else prompt 159 | prompt = [prompt_clean(u) for u in prompt] 160 | batch_size = len(prompt) 161 | 162 | text_inputs = self.tokenizer( 163 | prompt, 164 | padding="max_length", 165 | max_length=max_sequence_length, 166 | truncation=True, 167 | add_special_tokens=True, 168 | return_attention_mask=True, 169 | return_tensors="pt", 170 | ) 171 | text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask 172 | seq_lens = mask.gt(0).sum(dim=1).long() 173 | 174 | prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state 175 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 176 | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] 177 | prompt_embeds = torch.stack( 178 | [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 179 | ) 180 | 181 | # duplicate text embeddings for each generation per prompt, using mps friendly method 182 | _, seq_len, _ = prompt_embeds.shape 183 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 184 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 185 | 186 | return prompt_embeds 187 | 188 | def encode_prompt( 189 | self, 190 | prompt: Union[str, List[str]], 191 | negative_prompt: Optional[Union[str, List[str]]] = None, 192 | do_classifier_free_guidance: bool = True, 193 | num_videos_per_prompt: int = 1, 194 | prompt_embeds: Optional[torch.Tensor] = None, 195 | negative_prompt_embeds: Optional[torch.Tensor] = None, 196 | max_sequence_length: int = 226, 197 | device: Optional[torch.device] = None, 198 | dtype: Optional[torch.dtype] = None, 199 | ): 200 | r""" 201 | Encodes the prompt into text encoder hidden states. 202 | 203 | Args: 204 | prompt (`str` or `List[str]`, *optional*): 205 | prompt to be encoded 206 | negative_prompt (`str` or `List[str]`, *optional*): 207 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 208 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 209 | less than `1`). 210 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 211 | Whether to use classifier free guidance or not. 212 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 213 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 214 | prompt_embeds (`torch.Tensor`, *optional*): 215 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 216 | provided, text embeddings will be generated from `prompt` input argument. 217 | negative_prompt_embeds (`torch.Tensor`, *optional*): 218 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 219 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 220 | argument. 221 | device: (`torch.device`, *optional*): 222 | torch device 223 | dtype: (`torch.dtype`, *optional*): 224 | torch dtype 225 | """ 226 | device = device or self._execution_device 227 | 228 | prompt = [prompt] if isinstance(prompt, str) else prompt 229 | if prompt is not None: 230 | batch_size = len(prompt) 231 | else: 232 | batch_size = prompt_embeds.shape[0] 233 | 234 | if prompt_embeds is None: 235 | prompt_embeds = self._get_t5_prompt_embeds( 236 | prompt=prompt, 237 | num_videos_per_prompt=num_videos_per_prompt, 238 | max_sequence_length=max_sequence_length, 239 | device=device, 240 | dtype=dtype, 241 | ) 242 | 243 | if do_classifier_free_guidance and negative_prompt_embeds is None: 244 | negative_prompt = negative_prompt or "" 245 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 246 | 247 | if prompt is not None and type(prompt) is not type(negative_prompt): 248 | raise TypeError( 249 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 250 | f" {type(prompt)}." 251 | ) 252 | elif batch_size != len(negative_prompt): 253 | raise ValueError( 254 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 255 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 256 | " the batch size of `prompt`." 257 | ) 258 | 259 | negative_prompt_embeds = self._get_t5_prompt_embeds( 260 | prompt=negative_prompt, 261 | num_videos_per_prompt=num_videos_per_prompt, 262 | max_sequence_length=max_sequence_length, 263 | device=device, 264 | dtype=dtype, 265 | ) 266 | 267 | return prompt_embeds, negative_prompt_embeds 268 | 269 | def check_inputs( 270 | self, 271 | prompt, 272 | negative_prompt, 273 | height, 274 | width, 275 | prompt_embeds=None, 276 | negative_prompt_embeds=None, 277 | callback_on_step_end_tensor_inputs=None, 278 | ): 279 | if height % 16 != 0 or width % 16 != 0: 280 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 281 | 282 | if callback_on_step_end_tensor_inputs is not None and not all( 283 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 284 | ): 285 | raise ValueError( 286 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 287 | ) 288 | 289 | if prompt is not None and prompt_embeds is not None: 290 | raise ValueError( 291 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 292 | " only forward one of the two." 293 | ) 294 | elif negative_prompt is not None and negative_prompt_embeds is not None: 295 | raise ValueError( 296 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" 297 | " only forward one of the two." 298 | ) 299 | elif prompt is None and prompt_embeds is None: 300 | raise ValueError( 301 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 302 | ) 303 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 304 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 305 | elif negative_prompt is not None and ( 306 | not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) 307 | ): 308 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 309 | 310 | @staticmethod 311 | def _normalize_latents( 312 | latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor 313 | ) -> torch.Tensor: 314 | latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) 315 | latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) 316 | latents = ((latents.float() - latents_mean) * latents_std).to(latents) 317 | return latents 318 | 319 | @torch.no_grad() 320 | def prepare_conditon_latents_and_mask( 321 | self, 322 | batch_size: int, 323 | num_channels_latents: int = 16, 324 | height: int = 480, 325 | width: int = 832, 326 | video: Optional[torch.Tensor] = None, 327 | selected_cube: str = 'null', 328 | num_frames: int = 81, 329 | dtype: Optional[torch.dtype] = None, 330 | device: Optional[torch.device] = None, 331 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 332 | compute_posterior: bool = False, 333 | ) -> torch.Tensor: 334 | 335 | if selected_cube == 'null': 336 | cube_mask = get_mask(selected_cube,(batch_size,num_channels_latents,height,width)) 337 | cube_mask = F.interpolate(cube_mask, size=(cube_mask.shape[-2]//8,cube_mask.shape[-1]//8), mode='bilinear', align_corners=False) 338 | cube_mask = cube_mask.unsqueeze(2).repeat(1,1,(num_frames-1)//4+1,2,2) 339 | conditon_latents_and_mask = cube_mask.repeat(1,17,1,1,1) 340 | conditon_latents_and_mask = conditon_latents_and_mask.to(device=device,dtype=dtype) 341 | return conditon_latents_and_mask 342 | 343 | video = video[:num_frames,:,:,:] 344 | video = video.unsqueeze(0)/127.5 - 1.0 345 | assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" 346 | video = video.to(device=device, non_blocking=True) 347 | video = video.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] -> [B, C, F, H, W] 348 | 349 | b_,c_,f_,h_,w_ = video.shape 350 | 351 | cube_mask = get_mask(selected_cube,(b_,c_,h_,w_)) 352 | cube_mask = F.interpolate(cube_mask, size=(cube_mask.shape[-2]//8,cube_mask.shape[-1]//8), mode='bilinear', align_corners=False) 353 | cube_mask = cube_mask.unsqueeze(2).repeat(1,1,(f_-1)//4+1,1,1) 354 | cube_mask = cube_mask.to(device=device,dtype=dtype) 355 | 356 | # video = video.unsqueeze(1) 357 | # video = get_condition_tensors(selected_cube, video) 358 | 359 | video = get_condition_tensors(selected_cube,video.unsqueeze(1).float()) 360 | 361 | video = video.to(dtype=dtype).contiguous() 362 | 363 | 364 | if compute_posterior: 365 | latents = self.vae.encode(video).latent_dist.sample(generator=generator) 366 | latents = latents.to(dtype=dtype) 367 | else: 368 | # TODO(aryan): refactor in diffusers to have use_slicing attribute 369 | # if vae.use_slicing and video.shape[0] > 1: 370 | # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] 371 | # moments = torch.cat(encoded_slices) 372 | # else: 373 | # moments = vae._encode(video) 374 | moments = self.vae._encode(video) 375 | latents = moments.to(dtype=dtype) 376 | 377 | latents_mean = torch.tensor(self.vae.config.latents_mean) 378 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std) 379 | mu, logvar = torch.chunk(latents, 2, dim=1) 380 | mu = self._normalize_latents(mu, latents_mean, latents_std) 381 | logvar = self._normalize_latents(logvar, latents_mean, latents_std) 382 | latents = torch.cat([mu, logvar], dim=1) 383 | 384 | posterior = DiagonalGaussianDistribution(latents) 385 | latents = posterior.sample(generator=generator) 386 | del posterior 387 | 388 | # new_h = math.ceil(np.sqrt(2)*height/8) 389 | # new_w = math.ceil(np.sqrt(2)*width/8) 390 | #latents = get_condition_tensors(selected_cube, latents.unsqueeze(1).float(),new_h,new_w).to(latents) 391 | b,c,f,H,W = latents.shape 392 | 393 | condition_latents_canvas = latents #torch.zeros((b,c,f,h,w)).to(latents) 394 | condition_latents_mask = torch.zeros((b,1,f,H,W)).to(latents) 395 | 396 | h = H//2 397 | w = W//2 398 | if selected_cube == 'L': 399 | # condition_latents_canvas[:,:,:,0:h,0:w] = latents 400 | condition_latents_mask[:,:,:,0:h,0:w] = cube_mask 401 | elif selected_cube == 'F': 402 | #condition_latents_canvas[:,:,:,0:h,w:2*w] = latents 403 | condition_latents_mask[:,:,:,0:h,w:2*w] = cube_mask 404 | elif selected_cube == 'B': 405 | # condition_latents_canvas[:,:,:,h:2*h,0:w] = latents 406 | condition_latents_mask[:,:,:,h:2*h,0:w] = cube_mask 407 | elif selected_cube == 'R': 408 | # condition_latents_canvas[:,:,:,h:2*h,w:2*w] = latents 409 | condition_latents_mask[:,:,:,h:2*h,w:2*w] = cube_mask 410 | elif selected_cube == 'D': 411 | # condition_latents_canvas[:,:,:,h//2:3*(h//2),w//2:3*(w//2)] = latents 412 | condition_latents_mask[:,:,:,h//2:3*(h//2),w//2:3*(w//2)] = cube_mask 413 | return torch.cat([condition_latents_canvas,condition_latents_mask],dim=1) 414 | 415 | def prepare_latents( 416 | self, 417 | batch_size: int, 418 | num_channels_latents: int = 16, 419 | height: int = 480, 420 | width: int = 832, 421 | num_frames: int = 81, 422 | dtype: Optional[torch.dtype] = None, 423 | device: Optional[torch.device] = None, 424 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 425 | latents: Optional[torch.Tensor] = None, 426 | ) -> torch.Tensor: 427 | if latents is not None: 428 | return latents.to(device=device, dtype=dtype) 429 | 430 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 431 | shape = ( 432 | batch_size, 433 | num_channels_latents, 434 | num_latent_frames, 435 | int(height), 436 | int(width), 437 | ) 438 | if isinstance(generator, list) and len(generator) != batch_size: 439 | raise ValueError( 440 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 441 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 442 | ) 443 | 444 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 445 | return latents 446 | 447 | @property 448 | def guidance_scale(self): 449 | return self._guidance_scale 450 | 451 | @property 452 | def do_classifier_free_guidance(self): 453 | return self._guidance_scale > 1.0 454 | 455 | @property 456 | def num_timesteps(self): 457 | return self._num_timesteps 458 | 459 | @property 460 | def current_timestep(self): 461 | return self._current_timestep 462 | 463 | @property 464 | def interrupt(self): 465 | return self._interrupt 466 | 467 | @property 468 | def attention_kwargs(self): 469 | return self._attention_kwargs 470 | 471 | @torch.no_grad() 472 | @replace_example_docstring(EXAMPLE_DOC_STRING) 473 | def __call__( 474 | self, 475 | prompt: Union[str, List[str]] = None, 476 | negative_prompt: Union[str, List[str]] = None, 477 | video: Optional[torch.Tensor] = None, 478 | selected_cube: str = 'null', 479 | height: int = 480, 480 | width: int = 832, 481 | num_frames: int = 81, 482 | num_inference_steps: int = 50, 483 | guidance_scale: float = 5.0, 484 | num_videos_per_prompt: Optional[int] = 1, 485 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 486 | latents: Optional[torch.Tensor] = None, 487 | prompt_embeds: Optional[torch.Tensor] = None, 488 | negative_prompt_embeds: Optional[torch.Tensor] = None, 489 | output_type: Optional[str] = "np", 490 | return_dict: bool = True, 491 | attention_kwargs: Optional[Dict[str, Any]] = None, 492 | callback_on_step_end: Optional[ 493 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 494 | ] = None, 495 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 496 | max_sequence_length: int = 512, 497 | 498 | rotate: bool = False, 499 | ): 500 | r""" 501 | The call function to the pipeline for generation. 502 | 503 | Args: 504 | prompt (`str` or `List[str]`, *optional*): 505 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 506 | instead. 507 | height (`int`, defaults to `480`): 508 | The height in pixels of the generated image. 509 | width (`int`, defaults to `832`): 510 | The width in pixels of the generated image. 511 | num_frames (`int`, defaults to `81`): 512 | The number of frames in the generated video. 513 | num_inference_steps (`int`, defaults to `50`): 514 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 515 | expense of slower inference. 516 | guidance_scale (`float`, defaults to `5.0`): 517 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 518 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 519 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 520 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 521 | usually at the expense of lower image quality. 522 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 523 | The number of images to generate per prompt. 524 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 525 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 526 | generation deterministic. 527 | latents (`torch.Tensor`, *optional*): 528 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 529 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 530 | tensor is generated by sampling using the supplied random `generator`. 531 | prompt_embeds (`torch.Tensor`, *optional*): 532 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 533 | provided, text embeddings are generated from the `prompt` input argument. 534 | output_type (`str`, *optional*, defaults to `"pil"`): 535 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 536 | return_dict (`bool`, *optional*, defaults to `True`): 537 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. 538 | attention_kwargs (`dict`, *optional*): 539 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 540 | `self.processor` in 541 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 542 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 543 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 544 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 545 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 546 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 547 | callback_on_step_end_tensor_inputs (`List`, *optional*): 548 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 549 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 550 | `._callback_tensor_inputs` attribute of your pipeline class. 551 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): 552 | The dtype to use for the torch.amp.autocast. 553 | 554 | Examples: 555 | 556 | Returns: 557 | [`~WanPipelineOutput`] or `tuple`: 558 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where 559 | the first element is a list with the generated images and the second element is a list of `bool`s 560 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 561 | """ 562 | 563 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 564 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 565 | 566 | # 1. Check inputs. Raise error if not correct 567 | self.check_inputs( 568 | prompt, 569 | negative_prompt, 570 | height, 571 | width, 572 | prompt_embeds, 573 | negative_prompt_embeds, 574 | callback_on_step_end_tensor_inputs, 575 | ) 576 | 577 | self._guidance_scale = guidance_scale 578 | self._attention_kwargs = attention_kwargs 579 | self._current_timestep = None 580 | self._interrupt = False 581 | 582 | device = self._execution_device 583 | 584 | # 2. Define call parameters 585 | if prompt is not None and isinstance(prompt, str): 586 | batch_size = 1 587 | elif prompt is not None and isinstance(prompt, list): 588 | batch_size = len(prompt) 589 | else: 590 | batch_size = prompt_embeds.shape[0] 591 | 592 | # 3. Encode input prompt 593 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 594 | prompt=prompt, 595 | negative_prompt=negative_prompt, 596 | do_classifier_free_guidance=self.do_classifier_free_guidance, 597 | num_videos_per_prompt=num_videos_per_prompt, 598 | prompt_embeds=prompt_embeds, 599 | negative_prompt_embeds=negative_prompt_embeds, 600 | max_sequence_length=max_sequence_length, 601 | device=device, 602 | ) 603 | 604 | transformer_dtype = self.transformer.dtype 605 | prompt_embeds = prompt_embeds.to(transformer_dtype) 606 | if negative_prompt_embeds is not None: 607 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) 608 | 609 | # 4. Prepare timesteps 610 | self.scheduler.set_timesteps(num_inference_steps, device=device) 611 | timesteps = self.scheduler.timesteps 612 | 613 | # 5. Prepare latent variables 614 | num_channels_latents = 16 615 | 616 | # TODO(fzx): condition latents and mask 617 | condition_latents_mask = self.prepare_conditon_latents_and_mask( 618 | batch_size * num_videos_per_prompt, 619 | num_channels_latents, 620 | height, 621 | width, 622 | video, 623 | selected_cube, 624 | num_frames, 625 | self.vae.dtype, 626 | device, 627 | generator, 628 | ) 629 | 630 | latents = self.prepare_latents( 631 | batch_size * num_videos_per_prompt, 632 | num_channels_latents, 633 | condition_latents_mask.shape[-2], 634 | condition_latents_mask.shape[-1], 635 | num_frames, 636 | torch.float32, 637 | device, 638 | generator, 639 | latents, 640 | ) 641 | 642 | # 6. Denoising loop 643 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 644 | self._num_timesteps = len(timesteps) 645 | 646 | with self.progress_bar(total=num_inference_steps) as progress_bar: 647 | for i, t in enumerate(timesteps): 648 | if self.interrupt: 649 | continue 650 | 651 | self._current_timestep = t 652 | latent_model_input = torch.cat([latents.to(transformer_dtype),condition_latents_mask],dim=1) 653 | timestep = t.expand(latents.shape[0]) 654 | 655 | noise_pred = self.transformer( 656 | hidden_states=latent_model_input, 657 | timestep=timestep, 658 | encoder_hidden_states=prompt_embeds, 659 | attention_kwargs=attention_kwargs, 660 | return_dict=False, 661 | )[0] 662 | 663 | if self.do_classifier_free_guidance: 664 | noise_uncond = self.transformer( 665 | hidden_states=latent_model_input, 666 | timestep=timestep, 667 | encoder_hidden_states=negative_prompt_embeds, 668 | attention_kwargs=attention_kwargs, 669 | return_dict=False, 670 | )[0] 671 | noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) 672 | 673 | # compute the previous noisy sample x_t -> x_t-1 674 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 675 | if rotate and (i < 8): 676 | latents = torch.rot90(latents, k=1, dims=(-2, -1)) 677 | condition_latents_mask = torch.rot90(condition_latents_mask, k=1, dims=(-2, -1)) 678 | 679 | if callback_on_step_end is not None: 680 | callback_kwargs = {} 681 | for k in callback_on_step_end_tensor_inputs: 682 | callback_kwargs[k] = locals()[k] 683 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 684 | 685 | latents = callback_outputs.pop("latents", latents) 686 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 687 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 688 | 689 | # call the callback, if provided 690 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 691 | progress_bar.update() 692 | 693 | if XLA_AVAILABLE: 694 | xm.mark_step() 695 | 696 | self._current_timestep = None 697 | 698 | if not output_type == "latent": 699 | # TODO(fzx): reshape to decode 700 | 701 | # L_F, B_R = torch.chunk(latents, 2, dim=-2) 702 | # L,F = torch.chunk(L_F, 2, dim=-1) 703 | # B,R = torch.chunk(B_R, 2, dim=-1) 704 | # latents = torch.cat([L,F,B,R], dim=0) 705 | 706 | latents = latents.to(self.vae.dtype) 707 | latents_mean = ( 708 | torch.tensor(self.vae.config.latents_mean) 709 | .view(1, self.vae.config.z_dim, 1, 1, 1) 710 | .to(latents.device, latents.dtype) 711 | ) 712 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 713 | latents.device, latents.dtype 714 | ) 715 | latents = latents / latents_std + latents_mean 716 | video = self.vae.decode(latents, return_dict=False)[0] 717 | video = self.video_processor.postprocess_video(video, output_type=output_type) 718 | else: 719 | video = latents 720 | 721 | # Offload all models 722 | self.maybe_free_model_hooks() 723 | 724 | if not return_dict: 725 | return (video,) 726 | 727 | return WanPipelineOutput(frames=video) 728 | 729 | @dataclass 730 | class WanPipelineOutput(BaseOutput): 731 | r""" 732 | Output class for Wan pipelines. 733 | 734 | Args: 735 | frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): 736 | List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing 737 | denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape 738 | `(batch_size, num_frames, channels, height, width)`. 739 | """ 740 | 741 | frames: torch.Tensor --------------------------------------------------------------------------------