├── gvm ├── __init__.py ├── models │ ├── __init__.py │ └── unet_spatio_temporal_condition.py ├── utils │ ├── __init__.py │ └── inference_utils.py └── pipelines │ └── pipeline_gvm.py ├── requirements.txt ├── setup.py ├── README.md └── demo.py /gvm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gvm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gvm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==11.0.0 2 | diffusers==0.35.0.dev0 3 | diffusers.egg==info 4 | easydict==1.13 5 | imageio==2.37.0 6 | matplotlib==3.10.5 7 | numpy==2.3.2 8 | opencv_python_headless==4.11.0.86 9 | peft==0.17.0 10 | Pillow==11.3.0 11 | PIMS==0.7 12 | setuptools==75.9.1 13 | torch==2.6.0 14 | torchvision==0.21.0 15 | tqdm==4.67.1 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import find_packages, setup 5 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 6 | CUDAExtension) 7 | 8 | with open('README.md', 'r') as fh: 9 | long_description = fh.read() 10 | 11 | BUILD_CUDA = os.getenv("BUILD_CUDA", "1") == "1" 12 | BUILD_ALLOW_ERRORS = os.getenv("BUILD_ALLOW_ERRORS", "1") == "1" 13 | 14 | CUDA_ERROR_MSG = ( 15 | "{}\n\n" 16 | "Failed to build the CUDA extension due to the error above. " 17 | "You can still use and it's OK to ignore the error above, although some " 18 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 19 | "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n" 20 | ) 21 | 22 | 23 | setup( 24 | name='gvm', 25 | version='0.0.1', 26 | author='yongtaoge', 27 | author_email='yongtao.ge@adelaide.edu.au', 28 | description='Code for Generative Video Matting.', 29 | long_description=long_description, 30 | long_description_content_type='text/markdown', 31 | url=None, 32 | packages=find_packages(exclude=('configs', 'docs', 'scripts', 'extensions', 'data', 'requirements'),), 33 | classifiers=[ 34 | 'Programming Language :: Python :: 3', 35 | 'Operating System :: OS Independent', 36 | ], 37 | install_requires=[], 38 | ) 39 | -------------------------------------------------------------------------------- /gvm/utils/inference_utils.py: -------------------------------------------------------------------------------- 1 | import av 2 | import os 3 | import pims 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms.functional import to_pil_image 8 | from PIL import Image 9 | 10 | 11 | class VideoReader(Dataset): 12 | def __init__(self, path, max_frames=None, transform=None): 13 | self.video = pims.PyAVVideoReader(path) 14 | self.rate = self.video.frame_rate 15 | self.transform = transform 16 | self.max_frames = max_frames 17 | 18 | @property 19 | def frame_rate(self): 20 | return self.rate 21 | 22 | @property 23 | def origin_shape(self): 24 | return self.video[0].shape[:2] 25 | 26 | def __len__(self): 27 | if self.max_frames is not None and self.max_frames > 0: 28 | return min(len(self.video), self.max_frames) 29 | else: 30 | return len(self.video) 31 | 32 | def __getitem__(self, idx): 33 | frame = self.video[idx] 34 | frame = Image.fromarray(np.asarray(frame)) 35 | if self.transform is not None: 36 | frame = self.transform(frame) 37 | return frame 38 | 39 | 40 | class VideoWriter: 41 | def __init__(self, path, frame_rate, bit_rate=1000000): 42 | self.container = av.open(path, mode='w') 43 | # self.container.add_stream('h264', rate=30) 44 | self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}') 45 | self.stream.pix_fmt = 'yuv420p' 46 | self.stream.bit_rate = bit_rate 47 | 48 | def write(self, frames): 49 | 50 | # frames: [T, C, H, W] 51 | self.stream.width = frames.size(3) 52 | self.stream.height = frames.size(2) 53 | if frames.size(1) == 1: 54 | frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB 55 | frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() 56 | 57 | for t in range(frames.shape[0]): 58 | frame = frames[t] 59 | frame = av.VideoFrame.from_ndarray(frame, format='rgb24') 60 | self.container.mux(self.stream.encode(frame)) 61 | 62 | def write_numpy(self, frames): 63 | 64 | # frames: [T, H, W, C] 65 | self.stream.height = frames.shape[1] 66 | self.stream.width = frames.shape[2] 67 | 68 | for t in range(frames.shape[0]): 69 | frame = frames[t] 70 | frame = av.VideoFrame.from_ndarray(frame, format='rgb24') 71 | self.container.mux(self.stream.encode(frame)) 72 | 73 | def close(self): 74 | self.container.mux(self.stream.encode()) 75 | self.container.close() 76 | 77 | 78 | class ImageSequenceReader(Dataset): 79 | def __init__(self, path, transform=None): 80 | self.path = path 81 | self.files = sorted(os.listdir(path)) 82 | self.transform = transform 83 | 84 | @property 85 | def origin_shape(self): 86 | return np.array(Image.open(os.path.join(self.path, self.files[0]))).shape[:2] 87 | 88 | def __len__(self): 89 | return len(self.files) 90 | 91 | def __getitem__(self, idx): 92 | with Image.open(os.path.join(self.path, self.files[idx])) as img: 93 | img.load() 94 | 95 | origin_shape = torch.from_numpy(np.asarray(np.array(img).shape[:2])) 96 | 97 | if self.transform is not None: 98 | img, filename = self.transform(img), self.files[idx] 99 | else: 100 | filename = self.files[idx] 101 | 102 | return {"image": img, "filename": filename, "origin_shape": origin_shape} 103 | 104 | 105 | class ImageSequenceWriter: 106 | def __init__(self, path, extension='jpg'): 107 | self.path = path 108 | self.extension = extension 109 | self.counter = 0 110 | os.makedirs(path, exist_ok=True) 111 | 112 | def write(self, frames, filenames=None): 113 | # frames: [T, C, H, W] 114 | for t in range(frames.shape[0]): 115 | if filenames is None: 116 | filename = str(self.counter).zfill(4) + '.' + self.extension 117 | else: 118 | filename = filenames[t].split('.')[0] + '.' + self.extension 119 | 120 | to_pil_image(frames[t]).save(os.path.join( 121 | self.path, filename)) 122 | self.counter += 1 123 | 124 | def close(self): 125 | pass 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
17 | 18 | 19 | ## 📖 Table of Contents 20 | 21 | - [Generative Video Matting](#-generative-video-matting) 22 | - [🔥 News](#-news) 23 | - [🚀 Getting Started](#-getting-started) 24 | - [Environment Requirement 🌍](#environment-requirement-) 25 | - [Download Model Weights ⬇️](#download-️model-weights-) 26 | - [🏃🏼 Run](#-run) 27 | - [Inference 📜](#inference-) 28 | - [Evaluation 📏](#evaluation-) 29 | - [🎫 License](#-license) 30 | - [📢 Disclaimer](#-disclaimer) 31 | - [🤝 Cite Us](#-cite-us) 32 | 33 | ## 🔥 News 34 | - **August 10, 2025:** Release the inference code and model checkpoints. 35 | - **June 11, 2025:** Repo created. The code and dataset for this project are currently being prepared for release and will be available here soon. Please stay tuned! 36 | 37 | 38 | ## 🚀 Getting Started 39 | 40 | ### Environment Requirement 🌍 41 | 42 | First, clone the repo: 43 | 44 | ``` 45 | git clone https://github.com/aim-uofa/GVM.git 46 | cd GVM 47 | ``` 48 | 49 | Then, we recommend you first use `conda` to create virtual environment, and install needed libraries. For example: 50 | 51 | ``` 52 | conda create -n gvm python=3.10 -y 53 | conda activate gvm 54 | pip install -r requirements.txt 55 | python setup.py develop 56 | ``` 57 | 58 | ### Download Model Weights ⬇️ 59 | 60 | You need to download the model weights by: 61 | 62 | ``` 63 | hugginface-cli download geyongtao/gvm --local-dir data/weights 64 | ``` 65 | 66 | The ckpt structure should be like: 67 | 68 | ``` 69 | |-- GVM 70 | |-- data 71 | |-- weights 72 | |-- vae 73 | |-- config.json 74 | |-- diffusion_pytorch_model.safetensors 75 | |-- unet 76 | |-- config.json 77 | |-- diffusion_pytorch_model.safetensors 78 | |-- scheduler 79 | |-- scheduler_config.json 80 | |-- datasets 81 | |-- demo_videos 82 | ``` 83 | 84 | 85 | 86 | ## 🏃🏼 Run 87 | 88 | ### Inference 📜 89 | 90 | You can run generative video matting with: 91 | 92 | ``` 93 | python demo.py \ 94 | --model_base 'data/weights/' \ 95 | --unet_base data/weights/unet \ 96 | --lora_base data/weights/unet \ 97 | --mode 'matte' \ 98 | --num_frames_per_batch 8 \ 99 | --num_interp_frames 1 \ 100 | --num_overlap_frames 1 \ 101 | --denoise_steps 1 \ 102 | --decode_chunk_size 8 \ 103 | --max_resolution 960 \ 104 | --pretrain_type 'svd' \ 105 | --data_dir 'data/demo_videos/xxx.mp4' \ 106 | --output_dir 'output_path' 107 | ``` 108 | 109 | 110 | ### Evaluation 📏 111 | 112 | ``` 113 | TODO 114 | ``` 115 | 116 | 117 | ## 🎫 License 118 | 119 | For academic usage, this project is licensed under [the 2-clause BSD License](LICENSE). For commercial inquiries, please contact [Chunhua Shen](mailto:chhshen@gmail.com). 120 | 121 | 122 | ## 📢 Disclaimer 123 | 124 | This repository provides a one-step model for faster inference speed. Its performance is slightly different from the results reported in the original SIGRRAPH paper. 125 | 126 | ## 🤝 Cite Us 127 | 128 | If you find this work helpful for your research, please cite: 129 | ``` 130 | @inproceedings{ge2025gvm, 131 | author = {Ge, Yongtao and Xie, Kangyang and Xu, Guangkai and Ke, Li and Liu, Mingyu and Huang, Longtao and Xue, Hui and Chen, Hao and Shen, Chunhua}, 132 | title = {Generative Video Matting}, 133 | publisher = {Association for Computing Machinery}, 134 | url = {https://doi.org/10.1145/3721238.3730642}, 135 | doi = {10.1145/3721238.3730642}, 136 | booktitle = {Proceedings of the Special Interest Group on Computer Graphics and Interactive Techniques Conference Conference Papers}, 137 | series = {SIGGRAPH Conference Papers '25} 138 | } 139 | ``` -------------------------------------------------------------------------------- /gvm/pipelines/pipeline_gvm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import numpy as np 4 | from diffusers import DiffusionPipeline 5 | from diffusers.utils import ( 6 | BaseOutput, 7 | USE_PEFT_BACKEND, 8 | is_peft_available, 9 | is_peft_version, 10 | is_torch_version, 11 | logging 12 | ) 13 | from diffusers.loaders.lora_pipeline import ( 14 | _LOW_CPU_MEM_USAGE_DEFAULT_LORA, 15 | StableDiffusionLoraLoaderMixin 16 | ) 17 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict 18 | import os 19 | 20 | import matplotlib 21 | from typing import Union, Dict 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class GVMLoraLoader(StableDiffusionLoraLoaderMixin): 26 | _lora_loadable_modules = ["unet"] 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | def load_lora_weights( 31 | self, 32 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 33 | adapter_name=None, 34 | hotswap: bool = False, 35 | **kwargs 36 | ): 37 | 38 | unet_lora_config = LoraConfig.from_pretrained(pretrained_model_name_or_path_or_dict) 39 | checkpoint = os.path.join(pretrained_model_name_or_path_or_dict, f"unet_lora.pt") 40 | unet_lora_ckpt = torch.load(checkpoint) 41 | self.unet = LoraModel(self.unet, unet_lora_config, "default") 42 | set_peft_model_state_dict(self.unet, unet_lora_ckpt) 43 | 44 | 45 | class GVMOutput(BaseOutput): 46 | r""" 47 | Output class for zero-shot text-to-video pipeline. 48 | 49 | Args: 50 | frames (`[List[PIL.Image.Image]`, `np.ndarray`]): 51 | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, 52 | num_channels)`. 53 | """ 54 | alpha: np.ndarray 55 | image: np.ndarray 56 | 57 | class GVMPipeline(DiffusionPipeline, GVMLoraLoader): 58 | def __init__(self, vae, unet, scheduler): 59 | super().__init__() 60 | self.register_modules( 61 | vae=vae, unet=unet, scheduler=scheduler 62 | ) 63 | 64 | def encode(self, input): 65 | num_frames = input.shape[1] 66 | input = input.flatten(0, 1) 67 | latent = self.vae.encode(input.to(self.vae.dtype)).latent_dist.mode() 68 | latent = latent * self.vae.config.scaling_factor 69 | latent = latent.reshape(-1, num_frames, *latent.shape[1:]) 70 | return latent 71 | 72 | def decode(self, latents, decode_chunk_size=16): 73 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] 74 | num_frames = latents.shape[1] 75 | latents = latents.flatten(0, 1) 76 | latents = latents / self.vae.config.scaling_factor 77 | 78 | # decode decode_chunk_size frames at a time to avoid OOM 79 | frames = [] 80 | for i in range(0, latents.shape[0], decode_chunk_size): 81 | num_frames_in = latents[i : i + decode_chunk_size].shape[0] 82 | frame = self.vae.decode( 83 | latents[i : i + decode_chunk_size].to(self.vae.dtype), 84 | num_frames=num_frames_in, 85 | ).sample 86 | frames.append(frame) 87 | frames = torch.cat(frames, dim=0) 88 | 89 | # [batch, frames, channels, height, width] 90 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]) 91 | return frames.to(torch.float32) 92 | 93 | 94 | def single_infer(self, rgb, position_ids=None, num_inference_steps=None, class_labels=None, noise_type="gaussian"): 95 | rgb_latent = self.encode(rgb) 96 | 97 | self.scheduler.set_timesteps(num_inference_steps, device=rgb.device) 98 | 99 | if noise_type == "gaussian": 100 | noise_latent = torch.randn_like(rgb_latent) 101 | timesteps = self.scheduler.timesteps 102 | elif noise_type == "zeros": 103 | noise_latent = torch.zeros_like(rgb_latent) 104 | timesteps = torch.ones_like(self.scheduler.timesteps) * (self.scheduler.config.num_train_timesteps - 1) # 999 105 | timesteps = timesteps.long() 106 | else: 107 | raise NotImplementedError 108 | 109 | image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to( 110 | noise_latent 111 | ) 112 | 113 | for i, t in enumerate(timesteps): 114 | latent_model_input = noise_latent 115 | latent_model_input = torch.cat([latent_model_input, rgb_latent], dim=2) 116 | # [batch_size, num_frame, 4, h, w] 117 | model_output = self.unet( 118 | latent_model_input, 119 | t, 120 | encoder_hidden_states=image_embeddings, 121 | position_ids=position_ids, 122 | class_labels=class_labels, 123 | ).sample 124 | 125 | if noise_type == 'zeros': 126 | noise_latent = model_output 127 | else: 128 | # compute the previous noisy sample x_t -> x_t-1 129 | noise_latent = self.scheduler.step( 130 | model_output, t, noise_latent 131 | ).prev_sample 132 | 133 | return noise_latent 134 | 135 | 136 | def __call__( 137 | self, 138 | image, 139 | num_frames, 140 | num_overlap_frames, 141 | num_interp_frames, 142 | decode_chunk_size, 143 | num_inference_steps, 144 | use_clip_img_emb=False, 145 | noise_type='zeros', 146 | mode='matte', 147 | ensemble_size: int = 3, 148 | ): 149 | 150 | assert ensemble_size >= 1 151 | self.vae.to(dtype=torch.float16) 152 | class_embedding = None 153 | 154 | # (1, N, 3, H, W) 155 | image = image.unsqueeze(0) 156 | B, N = image.shape[:2] 157 | rgb_norm = image * 2 - 1 # [-1, 1] 158 | 159 | rgb = rgb_norm.expand(ensemble_size, -1, -1, -1, -1) 160 | if N <= num_frames: 161 | position_ids = torch.arange(N).unsqueeze(0).repeat(B, 1).to(rgb.device) 162 | position_ids = torch.zeros_like(position_ids) 163 | position_ids = None 164 | 165 | latent_all = self.single_infer( 166 | rgb, 167 | num_inference_steps=num_inference_steps, 168 | class_labels=class_embedding, 169 | position_ids=position_ids, 170 | noise_type=noise_type 171 | ) 172 | else: 173 | # assert 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2 174 | assert num_frames % 2 == 0 175 | # num_interp_frames = num_frames - 2 176 | key_frame_indices = [] 177 | for i in range(0, N, num_frames - num_overlap_frames): 178 | if ( 179 | i + num_frames - 1 >= N 180 | or len(key_frame_indices) >= num_frames 181 | ): 182 | 183 | # print(i) 184 | pass 185 | 186 | key_frame_indices.append(i) 187 | key_frame_indices.append(min(N - 1, i + num_frames - 1)) 188 | 189 | key_frame_indices = torch.tensor(key_frame_indices, device=rgb.device) 190 | 191 | latent_all = None 192 | pre_latent = None 193 | 194 | for i in tqdm.tqdm(range(0, len(key_frame_indices), 2)): 195 | position_ids = torch.arange(0, key_frame_indices[i + 1] - key_frame_indices[i] + 1).to(rgb.device) 196 | position_ids = position_ids.unsqueeze(0).repeat(B, 1) 197 | position_ids = None 198 | latent = self.single_infer( 199 | rgb[:, key_frame_indices[i] : key_frame_indices[i + 1] + 1], 200 | position_ids=position_ids, 201 | num_inference_steps=num_inference_steps, 202 | class_labels=class_embedding 203 | ) 204 | 205 | if pre_latent is not None: 206 | ratio = ( 207 | torch.linspace(0, 1, num_overlap_frames) 208 | .to(latent) 209 | .view(1, -1, 1, 1, 1) 210 | ) 211 | try: 212 | latent_all[:, -num_overlap_frames:] = latent[:,:num_overlap_frames] * ratio + latent_all[:, -num_overlap_frames:] * (1 - ratio) 213 | except: 214 | num_overlap_frames = min(num_overlap_frames, latent.shape[1]) 215 | ratio = ( 216 | torch.linspace(0, 1, num_overlap_frames) 217 | .to(latent) 218 | .view(1, -1, 1, 1, 1) 219 | ) 220 | latent_all[:, -num_overlap_frames:] = latent[:,:num_overlap_frames] * ratio + latent_all[:, -num_overlap_frames:] * (1 - ratio) 221 | latent_all = torch.cat([latent_all, latent[:,num_overlap_frames:]], dim=1) 222 | else: 223 | latent_all = latent.clone() 224 | 225 | pre_latent = latent 226 | torch.cuda.empty_cache() 227 | 228 | assert latent_all.shape[1] == image.shape[1] 229 | 230 | alpha = self.decode(latent_all, decode_chunk_size=decode_chunk_size) 231 | 232 | # (N_videos, num_frames, H, W, 3) 233 | alpha = alpha.mean(dim=2, keepdim=True) 234 | alpha, _ = torch.max(alpha, dim=0) 235 | alpha = torch.clamp(alpha * 0.5 + 0.5, 0.0, 1.0) 236 | 237 | if alpha.dim() == 5: 238 | alpha = alpha.squeeze(0) 239 | 240 | # (N, H, W, 3) 241 | image = image.squeeze(0) 242 | 243 | return GVMOutput( 244 | alpha=alpha, 245 | image=image, 246 | ) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import os.path as osp 5 | import cv2 6 | import random 7 | 8 | from easydict import EasyDict 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from torchvision.transforms import ToTensor, Resize, Compose 15 | from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler 16 | 17 | from gvm.pipelines.pipeline_gvm import GVMPipeline 18 | from gvm.utils.inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter 19 | from gvm.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel 20 | from tqdm import tqdm 21 | 22 | def sequence_collate_fn(examples): 23 | rgb_values = torch.stack([example["image"] for example in examples]) 24 | rgb_values = rgb_values.to(memory_format=torch.contiguous_format).float() 25 | return {'rgb_values': rgb_values, 'rgb_names': rgb_names} 26 | 27 | def impad_multi(img, multiple=32): 28 | 29 | target_h = int(np.ceil(img.shape[2] / multiple) * multiple) 30 | target_w = int(np.ceil(img.shape[3] / multiple) * multiple) 31 | 32 | pad_top = (target_h - img.shape[2]) // 2 33 | pad_bottom = target_h - img.shape[2] - pad_top 34 | pad_left = (target_w - img.shape[3]) // 2 35 | pad_right = target_w - img.shape[3] - pad_left 36 | 37 | padded = torch.zeros((img.shape[0], img.shape[1], target_h, target_w), dtype=img.dtype) 38 | padded[:, :, pad_top:pad_top + img.shape[2], pad_left:pad_left + img.shape[3]] = img 39 | 40 | return padded, (pad_top, pad_left, pad_bottom, pad_right) 41 | 42 | 43 | def seed_all(seed: int = 0): 44 | """ 45 | Set random seeds of all components. 46 | """ 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | 52 | 53 | def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): 54 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] 55 | 56 | blurred_FA = cv2.blur(F * alpha, (r, r)) 57 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 58 | 59 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) 60 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 61 | F = blurred_F + alpha * \ 62 | (image - alpha * blurred_F - (1 - alpha) * blurred_B) 63 | F = np.clip(F, 0, 1) 64 | return F, blurred_B 65 | 66 | 67 | def FB_blur_fusion_foreground_estimator_1(image, alpha, r=90): 68 | alpha = alpha[:, :, None] 69 | return FB_blur_fusion_foreground_estimator(image, F=image, B=image, alpha=alpha, r=r)[0] 70 | 71 | 72 | def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): 73 | alpha = alpha[:, :, None] 74 | F, blur_B = FB_blur_fusion_foreground_estimator( 75 | image, image, image, alpha, r) 76 | return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] 77 | 78 | 79 | if "__main__" == __name__: 80 | logging.basicConfig(level=logging.INFO) 81 | 82 | parser = argparse.ArgumentParser( 83 | description="Run video matte." 84 | ) 85 | parser.add_argument( 86 | "--mode", 87 | type=str, 88 | default="matte", 89 | help="Inference mode.", 90 | ) 91 | parser.add_argument( 92 | "--variant", 93 | type=str, 94 | default="fp16", 95 | help=".", 96 | ) 97 | parser.add_argument( 98 | "--model_base", 99 | type=str, 100 | default="data/weights", 101 | help="Checkpoint path or hub name.", 102 | ) 103 | parser.add_argument( 104 | "--pretrain_type", 105 | type=str, 106 | default="dav", 107 | help="Checkpoint path or hub name.", 108 | ) 109 | parser.add_argument( 110 | "--unet_base", 111 | type=str, 112 | default=None, 113 | help="Checkpoint path or hub name.", 114 | ) 115 | parser.add_argument( 116 | "--lora_base", 117 | type=str, 118 | default=None, 119 | help="Checkpoint path or hub name.", 120 | ) 121 | parser.add_argument( 122 | "--noise_type", 123 | type=str, 124 | default='zeros', 125 | choices=['gaussian', 'zeros'], 126 | ) 127 | # data setting 128 | parser.add_argument( 129 | "--data_dir", type=str, required=True, help="input data directory." 130 | ) 131 | 132 | parser.add_argument( 133 | "--output_dir", type=str, required=True, help="Output directory." 134 | ) 135 | 136 | # inference setting 137 | parser.add_argument( 138 | "--denoise_steps", 139 | type=int, 140 | default=1, 141 | help="Denoising steps, 1-3 steps work fine.", 142 | ) 143 | parser.add_argument( 144 | "--num_frames_per_batch", 145 | type=int, 146 | default=32, 147 | help="Number of frames to infer per forward", 148 | ) 149 | parser.add_argument( 150 | "--max_frames", 151 | type=int, 152 | default=None, 153 | help="Number of frames to infer per forward", 154 | ) 155 | parser.add_argument( 156 | "--decode_chunk_size", 157 | type=int, 158 | default=16, 159 | help="Number of frames to decode per forward", 160 | ) 161 | parser.add_argument( 162 | "--num_interp_frames", 163 | type=int, 164 | default=16, 165 | help="Number of frames for inpaint inference", 166 | ) 167 | parser.add_argument( 168 | "--num_overlap_frames", 169 | type=int, 170 | default=6, 171 | help="Number of frames to overlap between windows", 172 | ) 173 | parser.add_argument( 174 | "--use_unet_interp", 175 | action="store_true", 176 | default=False, 177 | help="Whether use interploation unet", 178 | ) 179 | parser.add_argument( 180 | "--use_clip_img_emb", 181 | action="store_true", 182 | default=False, 183 | help="Whether use interploation unet", 184 | ) 185 | parser.add_argument( 186 | "--size", 187 | type=int, 188 | default=720, # decrease for faster inference and lower memory usage 189 | help="Maximum resolution for inference.", 190 | ) 191 | parser.add_argument( 192 | "--max_resolution", 193 | type=int, 194 | default=1024, # decrease for faster inference and lower memory usage 195 | help="Maximum resolution for inference.", 196 | ) 197 | parser.add_argument( 198 | "--output_image_seq_only", 199 | action="store_true", 200 | default=False, 201 | help="Whether to disable concatenating the result with rgb image", 202 | ) 203 | 204 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 205 | 206 | args = parser.parse_args() 207 | cfg = EasyDict(vars(args)) 208 | 209 | upper_bound = 240./255. 210 | lower_bound = 25./ 255. 211 | 212 | file_name = cfg.data_dir.split("/")[-1].split(".")[0] 213 | is_video = cfg.data_dir.endswith(".mp4") or cfg.data_dir.endswith(".mkv") or cfg.data_dir.endswith(".gif") 214 | is_gif = cfg.data_dir.endswith(".gif") 215 | 216 | is_image_sequence = (not is_video) and (cfg.data_dir.split('.')[-1] not in ['jpg', 'png', 'jpeg', 'JPG']) 217 | if is_image_sequence and osp.exists(cfg.output_dir) and len(os.listdir(cfg.output_dir)) != 0: 218 | exit() 219 | 220 | if cfg.seed is None: 221 | import time 222 | 223 | cfg.seed = int(time.time()) 224 | seed_all(cfg.seed) 225 | 226 | device_type = "cuda" 227 | device = torch.device(device_type) 228 | 229 | os.makedirs(cfg.output_dir, exist_ok=True) 230 | logging.info(f"output dir = {cfg.output_dir}") 231 | 232 | vae = AutoencoderKLTemporalDecoder.from_pretrained(cfg.model_base, subfolder="vae") 233 | scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 234 | cfg.model_base, 235 | subfolder="scheduler" 236 | ) 237 | unet_folder = cfg.unet_base if cfg.unet_base is not None else cfg.model_base 238 | 239 | if args.pretrain_type == 'dav': 240 | unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( 241 | unet_folder, 242 | subfolder="unet" 243 | ) 244 | else: 245 | unet = UNetSpatioTemporalConditionModel.from_pretrained( 246 | unet_folder, 247 | subfolder="unet", 248 | # variant=args.variant, 249 | class_embed_type=None, 250 | # low_cpu_mem_usage=True, 251 | ) 252 | # import pdb;pdb.set_trace() 253 | 254 | pipe = GVMPipeline( 255 | vae=vae, 256 | unet=unet, 257 | scheduler=scheduler, 258 | ) 259 | # pipe.to(args.variant) 260 | 261 | if args.lora_base is not None: 262 | pipe.load_lora_weights(f"{args.lora_base}/pytorch_lora_weights.safetensors") 263 | 264 | pipe = pipe.to(device) 265 | 266 | # import pdb;pdb.set_trace() 267 | if is_video: 268 | num_interp_frames = cfg.num_interp_frames 269 | num_overlap_frames = cfg.num_overlap_frames 270 | num_frames = cfg.num_frames_per_batch 271 | 272 | reader = VideoReader( 273 | cfg.data_dir, 274 | max_frames=cfg.max_frames, 275 | # transform=None, 276 | transform=Compose( 277 | [ 278 | ToTensor(), 279 | Resize(size=cfg.size, max_size=cfg.max_resolution) 280 | ] 281 | ) 282 | ) 283 | fps = reader.frame_rate 284 | origin_shape = reader.origin_shape 285 | total_frames = len(reader) 286 | print('total video frames: {}'.format(total_frames)) 287 | 288 | # elif is_image_sequence: 289 | # reader = ImageSequenceReader(cfg.data_dir, transform=None, dataset_name='video240k') 290 | # total_frames = len(reader) 291 | # print('total video frames: {}'.format(total_frames)) 292 | else: 293 | input_root_list = file_name 294 | 295 | # raise NotImplementedError 296 | # image, file_name = img_utils.read_image_sequence(cfg.data_dir) 297 | # else: 298 | # image = img_utils.read_image(cfg.data_dir) 299 | # origin_shape = image[0].shape[:2] 300 | 301 | # origin_shape_list = [img.shape[:2] for img in image] 302 | # image = img_utils.imresize_max(image, cfg.max_resolution) 303 | # image = img_utils.imcrop_multi(image) 304 | # image, pad_info_list = img_utils.impad_multi(image) 305 | if cfg.output_image_seq_only: 306 | writer_alpha = ImageSequenceWriter(cfg.output_dir) 307 | 308 | else: 309 | writer_alpha = VideoWriter('{}/{}'.format(cfg.output_dir, f"{file_name}.mp4"), frame_rate=fps) 310 | writer_green = VideoWriter('{}/{}'.format(cfg.output_dir, f"{file_name}_green.mp4"), frame_rate=fps) 311 | # writer_green_1 = VideoWriter('{}/{}'.format(cfg.output_dir, f"{file_name}_green_use_pred_fg.mp4"), frame_rate=int(fps)) 312 | # writer_fg = VideoWriter('{}/{}'.format(cfg.output_dir, f"{file_name}_fg.mp4"), frame_rate=fps) 313 | writer_green_seq = ImageSequenceWriter(cfg.output_dir) 314 | writer_alpha_seq = ImageSequenceWriter(cfg.output_dir) 315 | # writer_fg_seq = ImageSequenceWriter(cfg.output_dir) 316 | 317 | with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.float16): 318 | # RGB tensor normalized to 0 ~ 1. 319 | if is_video: 320 | dataloader = DataLoader(reader, batch_size=cfg.num_frames_per_batch) 321 | else: 322 | reader = ImageSequenceReader( 323 | input_root, 324 | transform=Compose( 325 | [ 326 | ToTensor(), 327 | Resize(size=cfg.size, max_size=cfg.max_resolution) 328 | ]) 329 | ) 330 | 331 | # dataloader = DataLoader( 332 | # reader, 333 | # batch_size=cfg.num_frames_per_batch, 334 | # collate_fn=sequence_collate_fn, 335 | # ) 336 | 337 | for batch_id, batch in tqdm(enumerate(dataloader)): 338 | # src, filename = batch 339 | # filename = filename[0] 340 | # import pdb;pdb.set_trace() 341 | if is_video: 342 | b, _, h, w = batch.shape 343 | filenames = [] 344 | filenames_green = [] 345 | for i in range(0, b): 346 | file_id = batch_id * b + i 347 | filenames.append("{}.jpg".format(file_id)) 348 | filenames_green.append("{}_comp.jpg".format(file_id)) 349 | else: 350 | filenames = batch['rgb_names'] 351 | batch = batch['rgb_values'] 352 | 353 | # origin_shape_list = [img.shape[:2] for img in batch] 354 | # batch = img_utils.imresize_max(batch, cfg.max_resolution) 355 | batch, pad_info = impad_multi(batch) 356 | 357 | pipe_out = pipe( 358 | batch.to(device), 359 | # num_frames=num_frames, 360 | num_frames=args.num_frames_per_batch, 361 | num_overlap_frames=cfg.num_overlap_frames, 362 | num_interp_frames=cfg.num_interp_frames, 363 | decode_chunk_size=cfg.decode_chunk_size, 364 | num_inference_steps=cfg.denoise_steps, 365 | # mode='matte' 366 | mode=args.mode, 367 | use_clip_img_emb=args.use_clip_img_emb, 368 | noise_type=args.noise_type, 369 | ) 370 | image = pipe_out.image 371 | alpha = pipe_out.alpha 372 | 373 | # alpha = np.repeat(alpha[...,None], 3, axis=-1) 374 | # crop and resize to the origin shape 375 | out_h, out_w = image.shape[2:] 376 | pad_t, pad_l, pad_b, pad_r = pad_info 377 | image = image[:, :, pad_t:(out_h-pad_b), pad_l:(out_w-pad_r)] # N, 3, H, W 378 | alpha = alpha[:, :, pad_t:(out_h-pad_b), pad_l:(out_w-pad_r)] # N, 3, H, W 379 | 380 | image = F.interpolate(image, origin_shape, mode='bilinear') 381 | alpha = F.interpolate(alpha, origin_shape, mode='bilinear') 382 | 383 | alpha[alpha>=upper_bound] = 1.0 384 | alpha[alpha<=lower_bound] = 0.0 385 | 386 | if cfg.output_image_seq_only: 387 | pass 388 | else: 389 | writer_alpha.write(alpha) 390 | writer_alpha_seq.write(alpha, filenames=filenames) -------------------------------------------------------------------------------- /gvm/models/unet_spatio_temporal_condition.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import UNet2DConditionLoadersMixin 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.models.attention_processor import ( 11 | CROSS_ATTENTION_PROCESSORS, 12 | AttentionProcessor, 13 | AttnProcessor, 14 | ) 15 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 16 | from diffusers.models.modeling_utils import ModelMixin 17 | from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 18 | from diffusers.loaders import PeftAdapterMixin 19 | from diffusers.models.unets.unet_spatio_temporal_condition import ( 20 | UNetSpatioTemporalConditionOutput, 21 | ) 22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 23 | 24 | 25 | class UNetSpatioTemporalConditionModel( 26 | ModelMixin, 27 | ConfigMixin, 28 | UNet2DConditionLoadersMixin, 29 | PeftAdapterMixin, 30 | # LoraLoaderMixin, 31 | ): 32 | r""" 33 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 34 | shaped output. 35 | 36 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 37 | for all models (such as downloading or saving). 38 | 39 | Parameters: 40 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 41 | Height and width of input/output sample. 42 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 43 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 44 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 45 | The tuple of downsample blocks to use. 46 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 47 | The tuple of upsample blocks to use. 48 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 49 | The tuple of output channels for each block. 50 | addition_time_embed_dim: (`int`, defaults to 256): 51 | Dimension to to encode the additional time ids. 52 | projection_class_embeddings_input_dim (`int`, defaults to 768): 53 | The dimension of the projection of encoded `added_time_ids`. 54 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 55 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 56 | The dimension of the cross attention features. 57 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 58 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 59 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 60 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 61 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 62 | The number of attention heads. 63 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 64 | """ 65 | 66 | _supports_gradient_checkpointing = True 67 | 68 | @register_to_config 69 | def __init__( 70 | self, 71 | sample_size: Optional[int] = None, 72 | in_channels: int = 8, 73 | out_channels: int = 4, 74 | down_block_types: Tuple[str] = ( 75 | "CrossAttnDownBlockSpatioTemporal", 76 | "CrossAttnDownBlockSpatioTemporal", 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "DownBlockSpatioTemporal", 79 | ), 80 | up_block_types: Tuple[str] = ( 81 | "UpBlockSpatioTemporal", 82 | "CrossAttnUpBlockSpatioTemporal", 83 | "CrossAttnUpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | ), 86 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 87 | addition_time_embed_dim: int = 256, 88 | projection_class_embeddings_input_dim: int = 768, 89 | layers_per_block: Union[int, Tuple[int]] = 2, 90 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 91 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 92 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), 93 | num_frames: int = 25, 94 | 95 | class_embed_type: Optional[str] = None, # 'projection', 96 | num_class_embeds: Optional[int] = None, 97 | act_fn: str = "silu", 98 | ): 99 | super().__init__() 100 | 101 | self.sample_size = sample_size 102 | 103 | # Check inputs 104 | if len(down_block_types) != len(up_block_types): 105 | raise ValueError( 106 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 107 | ) 108 | 109 | if len(block_out_channels) != len(down_block_types): 110 | raise ValueError( 111 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 112 | ) 113 | 114 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( 115 | down_block_types 116 | ): 117 | raise ValueError( 118 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 119 | ) 120 | 121 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( 122 | down_block_types 123 | ): 124 | raise ValueError( 125 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 126 | ) 127 | 128 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len( 129 | down_block_types 130 | ): 131 | raise ValueError( 132 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 133 | ) 134 | 135 | # input 136 | self.conv_in = nn.Conv2d( 137 | in_channels, 138 | block_out_channels[0], 139 | kernel_size=3, 140 | padding=1, 141 | ) 142 | 143 | # time 144 | time_embed_dim = block_out_channels[0] * 4 145 | 146 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 147 | timestep_input_dim = block_out_channels[0] 148 | 149 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 150 | 151 | # self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 152 | # self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 153 | 154 | self.down_blocks = nn.ModuleList([]) 155 | self.up_blocks = nn.ModuleList([]) 156 | 157 | if isinstance(num_attention_heads, int): 158 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 159 | 160 | if isinstance(cross_attention_dim, int): 161 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 162 | 163 | if isinstance(layers_per_block, int): 164 | layers_per_block = [layers_per_block] * len(down_block_types) 165 | 166 | if isinstance(transformer_layers_per_block, int): 167 | transformer_layers_per_block = [transformer_layers_per_block] * len( 168 | down_block_types 169 | ) 170 | 171 | blocks_time_embed_dim = time_embed_dim 172 | 173 | # down 174 | output_channel = block_out_channels[0] 175 | for i, down_block_type in enumerate(down_block_types): 176 | input_channel = output_channel 177 | output_channel = block_out_channels[i] 178 | is_final_block = i == len(block_out_channels) - 1 179 | 180 | down_block = get_down_block( 181 | down_block_type, 182 | num_layers=layers_per_block[i], 183 | transformer_layers_per_block=transformer_layers_per_block[i], 184 | in_channels=input_channel, 185 | out_channels=output_channel, 186 | temb_channels=blocks_time_embed_dim, 187 | add_downsample=not is_final_block, 188 | resnet_eps=1e-5, 189 | cross_attention_dim=cross_attention_dim[i], 190 | num_attention_heads=num_attention_heads[i], 191 | resnet_act_fn="silu", 192 | ) 193 | self.down_blocks.append(down_block) 194 | 195 | # mid 196 | self.mid_block = UNetMidBlockSpatioTemporal( 197 | block_out_channels[-1], 198 | temb_channels=blocks_time_embed_dim, 199 | transformer_layers_per_block=transformer_layers_per_block[-1], 200 | cross_attention_dim=cross_attention_dim[-1], 201 | num_attention_heads=num_attention_heads[-1], 202 | ) 203 | 204 | # count how many layers upsample the images 205 | self.num_upsamplers = 0 206 | 207 | # up 208 | reversed_block_out_channels = list(reversed(block_out_channels)) 209 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 210 | reversed_layers_per_block = list(reversed(layers_per_block)) 211 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 212 | reversed_transformer_layers_per_block = list( 213 | reversed(transformer_layers_per_block) 214 | ) 215 | 216 | output_channel = reversed_block_out_channels[0] 217 | for i, up_block_type in enumerate(up_block_types): 218 | is_final_block = i == len(block_out_channels) - 1 219 | 220 | prev_output_channel = output_channel 221 | output_channel = reversed_block_out_channels[i] 222 | input_channel = reversed_block_out_channels[ 223 | min(i + 1, len(block_out_channels) - 1) 224 | ] 225 | 226 | # add upsample block for all BUT final layer 227 | if not is_final_block: 228 | add_upsample = True 229 | self.num_upsamplers += 1 230 | else: 231 | add_upsample = False 232 | 233 | up_block = get_up_block( 234 | up_block_type, 235 | num_layers=reversed_layers_per_block[i] + 1, 236 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 237 | in_channels=input_channel, 238 | out_channels=output_channel, 239 | prev_output_channel=prev_output_channel, 240 | temb_channels=blocks_time_embed_dim, 241 | add_upsample=add_upsample, 242 | resnet_eps=1e-5, 243 | resolution_idx=i, 244 | cross_attention_dim=reversed_cross_attention_dim[i], 245 | num_attention_heads=reversed_num_attention_heads[i], 246 | resnet_act_fn="silu", 247 | ) 248 | self.up_blocks.append(up_block) 249 | prev_output_channel = output_channel 250 | 251 | # out 252 | self.conv_norm_out = nn.GroupNorm( 253 | num_channels=block_out_channels[0], num_groups=32, eps=1e-5 254 | ) 255 | self.conv_act = nn.SiLU() 256 | 257 | self.conv_out = nn.Conv2d( 258 | block_out_channels[0], 259 | out_channels, 260 | kernel_size=3, 261 | padding=1, 262 | ) 263 | # class embedding 264 | if class_embed_type is not None: 265 | self._set_class_embedding( 266 | class_embed_type, 267 | act_fn=act_fn, 268 | num_class_embeds=num_class_embeds, 269 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 270 | time_embed_dim=time_embed_dim, 271 | timestep_input_dim=timestep_input_dim, 272 | ) 273 | 274 | def _set_class_embedding( 275 | self, 276 | class_embed_type: Optional[str], 277 | act_fn: str, 278 | num_class_embeds: Optional[int], 279 | projection_class_embeddings_input_dim: Optional[int], 280 | time_embed_dim: int, 281 | timestep_input_dim: int, 282 | ): 283 | if class_embed_type is None and num_class_embeds is not None: 284 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 285 | elif class_embed_type == "timestep": 286 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) 287 | elif class_embed_type == "identity": 288 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 289 | elif class_embed_type == "projection": 290 | if projection_class_embeddings_input_dim is None: 291 | raise ValueError( 292 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 293 | ) 294 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 295 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 296 | # 2. it projects from an arbitrary input dimension. 297 | # 298 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 299 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 300 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 301 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 302 | elif class_embed_type == "simple_projection": 303 | if projection_class_embeddings_input_dim is None: 304 | raise ValueError( 305 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" 306 | ) 307 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) 308 | else: 309 | self.class_embedding = None 310 | 311 | def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: 312 | class_emb = None 313 | if self.class_embedding is not None: 314 | if class_labels is None: 315 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 316 | 317 | if self.config.class_embed_type == "timestep": 318 | class_labels = self.time_proj(class_labels) 319 | 320 | # `Timesteps` does not contain any weights and will always return f32 tensors 321 | # there might be better ways to encapsulate this. 322 | class_labels = class_labels.to(dtype=sample.dtype) 323 | 324 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 325 | return class_emb 326 | 327 | 328 | @property 329 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 330 | r""" 331 | Returns: 332 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 333 | indexed by its weight name. 334 | """ 335 | # set recursively 336 | processors = {} 337 | 338 | def fn_recursive_add_processors( 339 | name: str, 340 | module: torch.nn.Module, 341 | processors: Dict[str, AttentionProcessor], 342 | ): 343 | if hasattr(module, "get_processor"): 344 | processors[f"{name}.processor"] = module.get_processor( 345 | # return_deprecated_lora=True 346 | ) 347 | 348 | for sub_name, child in module.named_children(): 349 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 350 | 351 | return processors 352 | 353 | for name, module in self.named_children(): 354 | fn_recursive_add_processors(name, module, processors) 355 | 356 | return processors 357 | 358 | def set_attn_processor( 359 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] 360 | ): 361 | r""" 362 | Sets the attention processor to use to compute attention. 363 | 364 | Parameters: 365 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 366 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 367 | for **all** `Attention` layers. 368 | 369 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 370 | processor. This is strongly recommended when setting trainable attention processors. 371 | 372 | """ 373 | count = len(self.attn_processors.keys()) 374 | 375 | if isinstance(processor, dict) and len(processor) != count: 376 | raise ValueError( 377 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 378 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 379 | ) 380 | 381 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 382 | if hasattr(module, "set_processor"): 383 | if not isinstance(processor, dict): 384 | module.set_processor(processor) 385 | else: 386 | module.set_processor(processor.pop(f"{name}.processor")) 387 | 388 | for sub_name, child in module.named_children(): 389 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 390 | 391 | for name, module in self.named_children(): 392 | fn_recursive_attn_processor(name, module, processor) 393 | 394 | def set_default_attn_processor(self): 395 | """ 396 | Disables custom attention processors and sets the default attention implementation. 397 | """ 398 | if all( 399 | proc.__class__ in CROSS_ATTENTION_PROCESSORS 400 | for proc in self.attn_processors.values() 401 | ): 402 | processor = AttnProcessor() 403 | else: 404 | raise ValueError( 405 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 406 | ) 407 | 408 | self.set_attn_processor(processor) 409 | 410 | def _set_gradient_checkpointing(self, module, value=False): 411 | if hasattr(module, "gradient_checkpointing"): 412 | module.gradient_checkpointing = value 413 | 414 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 415 | def enable_forward_chunking( 416 | self, chunk_size: Optional[int] = None, dim: int = 0 417 | ) -> None: 418 | """ 419 | Sets the attention processor to use [feed forward 420 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 421 | 422 | Parameters: 423 | chunk_size (`int`, *optional*): 424 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 425 | over each tensor of dim=`dim`. 426 | dim (`int`, *optional*, defaults to `0`): 427 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 428 | or dim=1 (sequence length). 429 | """ 430 | if dim not in [0, 1]: 431 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 432 | 433 | # By default chunk size is 1 434 | chunk_size = chunk_size or 1 435 | 436 | def fn_recursive_feed_forward( 437 | module: torch.nn.Module, chunk_size: int, dim: int 438 | ): 439 | if hasattr(module, "set_chunk_feed_forward"): 440 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 441 | 442 | for child in module.children(): 443 | fn_recursive_feed_forward(child, chunk_size, dim) 444 | 445 | for module in self.children(): 446 | fn_recursive_feed_forward(module, chunk_size, dim) 447 | 448 | def forward( 449 | self, 450 | sample: torch.FloatTensor, 451 | timestep: Union[torch.Tensor, float, int], 452 | encoder_hidden_states: torch.Tensor, 453 | return_dict: bool = True, 454 | position_ids=None, 455 | class_labels=None, 456 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 457 | r""" 458 | The [`UNetSpatioTemporalConditionModel`] forward method. 459 | 460 | Args: 461 | sample (`torch.FloatTensor`): 462 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 463 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 464 | encoder_hidden_states (`torch.FloatTensor`): 465 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 466 | return_dict (`bool`, *optional*, defaults to `True`): 467 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 468 | tuple. 469 | Returns: 470 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 471 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 472 | a `tuple` is returned where the first element is the sample tensor. 473 | """ 474 | default_overall_up_factor = 2**self.num_upsamplers 475 | 476 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 477 | forward_upsample_size = False 478 | upsample_size = None 479 | 480 | # for dim in sample.shape[-2:]: 481 | # if dim % default_overall_up_factor != 0: 482 | # # Forward upsample size to force interpolation output size. 483 | # forward_upsample_size = True 484 | # break 485 | 486 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 487 | # logger.info("Forward upsample size to force interpolation output size.") 488 | forward_upsample_size = True 489 | 490 | # 1. time 491 | timesteps = timestep 492 | if not torch.is_tensor(timesteps): 493 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 494 | # This would be a good case for the `match` statement (Python 3.10+) 495 | is_mps = sample.device.type == "mps" 496 | if isinstance(timestep, float): 497 | dtype = torch.float32 if is_mps else torch.float64 498 | else: 499 | dtype = torch.int32 if is_mps else torch.int64 500 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 501 | elif len(timesteps.shape) == 0: 502 | timesteps = timesteps[None].to(sample.device) 503 | 504 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 505 | batch_size, num_frames = sample.shape[:2] 506 | timesteps = timesteps.expand(batch_size) 507 | 508 | t_emb = self.time_proj(timesteps) 509 | 510 | # `Timesteps` does not contain any weights and will always return f32 tensors 511 | # but time_embedding might actually be running in fp16. so we need to cast here. 512 | # there might be better ways to encapsulate this. 513 | t_emb = t_emb.to(dtype=sample.dtype) 514 | emb = self.time_embedding(t_emb) 515 | 516 | 517 | # time_embeds = self.add_time_proj(added_time_ids.flatten()) 518 | # time_embeds = time_embeds.reshape((batch_size, -1)) 519 | # time_embeds = time_embeds.to(emb.dtype) 520 | # aug_emb = self.add_embedding(time_embeds) 521 | # emb = emb + aug_emb 522 | 523 | # if class_labels is not None: 524 | # class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) 525 | # emb = emb + class_emb 526 | 527 | # else: 528 | # class_emb = None 529 | # # import pdb;pdb.set_trace() 530 | # if class_emb is not None: 531 | # emb = emb + class_emb 532 | 533 | # Flatten the batch and frames dimensions 534 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 535 | sample = sample.flatten(0, 1) 536 | # Repeat the embeddings num_video_frames times 537 | # emb: [batch, channels] -> [batch * frames, channels] 538 | emb = emb.repeat_interleave(num_frames, dim=0) 539 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 540 | encoder_hidden_states = encoder_hidden_states.repeat_interleave( 541 | num_frames, dim=0 542 | ) 543 | 544 | # 2. pre-process 545 | sample = self.conv_in(sample) 546 | 547 | image_only_indicator = torch.zeros( 548 | batch_size, num_frames, dtype=sample.dtype, device=sample.device 549 | ) 550 | 551 | down_block_res_samples = (sample,) 552 | for downsample_block in self.down_blocks: 553 | if ( 554 | hasattr(downsample_block, "has_cross_attention") 555 | and downsample_block.has_cross_attention 556 | ): 557 | sample, res_samples = downsample_block( 558 | hidden_states=sample, 559 | temb=emb, 560 | encoder_hidden_states=encoder_hidden_states, 561 | image_only_indicator=image_only_indicator, 562 | # position_ids=position_ids, 563 | ) 564 | else: 565 | sample, res_samples = downsample_block( 566 | hidden_states=sample, 567 | temb=emb, 568 | image_only_indicator=image_only_indicator, 569 | # position_ids=position_ids, 570 | ) 571 | 572 | down_block_res_samples += res_samples 573 | 574 | # 4. mid 575 | sample = self.mid_block( 576 | hidden_states=sample, 577 | temb=emb, 578 | encoder_hidden_states=encoder_hidden_states, 579 | image_only_indicator=image_only_indicator, 580 | # position_ids=position_ids, 581 | ) 582 | 583 | # 5. up 584 | for i, upsample_block in enumerate(self.up_blocks): 585 | is_final_block = i == len(self.up_blocks) - 1 586 | 587 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 588 | down_block_res_samples = down_block_res_samples[ 589 | : -len(upsample_block.resnets) 590 | ] 591 | if not is_final_block and forward_upsample_size: 592 | upsample_size = down_block_res_samples[-1].shape[2:] 593 | 594 | if ( 595 | hasattr(upsample_block, "has_cross_attention") 596 | and upsample_block.has_cross_attention 597 | ): 598 | sample = upsample_block( 599 | hidden_states=sample, 600 | temb=emb, 601 | res_hidden_states_tuple=res_samples, 602 | encoder_hidden_states=encoder_hidden_states, 603 | upsample_size=upsample_size, 604 | image_only_indicator=image_only_indicator, 605 | # position_ids=position_ids, 606 | ) 607 | else: 608 | # print('unet 611 upsample_size:', upsample_size) 609 | sample = upsample_block( 610 | hidden_states=sample, 611 | temb=emb, 612 | res_hidden_states_tuple=res_samples, 613 | upsample_size=upsample_size, 614 | image_only_indicator=image_only_indicator, 615 | # position_ids=position_ids, 616 | ) 617 | 618 | # 6. post-process 619 | sample = self.conv_norm_out(sample) 620 | sample = self.conv_act(sample) 621 | sample = self.conv_out(sample) 622 | 623 | # 7. Reshape back to original shape 624 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 625 | 626 | if not return_dict: 627 | return (sample,) 628 | 629 | return UNetSpatioTemporalConditionOutput(sample=sample) 630 | --------------------------------------------------------------------------------